openai_schema_impl/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use serde_json::{json, Value};
4use syn::{
5 parse_macro_input, Attribute, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type,
6};
7
8#[proc_macro_derive(OpenAISchema)]
9pub fn openai_schema_derive(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11 let name = input.ident;
12
13 let description = get_description(&input.attrs);
14
15 let properties = match input.data {
16 Data::Struct(data) => match data.fields {
17 Fields::Named(fields) => fields
18 .named
19 .iter()
20 .map(|f| {
21 let field_name = f.ident.as_ref().unwrap().to_string();
22 let field_type = get_field_type(&f.ty);
23 (field_name, field_type)
24 })
25 .collect::<serde_json::Map<String, Value>>(),
26 Fields::Unnamed(fields) => fields
27 .unnamed
28 .iter()
29 .enumerate()
30 .map(|(i, f)| {
31 let field_name = i.to_string();
32 let field_type = get_field_type(&f.ty);
33 (field_name, field_type)
34 })
35 .collect::<serde_json::Map<String, Value>>(),
36 Fields::Unit => serde_json::Map::new(),
37 },
38 _ => panic!("Only structs are supported"),
39 };
40
41 let required: Vec<String> = properties.keys().cloned().collect();
42
43 let schema = json!({
44 "name": name.to_string(),
45 "description": description,
46 "strict": true,
47 "schema": {
48 "type": "object",
49 "properties": properties,
50 "required": required,
51 "additionalProperties": false
52 }
53 });
54
55 let schema_str = serde_json::to_string(&schema).unwrap();
56
57 let expanded = quote! {
58 impl openai_schema::OpenAISchema for #name {
59 fn openai_schema() -> openai_schema::GeneratedOpenAISchema {
60 #schema_str.into()
61 }
62 }
63 };
64
65 TokenStream::from(expanded)
66}
67
68fn get_description(attrs: &[Attribute]) -> String {
69 attrs
70 .iter()
71 .find(|attr| attr.path().is_ident("doc"))
72 .map(|attr| attr.parse_args::<syn::LitStr>().unwrap().value())
73 .unwrap_or_default()
74}
75
76fn get_field_type(ty: &Type) -> Value {
77 match ty {
78 Type::Path(type_path) => {
79 let segment = type_path.path.segments.last().unwrap();
80 let type_name = segment.ident.to_string();
81 match type_name.as_str() {
82 "String" => json!({"type": "string"}),
83 "i32" | "i64" | "f32" | "f64" => json!({"type": "number"}),
84 "bool" => json!({"type": "boolean"}),
85 "Vec" => {
86 if let PathArguments::AngleBracketed(args) = &segment.arguments {
87 if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
88 return json!({
89 "type": "array",
90 "items": get_field_type(inner_type)
91 });
92 }
93 }
94 json!({"type": "array", "items": {}})
95 }
96 "Option" => {
97 if let PathArguments::AngleBracketed(args) = &segment.arguments {
98 if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
99 let inner_schema = get_field_type(inner_type);
100 return json!({
101 "anyOf": [
102 inner_schema,
103 {"type": "null"}
104 ]
105 });
106 }
107 }
108 json!({"anyOf": [{"type": "null"}]})
109 }
110 _ => json!({"type": "object"}), }
112 }
113 _ => panic!("Unsupported type"),
114 }
115}