openai_schema_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use serde_json::{json, Value};
4use syn::{
5    parse_macro_input, Attribute, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type,
6};
7
8#[proc_macro_derive(OpenAISchema)]
9pub fn openai_schema_derive(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    let name = input.ident;
12
13    let description = get_description(&input.attrs);
14
15    let properties = match input.data {
16        Data::Struct(data) => match data.fields {
17            Fields::Named(fields) => fields
18                .named
19                .iter()
20                .map(|f| {
21                    let field_name = f.ident.as_ref().unwrap().to_string();
22                    let field_type = get_field_type(&f.ty);
23                    (field_name, field_type)
24                })
25                .collect::<serde_json::Map<String, Value>>(),
26            Fields::Unnamed(fields) => fields
27                .unnamed
28                .iter()
29                .enumerate()
30                .map(|(i, f)| {
31                    let field_name = i.to_string();
32                    let field_type = get_field_type(&f.ty);
33                    (field_name, field_type)
34                })
35                .collect::<serde_json::Map<String, Value>>(),
36            Fields::Unit => serde_json::Map::new(),
37        },
38        _ => panic!("Only structs are supported"),
39    };
40
41    let required: Vec<String> = properties.keys().cloned().collect();
42
43    let schema = json!({
44        "name": name.to_string(),
45        "description": description,
46        "strict": true,
47        "schema": {
48            "type": "object",
49            "properties": properties,
50            "required": required,
51            "additionalProperties": false
52        }
53    });
54
55    let schema_str = serde_json::to_string(&schema).unwrap();
56
57    let expanded = quote! {
58        impl openai_schema::OpenAISchema for #name {
59            fn openai_schema() -> openai_schema::GeneratedOpenAISchema {
60                #schema_str.into()
61            }
62        }
63    };
64
65    TokenStream::from(expanded)
66}
67
68fn get_description(attrs: &[Attribute]) -> String {
69    attrs
70        .iter()
71        .find(|attr| attr.path().is_ident("doc"))
72        .map(|attr| attr.parse_args::<syn::LitStr>().unwrap().value())
73        .unwrap_or_default()
74}
75
76fn get_field_type(ty: &Type) -> Value {
77    match ty {
78        Type::Path(type_path) => {
79            let segment = type_path.path.segments.last().unwrap();
80            let type_name = segment.ident.to_string();
81            match type_name.as_str() {
82                "String" => json!({"type": "string"}),
83                "i32" | "i64" | "f32" | "f64" => json!({"type": "number"}),
84                "bool" => json!({"type": "boolean"}),
85                "Vec" => {
86                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
87                        if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
88                            return json!({
89                                "type": "array",
90                                "items": get_field_type(inner_type)
91                            });
92                        }
93                    }
94                    json!({"type": "array", "items": {}})
95                }
96                "Option" => {
97                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
98                        if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
99                            let inner_schema = get_field_type(inner_type);
100                            return json!({
101                                "anyOf": [
102                                    inner_schema,
103                                    {"type": "null"}
104                                ]
105                            });
106                        }
107                    }
108                    json!({"anyOf": [{"type": "null"}]})
109                }
110                _ => json!({"type": "object"}), // assume custom types are objects
111            }
112        }
113        _ => panic!("Unsupported type"),
114    }
115}