use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{Field, FieldMutability, Fields, Ident, ItemStruct, Type, Visibility, parse_macro_input};
struct ObservableArgs {
state_type: Type,
error_type: Type,
}
impl ObservableArgs {
fn parse(args: TokenStream) -> syn::Result<Self> {
let args_str = args.to_string();
let parts = args_str.split(',').map(|s| s.trim());
let mut state_type = None;
let mut error_type = None;
for part in parts {
if let Some((key, value)) = part.split_once('=') {
let key = key.trim();
let value = value.trim();
match key {
"state" => {
state_type = Some(syn::parse_str::<Type>(value)?);
}
"error" => {
error_type = Some(syn::parse_str::<Type>(value)?);
}
_ => {
return Err(syn::Error::new(
Span::call_site(),
format!("unknown parameter '{}', expected 'state' or 'error'", key),
));
}
}
} else {
return Err(syn::Error::new(
Span::call_site(),
format!(
"invalid parameter format '{}', expected 'key = value'",
part
),
));
}
}
Ok(Self {
state_type: state_type
.ok_or_else(|| syn::Error::new(Span::call_site(), "missing 'state' parameter"))?,
error_type: error_type
.ok_or_else(|| syn::Error::new(Span::call_site(), "missing 'error' parameter"))?,
})
}
}
pub fn generate(args: TokenStream, input: TokenStream) -> TokenStream {
let args = match ObservableArgs::parse(args) {
Ok(args) => args,
Err(e) => return e.to_compile_error().into(),
};
let mut input_struct = parse_macro_input!(input as ItemStruct);
let struct_name = &input_struct.ident;
let registry_field = Field {
attrs: Vec::new(),
vis: Visibility::Inherited,
mutability: FieldMutability::None,
ident: Some(Ident::new("registry", Span::call_site())),
colon_token: Some(syn::token::Colon(Span::call_site())),
ty: syn::parse_str::<Type>(&format!(
"::rust_patterns::ObserverRegistry<{}>",
struct_name
))
.unwrap(),
};
match &mut input_struct.fields {
Fields::Named(fields) => {
fields.named.push(registry_field);
}
Fields::Unnamed(_fields) => {
return syn::Error::new_spanned(
struct_name,
"#[observable] can only be applied to structs with named fields",
)
.to_compile_error()
.into();
}
Fields::Unit => {
return syn::Error::new_spanned(
struct_name,
"#[observable] can only be applied to structs with named fields",
)
.to_compile_error()
.into();
}
}
let state_type = &args.state_type;
let error_type = &args.error_type;
let expanded = quote! {
#input_struct
impl ::rust_patterns::Observable for #struct_name {
type State = #state_type;
type Error = #error_type;
fn attach(&mut self, observer: ::std::sync::Arc<dyn ::rust_patterns::Observer<Subject = Self>>) {
self.registry.attach(observer);
}
fn detach(&mut self, observer: ::std::sync::Arc<dyn ::rust_patterns::Observer<Subject = Self>>) {
self.registry.detach(observer);
}
}
impl #struct_name {
#[inline]
fn notify(&self, state: &#state_type) -> Result<(), #error_type> {
self.registry.notify(state)
}
#[inline]
fn notify_ignore_error(&self, state: &#state_type) {
self.registry.notify_ignore_error(state)
}
}
};
expanded.into()
}