use std::borrow::Cow;
use pyro_spec::{ModuleFunc, ModuleKind, PyroField, PyroSchema};
use syn::{Attribute, Expr, FnArg, ItemFn, Lit, Meta, Pat, ReturnType, Type};
use crate::struct_doc::SchemaBuilder;
use super::parse::{ModuleAttrs, OutputSpec};
pub fn generate_module_spec(
content: &str,
dep_interfaces: &[pyro_spec::InterfaceSpec<'static>],
) -> syn::Result<Option<ModuleFunc<'static>>> {
let file = syn::parse_file(content)?;
let builder = SchemaBuilder::from_file(&file).with_foreign_specs(dep_interfaces);
for item in &file.items {
if let syn::Item::Fn(item_fn) = item {
if !super::has_module_attr(&item_fn.attrs) {
continue;
}
let attr_tokens = super::extract_module_attr(&item_fn.attrs)?.ok_or_else(|| {
syn::Error::new_spanned(
item_fn,
"Module attribute requires arguments: #[module(output = ...)]",
)
})?;
let attrs: ModuleAttrs = syn::parse2(attr_tokens)?;
let spec = ModuleSpecBuilder::build(item_fn, &attrs, &builder)?;
return Ok(Some(spec));
}
}
Ok(None)
}
pub struct ModuleSpecBuilder;
impl ModuleSpecBuilder {
pub fn build(
item_fn: &ItemFn,
attrs: &ModuleAttrs,
builder: &SchemaBuilder,
) -> syn::Result<ModuleFunc<'static>> {
let name = item_fn.sig.ident.to_string();
let description = extract_doc_string(&item_fn.attrs);
let input_fields: Vec<PyroField<'static>> = item_fn
.sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat_type) = arg
&& let Pat::Ident(pat_ident) = &*pat_type.pat
{
let field_name = pat_ident.ident.to_string();
let ty = &*pat_type.ty;
let data_type = builder.resolve_type(ty);
let nullable = SchemaBuilder::is_option(ty);
let doc = extract_doc_string(&pat_type.attrs);
let mut field = PyroField::new(Cow::Owned(field_name), data_type, nullable);
if let Some(d) = doc {
field = field.add_docstring(Cow::Owned(d));
}
return Some(field);
}
None
})
.collect();
let input = PyroSchema::new(input_fields);
let ok_type = extract_result_ok_type(&item_fn.sig.output)?;
let ok_type = if attrs.session {
if let Type::Path(inner_path) = ok_type
&& let Some(seg) = inner_path.path.segments.last()
&& seg.ident == "SessionResponse"
&& let syn::PathArguments::AngleBracketed(inner_args) = &seg.arguments
&& let Some(syn::GenericArgument::Type(output_ty)) = inner_args.args.first()
{
output_ty
} else {
ok_type
}
} else {
ok_type
};
let output = build_output_schema(&attrs.output, ok_type, builder)?;
let kind = if attrs.session {
let num_inputs = item_fn.sig.inputs.len();
if num_inputs == 2 {
ModuleKind::Session
} else if num_inputs == 3 {
ModuleKind::SessionDiff
} else {
ModuleKind::Normal
}
} else {
ModuleKind::Normal
};
let func = ModuleFunc {
name: Cow::Owned(name),
description: description.map(Cow::Owned),
input,
output,
kind,
};
Ok(func)
}
}
fn build_output_schema(
spec: &OutputSpec,
ok_type: &Type,
builder: &SchemaBuilder,
) -> syn::Result<PyroSchema<'static>> {
match spec {
OutputSpec::SingleField(field_name) => {
let data_type = builder.resolve_type(ok_type);
let nullable = SchemaBuilder::is_option(ok_type);
let field = PyroField::new(Cow::Owned(field_name.to_string()), data_type, nullable);
Ok(PyroSchema::new(vec![field]))
}
OutputSpec::TupleFields(field_names) => {
let tuple_types = extract_tuple_types(ok_type)?;
if tuple_types.len() != field_names.len() {
return Err(syn::Error::new_spanned(
ok_type,
format!(
"output field count ({}) does not match tuple element count ({})",
field_names.len(),
tuple_types.len()
),
));
}
let fields: Vec<PyroField<'static>> = field_names
.iter()
.zip(tuple_types.iter())
.map(|(name, ty)| {
let data_type = builder.resolve_type(ty);
let nullable = SchemaBuilder::is_option(ty);
PyroField::new(Cow::Owned(name.to_string()), data_type, nullable)
})
.collect();
Ok(PyroSchema::new(fields))
}
OutputSpec::Struct => {
let schema = match ok_type {
Type::Path(type_path) => {
if let Some(seg) = type_path.path.segments.last() {
builder.schema_for(&seg.ident.to_string())
} else {
None
}
}
_ => None,
};
Ok(schema.map(|s| s.into_owned()).unwrap_or_else(|| {
let data_type = builder.resolve_type(ok_type);
let nullable = SchemaBuilder::is_option(ok_type);
PyroSchema::new(vec![PyroField::new(
Cow::Borrowed("output"),
data_type,
nullable,
)])
}))
}
}
}
fn extract_result_ok_type(ret: &ReturnType) -> syn::Result<&Type> {
match ret {
ReturnType::Default => Err(syn::Error::new(
proc_macro2::Span::call_site(),
"module function must return Result<T>",
)),
ReturnType::Type(_, ty) => {
if let Type::Path(type_path) = &**ty
&& let Some(seg) = type_path.path.segments.last()
&& seg.ident == "Result"
&& let syn::PathArguments::AngleBracketed(args) = &seg.arguments
&& let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first()
{
return Ok(ok_ty);
}
Err(syn::Error::new_spanned(
&**ty,
"module function must return Result<T>",
))
}
}
}
fn extract_tuple_types(ty: &Type) -> syn::Result<Vec<&Type>> {
if let Type::Tuple(tuple) = ty {
Ok(tuple.elems.iter().collect())
} else {
Err(syn::Error::new_spanned(
ty,
"expected tuple return type for multi-field output",
))
}
}
fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
let lines: Vec<String> = attrs
.iter()
.filter_map(|attr| {
if !attr.path().is_ident("doc") {
return None;
}
if let Meta::NameValue(nv) = &attr.meta
&& let Expr::Lit(expr_lit) = &nv.value
&& let Lit::Str(s) = &expr_lit.lit
{
return Some(s.value().trim().to_string());
}
None
})
.collect();
if lines.is_empty() {
None
} else {
Some(lines.join("\n"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_field_output() {
let src = r#"
#[module(output = message)]
fn call(input: &str) -> Result<String> {
Ok(format!("hello {}", input))
}
"#;
let v = generate_module_spec(src, &[]).unwrap().unwrap();
assert_eq!(v.name, "call");
assert!(v.description.is_none());
let in_fields = &v.input.fields;
assert_eq!(in_fields[0].name, "input");
let out_fields = &v.output.fields;
assert_eq!(out_fields[0].name, "message");
}
#[test]
fn test_tuple_output() {
let src = r#"
#[module(output = (score, label))]
fn classify(text: String) -> Result<(f32, String)> {
Ok((0.9, "positive".into()))
}
"#;
let v = generate_module_spec(src, &[]).unwrap().unwrap();
let out_fields = &v.output.fields;
assert_eq!(out_fields[0].name, "score");
assert_eq!(out_fields[1].name, "label");
}
#[test]
fn test_struct_output() {
let src = r#"
#[config]
struct Output {
embedding: Vec<f32>,
tokens: u32,
}
/// Embed a piece of text.
#[module(output = Output)]
fn embed(text: String, model: String) -> Result<Output> {
todo!()
}
"#;
let v = generate_module_spec(src, &[]).unwrap().unwrap();
assert_eq!(v.name, "embed");
assert_eq!(v.description.unwrap(), "Embed a piece of text.");
let in_fields = &v.input.fields;
assert_eq!(in_fields.len(), 2);
assert_eq!(in_fields[0].name, "text");
assert_eq!(in_fields[1].name, "model");
let out_fields = &v.output.fields;
assert_eq!(out_fields[0].name, "embedding");
assert_eq!(out_fields[1].name, "tokens");
}
#[test]
fn test_session_foreign_struct() {
use std::collections::BTreeMap;
use pyro_spec::{InterfaceSpec, PyroField, PyroSchema, PyroType};
let src = r#"
#[module(session, output = ChatMessage)]
fn process(
prior: Vec<ChatMessage>,
input: ChatMessageRef<'_>,
) -> Result<SessionResponse<ChatMessage>> {
todo!()
}
"#;
let mut structs = BTreeMap::new();
structs.insert(
Cow::Borrowed("ChatMessage"),
PyroSchema::new(vec![
PyroField::new("role", PyroType::Str, false),
PyroField::new("content", PyroType::Str, false),
]),
);
let dep = InterfaceSpec {
capability: Cow::Borrowed("llm"),
description: None,
classes: vec![],
structs,
};
let v = generate_module_spec(src, &[dep]).unwrap().unwrap();
assert_eq!(v.name, "process");
assert_eq!(v.kind, pyro_spec::ModuleKind::Session);
let in_fields = &v.input.fields;
assert_eq!(in_fields.len(), 2);
assert_eq!(in_fields[0].name, "prior");
if let PyroType::List(inner, nullable) = &in_fields[0].data_type {
assert!(!nullable);
if let PyroType::Group(fields) = inner.as_ref() {
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].name, "role");
assert_eq!(fields[1].name, "content");
} else {
panic!("Expected Group inner type for prior list");
}
} else {
panic!("Expected List type for prior field");
}
assert_eq!(in_fields[1].name, "input");
if let PyroType::Group(fields) = &in_fields[1].data_type {
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].name, "role");
assert_eq!(fields[1].name, "content");
} else {
panic!("Expected Group type for input field");
}
let out_fields = &v.output.fields;
assert_eq!(out_fields.len(), 1);
assert_eq!(out_fields[0].name, "output");
if let PyroType::Group(fields) = &out_fields[0].data_type {
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].name, "role");
assert_eq!(fields[1].name, "content");
} else {
panic!("Expected Group type for output field");
}
}
#[test]
fn test_no_module_function() {
let src = r#"
fn plain(x: u32) -> u32 { x }
"#;
let result = generate_module_spec(src, &[]).unwrap();
assert!(result.is_none());
}
}