Skip to main content

clients_macros/
lib.rs

1//! Procedural macros that power the public `clients` API.
2//!
3//! The runtime crate intentionally keeps parsing dependencies small and local, so
4//! this crate hand-rolls just enough token inspection to support the user-facing
5//! `client!` and `#[derive(Depends)]` syntaxes without pulling in a larger macro
6//! parsing stack.
7
8use proc_macro::{Delimiter, Spacing, TokenStream, TokenTree};
9
10/// Declares a concrete dependency client backed by raw function pointers.
11///
12/// The generated type is a plain `struct`, not a trait object or a trait-based
13/// abstraction. Each declared method is stored as a function pointer field, and
14/// the generated method bodies simply call through to those pointers.
15///
16/// The macro also generates a helper module whose name comes after `as`, along
17/// with one nested helper module per declared method. Those helper modules are
18/// what power [`clients::deps!`](https://docs.rs/clients/latest/clients/macro.deps.html)
19/// and [`clients::test_deps!`](https://docs.rs/clients/latest/clients/macro.test_deps.html).
20///
21/// A method may provide a live implementation directly:
22///
23/// ```ignore
24/// client! {
25///     pub struct Clock as clock {
26///         pub fn now_millis() -> u64 = || 1234;
27///     }
28/// }
29/// ```
30///
31/// Or it may leave the live implementation unspecified, causing calls to panic
32/// unless a test override supplies one:
33///
34/// ```ignore
35/// client! {
36///     pub struct UserClient as user_client {
37///         pub fn fetch_user(id: u64) -> Result<User, DependencyError>;
38///     }
39/// }
40/// ```
41///
42/// Async methods are supported directly:
43///
44/// ```ignore
45/// client! {
46///     pub struct AsyncClock as async_clock {
47///         pub async fn now_millis() -> u64 = || async { 1234 };
48///     }
49/// }
50/// ```
51///
52/// Current limitations:
53///
54/// - method implementations must be non-capturing closures or function items
55/// - at most 4 method arguments are supported
56#[proc_macro]
57pub fn client(input: TokenStream) -> TokenStream {
58    match expand_client(input) {
59        Ok(stream) => stream,
60        Err(message) => compile_error(message),
61    }
62}
63
64/// Derives dependency-backed construction for a simple braced struct.
65///
66/// Fields marked with `#[dep]` are initialized with `::clients::get::<FieldType>()`.
67/// All other fields are initialized with `Default::default()`.
68///
69/// The derive generates:
70///
71/// - an implementation of `Default`
72/// - a `from_deps() -> Self` convenience constructor
73///
74/// Example:
75///
76/// ```ignore
77/// #[derive(Depends)]
78/// struct Greeter {
79///     #[dep]
80///     user_client: UserClient,
81///     #[dep]
82///     clock: Clock,
83///     greeting_prefix: String,
84/// }
85///
86/// let greeter = Greeter::from_deps();
87/// ```
88///
89/// Current limitations:
90///
91/// - only braced structs are supported
92/// - generics and where-clauses are not currently supported
93#[proc_macro_derive(Depends, attributes(dep))]
94pub fn derive_depends(input: TokenStream) -> TokenStream {
95    match derive_depends_impl(input) {
96        Ok(stream) => stream,
97        Err(message) => compile_error(message),
98    }
99}
100
101/// Expands a `client!` invocation into a concrete client struct plus helper
102/// modules for runtime lookup and test overrides.
103fn expand_client(input: TokenStream) -> Result<TokenStream, String> {
104    let tokens = input.into_iter().collect::<Vec<_>>();
105    let struct_index = tokens
106        .iter()
107        .position(|token| is_ident(token, "struct"))
108        .ok_or_else(|| "client! expects `struct`".to_string())?;
109
110    let visibility = tokens_to_string(&tokens[..struct_index]);
111    let name = ident_at(&tokens, struct_index + 1, "a client name")?;
112
113    if !matches!(tokens.get(struct_index + 2), Some(token) if is_ident(token, "as")) {
114        return Err("client! expects `as <module_name>` after the struct name".into());
115    }
116
117    let module = ident_at(&tokens, struct_index + 3, "a module name after `as`")?;
118    let body = match tokens.get(struct_index + 4) {
119        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group.stream(),
120        _ => return Err("client! expects a braced method body".into()),
121    };
122
123    if tokens.len() != struct_index + 5 {
124        return Err("unexpected tokens after the client body".into());
125    }
126
127    let methods = parse_methods(body)?;
128    if methods.is_empty() {
129        return Err("client! requires at least one method".into());
130    }
131
132    let visibility_prefix = with_trailing_space(&visibility);
133    let field_lines = methods
134        .iter()
135        .map(Method::render_field)
136        .collect::<Vec<_>>()
137        .join("\n");
138    let method_lines = methods
139        .iter()
140        .map(Method::render_method)
141        .collect::<Vec<_>>()
142        .join("\n\n");
143    let live_lines = methods
144        .iter()
145        .map(|method| format!("{}: {}", method.name, method.render_live_initializer(&name)))
146        .collect::<Vec<_>>()
147        .join(",\n                    ");
148    let module_lines = methods
149        .iter()
150        .map(|method| method.render_module(&name))
151        .collect::<Vec<_>>()
152        .join("\n");
153
154    let output = format!(
155        "#[derive(Clone, Copy)]
156        {visibility_prefix}struct {name} {{
157            {field_lines}
158        }}
159
160        impl {name} {{
161            {method_lines}
162        }}
163
164        impl ::clients::Dependency for {name} {{
165            fn live() -> Self {{
166                Self {{
167                    {live_lines}
168                }}
169            }}
170        }}
171
172        impl ::core::default::Default for {name} {{
173            fn default() -> Self {{
174                <Self as ::clients::Dependency>::live()
175            }}
176        }}
177
178        {visibility_prefix}mod {module} {{
179            use super::*;
180
181            pub fn get() -> super::{name} {{
182                ::clients::get::<super::{name}>()
183            }}
184
185            {module_lines}
186        }}"
187    );
188
189    output
190        .parse::<TokenStream>()
191        .map_err(|error| error.to_string())
192}
193
194/// Internal representation of one declared client method.
195#[derive(Clone)]
196struct Method {
197    /// The Rust identifier used for both the field and the wrapper method.
198    name: String,
199    /// Any leading method tokens that need to be replayed, such as `pub` or
200    /// attributes that were attached to the method declaration.
201    visibility: String,
202    /// The parsed argument list in declaration order.
203    arguments: Vec<Argument>,
204    /// The declared return type as source tokens.
205    return_ty: String,
206    /// The optional live implementation expression.
207    implementation: Option<String>,
208    /// Whether the declared method was marked `async`.
209    is_async: bool,
210}
211
212/// Internal representation of one method argument.
213#[derive(Clone)]
214struct Argument {
215    /// The argument binding name.
216    name: String,
217    /// The argument type as source tokens.
218    ty: String,
219}
220
221impl Method {
222    /// Returns the number of arguments declared by the method.
223    fn arity(&self) -> usize {
224        self.arguments.len()
225    }
226
227    /// Chooses the correct runtime eraser helper for the method shape.
228    fn eraser_name(&self) -> String {
229        if self.is_async {
230            format!("::clients::erase_async_{}", self.arity())
231        } else {
232            format!("::clients::erase_sync_{}", self.arity())
233        }
234    }
235
236    /// Renders the argument list in declaration form, for example
237    /// `id: u64, name: String`.
238    fn args_decl(&self) -> String {
239        self.arguments
240            .iter()
241            .map(|argument| format!("{}: {}", argument.name, argument.ty))
242            .collect::<Vec<_>>()
243            .join(", ")
244    }
245
246    /// Renders only the argument types in declaration order.
247    fn args_types(&self) -> String {
248        self.arguments
249            .iter()
250            .map(|argument| argument.ty.clone())
251            .collect::<Vec<_>>()
252            .join(", ")
253    }
254
255    /// Renders only the argument names in declaration order.
256    fn args_names(&self) -> String {
257        self.arguments
258            .iter()
259            .map(|argument| argument.name.clone())
260            .collect::<Vec<_>>()
261            .join(", ")
262    }
263
264    /// Renders the function-pointer return type used by the stored field.
265    fn fn_pointer_return(&self) -> String {
266        if self.is_async {
267            format!("::clients::BoxFuture<{}>", self.return_ty)
268        } else {
269            self.return_ty.clone()
270        }
271    }
272
273    /// Renders the struct field that stores the underlying function pointer.
274    fn render_field(&self) -> String {
275        format!(
276            "{}: fn({}) -> {},",
277            self.name,
278            self.args_types(),
279            self.fn_pointer_return()
280        )
281    }
282
283    /// Renders the ergonomic wrapper method that forwards to the stored
284    /// function pointer.
285    fn render_method(&self) -> String {
286        let visibility = with_trailing_space(&self.visibility);
287        let args_decl = self.args_decl();
288        let call_args = self.args_names();
289
290        if self.is_async {
291            format!(
292                "{visibility}async fn {}(&self{}{}) -> {} {{
293                (self.{})({}).await
294            }}",
295                self.name,
296                maybe_comma(&args_decl),
297                args_decl,
298                self.return_ty,
299                self.name,
300                call_args,
301            )
302        } else {
303            format!(
304                "{visibility}fn {}(&self{}{}) -> {} {{
305                (self.{})({})
306            }}",
307                self.name,
308                maybe_comma(&args_decl),
309                args_decl,
310                self.return_ty,
311                self.name,
312                call_args,
313            )
314        }
315    }
316
317    /// Renders the live initializer for the function-pointer field.
318    ///
319    /// When a live implementation is present, this selects the correct eraser
320    /// helper. Otherwise it generates a panic-based placeholder used to surface
321    /// missing live implementations with a readable dependency path.
322    fn render_live_initializer(&self, client_name: &str) -> String {
323        if let Some(implementation) = &self.implementation {
324            format!("{}({implementation})", self.eraser_name())
325        } else if self.is_async {
326            format!(
327                "{{
328                        fn __dep_unimplemented({}) -> ::clients::BoxFuture<{}> {{
329                            ::clients::boxed(async move {{
330                                ::clients::unimplemented_dependency(\"{}.{}\")
331                            }})
332                        }}
333
334                        __dep_unimplemented
335                    }}",
336                self.args_decl(),
337                self.return_ty,
338                client_name,
339                self.name,
340            )
341        } else {
342            format!(
343                "{{
344                        fn __dep_unimplemented({}) -> {} {{
345                            ::clients::unimplemented_dependency(\"{}.{}\")
346                        }}
347
348                        __dep_unimplemented
349                    }}",
350                self.args_decl(),
351                self.return_ty,
352                client_name,
353                self.name,
354            )
355        }
356    }
357
358    /// Renders the per-method helper module used by `deps!` and `test_deps!`.
359    fn render_module(&self, client_name: &str) -> String {
360        let args_types = self.args_types();
361        let fn_pointer_return = self.fn_pointer_return();
362        let eraser = self.eraser_name();
363
364        if self.is_async {
365            format!(
366                "pub mod {} {{
367                    use super::*;
368
369                    pub fn get() -> fn({}) -> {} {{
370                        super::get().{}
371                    }}
372
373                    pub fn override_with<F, Fut>(builder: &mut ::clients::OverrideBuilder, implementation: F)
374                    where
375                        F: Fn({}) -> Fut + Copy + 'static,
376                        Fut: ::core::future::Future<Output = {}> + Send + 'static,
377                    {{
378                        builder.update::<super::super::{client_name}, _>(|mut dependency| {{
379                            dependency.{} = {}(implementation);
380                            dependency
381                        }});
382                    }}
383                }}",
384                self.name,
385                args_types,
386                fn_pointer_return,
387                self.name,
388                args_types,
389                self.return_ty,
390                self.name,
391                eraser,
392            )
393        } else {
394            format!(
395                "pub mod {} {{
396                    use super::*;
397
398                    pub fn get() -> fn({}) -> {} {{
399                        super::get().{}
400                    }}
401
402                    pub fn override_with<F>(builder: &mut ::clients::OverrideBuilder, implementation: F)
403                    where
404                        F: Fn({}) -> {} + Copy + 'static,
405                    {{
406                        builder.update::<super::super::{client_name}, _>(|mut dependency| {{
407                            dependency.{} = {}(implementation);
408                            dependency
409                        }});
410                    }}
411                }}",
412                self.name,
413                args_types,
414                fn_pointer_return,
415                self.name,
416                args_types,
417                self.return_ty,
418                self.name,
419                eraser,
420            )
421        }
422    }
423}
424
425/// Parses a method list from the body of a `client!` declaration.
426fn parse_methods(stream: TokenStream) -> Result<Vec<Method>, String> {
427    split_top_level(stream, ';')
428        .into_iter()
429        .map(|tokens| parse_method(&tokens))
430        .collect()
431}
432
433/// Parses one declared client method.
434fn parse_method(tokens: &[TokenTree]) -> Result<Method, String> {
435    if tokens.is_empty() {
436        return Err("empty method definition".into());
437    }
438
439    let fn_index = tokens
440        .iter()
441        .position(|token| is_ident(token, "fn"))
442        .ok_or_else(|| "client methods must use `fn`".to_string())?;
443
444    let mut leading = tokens[..fn_index].to_vec();
445    let is_async = matches!(
446        leading.last(),
447        Some(TokenTree::Ident(ident)) if ident.to_string() == "async"
448    );
449    if is_async {
450        leading.pop();
451    }
452
453    let visibility = tokens_to_string(&leading);
454    let name = ident_at(tokens, fn_index + 1, "a method name")?;
455    let arguments_group = match tokens.get(fn_index + 2) {
456        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => {
457            group.stream()
458        }
459        _ => return Err(format!("method `{name}` is missing its argument list")),
460    };
461
462    let rest = &tokens[fn_index + 3..];
463    if !matches!(rest.first(), Some(TokenTree::Punct(punct)) if punct.as_char() == '-')
464        || !matches!(rest.get(1), Some(TokenTree::Punct(punct)) if punct.as_char() == '>')
465    {
466        return Err(format!("method `{name}` is missing `->`"));
467    }
468
469    let eq_index = rest
470        .iter()
471        .position(|token| matches!(token, TokenTree::Punct(punct) if punct.as_char() == '='));
472    let return_tokens = match eq_index {
473        Some(index) => &rest[2..index],
474        None => &rest[2..],
475    };
476    if return_tokens.is_empty() {
477        return Err(format!("method `{name}` is missing a return type"));
478    }
479
480    let implementation = eq_index.map(|index| tokens_to_string(&rest[index + 1..]));
481    let arguments = parse_arguments(arguments_group)?;
482    if arguments.len() > 4 {
483        return Err(format!(
484            "method `{name}` has {} arguments, but only up to 4 are supported right now",
485            arguments.len()
486        ));
487    }
488
489    Ok(Method {
490        name,
491        visibility,
492        arguments,
493        return_ty: tokens_to_string(return_tokens),
494        implementation,
495        is_async,
496    })
497}
498
499/// Parses a parenthesized argument list into named arguments.
500fn parse_arguments(stream: TokenStream) -> Result<Vec<Argument>, String> {
501    split_top_level(stream, ',')
502        .into_iter()
503        .map(|tokens| {
504            let colon_index = tokens
505                .iter()
506                .position(
507                    |token| matches!(token, TokenTree::Punct(punct) if punct.as_char() == ':'),
508                )
509                .ok_or_else(|| "expected arguments to look like `name: Type`".to_string())?;
510
511            let name = tokens[..colon_index]
512                .iter()
513                .rev()
514                .find_map(|token| match token {
515                    TokenTree::Ident(ident) => Some(ident.to_string()),
516                    _ => None,
517                })
518                .ok_or_else(|| "expected an argument name".to_string())?;
519
520            let ty = tokens_to_string(&tokens[colon_index + 1..]);
521            if ty.is_empty() {
522                return Err("expected an argument type".into());
523            }
524
525            Ok(Argument { name, ty })
526        })
527        .collect()
528}
529
530/// Parses the `Depends` derive input and routes to struct expansion.
531fn derive_depends_impl(input: TokenStream) -> Result<TokenStream, String> {
532    let mut tokens = input.into_iter().peekable();
533
534    while let Some(token) = tokens.next() {
535        if is_ident(&token, "struct") {
536            return expand_struct(tokens);
537        }
538    }
539
540    Err("Depends can only be derived for structs".into())
541}
542
543/// Expands a supported struct declaration into `Default` and `from_deps()`.
544fn expand_struct<I>(mut tokens: I) -> Result<TokenStream, String>
545where
546    I: Iterator<Item = TokenTree>,
547{
548    let name = match tokens.next() {
549        Some(TokenTree::Ident(ident)) => ident,
550        _ => return Err("expected a struct name".into()),
551    };
552
553    let fields_group = match tokens.next() {
554        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
555        Some(_) => return Err("Depends does not support generics or where clauses yet".into()),
556        None => return Err("expected a braced struct body".into()),
557    };
558
559    let fields = parse_fields(fields_group.stream())?;
560    let initializers = fields
561        .into_iter()
562        .map(|field| {
563            if field.injected {
564                format!("{}: ::clients::get::<{}>()", field.name, field.ty)
565            } else {
566                format!("{}: ::core::default::Default::default()", field.name)
567            }
568        })
569        .collect::<Vec<_>>()
570        .join(", ");
571
572    let output = format!(
573        "impl ::core::default::Default for {name} {{
574            fn default() -> Self {{
575                Self {{ {initializers} }}
576            }}
577        }}
578
579        impl {name} {{
580            #[doc = \"Constructs `Self` by resolving every `#[dep]` field from the dependency system and initializing all other fields with `Default::default()`.\"]
581            pub fn from_deps() -> Self {{
582                ::core::default::Default::default()
583            }}
584        }}",
585    );
586
587    output
588        .parse::<TokenStream>()
589        .map_err(|error| error.to_string())
590}
591
592/// Internal representation of a struct field encountered by `Depends`.
593struct Field {
594    /// The field identifier.
595    name: String,
596    /// The field type as source tokens.
597    ty: String,
598    /// Whether the field carries `#[dep]`.
599    injected: bool,
600}
601
602/// Parses a comma-separated field list from a braced struct body.
603fn parse_fields(stream: TokenStream) -> Result<Vec<Field>, String> {
604    split_top_level(stream, ',')
605        .into_iter()
606        .map(|tokens| parse_field(&tokens))
607        .collect()
608}
609
610/// Parses one struct field and detects the `#[dep]` marker attribute.
611fn parse_field(tokens: &[TokenTree]) -> Result<Field, String> {
612    let mut injected = false;
613    let mut colon_index = None;
614
615    for (index, token) in tokens.iter().enumerate() {
616        if matches_dep_attribute(tokens, index) {
617            injected = true;
618        }
619
620        if let TokenTree::Punct(punct) = token
621            && punct.as_char() == ':'
622        {
623            colon_index = Some(index);
624            break;
625        }
626    }
627
628    let colon_index = colon_index.ok_or_else(|| "expected a named struct field".to_string())?;
629
630    let name = tokens[..colon_index]
631        .iter()
632        .rev()
633        .find_map(|token| match token {
634            TokenTree::Ident(ident) => Some(ident.to_string()),
635            _ => None,
636        })
637        .ok_or_else(|| "expected a field name".to_string())?;
638
639    let ty_tokens = tokens[colon_index + 1..]
640        .iter()
641        .cloned()
642        .collect::<TokenStream>();
643    if ty_tokens.is_empty() {
644        return Err("expected a field type".into());
645    }
646
647    Ok(Field {
648        name,
649        ty: ty_tokens.to_string(),
650        injected,
651    })
652}
653
654/// Splits a token stream at a top-level punctuation separator.
655///
656/// Token groups already isolate parentheses, brackets, and braces for us, but
657/// generic type arguments are represented with raw punctuation, so this helper
658/// also tracks angle-bracket nesting in order to avoid splitting commas inside
659/// types like `Vec<Result<T, E>>`.
660fn split_top_level(stream: TokenStream, separator: char) -> Vec<Vec<TokenTree>> {
661    let mut items = Vec::new();
662    let mut current = Vec::new();
663    let mut angle_depth = 0usize;
664
665    for token in stream {
666        let should_split = matches!(
667            &token,
668            TokenTree::Punct(punct)
669                if punct.as_char() == separator
670                    && punct.spacing() == Spacing::Alone
671                    && angle_depth == 0
672        );
673
674        if should_split {
675            if !current.is_empty() {
676                items.push(current);
677                current = Vec::new();
678            }
679            continue;
680        }
681
682        if let TokenTree::Punct(punct) = &token {
683            match punct.as_char() {
684                '<' => angle_depth += 1,
685                '>' => angle_depth = angle_depth.saturating_sub(1),
686                _ => {}
687            }
688        }
689
690        current.push(token);
691    }
692
693    if !current.is_empty() {
694        items.push(current);
695    }
696
697    items
698}
699
700/// Reads an identifier from a token slice at a specific index.
701fn ident_at(tokens: &[TokenTree], index: usize, expected: &str) -> Result<String, String> {
702    match tokens.get(index) {
703        Some(TokenTree::Ident(ident)) => Ok(ident.to_string()),
704        _ => Err(format!("expected {expected}")),
705    }
706}
707
708/// Returns `true` when the token pair at `index` spells `#[dep]`.
709fn matches_dep_attribute(tokens: &[TokenTree], index: usize) -> bool {
710    let Some(TokenTree::Punct(pound)) = tokens.get(index) else {
711        return false;
712    };
713    if pound.as_char() != '#' {
714        return false;
715    }
716
717    let Some(TokenTree::Group(group)) = tokens.get(index + 1) else {
718        return false;
719    };
720
721    if group.delimiter() != Delimiter::Bracket {
722        return false;
723    }
724
725    let mut attribute_tokens = group.stream().into_iter();
726    matches!(attribute_tokens.next(), Some(TokenTree::Ident(ident)) if ident.to_string() == "dep")
727}
728
729/// Returns `true` when the token is an identifier matching `expected`.
730fn is_ident(token: &TokenTree, expected: &str) -> bool {
731    matches!(token, TokenTree::Ident(ident) if ident.to_string() == expected)
732}
733
734/// Serializes a token slice into a whitespace-normalized Rust source fragment.
735fn tokens_to_string(tokens: &[TokenTree]) -> String {
736    tokens
737        .iter()
738        .map(TokenTree::to_string)
739        .collect::<Vec<_>>()
740        .join(" ")
741}
742
743/// Returns `value` plus a trailing space when it is non-empty.
744fn with_trailing_space(value: &str) -> String {
745    if value.is_empty() {
746        String::new()
747    } else {
748        format!("{value} ")
749    }
750}
751
752/// Returns `", "` when `value` is non-empty so rendered methods can place
753/// commas between `&self` and declared parameters without branching inline.
754fn maybe_comma(value: &str) -> &'static str {
755    if value.is_empty() { "" } else { ", " }
756}
757
758/// Renders a compile-time error token stream from a friendly message.
759fn compile_error(message: String) -> TokenStream {
760    format!("compile_error!({message:?});")
761        .parse()
762        .expect("compile_error! should parse")
763}