1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Lit};
4
5#[proc_macro_derive(StructuredOutput, attributes(structured_output))]
27pub fn derive_structured_output(input: TokenStream) -> TokenStream {
28 let input = parse_macro_input!(input as DeriveInput);
29
30 let name = &input.ident;
31 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33 let (tool_name, tool_description) = parse_attributes(&input);
35
36 let schema = generate_schema(&input.data);
38
39 let expanded = quote! {
40 impl #impl_generics struct_llm::StructuredOutput for #name #ty_generics #where_clause {
41 fn tool_name() -> &'static str {
42 #tool_name
43 }
44
45 fn tool_description() -> &'static str {
46 #tool_description
47 }
48
49 fn json_schema() -> serde_json::Value {
50 serde_json::json!(#schema)
51 }
52 }
53 };
54
55 TokenStream::from(expanded)
56}
57
58fn parse_attributes(input: &DeriveInput) -> (String, String) {
59 let mut tool_name = None;
60 let mut tool_description = None;
61
62 for attr in &input.attrs {
63 if !attr.path().is_ident("structured_output") {
64 continue;
65 }
66
67 let _ = attr.parse_nested_meta(|meta| {
68 if meta.path.is_ident("name") {
69 let value = meta.value()?;
70 let s: Lit = value.parse()?;
71 if let Lit::Str(lit_str) = s {
72 tool_name = Some(lit_str.value());
73 }
74 } else if meta.path.is_ident("description") {
75 let value = meta.value()?;
76 let s: Lit = value.parse()?;
77 if let Lit::Str(lit_str) = s {
78 tool_description = Some(lit_str.value());
79 }
80 }
81 Ok(())
82 });
83 }
84
85 let tool_name = tool_name.expect("missing #[structured_output(name = \"...\")] attribute");
86 let tool_description = tool_description.expect("missing #[structured_output(description = \"...\")] attribute");
87
88 (tool_name, tool_description)
89}
90
91fn generate_schema(data: &Data) -> proc_macro2::TokenStream {
92 match data {
93 Data::Struct(data_struct) => generate_struct_schema(&data_struct.fields),
94 Data::Enum(_) => {
95 panic!("StructuredOutput can only be derived for structs, not enums");
96 }
97 Data::Union(_) => {
98 panic!("StructuredOutput can only be derived for unions");
99 }
100 }
101}
102
103fn generate_struct_schema(fields: &Fields) -> proc_macro2::TokenStream {
104 let mut properties = Vec::new();
105 let mut required = Vec::new();
106
107 match fields {
108 Fields::Named(fields_named) => {
109 for field in &fields_named.named {
110 let field_name = field.ident.as_ref().unwrap().to_string();
111 let field_schema = generate_field_schema(&field.ty);
112
113 properties.push(quote! {
114 #field_name: #field_schema
115 });
116
117 required.push(field_name);
118 }
119 }
120 Fields::Unnamed(_) => {
121 panic!("StructuredOutput does not support tuple structs");
122 }
123 Fields::Unit => {
124 panic!("StructuredOutput does not support unit structs");
125 }
126 }
127
128 let required_fields = required.iter().map(|s| quote! { #s });
129
130 quote! {
131 {
132 "type": "object",
133 "properties": {
134 #(#properties),*
135 },
136 "required": [#(#required_fields),*]
137 }
138 }
139}
140
141fn generate_field_schema(ty: &syn::Type) -> proc_macro2::TokenStream {
142 if let syn::Type::Path(type_path) = ty {
143 if let Some(segment) = type_path.path.segments.last() {
144 let type_name = segment.ident.to_string();
145
146 if type_name == "Vec" {
148 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
149 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
150 let item_type = infer_json_type(inner_ty);
151 return quote! {
152 {
153 "type": "array",
154 "items": {
155 "type": #item_type
156 }
157 }
158 };
159 }
160 }
161 return quote! {
163 {
164 "type": "array",
165 "items": {}
166 }
167 };
168 }
169
170 if type_name == "Option" {
172 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
173 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
174 return generate_field_schema(inner_ty);
176 }
177 }
178 }
179 }
180 }
181
182 let type_str = infer_json_type(ty);
184 quote! {
185 {
186 "type": #type_str
187 }
188 }
189}
190
191fn infer_json_type(ty: &syn::Type) -> &'static str {
192 if let syn::Type::Path(type_path) = ty {
194 if let Some(segment) = type_path.path.segments.last() {
195 let type_name = segment.ident.to_string();
196
197 return match type_name.as_str() {
198 "String" | "str" => "string",
199 "i8" | "i16" | "i32" | "i64" | "i128" |
200 "u8" | "u16" | "u32" | "u64" | "u128" |
201 "isize" | "usize" => "integer",
202 "f32" | "f64" => "number",
203 "bool" => "boolean",
204 "Vec" => "array",
205 "HashMap" | "BTreeMap" => "object",
206 _ => {
207 if type_name == "Option" {
209 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
211 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
212 return infer_json_type(inner_ty);
213 }
214 }
215 }
216 "string"
218 }
219 };
220 }
221 }
222
223 "string" }