dars-macros 0.1.0

Declarative Agents in Rust
Documentation
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, format_ident, quote};
use syn::{
    Field, Ident, LitStr, Token, Type, Visibility, braced,
    parse::{Parse, ParseStream},
    spanned::Spanned,
};

use crate::util::parse_desc;

struct InputField {
    name: String,
    ty: Type,
    desc: Option<String>,
}

struct OutputField {
    name: String,
    ty: Type,
    desc: Option<String>,
}

pub struct Signature {
    vis: Visibility,
    name: Ident,
    instruction: Option<String>,
    inputs: Vec<InputField>,
    outputs: Vec<OutputField>,
}

impl Signature {
    pub(crate) fn with_instruction(self, instruction: impl Into<Option<String>>) -> Self {
        Self {
            instruction: instruction.into(),
            ..self
        }
    }
}

impl Parse for Signature {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let vis = input.parse::<Visibility>()?;
        let _ = input.parse::<Token![struct]>()?;
        let name: Ident = input.parse()?;

        // Extract the content of the struct
        let content;
        braced!(content in input);

        // Parse input/output fields
        let fields = content.parse_terminated(Field::parse_named, Token![,])?;

        let mut inputs = Vec::new();
        let mut outputs = Vec::new();
        for field in fields {
            let name = field.ident.expect("Missing field name").to_string();

            if field.attrs.is_empty() {
                panic!("Missing input/output attribute on field {}", name);
            }

            for attr in field.attrs {
                if attr.path().is_ident("input") {
                    inputs.push(InputField {
                        name,
                        ty: field.ty,
                        desc: parse_desc(&attr)?,
                    });
                    break;
                }
                if attr.path().is_ident("output") {
                    outputs.push(OutputField {
                        name,
                        ty: field.ty,
                        desc: parse_desc(&attr)?,
                    });
                    break;
                }
                return Err(syn::Error::new(
                    attr.span(),
                    format!("Unknown attribute on field {name}"),
                ));
            }
        }

        Ok(Signature {
            vis,
            name,
            instruction: None,
            inputs,
            outputs,
        })
    }
}

impl ToTokens for Signature {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let instruction = self
            .instruction
            .as_ref()
            .map(|s| s.trim().to_string())
            .unwrap_or_default();

        let name = &self.name;
        let vis = &self.vis;

        // Input fields
        let input_struct = format_ident!("{}Input", self.name);
        let inputs = self.inputs.iter().map(|input| {
            let name = Ident::new(&input.name, Span::call_site());
            let ty = input.ty.clone();
            match &input.desc {
                Some(desc) => {
                    let desc = LitStr::new(desc, Span::call_site());
                    quote! {
                        #[field(desc = #desc)]
                        pub #name: #ty
                    }
                }
                None => {
                    quote! {
                        pub #name: #ty
                    }
                }
            }
        });

        // Output fields
        let output_struct = format_ident!("{}Output", self.name);
        let outputs = self.outputs.iter().map(|output| {
            let name = Ident::new(&output.name, Span::call_site());
            let ty = output.ty.clone();
            match &output.desc {
                Some(desc) => {
                    let desc = LitStr::new(desc, Span::call_site());
                    quote! {
                        #[field(desc = #desc)]
                        pub #name: #ty
                    }
                }
                None => {
                    quote! {
                        pub #name: #ty
                    }
                }
            }
        });

        let fields = self
            .inputs
            .iter()
            .map(|f| (&f.name, &f.ty))
            .chain(self.outputs.iter().map(|f| (&f.name, &f.ty)))
            .map(|(name, ty)| {
                let name = name.as_str();
                let ty = ty.clone();
                quote! {
                    (#name.to_string(), da_rs::schemars::schema_for!(#ty))
                }
            });

        let expanded = quote! {
            // Input model struct
            #[Model]
            #vis struct #input_struct {
                #(#inputs,)*
            }

            // Output model struct
            #[Model]
            #vis struct #output_struct {
                #(#outputs,)*
            }

            // Base signature struct
            #[derive(Debug)]
            #vis struct #name {
                instruction: String,
                fields: std::collections::HashMap<String, da_rs::schemars::Schema>,
            }

            impl #name {
                #vis fn new() -> Self {
                    Self {
                        instruction: #instruction.into(),
                        fields: std::collections::HashMap::from_iter([
                            #(#fields,)*
                        ]),
                    }
                }
            }

            impl da_rs::Signature for #name {
                type Input = #input_struct;
                type Output = #output_struct;

                #[inline(always)]
                fn instruction(&self) -> &str {
                    &self.instruction
                }

                #[inline]
                fn input_fields(&self) -> &[da_rs::Field] {
                    <#input_struct as da_rs::Model>::fields()
                }

                #[inline]
                fn output_fields(&self) -> &[da_rs::Field] {
                    <#output_struct as da_rs::Model>::fields()
                }

                #[inline]
                fn field(&self, name: &str) -> Option<&da_rs::schemars::Schema> {
                    self.fields.get(name)
                }
            }
        };
        tokens.extend(expanded);
    }
}