Skip to main content

pyro_macro/ffi/
paths.rs

1//! Path and naming utilities for capability FFI generation
2//!
3//! This module centralizes all naming conventions used throughout the capability system
4//! to ensure consistency between client and server sides.
5
6use std::{ops::Deref, slice::Iter};
7
8use heck::{AsSnakeCase, AsUpperCamelCase};
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11use syn::{
12    Error, GenericArgument, Ident, PathArguments, ReturnType, Type, parse_quote, token::RArrow,
13};
14
15/// Identity of the capability (State, Client, Error)
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct CapabilityIdent {
18    pub pkg_name: String,
19    pub pkg_version: String,
20    /// The struct being implemented (e.g., "MyStruct")
21    pub state_tn: Ident,
22    /// The client type identifier (e.g., "MyClient")
23    pub client_tn: Ident,
24    /// The config type identifier (e.g., "MyConfig")
25    pub config_tn: Option<Ident>,
26    /// The error type, if present (e.g., "MyError")
27    pub error_tn: Option<Type>,
28}
29
30impl CapabilityIdent {
31    // ========================================================================
32    // Method Paths
33    // ========================================================================
34
35    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
36    pub fn cap_id(&self) -> String {
37        format!("{}", self.pkg_name)
38    }
39
40    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
41    pub fn trace_name(&self, name: &FnName) -> Ident {
42        let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
43        let snake = AsSnakeCase(name.0.to_string()).to_string();
44        format_ident!("p__{}__{}", state_snake, snake)
45    }
46
47    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
48    pub fn class_name_static(&self) -> Ident {
49        let state_snake = AsSnakeCase(self.state_tn.to_string())
50            .to_string()
51            .to_uppercase();
52        format_ident!("p__{}", state_snake)
53    }
54
55    /// Library identifier for a method (e.g., __my_trait__my_state__method_name)
56    pub fn trace_name_static(&self, name: &FnName) -> Ident {
57        let state_snake = AsSnakeCase(self.state_tn.to_string())
58            .to_string()
59            .to_uppercase();
60        let snake = AsSnakeCase(name.0.to_string()).to_string().to_uppercase();
61        format_ident!("p__{}__{}", state_snake, snake)
62    }
63
64    /// FFI function name for a method (e.g., __my_trait__my_state__name__ffi)
65    pub fn ffi_name(&self, name: &FnName) -> Ident {
66        let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
67        let snake = AsSnakeCase(name.0.to_string()).to_string();
68        format_ident!("p__{}__{}__ffi", state_snake, snake)
69    }
70
71    /// WASM import name for a method (e.g., __my_trait__my_state__name__wasm)
72    pub fn wasm_name(&self, name: &FnName) -> Ident {
73        let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
74        let snake = AsSnakeCase(name.0.to_string()).to_string();
75        format_ident!("p__{}__{}__wasm", state_snake, snake)
76    }
77
78    /// Input struct name for a method with multiple parameters
79    pub fn input_struct(&self, name: &FnName) -> Ident {
80        let state_snake = AsUpperCamelCase(self.state_tn.to_string()).to_string();
81        let snake = AsUpperCamelCase(name.0.to_string()).to_string();
82        format_ident!("p__{}__{}__Input", state_snake, snake)
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct FnName(pub Ident);
88
89impl FnName {
90    pub fn trace_name(&self) -> Ident {
91        format_ident!("p__{}", AsSnakeCase(self.0.to_string()).to_string())
92    }
93
94    pub fn trace_name_static(&self) -> Ident {
95        format_ident!(
96            "p__{}",
97            AsSnakeCase(self.0.to_string()).to_string().to_uppercase()
98        )
99    }
100
101    /// Get the FFI function name
102    pub fn fn_ffi_name(&self) -> Ident {
103        format_ident!("p__{}__ffi", AsSnakeCase(self.0.to_string()).to_string())
104    }
105
106    /// Get the WASM import name
107    pub fn fn_wasm_name(&self) -> Ident {
108        format_ident!("p__{}__wasm", AsSnakeCase(self.0.to_string()).to_string())
109    }
110
111    /// Get the input struct name (if multiple parameters)
112    pub fn input_struct_name(&self) -> Ident {
113        format_ident!(
114            "p__{}__Input",
115            AsUpperCamelCase(self.0.to_string()).to_string()
116        )
117    }
118}
119
120impl Deref for FnName {
121    type Target = Ident;
122
123    fn deref(&self) -> &Self::Target {
124        &self.0
125    }
126}
127
128#[derive(Debug, Clone, PartialEq, Eq)]
129pub enum InputParams {
130    None,
131    One(Ident, Type),
132    Many(Vec<(Ident, Type)>),
133}
134
135pub enum InputParamsIter<'a> {
136    None,
137    One(Option<(&'a Ident, &'a Type)>),
138    Many(Iter<'a, (Ident, Type)>),
139}
140
141impl<'a> Iterator for InputParamsIter<'a> {
142    type Item = (&'a Ident, &'a Type);
143
144    fn next(&mut self) -> Option<Self::Item> {
145        match self {
146            InputParamsIter::None => None,
147            InputParamsIter::One(t) => t.take(),
148            InputParamsIter::Many(params) => params.next().map(|(i, t)| (i, t)),
149        }
150    }
151}
152
153impl InputParams {
154    pub fn is_empty(&self) -> bool {
155        match self {
156            InputParams::None => true,
157            InputParams::One(_, _) => false,
158            InputParams::Many(_) => false,
159        }
160    }
161
162    pub fn iter(&self) -> InputParamsIter<'_> {
163        match self {
164            InputParams::None => InputParamsIter::None,
165            InputParams::One(i, t) => InputParamsIter::One(Some((i, t))),
166            InputParams::Many(params) => InputParamsIter::Many(params.iter()),
167        }
168    }
169
170    pub fn input_type(&self, fn_name: &FnName, class: Option<&CapabilityIdent>) -> TokenStream {
171        match &self {
172            InputParams::Many(_) => {
173                let input_struct_name = class
174                    .map(|c| c.input_struct(&fn_name))
175                    .unwrap_or(fn_name.input_struct_name());
176                quote!(#input_struct_name)
177            }
178            InputParams::One(_, param_ty) => quote!(#param_ty),
179            InputParams::None => quote!(()),
180        }
181    }
182
183    pub fn input_serialization(
184        &self,
185        fn_name: &FnName,
186        class: Option<&CapabilityIdent>,
187    ) -> TokenStream {
188        match &self {
189            InputParams::Many(params) => {
190                let input_struct_name = class
191                    .map(|c| c.input_struct(&fn_name))
192                    .unwrap_or(fn_name.input_struct_name());
193                let args = params.iter().map(|(n, _)| quote!(#n));
194                quote!(Some(&#input_struct_name { #(#args),* }))
195            }
196            InputParams::One(param_name, _) => quote!(Some(&#param_name)),
197            InputParams::None => quote!(None),
198        }
199    }
200
201    pub fn input_args(&self) -> Vec<TokenStream> {
202        match &self {
203            InputParams::Many(params) => params.iter().map(|(n, _)| quote!(input.#n)).collect(),
204            InputParams::One(..) => vec![quote!(input)],
205            InputParams::None => Vec::new(),
206        }
207    }
208
209    pub fn input_struct(&self, fn_name: &FnName, class: Option<&CapabilityIdent>) -> TokenStream {
210        match &self {
211            InputParams::Many(params) => {
212                let input_struct_name = class
213                    .map(|c| c.input_struct(&fn_name))
214                    .unwrap_or(fn_name.input_struct_name());
215                let fields: Vec<_> = params.iter().map(|(n, t)| quote! { pub #n: #t }).collect();
216                quote! {
217                    #[::pyroduct::magma]
218                    struct #input_struct_name {
219                        #(#fields),*
220                    }
221                }
222            }
223            InputParams::One(_, _) => quote! {},
224            InputParams::None => quote! {},
225        }
226    }
227}
228
229#[derive(Debug, Clone)]
230pub enum FnOutput {
231    None,
232    Single(Type),
233    Result(Type, Type),
234}
235
236impl FnOutput {
237    pub fn parse(ret: &ReturnType, expected_err: Option<&Type>) -> syn::Result<FnOutput> {
238        let mut output = FnOutput::None;
239        match ret {
240            // Handle "-> " (Default)
241            ReturnType::Default => {}
242
243            ReturnType::Type(_, ty) => {
244                let ty = ty.as_ref();
245                output = FnOutput::Single(ty.clone());
246                match ty {
247                    Type::Tuple(tuple) if tuple.elems.is_empty() => output = FnOutput::None,
248                    Type::Path(type_path) => {
249                        // Check if the last segment is "Result" (heuristic)
250                        if let Some(segment) = type_path.path.segments.last() {
251                            if segment.ident == "Result" {
252                                if let PathArguments::AngleBracketed(args) = &segment.arguments {
253                                    // Ensure we have exactly 2 generic arguments: <T, E>
254                                    if args.args.len() == 2 {
255                                        let mut iter = args.args.iter();
256                                        // Ensure both arguments are Types (not lifetimes or consts)
257                                        if let (
258                                            Some(GenericArgument::Type(t)),
259                                            Some(GenericArgument::Type(e)),
260                                        ) = (iter.next(), iter.next())
261                                        {
262                                            output = FnOutput::Result(t.clone(), e.clone());
263                                        }
264                                    }
265                                }
266                            }
267                        }
268                    }
269                    _ => {}
270                }
271            }
272        }
273
274        match (output, expected_err) {
275            (a @ FnOutput::None, None)
276            | (a @ FnOutput::Single(_), None)
277            | (a @ FnOutput::Result(_, _), None) => Ok(a),
278            (FnOutput::None, Some(target_error)) | (FnOutput::Single(_), Some(target_error)) => {
279                let target_err_str = quote!(#target_error).to_string().replace(" ", "");
280                Err(Error::new_spanned(
281                    ret,
282                    format!(
283                        "Expected a result with '{}' or 'Self::Error' error type",
284                        target_err_str
285                    ),
286                ))
287            }
288            (FnOutput::Result(val, err_type), Some(target_error)) => {
289                let self_err_str: Type = parse_quote!(Self::Error);
290                if &err_type != target_error && &err_type != &self_err_str {
291                    let actual_err_str = quote!(#err_type).to_string().replace(" ", "");
292                    let target_err_str = quote!(#target_error).to_string().replace(" ", "");
293                    Err(Error::new_spanned(
294                        err_type,
295                        format!(
296                            "Invalid error type. Expected '{}' or 'Self::Error', found '{}'",
297                            target_err_str, actual_err_str
298                        ),
299                    ))
300                } else {
301                    Ok(FnOutput::Result(val, err_type))
302                }
303            }
304        }
305    }
306
307    pub fn to_return_type(&self) -> ReturnType {
308        match self {
309            // Maps back to no return arrow (void)
310            FnOutput::None => ReturnType::Default,
311
312            // Maps back to "-> T"
313            FnOutput::Single(ty) => ReturnType::Type(RArrow::default(), Box::new(ty.clone())),
314
315            // Maps back to "-> Result<T, E>"
316            FnOutput::Result(ok, err) => {
317                let result_ty: Type = parse_quote!(Result<#ok, #err>);
318                ReturnType::Type(RArrow::default(), Box::new(result_ty))
319            }
320        }
321    }
322
323    pub fn ty(&self) -> Type {
324        match self {
325            // Maps back to no return arrow (void)
326            FnOutput::None => parse_quote!(()),
327
328            // Maps back to "-> T"
329            FnOutput::Single(ty) => ty.clone(),
330
331            // Maps back to "-> Result<T, E>"
332            FnOutput::Result(ok, _) => ok.clone(),
333        }
334    }
335
336    pub fn err(&self) -> Option<&Type> {
337        match self {
338            // Maps back to no return arrow (void)
339            FnOutput::None => None,
340
341            // Maps back to "-> T"
342            FnOutput::Single(_) => None,
343
344            // Maps back to "-> Result<T, E>"
345            FnOutput::Result(_, err) => Some(err),
346        }
347    }
348}