llm_schema/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Type, Meta};
4
5#[proc_macro_derive(LlmSchema, attributes(llmschem))]
6pub fn derive_llm_schema(input: TokenStream) -> TokenStream {
7    let ast = parse_macro_input!(input as DeriveInput);
8    let name = &ast.ident;
9
10    let schema = generate_schema(&ast);
11
12    let expanded = quote! {
13        impl #name {
14            pub fn llm_schema() -> ::serde_json::Value {
15                #schema
16            }
17        }
18    };
19
20    TokenStream::from(expanded)
21}
22
23fn generate_schema(ast: &DeriveInput) -> proc_macro2::TokenStream {
24    let name = &ast.ident;
25    let fields = match &ast.data {
26        Data::Struct(data) => match &data.fields {
27            Fields::Named(fields) => fields.named.iter(),
28            _ => unimplemented!("Only named fields are supported"),
29        },
30        _ => unimplemented!("Only structs are supported"),
31    };
32
33    let mut properties = quote! {};
34    let mut required_fields = Vec::new();
35
36    for field in fields {
37        let field_name = &field.ident;
38        let field_name_str = field_name.as_ref().unwrap().to_string();
39
40        // Check for #[llmschem(require)] attribute
41        let mut is_required = false; // Changed to false by default to match previous behavior
42        for attr in &field.attrs {
43            if attr.path().is_ident("llmschem") {
44                attr.parse_nested_meta(|meta| {
45                    if meta.path.is_ident("require") {
46                        is_required = true;
47                    }
48                    Ok(())
49                }).unwrap_or_default();
50            }
51        }
52
53        // Handle Option<T> types
54        if let Type::Path(type_path) = &field.ty {
55            if let Some(segment) = type_path.path.segments.last() {
56                if segment.ident == "Option" {
57                    is_required = false;
58                }
59            }
60        }
61
62        let type_def = get_type_definition(&field.ty);
63
64        properties.extend(quote! {
65            #field_name_str: {
66                "type": #type_def
67            },
68        });
69
70        if is_required {
71            required_fields.push(field_name_str);
72        }
73    }
74
75    let required_array = if !required_fields.is_empty() {
76        let required = required_fields.iter().map(|s| quote! { #s });
77        quote! {
78            schema["required"] = ::serde_json::json!([#(#required),*]);
79        }
80    } else {
81        quote! {}
82    };
83
84    quote! {
85        {
86            let mut schema = ::serde_json::json!({
87                "type": "object",
88                "properties": {
89                    #properties
90                }
91            });
92
93            #required_array
94
95            schema
96        }
97    }
98}
99
100fn get_type_definition(ty: &Type) -> &'static str {
101    match ty {
102        Type::Path(type_path) => {
103            if let Some(segment) = type_path.path.segments.last() {
104                match segment.ident.to_string().as_str() {
105                    "String" => "string",
106                    "str" => "string",
107                    "bool" => "boolean",
108                    "f32" | "f64" | "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" => "number",
109                    "Option" => "null", // Option<T> will be handled in parent function
110                    _ => panic!("Unsupported type for LLM schema"),
111                }
112            } else {
113                panic!("Invalid type path");
114            }
115        }
116        _ => panic!("Only simple types are supported"),
117    }
118}