openai_magic_instantiate_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Expr, Field};
4use heck::{self, ToLowerCamelCase};
5
6#[derive(Debug, Default)]
7struct MagicAttrArgs {
8    // The #[magic(description = "...")]
9    description: Option<Expr>,
10    // The #[magic(validator = "...")]
11    validators: Vec<Expr>,
12}
13
14impl MagicAttrArgs {
15    fn merge(&mut self, other: Self) {
16        if other.description.is_some() {
17            self.description = other.description;
18        }
19        self.validators.extend(other.validators);
20    }
21}
22
23impl syn::parse::Parse for MagicAttrArgs {
24    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
25        let mut description = None;
26        let mut validators = vec![];
27        while !input.is_empty() {
28            let name: syn::Ident = input.parse()?;
29            match name.to_string().as_str() {
30                "description" => {
31                    input.parse::<syn::Token![=]>()?;
32                    let value: Expr = input.parse()?;
33                    description = Some(value);
34                },
35                "validator" => {
36                    input.parse::<syn::Token![=]>()?;
37                    let value: syn::Expr = input.parse()?;
38                    validators.push(value);
39                },
40                _ => return Err(syn::Error::new(name.span(), "Unknown attribute")),
41            }
42            if input.is_empty() {
43                break;
44            }
45            input.parse::<syn::Token![,]>()?;
46        }
47        Ok(Self { description, validators })
48    }
49}
50
51
52fn attributes<'a>(attrs: impl Iterator<Item = &'a syn::Attribute>) -> MagicAttrArgs {
53    let mut result = MagicAttrArgs::default();
54    for attr in attrs {
55        if attr.path().is_ident("magic") {
56            let attr_args: MagicAttrArgs = attr.parse_args().unwrap();
57            result.merge(attr_args);
58        }
59    }
60    result
61}
62
63fn field_attributes<'a>(fields: impl Iterator<Item = &'a Field>) -> Vec<MagicAttrArgs> {
64    let mut results = vec![];
65    for field in fields {
66        results.push(attributes(field.attrs.iter()));
67    }
68    results
69}
70
71
72/// Derive the `MagicInstantiate` trait for a struct or enum.
73/// Descriptions and validators can be added to fields using the `#[magic(description = ...)]` and `#[magic(validator = ...)]` attributes.
74#[proc_macro_derive(MagicInstantiate, attributes(magic))]
75pub fn derive_magic_instantiate(input: TokenStream) -> TokenStream {
76    let DeriveInput { ident, data, generics, attrs, .. } = parse_macro_input!(input as DeriveInput);
77
78    let attrs = attributes(attrs.iter());
79    let definition_description = attrs.description.into_iter().collect::<Vec<_>>();
80    let definition_validators = attrs
81        .validators
82        .iter()
83        .map(|v| quote! { openai_magic_instantiate::Validator::<Self>::validate(&#v, &result)?; })
84        .collect::<Vec<_>>();
85    let definition_validator_instructions = attrs
86        .validators
87        .iter()
88        .map(|v| quote! { openai_magic_instantiate::Validator::<Self>::instructions(&#v) })
89        .collect::<Vec<_>>();
90
91    let mut generics = generics.clone();
92    for generic in generics.params.iter_mut() {
93        if let syn::GenericParam::Type(type_param) = generic {
94            type_param.bounds.push(syn::parse_quote!(MagicInstantiate));
95        }
96    }
97    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
98
99    let generic_types = generics.params.iter().map(|p| {
100        match p {
101            syn::GenericParam::Type(type_param) => &type_param.ident,
102            syn::GenericParam::Lifetime(_) => panic!("Lifetime parameters are not supported"),
103            syn::GenericParam::Const(_) => panic!("Const parameters are not supported"),
104        }
105    }).collect::<Vec<_>>();
106
107    let name = quote ! {
108        let mut result = stringify!(#ident).to_string();
109        #(
110            result.push_str(&format!("{}", <#generic_types>::name()));
111        )*
112        result
113    };
114
115    match data {
116        Data::Struct(DataStruct { fields, .. }) => {
117            match &fields {
118                syn::Fields::Unit => {
119                    quote! {
120                        impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
121                            fn name() -> String {
122                                #name
123                            }
124
125                            fn reference() -> String {
126                                ()::reference()
127                            }
128
129                            fn definition() -> String {
130                                ()::definition()
131                            }
132
133                            fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) -> String {
134                                ()::add_dependencies(builder)
135                            }
136                    
137                            fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
138                                ()::validate(value)?;
139                                Ok(Self)
140                            }
141                    
142                            fn default_if_omitted() -> Option<Self> {
143                                Some(Self)
144                            }
145
146                            fn is_object() -> bool {
147                                false
148                            }
149                        }
150                    }
151                },
152                syn::Fields::Unnamed(fields) => {
153                    let field_types = fields.unnamed.iter().map(|f| &f.ty).collect::<Vec<_>>();
154                    let field_indices = (0..field_types.len()).collect::<Vec<_>>();
155                    let field_count = field_types.len();
156                    let type_definition = if field_count == 1 {
157                        quote! {
158                            result.push_str(&format!("type {} = {};", stringify!(#ident), references[0]));
159                        }
160                    } else {
161                        quote! {
162                            result.push_str(&format!("type {} = [{}];", stringify!(#ident), references.join(", ")));
163                        }
164                    };
165                    let validate_definition = if field_count == 1 {
166                        let field_type = &field_types[0];
167                        quote! {
168                            let value = <#field_type>::validate(value)?;
169                            let result = Self(value);
170                        }
171                    } else {
172                        quote! {
173                            let openai_magic_instantiate::export::JsonValue::Array(value) = value else { return Err(format!("Expected array tuple, got {}", openai_magic_instantiate::JsonValueExt::type_str(value))) };
174                            if value.len() != #field_count {
175                                return Err(format!("Expected {} elements but got {}", #field_count, value.len()));
176                            }
177                            let result = Self(#(#field_types::validate(&value[#field_indices])?),*);
178                        }
179                    };
180                    quote! {
181                        impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
182                            fn name() -> String {
183                                #name
184                            }
185
186                            fn reference() -> String {
187                                Self::name()
188                            }
189
190                            fn definition() -> String {
191                                let mut result = String::new();
192                                #(
193                                    for line in #definition_description.lines() {
194                                        result.push_str(&format!("// {}\n", line));
195                                    }
196                                )*
197                                #(
198                                    for line in #definition_validator_instructions.lines() {
199                                        result.push_str(&format!("// {}\n", line));
200                                    }
201                                )*
202                                let references = vec![#(<#field_types>::reference()),*];
203                                #type_definition
204                            }
205
206                            fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {
207                                #(
208                                    builder.add::<#field_types>();
209                                )*
210                            }
211                    
212                            fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
213                                #validate_definition
214                                #(
215                                    #definition_validators
216                                )*
217                                Ok(result)
218                            }
219                    
220                            fn default_if_omitted() -> Option<Self> {
221                                Some(#ident(#(<#field_types>::default_if_omitted()?),*))
222                            }
223
224                            fn is_object() -> bool {
225                                false
226                            }
227                        }
228                    }
229                },
230                syn::Fields::Named(fields) => {
231                    let attributes = field_attributes(fields.named.iter());
232                    let field_idents = fields.named.iter().map(|f| f.ident.as_ref().unwrap()).collect::<Vec<_>>();
233                    let field_types = fields.named.iter().map(|f| &f.ty).collect::<Vec<_>>();
234                    let field_names_camel = field_idents.iter().map(|f| f.to_string().to_lower_camel_case()).collect::<Vec<_>>();
235                    let field_is_optionals = field_types.iter().map(|f| {
236                        quote! {
237                            if <#f>::default_if_omitted().is_some() { "?" } else { "" }
238                        }
239                    }).collect::<Vec<_>>();
240
241                    let descriptions = attributes.iter().map(|a| {
242                        if let Some(description) = &a.description {
243                            quote! {
244                                result.push_str(&format!("    // {}\n", #description));
245                            }
246                        } else {
247                            quote! {}
248                        }
249                    }).collect::<Vec<_>>();
250
251                    let validation_comments = attributes.iter().zip(&field_types).map(|(a, field_type)| {
252                        let validators = &a.validators;
253                        quote! {
254                            #(
255                                for line in openai_magic_instantiate::Validator::<#field_type>::instructions(&#validators).lines() {
256                                    result.push_str(&format!("    // {}\n", line));
257                                }
258                            )*
259                        }
260                    }).collect::<Vec<_>>();
261
262                    let field_validators = (0..field_types.len()).map(|i| {
263                        let field_type = &field_types[i];
264                        let validators = &attributes[i].validators;
265                        quote! {
266                            #(
267                                openai_magic_instantiate::Validator::<#field_type>::validate(&#validators, &value)?;
268                            )*
269                        }
270                    }).collect::<Vec<_>>();
271
272                    quote! {
273                        impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
274                            fn name() -> String {
275                                #name 
276                            }
277
278                            fn reference() -> String {
279                                Self::name()
280                            }
281
282                            fn definition() -> String {
283                                let mut result = String::new();
284                                #(
285                                    for line in #definition_description.lines() {
286                                        result.push_str(&format!("// {}\n", line));
287                                    }
288                                )*
289                                #(
290                                    for line in #definition_validator_instructions.lines() {
291                                        result.push_str(&format!("// {}\n", line));
292                                    }
293                                )*
294                                result.push_str(&format!("type {} = {{\n", Self::name()));
295                                #(
296                                    #descriptions
297                                    #validation_comments
298                                    result.push_str(&format!("    {}{}: {};\n", #field_names_camel, #field_is_optionals, <#field_types>::reference()));
299                                )*
300                                result.push_str("};");
301                                result
302                            }
303
304                            fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {
305                                #(
306                                    builder.add::<#field_types>();
307                                )*
308                            }
309                    
310                            fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
311                                let openai_magic_instantiate::export::JsonValue::Object(value) = value else { 
312                                    let expected: &[&str] = &[#(#field_names_camel),*];
313                                    return Err(format!("Expected object with fields {:?}, got {}", expected, openai_magic_instantiate::JsonValueExt::type_str(value)))
314                                };
315                                let result = Self {
316                                    #(
317                                        #field_idents: {
318                                            let value = match value.get(#field_names_camel) {
319                                                None => match <#field_types>::default_if_omitted() {
320                                                    Some(value) => value,
321                                                    None => return Err(format!("Expected field {}, but it wasn't present", #field_names_camel)),
322                                                },
323                                                Some(value) => match <#field_types>::validate(value) {
324                                                    Ok(value) => value,
325                                                    Err(error) => return Err(format!("Validation error for field {}:\n{}", #field_names_camel, error)),
326                                                }
327                                            };
328                                            #field_validators
329                                            value
330                                        },
331                                    )*
332                                };
333                                #(
334                                    #definition_validators
335                                )*
336                                Ok(result)
337                            }
338                    
339                            fn default_if_omitted() -> Option<Self> {
340                                Some(#ident {
341                                    #(
342                                        #field_idents: <#field_types>::default_if_omitted()?,
343                                    )*
344                                })
345                            }
346
347                            fn is_object() -> bool {
348                                true
349                            }
350                        }
351                    } 
352                },
353            }
354        }
355        Data::Enum(DataEnum { variants, .. }) => {
356            let mut variant_definitions = vec![];
357            let mut variant_struct_names = vec![];
358            let mut variant_struct_kinds = vec![];
359            let mut variant_struct_to_variants = vec![];
360
361            if generics.params.len() > 0 {
362                panic!("Enums with generics are not supported");
363            } 
364
365            for variant in variants {
366                let variant_attributes = variant
367                    .attrs
368                    .iter()
369                    .filter(|a| a.path().is_ident("magic"))
370                    .collect::<Vec<_>>();
371
372                let variant_ident = variant.ident;
373                let variant_struct_name = syn::Ident::new(&format!("{}{}", ident, variant_ident), proc_macro2::Span::call_site());
374                variant_struct_names.push(variant_struct_name.clone());
375
376                let variant_struct_kind = syn::Ident::new(&format!("{}{}", variant_struct_name, variant_ident), proc_macro2::Span::call_site());
377                variant_struct_kinds.push(variant_ident.clone());
378
379                let mut variant_fields = vec![
380                    quote! {
381                        kind: #variant_struct_kind,
382                    }
383                ];
384
385                match variant.fields {
386                    syn::Fields::Unit => {
387                        variant_struct_to_variants.push(quote! {
388                            Ok(Self::#variant_ident)
389                        });
390                    },
391                    syn::Fields::Unnamed(fields) => {
392                        let field_types = fields.unnamed.iter().map(|f| &f.ty).collect::<Vec<_>>();
393                        variant_fields.push(quote! {
394                            value: (#(#field_types,)*),
395                        });
396                        let field_idents = (0..field_types.len()).map(|i| syn::Ident::new(&format!("field{}", i), proc_macro2::Span::call_site())).collect::<Vec<_>>();
397                        variant_struct_to_variants.push(quote! {
398                            let (#(#field_idents,)*) = value.value;
399                            Ok(Self::#variant_ident(#(#field_idents),*))
400                        });
401                    },
402                    syn::Fields::Named(fields) => {
403                        for field in &fields.named {
404                            let field_attributes = &field.attrs;
405                            let field_name = field.ident.as_ref().unwrap();
406                            let field_type = &field.ty;
407
408                            variant_fields.push(quote! {
409                                #(#field_attributes)*
410                                #field_name: #field_type,
411                            });
412                        }
413                        let field_idents = fields.named.iter().map(|f| f.ident.as_ref().unwrap()).collect::<Vec<_>>();
414                        variant_struct_to_variants.push(quote! {
415                            Ok(Self::#variant_ident {
416                                #(#field_idents: value.#field_idents,)*
417                            })
418                        });
419                    }
420                }
421
422                variant_definitions.push(quote! {
423
424                    struct #variant_struct_kind;
425
426                    impl MagicInstantiate for #variant_struct_kind {
427                        fn name() -> String {
428                            stringify!(#variant_ident).to_string()
429                        }
430                        fn reference() -> String {
431                            format!("\"{}\"", stringify!(#variant_ident))
432                        }
433                        fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {}
434                        fn definition() -> String { "".to_string() }
435
436                        fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
437                            let expected = stringify!(#variant_ident);
438                            if value.as_str() == Some(expected.as_ref()) {
439                                Ok(Self)
440                            } else {
441                                Err(format!("Expected \"{expected}\""))
442                            }
443                        }
444                        fn default_if_omitted() -> Option<Self> { None }
445                        fn is_object() -> bool { false }
446                    }
447
448                    #[derive(MagicInstantiate)]
449                    #(#variant_attributes)*
450                    struct #variant_struct_name {
451                        #(#variant_fields)*
452                    }
453                });
454            }
455
456            quote! {
457                #(#variant_definitions)*
458
459                impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
460                    fn name() -> String {
461                        #name
462                    }
463
464                    fn reference() -> String {
465                        Self::name()
466                    }
467
468                    fn definition() -> String {
469                        let mut result = String::new();
470                        #(
471                            for line in #definition_description.lines() {
472                                result.push_str(&format!("// {}\n", line));
473                            }
474                        )*
475                        #(
476                            for line in #definition_validator_instructions.lines() {
477                                result.push_str(&format!("// {}\n", line));
478                            }
479                        )*
480                        result.push_str(&format!("type {} =\n", stringify!(#ident)));
481                        #(
482                            result.push_str(&format!("    | {}\n", <#variant_struct_names>::reference()));
483                        )*
484                        result.push_str(";");
485                        result
486                    }
487
488                    fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {
489                        #(
490                            builder.add::<#variant_struct_names>();
491                        )*
492                    }
493            
494                    fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
495                        let kind = value.get("kind").ok_or("Expected field 'kind'")?;
496                        let kind = kind.as_str().ok_or_else(|| format!("Expected 'kind' to be a string, got {}", openai_magic_instantiate::JsonValueExt::type_str(value)))?;
497                        let result = match kind {
498                            #(
499                                stringify!(#variant_struct_kinds) => {
500                                    let value = <#variant_struct_names>::validate(value)?;
501                                    #variant_struct_to_variants
502                                },
503                            )*
504                            _ => Err(format!("Unknown variant {}", kind)),
505                        }?;
506                        #(
507                            #definition_validators
508                        )*
509                        Ok(result)
510                    }
511            
512                    fn default_if_omitted() -> Option<Self> {
513                        None
514                    }
515
516                    fn is_object() -> bool {
517                        true
518                    }
519                }
520            }
521        },
522        Data::Union(_) => todo!(),
523    }.into()
524}
525
526#[proc_macro]
527pub fn implement_integers(_input: TokenStream) -> TokenStream {
528    let type_tokens = vec![
529        quote! { u8 },
530        quote! { u16 },
531        quote! { u32 },
532        quote! { u64 },
533        quote! { usize },
534        quote! { i8 },
535        quote! { i16 },
536        quote! { i32 },
537        quote! { i64 },
538        quote! { isize },
539    ];
540
541    let names = vec![
542        "U8",
543        "U16",
544        "U32",
545        "U64",
546        "USize",
547        "I8",
548        "I16",
549        "I32",
550        "I64",
551        "ISize",
552    ];
553
554    quote! {
555        #(
556            impl MagicInstantiate for #type_tokens {
557                fn name() -> String {
558                    #names.to_string()
559                }
560
561                fn reference() -> String {
562                    #names.to_string()
563                }
564
565                fn definition() -> String {
566                    let min = Self::MIN;
567                    let max = Self::MAX;
568                    let name = #names;
569                    format!("
570// Integer in [{min}, {max}]
571type {name} = number;
572                    ").trim().to_string()
573                }
574
575                fn add_dependencies(builder: &mut TypeScriptAccumulator) {}
576
577                fn validate(value: &JsonValue) -> Result<Self, String> {
578                    match value {
579                        JsonValue::Number(number) => {
580                            match number.as_i64() {
581                                Some(number) => {
582                                    if number >= (Self::MIN as i64) && number < (Self::MAX as i64) {
583                                        Ok(number as _)
584                                    } else {
585                                        Err(format!("Expected integer in [{}, {}], got {}", Self::MIN, Self::MAX, number))
586                                    }
587                                }
588                                None => Err(format!("Expected integer in [{}, {}], got {}", Self::MIN, Self::MAX, number)),
589                            }
590                        }
591                        _ => Err(format!("Expected integer, got {}", value.type_str())),
592                    }
593                }
594
595                fn default_if_omitted() -> Option<Self> {
596                    None
597                }
598
599                fn is_object() -> bool {
600                    false
601                }
602            }
603        )*
604    }.into()
605}
606
607#[proc_macro]
608pub fn implement_tuples(_input: TokenStream) -> TokenStream {
609    let mut results = vec![];
610
611    for i in 2..16usize {
612
613        let generic_names = (1..=i).map(|i| syn::Ident::new(&format!("T{}", i), proc_macro2::Span::call_site())).collect::<Vec<_>>();
614        let indexes = (0..i).collect::<Vec<_>>();
615
616        results.push(quote! {
617
618            impl<#(#generic_names: MagicInstantiate),*> MagicInstantiate for (#(#generic_names,)*) {
619                fn name() -> String {
620                    let names = vec![#(<#generic_names>::name()),*];
621                    format!("Tuple{}", names.join(""))
622                }
623
624                fn reference() -> String {
625                    let references = vec![#(<#generic_names>::reference()),*];
626                    format!("[{}]", references.join(", "))
627                }
628
629                fn definition () -> String { "".to_string() }
630
631                fn add_dependencies(builder: &mut TypeScriptAccumulator) {
632                    #(
633                    builder.add::<#generic_names>();
634                    )*
635                }
636
637                fn validate(value: &JsonValue) -> Result<Self, String> {
638                    let JsonValue::Array(value) = value else { return Err(format!("Expected array tuple, got {}", value.type_str())) };
639                    if value.len() != #i {
640                        return Err(format!("Expected {} elements but got {}", #i, value.len()));
641                    }
642                    Ok((#(<#generic_names>::validate(&value[#indexes])?,)*))
643                }
644
645                fn default_if_omitted() -> Option<Self> {
646                    None
647                }
648
649                fn is_object() -> bool {
650                    false
651                }
652            }
653        });
654    }
655
656    quote! {
657        #( #results )*
658    }.into()
659}