Skip to main content

wa_rs_derive/
lib.rs

1//! Derive macros for wa_rs_core protocol types.
2//!
3//! This crate provides derive macros for implementing the `ProtocolNode` trait
4//! on structs that represent WhatsApp protocol nodes.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use wa_rs_derive::{ProtocolNode, StringEnum};
10//!
11//! /// A query request node.
12//! /// Wire format: `<query request="interactive"/>`
13//! #[derive(ProtocolNode)]
14//! #[protocol(tag = "query")]
15//! pub struct QueryRequest {
16//!     #[attr(name = "request", default = "interactive")]
17//!     pub request_type: String,
18//! }
19//!
20//! /// An enum with string representation.
21//! #[derive(StringEnum)]
22//! pub enum MemberAddMode {
23//!     #[str = "admin_add"]
24//!     AdminAdd,
25//!     #[str = "all_member_add"]
26//!     AllMemberAdd,
27//! }
28//! ```
29
30use proc_macro::TokenStream;
31use quote::quote;
32use syn::{Data, DeriveInput, Fields, parse_macro_input};
33
34/// Derive macro for implementing `ProtocolNode` on structs with attributes.
35///
36/// # Attributes
37///
38/// - `#[protocol(tag = "tagname")]` - Required. Specifies the XML tag name.
39/// - `#[attr(name = "attrname")]` - Marks a String field as an XML attribute.
40/// - `#[attr(name = "attrname", default = "value")]` - Attribute with default value.
41///   For `Option<String>` fields, a default always yields `Some(default)`.
42/// - `#[attr(name = "attrname", jid)]` - Marks a Jid field as a JID attribute (required).
43/// - `#[attr(name = "attrname", jid, optional)]` - Marks an Option<Jid> field as optional.
44///
45/// # Example
46///
47/// ```ignore
48/// #[derive(ProtocolNode)]
49/// #[protocol(tag = "message")]
50/// pub struct MessageStanza {
51///     #[attr(name = "from", jid)]
52///     pub from: Jid,
53///     
54///     #[attr(name = "to", jid)]
55///     pub to: Jid,
56///     
57///     #[attr(name = "id")]
58///     pub id: String,
59///     
60///     #[attr(name = "sender_lid", jid, optional)]
61///     pub sender_lid: Option<Jid>,
62/// }
63/// ```
64#[proc_macro_derive(ProtocolNode, attributes(protocol, attr))]
65pub fn derive_protocol_node(input: TokenStream) -> TokenStream {
66    let input = parse_macro_input!(input as DeriveInput);
67
68    let name = &input.ident;
69
70    let tag = match extract_tag(&input.attrs) {
71        Ok(Some(tag)) => tag,
72        Ok(None) => {
73            return syn::Error::new_spanned(
74                &input.ident,
75                "ProtocolNode requires #[protocol(tag = \"...\")]",
76            )
77            .to_compile_error()
78            .into();
79        }
80        Err(e) => return e.to_compile_error().into(),
81    };
82
83    let fields = match &input.data {
84        Data::Struct(data) => match &data.fields {
85            Fields::Named(fields) => &fields.named,
86            Fields::Unit => return generate_empty_impl(name, &tag).into(),
87            _ => {
88                return syn::Error::new_spanned(
89                    &input.ident,
90                    "ProtocolNode only supports named fields or unit structs",
91                )
92                .to_compile_error()
93                .into();
94            }
95        },
96        _ => {
97            return syn::Error::new_spanned(
98                &input.ident,
99                "ProtocolNode can only be derived for structs",
100            )
101            .to_compile_error()
102            .into();
103        }
104    };
105
106    let mut attr_fields = Vec::new();
107    for field in fields {
108        match extract_attr_info(field) {
109            Ok(Some(attr_info)) => attr_fields.push(attr_info),
110            Ok(None) => {}
111            Err(e) => return e.to_compile_error().into(),
112        }
113    }
114
115    let attr_setters: Vec<_> = attr_fields
116        .iter()
117        .map(|info| {
118            let field_ident = &info.field_ident;
119            let attr_name = &info.attr_name;
120
121            match (&info.attr_type, info.optional) {
122                (AttrType::Jid, true) => {
123                    // Option<Jid> - only insert if Some
124                    quote! {
125                        if let Some(jid) = self.#field_ident {
126                            builder = builder.jid_attr(#attr_name, jid);
127                        }
128                    }
129                }
130                (AttrType::Jid, false) => {
131                    // Required Jid - always insert
132                    quote! {
133                        builder = builder.jid_attr(#attr_name, self.#field_ident);
134                    }
135                }
136                (AttrType::String, true) => {
137                    // Option<String> - only insert if Some
138                    quote! {
139                        if let Some(s) = self.#field_ident {
140                            builder = builder.attr(#attr_name, s);
141                        }
142                    }
143                }
144                (AttrType::String, false) => {
145                    // Required String - always insert
146                    quote! {
147                        builder = builder.attr(#attr_name, self.#field_ident);
148                    }
149                }
150            }
151        })
152        .collect();
153
154    let field_parsers: Vec<_> = attr_fields
155        .iter()
156        .map(|info| {
157            let field_ident = &info.field_ident;
158            let attr_name = &info.attr_name;
159
160            match (&info.attr_type, info.optional, &info.default) {
161                (AttrType::Jid, false, _) => {
162                    // Required Jid
163                    quote! {
164                        #field_ident: node.attrs().optional_jid(#attr_name)
165                            .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
166                    }
167                }
168                (AttrType::Jid, true, _) => {
169                    // Optional Jid
170                    quote! {
171                        #field_ident: node.attrs().optional_jid(#attr_name)
172                    }
173                }
174                (AttrType::String, false, Some(default)) => {
175                    // String with default
176                    quote! {
177                        #field_ident: node.attrs().optional_string(#attr_name)
178                            .map(|s| s.to_string())
179                            .unwrap_or_else(|| #default.to_string())
180                    }
181                }
182                (AttrType::String, false, None) => {
183                    // Required String
184                    quote! {
185                        #field_ident: node.attrs().required_string(#attr_name)?.to_string()
186                    }
187                }
188                (AttrType::String, true, Some(default)) => {
189                    // Optional String with default (always Some)
190                    quote! {
191                        #field_ident: node.attrs().optional_string(#attr_name)
192                            .map(|s| s.to_string())
193                            .or_else(|| Some(#default.to_string()))
194                    }
195                }
196                (AttrType::String, true, None) => {
197                    // Optional String
198                    quote! {
199                        #field_ident: node.attrs().optional_string(#attr_name).map(|s| s.to_string())
200                    }
201                }
202            }
203        })
204        .collect();
205
206    // Only generate Default impl if all fields have defaults or are optional
207    let all_have_defaults = attr_fields
208        .iter()
209        .all(|info| info.default.is_some() || info.optional);
210
211    let default_impl = if all_have_defaults {
212        let default_fields: Vec<_> = attr_fields
213            .iter()
214            .map(|info| {
215                let field_ident = &info.field_ident;
216                match (&info.attr_type, info.optional, &info.default) {
217                    (_, true, Some(default)) => quote! { #field_ident: Some(#default.to_string()) },
218                    (_, true, None) => quote! { #field_ident: None },
219                    (AttrType::String, false, Some(default)) => {
220                        quote! { #field_ident: #default.to_string() }
221                    }
222                    _ => unreachable!("all_have_defaults check should prevent this branch"),
223                }
224            })
225            .collect();
226
227        quote! {
228            impl ::core::default::Default for #name {
229                fn default() -> Self {
230                    Self {
231                        #(#default_fields),*
232                    }
233                }
234            }
235        }
236    } else {
237        quote! {}
238    };
239
240    let expanded = quote! {
241        impl ::wa_rs_core::protocol::ProtocolNode for #name {
242            fn tag(&self) -> &'static str {
243                #tag
244            }
245
246            fn into_node(self) -> ::wa_rs_binary::node::Node {
247                let mut builder = ::wa_rs_binary::builder::NodeBuilder::new(#tag);
248                #(#attr_setters)*
249                builder.build()
250            }
251
252            fn try_from_node(node: &::wa_rs_binary::node::Node) -> ::anyhow::Result<Self> {
253                if node.tag != #tag {
254                    return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
255                }
256                Ok(Self {
257                    #(#field_parsers),*
258                })
259            }
260        }
261
262        #default_impl
263    };
264
265    expanded.into()
266}
267
268/// Derive macro for empty protocol nodes (tag only, no attributes).
269///
270/// # Attributes
271///
272/// - `#[protocol(tag = "tagname")]` - Required. Specifies the XML tag name.
273///
274/// # Example
275///
276/// ```ignore
277/// #[derive(EmptyNode)]
278/// #[protocol(tag = "participants")]
279/// pub struct ParticipantsRequest;
280/// ```
281#[proc_macro_derive(EmptyNode, attributes(protocol))]
282pub fn derive_empty_node(input: TokenStream) -> TokenStream {
283    let input = parse_macro_input!(input as DeriveInput);
284
285    let name = &input.ident;
286
287    let tag = match extract_tag(&input.attrs) {
288        Ok(Some(tag)) => tag,
289        Ok(None) => {
290            return syn::Error::new_spanned(
291                &input.ident,
292                "EmptyNode requires #[protocol(tag = \"...\")]",
293            )
294            .to_compile_error()
295            .into();
296        }
297        Err(e) => return e.to_compile_error().into(),
298    };
299
300    generate_empty_impl(name, &tag).into()
301}
302
303fn generate_empty_impl(name: &syn::Ident, tag: &str) -> proc_macro2::TokenStream {
304    quote! {
305        impl ::wa_rs_core::protocol::ProtocolNode for #name {
306            fn tag(&self) -> &'static str {
307                #tag
308            }
309
310            fn into_node(self) -> ::wa_rs_binary::node::Node {
311                ::wa_rs_binary::builder::NodeBuilder::new(#tag).build()
312            }
313
314            fn try_from_node(node: &::wa_rs_binary::node::Node) -> ::anyhow::Result<Self> {
315                if node.tag != #tag {
316                    return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
317                }
318                Ok(Self)
319            }
320        }
321
322        impl ::core::default::Default for #name {
323            fn default() -> Self {
324                Self
325            }
326        }
327    }
328}
329
330enum AttrType {
331    String,
332    Jid,
333}
334
335struct AttrFieldInfo {
336    field_ident: syn::Ident,
337    attr_name: String,
338    attr_type: AttrType,
339    optional: bool,
340    default: Option<String>,
341}
342
343fn extract_tag(attrs: &[syn::Attribute]) -> Result<Option<String>, syn::Error> {
344    for attr in attrs {
345        if attr.path().is_ident("protocol") {
346            let mut tag = None;
347            attr.parse_nested_meta(|meta| {
348                if meta.path.is_ident("tag") {
349                    let value: syn::LitStr = meta.value()?.parse()?;
350                    tag = Some(value.value());
351                }
352                Ok(())
353            })?;
354            if tag.is_some() {
355                return Ok(tag);
356            }
357        }
358    }
359    Ok(None)
360}
361
362fn extract_attr_info(field: &syn::Field) -> Result<Option<AttrFieldInfo>, syn::Error> {
363    let field_ident = match field.ident.clone() {
364        Some(ident) => ident,
365        None => return Ok(None),
366    };
367
368    // Check if field type is Option<T>
369    let is_optional = is_option_type(&field.ty);
370
371    for attr in &field.attrs {
372        if attr.path().is_ident("attr") {
373            let mut attr_name = None;
374            let mut default = None;
375            let mut is_jid = false;
376            let mut explicit_optional = false;
377
378            attr.parse_nested_meta(|meta| {
379                if meta.path.is_ident("name") {
380                    let value: syn::LitStr = meta.value()?.parse()?;
381                    attr_name = Some(value.value());
382                } else if meta.path.is_ident("default") {
383                    let value: syn::LitStr = meta.value()?.parse()?;
384                    default = Some(value.value());
385                } else if meta.path.is_ident("jid") {
386                    is_jid = true;
387                } else if meta.path.is_ident("optional") {
388                    explicit_optional = true;
389                }
390                Ok(())
391            })?;
392
393            match attr_name {
394                Some(name) => {
395                    let attr_type = if is_jid {
396                        AttrType::Jid
397                    } else {
398                        AttrType::String
399                    };
400
401                    // Determine if optional: either explicit marker or Option<T> type
402                    let optional = explicit_optional || is_optional;
403
404                    return Ok(Some(AttrFieldInfo {
405                        field_ident,
406                        attr_name: name,
407                        attr_type,
408                        optional,
409                        default,
410                    }));
411                }
412                None => {
413                    return Err(syn::Error::new_spanned(
414                        attr,
415                        "missing required `name` in #[attr(...)]",
416                    ));
417                }
418            }
419        }
420    }
421    Ok(None)
422}
423
424/// Check if a type is Option<T>
425fn is_option_type(ty: &syn::Type) -> bool {
426    if let syn::Type::Path(type_path) = ty
427        && let Some(segment) = type_path.path.segments.last()
428    {
429        return segment.ident == "Option";
430    }
431    false
432}
433
434/// Derive macro for enums with string representations.
435///
436/// Automatically implements:
437/// - `as_str(&self) -> &'static str`
438/// - `std::fmt::Display`
439/// - `TryFrom<&str>`
440/// - `Default` (first variant is default, or use `#[string_default]`)
441///
442/// # Attributes
443///
444/// - `#[str = "value"]` - Required on each variant. The string representation.
445/// - `#[string_default]` - Optional. Marks this variant as the default.
446///
447/// # Example
448///
449/// ```ignore
450/// #[derive(StringEnum)]
451/// pub enum MemberAddMode {
452///     #[str = "admin_add"]
453///     AdminAdd,
454///     #[string_default]
455///     #[str = "all_member_add"]
456///     AllMemberAdd,
457/// }
458///
459/// assert_eq!(MemberAddMode::AdminAdd.as_str(), "admin_add");
460/// assert_eq!(MemberAddMode::try_from("all_member_add").unwrap(), MemberAddMode::AllMemberAdd);
461/// ```
462#[proc_macro_derive(StringEnum, attributes(str, string_default))]
463pub fn derive_string_enum(input: TokenStream) -> TokenStream {
464    let input = parse_macro_input!(input as DeriveInput);
465
466    let name = &input.ident;
467
468    let variants = match &input.data {
469        Data::Enum(data) => &data.variants,
470        _ => {
471            return syn::Error::new_spanned(
472                &input.ident,
473                "StringEnum can only be derived for enums",
474            )
475            .to_compile_error()
476            .into();
477        }
478    };
479
480    let mut variant_infos = Vec::new();
481    let mut default_variant = None;
482    let mut seen_str_values: std::collections::HashMap<String, syn::Ident> =
483        std::collections::HashMap::new();
484
485    for variant in variants {
486        let variant_ident = &variant.ident;
487
488        if !matches!(variant.fields, syn::Fields::Unit) {
489            return syn::Error::new_spanned(
490                variant_ident,
491                "StringEnum only supports unit variants",
492            )
493            .to_compile_error()
494            .into();
495        }
496
497        let mut str_value = None;
498        let mut is_default = false;
499
500        for attr in &variant.attrs {
501            if attr.path().is_ident("str") {
502                if let syn::Meta::NameValue(nv) = &attr.meta
503                    && let syn::Expr::Lit(expr_lit) = &nv.value
504                    && let syn::Lit::Str(lit_str) = &expr_lit.lit
505                {
506                    str_value = Some(lit_str.value());
507                }
508            } else if attr.path().is_ident("string_default") {
509                is_default = true;
510            }
511        }
512
513        let str_val = match str_value {
514            Some(v) => v,
515            None => {
516                return syn::Error::new_spanned(
517                    variant_ident,
518                    format!(
519                        "StringEnum variant {} requires #[str = \"...\"] attribute",
520                        variant_ident
521                    ),
522                )
523                .to_compile_error()
524                .into();
525            }
526        };
527
528        if let Some(prev_variant) = seen_str_values.get(&str_val) {
529            return syn::Error::new_spanned(
530                variant_ident,
531                format!(
532                    "duplicate #[str = \"{}\"] value; already used by variant `{}`",
533                    str_val, prev_variant
534                ),
535            )
536            .to_compile_error()
537            .into();
538        }
539        seen_str_values.insert(str_val.clone(), variant_ident.clone());
540
541        if is_default {
542            if default_variant.is_some() {
543                return syn::Error::new_spanned(
544                    variant_ident,
545                    "Multiple #[string_default] attributes found; only one variant may be the default",
546                )
547                .to_compile_error()
548                .into();
549            }
550            default_variant = Some(variant_ident.clone());
551        }
552
553        variant_infos.push((variant_ident.clone(), str_val));
554    }
555
556    // Check for empty enums
557    if variant_infos.is_empty() {
558        return syn::Error::new_spanned(
559            &input.ident,
560            "StringEnum cannot be derived for empty enums",
561        )
562        .to_compile_error()
563        .into();
564    }
565
566    // If no explicit default, use first variant
567    let default_variant = default_variant.unwrap_or_else(|| variant_infos[0].0.clone());
568
569    // Generate as_str() match arms
570    let as_str_arms: Vec<_> = variant_infos
571        .iter()
572        .map(|(ident, str_val)| {
573            quote! { #name::#ident => #str_val }
574        })
575        .collect();
576
577    // Generate TryFrom match arms
578    let try_from_arms: Vec<_> = variant_infos
579        .iter()
580        .map(|(ident, str_val)| {
581            quote! { #str_val => Ok(#name::#ident) }
582        })
583        .collect();
584
585    let expanded = quote! {
586        impl #name {
587            /// Returns the string representation of this enum variant.
588            pub fn as_str(&self) -> &'static str {
589                match self {
590                    #(#as_str_arms),*
591                }
592            }
593        }
594
595        impl ::core::fmt::Display for #name {
596            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
597                f.write_str(self.as_str())
598            }
599        }
600
601        impl ::core::convert::TryFrom<&str> for #name {
602            type Error = ::anyhow::Error;
603
604            fn try_from(value: &str) -> ::core::result::Result<Self, Self::Error> {
605                match value {
606                    #(#try_from_arms),*,
607                    _ => Err(::anyhow::anyhow!("unknown {}: {}", stringify!(#name), value)),
608                }
609            }
610        }
611
612        impl ::core::default::Default for #name {
613            fn default() -> Self {
614                #name::#default_variant
615            }
616        }
617    };
618
619    expanded.into()
620}