Skip to main content

explainable_macros/
lib.rs

1//! Procedural macro support for the `explainable` crate.
2//!
3//! This crate is an implementation detail. Use the re-exported
4//! `explainable::explainable` attribute macro rather than depending on this
5//! crate directly.
6
7use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use quote::{format_ident, quote};
10use syn::{
11    FnArg, GenericArgument, ItemTrait, Pat, PathArguments, ReturnType, TraitItem, Type,
12    TypeParamBound, WherePredicate, parse_macro_input,
13};
14
15// ─── Helpers ─────────────────────────────────────────────────────────────────
16
17/// Returns `true` when `ty` looks like a `Result`-style type.
18fn looks_like_result(ty: &Type) -> bool {
19    match ty {
20        Type::Path(tp) => tp
21            .path
22            .segments
23            .last()
24            .map(|s| {
25                let name = s.ident.to_string();
26                name == "Result" || name.ends_with("Result")
27            })
28            .unwrap_or(false),
29        _ => false,
30    }
31}
32
33/// Returns `true` when `ty` is a Result-style type whose first generic arg is `()`.
34fn result_ok_is_unit(ty: &Type) -> bool {
35    let Type::Path(tp) = ty else { return false };
36    let Some(seg) = tp.path.segments.last() else {
37        return false;
38    };
39    let PathArguments::AngleBracketed(args) = &seg.arguments else {
40        return false;
41    };
42    let Some(GenericArgument::Type(first_ty)) = args.args.first() else {
43        return false;
44    };
45    matches!(first_ty, Type::Tuple(t) if t.elems.is_empty())
46}
47
48/// Returns `true` when `ty` is a Result-style type whose first generic arg is `Self`.
49fn result_ok_is_self(ty: &Type) -> bool {
50    let Type::Path(tp) = ty else { return false };
51    let Some(seg) = tp.path.segments.last() else {
52        return false;
53    };
54    let PathArguments::AngleBracketed(args) = &seg.arguments else {
55        return false;
56    };
57    let Some(GenericArgument::Type(ok_ty)) = args.args.first() else {
58        return false;
59    };
60    match ok_ty {
61        Type::Path(p) => p.path.is_ident("Self"),
62        _ => false,
63    }
64}
65
66/// Returns `true` when the return type is chainable — i.e. the method can be
67/// included in the `FooExt` / blanket-impl chain.
68///
69/// A method is chainable when its return is one of:
70/// - `Self` — assign directly
71/// - `Result<Self, _>` / `FooResult<Self>` — unwrap and assign
72/// - `()` or `Result<(), _>` — call for side-effect, no assignment
73///
74/// Methods returning a *concrete* type (e.g. `AudioSamples<'static, Self::Sample>`)
75/// that is *not* `Self` cannot be assigned back to `self.inner: T` in the generic
76/// blanket impl and are excluded from the chain.
77fn is_chainable_return(ret: &ReturnType) -> bool {
78    match ret {
79        // `fn foo() { ... }` — void, chainable
80        ReturnType::Default => true,
81        ReturnType::Type(_, ty) => {
82            // Direct `Self`
83            if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("Self")) {
84                return true;
85            }
86            // `()` tuple type
87            if matches!(ty.as_ref(), Type::Tuple(t) if t.elems.is_empty()) {
88                return true;
89            }
90            if looks_like_result(ty) {
91                // Result<(), _> or Result<Self, _> — both chainable
92                return result_ok_is_unit(ty) || result_ok_is_self(ty);
93            }
94            false
95        }
96    }
97}
98
99/// Returns `true` when the `FnArg` is a consuming `self` receiver.
100fn is_consuming_receiver(arg: &FnArg) -> bool {
101    matches!(arg, FnArg::Receiver(r) if r.reference.is_none())
102}
103
104/// Walk `ty` recursively and collect every `Self::X` associated-type reference,
105/// appending unique `X` idents to `found`.
106fn collect_self_assoc_in_type(ty: &Type, found: &mut Vec<syn::Ident>) {
107    match ty {
108        Type::Path(tp) if tp.qself.is_none() => {
109            let segs: Vec<_> = tp.path.segments.iter().collect();
110            // Detect bare `Self::X` paths (two segments, first is "Self").
111            if segs.len() == 2 && segs[0].ident == "Self" {
112                let name = segs[1].ident.clone();
113                if !found.iter().any(|i: &syn::Ident| *i == name) {
114                    found.push(name);
115                }
116            }
117            // Recurse into every set of angle-bracketed generic arguments.
118            for seg in &tp.path.segments {
119                if let PathArguments::AngleBracketed(args) = &seg.arguments {
120                    for ga in &args.args {
121                        if let GenericArgument::Type(inner) = ga {
122                            collect_self_assoc_in_type(inner, found);
123                        }
124                    }
125                }
126            }
127        }
128        Type::Reference(r) => collect_self_assoc_in_type(&r.elem, found),
129        Type::Slice(s) => collect_self_assoc_in_type(&s.elem, found),
130        Type::Array(a) => collect_self_assoc_in_type(&a.elem, found),
131        Type::Tuple(t) => t
132            .elems
133            .iter()
134            .for_each(|e| collect_self_assoc_in_type(e, found)),
135        _ => {}
136    }
137}
138
139// ─── Attribute macro ──────────────────────────────────────────────────────────
140
141/// Annotate an operation trait to generate the full explaining scaffolding for
142/// every method in that trait.
143///
144/// # Receiver-type handling
145///
146/// - **`self`** (consuming) — uses `std::mem::replace` to move the value out safely.
147/// - **`&self` / `&mut self`** returning `Self` or `Result<Self, E>` — calls directly.
148/// - **`&mut self`** returning `()` or `Result<(), E>` — calls for side-effect only.
149///
150/// # Associated-type propagation
151///
152/// Method parameters that reference `Self::X` (associated types from a supertrait) are
153/// handled by declaring `type X;` in the generated `FooExt` trait and setting
154/// `type X = T::X;` in the blanket impl. No substitution of method signatures is needed.
155///
156/// # `#[cfg(…)]` propagation
157///
158/// `#[cfg(…)]` attributes on individual trait methods are forwarded to all generated items.
159#[proc_macro_attribute]
160pub fn explainable(_args: TokenStream, input: TokenStream) -> TokenStream {
161    let trait_def = parse_macro_input!(input as ItemTrait);
162    let trait_name = &trait_def.ident;
163    let explain_text_trait_name = format_ident!("{}ExplainText", trait_name);
164    let ext_trait_name = format_ident!("{}Ext", trait_name);
165    let vis = &trait_def.vis;
166
167    let self_methods: Vec<_> = trait_def
168        .items
169        .iter()
170        .filter_map(|item| {
171            if let TraitItem::Fn(f) = item {
172                let has_receiver = f
173                    .sig
174                    .inputs
175                    .first()
176                    .map(|a| matches!(a, FnArg::Receiver(_)))
177                    .unwrap_or(false);
178                // Exclude methods whose return type cannot be assigned back to
179                // `self.inner: T` in the generic blanket impl (e.g. methods that
180                // return a concrete `AudioSamples<'static, Self::Sample>` rather
181                // than `Self`).
182                let chainable = is_chainable_return(&f.sig.output);
183                if has_receiver && chainable {
184                    Some(f)
185                } else {
186                    None
187                }
188            } else {
189                None
190            }
191        })
192        .collect();
193
194    // ── Collect Self::X associated-type refs from all method params ────────────
195
196    let mut assoc_idents: Vec<syn::Ident> = Vec::new();
197    for m in &self_methods {
198        for param in m.sig.inputs.iter() {
199            if let FnArg::Typed(pt) = param {
200                collect_self_assoc_in_type(&pt.ty, &mut assoc_idents);
201            }
202        }
203    }
204
205    // ── Collect bounds on Self::X from the original trait's where clause ───────
206    // Maps assoc ident → list of bounds so we can emit `type X: Bound1 + Bound2;`.
207
208    let where_bounds: Vec<Vec<&TypeParamBound>> = assoc_idents
209        .iter()
210        .map(|name| {
211            let mut bounds: Vec<&TypeParamBound> = Vec::new();
212            if let Some(wc) = &trait_def.generics.where_clause {
213                for pred in &wc.predicates {
214                    if let WherePredicate::Type(pt) = pred {
215                        // Look for `Self::Name: Bound` predicates.
216                        if let Type::Path(tp) = &pt.bounded_ty {
217                            let segs: Vec<_> = tp.path.segments.iter().collect();
218                            if segs.len() == 2 && segs[0].ident == "Self" && &segs[1].ident == name
219                            {
220                                bounds.extend(pt.bounds.iter());
221                            }
222                        }
223                    }
224                }
225            }
226            bounds
227        })
228        .collect();
229
230    // type X: Bound; declarations for FooExt (with doc so missing_docs is satisfied)
231    let ext_assoc_type_decls: Vec<TokenStream2> = assoc_idents
232        .iter()
233        .zip(where_bounds.iter())
234        .map(|(name, bounds)| {
235            let doc = format!("Associated type `{}` forwarded from the domain type.", name);
236            if bounds.is_empty() {
237                quote! {
238                    #[doc = #doc]
239                    type #name;
240                }
241            } else {
242                quote! {
243                    #[doc = #doc]
244                    type #name: #(#bounds)+*;
245                }
246            }
247        })
248        .collect();
249
250    // type X = T::X; definitions for the blanket impl
251    let ext_assoc_type_impls: Vec<TokenStream2> = assoc_idents
252        .iter()
253        .map(|name| quote! { type #name = T::#name; })
254        .collect();
255
256    // ── Companion ExplainText trait ───────────────────────────────────────────
257
258    let explain_text_methods: Vec<TokenStream2> = self_methods
259        .iter()
260        .map(|m| {
261            let method_name = &m.sig.ident;
262            let explain_fn = format_ident!("explain_text_{}", method_name);
263            let cfg_attrs: Vec<_> = m
264                .attrs
265                .iter()
266                .filter(|a| a.path().is_ident("cfg"))
267                .collect();
268            quote! {
269                #(#cfg_attrs)*
270                fn #explain_fn(before: &Self, after: &Self) -> String;
271            }
272        })
273        .collect();
274
275    // ── Extension trait — method signatures ───────────────────────────────────
276
277    let ext_method_sigs: Vec<TokenStream2> = self_methods
278        .iter()
279        .map(|m| {
280            let method_name = &m.sig.ident;
281            let cfg_attrs: Vec<_> = m
282                .attrs
283                .iter()
284                .filter(|a| a.path().is_ident("cfg"))
285                .collect();
286            let non_recv_params: Vec<_> = m
287                .sig
288                .inputs
289                .iter()
290                .filter(|a| !matches!(a, FnArg::Receiver(_)))
291                .collect();
292            quote! {
293                #(#cfg_attrs)*
294                fn #method_name(&mut self, #(#non_recv_params),*) -> &mut Self;
295            }
296        })
297        .collect();
298
299    // ── Blanket impl — method bodies ──────────────────────────────────────────
300
301    let ext_method_impls: Vec<TokenStream2> = self_methods
302        .iter()
303        .map(|m| {
304            let method_name = &m.sig.ident;
305            let explain_fn = format_ident!("explain_text_{}", method_name);
306            let cfg_attrs: Vec<_> = m
307                .attrs
308                .iter()
309                .filter(|a| a.path().is_ident("cfg"))
310                .collect();
311
312            let non_recv_params: Vec<_> = m
313                .sig
314                .inputs
315                .iter()
316                .filter(|a| !matches!(a, FnArg::Receiver(_)))
317                .collect();
318
319            let arg_idents: Vec<_> = non_recv_params
320                .iter()
321                .filter_map(|a| {
322                    if let FnArg::Typed(pt) = a {
323                        if let Pat::Ident(pi) = pt.pat.as_ref() {
324                            Some(&pi.ident)
325                        } else {
326                            None
327                        }
328                    } else {
329                        None
330                    }
331                })
332                .collect();
333
334            let consuming = m
335                .sig
336                .inputs
337                .first()
338                .map(|a| is_consuming_receiver(a))
339                .unwrap_or(false);
340
341            let (is_result, is_void) = match &m.sig.output {
342                ReturnType::Type(_, ty) => {
343                    let r = looks_like_result(ty);
344                    (r, r && result_ok_is_unit(ty))
345                }
346                ReturnType::Default => (false, true),
347            };
348
349            let update_inner = if is_void {
350                if is_result {
351                    quote! { self.inner.#method_name(#(#arg_idents),*).unwrap(); }
352                } else {
353                    quote! { self.inner.#method_name(#(#arg_idents),*); }
354                }
355            } else if consuming {
356                if is_result {
357                    quote! {
358                        let __taken = ::std::mem::replace(&mut self.inner, before.clone());
359                        self.inner = __taken.#method_name(#(#arg_idents),*).unwrap();
360                    }
361                } else {
362                    quote! {
363                        let __taken = ::std::mem::replace(&mut self.inner, before.clone());
364                        self.inner = __taken.#method_name(#(#arg_idents),*);
365                    }
366                }
367            } else if is_result {
368                quote! { self.inner = self.inner.#method_name(#(#arg_idents),*).unwrap(); }
369            } else {
370                quote! { self.inner = self.inner.#method_name(#(#arg_idents),*); }
371            };
372
373            quote! {
374                #(#cfg_attrs)*
375                fn #method_name(&mut self, #(#non_recv_params),*) -> &mut Self {
376                    let before = self.inner.clone();
377                    #update_inner
378                    let text = match self.mode {
379                        ::explainable::ExplainMode::Text
380                        | ::explainable::ExplainMode::Both => Some(
381                            <T as #explain_text_trait_name>::#explain_fn(
382                                &before,
383                                &self.inner,
384                            ),
385                        ),
386                        _ => None,
387                    };
388                    let visual = match self.mode {
389                        ::explainable::ExplainMode::Visual
390                        | ::explainable::ExplainMode::Both => Some(
391                            <T as ::explainable::RenderVisual>::render_visual(
392                                &before,
393                                &self.inner,
394                            ),
395                        ),
396                        _ => None,
397                    };
398                    self.explanations.push(::explainable::Explanation::new(
399                        self.mode,
400                        text,
401                        visual,
402                    ));
403                    self
404                }
405            }
406        })
407        .collect();
408
409    // ── Assemble output ───────────────────────────────────────────────────────
410
411    let explain_text_doc = format!(
412        "Companion text trait generated by `#[explainable]` for [`{}`].\n\n\
413         Implement one `explain_text_<method>` per operation to supply the pedagogical \
414         text explanation shown when that operation runs inside an explaining chain.",
415        trait_name
416    );
417    let ext_trait_doc = format!(
418        "Extension trait generated by `#[explainable]` for [`{}`].\n\n\
419         Bring this into scope to call `{}` operations on an \
420         [`explainable::Explaining`] chain.",
421        trait_name, trait_name
422    );
423
424    let output = quote! {
425        #trait_def
426
427        #[doc = #explain_text_doc]
428        #[allow(missing_docs)]
429        #vis trait #explain_text_trait_name:
430            ::explainable::Explainable + #trait_name
431        {
432            #(#explain_text_methods)*
433        }
434
435        #[doc = #ext_trait_doc]
436        #[allow(missing_docs)]
437        #vis trait #ext_trait_name {
438            #(#ext_assoc_type_decls)*
439            #(#ext_method_sigs)*
440        }
441
442        #[allow(missing_docs)]
443        impl<T> #ext_trait_name for ::explainable::Explaining<T>
444        where
445            T: ::explainable::Explainable + #trait_name + #explain_text_trait_name,
446        {
447            #(#ext_assoc_type_impls)*
448            #(#ext_method_impls)*
449        }
450    };
451
452    output.into()
453}