1use std::rc::Rc;
11
12use proc_macro2::TokenStream;
13use quote::quote;
14use syn::{Error, FnArg, Ident, ImplItemFn, Pat, Type};
15
16use super::paths::{CapabilityIdent, FnName, FnOutput};
17use crate::{ffi::paths::InputParams, utils::extract_ident_from_type};
18
19#[derive(Debug, Clone)]
21pub struct ImplMethod {
22 pub name: FnName,
23 pub class: Rc<CapabilityIdent>,
24 pub client_param: Ident,
25 pub inputs: InputParams,
26 pub output: FnOutput,
27 pub is_async: bool,
28 pub is_mutable_self: bool,
29 pub body: syn::Block,
30 pub attrs: Vec<syn::Attribute>,
31}
32
33impl ImplMethod {
34 pub fn parse(
36 f: &ImplItemFn,
37 class: &Rc<CapabilityIdent>,
38 required_docs: bool,
39 ) -> syn::Result<Self> {
40 let sig = &f.sig;
41 let name = sig.ident.clone();
42
43 let has_docs = f.attrs.iter().any(|attr| attr.path().is_ident("doc"));
44 if !has_docs && required_docs {
45 return Err(Error::new_spanned(
46 &name,
47 "Capability methods must have documentation (///) to generate API specs.",
48 ));
49 }
50
51 let is_mutable_self = match sig.inputs.first() {
53 Some(FnArg::Receiver(r)) => {
54 if r.reference.is_none() {
55 return Err(Error::new_spanned(
56 r,
57 "Capability methods must take &self or &mut self (not value self)",
58 ));
59 }
60 r.mutability.is_some()
61 }
62 Some(arg) => {
63 return Err(Error::new_spanned(
64 arg,
65 "Capability methods must take &self or &mut self as first parameter",
66 ));
67 }
68 None => {
69 return Err(Error::new_spanned(
70 &sig,
71 "Capability methods must take &self or &mut self",
72 ));
73 }
74 };
75
76 let client_param_arg = sig.inputs.iter().nth(1);
78 let client_param_ident = match client_param_arg {
79 Some(FnArg::Typed(pt)) => {
80 let ident = if let Pat::Ident(pi) = &*pt.pat {
81 pi.ident.clone()
82 } else {
83 return Err(Error::new_spanned(&pt.pat, "Expected simple identifier"));
84 };
85
86 if let Type::Reference(r) = &*pt.ty {
87 let param_type = extract_ident_from_type(&r.elem)?;
88 if param_type != class.client_tn {
89 return Err(Error::new_spanned(
90 &pt.ty,
91 format!("Expected &{}, found &{}", class.client_tn, param_type),
92 ));
93 }
94 } else {
95 return Err(Error::new_spanned(
96 &pt.ty,
97 format!("Expected &{}", class.client_tn),
98 ));
99 }
100 ident
101 }
102 Some(arg) => {
103 return Err(Error::new_spanned(
104 arg,
105 format!("Expected client: &{}", class.client_tn),
106 ));
107 }
108 None => {
109 return Err(Error::new_spanned(
110 &sig,
111 format!(
112 "Capability methods must take client: &{} as second parameter",
113 class.client_tn
114 ),
115 ));
116 }
117 };
118
119 let mut inputs = Vec::new();
121 for arg in sig.inputs.iter().skip(2) {
122 if let FnArg::Typed(pt) = arg {
123 let arg_name = if let Pat::Ident(pi) = &*pt.pat {
124 pi.ident.clone()
125 } else {
126 return Err(Error::new_spanned(
127 &pt.pat,
128 "Method arguments must be named identifiers",
129 ));
130 };
131 inputs.push((arg_name, (*pt.ty).clone()));
132 }
133 }
134
135 let inputs = if inputs.is_empty() {
136 InputParams::None
137 } else if inputs.len() == 1 {
138 let (n, t) = inputs.pop().unwrap();
139 InputParams::One(n, t)
140 } else {
141 InputParams::Many(inputs)
142 };
143
144 let output = FnOutput::parse(&sig.output, class.error_tn.as_ref())?;
146
147 Ok(Self {
148 name: FnName(name),
149 class: class.clone(),
150 client_param: client_param_ident,
151 inputs,
152 output,
153 is_async: sig.asyncness.is_some(),
154 is_mutable_self,
155 body: f.block.clone(),
156 attrs: f.attrs.clone(),
157 })
158 }
159
160 pub fn generate_input_struct(&self) -> TokenStream {
161 self.inputs.input_struct(&self.name, Some(&self.class))
162 }
163
164 pub fn generate_server_method(&self) -> TokenStream {
167 let name = &self.name.0;
168 let attrs = &self.attrs;
169 let client_type = &self.class.client_tn;
170 let client_var = &self.client_param;
171 let body = &self.body;
172 let output = &self.output.to_return_type();
173
174 let async_kw = if self.is_async {
175 quote!(async)
176 } else {
177 quote!()
178 };
179
180 let self_arg = if self.is_mutable_self {
182 quote!(&mut self)
183 } else {
184 quote!(&self)
185 };
186
187 let args: Vec<_> = self.inputs.iter().map(|(n, t)| quote!(#n: #t)).collect();
188
189 quote! {
190 #(#attrs)*
191 pub #async_kw fn #name(#self_arg, #client_var: &#client_type, #(#args),*) #output #body
192 }
193 }
194
195 pub fn generate_server_ffi(&self) -> TokenStream {
196 let fn_ffi_name = self.class.ffi_name(&self.name);
197 let input_struct = self.inputs.input_struct(&self.name, Some(&self.class));
198 let state_tn = &self.class.state_tn;
199 let client_tn = &self.class.client_tn;
200 let mut call_args = Vec::new();
201
202 let state_retrieval = quote! {
203 let state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
204 Ok(state) => state,
205 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
206 };
207 let state = state_ptr.as_ref::<#state_tn>();
208 };
209 let client_retrieval = quote! {
211 let client: #client_tn = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
212 Ok(buf) => buf,
213 Err(err) => return err.encode().view(),
214 };
215 };
216 call_args.push(quote! { &client });
217 let input_retrieval = match &self.inputs {
218 InputParams::One(_, ty) => {
219 call_args.push(quote!(input));
220 quote! {
221 let input: #ty = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
222 Ok(buf) => buf,
223 Err(err) => return err.encode().view(),
224 };
225 }
226 }
227 InputParams::Many(items) => {
228 let input_struct_name = self.class.input_struct(&self.name);
229 let args = items.iter().map(|(n, _)| quote!(input.#n));
230 call_args.extend(args);
231 quote! {
232 let input: #input_struct_name = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
233 Ok(buf) => buf,
234 Err(err) => return err.encode().view(),
235 };
236 }
237 }
238 InputParams::None => quote! {},
239 };
240 let fn_name = &self.name.0;
241 let method_call = quote!(state.#fn_name(#(#call_args),*));
242
243 let (ffi_ret, body) = match (self.is_async, &self.class.error_tn) {
245 (true, Some(_)) => (
246 quote!(::pyroduct::ffi::FuturePyroView),
247 quote! {
248 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
249 #state_retrieval
250 #client_retrieval
251 #input_retrieval
252 ::pyroduct::ffi::guest::serialize_result(#method_call.await)
253 }, capability_state_ptr.object_id, mux_id)
254 },
255 ),
256 (false, Some(_)) => (
257 quote!(::pyroduct::format::PyroViewPtr),
258 quote! {
259 ::pyroduct::ffi::guest::execute_safe(|| {
260 #state_retrieval
261 #client_retrieval
262 #input_retrieval
263 ::pyroduct::ffi::guest::serialize_result(#method_call)
264 }, capability_state_ptr.object_id, mux_id)
265 },
266 ),
267 (true, None) => (
268 quote!(::pyroduct::ffi::FuturePyroView),
269 quote! {
270 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
271 #state_retrieval
272 #client_retrieval
273 #input_retrieval
274 ::pyroduct::ffi::guest::serialize_output(#method_call.await)
275 }, capability_state_ptr.object_id, mux_id)
276 },
277 ),
278 (false, None) => (
279 quote!(::pyroduct::format::PyroViewPtr),
280 quote! {
281 ::pyroduct::ffi::guest::execute_safe(|| {
282 #state_retrieval
283 #client_retrieval
284 #input_retrieval
285 ::pyroduct::ffi::guest::serialize_output(#method_call)
286 }, capability_state_ptr.object_id, mux_id)
287 },
288 ),
289 };
290
291 quote! {
292 #input_struct
293
294 #[unsafe(no_mangle)]
295 pub unsafe extern "C" fn #fn_ffi_name (
296 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
297 client_state_ptr: ::pyroduct::format::PyroRefPtr,
298 input_ptr: ::pyroduct::format::PyroRefPtr,
299 ) -> #ffi_ret {
300 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
301 #body
302 }
303 }
304 }
305
306 pub fn generate_vtable_entry(&self) -> TokenStream {
307 let fn_ffi_name = self.class.ffi_name(&self.name);
308 let wasm_name_ident = self.class.trace_name_static(&self.name);
309
310 let func_variant = if self.is_async {
311 quote! {
312 ::pyroduct::ffi::Function::Async(#fn_ffi_name)
313 }
314 } else {
315 quote! {
316 ::pyroduct::ffi::Function::Sync(#fn_ffi_name)
317 }
318 };
319
320 quote! {
321 ::pyroduct::ffi::MethodExport {
322 name: #wasm_name_ident.as_ptr(),
323 name_len: #wasm_name_ident.len(),
324 func: #func_variant,
325 }
326 }
327 }
328
329 pub fn generate_client_method(&self, module: Option<&Ident>) -> TokenStream {
332 let name = &self.name.0;
333 let attrs = &self.attrs;
334
335 let wasm_call = self.class.wasm_name(&self.name);
336 let wasm_call = match module {
337 Some(m) => quote! {#m::#wasm_call},
338 None => quote! {#wasm_call},
339 };
340
341 let wasm_call = quote! {
342 |client_state_ptr: *const u8,
343 input_ptr: *const u8| {
344 unsafe {
345 #wasm_call(client_state_ptr, input_ptr)
346 }
347 }
348 };
349
350 let mut args: Vec<_> = self.inputs.iter().map(|(n, t)| quote!(#n: #t)).collect();
352 args.insert(0, quote!(&self));
353
354 let i_struct = self.inputs.input_struct(&self.name, Some(&self.class));
355 let i_name = self.inputs.input_type(&self.name, Some(&self.class));
356 let i_fill = self
357 .inputs
358 .input_serialization(&self.name, Some(&self.class));
359 match &self.class.error_tn {
360 Some(err_tn) => {
361 let output_type = match &self.output {
362 FnOutput::Result(output_type, _) => output_type,
363 _ => unreachable!(),
364 };
365 let output_return = &self.output.to_return_type();
366 quote! {
367 #(#attrs)*
368 fn #name(#(#args),*) #output_return {
369 #i_struct
370
371 self.__call_result_from_wasm::<#i_name, #output_type, #err_tn, _>(#i_fill, #wasm_call)
372 }
373 }
374 }
375 None => {
376 let output_type = &self.output.ty();
377 let output_return = &self.output.to_return_type();
378 match self.output.err() {
379 Some(err) => {
380 quote! {
381 #(#attrs)*
382 fn #name(#(#args),*) #output_return {
383 #i_struct
384
385 self.__call_result_from_wasm::<#i_name, #output_type, #err, _>(#i_fill, #wasm_call)
386 }
387 }
388 }
389 None => {
390 quote! {
391 #(#attrs)*
392 fn #name(#(#args),*) #output_return {
393 #i_struct
394
395 self.__call_from_wasm::<#i_name, #output_type, _>(#i_fill, #wasm_call)
396 }
397 }
398 }
399 }
400 }
401 }
402 }
403
404 pub fn generate_client_wasm(&self) -> TokenStream {
405 let fn_wasm_name = self.class.wasm_name(&self.name);
406 quote! {
407 pub fn #fn_wasm_name(
408 cs_ptr: *const u8,
409 in_ptr: *const u8,
410 ) -> *mut u8;
411 }
412 }
413
414 pub fn doc_attrs(&self) -> Vec<&syn::Attribute> {
415 self.attrs
416 .iter()
417 .filter(|attr| attr.path().is_ident("doc"))
418 .collect()
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use crate::fmt::assert_code_eq_token;
425
426 use super::*;
427 use quote::format_ident;
428 use syn::parse_quote;
429
430 fn mock_class(error: Option<&str>) -> Rc<CapabilityIdent> {
431 Rc::new(CapabilityIdent {
432 pkg_name: "cap_name".to_string(),
433 pkg_version: "0.1.0".to_string(),
434 config_tn: None,
435 state_tn: format_ident!("MyServer"),
436 client_tn: format_ident!("MyClient"),
437 error_tn: error.map(|s| syn::parse_str(s).unwrap()),
438 })
439 }
440
441 #[test]
442 fn test_server_method_preserves_mutability() {
443 let class = mock_class(None);
444
445 let f: ImplItemFn = parse_quote! {
447 fn update(&mut self, ctx: &MyClient, val: u32) {
448 self.val = val;
449 }
450 };
451
452 let method = ImplMethod::parse(&f, &class, false).unwrap();
454 let output = method.generate_server_method();
455
456 let expected = quote! {
459 pub fn update(&mut self, ctx: &MyClient, val: u32) {
460 self.val = val;
461 }
462 };
463
464 assert_code_eq_token(&output, &expected);
465 }
466
467 #[test]
468 fn test_client_method_forces_immutability() {
469 let class = mock_class(None);
470 let module = format_ident!("wasm_bridge");
471
472 let f: ImplItemFn = parse_quote! {
474 fn update(&mut self, ctx: &MyClient, val: u32) { }
475 };
476
477 let method = ImplMethod::parse(&f, &class, false).unwrap();
479 let output = method.generate_client_method(Some(&module));
480
481 let output_str = output.to_string();
482 assert!(output_str.contains("fn update (& self"));
483 assert!(!output_str.contains("& mut self"));
484 }
485
486 #[test]
487 fn test_parse_validates_client_arg_name_capture() {
488 let class = mock_class(None);
489
490 let f: ImplItemFn = parse_quote! {
492 fn get(&self, c: &MyClient) -> u32 { 10 }
493 };
494
495 let method = ImplMethod::parse(&f, &class, false).unwrap();
497 let output = method.generate_server_method();
498
499 let expected = quote! {
501 pub fn get(&self, c: &MyClient) -> u32 { 10 }
502 };
503
504 assert_code_eq_token(&output, &expected);
505 }
506
507 #[test]
508 fn test_reject_value_self() {
509 let class = mock_class(None);
510
511 let f: ImplItemFn = parse_quote! {
513 fn consume(self, _c: &MyClient) {}
514 };
515
516 let err = ImplMethod::parse(&f, &class, false).unwrap_err();
518 assert!(err.to_string().contains("not value self"));
519 }
520
521 fn mock_method_base(name: &str, is_async: bool, has_err: bool) -> ImplMethod {
522 let (error_tn, output) = if has_err {
523 (
524 Some(parse_quote!(MockError)),
525 FnOutput::Result(parse_quote!(u32), parse_quote!(MockError)),
526 )
527 } else {
528 (None, FnOutput::Single(parse_quote!(u32)))
529 };
530 let class = Rc::new(CapabilityIdent {
531 pkg_name: "cap_name".to_string(),
532 pkg_version: "0.1.0".to_string(),
533 config_tn: None,
534 state_tn: format_ident!("MockServer"),
535 client_tn: format_ident!("MockClient"),
536 error_tn,
537 });
538
539 ImplMethod {
540 name: FnName(format_ident!("{}", name)),
541 class,
542 client_param: format_ident!("client"),
543 inputs: InputParams::None,
544 output,
546 is_async,
547 is_mutable_self: false,
548 body: parse_quote!({ 0 }),
549 attrs: vec![],
550 }
551 }
552
553 #[test]
557 fn test_case_4_async_no_input_with_client() {
558 let ffi = mock_method_base("test_async_client", true, false);
559
560 let capability_tokens = ffi.generate_server_ffi();
561 let module_tokens = ffi.generate_client_method(None);
562 let module_tokens = quote! {
563 impl Mod {
564 #module_tokens
565 }
566 };
567
568 let output_capability = quote! {
569 #[unsafe(no_mangle)]
570 pub unsafe extern "C" fn p__mock_server__test_async_client__ffi(
571 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
572 client_state_ptr: ::pyroduct::format::PyroRefPtr,
573 input_ptr: ::pyroduct::format::PyroRefPtr,
574 ) -> ::pyroduct::ffi::FuturePyroView {
575 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
576 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
577 let state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
578 Ok(state) => state,
579 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
580 };
581 let state = state_ptr.as_ref::<MockServer>();
582
583 let client: MockClient = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
584 Ok(buf) => buf,
585 Err(err) => return err.encode().view(),
586 };
587 ::pyroduct::ffi::guest::serialize_output(state.test_async_client(&client).await)
588 }, capability_state_ptr.object_id, mux_id)
589 }
590 };
591
592 crate::fmt::assert_code_eq_token(&capability_tokens, &output_capability);
593
594 let output_module = quote! {
595 impl Mod {
596 fn test_async_client(&self) -> u32 {
597 self.__call_from_wasm::<
598 (),
599 u32,
600 _,
601 >(
602 None,
603 |client_state_ptr: *const u8,
604 input_ptr: *const u8| {
605 unsafe {
606 p__mock_server__test_async_client__wasm(
607 client_state_ptr,
608 input_ptr
609 )
610 }
611 }
612 )
613 }
614 }
615 };
616
617 crate::fmt::assert_code_eq_token(&module_tokens, &output_module);
619 }
620
621 #[test]
625 fn test_case_5_sync_single_input_with_client() {
626 let mut ffi = mock_method_base("test_sync_client_input", false, true);
627 ffi.inputs = InputParams::One(format_ident!("x"), parse_quote!(i32));
628 let capability_tokens = ffi.generate_server_ffi();
629 let module_tokens = ffi.generate_client_method(None);
630 let module_tokens = quote! {
631 impl Mod {
632 #module_tokens
633 }
634 };
635
636 let output_capability = quote! {
637 #[unsafe(no_mangle)]
638 pub unsafe extern "C" fn p__mock_server__test_sync_client_input__ffi(
639 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
640 client_state_ptr: ::pyroduct::format::PyroRefPtr,
641 input_ptr: ::pyroduct::format::PyroRefPtr,
642 ) -> ::pyroduct::format::PyroViewPtr {
643 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
644 ::pyroduct::ffi::guest::execute_safe(|| {
645 let state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
646 Ok(state) => state,
647 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
648 };
649 let state = state_ptr.as_ref::<MockServer>();
650
651 let client: MockClient = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
652 Ok(buf) => buf,
653 Err(err) => return err.encode().view(),
654 };
655
656 let input: i32 = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
657 Ok(buf) => buf,
658 Err(err) => return err.encode().view(),
659 };
660 ::pyroduct::ffi::guest::serialize_result(state.test_sync_client_input(&client, input))
661 }, capability_state_ptr.object_id, mux_id)
662 }
663 };
664
665 crate::fmt::assert_code_eq_token(&capability_tokens, &output_capability);
666
667 let output_module = quote! {
668 impl Mod {
669 fn test_sync_client_input(&self, x: i32) -> Result<u32, MockError> {
670 self.__call_result_from_wasm::<
671 i32,
672 u32,
673 MockError,
674 _,
675 >(
676 Some(&x),
677 |client_state_ptr: *const u8,
678 input_ptr: *const u8| {
679 unsafe {
680 p__mock_server__test_sync_client_input__wasm(
681 client_state_ptr,
682 input_ptr,
683 )
684 }
685 },
686 )
687 }
688 }
689 };
690
691 crate::fmt::assert_code_eq_token(&module_tokens, &output_module);
692 }
693
694 #[test]
698 fn test_case_full_sci() {
699 let mut ffi = mock_method_base("test_sci_multi", true, false);
700 ffi.inputs = InputParams::Many(vec![
701 (format_ident!("a"), parse_quote!(i32)),
702 (format_ident!("b"), parse_quote!(i32)),
703 ]);
704
705 let capability_tokens = ffi.generate_server_ffi();
706 let module_tokens = ffi.generate_client_method(None);
707 let module_tokens = quote! {
708 impl Mod {
709 #module_tokens
710 }
711 };
712
713 let output_capability = quote! {
714 #[::pyroduct::magma]
715 struct p__MockServer__TestSciMulti__Input {
716 pub a: i32,
717 pub b: i32,
718 }
719
720 #[unsafe(no_mangle)]
721 pub unsafe extern "C" fn p__mock_server__test_sci_multi__ffi(
722 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
723 client_state_ptr: ::pyroduct::format::PyroRefPtr,
724 input_ptr: ::pyroduct::format::PyroRefPtr,
725 ) -> ::pyroduct::ffi::FuturePyroView {
726 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(input_ptr);
727 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
728 let state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
729 Ok(state) => state,
730 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
731 };
732 let state = state_ptr.as_ref::<MockServer>();
733
734 let client: MockClient = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
735 Ok(buf) => buf,
736 Err(err) => return err.encode().view(),
737 };
738
739 let input: p__MockServer__TestSciMulti__Input = match ::pyroduct::ffi::guest::deserialize_input(input_ptr) {
740 Ok(buf) => buf,
741 Err(err) => return err.encode().view(),
742 };
743 ::pyroduct::ffi::guest::serialize_output(state.test_sci_multi(&client, input.a, input.b).await)
744 }, capability_state_ptr.object_id, mux_id)
745 }
746 };
747
748 crate::fmt::assert_code_eq_token(&capability_tokens, &output_capability);
750
751 let output_module = quote! {
752 impl Mod {
753 fn test_sci_multi(&self, a: i32, b: i32) -> u32 {
754 #[::pyroduct::magma]
755 struct p__MockServer__TestSciMulti__Input {
756 pub a: i32,
757 pub b: i32,
758 }
759 self.__call_from_wasm::<
760 p__MockServer__TestSciMulti__Input,
761 u32,
762 _,
763 >(
764 Some(
765 &p__MockServer__TestSciMulti__Input {
766 a,
767 b,
768 },
769 ),
770 |client_state_ptr: *const u8,
771 input_ptr: *const u8| {
772 unsafe {
773 p__mock_server__test_sci_multi__wasm(
774 client_state_ptr,
775 input_ptr,
776 )
777 }
778 },
779 )
780 }
781 }
782 };
783
784 crate::fmt::assert_code_eq_token(&module_tokens, &output_module);
786 }
787}