1use std::rc::Rc;
17
18use heck::AsSnakeCase;
19use proc_macro2::TokenStream;
20use quote::{format_ident, quote};
21use syn::{Error, Ident, ImplItem, ItemImpl, Type, parse_quote};
22
23use crate::{
24 ffi::{
25 lifecycle::{InitFn, NewClientFn, ResetFn},
26 methods::ImplMethod,
27 paths::CapabilityIdent,
28 },
29 utils::extract_ident_from_type,
30};
31
32#[derive(Debug)]
34pub struct CapabilityImpl {
35 pub ident: Rc<CapabilityIdent>,
37
38 pub init_fn: InitFn,
40 pub reset_fn: ResetFn,
41 pub register_fn: NewClientFn,
42
43 pub methods: Vec<ImplMethod>,
45
46 pub other_items: Vec<ImplItem>,
48 pub attrs: Vec<syn::Attribute>,
49}
50
51impl CapabilityImpl {
52 pub fn new(
53 input: ItemImpl,
54 required_docs: bool,
55 cap_name: &str,
56 cap_semver: &str,
57 ) -> syn::Result<Self> {
58 let state_tn =
60 match &*input.self_ty {
61 Type::Path(tp) => tp.path.get_ident().cloned().ok_or_else(|| {
62 Error::new_spanned(&input.self_ty, "Expected simple type name")
63 })?,
64 _ => {
65 return Err(Error::new_spanned(
66 &input.self_ty,
67 "Expected simple type name",
68 ));
69 }
70 };
71
72 if input.trait_.is_some() {
74 return Err(Error::new_spanned(
75 &input,
76 "#[capability] cannot be used on trait implementations",
77 ));
78 }
79 let attrs = input.attrs.clone();
80
81 let mut client_tn: Option<Ident> = None;
83 let mut config_tn: Option<Ident> = None;
84
85 let mut init_fn: Option<InitFn> = None;
86 let mut reset_fn: Option<ResetFn> = None;
87 let mut register_fn: Option<NewClientFn> = None;
88 let mut method_fns = Vec::new();
89 let mut other_items = Vec::new();
90
91 for item in &input.items {
92 if let ImplItem::Type(ty) = item {
93 if ty.ident == "Client" {
94 client_tn = Some(extract_ident_from_type(&ty.ty)?);
95 } else if ty.ident == "Config" {
96 config_tn = Some(extract_ident_from_type(&ty.ty)?);
97 }
98 }
101 }
102
103 let client_tn = client_tn
104 .ok_or_else(|| Error::new_spanned(&state_tn, "Missing `type Client = ...;`"))?;
105
106 let ident = Rc::new(CapabilityIdent {
108 pkg_name: cap_name.to_string(),
109 pkg_version: cap_semver.to_string(),
110 state_tn,
111 client_tn,
112 config_tn,
113 });
114
115 for item in &input.items {
116 match item {
117 ImplItem::Fn(f) => {
118 let name = f.sig.ident.to_string();
119 match name.as_str() {
120 "new" => {
121 let conf = ident.config_tn.clone().map(|t| parse_quote! { #t });
122 init_fn = Some(InitFn::parse(conf, f)?);
123 }
124 "reset" => {
125 reset_fn = Some(ResetFn::parse(f)?);
126 }
127 "register" => {
128 register_fn = Some(NewClientFn::parse(f, &ident)?);
129 }
130 _ => {
131 method_fns.push(f.clone());
133 }
134 }
135 }
136 ImplItem::Type(_) => {
137 }
140 other => other_items.push(other.clone()),
141 }
142 }
143
144 let register_fn = register_fn.ok_or_else(|| {
145 Error::new_spanned(
146 &ident.state_tn,
147 "Missing `fn register(&self, client: &Client)`",
148 )
149 })?;
150 let init_fn = init_fn.ok_or_else(|| {
151 Error::new_spanned(
152 &ident.state_tn,
153 "Missing `fn new() -> Self` or `fn new(config: &Config) -> Self`",
154 )
155 })?;
156 let reset_fn = reset_fn
157 .ok_or_else(|| Error::new_spanned(&ident.state_tn, "Missing `fn reset(&mut self)`"))?;
158
159 let methods: Result<Vec<_>, _> = method_fns
161 .iter()
162 .map(|f| ImplMethod::parse(f, &ident, required_docs))
163 .collect();
164 let methods = methods?;
165
166 Ok(Self {
167 ident,
168 init_fn,
169 reset_fn,
170 register_fn,
171 methods,
172 other_items,
173 attrs,
174 })
175 }
176
177 pub fn expand_capability(&self) -> TokenStream {
179 let server_impl = self.generate_server_impl();
180 let lifecycle_ffi = self.generate_lifecycle_ffi();
181 let method_ffis = self.generate_method_ffis();
182 let export_table = self.generate_export_table();
183
184 quote! {
185 #server_impl
186 #lifecycle_ffi
187 #method_ffis
188 #export_table
189 }
190 }
191
192 pub fn expand_module(&self) -> TokenStream {
194 let wasm_imports = self.generate_wasm_imports();
195 let client_impl = self.generate_client_impl();
196
197 quote! {
198 #client_impl
199 #wasm_imports
200 }
201 }
202
203 fn generate_server_impl(&self) -> TokenStream {
204 let server = &self.ident.state_tn;
205 let init_method = self.init_fn.generate_impl_method();
206 let reset_method = self.reset_fn.generate_impl_method();
207 let new_client_method = self.register_fn.generate_impl_method();
208 let other_items = &self.other_items;
209
210 let methods: Vec<_> = self
211 .methods
212 .iter()
213 .map(|m| m.generate_server_method())
214 .collect();
215
216 quote! {
217 impl #server {
218 #init_method
219 #reset_method
220 #new_client_method
221 #(#other_items)*
222 #(#methods)*
223 }
224 }
225 }
226
227 fn generate_client_impl(&self) -> TokenStream {
228 let client = &self.ident.client_tn;
229 let module = format_ident!("wasm");
230
231 let client_impl = self.register_fn.generate_client_impl(Some(&module));
233
234 let trait_name = format_ident!("{}Methods", client);
236
237 let trait_methods: Vec<_> = self
238 .methods
239 .iter()
240 .map(|m| {
241 let name = &m.name.0;
242 let output = &m.output.to_return_type();
243 let args: Vec<_> = m.inputs.iter().map(|(n, t)| quote!(#n: #t)).collect();
244 let docs = m.doc_attrs();
245
246 quote! {
247 #(#docs)*
248 fn #name(&self, #(#args),*) #output;
249 }
250 })
251 .collect();
252
253 let trait_def = quote! {
254 pub trait #trait_name {
255 #(#trait_methods)*
256 }
257 };
258
259 let method_impls: Vec<_> = self
261 .methods
262 .iter()
263 .map(|m| m.generate_client_method(Some(&module)))
264 .collect();
265
266 let trait_impl = quote! {
267 impl #trait_name for ::pyroduct::wasm::Client<#client> {
268 #(#method_impls)*
269 }
270 };
271
272 quote! {
273 #client_impl
274 #trait_def
275 #trait_impl
276 }
277 }
278
279 fn generate_lifecycle_ffi(&self) -> TokenStream {
280 let server = &self.ident.state_tn;
281
282 let init_ffi = self.init_fn.generate_ffi(server);
283 let reset_ffi = self.reset_fn.generate_ffi(server);
284 let register_ffi = self.register_fn.generate_capability_ffi();
285
286 quote! {
287 #init_ffi
288 #reset_ffi
289 #register_ffi
290 }
291 }
292
293 fn generate_method_ffis(&self) -> TokenStream {
294 let method_ffis: Vec<_> = self
295 .methods
296 .iter()
297 .map(|m| m.generate_server_ffi())
298 .collect();
299
300 quote! {
301 #(#method_ffis)*
302 }
303 }
304
305 fn generate_export_table(&self) -> TokenStream {
306 let cap_id = self.ident.cap_id();
307
308 let server = &self.ident.state_tn;
309 let server_snake = AsSnakeCase(server.to_string()).to_string();
310 let server_upper = server_snake.to_uppercase();
311
312 let class_name_static = format_ident!("p__{}", server_upper);
313 let class_name_string = format!("{}", server_snake);
314
315 let static_strs: Vec<_> = self
316 .methods
317 .iter()
318 .map(|m| {
319 let trace_name = self.ident.wasm_name(&m.name).to_string();
320 let static_name = self.ident.trace_name_static(&m.name);
321 quote! { const #static_name: &'static str = #trace_name; }
322 })
323 .collect();
324
325 let exports: Vec<_> = self
326 .methods
327 .iter()
328 .map(|ffi| ffi.generate_vtable_entry())
329 .collect();
330
331 let num_exports = exports.len();
332 let exports_array_name = format_ident!("{}__METHODS", class_name_static);
333
334 let init_export = self.init_fn.generate_export(server);
335 let reset_export = self.reset_fn.generate_export(server);
336 let register_export = self.register_fn.generate_export();
337
338 let capability_manifest_fn = quote! {
339 #[unsafe(no_mangle)]
340 pub extern "C" fn pyro_capability_manifest(
341 id: i64,
342 log_callback: ::pyroduct::ffi::LogCallback,
343 ) -> ::pyroduct::ffi::ClassExport {
344 ::pyroduct::ffi::guest::logger::init_logging(id, log_callback);
345
346 ::pyroduct::ffi::ClassExport {
347 name: #class_name_static.as_ptr(),
348 name_len: #class_name_static.len(),
349 len: #exports_array_name.len(),
350 ptr: #exports_array_name.as_ptr() as *mut _,
351 init: #init_export,
352 reset: #reset_export,
353 register: #register_export,
354 }
355 }
356 };
357
358 quote! {
359 const CAPABILITY_NAME_VERSION: &'static str = #cap_id;
360 const #class_name_static: &'static str = #class_name_string;
361 #(#static_strs)*
362
363 const #exports_array_name: [::pyroduct::ffi::MethodExport; #num_exports] = [
364 #(#exports),*
365 ];
366
367 #capability_manifest_fn
368 }
369 }
370
371 fn generate_wasm_imports(&self) -> TokenStream {
372 let class_id = self.ident.class_name();
373 let new_client_decl = self.register_fn.generate_client_wasm();
374
375 let method_decls: Vec<_> = self
376 .methods
377 .iter()
378 .map(|m| m.generate_client_wasm())
379 .collect();
380
381 quote! {
382 mod wasm {
383 use super::*;
384 #[link(wasm_import_module = #class_id)]
385 unsafe extern "C" {
386 #new_client_decl
387 #(#method_decls)*
388 }
389 }
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use syn::parse2;
398
399 #[test]
400 fn test_basic_capability_impl() {
401 let code = quote! {
402 impl StatefulServer {
403 type Client = SimpleClient;
404
405 fn new() -> Result<Self, CapturedError> { Ok(Self) }
406 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
407 fn register(&self, _client: &SimpleClient) -> Result<(), CapturedError> { Ok(()) }
408 fn call(&self, _client: &SimpleClient) -> Result<f32, CapturedError> { Ok(42.0) }
409 }
410 };
411
412 let input: ItemImpl = parse2(code).unwrap();
413 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
414
415 assert_eq!(cap.ident.state_tn.to_string(), "StatefulServer");
416 assert_eq!(cap.ident.client_tn.to_string(), "SimpleClient");
417 assert_eq!(cap.methods.len(), 1);
418 assert_eq!(cap.methods[0].name.to_string(), "call");
419 assert!(!cap.init_fn.is_async);
420 assert!(cap.init_fn.config_type.is_none());
421 assert!(cap.ident.config_tn.is_none());
422 }
423
424 #[test]
425 fn test_with_config() {
426 let code = quote! {
427 impl StatefulServer {
428 type Config = MyConfig;
429 type Client = SimpleClient;
430
431 fn new(config: Option<MyConfig>) -> Result<Self, CapturedError> { Ok(Self) }
432 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
433 fn register(&self, client: &SimpleClient) -> Result<(), CapturedError> { Ok(()) }
434 }
435 };
436
437 let input: ItemImpl = parse2(code).unwrap();
438 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
439
440 assert!(cap.init_fn.config_type.is_some());
441 assert!(cap.ident.config_tn.is_some());
442
443 let cfg = cap.ident.config_tn.as_ref().unwrap();
444 assert_eq!(quote!(#cfg).to_string(), "MyConfig");
445 }
446
447 #[test]
448 fn test_config_mismatch() {
449 let code = quote! {
450 impl StatefulServer {
451 type Config = MyConfig;
452 type Client = SimpleClient;
453
454 fn new(config: Option<OtherConfig>) -> Result<Self, CapturedError> { Ok(Self) }
455 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
456 fn register(&self, client: &SimpleClient) -> Result<(), CapturedError> { Ok(()) }
457 }
458 };
459
460 let input: ItemImpl = parse2(code).unwrap();
461 let err = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap_err();
462 println!("{}", err);
463 assert!(err.to_string().contains("Type mismatch. Expected 'Option<MyConfig>' based on macro attribute, found 'Option<OtherConfig>'"));
464 }
465
466 #[test]
467 fn test_async_lifecycle() {
468 let code = quote! {
469 impl StatefulServer {
470 type Client = SimpleClient;
471
472 async fn new() -> Result<Self, CapturedError> { Ok(Self) }
473 async fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
474 fn register(&self, client: &SimpleClient) -> Result<(), CapturedError> { Ok(()) }
475 }
476 };
477
478 let input: ItemImpl = parse2(code).unwrap();
479 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
480
481 assert!(cap.init_fn.is_async);
482 assert!(cap.reset_fn.is_async);
483 }
484
485 #[test]
486 fn test_with_error_type_fails() {
487 let code = quote! {
488 impl StatefulServer {
489 type Client = SimpleClient;
490 type Error = MyError;
491
492 fn new() -> Result<Self, CapturedError> { Ok(Self) }
493 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
494 fn register(&self, client: &SimpleClient) -> Result<(), MyError> { Ok(()) }
495 }
496 };
497
498 let input: ItemImpl = parse2(code).unwrap();
499 let err = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap_err();
500 assert!(
501 err.to_string()
502 .contains("Invalid error type. Expected 'CapturedError', found 'MyError'")
503 );
504 }
505
506 #[test]
507 fn test_with_captured_error() {
508 let code = quote! {
509 impl StatefulServer {
510 type Client = SimpleClient;
511
512 fn new() -> Result<Self, CapturedError> { Ok(Self) }
513 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
514 fn register(&self, client: &SimpleClient) -> Result<(), CapturedError> { Ok(()) }
515 fn fallible(&self, _client: &SimpleClient) -> Result<u32, CapturedError> { Ok(42) }
516 }
517 };
518
519 let input: ItemImpl = parse2(code).unwrap();
520 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
521
522 assert_eq!(cap.methods.len(), 1);
523 }
524
525 #[test]
526 fn test_generate_export_table() {
527 let code = quote! {
528 impl TestServer {
529 type Client = TestClient;
530
531 fn new() -> Result<Self, CapturedError> { Ok(Self) }
532 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
533 fn register(&self, client: &TestClient) -> Result<(), CapturedError> { Ok(()) }
534 fn get_value(&self, client: &TestClient) -> Result<u32, CapturedError> { Ok(0) }
535 }
536 };
537
538 let input: ItemImpl = parse2(code).unwrap();
539 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
540
541 let output = cap.generate_export_table();
542
543 let expected = quote! {
544 const CAPABILITY_NAME_VERSION: &'static str = "cap_name";
545 const p__TEST_SERVER: &'static str = "test_server";
546 const p__TEST_SERVER__GET_VALUE: &'static str = "p__test_server__get_value__wasm";
547
548 const p__TEST_SERVER__METHODS: [::pyroduct::ffi::MethodExport; 1usize] = [
549 ::pyroduct::ffi::MethodExport {
550 name: p__TEST_SERVER__GET_VALUE.as_ptr(),
551 name_len: p__TEST_SERVER__GET_VALUE.len(),
552 func: ::pyroduct::ffi::Function::Sync(p__test_server__get_value__ffi),
553 }
554 ];
555
556 #[unsafe(no_mangle)]
557 pub extern "C" fn pyro_capability_manifest(
558 id: i64,
559 log_callback: ::pyroduct::ffi::LogCallback,
560 ) -> ::pyroduct::ffi::ClassExport {
561 ::pyroduct::ffi::guest::logger::init_logging(id, log_callback);
562
563 ::pyroduct::ffi::ClassExport {
564 name: p__TEST_SERVER.as_ptr(),
565 name_len: p__TEST_SERVER.len(),
566 len: p__TEST_SERVER__METHODS.len(),
567 ptr: p__TEST_SERVER__METHODS.as_ptr() as *mut _,
568 init: ::pyroduct::ffi::ClassInitFn::Sync(p__test_server__ffi_init),
569 reset: ::pyroduct::ffi::ClassResetFn::Sync(p__test_server__ffi_reset),
570 register: ::pyroduct::ffi::ClientRegisterFn::Sync(p__test_server__register__ffi),
571 }
572 }
573 };
574
575 crate::fmt::assert_code_eq_token(&output, &expected);
576
577 let output = cap.generate_wasm_imports();
578
579 let expected = quote! {
580 mod wasm {
581 use super::*;
582 #[link(wasm_import_module = "test_server")]
583 unsafe extern "C" {
584 pub fn register(ptr: *const u8) -> *mut u8;
585 pub fn p__test_server__get_value__wasm(
586 cs_ptr: *const u8,
587 in_ptr: *const u8,
588 ) -> *mut u8;
589 }
590 }
591 };
592 crate::fmt::assert_code_eq_token(&output, &expected);
593 }
594
595 #[test]
596 fn test_generate_client_impl_integration() {
597 let code = quote! {
599 impl MyState {
600 type Client = MyClient;
601 type Config = MyConfig;
602
603 fn new(config: Option<MyConfig>) -> Result<Self, CapturedError> { Ok(Self) }
604 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
605 fn register(&self, client: &MyClient) -> Result<(), CapturedError> { Ok(()) }
606 fn get_info(&self, client: &MyClient) -> Result<u32, CapturedError> { Ok(0) }
607 fn get_other_info(&self, client: &MyClient, data: f32) -> Result<u32, CapturedError> { Ok(0) }
608 }
609 };
610
611 let input: ItemImpl = parse2(code).unwrap();
613 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
614
615 let output = cap.generate_client_impl();
617
618 let expected = quote! {
621 impl MyClient {
622 pub fn register(self) -> Result<::pyroduct::wasm::Client<Self>, ::pyroduct::CapturedError> {
623 ::pyroduct::wasm::Client::<Self>::__register_result(self, |ptr| unsafe { wasm::register(ptr) })
624 }
625 }
626 pub trait MyClientMethods {
627 fn get_info(&self) -> Result<u32, ::pyroduct::CapturedError>;
628 fn get_other_info(&self, data: f32) -> Result<u32, ::pyroduct::CapturedError>;
629 }
630 impl MyClientMethods for ::pyroduct::wasm::Client<MyClient> {
631 fn get_info(&self) -> Result<u32, ::pyroduct::CapturedError> {
632 self.__call_result_from_wasm::<(), u32, _>(None,
633 |client_state_ptr: *const u8,
634 input_ptr: *const u8| {
635 unsafe {
636 wasm::p__my_state__get_info__wasm(
637 client_state_ptr,
638 input_ptr,
639 )
640 }
641 })
642 }
643
644 fn get_other_info(&self, data: f32) -> Result<u32, ::pyroduct::CapturedError> {
645 self.__call_result_from_wasm::<
646 f32,
647 u32,
648 _,
649 >(Some(&data),
650 |client_state_ptr: *const u8,
651 input_ptr: *const u8| {
652 unsafe {
653 wasm::p__my_state__get_other_info__wasm(
654 client_state_ptr,
655 input_ptr,
656 )
657 }
658 },
659 )
660 }
661 }
662 };
663
664 crate::fmt::assert_code_eq_token(&output, &expected);
665 }
666
667 #[test]
668 fn test_generate_client_impl_with_error_and_input_structs() {
669 let code = quote! {
671 impl AdvancedStruct {
672 type Client = AdvancedClient;
673
674 fn new() -> Result<Self, CapturedError> { Ok(Self) }
675 fn reset(&mut self) -> Result<(), CapturedError> { Ok(()) }
676
677 fn register(&self, client: &AdvancedClient) -> Result<(), CapturedError> {
678 Ok(())
679 }
680
681 async fn process(&self, client: &AdvancedClient, val: u32, flag: bool) -> Result<u32, CapturedError> {
682 Ok(val)
683 }
684 }
685 };
686
687 let input: ItemImpl = parse2(code).unwrap();
689 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
690
691 let output = cap.generate_client_impl();
693
694 let expected = quote! {
698 impl AdvancedClient {
699 pub fn register(self) -> Result<::pyroduct::wasm::Client<Self>, ::pyroduct::CapturedError> {
700 ::pyroduct::wasm::Client::<Self>::__register_result(self, |ptr| unsafe { wasm::register(ptr) })
701 }
702 }
703 pub trait AdvancedClientMethods {
704 fn process(&self, val: u32, flag: bool) -> Result<u32, ::pyroduct::CapturedError>;
705 }
706 impl AdvancedClientMethods for ::pyroduct::wasm::Client<AdvancedClient> {
707 fn process(&self, val: u32, flag: bool) -> Result<u32, ::pyroduct::CapturedError> {
708 #[::pyroduct::magma]
709 struct p__AdvancedStruct__Process__Input {
710 pub val: u32,
711 pub flag: bool
712 }
713
714 self.__call_result_from_wasm::<
715 p__AdvancedStruct__Process__Input,
716 u32,
717 _
718 >(
719 Some(&p__AdvancedStruct__Process__Input { val, flag }),
720 |client_state_ptr: *const u8,
721 input_ptr: *const u8| {
722 unsafe {
723 wasm::p__advanced_struct__process__wasm(
724 client_state_ptr,
725 input_ptr,
726 )
727 }
728 }
729 )
730 }
731 }
732 };
733
734 crate::fmt::assert_code_eq_token(&output, &expected);
735 }
736}