1use std::rc::Rc;
17
18use heck::AsSnakeCase;
19use proc_macro2::TokenStream;
20use quote::{format_ident, quote};
21use syn::{Error, Ident, ImplItem, ItemImpl, Type, parse_quote};
22
23use crate::{
24 ffi::{
25 lifecycle::{InitFn, NewClientFn, ResetFn},
26 methods::ImplMethod,
27 paths::CapabilityIdent,
28 },
29 utils::extract_ident_from_type,
30};
31
32#[derive(Debug)]
34pub struct CapabilityImpl {
35 pub ident: Rc<CapabilityIdent>,
37
38 pub init_fn: InitFn,
40 pub reset_fn: ResetFn,
41 pub register_fn: NewClientFn,
42
43 pub methods: Vec<ImplMethod>,
45
46 pub other_items: Vec<ImplItem>,
48 pub attrs: Vec<syn::Attribute>,
49}
50
51impl CapabilityImpl {
52 pub fn new(
53 input: ItemImpl,
54 required_docs: bool,
55 cap_name: &str,
56 cap_semver: &str,
57 ) -> syn::Result<Self> {
58 let state_tn =
60 match &*input.self_ty {
61 Type::Path(tp) => tp.path.get_ident().cloned().ok_or_else(|| {
62 Error::new_spanned(&input.self_ty, "Expected simple type name")
63 })?,
64 _ => {
65 return Err(Error::new_spanned(
66 &input.self_ty,
67 "Expected simple type name",
68 ));
69 }
70 };
71
72 if input.trait_.is_some() {
74 return Err(Error::new_spanned(
75 &input,
76 "#[capability] cannot be used on trait implementations",
77 ));
78 }
79 let attrs = input.attrs.clone();
80
81 let mut client_tn: Option<Ident> = None;
83 let mut config_tn: Option<Ident> = None;
84 let mut error_tn: Option<Type> = None;
85
86 let mut init_fn: Option<InitFn> = None;
87 let mut reset_fn: Option<ResetFn> = None;
88 let mut register_fn: Option<NewClientFn> = None;
89 let mut method_fns = Vec::new();
90 let mut other_items = Vec::new();
91
92 for item in &input.items {
93 match item {
94 ImplItem::Type(ty) => {
95 if ty.ident == "Client" {
96 client_tn = Some(extract_ident_from_type(&ty.ty)?);
97 } else if ty.ident == "Config" {
98 config_tn = Some(extract_ident_from_type(&ty.ty)?);
99 } else if ty.ident == "Error" {
100 error_tn = Some(ty.ty.clone());
101 }
102 }
105 _ => {}
106 }
107 }
108
109 let client_tn = client_tn
110 .ok_or_else(|| Error::new_spanned(&state_tn, "Missing `type Client = ...;`"))?;
111
112 let ident = Rc::new(CapabilityIdent {
114 pkg_name: cap_name.to_string(),
115 pkg_version: cap_semver.to_string(),
116 state_tn,
117 client_tn,
118 config_tn,
119 error_tn,
120 });
121
122 for item in &input.items {
123 match item {
124 ImplItem::Fn(f) => {
125 let name = f.sig.ident.to_string();
126 match name.as_str() {
127 "new" => {
128 let conf = ident.config_tn.clone().map(|t| parse_quote! { #t });
129 init_fn = Some(InitFn::parse(conf, f)?);
130 }
131 "reset" => {
132 reset_fn = Some(ResetFn::parse(f)?);
133 }
134 "register" => {
135 register_fn = Some(NewClientFn::parse(f, &ident)?);
136 }
137 _ => {
138 method_fns.push(f.clone());
140 }
141 }
142 }
143 ImplItem::Type(_) => {
144 }
147 other => other_items.push(other.clone()),
148 }
149 }
150
151 let register_fn = register_fn.ok_or_else(|| {
152 Error::new_spanned(
153 &ident.state_tn,
154 "Missing `fn register(&self, client: &Client)`",
155 )
156 })?;
157 let init_fn = init_fn.ok_or_else(|| {
158 Error::new_spanned(
159 &ident.state_tn,
160 "Missing `fn new() -> Self` or `fn new(config: &Config) -> Self`",
161 )
162 })?;
163 let reset_fn = reset_fn
164 .ok_or_else(|| Error::new_spanned(&ident.state_tn, "Missing `fn reset(&mut self)`"))?;
165
166 let methods: Result<Vec<_>, _> = method_fns
168 .iter()
169 .map(|f| ImplMethod::parse(f, &ident, required_docs))
170 .collect();
171 let methods = methods?;
172
173 Ok(Self {
174 ident,
175 init_fn,
176 reset_fn,
177 register_fn,
178 methods,
179 other_items,
180 attrs,
181 })
182 }
183
184 pub fn expand_capability(&self) -> TokenStream {
186 let server_impl = self.generate_server_impl();
187 let lifecycle_ffi = self.generate_lifecycle_ffi();
188 let method_ffis = self.generate_method_ffis();
189 let export_table = self.generate_export_table();
190
191 quote! {
192 #server_impl
193 #lifecycle_ffi
194 #method_ffis
195 #export_table
196 }
197 }
198
199 pub fn expand_module(&self) -> TokenStream {
201 let wasm_imports = self.generate_wasm_imports();
202 let client_impl = self.generate_client_impl();
203
204 quote! {
205 #client_impl
206 #wasm_imports
207 }
208 }
209
210 fn generate_server_impl(&self) -> TokenStream {
211 let server = &self.ident.state_tn;
212 let init_method = self.init_fn.generate_impl_method();
213 let reset_method = self.reset_fn.generate_impl_method();
214 let new_client_method = self.register_fn.generate_impl_method();
215 let other_items = &self.other_items;
216
217 let methods: Vec<_> = self
218 .methods
219 .iter()
220 .map(|m| m.generate_server_method())
221 .collect();
222
223 quote! {
224 impl #server {
225 #init_method
226 #reset_method
227 #new_client_method
228 #(#other_items)*
229 #(#methods)*
230 }
231 }
232 }
233
234 fn generate_client_impl(&self) -> TokenStream {
235 let client = &self.ident.client_tn;
236 let module = format_ident!("wasm");
237
238 let client_impl = self.register_fn.generate_client_impl(Some(&module));
240
241 let trait_name = format_ident!("{}Methods", client);
243
244 let trait_methods: Vec<_> = self
245 .methods
246 .iter()
247 .map(|m| {
248 let name = &m.name.0;
249 let output = &m.output.to_return_type();
250 let args: Vec<_> = m.inputs.iter().map(|(n, t)| quote!(#n: #t)).collect();
251 let docs = m.doc_attrs();
252
253 quote! {
254 #(#docs)*
255 fn #name(&self, #(#args),*) #output;
256 }
257 })
258 .collect();
259
260 let trait_def = quote! {
261 pub trait #trait_name {
262 #(#trait_methods)*
263 }
264 };
265
266 let method_impls: Vec<_> = self
268 .methods
269 .iter()
270 .map(|m| m.generate_client_method(Some(&module)))
271 .collect();
272
273 let trait_impl = quote! {
274 impl #trait_name for ::pyroduct::wasm::Client<#client> {
275 #(#method_impls)*
276 }
277 };
278
279 quote! {
280 #client_impl
281 #trait_def
282 #trait_impl
283 }
284 }
285
286 fn generate_lifecycle_ffi(&self) -> TokenStream {
287 let server = &self.ident.state_tn;
288
289 let init_ffi = self.init_fn.generate_ffi(server);
290 let reset_ffi = self.reset_fn.generate_ffi(server);
291 let register_ffi = self.register_fn.generate_capability_ffi();
292
293 quote! {
294 #init_ffi
295 #reset_ffi
296 #register_ffi
297 }
298 }
299
300 fn generate_method_ffis(&self) -> TokenStream {
301 let method_ffis: Vec<_> = self
302 .methods
303 .iter()
304 .map(|m| m.generate_server_ffi())
305 .collect();
306
307 quote! {
308 #(#method_ffis)*
309 }
310 }
311
312 fn generate_export_table(&self) -> TokenStream {
313 let cap_id = self.ident.cap_id();
314
315 let server = &self.ident.state_tn;
316 let server_snake = AsSnakeCase(server.to_string()).to_string();
317 let server_upper = server_snake.to_uppercase();
318
319 let class_name_static = format_ident!("p__{}", server_upper);
320 let class_name_string = format!("p__{}", server_snake);
321
322 let static_strs: Vec<_> = self
323 .methods
324 .iter()
325 .map(|m| {
326 let trace_name = self.ident.wasm_name(&m.name).to_string();
327 let static_name = self.ident.trace_name_static(&m.name);
328 quote! { const #static_name: &'static str = #trace_name; }
329 })
330 .collect();
331
332 let exports: Vec<_> = self
333 .methods
334 .iter()
335 .map(|ffi| ffi.generate_vtable_entry())
336 .collect();
337
338 let num_exports = exports.len();
339 let exports_array_name = format_ident!("{}__METHODS", class_name_static);
340
341 let init_export = self.init_fn.generate_export(server);
342 let reset_export = self.reset_fn.generate_export(server);
343 let register_export = self.register_fn.generate_export();
344
345 let capability_manifest_fn = quote! {
346 #[unsafe(no_mangle)]
347 pub extern "C" fn pyro_capability_manifest(
348 id: i64,
349 log_callback: ::pyroduct::ffi::LogCallback,
350 ) -> ::pyroduct::ffi::ClassExport {
351 ::pyroduct::ffi::guest::logger::init_logging(id, log_callback);
352
353 ::pyroduct::ffi::ClassExport {
354 name: CAPABILITY_NAME_VERSION.as_ptr(),
355 name_len: CAPABILITY_NAME_VERSION.len(),
356 len: #exports_array_name.len(),
357 ptr: #exports_array_name.as_ptr() as *mut _,
358 init: #init_export,
359 reset: #reset_export,
360 register: #register_export,
361 }
362 }
363 };
364
365 quote! {
366 const CAPABILITY_NAME_VERSION: &'static str = #cap_id;
367 const #class_name_static: &'static str = #class_name_string;
368 #(#static_strs)*
369
370 const #exports_array_name: [::pyroduct::ffi::MethodExport; #num_exports] = [
371 #(#exports),*
372 ];
373
374 #capability_manifest_fn
375 }
376 }
377
378 fn generate_wasm_imports(&self) -> TokenStream {
379 let cap_id = self.ident.cap_id();
380 let new_client_decl = self.register_fn.generate_client_wasm();
381
382 let method_decls: Vec<_> = self
383 .methods
384 .iter()
385 .map(|m| m.generate_client_wasm())
386 .collect();
387
388 quote! {
389 mod wasm {
390 use super::*;
391 #[link(wasm_import_module = #cap_id)]
392 unsafe extern "C" {
393 #new_client_decl
394 #(#method_decls)*
395 }
396 }
397 }
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use syn::parse2;
405
406 #[test]
407 fn test_basic_capability_impl() {
408 let code = quote! {
409 impl StatefulServer {
410 type Client = SimpleClient;
411
412 fn new() -> Self { Self }
413 fn reset(&mut self) {}
414 fn register(&self, _client: &SimpleClient) {}
415 fn call(&self, _client: &SimpleClient) -> f32 { 42.0 }
416 }
417 };
418
419 let input: ItemImpl = parse2(code).unwrap();
420 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
421
422 assert_eq!(cap.ident.state_tn.to_string(), "StatefulServer");
423 assert_eq!(cap.ident.client_tn.to_string(), "SimpleClient");
424 assert_eq!(cap.methods.len(), 1);
425 assert_eq!(cap.methods[0].name.to_string(), "call");
426 assert!(!cap.init_fn.is_async);
427 assert!(cap.init_fn.config_type.is_none());
428 assert!(cap.ident.config_tn.is_none());
429 }
430
431 #[test]
432 fn test_with_config() {
433 let code = quote! {
434 impl StatefulServer {
435 type Config = MyConfig;
436 type Client = SimpleClient;
437
438 fn new(config: Option<MyConfig>) -> Self { Self }
439 fn reset(&mut self) {}
440 fn register(&self, client: &SimpleClient) {}
441 }
442 };
443
444 let input: ItemImpl = parse2(code).unwrap();
445 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
446
447 assert!(cap.init_fn.config_type.is_some());
448 assert!(cap.ident.config_tn.is_some());
449
450 let cfg = cap.ident.config_tn.as_ref().unwrap();
451 assert_eq!(quote!(#cfg).to_string(), "MyConfig");
452 }
453
454 #[test]
455 fn test_config_mismatch() {
456 let code = quote! {
457 impl StatefulServer {
458 type Config = MyConfig;
459 type Client = SimpleClient;
460
461 fn new(config: Option<OtherConfig>) -> Self { Self }
462 fn reset(&mut self) {}
463 fn register(&self, client: &SimpleClient) {}
464 }
465 };
466
467 let input: ItemImpl = parse2(code).unwrap();
468 let err = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap_err();
469 println!("{}", err);
470 assert!(err.to_string().contains("Type mismatch. Expected 'Option<MyConfig>' based on macro attribute, found 'Option<OtherConfig>'"));
471 }
472
473 #[test]
474 fn test_async_lifecycle() {
475 let code = quote! {
476 impl StatefulServer {
477 type Client = SimpleClient;
478
479 async fn new() -> Self { Self }
480 async fn reset(&mut self) {}
481 fn register(&self, client: &SimpleClient) {}
482 }
483 };
484
485 let input: ItemImpl = parse2(code).unwrap();
486 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
487
488 assert!(cap.init_fn.is_async);
489 assert!(cap.reset_fn.is_async);
490 }
491
492 #[test]
493 fn test_with_error_type() {
494 let code = quote! {
495 impl StatefulServer {
496 type Client = SimpleClient;
497 type Error = MyError;
498
499 fn new() -> Self { Self }
500 fn reset(&mut self) {}
501 fn register(&self, client: &SimpleClient) -> Result<(), MyError> { Ok(()) }
502 fn fallible(&self, _client: &SimpleClient) -> Result<u32, MyError> { Ok(42) }
503 }
504 };
505
506 let input: ItemImpl = parse2(code).unwrap();
507 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
508
509 assert!(cap.ident.error_tn.is_some());
510 assert!(cap.register_fn.error_type.is_some());
511 assert_eq!(cap.methods.len(), 1);
512 }
513
514 #[test]
515 fn test_generate_export_table() {
516 let code = quote! {
517 impl TestServer {
518 type Client = TestClient;
519
520 fn new() -> Self { Self }
521 fn reset(&mut self) {}
522 fn register(&self, client: &TestClient) {}
523 fn get_value(&self, client: &TestClient) -> u32 { 0 }
524 }
525 };
526
527 let input: ItemImpl = parse2(code).unwrap();
528 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
529
530 let output = cap.generate_export_table();
531
532 let expected = quote! {
533 const CAPABILITY_NAME_VERSION: &'static str = "cap_name";
534 const p__TEST_SERVER: &'static str = "p__test_server";
535 const p__TEST_SERVER__GET_VALUE: &'static str = "p__test_server__get_value__wasm";
536
537 const p__TEST_SERVER__METHODS: [::pyroduct::ffi::MethodExport; 1usize] = [
538 ::pyroduct::ffi::MethodExport {
539 name: p__TEST_SERVER__GET_VALUE.as_ptr(),
540 name_len: p__TEST_SERVER__GET_VALUE.len(),
541 func: ::pyroduct::ffi::Function::Sync(p__test_server__get_value__ffi),
542 }
543 ];
544
545 #[unsafe(no_mangle)]
546 pub extern "C" fn pyro_capability_manifest(
547 id: i64,
548 log_callback: ::pyroduct::ffi::LogCallback,
549 ) -> ::pyroduct::ffi::ClassExport {
550 ::pyroduct::ffi::guest::logger::init_logging(id, log_callback);
551
552 ::pyroduct::ffi::ClassExport {
553 name: CAPABILITY_NAME_VERSION.as_ptr(),
554 name_len: CAPABILITY_NAME_VERSION.len(),
555 len: p__TEST_SERVER__METHODS.len(),
556 ptr: p__TEST_SERVER__METHODS.as_ptr() as *mut _,
557 init: ::pyroduct::ffi::ClassInitFn::Sync(p__test_server__ffi_init),
558 reset: ::pyroduct::ffi::ClassResetFn::Sync(p__test_server__ffi_reset),
559 register: ::pyroduct::ffi::ClientRegisterFn::Sync(p__test_server__register__ffi),
560 }
561 }
562 };
563
564 crate::fmt::assert_code_eq_token(&output, &expected);
565 }
566
567 #[test]
568 fn test_generate_client_impl_integration() {
569 let code = quote! {
571 impl MyState {
572 type Client = MyClient;
573 type Config = MyConfig;
574
575 fn new(config: Option<MyConfig>) -> Self { Self }
576 fn reset(&mut self) {}
577 fn register(&self, client: &MyClient) {}
578 fn get_info(&self, client: &MyClient) -> u32 { 0 }
579 fn get_other_info(&self, client: &MyClient, data: f32) -> u32 { 0 }
580 }
581 };
582
583 let input: ItemImpl = parse2(code).unwrap();
585 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
586
587 let output = cap.generate_client_impl();
589
590 let expected = quote! {
593 impl MyClient {
594 pub fn register(self) -> ::pyroduct::wasm::Client<Self> {
595 ::pyroduct::wasm::Client::<Self>::__register(self, |ptr| unsafe { wasm::register(ptr) })
596 }
597 }
598 pub trait MyClientMethods {
599 fn get_info(&self) -> u32;
600 fn get_other_info(&self, data: f32) -> u32;
601 }
602 impl MyClientMethods for ::pyroduct::wasm::Client<MyClient> {
603 fn get_info(&self) -> u32 {
604 self.__call_from_wasm::<(), u32, _>(None,
605 |client_state_ptr: *const u8,
606 input_ptr: *const u8| {
607 unsafe {
608 wasm::p__my_state__get_info__wasm(
609 client_state_ptr,
610 input_ptr,
611 )
612 }
613 })
614 }
615
616 fn get_other_info(&self, data: f32) -> u32 {
617 self.__call_from_wasm::<
618 f32,
619 u32,
620 _,
621 >(Some(&data),
622 |client_state_ptr: *const u8,
623 input_ptr: *const u8| {
624 unsafe {
625 wasm::p__my_state__get_other_info__wasm(
626 client_state_ptr,
627 input_ptr,
628 )
629 }
630 },
631 )
632 }
633 }
634 };
635
636 crate::fmt::assert_code_eq_token(&output, &expected);
637 }
638
639 #[test]
640 fn test_generate_client_impl_with_error_and_input_structs() {
641 let code = quote! {
643 impl AdvancedStruct {
644 type Client = AdvancedClient;
645 type Error = MyError;
646
647 fn new() -> Self { Self }
648 fn reset(&mut self) {}
649
650 fn register(&self, client: &AdvancedClient) -> Result<(), MyError> {
651 Ok(())
652 }
653
654 async fn process(&self, client: &AdvancedClient, val: u32, flag: bool) -> Result<u32, MyError> {
655 Ok(val)
656 }
657 }
658 };
659
660 let input: ItemImpl = parse2(code).unwrap();
662 let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
663
664 let output = cap.generate_client_impl();
666
667 let expected = quote! {
671 impl AdvancedClient {
672 pub fn register(self) -> Result<::pyroduct::wasm::Client<Self>, MyError> {
673 ::pyroduct::wasm::Client::<Self>::__register_result::<MyError, _>(self, |ptr| unsafe { wasm::register(ptr) })
674 }
675 }
676 pub trait AdvancedClientMethods {
677 fn process(&self, val: u32, flag: bool) -> Result<u32, MyError>;
678 }
679 impl AdvancedClientMethods for ::pyroduct::wasm::Client<AdvancedClient> {
680 fn process(&self, val: u32, flag: bool) -> Result<u32, MyError> {
681 #[::pyroduct::magma]
682 struct p__AdvancedStruct__Process__Input {
683 pub val: u32,
684 pub flag: bool
685 }
686
687 self.__call_result_from_wasm::<
688 p__AdvancedStruct__Process__Input,
689 u32,
690 MyError,
691 _
692 >(
693 Some(&p__AdvancedStruct__Process__Input { val, flag }),
694 |client_state_ptr: *const u8,
695 input_ptr: *const u8| {
696 unsafe {
697 wasm::p__advanced_struct__process__wasm(
698 client_state_ptr,
699 input_ptr,
700 )
701 }
702 }
703 )
704 }
705 }
706 };
707
708 crate::fmt::assert_code_eq_token(&output, &expected);
709 }
710}