Skip to main content

pyro_macro/ffi/
paths.rs

1//! Path and naming utilities for capability FFI generation
2//!
3//! This module centralizes all naming conventions used throughout the capability system
4//! to ensure consistency between client and server sides.
5
6use std::{ops::Deref, slice::Iter};
7
8use heck::{AsSnakeCase, AsUpperCamelCase};
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11use syn::{
12    Error, GenericArgument, Ident, PathArguments, ReturnType, Type, parse_quote, token::RArrow,
13};
14
15/// Identity of the capability (State, Client, Error)
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct CapabilityIdent {
18    pub pkg_name: String,
19    pub pkg_version: String,
20    /// The struct being implemented (e.g., "MyStruct")
21    pub state_tn: Ident,
22    /// The client type identifier (e.g., "MyClient")
23    pub client_tn: Ident,
24    /// The config type identifier (e.g., "MyConfig")
25    pub config_tn: Option<Ident>,
26}
27
28impl CapabilityIdent {
29    // ========================================================================
30    // Method Paths
31    // ========================================================================
32
33    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
34    pub fn cap_id(&self) -> String {
35        self.pkg_name.to_string()
36    }
37
38    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
39    pub fn trace_name(&self, name: &FnName) -> Ident {
40        let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
41        let snake = AsSnakeCase(name.0.to_string()).to_string();
42        format_ident!("p__{}__{}", state_snake, snake)
43    }
44
45    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
46    pub fn class_name(&self) -> String {
47        AsSnakeCase(self.state_tn.to_string()).to_string()
48    }
49
50    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
51    pub fn class_name_static(&self) -> Ident {
52        let state_snake = AsSnakeCase(self.state_tn.to_string())
53            .to_string()
54            .to_uppercase();
55        format_ident!("p__{}", state_snake)
56    }
57
58    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
59    pub fn trace_name_static(&self, name: &FnName) -> Ident {
60        let state_snake = AsSnakeCase(self.state_tn.to_string())
61            .to_string()
62            .to_uppercase();
63        let snake = AsSnakeCase(name.0.to_string()).to_string().to_uppercase();
64        format_ident!("p__{}__{}", state_snake, snake)
65    }
66
67    /// FFI function name for a method (e.g., __my_trait__my_state__name__ffi)
68    pub fn ffi_name(&self, name: &FnName) -> Ident {
69        let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
70        let snake = AsSnakeCase(name.0.to_string()).to_string();
71        format_ident!("p__{}__{}__ffi", state_snake, snake)
72    }
73
74    /// WASM import name for a method (e.g., __my_trait__my_state__name__wasm)
75    pub fn wasm_name(&self, name: &FnName) -> Ident {
76        let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
77        let snake = AsSnakeCase(name.0.to_string()).to_string();
78        format_ident!("p__{}__{}__wasm", state_snake, snake)
79    }
80
81    /// Input struct name for a method with multiple parameters
82    pub fn input_struct(&self, name: &FnName) -> Ident {
83        let state_snake = AsUpperCamelCase(self.state_tn.to_string()).to_string();
84        let snake = AsUpperCamelCase(name.0.to_string()).to_string();
85        format_ident!("p__{}__{}__Input", state_snake, snake)
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct FnName(pub Ident);
91
92impl FnName {
93    pub fn trace_name(&self) -> Ident {
94        format_ident!("p__{}", AsSnakeCase(self.0.to_string()).to_string())
95    }
96
97    pub fn trace_name_static(&self) -> Ident {
98        format_ident!(
99            "p__{}",
100            AsSnakeCase(self.0.to_string()).to_string().to_uppercase()
101        )
102    }
103
104    /// Get the FFI function name
105    pub fn fn_ffi_name(&self) -> Ident {
106        format_ident!("p__{}__ffi", AsSnakeCase(self.0.to_string()).to_string())
107    }
108
109    /// Get the WASM import name
110    pub fn fn_wasm_name(&self) -> Ident {
111        format_ident!("p__{}__wasm", AsSnakeCase(self.0.to_string()).to_string())
112    }
113
114    /// Get the input struct name (if multiple parameters)
115    pub fn input_struct_name(&self) -> Ident {
116        format_ident!(
117            "p__{}__Input",
118            AsUpperCamelCase(self.0.to_string()).to_string()
119        )
120    }
121}
122
123impl Deref for FnName {
124    type Target = Ident;
125
126    fn deref(&self) -> &Self::Target {
127        &self.0
128    }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub enum InputParams {
133    None,
134    One(Ident, Box<Type>),
135    Many(Vec<(Ident, Type)>),
136}
137
138pub enum InputParamsIter<'a> {
139    None,
140    One(Option<(&'a Ident, &'a Type)>),
141    Many(Iter<'a, (Ident, Type)>),
142}
143
144impl<'a> Iterator for InputParamsIter<'a> {
145    type Item = (&'a Ident, &'a Type);
146
147    fn next(&mut self) -> Option<Self::Item> {
148        match self {
149            InputParamsIter::None => None,
150            InputParamsIter::One(t) => t.take(),
151            InputParamsIter::Many(params) => params.next().map(|(i, t)| (i, t)),
152        }
153    }
154}
155
156impl InputParams {
157    pub fn is_empty(&self) -> bool {
158        match self {
159            InputParams::None => true,
160            InputParams::One(_, _) => false,
161            InputParams::Many(_) => false,
162        }
163    }
164
165    pub fn iter(&self) -> InputParamsIter<'_> {
166        match self {
167            InputParams::None => InputParamsIter::None,
168            InputParams::One(i, t) => InputParamsIter::One(Some((i, t))),
169            InputParams::Many(params) => InputParamsIter::Many(params.iter()),
170        }
171    }
172
173    pub fn input_type(&self, fn_name: &FnName, class: Option<&CapabilityIdent>) -> TokenStream {
174        match &self {
175            InputParams::Many(_) => {
176                let input_struct_name = class
177                    .map(|c| c.input_struct(fn_name))
178                    .unwrap_or(fn_name.input_struct_name());
179                quote!(#input_struct_name)
180            }
181            InputParams::One(_, param_ty) => quote!(#param_ty),
182            InputParams::None => quote!(()),
183        }
184    }
185
186    pub fn input_serialization(
187        &self,
188        fn_name: &FnName,
189        class: Option<&CapabilityIdent>,
190    ) -> TokenStream {
191        match &self {
192            InputParams::Many(params) => {
193                let input_struct_name = class
194                    .map(|c| c.input_struct(fn_name))
195                    .unwrap_or(fn_name.input_struct_name());
196                let args = params.iter().map(|(n, _)| quote!(#n));
197                quote!(Some(&#input_struct_name { #(#args),* }))
198            }
199            InputParams::One(param_name, _) => quote!(Some(&#param_name)),
200            InputParams::None => quote!(None),
201        }
202    }
203
204    pub fn input_args(&self) -> Vec<TokenStream> {
205        match &self {
206            InputParams::Many(params) => params.iter().map(|(n, _)| quote!(input.#n)).collect(),
207            InputParams::One(..) => vec![quote!(input)],
208            InputParams::None => Vec::new(),
209        }
210    }
211
212    pub fn input_struct(&self, fn_name: &FnName, class: Option<&CapabilityIdent>) -> TokenStream {
213        match &self {
214            InputParams::Many(params) => {
215                let input_struct_name = class
216                    .map(|c| c.input_struct(fn_name))
217                    .unwrap_or(fn_name.input_struct_name());
218                let fields: Vec<_> = params.iter().map(|(n, t)| quote! { pub #n: #t }).collect();
219                quote! {
220                    #[::pyroduct::magma]
221                    struct #input_struct_name {
222                        #(#fields),*
223                    }
224                }
225            }
226            InputParams::One(_, _) => quote! {},
227            InputParams::None => quote! {},
228        }
229    }
230}
231
232pub fn is_captured_error(ty: &Type) -> bool {
233    let ty_str = quote!(#ty).to_string().replace(" ", "");
234    ty_str == "CapturedError" || ty_str == "pyroduct::CapturedError" || ty_str == "::pyroduct::CapturedError"
235}
236
237pub fn verify_result_return_type(ret: &ReturnType) -> syn::Result<(Type, Type)> {
238    match ret {
239        ReturnType::Type(_, ty) => {
240            let ty = ty.as_ref();
241            if let Type::Path(type_path) = ty {
242                if let Some(segment) = type_path.path.segments.last()
243                    && segment.ident == "Result"
244                    && let PathArguments::AngleBracketed(args) = &segment.arguments
245                {
246                    if args.args.len() == 2 {
247                        let mut iter = args.args.iter();
248                        if let (
249                            Some(GenericArgument::Type(t)),
250                            Some(GenericArgument::Type(e)),
251                        ) = (iter.next(), iter.next())
252                        {
253                            if !is_captured_error(e) {
254                                let actual_err_str = quote!(#e).to_string().replace(" ", "");
255                                return Err(Error::new_spanned(
256                                    e,
257                                    format!(
258                                        "Invalid error type. Expected 'CapturedError', found '{}'",
259                                        actual_err_str
260                                    ),
261                                ));
262                            }
263                            let err_ty: Type = parse_quote!(::pyroduct::CapturedError);
264                            return Ok((t.clone(), err_ty));
265                        }
266                    } else if args.args.len() == 1 {
267                        let mut iter = args.args.iter();
268                        if let Some(GenericArgument::Type(t)) = iter.next() {
269                            let err_ty: Type = parse_quote!(::pyroduct::CapturedError);
270                            return Ok((t.clone(), err_ty));
271                        }
272                    }
273                }
274            }
275        }
276        ReturnType::Default => {}
277    }
278
279    Err(Error::new_spanned(
280        ret,
281        "Function must return Result<T, CapturedError> or Result<T>",
282    ))
283}
284
285#[derive(Debug, Clone)]
286pub struct FnOutput {
287    pub ok_type: Type,
288    pub err_type: Type,
289}
290
291impl FnOutput {
292    pub fn parse(ret: &ReturnType) -> syn::Result<FnOutput> {
293        let (ok_type, err_type) = verify_result_return_type(ret)?;
294        Ok(FnOutput { ok_type, err_type })
295    }
296
297    pub fn to_return_type(&self) -> ReturnType {
298        let ok = &self.ok_type;
299        let err = &self.err_type;
300        let result_ty: Type = parse_quote!(Result<#ok, #err>);
301        ReturnType::Type(RArrow::default(), Box::new(result_ty))
302    }
303
304    pub fn ty(&self) -> Box<Type> {
305        Box::new(self.ok_type.clone())
306    }
307
308    pub fn err(&self) -> Option<&Type> {
309        Some(&self.err_type)
310    }
311}