Skip to main content

pyro_macro/module/
spec.rs

1//! Generates module interface specifications using the Pyro type system.
2//!
3//! Parses a module source file, locates the `#[module(...)]` function, and
4//! builds a [`ModuleFunc`] describing its input parameters and output schema.
5//! The result is serialised to JSON and written to `module.json` alongside the
6//! compiled artefact.
7//!
8//! Input schema
9//! ------------
10//! Every typed parameter of the annotated function becomes a [`PyroField`] in
11//! the input [`PyroSchema`].  Type resolution is delegated to
12//! [`SchemaBuilder`], so struct parameters expand into `Group(fields)`.
13//!
14//! Output schema
15//! -------------
16//! The output schema is derived from the `output = …` argument:
17//!
18//! | `output =`         | generated `__Output` struct                        |
19//! |--------------------|---------------------------------------------------|
20//! | `field`            | `{ field: <ReturnType> }`                          |
21//! | `(f1, f2, …)`      | `{ f1: T1, f2: T2, … }` (from tuple return type)  |
22//! | `StructName`       | fields of the named struct looked up in the file   |
23
24use std::borrow::Cow;
25
26use pyro_spec::{ModuleFunc, PyroField, PyroSchema};
27use syn::{Attribute, Expr, FnArg, ItemFn, Lit, Meta, Pat, ReturnType, Type};
28
29use crate::struct_doc::SchemaBuilder;
30
31use super::parse::{ModuleAttrs, OutputSpec};
32
33// =============================================================================
34// Public entry point
35// =============================================================================
36
37/// Parse `content` (a module source file), locate the `#[module(...)]`
38/// function, and return a pretty-printed JSON string describing it.
39///
40/// Returns `None` when no `#[module(...)]` function is found.
41pub fn generate_module_spec(content: &str) -> syn::Result<Option<ModuleFunc<'static>>> {
42    let file = syn::parse_file(content)?;
43    let builder = SchemaBuilder::from_file(&file);
44
45    for item in &file.items {
46        if let syn::Item::Fn(item_fn) = item {
47            if !super::has_module_attr(&item_fn.attrs) {
48                continue;
49            }
50
51            let attr_tokens = super::extract_module_attr(&item_fn.attrs)?.ok_or_else(|| {
52                syn::Error::new_spanned(
53                    item_fn,
54                    "Module attribute requires arguments: #[module(output = ...)]",
55                )
56            })?;
57
58            let attrs: ModuleAttrs = syn::parse2(attr_tokens)?;
59            let spec = ModuleSpecBuilder::build(item_fn, &attrs, &builder)?;
60
61            return Ok(Some(spec.into()));
62        }
63    }
64
65    Ok(None)
66}
67
68// =============================================================================
69// Builder
70// =============================================================================
71
72pub struct ModuleSpecBuilder;
73
74impl ModuleSpecBuilder {
75    /// Build a [`ModuleFuncSpec`] from a parsed function and its `#[module(...)]` attrs.
76    pub fn build(
77        item_fn: &ItemFn,
78        attrs: &ModuleAttrs,
79        builder: &SchemaBuilder,
80    ) -> syn::Result<ModuleFunc<'static>> {
81        let name = item_fn.sig.ident.to_string();
82        let description = extract_doc_string(&item_fn.attrs);
83
84        // ── Input schema ─────────────────────────────────────────────────────
85        let input_fields: Vec<PyroField<'static>> = item_fn
86            .sig
87            .inputs
88            .iter()
89            .filter_map(|arg| {
90                if let FnArg::Typed(pat_type) = arg {
91                    if let Pat::Ident(pat_ident) = &*pat_type.pat {
92                        let field_name = pat_ident.ident.to_string();
93                        let ty = &*pat_type.ty;
94                        let data_type = builder.resolve_type(ty);
95                        let nullable = SchemaBuilder::is_option(ty);
96                        let doc = extract_doc_string(&pat_type.attrs);
97                        let mut field = PyroField::new(Cow::Owned(field_name), data_type, nullable);
98                        if let Some(d) = doc {
99                            field = field.add_docstring(Cow::Owned(d));
100                        }
101                        return Some(field);
102                    }
103                }
104                None
105            })
106            .collect();
107
108        let input = PyroSchema::new(input_fields);
109
110        // ── Output schema ────────────────────────────────────────────────────
111        let ok_type = extract_result_ok_type(&item_fn.sig.output)?;
112        let output = build_output_schema(&attrs.output, &ok_type, builder)?;
113
114        let func = ModuleFunc {
115            name: Cow::Owned(name),
116            description: description.map(Cow::Owned),
117            input,
118            output,
119        };
120
121        Ok(func)
122    }
123}
124
125// =============================================================================
126// Helpers
127// =============================================================================
128
129/// Build the output [`PyroSchema`] from the `output = …` spec and the
130/// function's `Ok` return type.
131fn build_output_schema(
132    spec: &OutputSpec,
133    ok_type: &Type,
134    builder: &SchemaBuilder,
135) -> syn::Result<PyroSchema<'static>> {
136    match spec {
137        // output = single_field  →  { single_field: <ok_type> }
138        OutputSpec::SingleField(field_name) => {
139            let data_type = builder.resolve_type(ok_type);
140            let nullable = SchemaBuilder::is_option(ok_type);
141            let field = PyroField::new(Cow::Owned(field_name.to_string()), data_type, nullable);
142            Ok(PyroSchema::new(vec![field]))
143        }
144
145        // output = (f1, f2, …)  →  one field per tuple element
146        OutputSpec::TupleFields(field_names) => {
147            let tuple_types = extract_tuple_types(ok_type)?;
148
149            if tuple_types.len() != field_names.len() {
150                return Err(syn::Error::new_spanned(
151                    ok_type,
152                    format!(
153                        "output field count ({}) does not match tuple element count ({})",
154                        field_names.len(),
155                        tuple_types.len()
156                    ),
157                ));
158            }
159
160            let fields: Vec<PyroField<'static>> = field_names
161                .iter()
162                .zip(tuple_types.iter())
163                .map(|(name, ty)| {
164                    let data_type = builder.resolve_type(ty);
165                    let nullable = SchemaBuilder::is_option(ty);
166                    PyroField::new(Cow::Owned(name.to_string()), data_type, nullable)
167                })
168                .collect();
169
170            Ok(PyroSchema::new(fields))
171        }
172
173        // output = StructName  →  look up struct in the file registry
174        OutputSpec::Struct => {
175            // The return type must be a simple path — use it to look up the
176            // schema from the builder registry.
177            let schema = match ok_type {
178                Type::Path(type_path) => {
179                    if let Some(seg) = type_path.path.segments.last() {
180                        builder.schema_for(&seg.ident.to_string())
181                    } else {
182                        None
183                    }
184                }
185                _ => None,
186            };
187
188            Ok(schema.map(|s| s.into_owned()).unwrap_or_else(|| {
189                // Fallback: resolve as a single anonymous field
190                let data_type = builder.resolve_type(ok_type);
191                let nullable = SchemaBuilder::is_option(ok_type);
192                PyroSchema::new(vec![PyroField::new(
193                    Cow::Borrowed("output"),
194                    data_type,
195                    nullable,
196                )])
197            }))
198        }
199    }
200}
201
202/// Extract the `Ok` type from `Result<T, _>` or `Result<T>`.
203fn extract_result_ok_type(ret: &ReturnType) -> syn::Result<&Type> {
204    match ret {
205        ReturnType::Default => Err(syn::Error::new(
206            proc_macro2::Span::call_site(),
207            "module function must return Result<T>",
208        )),
209        ReturnType::Type(_, ty) => {
210            if let Type::Path(type_path) = &**ty {
211                if let Some(seg) = type_path.path.segments.last() {
212                    if seg.ident == "Result" {
213                        if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
214                            if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
215                                return Ok(ok_ty);
216                            }
217                        }
218                    }
219                }
220            }
221            Err(syn::Error::new_spanned(
222                &**ty,
223                "module function must return Result<T>",
224            ))
225        }
226    }
227}
228
229/// Extract element types from a tuple type `(T1, T2, …)`.
230fn extract_tuple_types(ty: &Type) -> syn::Result<Vec<&Type>> {
231    if let Type::Tuple(tuple) = ty {
232        Ok(tuple.elems.iter().collect())
233    } else {
234        Err(syn::Error::new_spanned(
235            ty,
236            "expected tuple return type for multi-field output",
237        ))
238    }
239}
240
241/// Collect `/// doc` comments from a slice of attributes into a single string.
242fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
243    let lines: Vec<String> = attrs
244        .iter()
245        .filter_map(|attr| {
246            if !attr.path().is_ident("doc") {
247                return None;
248            }
249            if let Meta::NameValue(nv) = &attr.meta {
250                if let Expr::Lit(expr_lit) = &nv.value {
251                    if let Lit::Str(s) = &expr_lit.lit {
252                        return Some(s.value().trim().to_string());
253                    }
254                }
255            }
256            None
257        })
258        .collect();
259
260    if lines.is_empty() {
261        None
262    } else {
263        Some(lines.join("\n"))
264    }
265}
266
267// =============================================================================
268// Tests
269// =============================================================================
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    // ── Single field output ──────────────────────────────────────────────────
276
277    #[test]
278    fn test_single_field_output() {
279        let src = r#"
280            #[module(output = message)]
281            fn call(input: &str) -> Result<String> {
282                Ok(format!("hello {}", input))
283            }
284        "#;
285
286        let v = generate_module_spec(src).unwrap().unwrap();
287
288        assert_eq!(v.name, "call");
289        assert!(v.description.is_none());
290
291        // input: one field called `input` of type Str
292        let in_fields = &v.input.fields;
293        assert_eq!(in_fields[0].name, "input");
294
295        // output: one field called `message` of type Str
296        let out_fields = &v.output.fields;
297        assert_eq!(out_fields[0].name, "message");
298    }
299
300    // ── Tuple field output ───────────────────────────────────────────────────
301
302    #[test]
303    fn test_tuple_output() {
304        let src = r#"
305            #[module(output = (score, label))]
306            fn classify(text: String) -> Result<(f32, String)> {
307                Ok((0.9, "positive".into()))
308            }
309        "#;
310
311        let v = generate_module_spec(src).unwrap().unwrap();
312
313        let out_fields = &v.output.fields;
314        assert_eq!(out_fields[0].name, "score");
315        assert_eq!(out_fields[1].name, "label");
316    }
317
318    // ── Struct output ────────────────────────────────────────────────────────
319
320    #[test]
321    fn test_struct_output() {
322        let src = r#"
323            #[config]
324            struct Output {
325                embedding: Vec<f32>,
326                tokens: u32,
327            }
328
329            /// Embed a piece of text.
330            #[module(output = Output)]
331            fn embed(text: String, model: String) -> Result<Output> {
332                todo!()
333            }
334        "#;
335
336        let v = generate_module_spec(src).unwrap().unwrap();
337
338        assert_eq!(v.name, "embed");
339        assert_eq!(v.description.unwrap(), "Embed a piece of text.");
340
341        let in_fields = &v.input.fields;
342        assert_eq!(in_fields.len(), 2);
343        assert_eq!(in_fields[0].name, "text");
344        assert_eq!(in_fields[1].name, "model");
345
346        let out_fields = &v.output.fields;
347        assert_eq!(out_fields[0].name, "embedding");
348        assert_eq!(out_fields[1].name, "tokens");
349    }
350
351    // ── No module function ───────────────────────────────────────────────────
352
353    #[test]
354    fn test_no_module_function() {
355        let src = r#"
356            fn plain(x: u32) -> u32 { x }
357        "#;
358        let result = generate_module_spec(src).unwrap();
359        assert!(result.is_none());
360    }
361}