Skip to main content

tele_macros/
lib.rs

1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4
5use proc_macro_crate::{FoundCrate, crate_name};
6use proc_macro2::TokenStream as TokenStream2;
7use quote::{ToTokens, format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::{
10    Data, DeriveInput, Fields, GenericArgument, LitStr, Meta, PathArguments, Token, Type, Variant,
11    parse_macro_input,
12};
13
14#[proc_macro_derive(BotCommands, attributes(command))]
15pub fn derive_bot_commands(input: TokenStream) -> TokenStream {
16    match derive_bot_commands_impl(parse_macro_input!(input as DeriveInput)) {
17        Ok(tokens) => tokens.into(),
18        Err(error) => error.to_compile_error().into(),
19    }
20}
21
22fn derive_bot_commands_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
23    let enum_name = input.ident;
24    let generics = input.generics;
25    let tele_path = tele_crate_path();
26
27    let data = match input.data {
28        Data::Enum(data) => data,
29        _ => {
30            return Err(syn::Error::new_spanned(
31                enum_name,
32                "BotCommands can only be derived for enums",
33            ));
34        }
35    };
36
37    let mut parse_arms = Vec::new();
38    let mut description_entries = Vec::new();
39    let mut known_command_variants = HashMap::<String, String>::new();
40
41    for variant in data.variants {
42        let attrs = parse_variant_attrs(&variant)?;
43        let variant_ident = variant.ident.clone();
44        let command_name = attrs
45            .rename
46            .unwrap_or_else(|| to_snake_case(&variant_ident.to_string()));
47        validate_command_name(&command_name, &variant_ident)?;
48        let description = attrs
49            .description
50            .unwrap_or_else(|| format!("{command_name} command"));
51        validate_command_description(&description, &variant_ident)?;
52
53        let mut parse_names = Vec::with_capacity(1 + attrs.aliases.len());
54        parse_names.push(command_name.clone());
55        parse_names.extend(attrs.aliases);
56        for parse_name in &parse_names {
57            validate_command_name(parse_name, &variant_ident)?;
58            if let Some(existing_variant) = known_command_variants.get(parse_name) {
59                return Err(syn::Error::new_spanned(
60                    &variant_ident,
61                    format!(
62                        "command name `{parse_name}` for variant `{variant_ident}` conflicts with variant `{existing_variant}`"
63                    ),
64                ));
65            }
66
67            known_command_variants.insert(parse_name.clone(), variant_ident.to_string());
68        }
69
70        let name_lit = LitStr::new(&command_name, variant_ident.span());
71        let desc_lit = LitStr::new(&description, variant_ident.span());
72
73        description_entries.push(quote! {
74            #tele_path::bot::CommandDescription {
75                command: #name_lit,
76                description: #desc_lit,
77            }
78        });
79
80        let parse_arm = parse_arm_for_variant(&enum_name, &variant_ident, &variant, &tele_path)?;
81        for parse_name in parse_names {
82            let parse_name_lit = LitStr::new(&parse_name, variant_ident.span());
83            let parse_arm_tokens = parse_arm.clone();
84            parse_arms.push(quote! {
85                #parse_name_lit => #parse_arm_tokens
86            });
87        }
88    }
89
90    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
91
92    Ok(quote! {
93        impl #impl_generics #tele_path::bot::BotCommands for #enum_name #ty_generics #where_clause {
94            fn parse(command: &str, args: &str) -> Option<Self> {
95                let args = args.trim();
96                match command {
97                    #(#parse_arms,)*
98                    _ => None,
99                }
100            }
101
102            fn descriptions() -> &'static [#tele_path::bot::CommandDescription] {
103                &[
104                    #(#description_entries),*
105                ]
106            }
107        }
108    })
109}
110
111fn parse_arm_for_variant(
112    enum_name: &syn::Ident,
113    variant_ident: &syn::Ident,
114    variant: &Variant,
115    tele_path: &TokenStream2,
116) -> syn::Result<TokenStream2> {
117    match &variant.fields {
118        Fields::Unit => Ok(quote! {
119            if args.is_empty() {
120                Some(#enum_name::#variant_ident)
121            } else {
122                None
123            }
124        }),
125        Fields::Unnamed(fields) => {
126            if fields.unnamed.is_empty() {
127                return Err(syn::Error::new_spanned(
128                    fields,
129                    "tuple command variants must have at least one field",
130                ));
131            }
132
133            let mut value_bindings = Vec::new();
134            let mut value_names = Vec::new();
135            let field_count = fields.unnamed.len();
136
137            for (index, field) in fields.unnamed.iter().enumerate() {
138                let value_ident = format_ident!("__arg_{index}");
139                let is_last = index + 1 == field_count;
140                let ty = &field.ty;
141                validate_field_type(ty, field)?;
142                let value_expr = parse_value_expr(ty, is_last);
143
144                value_bindings.push(quote! {
145                    let #value_ident: #ty = #value_expr;
146                });
147                value_names.push(value_ident);
148            }
149
150            Ok(quote! {
151                {
152                    let __tokens = #tele_path::bot::tokenize_command_args(args)?;
153                    let mut __cursor: usize = 0;
154                    #(#value_bindings)*
155
156                    if __cursor < __tokens.len() {
157                        None
158                    } else {
159                        Some(#enum_name::#variant_ident(#(#value_names),*))
160                    }
161                }
162            })
163        }
164        Fields::Named(fields) => {
165            if fields.named.is_empty() {
166                return Err(syn::Error::new_spanned(
167                    fields,
168                    "named command variants must have at least one field",
169                ));
170            }
171
172            let mut value_bindings = Vec::new();
173            let mut field_assignments = Vec::new();
174            let field_count = fields.named.len();
175
176            for (index, field) in fields.named.iter().enumerate() {
177                let value_ident = format_ident!("__arg_{index}");
178                let field_ident = field.ident.clone().ok_or_else(|| {
179                    syn::Error::new_spanned(field, "named field missing identifier")
180                })?;
181                let is_last = index + 1 == field_count;
182                let ty = &field.ty;
183                validate_field_type(ty, field)?;
184                let value_expr = parse_value_expr(ty, is_last);
185
186                value_bindings.push(quote! {
187                    let #value_ident: #ty = #value_expr;
188                });
189                field_assignments.push(quote! {
190                    #field_ident: #value_ident
191                });
192            }
193
194            Ok(quote! {
195                {
196                    let __tokens = #tele_path::bot::tokenize_command_args(args)?;
197                    let mut __cursor: usize = 0;
198                    #(#value_bindings)*
199
200                    if __cursor < __tokens.len() {
201                        None
202                    } else {
203                        Some(#enum_name::#variant_ident { #(#field_assignments),* })
204                    }
205                }
206            })
207        }
208    }
209}
210
211fn parse_value_expr(ty: &Type, is_last: bool) -> TokenStream2 {
212    if is_string_type(ty) {
213        if is_last {
214            return quote! {
215                if __cursor >= __tokens.len() {
216                    String::new()
217                } else {
218                    let value = __tokens[__cursor..].join(" ");
219                    __cursor = __tokens.len();
220                    value
221                }
222            };
223        }
224
225        return quote! {
226            {
227                let token = match __tokens.get(__cursor) {
228                    Some(token) => token,
229                    None => return None,
230                };
231                __cursor += 1;
232                token.clone()
233            }
234        };
235    }
236
237    if let Some(inner) = option_inner_type(ty) {
238        if is_string_type(inner) {
239            if is_last {
240                return quote! {
241                    if __cursor >= __tokens.len() {
242                        None
243                    } else {
244                        let value = __tokens[__cursor..].join(" ");
245                        __cursor = __tokens.len();
246                        Some(value)
247                    }
248                };
249            }
250
251            return quote! {
252                if __cursor >= __tokens.len() {
253                    None
254                } else {
255                    let token = __tokens[__cursor].clone();
256                    __cursor += 1;
257                    Some(token)
258                }
259            };
260        }
261
262        return quote! {
263            if __cursor >= __tokens.len() {
264                None
265            } else {
266                let token = &__tokens[__cursor];
267                __cursor += 1;
268                Some(token.parse::<#inner>().ok()?)
269            }
270        };
271    }
272
273    quote! {
274        {
275            let token = match __tokens.get(__cursor) {
276                Some(token) => token,
277                None => return None,
278            };
279            __cursor += 1;
280            token.parse::<#ty>().ok()?
281        }
282    }
283}
284
285#[derive(Default)]
286struct VariantAttrs {
287    rename: Option<String>,
288    description: Option<String>,
289    aliases: Vec<String>,
290}
291
292fn parse_variant_attrs(variant: &Variant) -> syn::Result<VariantAttrs> {
293    let mut parsed = VariantAttrs::default();
294
295    for attr in &variant.attrs {
296        if !attr.path().is_ident("command") {
297            continue;
298        }
299
300        let nested: Punctuated<Meta, Token![,]> =
301            attr.parse_args_with(Punctuated::parse_terminated)?;
302
303        for meta in nested {
304            match meta {
305                Meta::NameValue(name_value) if name_value.path.is_ident("rename") => {
306                    let literal: LitStr = syn::parse2(name_value.value.into_token_stream())?;
307                    let value = literal.value();
308                    if parsed.rename.replace(value).is_some() {
309                        return Err(syn::Error::new_spanned(
310                            name_value.path,
311                            "duplicate `rename` attribute",
312                        ));
313                    }
314                }
315                Meta::NameValue(name_value) if name_value.path.is_ident("description") => {
316                    let literal: LitStr = syn::parse2(name_value.value.into_token_stream())?;
317                    let value = literal.value();
318                    if parsed.description.replace(value).is_some() {
319                        return Err(syn::Error::new_spanned(
320                            name_value.path,
321                            "duplicate `description` attribute",
322                        ));
323                    }
324                }
325                Meta::NameValue(name_value) if name_value.path.is_ident("alias") => {
326                    let literal: LitStr = syn::parse2(name_value.value.into_token_stream())?;
327                    parsed.aliases.push(literal.value());
328                }
329                Meta::List(list) if list.path.is_ident("aliases") => {
330                    let aliases: Punctuated<LitStr, Token![,]> =
331                        list.parse_args_with(Punctuated::parse_terminated)?;
332                    if aliases.is_empty() {
333                        return Err(syn::Error::new_spanned(
334                            list.path,
335                            "`aliases(...)` requires at least one alias",
336                        ));
337                    }
338                    parsed
339                        .aliases
340                        .extend(aliases.into_iter().map(|alias| alias.value()));
341                }
342                other => {
343                    return Err(syn::Error::new_spanned(
344                        other,
345                        "unsupported command attribute, expected `rename = \"...\"`, `description = \"...\"`, `alias = \"...\"`, or `aliases(\"...\", ...)`",
346                    ));
347                }
348            }
349        }
350    }
351
352    Ok(parsed)
353}
354
355fn validate_command_name(name: &str, span: &impl ToTokens) -> syn::Result<()> {
356    if name.is_empty() {
357        return Err(syn::Error::new_spanned(
358            span,
359            "command name cannot be empty",
360        ));
361    }
362
363    if name.len() > 32 {
364        return Err(syn::Error::new_spanned(
365            span,
366            format!("command name `{name}` exceeds Telegram max length of 32"),
367        ));
368    }
369
370    let mut chars = name.chars();
371    let Some(first_char) = chars.next() else {
372        return Err(syn::Error::new_spanned(
373            span,
374            "command name cannot be empty",
375        ));
376    };
377
378    if !first_char.is_ascii_lowercase() {
379        return Err(syn::Error::new_spanned(
380            span,
381            format!("command name `{name}` must start with a lowercase ASCII letter"),
382        ));
383    }
384
385    if !name
386        .chars()
387        .all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '_')
388    {
389        return Err(syn::Error::new_spanned(
390            span,
391            format!(
392                "command name `{name}` contains invalid characters; use lowercase ASCII letters, digits, and `_`"
393            ),
394        ));
395    }
396
397    Ok(())
398}
399
400fn validate_command_description(description: &str, span: &impl ToTokens) -> syn::Result<()> {
401    if description.is_empty() {
402        return Err(syn::Error::new_spanned(
403            span,
404            "command description cannot be empty",
405        ));
406    }
407
408    if description.len() > 256 {
409        return Err(syn::Error::new_spanned(
410            span,
411            format!("command description exceeds Telegram max length of 256: `{description}`"),
412        ));
413    }
414
415    Ok(())
416}
417
418fn validate_field_type(ty: &Type, span: &impl ToTokens) -> syn::Result<()> {
419    if matches!(ty, Type::Reference(_)) {
420        return Err(syn::Error::new_spanned(
421            span,
422            "borrowed command argument types are unsupported; use owned types like `String`",
423        ));
424    }
425
426    if let Some(inner) = option_inner_type(ty)
427        && matches!(inner, Type::Reference(_))
428    {
429        return Err(syn::Error::new_spanned(
430            span,
431            "borrowed command argument types inside `Option` are unsupported; use `Option<String>`",
432        ));
433    }
434
435    Ok(())
436}
437
438fn is_string_type(ty: &Type) -> bool {
439    match ty {
440        Type::Path(type_path) => type_path
441            .path
442            .segments
443            .last()
444            .is_some_and(|segment| segment.ident == "String"),
445        _ => false,
446    }
447}
448
449fn option_inner_type(ty: &Type) -> Option<&Type> {
450    let type_path = match ty {
451        Type::Path(type_path) => type_path,
452        _ => return None,
453    };
454
455    let segment = type_path.path.segments.last()?;
456    if segment.ident != "Option" {
457        return None;
458    }
459
460    let args = match &segment.arguments {
461        PathArguments::AngleBracketed(args) => args,
462        _ => return None,
463    };
464
465    if args.args.len() != 1 {
466        return None;
467    }
468
469    match args.args.first()? {
470        GenericArgument::Type(inner) => Some(inner),
471        _ => None,
472    }
473}
474
475fn to_snake_case(name: &str) -> String {
476    let mut result = String::new();
477    let chars: Vec<char> = name.chars().collect();
478
479    for (index, ch) in chars.iter().enumerate() {
480        if ch.is_uppercase() {
481            if index > 0 {
482                let prev = chars[index - 1];
483                let next = chars.get(index + 1).copied();
484                if prev.is_lowercase() || next.is_some_and(|c| c.is_lowercase()) {
485                    result.push('_');
486                }
487            }
488
489            for lower in ch.to_lowercase() {
490                result.push(lower);
491            }
492        } else {
493            result.push(*ch);
494        }
495    }
496
497    result
498}
499
500fn tele_crate_path() -> TokenStream2 {
501    match crate_name("tele") {
502        Ok(FoundCrate::Itself) => quote!(::tele),
503        Ok(FoundCrate::Name(name)) => {
504            let ident = format_ident!("{name}");
505            quote!(::#ident)
506        }
507        Err(_) => quote!(::tele),
508    }
509}