Skip to main content

rialo_cli_representable/
lib.rs

1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! # Rialo CLI Representable Derive Macro
5//!
6//! This crate provides a procedural macro for deriving the `Representable` and `HumanReadable` traits
7//! for structs in the Rialo CLI system.
8//!
9//! ## Usage
10//!
11//! To use this macro, add the following to your struct:
12//!
13//! ```text
14//! #[derive(Representable)]
15//! #[representable(human_readable = "my_human_readable_fn")]
16//! struct MyStruct {
17//!     pub field1: String,
18//!     pub field2: u64,
19//! }
20//!
21//! fn my_human_readable_fn(data: &MyStruct) -> String {
22//!     format!("Field1: {}, Field2: {}", data.field1, data.field2)
23//! }
24//! ```
25//!
26//! ## Attributes
27//!
28//! The `#[representable]` attribute supports the following options:
29//!
30//! - `human_readable = "function_name"`: Specifies a custom function to use for human-readable output.
31//!   The function should take a reference to the struct and return a `String`.
32//!
33//! If no `human_readable` function is specified, the macro defaults to using `serde_json::to_string()`.
34//!
35//! ## Generated Code
36//!
37//! The macro generates implementations for:
38//!
39//! - `Representable`: A marker trait for CLI-representable types
40//! - `HumanReadable`: Provides a `human_readable()` method that returns a user-friendly string representation
41//!
42//! ## Example with Custom Display Function
43//!
44//! ```text
45//! #[derive(serde::Serialize, Clone, Representable)]
46//! #[representable(human_readable = "balance_display")]
47//! pub struct BalanceResult {
48//!     pub amount: f64,
49//!     pub currency: String,
50//! }
51//!
52//! fn balance_display(result: &BalanceResult) -> String {
53//!     format!("Balance: {} {}", result.amount, result.currency)
54//! }
55//! ```
56
57use proc_macro::TokenStream;
58use quote::quote;
59use syn::{parse_macro_input, DeriveInput, Error, Ident, Meta};
60
61/// Derives the `Representable` and `HumanReadable` traits for a struct.
62///
63/// This macro automatically implements:
64/// - `Representable`: A marker trait indicating the type can be represented in CLI output
65/// - `HumanReadable`: Provides a `human_readable()` method for user-friendly string representation
66///
67/// # Attributes
68///
69/// ## `#[representable(human_readable = "function_name")]`
70///
71/// Specifies a custom function to use for human-readable output. The function should:
72/// - Take a reference to the struct (`&Self`)
73/// - Return a `String`
74/// - Be defined in the same module as the struct
75///
76/// # Examples
77///
78/// With a custom human-readable function:
79///
80/// ```text
81/// #[derive(Representable)]
82/// #[representable(human_readable = "custom_display")]
83/// struct MyData {
84///     value: u64,
85/// }
86///
87/// fn custom_display(data: &MyData) -> String {
88///     format!("Value: {}", data.value)
89/// }
90/// ```
91///
92/// Without specifying a function (defaults to JSON):
93///
94/// ```text
95/// #[derive(Representable)]
96/// struct MyData {
97///     value: u64,
98/// }
99/// // Will use serde_json::to_string(self) by default
100/// ```
101#[proc_macro_derive(Representable, attributes(representable))]
102pub fn derive_representable(input: TokenStream) -> TokenStream {
103    let input = parse_macro_input!(input as DeriveInput);
104    let name = input.ident;
105
106    match parse_human_readable_fn(&input.attrs, &name) {
107        Ok(human_readable_fn) => {
108            let expanded = quote! {
109                impl ::rialo_cli_representation::HumanReadable for #name {
110                    fn human_readable(&self) -> String {
111                        #human_readable_fn
112                    }
113                }
114
115                impl ::rialo_cli_representation::Representable for #name {}
116            };
117
118            TokenStream::from(expanded)
119        }
120        Err(err) => {
121            let error = err.to_compile_error();
122            TokenStream::from(error)
123        }
124    }
125}
126
127/// Validates a function name to ensure it's safe for code generation.
128///
129/// This function performs several security checks:
130/// - Ensures the function name is a valid Rust identifier
131/// - Prevents path traversal attempts
132/// - Checks for potentially dangerous characters
133/// - Validates length constraints
134/// - Checks against reserved keywords using syn's built-in detection
135///
136/// # Arguments
137///
138/// * `function_name` - The function name to validate
139/// * `span` - The span for error reporting
140///
141/// # Returns
142///
143/// Returns `Ok(())` if the function name is valid, or an error with details.
144fn validate_function_name(function_name: &str, span: proc_macro2::Span) -> Result<(), Error> {
145    // Check for empty or whitespace-only names
146    if function_name.trim().is_empty() {
147        return Err(Error::new(
148            span,
149            "Function name cannot be empty or whitespace-only",
150        ));
151    }
152
153    // Check length constraints (reasonable limits for function names)
154    if function_name.len() > 100 {
155        return Err(Error::new(
156            span,
157            "Function name is too long (maximum 100 characters)",
158        ));
159    }
160
161    // Check for path separators and other dangerous characters
162    let dangerous_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\0'];
163    if let Some(ch) = function_name
164        .chars()
165        .find(|&c| dangerous_chars.contains(&c))
166    {
167        return Err(Error::new(
168            span,
169            format!("Function name contains invalid character '{ch}'. Function names must be valid Rust identifiers.")
170        ));
171    }
172
173    // Check for control characters
174    if function_name.chars().any(|c| c.is_control()) {
175        return Err(Error::new(
176            span,
177            "Function name contains control characters",
178        ));
179    }
180
181    // Check for valid Rust identifier start (must start with letter or underscore)
182    if let Some(first_char) = function_name.chars().next() {
183        if !first_char.is_alphabetic() && first_char != '_' {
184            return Err(Error::new(
185                span,
186                "Function name must start with a letter or underscore",
187            ));
188        }
189    }
190
191    // Check that all characters are valid for Rust identifiers
192    if !function_name
193        .chars()
194        .all(|c| c.is_alphanumeric() || c == '_')
195    {
196        return Err(Error::new(
197            span,
198            "Function name must contain only letters, numbers, and underscores",
199        ));
200    }
201
202    // Check for reserved keywords using syn's built-in detection
203    // This is more robust than maintaining a hardcoded list
204    if is_reserved_keyword(function_name) {
205        return Err(Error::new(
206            span,
207            format!("Function name '{function_name}' is a reserved Rust keyword"),
208        ));
209    }
210
211    Ok(())
212}
213
214/// Checks if a function name is a reserved Rust keyword or invalid identifier.
215///
216/// This approach uses a focused, minimal list of core Rust keywords that are
217/// most problematic in the context of function names. Rather than trying to
218/// maintain a complete language specification, we focus on keywords that would
219/// cause immediate compilation issues.
220///
221/// This is a pragmatic approach that balances:
222/// - **Robustness**: Catches the most problematic keywords
223/// - **Maintainability**: Minimal, focused list that's easy to keep updated
224/// - **Performance**: Simple string comparison rather than complex parsing
225/// - **Reliability**: No false positives from parsing edge cases
226///
227/// The list is intentionally kept small and focused on keywords that would
228/// cause immediate compilation failures, rather than trying to catch every
229/// possible edge case.
230///
231/// # Arguments
232///
233/// * `function_name` - The function name to check
234///
235/// # Returns
236///
237/// Returns `true` if the function name is a reserved keyword or invalid, `false` otherwise.
238fn is_reserved_keyword(function_name: &str) -> bool {
239    // Instead of maintaining a hardcoded list, we use a minimal set of keywords
240    // that are most problematic in the context of function names.
241    // This is a pragmatic approach that balances robustness with simplicity.
242
243    // Core Rust keywords that would cause immediate compilation issues
244    let core_keywords = [
245        "fn",
246        "struct",
247        "enum",
248        "trait",
249        "impl",
250        "mod",
251        "use",
252        "extern",
253        "crate",
254        "type",
255        "const",
256        "static",
257        "let",
258        "mut",
259        "ref",
260        "move",
261        "dyn",
262        "async",
263        "await",
264        "if",
265        "else",
266        "match",
267        "loop",
268        "while",
269        "for",
270        "in",
271        "return",
272        "break",
273        "continue",
274        "pub",
275        "priv",
276        "unsafe",
277        "where",
278        "as",
279        "box",
280        "do",
281        "final",
282        "override",
283        "self",
284        "Self",
285        "super",
286        "macro",
287        "macro_rules",
288        "try",
289        "union",
290    ];
291
292    core_keywords.contains(&function_name)
293}
294
295/// Parses the `human_readable` attribute from the struct's attributes.
296///
297/// This function extracts the function name specified in the `#[representable(human_readable = "fn_name")]`
298/// attribute, validates it for security, and generates the appropriate function call expression.
299///
300/// # Arguments
301///
302/// * `attrs` - The struct's attributes to parse
303/// * `_struct_name` - The name of the struct (currently unused but kept for future extensibility)
304///
305/// # Returns
306///
307/// A `Result<syn::Expr, Error>` containing either:
308/// - A `syn::Expr` representing the function call to the human-readable function, or
309/// - An `Error` if validation fails
310///
311/// If no function is specified, returns a default JSON serialization expression.
312///
313/// # Examples
314///
315/// For `#[representable(human_readable = "my_fn")]`, this returns `Ok(my_fn(self))`.
316/// If no attribute is found, it returns `Ok(serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string()))`.
317fn parse_human_readable_fn(
318    attrs: &[syn::Attribute],
319    _struct_name: &Ident,
320) -> Result<syn::Expr, Error> {
321    for attr in attrs {
322        if attr.path().is_ident("representable") {
323            if let Meta::List(meta_list) = &attr.meta {
324                for nested in meta_list
325                    .parse_args_with(
326                        syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
327                    )
328                    .unwrap_or_default()
329                {
330                    if let syn::Meta::NameValue(name_value) = nested {
331                        if name_value.path.is_ident("human_readable") {
332                            if let syn::Expr::Lit(syn::ExprLit {
333                                lit: syn::Lit::Str(lit_str),
334                                ..
335                            }) = &name_value.value
336                            {
337                                let function_name = &lit_str.value();
338                                let span = lit_str.span();
339
340                                // Validate the function name for security
341                                validate_function_name(function_name, span)?;
342
343                                let fn_name = Ident::new(function_name, span);
344                                return Ok(syn::parse_quote! { #fn_name(self) });
345                            }
346                        }
347                    }
348                }
349            }
350        }
351    }
352
353    // Default to JSON output if no explicit function is provided
354    Ok(syn::parse_quote! {
355        serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
356    })
357}
358
359#[cfg(test)]
360mod tests {
361    use proc_macro2::Span;
362    use syn::parse_quote;
363
364    use super::*;
365
366    #[test]
367    fn test_validate_function_name_valid() {
368        let valid_names = [
369            "my_function",
370            "myFunction",
371            "my_function_123",
372            "_private_function",
373            "f",
374            "a1b2c3",
375        ];
376
377        for name in &valid_names {
378            assert!(
379                validate_function_name(name, Span::call_site()).is_ok(),
380                "Function name '{name}' should be valid"
381            );
382        }
383    }
384
385    #[test]
386    fn test_validate_function_name_invalid_characters() {
387        let invalid_names = [
388            ("my/function", "path separator"),
389            ("my\\function", "backslash"),
390            ("my:function", "colon"),
391            ("my*function", "asterisk"),
392            ("my?function", "question mark"),
393            ("my\"function", "quote"),
394            ("my<function", "less than"),
395            ("my>function", "greater than"),
396            ("my|function", "pipe"),
397            ("my\0function", "null byte"),
398        ];
399
400        for (name, description) in &invalid_names {
401            let result = validate_function_name(name, Span::call_site());
402            assert!(
403                result.is_err(),
404                "Function name '{name}' ({description}) should be invalid"
405            );
406        }
407    }
408
409    #[test]
410    fn test_validate_function_name_invalid_start() {
411        let invalid_names = ["1function", "123function", ".function", "-function"];
412
413        for name in &invalid_names {
414            let result = validate_function_name(name, Span::call_site());
415            assert!(
416                result.is_err(),
417                "Function name '{name}' should be invalid (invalid start)"
418            );
419        }
420    }
421
422    #[test]
423    fn test_validate_function_name_reserved_keywords() {
424        let reserved_keywords = [
425            "fn", "struct", "enum", "impl", "trait", "mod", "use", "pub", "priv", "let", "mut",
426            "const", "static", "if", "else", "match", "loop", "while", "for", "in", "return",
427            "break", "continue", "as", "where", "unsafe", "async", "await", "dyn", "move", "ref",
428            "self", "Self", "super",
429        ];
430
431        for keyword in &reserved_keywords {
432            let result = validate_function_name(keyword, Span::call_site());
433            assert!(
434                result.is_err(),
435                "Reserved keyword '{keyword}' should be invalid"
436            );
437        }
438    }
439
440    #[test]
441    fn test_validate_function_name_length_constraints() {
442        // Test empty and whitespace-only names
443        let empty_names = ["", "   ", "\t", "\n", " \t \n "];
444        for name in &empty_names {
445            let result = validate_function_name(name, Span::call_site());
446            assert!(
447                result.is_err(),
448                "Empty/whitespace name '{name}' should be invalid"
449            );
450        }
451
452        // Test very long names
453        let long_name = "a".repeat(101);
454        let result = validate_function_name(&long_name, Span::call_site());
455        assert!(result.is_err(), "Very long name should be invalid");
456    }
457
458    #[test]
459    fn test_validate_function_name_control_characters() {
460        let control_chars = [
461            '\x00', '\x01', '\x02', '\x03', '\x04', '\x05', '\x06', '\x07', '\x08', '\x09', '\x0A',
462            '\x0B', '\x0C', '\x0D', '\x0E', '\x0F', '\x10', '\x11', '\x12', '\x13', '\x14', '\x15',
463            '\x16', '\x17', '\x18', '\x19', '\x1A', '\x1B', '\x1C', '\x1D', '\x1E', '\x1F', '\x7F',
464        ];
465
466        for &ch in &control_chars {
467            let name = format!("my{ch}function");
468            let result = validate_function_name(&name, Span::call_site());
469            assert!(
470                result.is_err(),
471                "Function name with control character '{ch}' should be invalid"
472            );
473        }
474    }
475
476    #[test]
477    fn test_parse_human_readable_fn_valid() {
478        let attrs = vec![parse_quote! {
479            #[representable(human_readable = "my_display_fn")]
480        }];
481        let struct_name = Ident::new("MyStruct", Span::call_site());
482
483        let result = parse_human_readable_fn(&attrs, &struct_name);
484        assert!(result.is_ok());
485
486        // The result should be a function call expression
487        if let Ok(expr) = result {
488            // We can't easily test the exact structure, but we can verify it's not the default
489            let expr_string = quote!(#expr).to_string();
490            assert!(expr_string.contains("my_display_fn"));
491        }
492    }
493
494    #[test]
495    fn test_parse_human_readable_fn_invalid() {
496        let attrs = vec![parse_quote! {
497            #[representable(human_readable = "my/function")]
498        }];
499        let struct_name = Ident::new("MyStruct", Span::call_site());
500
501        let result = parse_human_readable_fn(&attrs, &struct_name);
502        assert!(result.is_err());
503    }
504
505    #[test]
506    fn test_parse_human_readable_fn_no_attribute() {
507        let attrs = vec![];
508        let struct_name = Ident::new("MyStruct", Span::call_site());
509
510        let result = parse_human_readable_fn(&attrs, &struct_name);
511        assert!(result.is_ok());
512
513        // Should return the default JSON serialization
514        if let Ok(expr) = result {
515            let expr_string = quote!(#expr).to_string();
516            // The default expression contains serde_json::to_string and unwrap_or_else
517            assert!(expr_string.contains("serde_json") || expr_string.contains("unwrap_or_else"));
518        }
519    }
520
521    #[test]
522    fn test_parse_human_readable_fn_reserved_keyword() {
523        let attrs = vec![parse_quote! {
524            #[representable(human_readable = "fn")]
525        }];
526        let struct_name = Ident::new("MyStruct", Span::call_site());
527
528        let result = parse_human_readable_fn(&attrs, &struct_name);
529        assert!(result.is_err());
530    }
531}