1use std::{ops::Deref, slice::Iter};
7
8use heck::{AsSnakeCase, AsUpperCamelCase};
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11use syn::{
12 Error, GenericArgument, Ident, PathArguments, ReturnType, Type, parse_quote, token::RArrow,
13};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct CapabilityIdent {
18 pub pkg_name: String,
19 pub pkg_version: String,
20 pub state_tn: Ident,
22 pub client_tn: Ident,
24 pub config_tn: Option<Ident>,
26}
27
28impl CapabilityIdent {
29 pub fn cap_id(&self) -> String {
35 self.pkg_name.to_string()
36 }
37
38 pub fn trace_name(&self, name: &FnName) -> Ident {
40 let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
41 let snake = AsSnakeCase(name.0.to_string()).to_string();
42 format_ident!("p__{}__{}", state_snake, snake)
43 }
44
45 pub fn class_name(&self) -> String {
47 AsSnakeCase(self.state_tn.to_string()).to_string()
48 }
49
50 pub fn class_name_static(&self) -> Ident {
52 let state_snake = AsSnakeCase(self.state_tn.to_string())
53 .to_string()
54 .to_uppercase();
55 format_ident!("p__{}", state_snake)
56 }
57
58 pub fn trace_name_static(&self, name: &FnName) -> Ident {
60 let state_snake = AsSnakeCase(self.state_tn.to_string())
61 .to_string()
62 .to_uppercase();
63 let snake = AsSnakeCase(name.0.to_string()).to_string().to_uppercase();
64 format_ident!("p__{}__{}", state_snake, snake)
65 }
66
67 pub fn ffi_name(&self, name: &FnName) -> Ident {
69 let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
70 let snake = AsSnakeCase(name.0.to_string()).to_string();
71 format_ident!("p__{}__{}__ffi", state_snake, snake)
72 }
73
74 pub fn wasm_name(&self, name: &FnName) -> Ident {
76 let state_snake = AsSnakeCase(self.state_tn.to_string()).to_string();
77 let snake = AsSnakeCase(name.0.to_string()).to_string();
78 format_ident!("p__{}__{}__wasm", state_snake, snake)
79 }
80
81 pub fn input_struct(&self, name: &FnName) -> Ident {
83 let state_snake = AsUpperCamelCase(self.state_tn.to_string()).to_string();
84 let snake = AsUpperCamelCase(name.0.to_string()).to_string();
85 format_ident!("p__{}__{}__Input", state_snake, snake)
86 }
87}
88
89#[derive(Debug, Clone)]
90pub struct FnName(pub Ident);
91
92impl FnName {
93 pub fn trace_name(&self) -> Ident {
94 format_ident!("p__{}", AsSnakeCase(self.0.to_string()).to_string())
95 }
96
97 pub fn trace_name_static(&self) -> Ident {
98 format_ident!(
99 "p__{}",
100 AsSnakeCase(self.0.to_string()).to_string().to_uppercase()
101 )
102 }
103
104 pub fn fn_ffi_name(&self) -> Ident {
106 format_ident!("p__{}__ffi", AsSnakeCase(self.0.to_string()).to_string())
107 }
108
109 pub fn fn_wasm_name(&self) -> Ident {
111 format_ident!("p__{}__wasm", AsSnakeCase(self.0.to_string()).to_string())
112 }
113
114 pub fn input_struct_name(&self) -> Ident {
116 format_ident!(
117 "p__{}__Input",
118 AsUpperCamelCase(self.0.to_string()).to_string()
119 )
120 }
121}
122
123impl Deref for FnName {
124 type Target = Ident;
125
126 fn deref(&self) -> &Self::Target {
127 &self.0
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub enum InputParams {
133 None,
134 One(Ident, Box<Type>),
135 Many(Vec<(Ident, Type)>),
136}
137
138pub enum InputParamsIter<'a> {
139 None,
140 One(Option<(&'a Ident, &'a Type)>),
141 Many(Iter<'a, (Ident, Type)>),
142}
143
144impl<'a> Iterator for InputParamsIter<'a> {
145 type Item = (&'a Ident, &'a Type);
146
147 fn next(&mut self) -> Option<Self::Item> {
148 match self {
149 InputParamsIter::None => None,
150 InputParamsIter::One(t) => t.take(),
151 InputParamsIter::Many(params) => params.next().map(|(i, t)| (i, t)),
152 }
153 }
154}
155
156impl InputParams {
157 pub fn is_empty(&self) -> bool {
158 match self {
159 InputParams::None => true,
160 InputParams::One(_, _) => false,
161 InputParams::Many(_) => false,
162 }
163 }
164
165 pub fn iter(&self) -> InputParamsIter<'_> {
166 match self {
167 InputParams::None => InputParamsIter::None,
168 InputParams::One(i, t) => InputParamsIter::One(Some((i, t))),
169 InputParams::Many(params) => InputParamsIter::Many(params.iter()),
170 }
171 }
172
173 pub fn input_type(&self, fn_name: &FnName, class: Option<&CapabilityIdent>) -> TokenStream {
174 match &self {
175 InputParams::Many(_) => {
176 let input_struct_name = class
177 .map(|c| c.input_struct(fn_name))
178 .unwrap_or(fn_name.input_struct_name());
179 quote!(#input_struct_name)
180 }
181 InputParams::One(_, param_ty) => quote!(#param_ty),
182 InputParams::None => quote!(()),
183 }
184 }
185
186 pub fn input_serialization(
187 &self,
188 fn_name: &FnName,
189 class: Option<&CapabilityIdent>,
190 ) -> TokenStream {
191 match &self {
192 InputParams::Many(params) => {
193 let input_struct_name = class
194 .map(|c| c.input_struct(fn_name))
195 .unwrap_or(fn_name.input_struct_name());
196 let args = params.iter().map(|(n, _)| quote!(#n));
197 quote!(Some(&#input_struct_name { #(#args),* }))
198 }
199 InputParams::One(param_name, _) => quote!(Some(&#param_name)),
200 InputParams::None => quote!(None),
201 }
202 }
203
204 pub fn input_args(&self) -> Vec<TokenStream> {
205 match &self {
206 InputParams::Many(params) => params.iter().map(|(n, _)| quote!(input.#n)).collect(),
207 InputParams::One(..) => vec![quote!(input)],
208 InputParams::None => Vec::new(),
209 }
210 }
211
212 pub fn input_struct(&self, fn_name: &FnName, class: Option<&CapabilityIdent>) -> TokenStream {
213 match &self {
214 InputParams::Many(params) => {
215 let input_struct_name = class
216 .map(|c| c.input_struct(fn_name))
217 .unwrap_or(fn_name.input_struct_name());
218 let fields: Vec<_> = params.iter().map(|(n, t)| quote! { pub #n: #t }).collect();
219 quote! {
220 #[::pyroduct::magma]
221 struct #input_struct_name {
222 #(#fields),*
223 }
224 }
225 }
226 InputParams::One(_, _) => quote! {},
227 InputParams::None => quote! {},
228 }
229 }
230}
231
232pub fn is_captured_error(ty: &Type) -> bool {
233 let ty_str = quote!(#ty).to_string().replace(" ", "");
234 ty_str == "CapturedError" || ty_str == "pyroduct::CapturedError" || ty_str == "::pyroduct::CapturedError"
235}
236
237pub fn verify_result_return_type(ret: &ReturnType) -> syn::Result<(Type, Type)> {
238 match ret {
239 ReturnType::Type(_, ty) => {
240 let ty = ty.as_ref();
241 if let Type::Path(type_path) = ty {
242 if let Some(segment) = type_path.path.segments.last()
243 && segment.ident == "Result"
244 && let PathArguments::AngleBracketed(args) = &segment.arguments
245 {
246 if args.args.len() == 2 {
247 let mut iter = args.args.iter();
248 if let (
249 Some(GenericArgument::Type(t)),
250 Some(GenericArgument::Type(e)),
251 ) = (iter.next(), iter.next())
252 {
253 if !is_captured_error(e) {
254 let actual_err_str = quote!(#e).to_string().replace(" ", "");
255 return Err(Error::new_spanned(
256 e,
257 format!(
258 "Invalid error type. Expected 'CapturedError', found '{}'",
259 actual_err_str
260 ),
261 ));
262 }
263 let err_ty: Type = parse_quote!(::pyroduct::CapturedError);
264 return Ok((t.clone(), err_ty));
265 }
266 } else if args.args.len() == 1 {
267 let mut iter = args.args.iter();
268 if let Some(GenericArgument::Type(t)) = iter.next() {
269 let err_ty: Type = parse_quote!(::pyroduct::CapturedError);
270 return Ok((t.clone(), err_ty));
271 }
272 }
273 }
274 }
275 }
276 ReturnType::Default => {}
277 }
278
279 Err(Error::new_spanned(
280 ret,
281 "Function must return Result<T, CapturedError> or Result<T>",
282 ))
283}
284
285#[derive(Debug, Clone)]
286pub struct FnOutput {
287 pub ok_type: Type,
288 pub err_type: Type,
289}
290
291impl FnOutput {
292 pub fn parse(ret: &ReturnType) -> syn::Result<FnOutput> {
293 let (ok_type, err_type) = verify_result_return_type(ret)?;
294 Ok(FnOutput { ok_type, err_type })
295 }
296
297 pub fn to_return_type(&self) -> ReturnType {
298 let ok = &self.ok_type;
299 let err = &self.err_type;
300 let result_ty: Type = parse_quote!(Result<#ok, #err>);
301 ReturnType::Type(RArrow::default(), Box::new(result_ty))
302 }
303
304 pub fn ty(&self) -> Box<Type> {
305 Box::new(self.ok_type.clone())
306 }
307
308 pub fn err(&self) -> Option<&Type> {
309 Some(&self.err_type)
310 }
311}