1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Type, Meta};
4
5#[proc_macro_derive(LlmSchema, attributes(llmschem))]
6pub fn derive_llm_schema(input: TokenStream) -> TokenStream {
7 let ast = parse_macro_input!(input as DeriveInput);
8 let name = &ast.ident;
9
10 let schema = generate_schema(&ast);
11
12 let expanded = quote! {
13 impl #name {
14 pub fn llm_schema() -> ::serde_json::Value {
15 #schema
16 }
17 }
18 };
19
20 TokenStream::from(expanded)
21}
22
23fn generate_schema(ast: &DeriveInput) -> proc_macro2::TokenStream {
24 let name = &ast.ident;
25 let fields = match &ast.data {
26 Data::Struct(data) => match &data.fields {
27 Fields::Named(fields) => fields.named.iter(),
28 _ => unimplemented!("Only named fields are supported"),
29 },
30 _ => unimplemented!("Only structs are supported"),
31 };
32
33 let mut properties = quote! {};
34 let mut required_fields = Vec::new();
35
36 for field in fields {
37 let field_name = &field.ident;
38 let field_name_str = field_name.as_ref().unwrap().to_string();
39
40 let mut is_required = false; for attr in &field.attrs {
43 if attr.path().is_ident("llmschem") {
44 attr.parse_nested_meta(|meta| {
45 if meta.path.is_ident("require") {
46 is_required = true;
47 }
48 Ok(())
49 }).unwrap_or_default();
50 }
51 }
52
53 if let Type::Path(type_path) = &field.ty {
55 if let Some(segment) = type_path.path.segments.last() {
56 if segment.ident == "Option" {
57 is_required = false;
58 }
59 }
60 }
61
62 let type_def = get_type_definition(&field.ty);
63
64 properties.extend(quote! {
65 #field_name_str: {
66 "type": #type_def
67 },
68 });
69
70 if is_required {
71 required_fields.push(field_name_str);
72 }
73 }
74
75 let required_array = if !required_fields.is_empty() {
76 let required = required_fields.iter().map(|s| quote! { #s });
77 quote! {
78 schema["required"] = ::serde_json::json!([#(#required),*]);
79 }
80 } else {
81 quote! {}
82 };
83
84 quote! {
85 {
86 let mut schema = ::serde_json::json!({
87 "type": "object",
88 "properties": {
89 #properties
90 }
91 });
92
93 #required_array
94
95 schema
96 }
97 }
98}
99
100fn get_type_definition(ty: &Type) -> &'static str {
101 match ty {
102 Type::Path(type_path) => {
103 if let Some(segment) = type_path.path.segments.last() {
104 match segment.ident.to_string().as_str() {
105 "String" => "string",
106 "str" => "string",
107 "bool" => "boolean",
108 "f32" | "f64" | "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" => "number",
109 "Option" => "null", _ => panic!("Unsupported type for LLM schema"),
111 }
112 } else {
113 panic!("Invalid type path");
114 }
115 }
116 _ => panic!("Only simple types are supported"),
117 }
118}