Skip to main content

pyro_macro/module/
spec.rs

1//! Generates module interface specifications using the Pyro type system.
2//!
3//! Parses a module source file, locates the `#[module(...)]` function, and
4//! builds a [`ModuleFunc`] describing its input parameters and output schema.
5//! The result is serialised to JSON and written to `module.json` alongside the
6//! compiled artefact.
7//!
8//! Input schema
9//! ------------
10//! Every typed parameter of the annotated function becomes a [`PyroField`] in
11//! the input [`PyroSchema`].  Type resolution is delegated to
12//! [`SchemaBuilder`], so struct parameters expand into `Group(fields)`.
13//!
14//! Output schema
15//! -------------
16//! The output schema is derived from the `output = …` argument:
17//!
18//! | `output =`         | generated `__Output` struct                        |
19//! |--------------------|---------------------------------------------------|
20//! | `field`            | `{ field: <ReturnType> }`                          |
21//! | `(f1, f2, …)`      | `{ f1: T1, f2: T2, … }` (from tuple return type)  |
22//! | `StructName`       | fields of the named struct looked up in the file   |
23
24use std::borrow::Cow;
25
26use pyro_spec::{ModuleFunc, ModuleKind, PyroField, PyroSchema};
27use syn::{Attribute, Expr, FnArg, ItemFn, Lit, Meta, Pat, ReturnType, Type};
28
29use crate::struct_doc::SchemaBuilder;
30
31use super::parse::{ModuleAttrs, OutputSpec};
32
33// =============================================================================
34// Public entry point
35// =============================================================================
36
37/// Parse `content` (a module source file), locate the `#[module(...)]`
38/// function, and return a pretty-printed JSON string describing it.
39///
40/// Returns `None` when no `#[module(...)]` function is found.
41pub fn generate_module_spec(
42    content: &str,
43    dep_interfaces: &[pyro_spec::InterfaceSpec<'static>],
44) -> syn::Result<Option<ModuleFunc<'static>>> {
45    let file = syn::parse_file(content)?;
46    let builder = SchemaBuilder::from_file(&file).with_foreign_specs(dep_interfaces);
47
48    for item in &file.items {
49        if let syn::Item::Fn(item_fn) = item {
50            if !super::has_module_attr(&item_fn.attrs) {
51                continue;
52            }
53
54            let attr_tokens = super::extract_module_attr(&item_fn.attrs)?.ok_or_else(|| {
55                syn::Error::new_spanned(
56                    item_fn,
57                    "Module attribute requires arguments: #[module(output = ...)]",
58                )
59            })?;
60
61            let attrs: ModuleAttrs = syn::parse2(attr_tokens)?;
62            let spec = ModuleSpecBuilder::build(item_fn, &attrs, &builder)?;
63
64            return Ok(Some(spec));
65        }
66    }
67
68    Ok(None)
69}
70
71// =============================================================================
72// Builder
73// =============================================================================
74
75pub struct ModuleSpecBuilder;
76
77impl ModuleSpecBuilder {
78    /// Build a [`ModuleFuncSpec`] from a parsed function and its `#[module(...)]` attrs.
79    pub fn build(
80        item_fn: &ItemFn,
81        attrs: &ModuleAttrs,
82        builder: &SchemaBuilder,
83    ) -> syn::Result<ModuleFunc<'static>> {
84        let name = item_fn.sig.ident.to_string();
85        let description = extract_doc_string(&item_fn.attrs);
86
87        // ── Input schema ─────────────────────────────────────────────────────
88        let input_fields: Vec<PyroField<'static>> = item_fn
89            .sig
90            .inputs
91            .iter()
92            .filter_map(|arg| {
93                if let FnArg::Typed(pat_type) = arg
94                    && let Pat::Ident(pat_ident) = &*pat_type.pat
95                {
96                    let field_name = pat_ident.ident.to_string();
97                    let ty = &*pat_type.ty;
98                    let data_type = builder.resolve_type(ty);
99                    let nullable = SchemaBuilder::is_option(ty);
100                    let doc = extract_doc_string(&pat_type.attrs);
101                    let mut field = PyroField::new(Cow::Owned(field_name), data_type, nullable);
102                    if let Some(d) = doc {
103                        field = field.add_docstring(Cow::Owned(d));
104                    }
105                    return Some(field);
106                }
107                None
108            })
109            .collect();
110
111        let input = PyroSchema::new(input_fields);
112
113        // ── Output schema ────────────────────────────────────────────────────
114        let ok_type = extract_result_ok_type(&item_fn.sig.output)?;
115        let ok_type = if attrs.session {
116            if let Type::Path(inner_path) = ok_type
117                && let Some(seg) = inner_path.path.segments.last()
118                && seg.ident == "SessionResponse"
119                && let syn::PathArguments::AngleBracketed(inner_args) = &seg.arguments
120                && let Some(syn::GenericArgument::Type(output_ty)) = inner_args.args.first()
121            {
122                output_ty
123            } else {
124                ok_type
125            }
126        } else {
127            ok_type
128        };
129        let output = build_output_schema(&attrs.output, ok_type, builder)?;
130
131        let kind = if attrs.session {
132            let num_inputs = item_fn.sig.inputs.len();
133            if num_inputs == 2 {
134                ModuleKind::Session
135            } else if num_inputs == 3 {
136                ModuleKind::SessionDiff
137            } else {
138                ModuleKind::Normal
139            }
140        } else {
141            ModuleKind::Normal
142        };
143
144        let func = ModuleFunc {
145            name: Cow::Owned(name),
146            description: description.map(Cow::Owned),
147            input,
148            output,
149            kind,
150        };
151
152        Ok(func)
153    }
154}
155
156// =============================================================================
157// Helpers
158// =============================================================================
159
160/// Build the output [`PyroSchema`] from the `output = …` spec and the
161/// function's `Ok` return type.
162fn build_output_schema(
163    spec: &OutputSpec,
164    ok_type: &Type,
165    builder: &SchemaBuilder,
166) -> syn::Result<PyroSchema<'static>> {
167    match spec {
168        // output = single_field  →  { single_field: <ok_type> }
169        OutputSpec::SingleField(field_name) => {
170            let data_type = builder.resolve_type(ok_type);
171            let nullable = SchemaBuilder::is_option(ok_type);
172            let field = PyroField::new(Cow::Owned(field_name.to_string()), data_type, nullable);
173            Ok(PyroSchema::new(vec![field]))
174        }
175
176        // output = (f1, f2, …)  →  one field per tuple element
177        OutputSpec::TupleFields(field_names) => {
178            let tuple_types = extract_tuple_types(ok_type)?;
179
180            if tuple_types.len() != field_names.len() {
181                return Err(syn::Error::new_spanned(
182                    ok_type,
183                    format!(
184                        "output field count ({}) does not match tuple element count ({})",
185                        field_names.len(),
186                        tuple_types.len()
187                    ),
188                ));
189            }
190
191            let fields: Vec<PyroField<'static>> = field_names
192                .iter()
193                .zip(tuple_types.iter())
194                .map(|(name, ty)| {
195                    let data_type = builder.resolve_type(ty);
196                    let nullable = SchemaBuilder::is_option(ty);
197                    PyroField::new(Cow::Owned(name.to_string()), data_type, nullable)
198                })
199                .collect();
200
201            Ok(PyroSchema::new(fields))
202        }
203
204        // output = StructName  →  look up struct in the file registry
205        OutputSpec::Struct => {
206            // The return type must be a simple path — use it to look up the
207            // schema from the builder registry.
208            let schema = match ok_type {
209                Type::Path(type_path) => {
210                    if let Some(seg) = type_path.path.segments.last() {
211                        builder.schema_for(&seg.ident.to_string())
212                    } else {
213                        None
214                    }
215                }
216                _ => None,
217            };
218
219            Ok(schema.map(|s| s.into_owned()).unwrap_or_else(|| {
220                // Fallback: resolve as a single anonymous field
221                let data_type = builder.resolve_type(ok_type);
222                let nullable = SchemaBuilder::is_option(ok_type);
223                PyroSchema::new(vec![PyroField::new(
224                    Cow::Borrowed("output"),
225                    data_type,
226                    nullable,
227                )])
228            }))
229        }
230    }
231}
232
233/// Extract the `Ok` type from `Result<T, _>` or `Result<T>`.
234fn extract_result_ok_type(ret: &ReturnType) -> syn::Result<&Type> {
235    match ret {
236        ReturnType::Default => Err(syn::Error::new(
237            proc_macro2::Span::call_site(),
238            "module function must return Result<T>",
239        )),
240        ReturnType::Type(_, ty) => {
241            if let Type::Path(type_path) = &**ty
242                && let Some(seg) = type_path.path.segments.last()
243                && seg.ident == "Result"
244                && let syn::PathArguments::AngleBracketed(args) = &seg.arguments
245                && let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first()
246            {
247                return Ok(ok_ty);
248            }
249            Err(syn::Error::new_spanned(
250                &**ty,
251                "module function must return Result<T>",
252            ))
253        }
254    }
255}
256
257/// Extract element types from a tuple type `(T1, T2, …)`.
258fn extract_tuple_types(ty: &Type) -> syn::Result<Vec<&Type>> {
259    if let Type::Tuple(tuple) = ty {
260        Ok(tuple.elems.iter().collect())
261    } else {
262        Err(syn::Error::new_spanned(
263            ty,
264            "expected tuple return type for multi-field output",
265        ))
266    }
267}
268
269/// Collect `/// doc` comments from a slice of attributes into a single string.
270fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
271    let lines: Vec<String> = attrs
272        .iter()
273        .filter_map(|attr| {
274            if !attr.path().is_ident("doc") {
275                return None;
276            }
277            if let Meta::NameValue(nv) = &attr.meta
278                && let Expr::Lit(expr_lit) = &nv.value
279                && let Lit::Str(s) = &expr_lit.lit
280            {
281                return Some(s.value().trim().to_string());
282            }
283            None
284        })
285        .collect();
286
287    if lines.is_empty() {
288        None
289    } else {
290        Some(lines.join("\n"))
291    }
292}
293
294// =============================================================================
295// Tests
296// =============================================================================
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    // ── Single field output ──────────────────────────────────────────────────
303
304    #[test]
305    fn test_single_field_output() {
306        let src = r#"
307            #[module(output = message)]
308            fn call(input: &str) -> Result<String> {
309                Ok(format!("hello {}", input))
310            }
311        "#;
312
313        let v = generate_module_spec(src, &[]).unwrap().unwrap();
314
315        assert_eq!(v.name, "call");
316        assert!(v.description.is_none());
317
318        // input: one field called `input` of type Str
319        let in_fields = &v.input.fields;
320        assert_eq!(in_fields[0].name, "input");
321
322        // output: one field called `message` of type Str
323        let out_fields = &v.output.fields;
324        assert_eq!(out_fields[0].name, "message");
325    }
326
327    // ── Tuple field output ───────────────────────────────────────────────────
328
329    #[test]
330    fn test_tuple_output() {
331        let src = r#"
332            #[module(output = (score, label))]
333            fn classify(text: String) -> Result<(f32, String)> {
334                Ok((0.9, "positive".into()))
335            }
336        "#;
337
338        let v = generate_module_spec(src, &[]).unwrap().unwrap();
339
340        let out_fields = &v.output.fields;
341        assert_eq!(out_fields[0].name, "score");
342        assert_eq!(out_fields[1].name, "label");
343    }
344
345    // ── Struct output ────────────────────────────────────────────────────────
346
347    #[test]
348    fn test_struct_output() {
349        let src = r#"
350            #[config]
351            struct Output {
352                embedding: Vec<f32>,
353                tokens: u32,
354            }
355
356            /// Embed a piece of text.
357            #[module(output = Output)]
358            fn embed(text: String, model: String) -> Result<Output> {
359                todo!()
360            }
361        "#;
362
363        let v = generate_module_spec(src, &[]).unwrap().unwrap();
364
365        assert_eq!(v.name, "embed");
366        assert_eq!(v.description.unwrap(), "Embed a piece of text.");
367
368        // input: text and model
369        let in_fields = &v.input.fields;
370        assert_eq!(in_fields.len(), 2);
371        assert_eq!(in_fields[0].name, "text");
372        assert_eq!(in_fields[1].name, "model");
373
374        let out_fields = &v.output.fields;
375        assert_eq!(out_fields[0].name, "embedding");
376        assert_eq!(out_fields[1].name, "tokens");
377    }
378
379    // ── Session module with foreign struct output ────────────────────────────
380
381    #[test]
382    fn test_session_foreign_struct() {
383        use std::collections::BTreeMap;
384        use pyro_spec::{InterfaceSpec, PyroField, PyroSchema, PyroType};
385
386        let src = r#"
387            #[module(session, output = ChatMessage)]
388            fn process(
389                prior: Vec<ChatMessage>,
390                input: ChatMessageRef<'_>,
391            ) -> Result<SessionResponse<ChatMessage>> {
392                todo!()
393            }
394        "#;
395
396        // Create a mock dependency interface that declares ChatMessage struct
397        let mut structs = BTreeMap::new();
398        structs.insert(
399            Cow::Borrowed("ChatMessage"),
400            PyroSchema::new(vec![
401                PyroField::new("role", PyroType::Str, false),
402                PyroField::new("content", PyroType::Str, false),
403            ]),
404        );
405
406        let dep = InterfaceSpec {
407            capability: Cow::Borrowed("llm"),
408            description: None,
409            classes: vec![],
410            structs,
411        };
412
413        let v = generate_module_spec(src, &[dep]).unwrap().unwrap();
414
415        assert_eq!(v.name, "process");
416        assert_eq!(v.kind, pyro_spec::ModuleKind::Session);
417
418        // Check input fields
419        let in_fields = &v.input.fields;
420        assert_eq!(in_fields.len(), 2);
421        assert_eq!(in_fields[0].name, "prior");
422        
423        // prior should be Vec<ChatMessage> which resolves to List(Group([role, content]), false)
424        if let PyroType::List(inner, nullable) = &in_fields[0].data_type {
425            assert!(!nullable);
426            if let PyroType::Group(fields) = inner.as_ref() {
427                assert_eq!(fields.len(), 2);
428                assert_eq!(fields[0].name, "role");
429                assert_eq!(fields[1].name, "content");
430            } else {
431                panic!("Expected Group inner type for prior list");
432            }
433        } else {
434            panic!("Expected List type for prior field");
435        }
436
437        // input should be ChatMessageRef which resolves to Group([role, content])
438        assert_eq!(in_fields[1].name, "input");
439        if let PyroType::Group(fields) = &in_fields[1].data_type {
440            assert_eq!(fields.len(), 2);
441            assert_eq!(fields[0].name, "role");
442            assert_eq!(fields[1].name, "content");
443        } else {
444            panic!("Expected Group type for input field");
445        }
446
447        // Check output field (which was output = ChatMessage)
448        // Since it returns SessionResponse<ChatMessage>, we extract ChatMessage and fallback to a single field "output"
449        let out_fields = &v.output.fields;
450        assert_eq!(out_fields.len(), 1);
451        assert_eq!(out_fields[0].name, "output");
452        if let PyroType::Group(fields) = &out_fields[0].data_type {
453            assert_eq!(fields.len(), 2);
454            assert_eq!(fields[0].name, "role");
455            assert_eq!(fields[1].name, "content");
456        } else {
457            panic!("Expected Group type for output field");
458        }
459    }
460
461    // ── No module function ───────────────────────────────────────────────────
462
463    #[test]
464    fn test_no_module_function() {
465        let src = r#"
466            fn plain(x: u32) -> u32 { x }
467        "#;
468        let result = generate_module_spec(src, &[]).unwrap();
469        assert!(result.is_none());
470    }
471}