pyro_macro/ffi/lifecycle/
reset.rs1use proc_macro2::TokenStream;
8use quote::{format_ident, quote};
9use syn::{Error, FnArg, Ident, ImplItemFn};
10
11use heck::AsSnakeCase;
12
13#[derive(Debug, Clone)]
14pub struct ResetFn {
15 pub is_async: bool,
16 pub body: syn::Block,
17 pub attrs: Vec<syn::Attribute>,
18}
19
20impl ResetFn {
21 pub fn parse(f: &ImplItemFn) -> syn::Result<Self> {
22 let sig = &f.sig;
23
24 if sig.ident != "reset" {
26 return Err(Error::new_spanned(
27 &sig.ident,
28 "Expected function named 'reset'",
29 ));
30 }
31
32 let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
34 let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
35 if ok_str != "()" {
36 return Err(Error::new_spanned(
37 &sig.output,
38 "fn reset must return Result<(), CapturedError> or Result<()>",
39 ));
40 }
41
42 if sig.inputs.len() != 1 {
44 return Err(Error::new_spanned(
45 &sig.inputs,
46 "fn reset must take exactly &mut self",
47 ));
48 }
49
50 match sig.inputs.first() {
51 Some(FnArg::Receiver(r)) => {
52 if r.mutability.is_none() {
53 return Err(Error::new_spanned(
54 r,
55 "fn reset must take &mut self (not &self)",
56 ));
57 }
58 if r.reference.is_none() {
59 return Err(Error::new_spanned(
60 r,
61 "fn reset must take &mut self (not mut self)",
62 ));
63 }
64 }
65 Some(arg) => {
66 return Err(Error::new_spanned(
67 arg,
68 "fn reset must take &mut self as its only parameter",
69 ));
70 }
71 None => {
72 return Err(Error::new_spanned(sig, "fn reset must take &mut self"));
73 }
74 }
75
76 Ok(Self {
77 is_async: sig.asyncness.is_some(),
78 body: f.block.clone(),
79 attrs: f.attrs.clone(),
80 })
81 }
82
83 pub fn generate_ffi(&self, server: &Ident) -> TokenStream {
85 let server_snake = AsSnakeCase(server.to_string()).to_string();
86 let reset_name = format_ident!("p__{}__ffi_reset", server_snake);
87
88 if self.is_async {
89 quote! {
90 #[unsafe(no_mangle)]
91 pub unsafe extern "C" fn #reset_name(
92 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
93 ) -> ::pyroduct::ffi::FuturePyroView {
94 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
95 let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
96 Ok(state) => state,
97 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
98 };
99 let state = state_ptr.as_ref::<#server>();
100 match state.reset().await {
101 Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
102 Err(err) => err.encode().view(),
103 }
104 }, capability_state_ptr.object_id, 0)
105 }
106 }
107 } else {
108 quote! {
109 #[unsafe(no_mangle)]
110 pub unsafe extern "C" fn #reset_name(
111 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
112 ) -> ::pyroduct::format::PyroViewPtr {
113 ::pyroduct::ffi::guest::execute_safe(|| {
114 let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
115 Ok(state) => state,
116 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
117 };
118 let state = state_ptr.as_ref::<#server>();
119 match state.reset() {
120 Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
121 Err(err) => err.encode().view(),
122 }
123 }, capability_state_ptr.object_id, 0)
124 }
125 }
126 }
127 }
128
129 pub fn generate_export(&self, server: &Ident) -> TokenStream {
131 let server_snake = AsSnakeCase(server.to_string()).to_string();
132 let reset_name = format_ident!("p__{}__ffi_reset", server_snake);
133
134 if self.is_async {
135 quote!(::pyroduct::ffi::ClassResetFn::Async(#reset_name))
136 } else {
137 quote!(::pyroduct::ffi::ClassResetFn::Sync(#reset_name))
138 }
139 }
140
141 pub fn generate_impl_method(&self) -> TokenStream {
143 let attrs = &self.attrs;
144 let body = &self.body;
145 let async_kw = if self.is_async {
146 quote!(async)
147 } else {
148 quote!()
149 };
150
151 quote! {
152 #(#attrs)*
153 pub #async_kw fn reset(&mut self) -> Result<(), ::pyroduct::CapturedError> #body
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use quote::{format_ident, quote};
162 use syn::{ImplItemFn, parse_quote};
163
164 #[test]
165 fn test_sync_server_reset_fn() {
166 let server_ident = format_ident!("GreeterServer");
167 let item: ImplItemFn = parse_quote! {
168 fn reset(&mut self) -> Result<(), CapturedError> {
169 self.count = 0;
170 Ok(())
171 }
172 };
173 let reset_fn = ResetFn::parse(&item).expect("Failed to parse reset fn");
174 let result = reset_fn.generate_ffi(&server_ident);
175 let expected = quote! {
176 #[unsafe(no_mangle)]
177 pub unsafe extern "C" fn p__greeter_server__ffi_reset(
178 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
179 ) -> ::pyroduct::format::PyroViewPtr {
180 ::pyroduct::ffi::guest::execute_safe(|| {
181 let mut state_ptr = match unsafe {
182 ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr)
183 } {
184 Ok(state) => state,
185 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
186 };
187 let state = state_ptr.as_ref::<GreeterServer>();
188 match state.reset() {
189 Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
190 Err(err) => err.encode().view(),
191 }
192 }, capability_state_ptr.object_id, 0)
193 }
194 };
195
196 crate::fmt::assert_code_eq_token(&result, &expected);
197 }
198
199 #[test]
200 fn test_async_server_reset_fn() {
201 let server_ident = format_ident!("GreeterServer");
202 let item: ImplItemFn = parse_quote! {
203 async fn reset(&mut self) -> Result<(), CapturedError> {
204 self.count = 0;
205 Ok(())
206 }
207 };
208
209 let reset_fn = ResetFn::parse(&item).expect("Failed to parse reset fn");
210 let result = reset_fn.generate_ffi(&server_ident);
211 let expected = quote! {
212 #[unsafe(no_mangle)]
213 pub unsafe extern "C" fn p__greeter_server__ffi_reset(
214 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
215 ) -> ::pyroduct::ffi::FuturePyroView {
216 ::pyroduct::ffi::guest::execute_safe_async(|| async move {
217 let mut state_ptr = match unsafe {
218 ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr)
219 } {
220 Ok(state) => state,
221 Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
222 };
223 let state = state_ptr.as_ref::<GreeterServer>();
224 match state.reset().await {
225 Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
226 Err(err) => err.encode().view(),
227 }
228 }, capability_state_ptr.object_id, 0)
229 }
230 };
231
232 crate::fmt::assert_code_eq_token(&result, &expected);
233 }
234}