1use std::borrow::Cow;
25
26use pyro_spec::{ModuleFunc, ModuleKind, PyroField, PyroSchema};
27use syn::{Attribute, Expr, FnArg, ItemFn, Lit, Meta, Pat, ReturnType, Type};
28
29use crate::struct_doc::SchemaBuilder;
30
31use super::parse::{ModuleAttrs, OutputSpec};
32
33pub fn generate_module_spec(
42 content: &str,
43 dep_interfaces: &[pyro_spec::InterfaceSpec<'static>],
44) -> syn::Result<Option<ModuleFunc<'static>>> {
45 let file = syn::parse_file(content)?;
46 let builder = SchemaBuilder::from_file(&file).with_foreign_specs(dep_interfaces);
47
48 for item in &file.items {
49 if let syn::Item::Fn(item_fn) = item {
50 if !super::has_module_attr(&item_fn.attrs) {
51 continue;
52 }
53
54 let attr_tokens = super::extract_module_attr(&item_fn.attrs)?.ok_or_else(|| {
55 syn::Error::new_spanned(
56 item_fn,
57 "Module attribute requires arguments: #[module(output = ...)]",
58 )
59 })?;
60
61 let attrs: ModuleAttrs = syn::parse2(attr_tokens)?;
62 let spec = ModuleSpecBuilder::build(item_fn, &attrs, &builder)?;
63
64 return Ok(Some(spec));
65 }
66 }
67
68 Ok(None)
69}
70
71pub struct ModuleSpecBuilder;
76
77impl ModuleSpecBuilder {
78 pub fn build(
80 item_fn: &ItemFn,
81 attrs: &ModuleAttrs,
82 builder: &SchemaBuilder,
83 ) -> syn::Result<ModuleFunc<'static>> {
84 let name = item_fn.sig.ident.to_string();
85 let description = extract_doc_string(&item_fn.attrs);
86
87 let input_fields: Vec<PyroField<'static>> = item_fn
89 .sig
90 .inputs
91 .iter()
92 .filter_map(|arg| {
93 if let FnArg::Typed(pat_type) = arg
94 && let Pat::Ident(pat_ident) = &*pat_type.pat
95 {
96 let field_name = pat_ident.ident.to_string();
97 let ty = &*pat_type.ty;
98 let data_type = builder.resolve_type(ty);
99 let nullable = SchemaBuilder::is_option(ty);
100 let doc = extract_doc_string(&pat_type.attrs);
101 let mut field = PyroField::new(Cow::Owned(field_name), data_type, nullable);
102 if let Some(d) = doc {
103 field = field.add_docstring(Cow::Owned(d));
104 }
105 return Some(field);
106 }
107 None
108 })
109 .collect();
110
111 let input = PyroSchema::new(input_fields);
112
113 let ok_type = extract_result_ok_type(&item_fn.sig.output)?;
115 let ok_type = if attrs.session {
116 if let Type::Path(inner_path) = ok_type
117 && let Some(seg) = inner_path.path.segments.last()
118 && seg.ident == "SessionResponse"
119 && let syn::PathArguments::AngleBracketed(inner_args) = &seg.arguments
120 && let Some(syn::GenericArgument::Type(output_ty)) = inner_args.args.first()
121 {
122 output_ty
123 } else {
124 ok_type
125 }
126 } else {
127 ok_type
128 };
129 let output = build_output_schema(&attrs.output, ok_type, builder)?;
130
131 let kind = if attrs.session {
132 let num_inputs = item_fn.sig.inputs.len();
133 if num_inputs == 2 {
134 ModuleKind::Session
135 } else if num_inputs == 3 {
136 ModuleKind::SessionDiff
137 } else {
138 ModuleKind::Normal
139 }
140 } else {
141 ModuleKind::Normal
142 };
143
144 let func = ModuleFunc {
145 name: Cow::Owned(name),
146 description: description.map(Cow::Owned),
147 input,
148 output,
149 kind,
150 };
151
152 Ok(func)
153 }
154}
155
156fn build_output_schema(
163 spec: &OutputSpec,
164 ok_type: &Type,
165 builder: &SchemaBuilder,
166) -> syn::Result<PyroSchema<'static>> {
167 match spec {
168 OutputSpec::SingleField(field_name) => {
170 let data_type = builder.resolve_type(ok_type);
171 let nullable = SchemaBuilder::is_option(ok_type);
172 let field = PyroField::new(Cow::Owned(field_name.to_string()), data_type, nullable);
173 Ok(PyroSchema::new(vec![field]))
174 }
175
176 OutputSpec::TupleFields(field_names) => {
178 let tuple_types = extract_tuple_types(ok_type)?;
179
180 if tuple_types.len() != field_names.len() {
181 return Err(syn::Error::new_spanned(
182 ok_type,
183 format!(
184 "output field count ({}) does not match tuple element count ({})",
185 field_names.len(),
186 tuple_types.len()
187 ),
188 ));
189 }
190
191 let fields: Vec<PyroField<'static>> = field_names
192 .iter()
193 .zip(tuple_types.iter())
194 .map(|(name, ty)| {
195 let data_type = builder.resolve_type(ty);
196 let nullable = SchemaBuilder::is_option(ty);
197 PyroField::new(Cow::Owned(name.to_string()), data_type, nullable)
198 })
199 .collect();
200
201 Ok(PyroSchema::new(fields))
202 }
203
204 OutputSpec::Struct => {
206 let schema = match ok_type {
209 Type::Path(type_path) => {
210 if let Some(seg) = type_path.path.segments.last() {
211 builder.schema_for(&seg.ident.to_string())
212 } else {
213 None
214 }
215 }
216 _ => None,
217 };
218
219 Ok(schema.map(|s| s.into_owned()).unwrap_or_else(|| {
220 let data_type = builder.resolve_type(ok_type);
222 let nullable = SchemaBuilder::is_option(ok_type);
223 PyroSchema::new(vec![PyroField::new(
224 Cow::Borrowed("output"),
225 data_type,
226 nullable,
227 )])
228 }))
229 }
230 }
231}
232
233fn extract_result_ok_type(ret: &ReturnType) -> syn::Result<&Type> {
235 match ret {
236 ReturnType::Default => Err(syn::Error::new(
237 proc_macro2::Span::call_site(),
238 "module function must return Result<T>",
239 )),
240 ReturnType::Type(_, ty) => {
241 if let Type::Path(type_path) = &**ty
242 && let Some(seg) = type_path.path.segments.last()
243 && seg.ident == "Result"
244 && let syn::PathArguments::AngleBracketed(args) = &seg.arguments
245 && let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first()
246 {
247 return Ok(ok_ty);
248 }
249 Err(syn::Error::new_spanned(
250 &**ty,
251 "module function must return Result<T>",
252 ))
253 }
254 }
255}
256
257fn extract_tuple_types(ty: &Type) -> syn::Result<Vec<&Type>> {
259 if let Type::Tuple(tuple) = ty {
260 Ok(tuple.elems.iter().collect())
261 } else {
262 Err(syn::Error::new_spanned(
263 ty,
264 "expected tuple return type for multi-field output",
265 ))
266 }
267}
268
269fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
271 let lines: Vec<String> = attrs
272 .iter()
273 .filter_map(|attr| {
274 if !attr.path().is_ident("doc") {
275 return None;
276 }
277 if let Meta::NameValue(nv) = &attr.meta
278 && let Expr::Lit(expr_lit) = &nv.value
279 && let Lit::Str(s) = &expr_lit.lit
280 {
281 return Some(s.value().trim().to_string());
282 }
283 None
284 })
285 .collect();
286
287 if lines.is_empty() {
288 None
289 } else {
290 Some(lines.join("\n"))
291 }
292}
293
294#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
305 fn test_single_field_output() {
306 let src = r#"
307 #[module(output = message)]
308 fn call(input: &str) -> Result<String> {
309 Ok(format!("hello {}", input))
310 }
311 "#;
312
313 let v = generate_module_spec(src, &[]).unwrap().unwrap();
314
315 assert_eq!(v.name, "call");
316 assert!(v.description.is_none());
317
318 let in_fields = &v.input.fields;
320 assert_eq!(in_fields[0].name, "input");
321
322 let out_fields = &v.output.fields;
324 assert_eq!(out_fields[0].name, "message");
325 }
326
327 #[test]
330 fn test_tuple_output() {
331 let src = r#"
332 #[module(output = (score, label))]
333 fn classify(text: String) -> Result<(f32, String)> {
334 Ok((0.9, "positive".into()))
335 }
336 "#;
337
338 let v = generate_module_spec(src, &[]).unwrap().unwrap();
339
340 let out_fields = &v.output.fields;
341 assert_eq!(out_fields[0].name, "score");
342 assert_eq!(out_fields[1].name, "label");
343 }
344
345 #[test]
348 fn test_struct_output() {
349 let src = r#"
350 #[config]
351 struct Output {
352 embedding: Vec<f32>,
353 tokens: u32,
354 }
355
356 /// Embed a piece of text.
357 #[module(output = Output)]
358 fn embed(text: String, model: String) -> Result<Output> {
359 todo!()
360 }
361 "#;
362
363 let v = generate_module_spec(src, &[]).unwrap().unwrap();
364
365 assert_eq!(v.name, "embed");
366 assert_eq!(v.description.unwrap(), "Embed a piece of text.");
367
368 let in_fields = &v.input.fields;
370 assert_eq!(in_fields.len(), 2);
371 assert_eq!(in_fields[0].name, "text");
372 assert_eq!(in_fields[1].name, "model");
373
374 let out_fields = &v.output.fields;
375 assert_eq!(out_fields[0].name, "embedding");
376 assert_eq!(out_fields[1].name, "tokens");
377 }
378
379 #[test]
382 fn test_session_foreign_struct() {
383 use std::collections::BTreeMap;
384 use pyro_spec::{InterfaceSpec, PyroField, PyroSchema, PyroType};
385
386 let src = r#"
387 #[module(session, output = ChatMessage)]
388 fn process(
389 prior: Vec<ChatMessage>,
390 input: ChatMessageRef<'_>,
391 ) -> Result<SessionResponse<ChatMessage>> {
392 todo!()
393 }
394 "#;
395
396 let mut structs = BTreeMap::new();
398 structs.insert(
399 Cow::Borrowed("ChatMessage"),
400 PyroSchema::new(vec![
401 PyroField::new("role", PyroType::Str, false),
402 PyroField::new("content", PyroType::Str, false),
403 ]),
404 );
405
406 let dep = InterfaceSpec {
407 capability: Cow::Borrowed("llm"),
408 description: None,
409 classes: vec![],
410 structs,
411 };
412
413 let v = generate_module_spec(src, &[dep]).unwrap().unwrap();
414
415 assert_eq!(v.name, "process");
416 assert_eq!(v.kind, pyro_spec::ModuleKind::Session);
417
418 let in_fields = &v.input.fields;
420 assert_eq!(in_fields.len(), 2);
421 assert_eq!(in_fields[0].name, "prior");
422
423 if let PyroType::List(inner, nullable) = &in_fields[0].data_type {
425 assert!(!nullable);
426 if let PyroType::Group(fields) = inner.as_ref() {
427 assert_eq!(fields.len(), 2);
428 assert_eq!(fields[0].name, "role");
429 assert_eq!(fields[1].name, "content");
430 } else {
431 panic!("Expected Group inner type for prior list");
432 }
433 } else {
434 panic!("Expected List type for prior field");
435 }
436
437 assert_eq!(in_fields[1].name, "input");
439 if let PyroType::Group(fields) = &in_fields[1].data_type {
440 assert_eq!(fields.len(), 2);
441 assert_eq!(fields[0].name, "role");
442 assert_eq!(fields[1].name, "content");
443 } else {
444 panic!("Expected Group type for input field");
445 }
446
447 let out_fields = &v.output.fields;
450 assert_eq!(out_fields.len(), 1);
451 assert_eq!(out_fields[0].name, "output");
452 if let PyroType::Group(fields) = &out_fields[0].data_type {
453 assert_eq!(fields.len(), 2);
454 assert_eq!(fields[0].name, "role");
455 assert_eq!(fields[1].name, "content");
456 } else {
457 panic!("Expected Group type for output field");
458 }
459 }
460
461 #[test]
464 fn test_no_module_function() {
465 let src = r#"
466 fn plain(x: u32) -> u32 { x }
467 "#;
468 let result = generate_module_spec(src, &[]).unwrap();
469 assert!(result.is_none());
470 }
471}