1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4
5use proc_macro_crate::{FoundCrate, crate_name};
6use proc_macro2::TokenStream as TokenStream2;
7use quote::{ToTokens, format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::{
10 Data, DeriveInput, Fields, GenericArgument, LitStr, Meta, PathArguments, Token, Type, Variant,
11 parse_macro_input,
12};
13
14#[proc_macro_derive(BotCommands, attributes(command))]
15pub fn derive_bot_commands(input: TokenStream) -> TokenStream {
16 match derive_bot_commands_impl(parse_macro_input!(input as DeriveInput)) {
17 Ok(tokens) => tokens.into(),
18 Err(error) => error.to_compile_error().into(),
19 }
20}
21
22fn derive_bot_commands_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
23 let enum_name = input.ident;
24 let generics = input.generics;
25 let tele_path = tele_crate_path();
26
27 let data = match input.data {
28 Data::Enum(data) => data,
29 _ => {
30 return Err(syn::Error::new_spanned(
31 enum_name,
32 "BotCommands can only be derived for enums",
33 ));
34 }
35 };
36
37 let mut parse_arms = Vec::new();
38 let mut description_entries = Vec::new();
39 let mut known_command_variants = HashMap::<String, String>::new();
40
41 for variant in data.variants {
42 let attrs = parse_variant_attrs(&variant)?;
43 let variant_ident = variant.ident.clone();
44 let command_name = attrs
45 .rename
46 .unwrap_or_else(|| to_snake_case(&variant_ident.to_string()));
47 validate_command_name(&command_name, &variant_ident)?;
48 let description = attrs
49 .description
50 .unwrap_or_else(|| format!("{command_name} command"));
51 validate_command_description(&description, &variant_ident)?;
52
53 let mut parse_names = Vec::with_capacity(1 + attrs.aliases.len());
54 parse_names.push(command_name.clone());
55 parse_names.extend(attrs.aliases);
56 for parse_name in &parse_names {
57 validate_command_name(parse_name, &variant_ident)?;
58 if let Some(existing_variant) = known_command_variants.get(parse_name) {
59 return Err(syn::Error::new_spanned(
60 &variant_ident,
61 format!(
62 "command name `{parse_name}` for variant `{variant_ident}` conflicts with variant `{existing_variant}`"
63 ),
64 ));
65 }
66
67 known_command_variants.insert(parse_name.clone(), variant_ident.to_string());
68 }
69
70 let name_lit = LitStr::new(&command_name, variant_ident.span());
71 let desc_lit = LitStr::new(&description, variant_ident.span());
72
73 description_entries.push(quote! {
74 #tele_path::bot::CommandDescription {
75 command: #name_lit,
76 description: #desc_lit,
77 }
78 });
79
80 let parse_arm = parse_arm_for_variant(&enum_name, &variant_ident, &variant, &tele_path)?;
81 for parse_name in parse_names {
82 let parse_name_lit = LitStr::new(&parse_name, variant_ident.span());
83 let parse_arm_tokens = parse_arm.clone();
84 parse_arms.push(quote! {
85 #parse_name_lit => #parse_arm_tokens
86 });
87 }
88 }
89
90 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
91
92 Ok(quote! {
93 impl #impl_generics #tele_path::bot::BotCommands for #enum_name #ty_generics #where_clause {
94 fn parse(command: &str, args: &str) -> Option<Self> {
95 let args = args.trim();
96 match command {
97 #(#parse_arms,)*
98 _ => None,
99 }
100 }
101
102 fn descriptions() -> &'static [#tele_path::bot::CommandDescription] {
103 &[
104 #(#description_entries),*
105 ]
106 }
107 }
108 })
109}
110
111fn parse_arm_for_variant(
112 enum_name: &syn::Ident,
113 variant_ident: &syn::Ident,
114 variant: &Variant,
115 tele_path: &TokenStream2,
116) -> syn::Result<TokenStream2> {
117 match &variant.fields {
118 Fields::Unit => Ok(quote! {
119 if args.is_empty() {
120 Some(#enum_name::#variant_ident)
121 } else {
122 None
123 }
124 }),
125 Fields::Unnamed(fields) => {
126 if fields.unnamed.is_empty() {
127 return Err(syn::Error::new_spanned(
128 fields,
129 "tuple command variants must have at least one field",
130 ));
131 }
132
133 let mut value_bindings = Vec::new();
134 let mut value_names = Vec::new();
135 let field_count = fields.unnamed.len();
136
137 for (index, field) in fields.unnamed.iter().enumerate() {
138 let value_ident = format_ident!("__arg_{index}");
139 let is_last = index + 1 == field_count;
140 let ty = &field.ty;
141 validate_field_type(ty, field)?;
142 let value_expr = parse_value_expr(ty, is_last);
143
144 value_bindings.push(quote! {
145 let #value_ident: #ty = #value_expr;
146 });
147 value_names.push(value_ident);
148 }
149
150 Ok(quote! {
151 {
152 let __tokens = #tele_path::bot::tokenize_command_args(args)?;
153 let mut __cursor: usize = 0;
154 #(#value_bindings)*
155
156 if __cursor < __tokens.len() {
157 None
158 } else {
159 Some(#enum_name::#variant_ident(#(#value_names),*))
160 }
161 }
162 })
163 }
164 Fields::Named(fields) => {
165 if fields.named.is_empty() {
166 return Err(syn::Error::new_spanned(
167 fields,
168 "named command variants must have at least one field",
169 ));
170 }
171
172 let mut value_bindings = Vec::new();
173 let mut field_assignments = Vec::new();
174 let field_count = fields.named.len();
175
176 for (index, field) in fields.named.iter().enumerate() {
177 let value_ident = format_ident!("__arg_{index}");
178 let field_ident = field.ident.clone().ok_or_else(|| {
179 syn::Error::new_spanned(field, "named field missing identifier")
180 })?;
181 let is_last = index + 1 == field_count;
182 let ty = &field.ty;
183 validate_field_type(ty, field)?;
184 let value_expr = parse_value_expr(ty, is_last);
185
186 value_bindings.push(quote! {
187 let #value_ident: #ty = #value_expr;
188 });
189 field_assignments.push(quote! {
190 #field_ident: #value_ident
191 });
192 }
193
194 Ok(quote! {
195 {
196 let __tokens = #tele_path::bot::tokenize_command_args(args)?;
197 let mut __cursor: usize = 0;
198 #(#value_bindings)*
199
200 if __cursor < __tokens.len() {
201 None
202 } else {
203 Some(#enum_name::#variant_ident { #(#field_assignments),* })
204 }
205 }
206 })
207 }
208 }
209}
210
211fn parse_value_expr(ty: &Type, is_last: bool) -> TokenStream2 {
212 if is_string_type(ty) {
213 if is_last {
214 return quote! {
215 if __cursor >= __tokens.len() {
216 String::new()
217 } else {
218 let value = __tokens[__cursor..].join(" ");
219 __cursor = __tokens.len();
220 value
221 }
222 };
223 }
224
225 return quote! {
226 {
227 let token = match __tokens.get(__cursor) {
228 Some(token) => token,
229 None => return None,
230 };
231 __cursor += 1;
232 token.clone()
233 }
234 };
235 }
236
237 if let Some(inner) = option_inner_type(ty) {
238 if is_string_type(inner) {
239 if is_last {
240 return quote! {
241 if __cursor >= __tokens.len() {
242 None
243 } else {
244 let value = __tokens[__cursor..].join(" ");
245 __cursor = __tokens.len();
246 Some(value)
247 }
248 };
249 }
250
251 return quote! {
252 if __cursor >= __tokens.len() {
253 None
254 } else {
255 let token = __tokens[__cursor].clone();
256 __cursor += 1;
257 Some(token)
258 }
259 };
260 }
261
262 return quote! {
263 if __cursor >= __tokens.len() {
264 None
265 } else {
266 let token = &__tokens[__cursor];
267 __cursor += 1;
268 Some(token.parse::<#inner>().ok()?)
269 }
270 };
271 }
272
273 quote! {
274 {
275 let token = match __tokens.get(__cursor) {
276 Some(token) => token,
277 None => return None,
278 };
279 __cursor += 1;
280 token.parse::<#ty>().ok()?
281 }
282 }
283}
284
285#[derive(Default)]
286struct VariantAttrs {
287 rename: Option<String>,
288 description: Option<String>,
289 aliases: Vec<String>,
290}
291
292fn parse_variant_attrs(variant: &Variant) -> syn::Result<VariantAttrs> {
293 let mut parsed = VariantAttrs::default();
294
295 for attr in &variant.attrs {
296 if !attr.path().is_ident("command") {
297 continue;
298 }
299
300 let nested: Punctuated<Meta, Token![,]> =
301 attr.parse_args_with(Punctuated::parse_terminated)?;
302
303 for meta in nested {
304 match meta {
305 Meta::NameValue(name_value) if name_value.path.is_ident("rename") => {
306 let literal: LitStr = syn::parse2(name_value.value.into_token_stream())?;
307 let value = literal.value();
308 if parsed.rename.replace(value).is_some() {
309 return Err(syn::Error::new_spanned(
310 name_value.path,
311 "duplicate `rename` attribute",
312 ));
313 }
314 }
315 Meta::NameValue(name_value) if name_value.path.is_ident("description") => {
316 let literal: LitStr = syn::parse2(name_value.value.into_token_stream())?;
317 let value = literal.value();
318 if parsed.description.replace(value).is_some() {
319 return Err(syn::Error::new_spanned(
320 name_value.path,
321 "duplicate `description` attribute",
322 ));
323 }
324 }
325 Meta::NameValue(name_value) if name_value.path.is_ident("alias") => {
326 let literal: LitStr = syn::parse2(name_value.value.into_token_stream())?;
327 parsed.aliases.push(literal.value());
328 }
329 Meta::List(list) if list.path.is_ident("aliases") => {
330 let aliases: Punctuated<LitStr, Token![,]> =
331 list.parse_args_with(Punctuated::parse_terminated)?;
332 if aliases.is_empty() {
333 return Err(syn::Error::new_spanned(
334 list.path,
335 "`aliases(...)` requires at least one alias",
336 ));
337 }
338 parsed
339 .aliases
340 .extend(aliases.into_iter().map(|alias| alias.value()));
341 }
342 other => {
343 return Err(syn::Error::new_spanned(
344 other,
345 "unsupported command attribute, expected `rename = \"...\"`, `description = \"...\"`, `alias = \"...\"`, or `aliases(\"...\", ...)`",
346 ));
347 }
348 }
349 }
350 }
351
352 Ok(parsed)
353}
354
355fn validate_command_name(name: &str, span: &impl ToTokens) -> syn::Result<()> {
356 if name.is_empty() {
357 return Err(syn::Error::new_spanned(
358 span,
359 "command name cannot be empty",
360 ));
361 }
362
363 if name.len() > 32 {
364 return Err(syn::Error::new_spanned(
365 span,
366 format!("command name `{name}` exceeds Telegram max length of 32"),
367 ));
368 }
369
370 let mut chars = name.chars();
371 let Some(first_char) = chars.next() else {
372 return Err(syn::Error::new_spanned(
373 span,
374 "command name cannot be empty",
375 ));
376 };
377
378 if !first_char.is_ascii_lowercase() {
379 return Err(syn::Error::new_spanned(
380 span,
381 format!("command name `{name}` must start with a lowercase ASCII letter"),
382 ));
383 }
384
385 if !name
386 .chars()
387 .all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '_')
388 {
389 return Err(syn::Error::new_spanned(
390 span,
391 format!(
392 "command name `{name}` contains invalid characters; use lowercase ASCII letters, digits, and `_`"
393 ),
394 ));
395 }
396
397 Ok(())
398}
399
400fn validate_command_description(description: &str, span: &impl ToTokens) -> syn::Result<()> {
401 if description.is_empty() {
402 return Err(syn::Error::new_spanned(
403 span,
404 "command description cannot be empty",
405 ));
406 }
407
408 if description.len() > 256 {
409 return Err(syn::Error::new_spanned(
410 span,
411 format!("command description exceeds Telegram max length of 256: `{description}`"),
412 ));
413 }
414
415 Ok(())
416}
417
418fn validate_field_type(ty: &Type, span: &impl ToTokens) -> syn::Result<()> {
419 if matches!(ty, Type::Reference(_)) {
420 return Err(syn::Error::new_spanned(
421 span,
422 "borrowed command argument types are unsupported; use owned types like `String`",
423 ));
424 }
425
426 if let Some(inner) = option_inner_type(ty)
427 && matches!(inner, Type::Reference(_))
428 {
429 return Err(syn::Error::new_spanned(
430 span,
431 "borrowed command argument types inside `Option` are unsupported; use `Option<String>`",
432 ));
433 }
434
435 Ok(())
436}
437
438fn is_string_type(ty: &Type) -> bool {
439 match ty {
440 Type::Path(type_path) => type_path
441 .path
442 .segments
443 .last()
444 .is_some_and(|segment| segment.ident == "String"),
445 _ => false,
446 }
447}
448
449fn option_inner_type(ty: &Type) -> Option<&Type> {
450 let type_path = match ty {
451 Type::Path(type_path) => type_path,
452 _ => return None,
453 };
454
455 let segment = type_path.path.segments.last()?;
456 if segment.ident != "Option" {
457 return None;
458 }
459
460 let args = match &segment.arguments {
461 PathArguments::AngleBracketed(args) => args,
462 _ => return None,
463 };
464
465 if args.args.len() != 1 {
466 return None;
467 }
468
469 match args.args.first()? {
470 GenericArgument::Type(inner) => Some(inner),
471 _ => None,
472 }
473}
474
475fn to_snake_case(name: &str) -> String {
476 let mut result = String::new();
477 let chars: Vec<char> = name.chars().collect();
478
479 for (index, ch) in chars.iter().enumerate() {
480 if ch.is_uppercase() {
481 if index > 0 {
482 let prev = chars[index - 1];
483 let next = chars.get(index + 1).copied();
484 if prev.is_lowercase() || next.is_some_and(|c| c.is_lowercase()) {
485 result.push('_');
486 }
487 }
488
489 for lower in ch.to_lowercase() {
490 result.push(lower);
491 }
492 } else {
493 result.push(*ch);
494 }
495 }
496
497 result
498}
499
500fn tele_crate_path() -> TokenStream2 {
501 match crate_name("tele") {
502 Ok(FoundCrate::Itself) => quote!(::tele),
503 Ok(FoundCrate::Name(name)) => {
504 let ident = format_ident!("{name}");
505 quote!(::#ident)
506 }
507 Err(_) => quote!(::tele),
508 }
509}