explainable_macros/
lib.rs1use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use quote::{format_ident, quote};
10use syn::{
11 FnArg, GenericArgument, ItemTrait, Pat, PathArguments, ReturnType, TraitItem, Type,
12 TypeParamBound, WherePredicate, parse_macro_input,
13};
14
15fn looks_like_result(ty: &Type) -> bool {
19 match ty {
20 Type::Path(tp) => tp
21 .path
22 .segments
23 .last()
24 .map(|s| {
25 let name = s.ident.to_string();
26 name == "Result" || name.ends_with("Result")
27 })
28 .unwrap_or(false),
29 _ => false,
30 }
31}
32
33fn result_ok_is_unit(ty: &Type) -> bool {
35 let Type::Path(tp) = ty else { return false };
36 let Some(seg) = tp.path.segments.last() else {
37 return false;
38 };
39 let PathArguments::AngleBracketed(args) = &seg.arguments else {
40 return false;
41 };
42 let Some(GenericArgument::Type(first_ty)) = args.args.first() else {
43 return false;
44 };
45 matches!(first_ty, Type::Tuple(t) if t.elems.is_empty())
46}
47
48fn result_ok_is_self(ty: &Type) -> bool {
50 let Type::Path(tp) = ty else { return false };
51 let Some(seg) = tp.path.segments.last() else {
52 return false;
53 };
54 let PathArguments::AngleBracketed(args) = &seg.arguments else {
55 return false;
56 };
57 let Some(GenericArgument::Type(ok_ty)) = args.args.first() else {
58 return false;
59 };
60 match ok_ty {
61 Type::Path(p) => p.path.is_ident("Self"),
62 _ => false,
63 }
64}
65
66fn is_chainable_return(ret: &ReturnType) -> bool {
78 match ret {
79 ReturnType::Default => true,
81 ReturnType::Type(_, ty) => {
82 if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("Self")) {
84 return true;
85 }
86 if matches!(ty.as_ref(), Type::Tuple(t) if t.elems.is_empty()) {
88 return true;
89 }
90 if looks_like_result(ty) {
91 return result_ok_is_unit(ty) || result_ok_is_self(ty);
93 }
94 false
95 }
96 }
97}
98
99fn is_consuming_receiver(arg: &FnArg) -> bool {
101 matches!(arg, FnArg::Receiver(r) if r.reference.is_none())
102}
103
104fn collect_self_assoc_in_type(ty: &Type, found: &mut Vec<syn::Ident>) {
107 match ty {
108 Type::Path(tp) if tp.qself.is_none() => {
109 let segs: Vec<_> = tp.path.segments.iter().collect();
110 if segs.len() == 2 && segs[0].ident == "Self" {
112 let name = segs[1].ident.clone();
113 if !found.iter().any(|i: &syn::Ident| *i == name) {
114 found.push(name);
115 }
116 }
117 for seg in &tp.path.segments {
119 if let PathArguments::AngleBracketed(args) = &seg.arguments {
120 for ga in &args.args {
121 if let GenericArgument::Type(inner) = ga {
122 collect_self_assoc_in_type(inner, found);
123 }
124 }
125 }
126 }
127 }
128 Type::Reference(r) => collect_self_assoc_in_type(&r.elem, found),
129 Type::Slice(s) => collect_self_assoc_in_type(&s.elem, found),
130 Type::Array(a) => collect_self_assoc_in_type(&a.elem, found),
131 Type::Tuple(t) => t
132 .elems
133 .iter()
134 .for_each(|e| collect_self_assoc_in_type(e, found)),
135 _ => {}
136 }
137}
138
139#[proc_macro_attribute]
160pub fn explainable(_args: TokenStream, input: TokenStream) -> TokenStream {
161 let trait_def = parse_macro_input!(input as ItemTrait);
162 let trait_name = &trait_def.ident;
163 let explain_text_trait_name = format_ident!("{}ExplainText", trait_name);
164 let ext_trait_name = format_ident!("{}Ext", trait_name);
165 let vis = &trait_def.vis;
166
167 let self_methods: Vec<_> = trait_def
168 .items
169 .iter()
170 .filter_map(|item| {
171 if let TraitItem::Fn(f) = item {
172 let has_receiver = f
173 .sig
174 .inputs
175 .first()
176 .map(|a| matches!(a, FnArg::Receiver(_)))
177 .unwrap_or(false);
178 let chainable = is_chainable_return(&f.sig.output);
183 if has_receiver && chainable {
184 Some(f)
185 } else {
186 None
187 }
188 } else {
189 None
190 }
191 })
192 .collect();
193
194 let mut assoc_idents: Vec<syn::Ident> = Vec::new();
197 for m in &self_methods {
198 for param in m.sig.inputs.iter() {
199 if let FnArg::Typed(pt) = param {
200 collect_self_assoc_in_type(&pt.ty, &mut assoc_idents);
201 }
202 }
203 }
204
205 let where_bounds: Vec<Vec<&TypeParamBound>> = assoc_idents
209 .iter()
210 .map(|name| {
211 let mut bounds: Vec<&TypeParamBound> = Vec::new();
212 if let Some(wc) = &trait_def.generics.where_clause {
213 for pred in &wc.predicates {
214 if let WherePredicate::Type(pt) = pred {
215 if let Type::Path(tp) = &pt.bounded_ty {
217 let segs: Vec<_> = tp.path.segments.iter().collect();
218 if segs.len() == 2 && segs[0].ident == "Self" && &segs[1].ident == name
219 {
220 bounds.extend(pt.bounds.iter());
221 }
222 }
223 }
224 }
225 }
226 bounds
227 })
228 .collect();
229
230 let ext_assoc_type_decls: Vec<TokenStream2> = assoc_idents
232 .iter()
233 .zip(where_bounds.iter())
234 .map(|(name, bounds)| {
235 let doc = format!("Associated type `{}` forwarded from the domain type.", name);
236 if bounds.is_empty() {
237 quote! {
238 #[doc = #doc]
239 type #name;
240 }
241 } else {
242 quote! {
243 #[doc = #doc]
244 type #name: #(#bounds)+*;
245 }
246 }
247 })
248 .collect();
249
250 let ext_assoc_type_impls: Vec<TokenStream2> = assoc_idents
252 .iter()
253 .map(|name| quote! { type #name = T::#name; })
254 .collect();
255
256 let explain_text_methods: Vec<TokenStream2> = self_methods
259 .iter()
260 .map(|m| {
261 let method_name = &m.sig.ident;
262 let explain_fn = format_ident!("explain_text_{}", method_name);
263 let cfg_attrs: Vec<_> = m
264 .attrs
265 .iter()
266 .filter(|a| a.path().is_ident("cfg"))
267 .collect();
268 quote! {
269 #(#cfg_attrs)*
270 fn #explain_fn(before: &Self, after: &Self) -> String;
271 }
272 })
273 .collect();
274
275 let ext_method_sigs: Vec<TokenStream2> = self_methods
278 .iter()
279 .map(|m| {
280 let method_name = &m.sig.ident;
281 let cfg_attrs: Vec<_> = m
282 .attrs
283 .iter()
284 .filter(|a| a.path().is_ident("cfg"))
285 .collect();
286 let non_recv_params: Vec<_> = m
287 .sig
288 .inputs
289 .iter()
290 .filter(|a| !matches!(a, FnArg::Receiver(_)))
291 .collect();
292 quote! {
293 #(#cfg_attrs)*
294 fn #method_name(&mut self, #(#non_recv_params),*) -> &mut Self;
295 }
296 })
297 .collect();
298
299 let ext_method_impls: Vec<TokenStream2> = self_methods
302 .iter()
303 .map(|m| {
304 let method_name = &m.sig.ident;
305 let explain_fn = format_ident!("explain_text_{}", method_name);
306 let cfg_attrs: Vec<_> = m
307 .attrs
308 .iter()
309 .filter(|a| a.path().is_ident("cfg"))
310 .collect();
311
312 let non_recv_params: Vec<_> = m
313 .sig
314 .inputs
315 .iter()
316 .filter(|a| !matches!(a, FnArg::Receiver(_)))
317 .collect();
318
319 let arg_idents: Vec<_> = non_recv_params
320 .iter()
321 .filter_map(|a| {
322 if let FnArg::Typed(pt) = a {
323 if let Pat::Ident(pi) = pt.pat.as_ref() {
324 Some(&pi.ident)
325 } else {
326 None
327 }
328 } else {
329 None
330 }
331 })
332 .collect();
333
334 let consuming = m
335 .sig
336 .inputs
337 .first()
338 .map(|a| is_consuming_receiver(a))
339 .unwrap_or(false);
340
341 let (is_result, is_void) = match &m.sig.output {
342 ReturnType::Type(_, ty) => {
343 let r = looks_like_result(ty);
344 (r, r && result_ok_is_unit(ty))
345 }
346 ReturnType::Default => (false, true),
347 };
348
349 let update_inner = if is_void {
350 if is_result {
351 quote! { self.inner.#method_name(#(#arg_idents),*).unwrap(); }
352 } else {
353 quote! { self.inner.#method_name(#(#arg_idents),*); }
354 }
355 } else if consuming {
356 if is_result {
357 quote! {
358 let __taken = ::std::mem::replace(&mut self.inner, before.clone());
359 self.inner = __taken.#method_name(#(#arg_idents),*).unwrap();
360 }
361 } else {
362 quote! {
363 let __taken = ::std::mem::replace(&mut self.inner, before.clone());
364 self.inner = __taken.#method_name(#(#arg_idents),*);
365 }
366 }
367 } else if is_result {
368 quote! { self.inner = self.inner.#method_name(#(#arg_idents),*).unwrap(); }
369 } else {
370 quote! { self.inner = self.inner.#method_name(#(#arg_idents),*); }
371 };
372
373 quote! {
374 #(#cfg_attrs)*
375 fn #method_name(&mut self, #(#non_recv_params),*) -> &mut Self {
376 let before = self.inner.clone();
377 #update_inner
378 let text = match self.mode {
379 ::explainable::ExplainMode::Text
380 | ::explainable::ExplainMode::Both => Some(
381 <T as #explain_text_trait_name>::#explain_fn(
382 &before,
383 &self.inner,
384 ),
385 ),
386 _ => None,
387 };
388 let visual = match self.mode {
389 ::explainable::ExplainMode::Visual
390 | ::explainable::ExplainMode::Both => Some(
391 <T as ::explainable::RenderVisual>::render_visual(
392 &before,
393 &self.inner,
394 ),
395 ),
396 _ => None,
397 };
398 self.explanations.push(::explainable::Explanation::new(
399 self.mode,
400 text,
401 visual,
402 ));
403 self
404 }
405 }
406 })
407 .collect();
408
409 let explain_text_doc = format!(
412 "Companion text trait generated by `#[explainable]` for [`{}`].\n\n\
413 Implement one `explain_text_<method>` per operation to supply the pedagogical \
414 text explanation shown when that operation runs inside an explaining chain.",
415 trait_name
416 );
417 let ext_trait_doc = format!(
418 "Extension trait generated by `#[explainable]` for [`{}`].\n\n\
419 Bring this into scope to call `{}` operations on an \
420 [`explainable::Explaining`] chain.",
421 trait_name, trait_name
422 );
423
424 let output = quote! {
425 #trait_def
426
427 #[doc = #explain_text_doc]
428 #[allow(missing_docs)]
429 #vis trait #explain_text_trait_name:
430 ::explainable::Explainable + #trait_name
431 {
432 #(#explain_text_methods)*
433 }
434
435 #[doc = #ext_trait_doc]
436 #[allow(missing_docs)]
437 #vis trait #ext_trait_name {
438 #(#ext_assoc_type_decls)*
439 #(#ext_method_sigs)*
440 }
441
442 #[allow(missing_docs)]
443 impl<T> #ext_trait_name for ::explainable::Explaining<T>
444 where
445 T: ::explainable::Explainable + #trait_name + #explain_text_trait_name,
446 {
447 #(#ext_assoc_type_impls)*
448 #(#ext_method_impls)*
449 }
450 };
451
452 output.into()
453}