openai_magic_instantiate_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Expr, Field};
4use heck::{self, ToLowerCamelCase};
5
6#[derive(Debug, Default)]
7struct MagicAttrArgs {
8 description: Option<Expr>,
10 validators: Vec<Expr>,
12}
13
14impl MagicAttrArgs {
15 fn merge(&mut self, other: Self) {
16 if other.description.is_some() {
17 self.description = other.description;
18 }
19 self.validators.extend(other.validators);
20 }
21}
22
23impl syn::parse::Parse for MagicAttrArgs {
24 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
25 let mut description = None;
26 let mut validators = vec![];
27 while !input.is_empty() {
28 let name: syn::Ident = input.parse()?;
29 match name.to_string().as_str() {
30 "description" => {
31 input.parse::<syn::Token![=]>()?;
32 let value: Expr = input.parse()?;
33 description = Some(value);
34 },
35 "validator" => {
36 input.parse::<syn::Token![=]>()?;
37 let value: syn::Expr = input.parse()?;
38 validators.push(value);
39 },
40 _ => return Err(syn::Error::new(name.span(), "Unknown attribute")),
41 }
42 if input.is_empty() {
43 break;
44 }
45 input.parse::<syn::Token![,]>()?;
46 }
47 Ok(Self { description, validators })
48 }
49}
50
51
52fn attributes<'a>(attrs: impl Iterator<Item = &'a syn::Attribute>) -> MagicAttrArgs {
53 let mut result = MagicAttrArgs::default();
54 for attr in attrs {
55 if attr.path().is_ident("magic") {
56 let attr_args: MagicAttrArgs = attr.parse_args().unwrap();
57 result.merge(attr_args);
58 }
59 }
60 result
61}
62
63fn field_attributes<'a>(fields: impl Iterator<Item = &'a Field>) -> Vec<MagicAttrArgs> {
64 let mut results = vec![];
65 for field in fields {
66 results.push(attributes(field.attrs.iter()));
67 }
68 results
69}
70
71
72#[proc_macro_derive(MagicInstantiate, attributes(magic))]
75pub fn derive_magic_instantiate(input: TokenStream) -> TokenStream {
76 let DeriveInput { ident, data, generics, attrs, .. } = parse_macro_input!(input as DeriveInput);
77
78 let attrs = attributes(attrs.iter());
79 let definition_description = attrs.description.into_iter().collect::<Vec<_>>();
80 let definition_validators = attrs
81 .validators
82 .iter()
83 .map(|v| quote! { openai_magic_instantiate::Validator::<Self>::validate(&#v, &result)?; })
84 .collect::<Vec<_>>();
85 let definition_validator_instructions = attrs
86 .validators
87 .iter()
88 .map(|v| quote! { openai_magic_instantiate::Validator::<Self>::instructions(&#v) })
89 .collect::<Vec<_>>();
90
91 let mut generics = generics.clone();
92 for generic in generics.params.iter_mut() {
93 if let syn::GenericParam::Type(type_param) = generic {
94 type_param.bounds.push(syn::parse_quote!(MagicInstantiate));
95 }
96 }
97 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
98
99 let generic_types = generics.params.iter().map(|p| {
100 match p {
101 syn::GenericParam::Type(type_param) => &type_param.ident,
102 syn::GenericParam::Lifetime(_) => panic!("Lifetime parameters are not supported"),
103 syn::GenericParam::Const(_) => panic!("Const parameters are not supported"),
104 }
105 }).collect::<Vec<_>>();
106
107 let name = quote ! {
108 let mut result = stringify!(#ident).to_string();
109 #(
110 result.push_str(&format!("{}", <#generic_types>::name()));
111 )*
112 result
113 };
114
115 match data {
116 Data::Struct(DataStruct { fields, .. }) => {
117 match &fields {
118 syn::Fields::Unit => {
119 quote! {
120 impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
121 fn name() -> String {
122 #name
123 }
124
125 fn reference() -> String {
126 ()::reference()
127 }
128
129 fn definition() -> String {
130 ()::definition()
131 }
132
133 fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) -> String {
134 ()::add_dependencies(builder)
135 }
136
137 fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
138 ()::validate(value)?;
139 Ok(Self)
140 }
141
142 fn default_if_omitted() -> Option<Self> {
143 Some(Self)
144 }
145
146 fn is_object() -> bool {
147 false
148 }
149 }
150 }
151 },
152 syn::Fields::Unnamed(fields) => {
153 let field_types = fields.unnamed.iter().map(|f| &f.ty).collect::<Vec<_>>();
154 let field_indices = (0..field_types.len()).collect::<Vec<_>>();
155 let field_count = field_types.len();
156 let type_definition = if field_count == 1 {
157 quote! {
158 result.push_str(&format!("type {} = {};", stringify!(#ident), references[0]));
159 }
160 } else {
161 quote! {
162 result.push_str(&format!("type {} = [{}];", stringify!(#ident), references.join(", ")));
163 }
164 };
165 let validate_definition = if field_count == 1 {
166 let field_type = &field_types[0];
167 quote! {
168 let value = <#field_type>::validate(value)?;
169 let result = Self(value);
170 }
171 } else {
172 quote! {
173 let openai_magic_instantiate::export::JsonValue::Array(value) = value else { return Err(format!("Expected array tuple, got {}", openai_magic_instantiate::JsonValueExt::type_str(value))) };
174 if value.len() != #field_count {
175 return Err(format!("Expected {} elements but got {}", #field_count, value.len()));
176 }
177 let result = Self(#(#field_types::validate(&value[#field_indices])?),*);
178 }
179 };
180 quote! {
181 impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
182 fn name() -> String {
183 #name
184 }
185
186 fn reference() -> String {
187 Self::name()
188 }
189
190 fn definition() -> String {
191 let mut result = String::new();
192 #(
193 for line in #definition_description.lines() {
194 result.push_str(&format!("// {}\n", line));
195 }
196 )*
197 #(
198 for line in #definition_validator_instructions.lines() {
199 result.push_str(&format!("// {}\n", line));
200 }
201 )*
202 let references = vec![#(<#field_types>::reference()),*];
203 #type_definition
204 }
205
206 fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {
207 #(
208 builder.add::<#field_types>();
209 )*
210 }
211
212 fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
213 #validate_definition
214 #(
215 #definition_validators
216 )*
217 Ok(result)
218 }
219
220 fn default_if_omitted() -> Option<Self> {
221 Some(#ident(#(<#field_types>::default_if_omitted()?),*))
222 }
223
224 fn is_object() -> bool {
225 false
226 }
227 }
228 }
229 },
230 syn::Fields::Named(fields) => {
231 let attributes = field_attributes(fields.named.iter());
232 let field_idents = fields.named.iter().map(|f| f.ident.as_ref().unwrap()).collect::<Vec<_>>();
233 let field_types = fields.named.iter().map(|f| &f.ty).collect::<Vec<_>>();
234 let field_names_camel = field_idents.iter().map(|f| f.to_string().to_lower_camel_case()).collect::<Vec<_>>();
235 let field_is_optionals = field_types.iter().map(|f| {
236 quote! {
237 if <#f>::default_if_omitted().is_some() { "?" } else { "" }
238 }
239 }).collect::<Vec<_>>();
240
241 let descriptions = attributes.iter().map(|a| {
242 if let Some(description) = &a.description {
243 quote! {
244 result.push_str(&format!(" // {}\n", #description));
245 }
246 } else {
247 quote! {}
248 }
249 }).collect::<Vec<_>>();
250
251 let validation_comments = attributes.iter().zip(&field_types).map(|(a, field_type)| {
252 let validators = &a.validators;
253 quote! {
254 #(
255 for line in openai_magic_instantiate::Validator::<#field_type>::instructions(&#validators).lines() {
256 result.push_str(&format!(" // {}\n", line));
257 }
258 )*
259 }
260 }).collect::<Vec<_>>();
261
262 let field_validators = (0..field_types.len()).map(|i| {
263 let field_type = &field_types[i];
264 let validators = &attributes[i].validators;
265 quote! {
266 #(
267 openai_magic_instantiate::Validator::<#field_type>::validate(&#validators, &value)?;
268 )*
269 }
270 }).collect::<Vec<_>>();
271
272 quote! {
273 impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
274 fn name() -> String {
275 #name
276 }
277
278 fn reference() -> String {
279 Self::name()
280 }
281
282 fn definition() -> String {
283 let mut result = String::new();
284 #(
285 for line in #definition_description.lines() {
286 result.push_str(&format!("// {}\n", line));
287 }
288 )*
289 #(
290 for line in #definition_validator_instructions.lines() {
291 result.push_str(&format!("// {}\n", line));
292 }
293 )*
294 result.push_str(&format!("type {} = {{\n", Self::name()));
295 #(
296 #descriptions
297 #validation_comments
298 result.push_str(&format!(" {}{}: {};\n", #field_names_camel, #field_is_optionals, <#field_types>::reference()));
299 )*
300 result.push_str("};");
301 result
302 }
303
304 fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {
305 #(
306 builder.add::<#field_types>();
307 )*
308 }
309
310 fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
311 let openai_magic_instantiate::export::JsonValue::Object(value) = value else {
312 let expected: &[&str] = &[#(#field_names_camel),*];
313 return Err(format!("Expected object with fields {:?}, got {}", expected, openai_magic_instantiate::JsonValueExt::type_str(value)))
314 };
315 let result = Self {
316 #(
317 #field_idents: {
318 let value = match value.get(#field_names_camel) {
319 None => match <#field_types>::default_if_omitted() {
320 Some(value) => value,
321 None => return Err(format!("Expected field {}, but it wasn't present", #field_names_camel)),
322 },
323 Some(value) => match <#field_types>::validate(value) {
324 Ok(value) => value,
325 Err(error) => return Err(format!("Validation error for field {}:\n{}", #field_names_camel, error)),
326 }
327 };
328 #field_validators
329 value
330 },
331 )*
332 };
333 #(
334 #definition_validators
335 )*
336 Ok(result)
337 }
338
339 fn default_if_omitted() -> Option<Self> {
340 Some(#ident {
341 #(
342 #field_idents: <#field_types>::default_if_omitted()?,
343 )*
344 })
345 }
346
347 fn is_object() -> bool {
348 true
349 }
350 }
351 }
352 },
353 }
354 }
355 Data::Enum(DataEnum { variants, .. }) => {
356 let mut variant_definitions = vec![];
357 let mut variant_struct_names = vec![];
358 let mut variant_struct_kinds = vec![];
359 let mut variant_struct_to_variants = vec![];
360
361 if generics.params.len() > 0 {
362 panic!("Enums with generics are not supported");
363 }
364
365 for variant in variants {
366 let variant_attributes = variant
367 .attrs
368 .iter()
369 .filter(|a| a.path().is_ident("magic"))
370 .collect::<Vec<_>>();
371
372 let variant_ident = variant.ident;
373 let variant_struct_name = syn::Ident::new(&format!("{}{}", ident, variant_ident), proc_macro2::Span::call_site());
374 variant_struct_names.push(variant_struct_name.clone());
375
376 let variant_struct_kind = syn::Ident::new(&format!("{}{}", variant_struct_name, variant_ident), proc_macro2::Span::call_site());
377 variant_struct_kinds.push(variant_ident.clone());
378
379 let mut variant_fields = vec![
380 quote! {
381 kind: #variant_struct_kind,
382 }
383 ];
384
385 match variant.fields {
386 syn::Fields::Unit => {
387 variant_struct_to_variants.push(quote! {
388 Ok(Self::#variant_ident)
389 });
390 },
391 syn::Fields::Unnamed(fields) => {
392 let field_types = fields.unnamed.iter().map(|f| &f.ty).collect::<Vec<_>>();
393 variant_fields.push(quote! {
394 value: (#(#field_types,)*),
395 });
396 let field_idents = (0..field_types.len()).map(|i| syn::Ident::new(&format!("field{}", i), proc_macro2::Span::call_site())).collect::<Vec<_>>();
397 variant_struct_to_variants.push(quote! {
398 let (#(#field_idents,)*) = value.value;
399 Ok(Self::#variant_ident(#(#field_idents),*))
400 });
401 },
402 syn::Fields::Named(fields) => {
403 for field in &fields.named {
404 let field_attributes = &field.attrs;
405 let field_name = field.ident.as_ref().unwrap();
406 let field_type = &field.ty;
407
408 variant_fields.push(quote! {
409 #(#field_attributes)*
410 #field_name: #field_type,
411 });
412 }
413 let field_idents = fields.named.iter().map(|f| f.ident.as_ref().unwrap()).collect::<Vec<_>>();
414 variant_struct_to_variants.push(quote! {
415 Ok(Self::#variant_ident {
416 #(#field_idents: value.#field_idents,)*
417 })
418 });
419 }
420 }
421
422 variant_definitions.push(quote! {
423
424 struct #variant_struct_kind;
425
426 impl MagicInstantiate for #variant_struct_kind {
427 fn name() -> String {
428 stringify!(#variant_ident).to_string()
429 }
430 fn reference() -> String {
431 format!("\"{}\"", stringify!(#variant_ident))
432 }
433 fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {}
434 fn definition() -> String { "".to_string() }
435
436 fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
437 let expected = stringify!(#variant_ident);
438 if value.as_str() == Some(expected.as_ref()) {
439 Ok(Self)
440 } else {
441 Err(format!("Expected \"{expected}\""))
442 }
443 }
444 fn default_if_omitted() -> Option<Self> { None }
445 fn is_object() -> bool { false }
446 }
447
448 #[derive(MagicInstantiate)]
449 #(#variant_attributes)*
450 struct #variant_struct_name {
451 #(#variant_fields)*
452 }
453 });
454 }
455
456 quote! {
457 #(#variant_definitions)*
458
459 impl #impl_generics MagicInstantiate for #ident #ty_generics #where_clause {
460 fn name() -> String {
461 #name
462 }
463
464 fn reference() -> String {
465 Self::name()
466 }
467
468 fn definition() -> String {
469 let mut result = String::new();
470 #(
471 for line in #definition_description.lines() {
472 result.push_str(&format!("// {}\n", line));
473 }
474 )*
475 #(
476 for line in #definition_validator_instructions.lines() {
477 result.push_str(&format!("// {}\n", line));
478 }
479 )*
480 result.push_str(&format!("type {} =\n", stringify!(#ident)));
481 #(
482 result.push_str(&format!(" | {}\n", <#variant_struct_names>::reference()));
483 )*
484 result.push_str(";");
485 result
486 }
487
488 fn add_dependencies(builder: &mut openai_magic_instantiate::TypeScriptAccumulator) {
489 #(
490 builder.add::<#variant_struct_names>();
491 )*
492 }
493
494 fn validate(value: &openai_magic_instantiate::export::JsonValue) -> std::result::Result<Self, String> {
495 let kind = value.get("kind").ok_or("Expected field 'kind'")?;
496 let kind = kind.as_str().ok_or_else(|| format!("Expected 'kind' to be a string, got {}", openai_magic_instantiate::JsonValueExt::type_str(value)))?;
497 let result = match kind {
498 #(
499 stringify!(#variant_struct_kinds) => {
500 let value = <#variant_struct_names>::validate(value)?;
501 #variant_struct_to_variants
502 },
503 )*
504 _ => Err(format!("Unknown variant {}", kind)),
505 }?;
506 #(
507 #definition_validators
508 )*
509 Ok(result)
510 }
511
512 fn default_if_omitted() -> Option<Self> {
513 None
514 }
515
516 fn is_object() -> bool {
517 true
518 }
519 }
520 }
521 },
522 Data::Union(_) => todo!(),
523 }.into()
524}
525
526#[proc_macro]
527pub fn implement_integers(_input: TokenStream) -> TokenStream {
528 let type_tokens = vec![
529 quote! { u8 },
530 quote! { u16 },
531 quote! { u32 },
532 quote! { u64 },
533 quote! { usize },
534 quote! { i8 },
535 quote! { i16 },
536 quote! { i32 },
537 quote! { i64 },
538 quote! { isize },
539 ];
540
541 let names = vec![
542 "U8",
543 "U16",
544 "U32",
545 "U64",
546 "USize",
547 "I8",
548 "I16",
549 "I32",
550 "I64",
551 "ISize",
552 ];
553
554 quote! {
555 #(
556 impl MagicInstantiate for #type_tokens {
557 fn name() -> String {
558 #names.to_string()
559 }
560
561 fn reference() -> String {
562 #names.to_string()
563 }
564
565 fn definition() -> String {
566 let min = Self::MIN;
567 let max = Self::MAX;
568 let name = #names;
569 format!("
570// Integer in [{min}, {max}]
571type {name} = number;
572 ").trim().to_string()
573 }
574
575 fn add_dependencies(builder: &mut TypeScriptAccumulator) {}
576
577 fn validate(value: &JsonValue) -> Result<Self, String> {
578 match value {
579 JsonValue::Number(number) => {
580 match number.as_i64() {
581 Some(number) => {
582 if number >= (Self::MIN as i64) && number < (Self::MAX as i64) {
583 Ok(number as _)
584 } else {
585 Err(format!("Expected integer in [{}, {}], got {}", Self::MIN, Self::MAX, number))
586 }
587 }
588 None => Err(format!("Expected integer in [{}, {}], got {}", Self::MIN, Self::MAX, number)),
589 }
590 }
591 _ => Err(format!("Expected integer, got {}", value.type_str())),
592 }
593 }
594
595 fn default_if_omitted() -> Option<Self> {
596 None
597 }
598
599 fn is_object() -> bool {
600 false
601 }
602 }
603 )*
604 }.into()
605}
606
607#[proc_macro]
608pub fn implement_tuples(_input: TokenStream) -> TokenStream {
609 let mut results = vec![];
610
611 for i in 2..16usize {
612
613 let generic_names = (1..=i).map(|i| syn::Ident::new(&format!("T{}", i), proc_macro2::Span::call_site())).collect::<Vec<_>>();
614 let indexes = (0..i).collect::<Vec<_>>();
615
616 results.push(quote! {
617
618 impl<#(#generic_names: MagicInstantiate),*> MagicInstantiate for (#(#generic_names,)*) {
619 fn name() -> String {
620 let names = vec![#(<#generic_names>::name()),*];
621 format!("Tuple{}", names.join(""))
622 }
623
624 fn reference() -> String {
625 let references = vec![#(<#generic_names>::reference()),*];
626 format!("[{}]", references.join(", "))
627 }
628
629 fn definition () -> String { "".to_string() }
630
631 fn add_dependencies(builder: &mut TypeScriptAccumulator) {
632 #(
633 builder.add::<#generic_names>();
634 )*
635 }
636
637 fn validate(value: &JsonValue) -> Result<Self, String> {
638 let JsonValue::Array(value) = value else { return Err(format!("Expected array tuple, got {}", value.type_str())) };
639 if value.len() != #i {
640 return Err(format!("Expected {} elements but got {}", #i, value.len()));
641 }
642 Ok((#(<#generic_names>::validate(&value[#indexes])?,)*))
643 }
644
645 fn default_if_omitted() -> Option<Self> {
646 None
647 }
648
649 fn is_object() -> bool {
650 false
651 }
652 }
653 });
654 }
655
656 quote! {
657 #( #results )*
658 }.into()
659}