pyro-macro 0.2.2

Derive macros for Pyroduct
Documentation
//! Reset function parsing and validation
//!
//! Valid signatures:
//! - `fn reset(&mut self)`
//! - `async fn reset(&mut self)`

use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{Error, FnArg, Ident, ImplItemFn};

use heck::AsSnakeCase;

#[derive(Debug, Clone)]
pub struct ResetFn {
    pub is_async: bool,
    pub body: syn::Block,
    pub attrs: Vec<syn::Attribute>,
}

impl ResetFn {
    pub fn parse(f: &ImplItemFn) -> syn::Result<Self> {
        let sig = &f.sig;

        // 1. Validate name
        if sig.ident != "reset" {
            return Err(Error::new_spanned(
                &sig.ident,
                "Expected function named 'reset'",
            ));
        }

        // 2. Validate return type is Result<(), CapturedError> or Result<()>
        let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
        let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
        if ok_str != "()" {
            return Err(Error::new_spanned(
                &sig.output,
                "fn reset must return Result<(), CapturedError> or Result<()>",
            ));
        }

        // 3. Validate &mut self as first and only parameter
        if sig.inputs.len() != 1 {
            return Err(Error::new_spanned(
                &sig.inputs,
                "fn reset must take exactly &mut self",
            ));
        }

        match sig.inputs.first() {
            Some(FnArg::Receiver(r)) => {
                if r.mutability.is_none() {
                    return Err(Error::new_spanned(
                        r,
                        "fn reset must take &mut self (not &self)",
                    ));
                }
                if r.reference.is_none() {
                    return Err(Error::new_spanned(
                        r,
                        "fn reset must take &mut self (not mut self)",
                    ));
                }
            }
            Some(arg) => {
                return Err(Error::new_spanned(
                    arg,
                    "fn reset must take &mut self as its only parameter",
                ));
            }
            None => {
                return Err(Error::new_spanned(sig, "fn reset must take &mut self"));
            }
        }

        Ok(Self {
            is_async: sig.asyncness.is_some(),
            body: f.block.clone(),
            attrs: f.attrs.clone(),
        })
    }

    /// Generate the FFI reset function
    pub fn generate_ffi(&self, server: &Ident) -> TokenStream {
        let server_snake = AsSnakeCase(server.to_string()).to_string();
        let reset_name = format_ident!("p__{}__ffi_reset", server_snake);

        if self.is_async {
            quote! {
                #[unsafe(no_mangle)]
                pub unsafe extern "C" fn #reset_name(
                    capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
                ) -> ::pyroduct::ffi::FuturePyroView {
                    ::pyroduct::ffi::guest::execute_safe_async(|| async move {
                        let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
                            Ok(state) => state,
                            Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
                        };
                        let state = state_ptr.as_ref::<#server>();
                        match state.reset().await {
                            Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
                            Err(err) => err.encode().view(),
                        }
                    }, capability_state_ptr.object_id, 0)
                }
            }
        } else {
            quote! {
                #[unsafe(no_mangle)]
                pub unsafe extern "C" fn #reset_name(
                    capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
                ) -> ::pyroduct::format::PyroViewPtr {
                    ::pyroduct::ffi::guest::execute_safe(|| {
                        let mut state_ptr = match unsafe { ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr) } {
                            Ok(state) => state,
                            Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
                        };
                        let state = state_ptr.as_ref::<#server>();
                        match state.reset() {
                            Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
                            Err(err) => err.encode().view(),
                        }
                    }, capability_state_ptr.object_id, 0)
                }
            }
        }
    }

    /// Generate the export entry for the reset function
    pub fn generate_export(&self, server: &Ident) -> TokenStream {
        let server_snake = AsSnakeCase(server.to_string()).to_string();
        let reset_name = format_ident!("p__{}__ffi_reset", server_snake);

        if self.is_async {
            quote!(::pyroduct::ffi::ClassResetFn::Async(#reset_name))
        } else {
            quote!(::pyroduct::ffi::ClassResetFn::Sync(#reset_name))
        }
    }

    /// Generate the impl method (preserves original)
    pub fn generate_impl_method(&self) -> TokenStream {
        let attrs = &self.attrs;
        let body = &self.body;
        let async_kw = if self.is_async {
            quote!(async)
        } else {
            quote!()
        };

        quote! {
            #(#attrs)*
            pub #async_kw fn reset(&mut self) -> Result<(), ::pyroduct::CapturedError> #body
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use quote::{format_ident, quote};
    use syn::{ImplItemFn, parse_quote};

    #[test]
    fn test_sync_server_reset_fn() {
        let server_ident = format_ident!("GreeterServer");
        let item: ImplItemFn = parse_quote! {
            fn reset(&mut self) -> Result<(), CapturedError> {
                self.count = 0;
                Ok(())
            }
        };
        let reset_fn = ResetFn::parse(&item).expect("Failed to parse reset fn");
        let result = reset_fn.generate_ffi(&server_ident);
        let expected = quote! {
            #[unsafe(no_mangle)]
            pub unsafe extern "C" fn p__greeter_server__ffi_reset(
                capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
            ) -> ::pyroduct::format::PyroViewPtr {
                ::pyroduct::ffi::guest::execute_safe(|| {
                    let mut state_ptr = match unsafe {
                        ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr)
                    } {
                        Ok(state) => state,
                        Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
                    };
                    let state = state_ptr.as_ref::<GreeterServer>();
                    match state.reset() {
                        Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
                        Err(err) => err.encode().view(),
                    }
                }, capability_state_ptr.object_id, 0)
            }
        };

        crate::fmt::assert_code_eq_token(&result, &expected);
    }

    #[test]
    fn test_async_server_reset_fn() {
        let server_ident = format_ident!("GreeterServer");
        let item: ImplItemFn = parse_quote! {
            async fn reset(&mut self) -> Result<(), CapturedError> {
                self.count = 0;
                Ok(())
            }
        };

        let reset_fn = ResetFn::parse(&item).expect("Failed to parse reset fn");
        let result = reset_fn.generate_ffi(&server_ident);
        let expected = quote! {
            #[unsafe(no_mangle)]
            pub unsafe extern "C" fn p__greeter_server__ffi_reset(
                capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
            ) -> ::pyroduct::ffi::FuturePyroView {
                ::pyroduct::ffi::guest::execute_safe_async(|| async move {
                    let mut state_ptr = match unsafe {
                        ::pyroduct::ffi::PyroObjectRef::from_raw(capability_state_ptr)
                    } {
                        Ok(state) => state,
                        Err(error) => return ::pyroduct::PyroError::CodePanic(error.into()).encode().view(),
                    };
                    let state = state_ptr.as_ref::<GreeterServer>();
                    match state.reset().await {
                        Ok(()) => ::pyroduct::format::PyroVec::ok().view(),
                        Err(err) => err.encode().view(),
                    }
                }, capability_state_ptr.object_id, 0)
            }
        };

        crate::fmt::assert_code_eq_token(&result, &expected);
    }
}