Skip to main content

pyro_macro/ffi/lifecycle/
reset.rs

1//! Reset function parsing and validation
2//!
3//! Valid signatures:
4//! - `fn reset(&mut self)`
5//! - `async fn reset(&mut self)`
6
7use proc_macro2::TokenStream;
8use quote::{format_ident, quote};
9use syn::{Error, FnArg, Ident, ImplItemFn};
10
11use heck::AsSnakeCase;
12
13#[derive(Debug, Clone)]
14pub struct ResetFn {
15    pub is_async: bool,
16    pub body: syn::Block,
17    pub attrs: Vec<syn::Attribute>,
18}
19
20impl ResetFn {
21    pub fn parse(f: &ImplItemFn) -> syn::Result<Self> {
22        let sig = &f.sig;
23
24        // 1. Validate name
25        if sig.ident != "reset" {
26            return Err(Error::new_spanned(
27                &sig.ident,
28                "Expected function named 'reset'",
29            ));
30        }
31
32        // 2. Validate return type is Result<(), CapturedError> or Result<()>
33        let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
34        let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
35        if ok_str != "()" {
36            return Err(Error::new_spanned(
37                &sig.output,
38                "fn reset must return Result<(), CapturedError> or Result<()>",
39            ));
40        }
41
42        // 3. Validate &mut self as first and only parameter
43        if sig.inputs.len() != 1 {
44            return Err(Error::new_spanned(
45                &sig.inputs,
46                "fn reset must take exactly &mut self",
47            ));
48        }
49
50        match sig.inputs.first() {
51            Some(FnArg::Receiver(r)) => {
52                if r.mutability.is_none() {
53                    return Err(Error::new_spanned(
54                        r,
55                        "fn reset must take &mut self (not &self)",
56                    ));
57                }
58                if r.reference.is_none() {
59                    return Err(Error::new_spanned(
60                        r,
61                        "fn reset must take &mut self (not mut self)",
62                    ));
63                }
64            }
65            Some(arg) => {
66                return Err(Error::new_spanned(
67                    arg,
68                    "fn reset must take &mut self as its only parameter",
69                ));
70            }
71            None => {
72                return Err(Error::new_spanned(sig, "fn reset must take &mut self"));
73            }
74        }
75
76        Ok(Self {
77            is_async: sig.asyncness.is_some(),
78            body: f.block.clone(),
79            attrs: f.attrs.clone(),
80        })
81    }
82
83    /// Generate the FFI reset function
84    pub fn generate_ffi(&self, server: &Ident) -> TokenStream {
85        let server_snake = AsSnakeCase(server.to_string()).to_string();
86        let reset_name = format_ident!("p__{}__ffi_reset", server_snake);
87
88        if self.is_async {
89            quote! {
90                #[unsafe(no_mangle)]
91                pub unsafe extern "C" fn #reset_name(
92                    capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
93                ) -> ::pyroduct::ffi::FuturePyroView {
94                    ::pyroduct::ffi::guest::execute_safe_async(|| async move {
95                        let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
96                            Ok(state) => state,
97                            Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
98                        };
99                        let state = state_ptr.as_ref::<#server>();
100                        match state.reset().await {
101                            Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
102                            Err(err) => err.encode().view(),
103                        }
104                    }, capability_state_ptr.object_id, 0)
105                }
106            }
107        } else {
108            quote! {
109                #[unsafe(no_mangle)]
110                pub unsafe extern "C" fn #reset_name(
111                    capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
112                ) -> ::pyroduct::format::PyroViewPtr {
113                    ::pyroduct::ffi::guest::execute_safe(|| {
114                        let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
115                            Ok(state) => state,
116                            Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
117                        };
118                        let state = state_ptr.as_ref::<#server>();
119                        match state.reset() {
120                            Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
121                            Err(err) => err.encode().view(),
122                        }
123                    }, capability_state_ptr.object_id, 0)
124                }
125            }
126        }
127    }
128
129    /// Generate the export entry for the reset function
130    pub fn generate_export(&self, server: &Ident) -> TokenStream {
131        let server_snake = AsSnakeCase(server.to_string()).to_string();
132        let reset_name = format_ident!("p__{}__ffi_reset", server_snake);
133
134        if self.is_async {
135            quote!(::pyroduct::ffi::ClassResetFn::Async(#reset_name))
136        } else {
137            quote!(::pyroduct::ffi::ClassResetFn::Sync(#reset_name))
138        }
139    }
140
141    /// Generate the impl method (preserves original)
142    pub fn generate_impl_method(&self) -> TokenStream {
143        let attrs = &self.attrs;
144        let body = &self.body;
145        let async_kw = if self.is_async {
146            quote!(async)
147        } else {
148            quote!()
149        };
150
151        quote! {
152            #(#attrs)*
153            pub #async_kw fn reset(&mut self) -> Result<(), ::pyroduct::CapturedError> #body
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use quote::{format_ident, quote};
162    use syn::{ImplItemFn, parse_quote};
163
164    #[test]
165    fn test_sync_server_reset_fn() {
166        let server_ident = format_ident!("GreeterServer");
167        let item: ImplItemFn = parse_quote! {
168            fn reset(&mut self) -> Result<(), CapturedError> {
169                self.count = 0;
170                Ok(())
171            }
172        };
173        let reset_fn = ResetFn::parse(&item).expect("Failed to parse reset fn");
174        let result = reset_fn.generate_ffi(&server_ident);
175        let expected = quote! {
176            #[unsafe(no_mangle)]
177            pub unsafe extern "C" fn p__greeter_server__ffi_reset(
178                capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
179            ) -> ::pyroduct::format::PyroViewPtr {
180                ::pyroduct::ffi::guest::execute_safe(|| {
181                    let mut state_ptr = match unsafe {
182                        ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr)
183                    } {
184                        Ok(state) => state,
185                        Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
186                    };
187                    let state = state_ptr.as_ref::<GreeterServer>();
188                    match state.reset() {
189                        Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
190                        Err(err) => err.encode().view(),
191                    }
192                }, capability_state_ptr.object_id, 0)
193            }
194        };
195
196        crate::fmt::assert_code_eq_token(&result, &expected);
197    }
198
199    #[test]
200    fn test_async_server_reset_fn() {
201        let server_ident = format_ident!("GreeterServer");
202        let item: ImplItemFn = parse_quote! {
203            async fn reset(&mut self) -> Result<(), CapturedError> {
204                self.count = 0;
205                Ok(())
206            }
207        };
208
209        let reset_fn = ResetFn::parse(&item).expect("Failed to parse reset fn");
210        let result = reset_fn.generate_ffi(&server_ident);
211        let expected = quote! {
212            #[unsafe(no_mangle)]
213            pub unsafe extern "C" fn p__greeter_server__ffi_reset(
214                capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
215            ) -> ::pyroduct::ffi::FuturePyroView {
216                ::pyroduct::ffi::guest::execute_safe_async(|| async move {
217                    let mut state_ptr = match unsafe {
218                        ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr)
219                    } {
220                        Ok(state) => state,
221                        Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
222                    };
223                    let state = state_ptr.as_ref::<GreeterServer>();
224                    match state.reset().await {
225                        Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
226                        Err(err) => err.encode().view(),
227                    }
228                }, capability_state_ptr.object_id, 0)
229            }
230        };
231
232        crate::fmt::assert_code_eq_token(&result, &expected);
233    }
234}