pharia_skill_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    spanned::Spanned, AttrStyle, Expr, FnArg, GenericArgument, ItemFn, Lit, PathArguments,
5    ReturnType, Type,
6};
7
8fn report_error(msg: &str, span: proc_macro2::Span) -> TokenStream {
9    syn::Error::new(span, msg).to_compile_error().into()
10}
11
12const ARG_MSG: &str = "The skill function should take two arguments: first is `csi: &impl Csi`, second is `input` with a type that implements `serde::Deserialize` and `schemars::JsonSchema`.";
13const RETURN_MSG: &str = "The skill function should return a value that implements `serde::Serialize` and `schemars::JsonSchema`.";
14
15/// Macro to define a Skill. It wraps a function that takes a single argument and returns a single value.
16///
17/// The argument should implement `serde::Deserialize` to process an incoming JSON body, and the return value should implement `serde::Serialize` to return a JSON body.
18/// Both also need to implement `schemars::JsonSchema` to generate a JSON schema for the input and output.
19/// You can use the `#[derive(schemars::JsonSchema)]` attribute to automatically implement `JsonSchema` for your types.
20///
21/// Also, the doc comment can be used to provide a description of the skill.
22#[proc_macro_attribute]
23pub fn skill(_attr: TokenStream, item: TokenStream) -> TokenStream {
24    let func = syn::parse_macro_input!(item as syn::ItemFn);
25    let func_name = &func.sig.ident;
26    let description = extract_doc_comment(&func);
27
28    let Some(input_type) = func.sig.inputs.last() else {
29        return report_error(ARG_MSG, func.span());
30    };
31    let input_type = match input_type {
32        FnArg::Typed(pat_type) => &pat_type.ty,
33        FnArg::Receiver(_) => return report_error(ARG_MSG, func.span()),
34    };
35    let output_type = extract_output_result(match &func.sig.output {
36        ReturnType::Type(_, ty) => ty,
37        ReturnType::Default => return report_error(RETURN_MSG, func.span()),
38    });
39
40    quote!(
41        #func
42
43        static __SKILL_METADATA: std::sync::LazyLock<::pharia_skill::bindings::exports::pharia::skill::skill_handler::SkillMetadata> = std::sync::LazyLock::new(|| {
44            use ::pharia_skill::bindings::{exports::pharia::skill::skill_handler::SkillMetadata, json};
45            let input_schema = json::schema_for!(#input_type);
46            let output_schema = json::schema_for!(#output_type);
47            SkillMetadata {
48                description: (!#description.is_empty()).then_some(#description.to_string()),
49                input_schema: json::to_vec(&input_schema).expect("Failed to serialize input schema"),
50                output_schema: json::to_vec(&output_schema).expect("Failed to serialize output schema"),
51            }
52        });
53
54        mod __pharia_skill {
55            use ::pharia_skill::bindings::{
56                export,
57                exports::pharia::skill::skill_handler::{Error, Guest, SkillMetadata},
58                json, HandlerResult, WitCsi,
59            };
60
61            pub struct Skill;
62
63            impl Guest for Skill {
64                fn run(input: Vec<u8>) -> Result<Vec<u8>, Error> {
65                    let input = json::from_slice(&input)?;
66                    let output = super::#func_name(&WitCsi, input);
67                    HandlerResult::from(output).into()
68                }
69
70                fn metadata() -> SkillMetadata {
71                    super::__SKILL_METADATA.clone()
72                }
73            }
74
75            export!(Skill);
76        }
77
78    )
79    .into()
80}
81
82// Pull out the type from a Result type if the user used one.
83fn extract_output_result(output_type: &Type) -> &Type {
84    match output_type {
85        Type::Path(path) => path
86            .path
87            .segments
88            .last()
89            // Special case for Result types
90            .and_then(|segment| (segment.ident == "Result").then_some(&segment.arguments))
91            // We only want angle bracket generics
92            .and_then(|args| match args {
93                PathArguments::AngleBracketed(args) => Some(args),
94                PathArguments::Parenthesized(_) | PathArguments::None => None,
95            })
96            // Get the first one that is a type
97            .and_then(|args| {
98                args.args.iter().find_map(|arg| match arg {
99                    GenericArgument::Type(ty) => Some(ty),
100                    _ => None,
101                })
102            }),
103        _ => None,
104    }
105    .unwrap_or(output_type)
106}
107
108fn extract_doc_comment(func: &ItemFn) -> String {
109    func.attrs
110        .iter()
111        // Only grab attributes that are outer doc comments
112        .filter(|attr| matches!(attr.style, AttrStyle::Outer) && attr.path().is_ident("doc"))
113        // All doc comments should be NameValues
114        .filter_map(|attr| attr.meta.require_name_value().ok())
115        // Pull out the literal value of the line
116        .filter_map(|meta_name_value| match &meta_name_value.value {
117            Expr::Lit(lit) => Some(&lit.lit),
118            _ => None,
119        })
120        // Get the string, and trim the extra whitespace
121        .filter_map(|lit| {
122            if let Lit::Str(s) = lit {
123                Some(s.value().trim().to_owned())
124            } else {
125                None
126            }
127        })
128        .collect::<Vec<_>>()
129        .join("\n")
130}