1use darling::{ast::Data, FromDeriveInput, FromField};
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::quote;
5use syn::{DeriveInput, Ident};
6
7#[derive(Debug, FromDeriveInput)]
8#[darling(attributes(node), supports(struct_named))]
9struct Node {
10 ident: Ident,
11 data: Data<(), NodeField>,
12 class_type: String,
13 #[darling(default)]
14 trait_name: Option<String>,
15}
16
17#[derive(Debug, FromField)]
18#[darling(attributes(node_input))]
19struct NodeField {
20 ident: Option<Ident>,
21 ty: syn::Type,
22 #[darling(default)]
23 skip: bool,
24}
25
26impl Node {
27 fn ident(&self) -> &Ident {
28 &self.ident
29 }
30
31 fn trait_name_ident(&self) -> Ident {
32 let trait_name = self.trait_name.as_ref().unwrap_or(self.class_type());
33 Ident::new(trait_name.as_str(), Span::mixed_site())
34 }
35
36 fn class_type(&self) -> &String {
37 &self.class_type
38 }
39
40 fn fields(&self) -> impl Iterator<Item = &NodeField> {
41 match &self.data {
42 Data::Enum(_) => unreachable!(),
43 Data::Struct(fields) => fields.fields.iter().filter(|f| !f.skip),
44 }
45 }
46}
47
48impl NodeField {
49 fn ident(&self) -> &Ident {
50 self.ident.as_ref().expect("only named struct is supported")
51 }
52
53 fn getter_ident(&self) -> &Ident {
54 self.ident()
55 }
56
57 fn setter_ident(&self) -> Ident {
58 let name = format!("set_{}", self.ident());
59 Ident::new(name.as_str(), Span::mixed_site())
60 }
61}
62
63#[allow(unused)]
64#[proc_macro_derive(Node, attributes(node, node_input))]
65pub fn derive_node(input: TokenStream) -> TokenStream {
66 let input: DeriveInput = syn::parse(input).unwrap();
67 let node = Node::from_derive_input(&input).unwrap();
68 let node_ident = node.ident();
69 let class_type = node.class_type();
70
71 let fields_methods = node.fields().map(|field| {
72 let field_name = field.ident();
73 let field_type = &field.ty;
74 let get_field_name = field.getter_ident();
75 let set_field_name = field.setter_ident();
76 quote!(
77 fn #get_field_name(&self) -> ::cmfy::Result<#field_type>;
78 fn #set_field_name(&mut self, value: #field_type) -> ::cmfy::Result<()>;
79 )
80 });
81
82 let fields_methods_impl = node.fields().map(|field| {
83 let field_name = field.ident();
84 let field_type = &field.ty;
85 let get_field_name = field.getter_ident();
86 let set_field_name = field.setter_ident();
87 quote!(
88 fn #get_field_name(&self) -> ::cmfy::Result<#field_type> {
89 let (_, node) = self.first_by_class::<#node_ident>()?;
90 Ok(node.#field_name)
91 }
92 fn #set_field_name(&mut self, value: #field_type) -> ::cmfy::Result<()> {
93 self.change_first_by_class(move |node: &mut #node_ident| {
94 node.#field_name = value.clone();
95 })
96 }
97 )
98 });
99
100 let node_trait = node.trait_name_ident();
101 let generated = quote!(
102 impl ::cmfy::dto::ClassType for #node_ident {
103 const CLASS_TYPE: &'static str = #class_type;
104 }
105
106 pub trait #node_trait {
107 #(#fields_methods)*
108 }
109
110 impl #node_trait for ::cmfy::dto::PromptNodes {
111 #(#fields_methods_impl)*
112 }
113 );
114
115 generated.into()
116}