1use std::rc::Rc;
11
12use proc_macro2::TokenStream;
13use quote::quote;
14use syn::{Error, FnArg, Ident, ImplItemFn, Pat, Type};
15
16use super::paths::{CapabilityIdent, FnName, FnOutput};
17use crate::{ffi::paths::InputParams, utils::extract_ident_from_type};
18
19#[derive(Debug, Clone)]
21pub struct ImplMethod {
22 pub name: FnName,
23 pub class: Rc<CapabilityIdent>,
24 pub client_param: Ident,
25 pub inputs: InputParams,
26 pub output: FnOutput,
27 pub is_async: bool,
28 pub is_mutable_self: bool,
29 pub body: syn::Block,
30 pub attrs: Vec<syn::Attribute>,
31}
32
33impl ImplMethod {
34 pub fn parse(
36 f: &ImplItemFn,
37 class: &Rc<CapabilityIdent>,
38 required_docs: bool,
39 ) -> syn::Result<Self> {
40 let sig = &f.sig;
41 let name = sig.ident.clone();
42
43 let has_docs = f.attrs.iter().any(|attr| attr.path().is_ident("doc"));
44 if !has_docs && required_docs {
45 return Err(Error::new_spanned(
46 &name,
47 "Capability methods must have documentation (///) to generate API specs.",
48 ));
49 }
50
51 let is_mutable_self = match sig.inputs.first() {
53 Some(FnArg::Receiver(r)) => {
54 if r.reference.is_none() {
55 return Err(Error::new_spanned(
56 r,
57 "Capability methods must take &self or &mut self (not value self)",
58 ));
59 }
60 r.mutability.is_some()
61 }
62 Some(arg) => {
63 return Err(Error::new_spanned(
64 arg,
65 "Capability methods must take &self or &mut self as first parameter",
66 ));
67 }
68 None => {
69 return Err(Error::new_spanned(
70 sig,
71 "Capability methods must take &self or &mut self",
72 ));
73 }
74 };
75
76 let client_param_arg = sig.inputs.iter().nth(1);
78 let client_param_ident = match client_param_arg {
79 Some(FnArg::Typed(pt)) => {
80 let ident = if let Pat::Ident(pi) = &*pt.pat {
81 pi.ident.clone()
82 } else {
83 return Err(Error::new_spanned(&pt.pat, "Expected simple identifier"));
84 };
85
86 if let Type::Reference(r) = &*pt.ty {
87 let param_type = extract_ident_from_type(&r.elem)?;
88 if param_type != class.client_tn {
89 return Err(Error::new_spanned(
90 &pt.ty,
91 format!("Expected &{}, found &{}", class.client_tn, param_type),
92 ));
93 }
94 } else {
95 return Err(Error::new_spanned(
96 &pt.ty,
97 format!("Expected &{}", class.client_tn),
98 ));
99 }
100 ident
101 }
102 Some(arg) => {
103 return Err(Error::new_spanned(
104 arg,
105 format!("Expected client: &{}", class.client_tn),
106 ));
107 }
108 None => {
109 return Err(Error::new_spanned(
110 sig,
111 format!(
112 "Capability methods must take client: &{} as second parameter",
113 class.client_tn
114 ),
115 ));
116 }
117 };
118
119 let mut inputs = Vec::new();
121 for arg in sig.inputs.iter().skip(2) {
122 if let FnArg::Typed(pt) = arg {
123 let arg_name = if let Pat::Ident(pi) = &*pt.pat {
124 pi.ident.clone()
125 } else {
126 return Err(Error::new_spanned(
127 &pt.pat,
128 "Method arguments must be named identifiers",
129 ));
130 };
131 inputs.push((arg_name, (*pt.ty).clone()));
132 }
133 }
134
135 let inputs = if inputs.is_empty() {
136 InputParams::None
137 } else if inputs.len() == 1 {
138 let (n, t) = inputs.pop().unwrap();
139 InputParams::One(n, t.into())
140 } else {
141 InputParams::Many(inputs)
142 };
143
144 let output = FnOutput::parse(&sig.output)?;
146
147 Ok(Self {
148 name: FnName(name),
149 class: class.clone(),
150 client_param: client_param_ident,
151 inputs,
152 output,
153 is_async: sig.asyncness.is_some(),
154 is_mutable_self,
155 body: f.block.clone(),
156 attrs: f.attrs.clone(),
157 })
158 }
159
160 pub fn generate_input_struct(&self) -> TokenStream {
161 self.inputs.input_struct(&self.name, Some(&self.class))
162 }
163
164 pub fn generate_server_method(&self) -> TokenStream {
167 let name = &self.name.0;
168 let attrs = &self.attrs;
169 let client_type = &self.class.client_tn;
170 let client_var = &self.client_param;
171 let body = &self.body;
172 let output = &self.output.to_return_type();
173
174 let async_kw = if self.is_async {
175 quote!(async)
176 } else {
177 quote!()
178 };
179
180 let self_arg = if self.is_mutable_self {
182 quote!(&mut self)
183 } else {
184 quote!(&self)
185 };
186
187 let args: Vec<_> = self.inputs.iter().map(|(n, t)| quote!(#n: #t)).collect();
188
189 quote! {
190 #(#attrs)*
191 pub #async_kw fn #name(#self_arg, #client_var: &#client_type, #(#args),*) #output #body
192 }
193 }
194
195 pub fn generate_server_ffi(&self) -> TokenStream {
196 let fn_ffi_name = self.class.ffi_name(&self.name);
197 let input_struct = self.inputs.input_struct(&self.name, Some(&self.class));
198 let state_tn = &self.class.state_tn;
199 let client_tn = &self.class.client_tn;
200 let mut call_args = Vec::new();
201
202 let state_retrieval = quote! {
203 let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
204 Ok(state) => state,
205 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
206 };
207 let state = state_ptr.as_ref::<#state_tn>();
208 };
209 let client_retrieval = quote! {
211 let client: #client_tn = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
212 Ok(buf) => buf,
213 Err(err) => return err.encode().view(),
214 };
215 };
216 call_args.push(quote! { &client });
217 let input_retrieval = match &self.inputs {
218 InputParams::One(_, ty) => {
219 call_args.push(quote!(input));
220 quote! {
221 let input: #ty = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
222 Ok(buf) => buf,
223 Err(err) => return err.encode().view(),
224 };
225 }
226 }
227 InputParams::Many(items) => {
228 let input_struct_name = self.class.input_struct(&self.name);
229 let args = items.iter().map(|(n, _)| quote!(input.#n));
230 call_args.extend(args);
231 quote! {
232 let input: #input_struct_name = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
233 Ok(buf) => buf,
234 Err(err) => return err.encode().view(),
235 };
236 }
237 }
238 InputParams::None => quote! {},
239 };
240 let fn_name = &self.name.0;
241 let method_call = quote!(state.#fn_name(#(#call_args),*));
242
243 let (ffi_ret, body) = if self.is_async {
245 (
246 quote!(::pyroduct::ffi::FuturePyroView),
247 quote! {
248 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
249 #state_retrieval
250 #client_retrieval
251 #input_retrieval
252 ::pyroduct::ffi::guest::serialize_result(#method_call.await)
253 }, capability_state_ptr.object_id, mux_id)
254 },
255 )
256 } else {
257 (
258 quote!(::pyroduct::format::PyroViewPtr),
259 quote! {
260 ::pyroduct::ffi::guest::execute_safe(|| {
261 #state_retrieval
262 #client_retrieval
263 #input_retrieval
264 ::pyroduct::ffi::guest::serialize_result(#method_call)
265 }, capability_state_ptr.object_id, mux_id)
266 },
267 )
268 };
269
270 quote! {
271 #input_struct
272
273 #[unsafe(no_mangle)]
274 pub unsafe extern "C" fn #fn_ffi_name (
275 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
276 client_state_ptr: ::pyroduct::format::PyroRefPtr,
277 input_ptr: ::pyroduct::format::PyroRefPtr,
278 ) -> #ffi_ret {
279 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
280 #body
281 }
282 }
283 }
284
285 pub fn generate_vtable_entry(&self) -> TokenStream {
286 let fn_ffi_name = self.class.ffi_name(&self.name);
287 let wasm_name_ident = self.class.trace_name_static(&self.name);
288
289 let func_variant = if self.is_async {
290 quote! {
291 ::pyroduct::ffi::Function::Async(#fn_ffi_name)
292 }
293 } else {
294 quote! {
295 ::pyroduct::ffi::Function::Sync(#fn_ffi_name)
296 }
297 };
298
299 quote! {
300 ::pyroduct::ffi::MethodExport {
301 name: #wasm_name_ident.as_ptr(),
302 name_len: #wasm_name_ident.len(),
303 func: #func_variant,
304 }
305 }
306 }
307
308 pub fn generate_client_method(&self, module: Option<&Ident>) -> TokenStream {
311 let name = &self.name.0;
312 let attrs = &self.attrs;
313
314 let wasm_call = self.class.wasm_name(&self.name);
315 let wasm_call = match module {
316 Some(m) => quote! {#m::#wasm_call},
317 None => quote! {#wasm_call},
318 };
319
320 let wasm_call = quote! {
321 |client_state_ptr: *const u8,
322 input_ptr: *const u8| {
323 unsafe {
324 #wasm_call(client_state_ptr, input_ptr)
325 }
326 }
327 };
328
329 let mut args: Vec<_> = self.inputs.iter().map(|(n, t)| quote!(#n: #t)).collect();
331 args.insert(0, quote!(&self));
332
333 let i_struct = self.inputs.input_struct(&self.name, Some(&self.class));
334 let i_name = self.inputs.input_type(&self.name, Some(&self.class));
335 let i_fill = self
336 .inputs
337 .input_serialization(&self.name, Some(&self.class));
338 let output_type = &self.output.ty();
339 let output_return = &self.output.to_return_type();
340
341 quote! {
342 #(#attrs)*
343 fn #name(#(#args),*) #output_return {
344 #i_struct
345
346 self.__call_result_from_wasm::<#i_name, #output_type, _>(#i_fill, #wasm_call)
347 }
348 }
349 }
350
351 pub fn generate_client_wasm(&self) -> TokenStream {
352 let fn_wasm_name = self.class.wasm_name(&self.name);
353 quote! {
354 pub fn #fn_wasm_name(
355 cs_ptr: *const u8,
356 in_ptr: *const u8,
357 ) -> *mut u8;
358 }
359 }
360
361 pub fn doc_attrs(&self) -> Vec<&syn::Attribute> {
362 self.attrs
363 .iter()
364 .filter(|attr| attr.path().is_ident("doc"))
365 .collect()
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use crate::fmt::assert_code_eq_token;
372
373 use super::*;
374 use quote::format_ident;
375 use syn::parse_quote;
376
377 fn mock_class() -> Rc<CapabilityIdent> {
378 Rc::new(CapabilityIdent {
379 pkg_name: "cap_name".to_string(),
380 pkg_version: "0.1.0".to_string(),
381 config_tn: None,
382 state_tn: format_ident!("MyServer"),
383 client_tn: format_ident!("MyClient"),
384 })
385 }
386
387 #[test]
388 fn test_server_method_preserves_mutability() {
389 let class = mock_class();
390
391 let f: ImplItemFn = parse_quote! {
393 fn update(&mut self, ctx: &MyClient, val: u32) -> Result<(), CapturedError> {
394 self.val = val;
395 Ok(())
396 }
397 };
398
399 let method = ImplMethod::parse(&f, &class, false).unwrap();
401 let output = method.generate_server_method();
402
403 let expected = quote! {
406 pub fn update(&mut self, ctx: &MyClient, val: u32) -> Result<(), ::pyroduct::CapturedError> {
407 self.val = val;
408 Ok(())
409 }
410 };
411
412 assert_code_eq_token(&output, &expected);
413 }
414
415 #[test]
416 fn test_client_method_forces_immutability() {
417 let class = mock_class();
418 let module = format_ident!("wasm_bridge");
419
420 let f: ImplItemFn = parse_quote! {
422 fn update(&mut self, ctx: &MyClient, val: u32) -> Result<(), CapturedError> { }
423 };
424
425 let method = ImplMethod::parse(&f, &class, false).unwrap();
427 let output = method.generate_client_method(Some(&module));
428
429 let output_str = output.to_string();
430 assert!(output_str.contains("fn update (& self"));
431 assert!(!output_str.contains("& mut self"));
432 }
433
434 #[test]
435 fn test_parse_validates_client_arg_name_capture() {
436 let class = mock_class();
437
438 let f: ImplItemFn = parse_quote! {
440 fn get(&self, c: &MyClient) -> Result<u32, CapturedError> { Ok(10) }
441 };
442
443 let method = ImplMethod::parse(&f, &class, false).unwrap();
445 let output = method.generate_server_method();
446
447 let expected = quote! {
449 pub fn get(&self, c: &MyClient) -> Result<u32, ::pyroduct::CapturedError> { Ok(10) }
450 };
451
452 assert_code_eq_token(&output, &expected);
453 }
454
455 #[test]
456 fn test_reject_value_self() {
457 let class = mock_class();
458
459 let f: ImplItemFn = parse_quote! {
461 fn consume(self, _c: &MyClient) -> Result<(), CapturedError> {}
462 };
463
464 let err = ImplMethod::parse(&f, &class, false).unwrap_err();
466 assert!(err.to_string().contains("not value self"));
467 }
468
469 fn mock_method_base(name: &str, is_async: bool) -> ImplMethod {
470 let output = FnOutput {
471 ok_type: parse_quote!(u32),
472 err_type: parse_quote!(::pyroduct::CapturedError),
473 };
474 let class = Rc::new(CapabilityIdent {
475 pkg_name: "cap_name".to_string(),
476 pkg_version: "0.1.0".to_string(),
477 config_tn: None,
478 state_tn: format_ident!("MockServer"),
479 client_tn: format_ident!("MockClient"),
480 });
481
482 ImplMethod {
483 name: FnName(format_ident!("{}", name)),
484 class,
485 client_param: format_ident!("client"),
486 inputs: InputParams::None,
487 output,
488 is_async,
489 is_mutable_self: false,
490 body: parse_quote!({ 0 }),
491 attrs: vec![],
492 }
493 }
494
495 #[test]
499 fn test_case_4_async_no_input_with_client() {
500 let ffi = mock_method_base("test_async_client", true);
501
502 let capability_tokens = ffi.generate_server_ffi();
503 let module_tokens = ffi.generate_client_method(None);
504 let module_tokens = quote! {
505 impl Mod {
506 #module_tokens
507 }
508 };
509
510 let output_capability = quote! {
511 #[unsafe(no_mangle)]
512 pub unsafe extern "C" fn p__mock_server__test_async_client__ffi(
513 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
514 client_state_ptr: ::pyroduct::format::PyroRefPtr,
515 input_ptr: ::pyroduct::format::PyroRefPtr,
516 ) -> ::pyroduct::ffi::FuturePyroView {
517 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
518 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
519 let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
520 Ok(state) => state,
521 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
522 };
523 let state = state_ptr.as_ref::<MockServer>();
524
525 let client: MockClient = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
526 Ok(buf) => buf,
527 Err(err) => return err.encode().view(),
528 };
529 ::pyroduct::ffi::guest::serialize_result(state.test_async_client(&client).await)
530 }, capability_state_ptr.object_id, mux_id)
531 }
532 };
533
534 crate::fmt::assert_code_eq_token(&capability_tokens, &output_capability);
535
536 let output_module = quote! {
537 impl Mod {
538 fn test_async_client(&self) -> Result<u32, ::pyroduct::CapturedError> {
539 self.__call_result_from_wasm::<
540 (),
541 u32,
542 _,
543 >(
544 None,
545 |client_state_ptr: *const u8,
546 input_ptr: *const u8| {
547 unsafe {
548 p__mock_server__test_async_client__wasm(
549 client_state_ptr,
550 input_ptr
551 )
552 }
553 }
554 )
555 }
556 }
557 };
558
559 crate::fmt::assert_code_eq_token(&module_tokens, &output_module);
561 }
562
563 #[test]
567 fn test_case_5_sync_single_input_with_client() {
568 let mut ffi = mock_method_base("test_sync_client_input", false);
569 ffi.inputs = InputParams::One(format_ident!("x"), parse_quote!(i32));
570 let capability_tokens = ffi.generate_server_ffi();
571 let module_tokens = ffi.generate_client_method(None);
572 let module_tokens = quote! {
573 impl Mod {
574 #module_tokens
575 }
576 };
577
578 let output_capability = quote! {
579 #[unsafe(no_mangle)]
580 pub unsafe extern "C" fn p__mock_server__test_sync_client_input__ffi(
581 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
582 client_state_ptr: ::pyroduct::format::PyroRefPtr,
583 input_ptr: ::pyroduct::format::PyroRefPtr,
584 ) -> ::pyroduct::format::PyroViewPtr {
585 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
586 ::pyroduct::ffi::guest::execute_safe(|| {
587 let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
588 Ok(state) => state,
589 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
590 };
591 let state = state_ptr.as_ref::<MockServer>();
592
593 let client: MockClient = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
594 Ok(buf) => buf,
595 Err(err) => return err.encode().view(),
596 };
597
598 let input: i32 = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
599 Ok(buf) => buf,
600 Err(err) => return err.encode().view(),
601 };
602 ::pyroduct::ffi::guest::serialize_result(state.test_sync_client_input(&client, input))
603 }, capability_state_ptr.object_id, mux_id)
604 }
605 };
606
607 crate::fmt::assert_code_eq_token(&capability_tokens, &output_capability);
608
609 let output_module = quote! {
610 impl Mod {
611 fn test_sync_client_input(&self, x: i32) -> Result<u32, ::pyroduct::CapturedError> {
612 self.__call_result_from_wasm::<
613 i32,
614 u32,
615 _,
616 >(
617 Some(&x),
618 |client_state_ptr: *const u8,
619 input_ptr: *const u8| {
620 unsafe {
621 p__mock_server__test_sync_client_input__wasm(
622 client_state_ptr,
623 input_ptr,
624 )
625 }
626 },
627 )
628 }
629 }
630 };
631
632 crate::fmt::assert_code_eq_token(&module_tokens, &output_module);
633 }
634
635 #[test]
639 fn test_case_full_sci() {
640 let mut ffi = mock_method_base("test_sci_multi", true);
641 ffi.inputs = InputParams::Many(vec![
642 (format_ident!("a"), parse_quote!(i32)),
643 (format_ident!("b"), parse_quote!(i32)),
644 ]);
645
646 let capability_tokens = ffi.generate_server_ffi();
647 let module_tokens = ffi.generate_client_method(None);
648 let module_tokens = quote! {
649 impl Mod {
650 #module_tokens
651 }
652 };
653
654 let output_capability = quote! {
655 #[::pyroduct::magma]
656 struct p__MockServer__TestSciMulti__Input {
657 pub a: i32,
658 pub b: i32,
659 }
660
661 #[unsafe(no_mangle)]
662 pub unsafe extern "C" fn p__mock_server__test_sci_multi__ffi(
663 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
664 client_state_ptr: ::pyroduct::format::PyroRefPtr,
665 input_ptr: ::pyroduct::format::PyroRefPtr,
666 ) -> ::pyroduct::ffi::FuturePyroView {
667 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
668 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
669 let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
670 Ok(state) => state,
671 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
672 };
673 let state = state_ptr.as_ref::<MockServer>();
674
675 let client: MockClient = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
676 Ok(buf) => buf,
677 Err(err) => return err.encode().view(),
678 };
679
680 let input: p__MockServer__TestSciMulti__Input = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
681 Ok(buf) => buf,
682 Err(err) => return err.encode().view(),
683 };
684 ::pyroduct::ffi::guest::serialize_result(state.test_sci_multi(&client, input.a, input.b).await)
685 }, capability_state_ptr.object_id, mux_id)
686 }
687 };
688
689 crate::fmt::assert_code_eq_token(&capability_tokens, &output_capability);
691
692 let output_module = quote! {
693 impl Mod {
694 fn test_sci_multi(&self, a: i32, b: i32) -> Result<u32, ::pyroduct::CapturedError> {
695 #[::pyroduct::magma]
696 struct p__MockServer__TestSciMulti__Input {
697 pub a: i32,
698 pub b: i32,
699 }
700 self.__call_result_from_wasm::<
701 p__MockServer__TestSciMulti__Input,
702 u32,
703 _,
704 >(
705 Some(
706 &p__MockServer__TestSciMulti__Input {
707 a,
708 b,
709 },
710 ),
711 |client_state_ptr: *const u8,
712 input_ptr: *const u8| {
713 unsafe {
714 p__mock_server__test_sci_multi__wasm(
715 client_state_ptr,
716 input_ptr,
717 )
718 }
719 },
720 )
721 }
722 }
723 };
724
725 crate::fmt::assert_code_eq_token(&module_tokens, &output_module);
727 }
728}