1use proc_macro::{Delimiter, Spacing, TokenStream, TokenTree};
9
10#[proc_macro]
57pub fn client(input: TokenStream) -> TokenStream {
58 match expand_client(input) {
59 Ok(stream) => stream,
60 Err(message) => compile_error(message),
61 }
62}
63
64#[proc_macro_derive(Depends, attributes(dep))]
94pub fn derive_depends(input: TokenStream) -> TokenStream {
95 match derive_depends_impl(input) {
96 Ok(stream) => stream,
97 Err(message) => compile_error(message),
98 }
99}
100
101fn expand_client(input: TokenStream) -> Result<TokenStream, String> {
104 let tokens = input.into_iter().collect::<Vec<_>>();
105 let struct_index = tokens
106 .iter()
107 .position(|token| is_ident(token, "struct"))
108 .ok_or_else(|| "client! expects `struct`".to_string())?;
109
110 let visibility = tokens_to_string(&tokens[..struct_index]);
111 let name = ident_at(&tokens, struct_index + 1, "a client name")?;
112
113 if !matches!(tokens.get(struct_index + 2), Some(token) if is_ident(token, "as")) {
114 return Err("client! expects `as <module_name>` after the struct name".into());
115 }
116
117 let module = ident_at(&tokens, struct_index + 3, "a module name after `as`")?;
118 let body = match tokens.get(struct_index + 4) {
119 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group.stream(),
120 _ => return Err("client! expects a braced method body".into()),
121 };
122
123 if tokens.len() != struct_index + 5 {
124 return Err("unexpected tokens after the client body".into());
125 }
126
127 let methods = parse_methods(body)?;
128 if methods.is_empty() {
129 return Err("client! requires at least one method".into());
130 }
131
132 let visibility_prefix = with_trailing_space(&visibility);
133 let field_lines = methods
134 .iter()
135 .map(Method::render_field)
136 .collect::<Vec<_>>()
137 .join("\n");
138 let method_lines = methods
139 .iter()
140 .map(Method::render_method)
141 .collect::<Vec<_>>()
142 .join("\n\n");
143 let live_lines = methods
144 .iter()
145 .map(|method| format!("{}: {}", method.name, method.render_live_initializer(&name)))
146 .collect::<Vec<_>>()
147 .join(",\n ");
148 let module_lines = methods
149 .iter()
150 .map(|method| method.render_module(&name))
151 .collect::<Vec<_>>()
152 .join("\n");
153
154 let output = format!(
155 "#[derive(Clone, Copy)]
156 {visibility_prefix}struct {name} {{
157 {field_lines}
158 }}
159
160 impl {name} {{
161 {method_lines}
162 }}
163
164 impl ::clients::Dependency for {name} {{
165 fn live() -> Self {{
166 Self {{
167 {live_lines}
168 }}
169 }}
170 }}
171
172 impl ::core::default::Default for {name} {{
173 fn default() -> Self {{
174 <Self as ::clients::Dependency>::live()
175 }}
176 }}
177
178 {visibility_prefix}mod {module} {{
179 use super::*;
180
181 pub fn get() -> super::{name} {{
182 ::clients::get::<super::{name}>()
183 }}
184
185 {module_lines}
186 }}"
187 );
188
189 output
190 .parse::<TokenStream>()
191 .map_err(|error| error.to_string())
192}
193
194#[derive(Clone)]
196struct Method {
197 name: String,
199 visibility: String,
202 arguments: Vec<Argument>,
204 return_ty: String,
206 implementation: Option<String>,
208 is_async: bool,
210}
211
212#[derive(Clone)]
214struct Argument {
215 name: String,
217 ty: String,
219}
220
221impl Method {
222 fn arity(&self) -> usize {
224 self.arguments.len()
225 }
226
227 fn eraser_name(&self) -> String {
229 if self.is_async {
230 format!("::clients::erase_async_{}", self.arity())
231 } else {
232 format!("::clients::erase_sync_{}", self.arity())
233 }
234 }
235
236 fn args_decl(&self) -> String {
239 self.arguments
240 .iter()
241 .map(|argument| format!("{}: {}", argument.name, argument.ty))
242 .collect::<Vec<_>>()
243 .join(", ")
244 }
245
246 fn args_types(&self) -> String {
248 self.arguments
249 .iter()
250 .map(|argument| argument.ty.clone())
251 .collect::<Vec<_>>()
252 .join(", ")
253 }
254
255 fn args_names(&self) -> String {
257 self.arguments
258 .iter()
259 .map(|argument| argument.name.clone())
260 .collect::<Vec<_>>()
261 .join(", ")
262 }
263
264 fn fn_pointer_return(&self) -> String {
266 if self.is_async {
267 format!("::clients::BoxFuture<{}>", self.return_ty)
268 } else {
269 self.return_ty.clone()
270 }
271 }
272
273 fn render_field(&self) -> String {
275 format!(
276 "{}: fn({}) -> {},",
277 self.name,
278 self.args_types(),
279 self.fn_pointer_return()
280 )
281 }
282
283 fn render_method(&self) -> String {
286 let visibility = with_trailing_space(&self.visibility);
287 let args_decl = self.args_decl();
288 let call_args = self.args_names();
289
290 if self.is_async {
291 format!(
292 "{visibility}async fn {}(&self{}{}) -> {} {{
293 (self.{})({}).await
294 }}",
295 self.name,
296 maybe_comma(&args_decl),
297 args_decl,
298 self.return_ty,
299 self.name,
300 call_args,
301 )
302 } else {
303 format!(
304 "{visibility}fn {}(&self{}{}) -> {} {{
305 (self.{})({})
306 }}",
307 self.name,
308 maybe_comma(&args_decl),
309 args_decl,
310 self.return_ty,
311 self.name,
312 call_args,
313 )
314 }
315 }
316
317 fn render_live_initializer(&self, client_name: &str) -> String {
323 if let Some(implementation) = &self.implementation {
324 format!("{}({implementation})", self.eraser_name())
325 } else if self.is_async {
326 format!(
327 "{{
328 fn __dep_unimplemented({}) -> ::clients::BoxFuture<{}> {{
329 ::clients::boxed(async move {{
330 ::clients::unimplemented_dependency(\"{}.{}\")
331 }})
332 }}
333
334 __dep_unimplemented
335 }}",
336 self.args_decl(),
337 self.return_ty,
338 client_name,
339 self.name,
340 )
341 } else {
342 format!(
343 "{{
344 fn __dep_unimplemented({}) -> {} {{
345 ::clients::unimplemented_dependency(\"{}.{}\")
346 }}
347
348 __dep_unimplemented
349 }}",
350 self.args_decl(),
351 self.return_ty,
352 client_name,
353 self.name,
354 )
355 }
356 }
357
358 fn render_module(&self, client_name: &str) -> String {
360 let args_types = self.args_types();
361 let fn_pointer_return = self.fn_pointer_return();
362 let eraser = self.eraser_name();
363
364 if self.is_async {
365 format!(
366 "pub mod {} {{
367 use super::*;
368
369 pub fn get() -> fn({}) -> {} {{
370 super::get().{}
371 }}
372
373 pub fn override_with<F, Fut>(builder: &mut ::clients::OverrideBuilder, implementation: F)
374 where
375 F: Fn({}) -> Fut + Copy + 'static,
376 Fut: ::core::future::Future<Output = {}> + Send + 'static,
377 {{
378 builder.update::<super::super::{client_name}, _>(|mut dependency| {{
379 dependency.{} = {}(implementation);
380 dependency
381 }});
382 }}
383 }}",
384 self.name,
385 args_types,
386 fn_pointer_return,
387 self.name,
388 args_types,
389 self.return_ty,
390 self.name,
391 eraser,
392 )
393 } else {
394 format!(
395 "pub mod {} {{
396 use super::*;
397
398 pub fn get() -> fn({}) -> {} {{
399 super::get().{}
400 }}
401
402 pub fn override_with<F>(builder: &mut ::clients::OverrideBuilder, implementation: F)
403 where
404 F: Fn({}) -> {} + Copy + 'static,
405 {{
406 builder.update::<super::super::{client_name}, _>(|mut dependency| {{
407 dependency.{} = {}(implementation);
408 dependency
409 }});
410 }}
411 }}",
412 self.name,
413 args_types,
414 fn_pointer_return,
415 self.name,
416 args_types,
417 self.return_ty,
418 self.name,
419 eraser,
420 )
421 }
422 }
423}
424
425fn parse_methods(stream: TokenStream) -> Result<Vec<Method>, String> {
427 split_top_level(stream, ';')
428 .into_iter()
429 .map(|tokens| parse_method(&tokens))
430 .collect()
431}
432
433fn parse_method(tokens: &[TokenTree]) -> Result<Method, String> {
435 if tokens.is_empty() {
436 return Err("empty method definition".into());
437 }
438
439 let fn_index = tokens
440 .iter()
441 .position(|token| is_ident(token, "fn"))
442 .ok_or_else(|| "client methods must use `fn`".to_string())?;
443
444 let mut leading = tokens[..fn_index].to_vec();
445 let is_async = matches!(
446 leading.last(),
447 Some(TokenTree::Ident(ident)) if ident.to_string() == "async"
448 );
449 if is_async {
450 leading.pop();
451 }
452
453 let visibility = tokens_to_string(&leading);
454 let name = ident_at(tokens, fn_index + 1, "a method name")?;
455 let arguments_group = match tokens.get(fn_index + 2) {
456 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => {
457 group.stream()
458 }
459 _ => return Err(format!("method `{name}` is missing its argument list")),
460 };
461
462 let rest = &tokens[fn_index + 3..];
463 if !matches!(rest.first(), Some(TokenTree::Punct(punct)) if punct.as_char() == '-')
464 || !matches!(rest.get(1), Some(TokenTree::Punct(punct)) if punct.as_char() == '>')
465 {
466 return Err(format!("method `{name}` is missing `->`"));
467 }
468
469 let eq_index = rest
470 .iter()
471 .position(|token| matches!(token, TokenTree::Punct(punct) if punct.as_char() == '='));
472 let return_tokens = match eq_index {
473 Some(index) => &rest[2..index],
474 None => &rest[2..],
475 };
476 if return_tokens.is_empty() {
477 return Err(format!("method `{name}` is missing a return type"));
478 }
479
480 let implementation = eq_index.map(|index| tokens_to_string(&rest[index + 1..]));
481 let arguments = parse_arguments(arguments_group)?;
482 if arguments.len() > 4 {
483 return Err(format!(
484 "method `{name}` has {} arguments, but only up to 4 are supported right now",
485 arguments.len()
486 ));
487 }
488
489 Ok(Method {
490 name,
491 visibility,
492 arguments,
493 return_ty: tokens_to_string(return_tokens),
494 implementation,
495 is_async,
496 })
497}
498
499fn parse_arguments(stream: TokenStream) -> Result<Vec<Argument>, String> {
501 split_top_level(stream, ',')
502 .into_iter()
503 .map(|tokens| {
504 let colon_index = tokens
505 .iter()
506 .position(
507 |token| matches!(token, TokenTree::Punct(punct) if punct.as_char() == ':'),
508 )
509 .ok_or_else(|| "expected arguments to look like `name: Type`".to_string())?;
510
511 let name = tokens[..colon_index]
512 .iter()
513 .rev()
514 .find_map(|token| match token {
515 TokenTree::Ident(ident) => Some(ident.to_string()),
516 _ => None,
517 })
518 .ok_or_else(|| "expected an argument name".to_string())?;
519
520 let ty = tokens_to_string(&tokens[colon_index + 1..]);
521 if ty.is_empty() {
522 return Err("expected an argument type".into());
523 }
524
525 Ok(Argument { name, ty })
526 })
527 .collect()
528}
529
530fn derive_depends_impl(input: TokenStream) -> Result<TokenStream, String> {
532 let mut tokens = input.into_iter().peekable();
533
534 while let Some(token) = tokens.next() {
535 if is_ident(&token, "struct") {
536 return expand_struct(tokens);
537 }
538 }
539
540 Err("Depends can only be derived for structs".into())
541}
542
543fn expand_struct<I>(mut tokens: I) -> Result<TokenStream, String>
545where
546 I: Iterator<Item = TokenTree>,
547{
548 let name = match tokens.next() {
549 Some(TokenTree::Ident(ident)) => ident,
550 _ => return Err("expected a struct name".into()),
551 };
552
553 let fields_group = match tokens.next() {
554 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
555 Some(_) => return Err("Depends does not support generics or where clauses yet".into()),
556 None => return Err("expected a braced struct body".into()),
557 };
558
559 let fields = parse_fields(fields_group.stream())?;
560 let initializers = fields
561 .into_iter()
562 .map(|field| {
563 if field.injected {
564 format!("{}: ::clients::get::<{}>()", field.name, field.ty)
565 } else {
566 format!("{}: ::core::default::Default::default()", field.name)
567 }
568 })
569 .collect::<Vec<_>>()
570 .join(", ");
571
572 let output = format!(
573 "impl ::core::default::Default for {name} {{
574 fn default() -> Self {{
575 Self {{ {initializers} }}
576 }}
577 }}
578
579 impl {name} {{
580 #[doc = \"Constructs `Self` by resolving every `#[dep]` field from the dependency system and initializing all other fields with `Default::default()`.\"]
581 pub fn from_deps() -> Self {{
582 ::core::default::Default::default()
583 }}
584 }}",
585 );
586
587 output
588 .parse::<TokenStream>()
589 .map_err(|error| error.to_string())
590}
591
592struct Field {
594 name: String,
596 ty: String,
598 injected: bool,
600}
601
602fn parse_fields(stream: TokenStream) -> Result<Vec<Field>, String> {
604 split_top_level(stream, ',')
605 .into_iter()
606 .map(|tokens| parse_field(&tokens))
607 .collect()
608}
609
610fn parse_field(tokens: &[TokenTree]) -> Result<Field, String> {
612 let mut injected = false;
613 let mut colon_index = None;
614
615 for (index, token) in tokens.iter().enumerate() {
616 if matches_dep_attribute(tokens, index) {
617 injected = true;
618 }
619
620 if let TokenTree::Punct(punct) = token
621 && punct.as_char() == ':'
622 {
623 colon_index = Some(index);
624 break;
625 }
626 }
627
628 let colon_index = colon_index.ok_or_else(|| "expected a named struct field".to_string())?;
629
630 let name = tokens[..colon_index]
631 .iter()
632 .rev()
633 .find_map(|token| match token {
634 TokenTree::Ident(ident) => Some(ident.to_string()),
635 _ => None,
636 })
637 .ok_or_else(|| "expected a field name".to_string())?;
638
639 let ty_tokens = tokens[colon_index + 1..]
640 .iter()
641 .cloned()
642 .collect::<TokenStream>();
643 if ty_tokens.is_empty() {
644 return Err("expected a field type".into());
645 }
646
647 Ok(Field {
648 name,
649 ty: ty_tokens.to_string(),
650 injected,
651 })
652}
653
654fn split_top_level(stream: TokenStream, separator: char) -> Vec<Vec<TokenTree>> {
661 let mut items = Vec::new();
662 let mut current = Vec::new();
663 let mut angle_depth = 0usize;
664
665 for token in stream {
666 let should_split = matches!(
667 &token,
668 TokenTree::Punct(punct)
669 if punct.as_char() == separator
670 && punct.spacing() == Spacing::Alone
671 && angle_depth == 0
672 );
673
674 if should_split {
675 if !current.is_empty() {
676 items.push(current);
677 current = Vec::new();
678 }
679 continue;
680 }
681
682 if let TokenTree::Punct(punct) = &token {
683 match punct.as_char() {
684 '<' => angle_depth += 1,
685 '>' => angle_depth = angle_depth.saturating_sub(1),
686 _ => {}
687 }
688 }
689
690 current.push(token);
691 }
692
693 if !current.is_empty() {
694 items.push(current);
695 }
696
697 items
698}
699
700fn ident_at(tokens: &[TokenTree], index: usize, expected: &str) -> Result<String, String> {
702 match tokens.get(index) {
703 Some(TokenTree::Ident(ident)) => Ok(ident.to_string()),
704 _ => Err(format!("expected {expected}")),
705 }
706}
707
708fn matches_dep_attribute(tokens: &[TokenTree], index: usize) -> bool {
710 let Some(TokenTree::Punct(pound)) = tokens.get(index) else {
711 return false;
712 };
713 if pound.as_char() != '#' {
714 return false;
715 }
716
717 let Some(TokenTree::Group(group)) = tokens.get(index + 1) else {
718 return false;
719 };
720
721 if group.delimiter() != Delimiter::Bracket {
722 return false;
723 }
724
725 let mut attribute_tokens = group.stream().into_iter();
726 matches!(attribute_tokens.next(), Some(TokenTree::Ident(ident)) if ident.to_string() == "dep")
727}
728
729fn is_ident(token: &TokenTree, expected: &str) -> bool {
731 matches!(token, TokenTree::Ident(ident) if ident.to_string() == expected)
732}
733
734fn tokens_to_string(tokens: &[TokenTree]) -> String {
736 tokens
737 .iter()
738 .map(TokenTree::to_string)
739 .collect::<Vec<_>>()
740 .join(" ")
741}
742
743fn with_trailing_space(value: &str) -> String {
745 if value.is_empty() {
746 String::new()
747 } else {
748 format!("{value} ")
749 }
750}
751
752fn maybe_comma(value: &str) -> &'static str {
755 if value.is_empty() { "" } else { ", " }
756}
757
758fn compile_error(message: String) -> TokenStream {
760 format!("compile_error!({message:?});")
761 .parse()
762 .expect("compile_error! should parse")
763}