rialo_cli_representable/lib.rs
1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! # Rialo CLI Representable Derive Macro
5//!
6//! This crate provides a procedural macro for deriving the `Representable` and `HumanReadable` traits
7//! for structs in the Rialo CLI system.
8//!
9//! ## Usage
10//!
11//! To use this macro, add the following to your struct:
12//!
13//! ```text
14//! #[derive(Representable)]
15//! #[representable(human_readable = "my_human_readable_fn")]
16//! struct MyStruct {
17//! pub field1: String,
18//! pub field2: u64,
19//! }
20//!
21//! fn my_human_readable_fn(data: &MyStruct) -> String {
22//! format!("Field1: {}, Field2: {}", data.field1, data.field2)
23//! }
24//! ```
25//!
26//! ## Attributes
27//!
28//! The `#[representable]` attribute supports the following options:
29//!
30//! - `human_readable = "function_name"`: Specifies a custom function to use for human-readable output.
31//! The function should take a reference to the struct and return a `String`.
32//!
33//! If no `human_readable` function is specified, the macro defaults to using `serde_json::to_string()`.
34//!
35//! ## Generated Code
36//!
37//! The macro generates implementations for:
38//!
39//! - `Representable`: A marker trait for CLI-representable types
40//! - `HumanReadable`: Provides a `human_readable()` method that returns a user-friendly string representation
41//!
42//! ## Example with Custom Display Function
43//!
44//! ```text
45//! #[derive(serde::Serialize, Clone, Representable)]
46//! #[representable(human_readable = "balance_display")]
47//! pub struct BalanceResult {
48//! pub amount: f64,
49//! pub currency: String,
50//! }
51//!
52//! fn balance_display(result: &BalanceResult) -> String {
53//! format!("Balance: {} {}", result.amount, result.currency)
54//! }
55//! ```
56
57use proc_macro::TokenStream;
58use quote::quote;
59use syn::{parse_macro_input, DeriveInput, Error, Ident, Meta};
60
61/// Derives the `Representable` and `HumanReadable` traits for a struct.
62///
63/// This macro automatically implements:
64/// - `Representable`: A marker trait indicating the type can be represented in CLI output
65/// - `HumanReadable`: Provides a `human_readable()` method for user-friendly string representation
66///
67/// # Attributes
68///
69/// ## `#[representable(human_readable = "function_name")]`
70///
71/// Specifies a custom function to use for human-readable output. The function should:
72/// - Take a reference to the struct (`&Self`)
73/// - Return a `String`
74/// - Be defined in the same module as the struct
75///
76/// # Examples
77///
78/// With a custom human-readable function:
79///
80/// ```text
81/// #[derive(Representable)]
82/// #[representable(human_readable = "custom_display")]
83/// struct MyData {
84/// value: u64,
85/// }
86///
87/// fn custom_display(data: &MyData) -> String {
88/// format!("Value: {}", data.value)
89/// }
90/// ```
91///
92/// Without specifying a function (defaults to JSON):
93///
94/// ```text
95/// #[derive(Representable)]
96/// struct MyData {
97/// value: u64,
98/// }
99/// // Will use serde_json::to_string(self) by default
100/// ```
101#[proc_macro_derive(Representable, attributes(representable))]
102pub fn derive_representable(input: TokenStream) -> TokenStream {
103 let input = parse_macro_input!(input as DeriveInput);
104 let name = input.ident;
105
106 match parse_human_readable_fn(&input.attrs, &name) {
107 Ok(human_readable_fn) => {
108 let expanded = quote! {
109 impl ::rialo_cli_representation::HumanReadable for #name {
110 fn human_readable(&self) -> String {
111 #human_readable_fn
112 }
113 }
114
115 impl ::rialo_cli_representation::Representable for #name {}
116 };
117
118 TokenStream::from(expanded)
119 }
120 Err(err) => {
121 let error = err.to_compile_error();
122 TokenStream::from(error)
123 }
124 }
125}
126
127/// Validates a function name to ensure it's safe for code generation.
128///
129/// This function performs several security checks:
130/// - Ensures the function name is a valid Rust identifier
131/// - Prevents path traversal attempts
132/// - Checks for potentially dangerous characters
133/// - Validates length constraints
134/// - Checks against reserved keywords using syn's built-in detection
135///
136/// # Arguments
137///
138/// * `function_name` - The function name to validate
139/// * `span` - The span for error reporting
140///
141/// # Returns
142///
143/// Returns `Ok(())` if the function name is valid, or an error with details.
144fn validate_function_name(function_name: &str, span: proc_macro2::Span) -> Result<(), Error> {
145 // Check for empty or whitespace-only names
146 if function_name.trim().is_empty() {
147 return Err(Error::new(
148 span,
149 "Function name cannot be empty or whitespace-only",
150 ));
151 }
152
153 // Check length constraints (reasonable limits for function names)
154 if function_name.len() > 100 {
155 return Err(Error::new(
156 span,
157 "Function name is too long (maximum 100 characters)",
158 ));
159 }
160
161 // Check for path separators and other dangerous characters
162 let dangerous_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\0'];
163 if let Some(ch) = function_name
164 .chars()
165 .find(|&c| dangerous_chars.contains(&c))
166 {
167 return Err(Error::new(
168 span,
169 format!("Function name contains invalid character '{ch}'. Function names must be valid Rust identifiers.")
170 ));
171 }
172
173 // Check for control characters
174 if function_name.chars().any(|c| c.is_control()) {
175 return Err(Error::new(
176 span,
177 "Function name contains control characters",
178 ));
179 }
180
181 // Check for valid Rust identifier start (must start with letter or underscore)
182 if let Some(first_char) = function_name.chars().next() {
183 if !first_char.is_alphabetic() && first_char != '_' {
184 return Err(Error::new(
185 span,
186 "Function name must start with a letter or underscore",
187 ));
188 }
189 }
190
191 // Check that all characters are valid for Rust identifiers
192 if !function_name
193 .chars()
194 .all(|c| c.is_alphanumeric() || c == '_')
195 {
196 return Err(Error::new(
197 span,
198 "Function name must contain only letters, numbers, and underscores",
199 ));
200 }
201
202 // Check for reserved keywords using syn's built-in detection
203 // This is more robust than maintaining a hardcoded list
204 if is_reserved_keyword(function_name) {
205 return Err(Error::new(
206 span,
207 format!("Function name '{function_name}' is a reserved Rust keyword"),
208 ));
209 }
210
211 Ok(())
212}
213
214/// Checks if a function name is a reserved Rust keyword or invalid identifier.
215///
216/// This approach uses a focused, minimal list of core Rust keywords that are
217/// most problematic in the context of function names. Rather than trying to
218/// maintain a complete language specification, we focus on keywords that would
219/// cause immediate compilation issues.
220///
221/// This is a pragmatic approach that balances:
222/// - **Robustness**: Catches the most problematic keywords
223/// - **Maintainability**: Minimal, focused list that's easy to keep updated
224/// - **Performance**: Simple string comparison rather than complex parsing
225/// - **Reliability**: No false positives from parsing edge cases
226///
227/// The list is intentionally kept small and focused on keywords that would
228/// cause immediate compilation failures, rather than trying to catch every
229/// possible edge case.
230///
231/// # Arguments
232///
233/// * `function_name` - The function name to check
234///
235/// # Returns
236///
237/// Returns `true` if the function name is a reserved keyword or invalid, `false` otherwise.
238fn is_reserved_keyword(function_name: &str) -> bool {
239 // Instead of maintaining a hardcoded list, we use a minimal set of keywords
240 // that are most problematic in the context of function names.
241 // This is a pragmatic approach that balances robustness with simplicity.
242
243 // Core Rust keywords that would cause immediate compilation issues
244 let core_keywords = [
245 "fn",
246 "struct",
247 "enum",
248 "trait",
249 "impl",
250 "mod",
251 "use",
252 "extern",
253 "crate",
254 "type",
255 "const",
256 "static",
257 "let",
258 "mut",
259 "ref",
260 "move",
261 "dyn",
262 "async",
263 "await",
264 "if",
265 "else",
266 "match",
267 "loop",
268 "while",
269 "for",
270 "in",
271 "return",
272 "break",
273 "continue",
274 "pub",
275 "priv",
276 "unsafe",
277 "where",
278 "as",
279 "box",
280 "do",
281 "final",
282 "override",
283 "self",
284 "Self",
285 "super",
286 "macro",
287 "macro_rules",
288 "try",
289 "union",
290 ];
291
292 core_keywords.contains(&function_name)
293}
294
295/// Parses the `human_readable` attribute from the struct's attributes.
296///
297/// This function extracts the function name specified in the `#[representable(human_readable = "fn_name")]`
298/// attribute, validates it for security, and generates the appropriate function call expression.
299///
300/// # Arguments
301///
302/// * `attrs` - The struct's attributes to parse
303/// * `_struct_name` - The name of the struct (currently unused but kept for future extensibility)
304///
305/// # Returns
306///
307/// A `Result<syn::Expr, Error>` containing either:
308/// - A `syn::Expr` representing the function call to the human-readable function, or
309/// - An `Error` if validation fails
310///
311/// If no function is specified, returns a default JSON serialization expression.
312///
313/// # Examples
314///
315/// For `#[representable(human_readable = "my_fn")]`, this returns `Ok(my_fn(self))`.
316/// If no attribute is found, it returns `Ok(serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string()))`.
317fn parse_human_readable_fn(
318 attrs: &[syn::Attribute],
319 _struct_name: &Ident,
320) -> Result<syn::Expr, Error> {
321 for attr in attrs {
322 if attr.path().is_ident("representable") {
323 if let Meta::List(meta_list) = &attr.meta {
324 for nested in meta_list
325 .parse_args_with(
326 syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
327 )
328 .unwrap_or_default()
329 {
330 if let syn::Meta::NameValue(name_value) = nested {
331 if name_value.path.is_ident("human_readable") {
332 if let syn::Expr::Lit(syn::ExprLit {
333 lit: syn::Lit::Str(lit_str),
334 ..
335 }) = &name_value.value
336 {
337 let function_name = &lit_str.value();
338 let span = lit_str.span();
339
340 // Validate the function name for security
341 validate_function_name(function_name, span)?;
342
343 let fn_name = Ident::new(function_name, span);
344 return Ok(syn::parse_quote! { #fn_name(self) });
345 }
346 }
347 }
348 }
349 }
350 }
351 }
352
353 // Default to JSON output if no explicit function is provided
354 Ok(syn::parse_quote! {
355 serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
356 })
357}
358
359#[cfg(test)]
360mod tests {
361 use proc_macro2::Span;
362 use syn::parse_quote;
363
364 use super::*;
365
366 #[test]
367 fn test_validate_function_name_valid() {
368 let valid_names = [
369 "my_function",
370 "myFunction",
371 "my_function_123",
372 "_private_function",
373 "f",
374 "a1b2c3",
375 ];
376
377 for name in &valid_names {
378 assert!(
379 validate_function_name(name, Span::call_site()).is_ok(),
380 "Function name '{name}' should be valid"
381 );
382 }
383 }
384
385 #[test]
386 fn test_validate_function_name_invalid_characters() {
387 let invalid_names = [
388 ("my/function", "path separator"),
389 ("my\\function", "backslash"),
390 ("my:function", "colon"),
391 ("my*function", "asterisk"),
392 ("my?function", "question mark"),
393 ("my\"function", "quote"),
394 ("my<function", "less than"),
395 ("my>function", "greater than"),
396 ("my|function", "pipe"),
397 ("my\0function", "null byte"),
398 ];
399
400 for (name, description) in &invalid_names {
401 let result = validate_function_name(name, Span::call_site());
402 assert!(
403 result.is_err(),
404 "Function name '{name}' ({description}) should be invalid"
405 );
406 }
407 }
408
409 #[test]
410 fn test_validate_function_name_invalid_start() {
411 let invalid_names = ["1function", "123function", ".function", "-function"];
412
413 for name in &invalid_names {
414 let result = validate_function_name(name, Span::call_site());
415 assert!(
416 result.is_err(),
417 "Function name '{name}' should be invalid (invalid start)"
418 );
419 }
420 }
421
422 #[test]
423 fn test_validate_function_name_reserved_keywords() {
424 let reserved_keywords = [
425 "fn", "struct", "enum", "impl", "trait", "mod", "use", "pub", "priv", "let", "mut",
426 "const", "static", "if", "else", "match", "loop", "while", "for", "in", "return",
427 "break", "continue", "as", "where", "unsafe", "async", "await", "dyn", "move", "ref",
428 "self", "Self", "super",
429 ];
430
431 for keyword in &reserved_keywords {
432 let result = validate_function_name(keyword, Span::call_site());
433 assert!(
434 result.is_err(),
435 "Reserved keyword '{keyword}' should be invalid"
436 );
437 }
438 }
439
440 #[test]
441 fn test_validate_function_name_length_constraints() {
442 // Test empty and whitespace-only names
443 let empty_names = ["", " ", "\t", "\n", " \t \n "];
444 for name in &empty_names {
445 let result = validate_function_name(name, Span::call_site());
446 assert!(
447 result.is_err(),
448 "Empty/whitespace name '{name}' should be invalid"
449 );
450 }
451
452 // Test very long names
453 let long_name = "a".repeat(101);
454 let result = validate_function_name(&long_name, Span::call_site());
455 assert!(result.is_err(), "Very long name should be invalid");
456 }
457
458 #[test]
459 fn test_validate_function_name_control_characters() {
460 let control_chars = [
461 '\x00', '\x01', '\x02', '\x03', '\x04', '\x05', '\x06', '\x07', '\x08', '\x09', '\x0A',
462 '\x0B', '\x0C', '\x0D', '\x0E', '\x0F', '\x10', '\x11', '\x12', '\x13', '\x14', '\x15',
463 '\x16', '\x17', '\x18', '\x19', '\x1A', '\x1B', '\x1C', '\x1D', '\x1E', '\x1F', '\x7F',
464 ];
465
466 for &ch in &control_chars {
467 let name = format!("my{ch}function");
468 let result = validate_function_name(&name, Span::call_site());
469 assert!(
470 result.is_err(),
471 "Function name with control character '{ch}' should be invalid"
472 );
473 }
474 }
475
476 #[test]
477 fn test_parse_human_readable_fn_valid() {
478 let attrs = vec![parse_quote! {
479 #[representable(human_readable = "my_display_fn")]
480 }];
481 let struct_name = Ident::new("MyStruct", Span::call_site());
482
483 let result = parse_human_readable_fn(&attrs, &struct_name);
484 assert!(result.is_ok());
485
486 // The result should be a function call expression
487 if let Ok(expr) = result {
488 // We can't easily test the exact structure, but we can verify it's not the default
489 let expr_string = quote!(#expr).to_string();
490 assert!(expr_string.contains("my_display_fn"));
491 }
492 }
493
494 #[test]
495 fn test_parse_human_readable_fn_invalid() {
496 let attrs = vec![parse_quote! {
497 #[representable(human_readable = "my/function")]
498 }];
499 let struct_name = Ident::new("MyStruct", Span::call_site());
500
501 let result = parse_human_readable_fn(&attrs, &struct_name);
502 assert!(result.is_err());
503 }
504
505 #[test]
506 fn test_parse_human_readable_fn_no_attribute() {
507 let attrs = vec![];
508 let struct_name = Ident::new("MyStruct", Span::call_site());
509
510 let result = parse_human_readable_fn(&attrs, &struct_name);
511 assert!(result.is_ok());
512
513 // Should return the default JSON serialization
514 if let Ok(expr) = result {
515 let expr_string = quote!(#expr).to_string();
516 // The default expression contains serde_json::to_string and unwrap_or_else
517 assert!(expr_string.contains("serde_json") || expr_string.contains("unwrap_or_else"));
518 }
519 }
520
521 #[test]
522 fn test_parse_human_readable_fn_reserved_keyword() {
523 let attrs = vec![parse_quote! {
524 #[representable(human_readable = "fn")]
525 }];
526 let struct_name = Ident::new("MyStruct", Span::call_site());
527
528 let result = parse_human_readable_fn(&attrs, &struct_name);
529 assert!(result.is_err());
530 }
531}