pyro_macro/module/
spec.rs1use std::borrow::Cow;
25
26use pyro_spec::{ModuleFunc, PyroField, PyroSchema};
27use syn::{Attribute, Expr, FnArg, ItemFn, Lit, Meta, Pat, ReturnType, Type};
28
29use crate::struct_doc::SchemaBuilder;
30
31use super::parse::{ModuleAttrs, OutputSpec};
32
33pub fn generate_module_spec(content: &str) -> syn::Result<Option<ModuleFunc<'static>>> {
42 let file = syn::parse_file(content)?;
43 let builder = SchemaBuilder::from_file(&file);
44
45 for item in &file.items {
46 if let syn::Item::Fn(item_fn) = item {
47 if !super::has_module_attr(&item_fn.attrs) {
48 continue;
49 }
50
51 let attr_tokens = super::extract_module_attr(&item_fn.attrs)?.ok_or_else(|| {
52 syn::Error::new_spanned(
53 item_fn,
54 "Module attribute requires arguments: #[module(output = ...)]",
55 )
56 })?;
57
58 let attrs: ModuleAttrs = syn::parse2(attr_tokens)?;
59 let spec = ModuleSpecBuilder::build(item_fn, &attrs, &builder)?;
60
61 return Ok(Some(spec.into()));
62 }
63 }
64
65 Ok(None)
66}
67
68pub struct ModuleSpecBuilder;
73
74impl ModuleSpecBuilder {
75 pub fn build(
77 item_fn: &ItemFn,
78 attrs: &ModuleAttrs,
79 builder: &SchemaBuilder,
80 ) -> syn::Result<ModuleFunc<'static>> {
81 let name = item_fn.sig.ident.to_string();
82 let description = extract_doc_string(&item_fn.attrs);
83
84 let input_fields: Vec<PyroField<'static>> = item_fn
86 .sig
87 .inputs
88 .iter()
89 .filter_map(|arg| {
90 if let FnArg::Typed(pat_type) = arg {
91 if let Pat::Ident(pat_ident) = &*pat_type.pat {
92 let field_name = pat_ident.ident.to_string();
93 let ty = &*pat_type.ty;
94 let data_type = builder.resolve_type(ty);
95 let nullable = SchemaBuilder::is_option(ty);
96 let doc = extract_doc_string(&pat_type.attrs);
97 let mut field = PyroField::new(Cow::Owned(field_name), data_type, nullable);
98 if let Some(d) = doc {
99 field = field.add_docstring(Cow::Owned(d));
100 }
101 return Some(field);
102 }
103 }
104 None
105 })
106 .collect();
107
108 let input = PyroSchema::new(input_fields);
109
110 let ok_type = extract_result_ok_type(&item_fn.sig.output)?;
112 let output = build_output_schema(&attrs.output, &ok_type, builder)?;
113
114 let func = ModuleFunc {
115 name: Cow::Owned(name),
116 description: description.map(Cow::Owned),
117 input,
118 output,
119 };
120
121 Ok(func)
122 }
123}
124
125fn build_output_schema(
132 spec: &OutputSpec,
133 ok_type: &Type,
134 builder: &SchemaBuilder,
135) -> syn::Result<PyroSchema<'static>> {
136 match spec {
137 OutputSpec::SingleField(field_name) => {
139 let data_type = builder.resolve_type(ok_type);
140 let nullable = SchemaBuilder::is_option(ok_type);
141 let field = PyroField::new(Cow::Owned(field_name.to_string()), data_type, nullable);
142 Ok(PyroSchema::new(vec![field]))
143 }
144
145 OutputSpec::TupleFields(field_names) => {
147 let tuple_types = extract_tuple_types(ok_type)?;
148
149 if tuple_types.len() != field_names.len() {
150 return Err(syn::Error::new_spanned(
151 ok_type,
152 format!(
153 "output field count ({}) does not match tuple element count ({})",
154 field_names.len(),
155 tuple_types.len()
156 ),
157 ));
158 }
159
160 let fields: Vec<PyroField<'static>> = field_names
161 .iter()
162 .zip(tuple_types.iter())
163 .map(|(name, ty)| {
164 let data_type = builder.resolve_type(ty);
165 let nullable = SchemaBuilder::is_option(ty);
166 PyroField::new(Cow::Owned(name.to_string()), data_type, nullable)
167 })
168 .collect();
169
170 Ok(PyroSchema::new(fields))
171 }
172
173 OutputSpec::Struct => {
175 let schema = match ok_type {
178 Type::Path(type_path) => {
179 if let Some(seg) = type_path.path.segments.last() {
180 builder.schema_for(&seg.ident.to_string())
181 } else {
182 None
183 }
184 }
185 _ => None,
186 };
187
188 Ok(schema.map(|s| s.into_owned()).unwrap_or_else(|| {
189 let data_type = builder.resolve_type(ok_type);
191 let nullable = SchemaBuilder::is_option(ok_type);
192 PyroSchema::new(vec![PyroField::new(
193 Cow::Borrowed("output"),
194 data_type,
195 nullable,
196 )])
197 }))
198 }
199 }
200}
201
202fn extract_result_ok_type(ret: &ReturnType) -> syn::Result<&Type> {
204 match ret {
205 ReturnType::Default => Err(syn::Error::new(
206 proc_macro2::Span::call_site(),
207 "module function must return Result<T>",
208 )),
209 ReturnType::Type(_, ty) => {
210 if let Type::Path(type_path) = &**ty {
211 if let Some(seg) = type_path.path.segments.last() {
212 if seg.ident == "Result" {
213 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
214 if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
215 return Ok(ok_ty);
216 }
217 }
218 }
219 }
220 }
221 Err(syn::Error::new_spanned(
222 &**ty,
223 "module function must return Result<T>",
224 ))
225 }
226 }
227}
228
229fn extract_tuple_types(ty: &Type) -> syn::Result<Vec<&Type>> {
231 if let Type::Tuple(tuple) = ty {
232 Ok(tuple.elems.iter().collect())
233 } else {
234 Err(syn::Error::new_spanned(
235 ty,
236 "expected tuple return type for multi-field output",
237 ))
238 }
239}
240
241fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
243 let lines: Vec<String> = attrs
244 .iter()
245 .filter_map(|attr| {
246 if !attr.path().is_ident("doc") {
247 return None;
248 }
249 if let Meta::NameValue(nv) = &attr.meta {
250 if let Expr::Lit(expr_lit) = &nv.value {
251 if let Lit::Str(s) = &expr_lit.lit {
252 return Some(s.value().trim().to_string());
253 }
254 }
255 }
256 None
257 })
258 .collect();
259
260 if lines.is_empty() {
261 None
262 } else {
263 Some(lines.join("\n"))
264 }
265}
266
267#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
278 fn test_single_field_output() {
279 let src = r#"
280 #[module(output = message)]
281 fn call(input: &str) -> Result<String> {
282 Ok(format!("hello {}", input))
283 }
284 "#;
285
286 let v = generate_module_spec(src).unwrap().unwrap();
287
288 assert_eq!(v.name, "call");
289 assert!(v.description.is_none());
290
291 let in_fields = &v.input.fields;
293 assert_eq!(in_fields[0].name, "input");
294
295 let out_fields = &v.output.fields;
297 assert_eq!(out_fields[0].name, "message");
298 }
299
300 #[test]
303 fn test_tuple_output() {
304 let src = r#"
305 #[module(output = (score, label))]
306 fn classify(text: String) -> Result<(f32, String)> {
307 Ok((0.9, "positive".into()))
308 }
309 "#;
310
311 let v = generate_module_spec(src).unwrap().unwrap();
312
313 let out_fields = &v.output.fields;
314 assert_eq!(out_fields[0].name, "score");
315 assert_eq!(out_fields[1].name, "label");
316 }
317
318 #[test]
321 fn test_struct_output() {
322 let src = r#"
323 #[config]
324 struct Output {
325 embedding: Vec<f32>,
326 tokens: u32,
327 }
328
329 /// Embed a piece of text.
330 #[module(output = Output)]
331 fn embed(text: String, model: String) -> Result<Output> {
332 todo!()
333 }
334 "#;
335
336 let v = generate_module_spec(src).unwrap().unwrap();
337
338 assert_eq!(v.name, "embed");
339 assert_eq!(v.description.unwrap(), "Embed a piece of text.");
340
341 let in_fields = &v.input.fields;
342 assert_eq!(in_fields.len(), 2);
343 assert_eq!(in_fields[0].name, "text");
344 assert_eq!(in_fields[1].name, "model");
345
346 let out_fields = &v.output.fields;
347 assert_eq!(out_fields[0].name, "embedding");
348 assert_eq!(out_fields[1].name, "tokens");
349 }
350
351 #[test]
354 fn test_no_module_function() {
355 let src = r#"
356 fn plain(x: u32) -> u32 { x }
357 "#;
358 let result = generate_module_spec(src).unwrap();
359 assert!(result.is_none());
360 }
361}