riglr_macros/
lib.rs

1/*!
2# riglr-macros
3
4Procedural macros for riglr - dramatically reducing boilerplate when creating blockchain tools.
5
6The `#[tool]` macro is the cornerstone of riglr's developer experience, transforming simple async
7functions, synchronous functions, and structs into full-featured blockchain tools with automatic error handling, JSON
8schema generation, and seamless `rig` framework integration.
9
10## Overview
11
12The `#[tool]` macro automatically implements the `Tool` trait for both async and sync functions, as well as structs,
13eliminating the need to write ~30 lines of boilerplate code per tool. It generates:
14
151. **Parameter struct** with proper JSON schema and serde annotations
162. **Tool trait implementation** with error handling and type conversion
173. **Documentation extraction** from doc comments for AI model consumption
184. **SignerContext integration** for secure blockchain operations
195. **Convenience constructors** for easy instantiation
20
21## Code Generation Process
22
23When you apply `#[tool]` to a function, the macro performs the following transformations:
24
25### 1. Parameter Extraction and Struct Generation
26
27```rust,ignore
28// Your function:
29#[tool]
30async fn swap_tokens(
31    /// Source token mint address
32    from_mint: String,
33    /// Destination token mint address
34    to_mint: String,
35    /// Amount to swap in base units
36    amount: u64,
37    /// Optional slippage tolerance (default: 0.5%)
38    #[serde(default = "default_slippage")]
39    slippage_bps: Option<u16>,
40) -> Result<String, SwapError> { ... }
41
42// Generated args struct:
43#[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema, Debug, Clone)]
44#[serde(rename_all = "camelCase")]
45pub struct SwapTokensArgs {
46    /// Source token mint address
47    pub from_mint: String,
48    /// Destination token mint address
49    pub to_mint: String,
50    /// Amount to swap in base units
51    pub amount: u64,
52    /// Optional slippage tolerance (default: 0.5%)
53    #[serde(default = "default_slippage")]
54    pub slippage_bps: Option<u16>,
55}
56```
57
58### 2. Tool Struct and Trait Implementation Generation
59
60```rust,ignore
61// Generated tool struct:
62#[derive(Clone)]
63pub struct SwapTokensTool;
64
65impl SwapTokensTool {
66    pub fn new() -> Self { Self }
67}
68
69#[async_trait::async_trait]
70impl riglr_core::Tool for SwapTokensTool {
71    async fn execute(&self, params: serde_json::Value, context: &riglr_core::provider::ApplicationContext) -> Result<riglr_core::JobResult, riglr_core::ToolError> {
72        // 1. Parse parameters with detailed error messages
73        let args: SwapTokensArgs = serde_json::from_value(params)
74            .map_err(|e| format!("Failed to parse parameters: {}", e))?;
75
76        // 2. Call your original function
77        let result = swap_tokens(args.from_mint, args.to_mint, args.amount, args.slippage_bps).await;
78
79        // 3. Convert results to standardized JobResult format
80        match result {
81            Ok(value) => Ok(riglr_core::JobResult::Success {
82                value: serde_json::to_value(value)?,
83                tx_hash: None,
84            }),
85            Err(error) => {
86                // 4. Structured error handling with retry logic
87                let tool_error: riglr_core::ToolError = error.into();
88                match tool_error {
89                    riglr_core::ToolError::Retriable(msg) => Ok(riglr_core::JobResult::Failure {
90                        error: msg,
91                        retriable: true,
92                    }),
93                    riglr_core::ToolError::Permanent(msg) => Ok(riglr_core::JobResult::Failure {
94                        error: msg,
95                        retriable: false,
96                    }),
97                    riglr_core::ToolError::RateLimited(msg) => Ok(riglr_core::JobResult::Failure {
98                        error: format!("Rate limited: {}", msg),
99                        retriable: true,
100                    }),
101                    riglr_core::ToolError::InvalidInput(msg) => Ok(riglr_core::JobResult::Failure {
102                        error: format!("Invalid input: {}", msg),
103                        retriable: false,
104                    }),
105                    riglr_core::ToolError::SignerContext(err) => Ok(riglr_core::JobResult::Failure {
106                        error: format!("Signer error: {}", err),
107                        retriable: false,
108                    }),
109                }
110            }
111        }
112    }
113
114    fn name(&self) -> &str {
115        "swap_tokens"
116    }
117}
118
119// Convenience constructor
120pub fn swap_tokens_tool() -> std::sync::Arc<dyn riglr_core::Tool> {
121    std::sync::Arc::new(SwapTokensTool::new())
122}
123```
124
125### 3. Documentation Processing and Description Attribute
126
127The macro extracts documentation from three sources and wires them into the Tool implementation:
128
129- **Function docstrings** → Tool descriptions for AI models
130- **Parameter docstrings** → JSON schema field descriptions
131- **Type annotations** → JSON schema type information
132
133You can also provide an explicit AI-facing description using the attribute:
134
135```rust,ignore
136#[tool(description = "Fetches the URL and returns the body as text.")]
137async fn fetch(url: String) -> Result<String, Error> { ... }
138```
139
140Priority logic for the generated `Tool::description()` method:
141- If `description = "..."` attribute is present, that string is used
142- Else, the item's rustdoc comments are used
143- Else, an empty string is returned
144
145This enables AI models to understand exactly what each tool does and how to use it properly.
146
147## Constraints and Requirements
148
149### Function Requirements
150
1511. **Return Type**: Must be `Result<T, E>` where `E: Into<riglr_core::ToolError>`
152   ```rust,ignore
153   // ✅ Valid - custom error type with derive
154   #[derive(Error, Debug, IntoToolError)]
155   enum MyError { NetworkError(String), InvalidInput(String) }
156   async fn valid_tool() -> Result<String, MyError> { ... }
157
158   // ❌ Invalid - not a Result
159   async fn invalid_tool() -> String { ... }
160
161   // ❌ Invalid - std::io::Error doesn't implement Into<ToolError>
162   async fn bad_error() -> Result<String, std::io::Error> { ... }
163
164   // ✅ Valid - wrap std library errors in custom types
165   #[derive(Error, Debug, IntoToolError)]
166   enum FileError {
167       #[error("IO error: {0}")]
168       Io(#[from] std::io::Error)
169   }
170   async fn good_file_tool() -> Result<String, FileError> { ... }
171   ```
172
1732. **Parameters**: All parameters must implement `serde::Deserialize + schemars::JsonSchema`
174   ```rust,ignore
175   // ✅ Valid - standard types implement these automatically
176   async fn good_params(address: String, amount: u64) -> Result<(), ToolError> { ... }
177
178   // ❌ Invalid - custom types need derives
179   struct CustomType { field: String }
180   async fn bad_params(custom: CustomType) -> Result<(), ToolError> { ... }
181   ```
182
1833. **Function Type**: The macro supports both async and synchronous functions
184   ```rust,ignore
185   // ✅ Valid - async function
186   #[tool]
187   async fn async_tool() -> Result<String, ToolError> { ... }
188
189   // ✅ Valid - sync function (executed within async context)
190   #[tool]
191   fn sync_tool() -> Result<String, ToolError> { ... }
192   ```
193
194   Synchronous functions are automatically wrapped to work within the async Tool trait.
195   They execute synchronously within the async `execute` method.
196
1974. **Documentation**: Function and parameters should have doc comments for AI consumption
198   ```rust,ignore
199   /// This description helps AI models understand the tool's purpose
200   #[tool]
201   async fn documented_tool(
202       /// This helps the AI understand this parameter
203       param: String,
204   ) -> Result<String, ToolError> { ... }
205   ```
206
207### Struct Requirements
208
209For struct-based tools, additional requirements apply:
210
2111. **Execute Method**: Must have an async `execute` method returning `Result<T, E>`
2122. **Serde Traits**: Must derive `Serialize`, `Deserialize`, and `JsonSchema`
2133. **Clone**: Must be `Clone` for multi-use scenarios
214
215```rust,ignore
216#[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema, Clone)]
217#[tool]
218struct MyStructTool {
219    config: String,
220}
221
222impl MyStructTool {
223    pub async fn execute(&self) -> Result<String, ToolError> {
224        // Implementation
225        Ok(format!("Processed: {}", self.config))
226    }
227}
228```
229
230## Complex Usage Examples
231
232### Synchronous Function Example
233
234The macro supports both async and sync functions. Sync functions are useful for
235computational tools that don't require I/O operations:
236
237```rust,ignore
238use riglr_core::ToolError;
239
240/// Calculate compound interest for a given principal, rate, and time
241///
242/// This is a computational tool that doesn't require async operations,
243/// so it's implemented as a synchronous function that runs efficiently
244/// within the async Tool framework.
245#[tool]
246fn calculate_compound_interest(
247    /// Principal amount in dollars
248    principal: f64,
249    /// Annual interest rate as a decimal (e.g., 0.05 for 5%)
250    annual_rate: f64,
251    /// Time period in years
252    years: f64,
253    /// Number of times interest is compounded per year
254    compounds_per_year: u32,
255) -> Result<f64, ToolError> {
256    if principal <= 0.0 {
257        return Err(ToolError::invalid_input_string("Principal must be positive"));
258    }
259    if annual_rate < 0.0 {
260        return Err(ToolError::invalid_input_string("Interest rate cannot be negative"));
261    }
262    if years < 0.0 {
263        return Err(ToolError::invalid_input_string("Time period cannot be negative"));
264    }
265    if compounds_per_year == 0 {
266        return Err(ToolError::invalid_input_string("Compounds per year must be at least 1"));
267    }
268
269    let rate_per_compound = annual_rate / compounds_per_year as f64;
270    let total_compounds = compounds_per_year as f64 * years;
271    let final_amount = principal * (1.0 + rate_per_compound).powf(total_compounds);
272
273    Ok(final_amount)
274}
275```
276
277#### Important Note on CPU-Intensive Sync Functions
278
279The `#[tool]` macro executes synchronous functions directly within the async executor's thread.
280This is fine for quick computations, but **CPU-intensive operations can block the async runtime**.
281
282For CPU-intensive work, wrap your function in `tokio::task::spawn_blocking` **before** applying
283the `#[tool]` macro:
284
285```rust,ignore
286use riglr_core::ToolError;
287
288/// CPU-intensive cryptographic operation
289///
290/// This uses spawn_blocking to avoid blocking the async runtime
291#[tool]
292async fn compute_hash(
293    /// Data to hash
294    data: Vec<u8>,
295    /// Number of iterations
296    iterations: u32,
297) -> Result<String, ToolError> {
298    // Move CPU-intensive work to a blocking thread pool
299    tokio::task::spawn_blocking(move || {
300        // Simulate expensive computation
301        let mut hash = data;
302        for _ in 0..iterations {
303            hash = sha256::digest(&hash).into_bytes();
304        }
305        Ok(hex::encode(hash))
306    })
307    .await
308    .map_err(|e| ToolError::permanent_string(format!("Task failed: {}", e)))?
309}
310```
311
312**Guidelines for choosing between sync and async with spawn_blocking:**
313- **Use sync functions** for quick calculations (< 1ms), simple data transformations, or validation
314- **Use async + spawn_blocking** for CPU-intensive work like cryptography, complex parsing, or heavy computation
315- **Use regular async** for I/O operations like network requests or database queries
316
317### Generic Parameters and Type Constraints
318
319```rust,ignore
320use serde::{Serialize, Deserialize};
321use schemars::JsonSchema;
322
323/// Generic tool that can process any serializable data
324#[tool]
325async fn process_data<T>(
326    /// The data to process (must be JSON-serializable)
327    data: T,
328    /// Processing options
329    options: ProcessingOptions,
330) -> Result<ProcessedData, ProcessingError>
331where
332    T: Serialize + Deserialize + JsonSchema + Send + Sync,
333{
334    // The macro handles generic constraints properly
335    let serialized = serde_json::to_string(&data)?;
336    // ... processing logic
337    Ok(ProcessedData::new(serialized))
338}
339```
340
341### SignerContext Integration
342
343Tools automatically have access to the current blockchain signer:
344
345```rust,ignore
346use riglr_core::signer::SignerContext;
347
348/// Swap tokens on Solana using Jupiter aggregator
349///
350/// This tool automatically accesses the current signer from the context,
351/// eliminating the need to pass signing credentials explicitly.
352#[tool]
353async fn jupiter_swap(
354    /// Input token mint address
355    input_mint: String,
356    /// Output token mint address
357    output_mint: String,
358    /// Amount to swap in base units
359    amount: u64,
360    /// Maximum slippage in basis points
361    max_slippage_bps: u16,
362) -> Result<String, SwapError> {
363    // Access the current signer automatically
364    let signer = SignerContext::current().await?;
365
366    // Derive RPC client from signer
367    let rpc_client = signer.rpc_client();
368
369    // Get quote from Jupiter
370    let quote = get_jupiter_quote(&input_mint, &output_mint, amount, max_slippage_bps).await?;
371
372    // Build and sign transaction
373    let tx = build_swap_transaction(quote, &signer.pubkey()).await?;
374    let signed_tx = signer.sign_transaction(tx).await?;
375
376    // Send transaction
377    let signature = rpc_client.send_and_confirm_transaction(&signed_tx).await?;
378
379    Ok(signature.to_string())
380}
381```
382
383### Multi-Chain Tool with Dynamic Signer Selection
384
385```rust,ignore
386use riglr_core::signer::{SignerContext, ChainType};
387
388/// Bridge tokens between different blockchains
389///
390/// Automatically detects the source chain from the current signer
391/// and handles cross-chain bridging operations.
392#[tool]
393async fn bridge_tokens(
394    /// Source token address
395    source_token: String,
396    /// Destination chain identifier
397    dest_chain: String,
398    /// Destination token address
399    dest_token: String,
400    /// Amount to bridge in base units
401    amount: u64,
402    /// Recipient address on destination chain
403    recipient: String,
404) -> Result<BridgeResult, BridgeError> {
405    let signer = SignerContext::current().await?;
406
407    // Dynamic chain detection
408    let bridge_operation = match signer.chain_type() {
409        ChainType::Solana => {
410            SolanaBridge::new(signer).bridge_to_evm(
411                source_token, dest_chain, dest_token, amount, recipient
412            ).await?
413        },
414        ChainType::Ethereum => {
415            EthereumBridge::new(signer).bridge_to_solana(
416                source_token, dest_token, amount, recipient
417            ).await?
418        },
419        ChainType::Polygon => {
420            PolygonBridge::new(signer).bridge_cross_chain(
421                source_token, dest_chain, dest_token, amount, recipient
422            ).await?
423        },
424        _ => return Err(BridgeError::UnsupportedChain),
425    };
426
427    Ok(bridge_operation)
428}
429```
430
431### Error Handling and Retry Logic
432
433The macro automatically integrates with riglr's structured error handling.
434
435**IMPORTANT REQUIREMENT:** The `#[tool]` macro requires that all error types implement `Into<ToolError>`.
436There is no automatic conversion for standard library error types like `std::io::Error` or `reqwest::Error`.
437You must define custom error types that provide proper classification and context.
438
439#### Recommended Pattern: Custom Error Types with `#[derive(IntoToolError)]`
440
441The required practice is to use the `IntoToolError` derive macro for automatic error handling:
442
443```rust,ignore
444use riglr_macros::IntoToolError;
445use thiserror::Error;
446
447#[derive(Error, Debug, IntoToolError)]
448enum SwapError {
449    #[error("Insufficient balance: need {required}, have {available}")]
450    InsufficientBalance { required: u64, available: u64 },
451
452    #[error("Network congestion, retry in {retry_after_seconds}s")]
453    #[tool_error(retriable)]  // Override default classification
454    NetworkCongestion { retry_after_seconds: u64 },
455
456    #[error("Slippage too high: expected {expected}%, got {actual}%")]
457    SlippageTooHigh { expected: f64, actual: f64 },
458
459    #[error("Invalid token mint: {mint}")]
460    InvalidToken { mint: String },
461}
462
463// The IntoToolError derive macro automatically generates the From<SwapError> for ToolError impl
464```
465
466See the `trybuild` tests in `riglr-macros/tests/trybuild/` for examples:
467- `pass/custom_error_into.rs` - Correct usage with custom error types
468- `fail/unconvertible_error.rs` - What happens when error types don't implement Into<ToolError>
469
470#### Alternative: Manual Implementation
471
472If you need more control, you can manually implement the conversion:
473
474```rust,ignore
475use riglr_core::ToolError;
476
477impl From<SwapError> for ToolError {
478    fn from(error: SwapError) -> Self {
479        match error {
480            SwapError::NetworkCongestion { .. } => ToolError::Retriable(error.to_string()),
481            SwapError::InsufficientBalance { .. } => ToolError::Permanent(error.to_string()),
482            SwapError::SlippageTooHigh { .. } => ToolError::Permanent(error.to_string()),
483            SwapError::InvalidToken { .. } => ToolError::Permanent(error.to_string()),
484        }
485    }
486}
487
488/// Advanced token swap with detailed error handling
489#[tool]
490async fn advanced_swap(
491    input_mint: String,
492    output_mint: String,
493    amount: u64,
494) -> Result<SwapResult, SwapError> {
495    let signer = SignerContext::current().await?;
496
497    // Check balance first
498    let balance = get_token_balance(&signer, &input_mint).await?;
499    if balance < amount {
500        return Err(SwapError::InsufficientBalance {
501            required: amount,
502            available: balance,
503        });
504    }
505
506    // Attempt swap with retries for transient failures
507    match attempt_swap(&signer, &input_mint, &output_mint, amount).await {
508        Err(SwapError::NetworkCongestion { .. }) => {
509            // The macro will automatically mark this as retriable
510            Err(SwapError::NetworkCongestion { retry_after_seconds: 10 })
511        },
512        result => result,
513    }
514}
515```
516
517### Testing Tool Implementations
518
519The macro-generated code is designed to be easily testable:
520
521```rust,ignore
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use riglr_core::signer::{MockSigner, SignerContext};
526    use serde_json::json;
527
528    #[tokio::test]
529    async fn test_swap_tool_execution() {
530        // Create mock signer with expected behavior
531        let mock_signer = MockSigner::new()
532            .with_token_balance("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", 1000000)  // USDC
533            .expect_transaction("swap")
534            .returns_signature("5j7s2Hz2UnknownTxHash");
535
536        // Test the generated tool
537        let tool = SwapTokensTool::new();
538
539        let result = SignerContext::new(&mock_signer).execute(async {
540            tool.execute(json!({
541                "fromMint": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
542                "toMint": "So11111111111111111111111111111111111111112",
543                "amount": 1000000,
544                "slippageBps": 50
545            })).await
546        }).await;
547
548        assert!(result.is_ok());
549        mock_signer.verify_all_expectations();
550    }
551}
552```
553
554## Best Practices
555
556### 1. Parameter Design
557- Use descriptive parameter names that clearly indicate their purpose
558- Provide comprehensive doc comments for each parameter
559- Use appropriate default values with `#[serde(default)]` where applicable
560- Group related parameters into structs for complex operations
561
562### 2. Error Handling
563- Define custom error types that implement `Into<ToolError>`
564- Use structured errors that provide actionable information
565- Distinguish between retriable and permanent errors appropriately
566- Include relevant context in error messages
567
568### 3. Documentation
569- Write clear, concise function descriptions that explain the tool's purpose
570- Document any side effects or state changes
571- Include examples in doc comments where helpful
572- Explain any complex parameters or return values
573
574### 4. Performance Considerations
575- Use `Arc<dyn Tool>` for tools that will be shared across threads
576- Implement `Clone` efficiently for struct-based tools
577- Consider caching for expensive operations that don't change frequently
578- Use appropriate timeouts for network operations
579
580### 5. Security and Business Logic Validation
581
582**⚠️ IMPORTANT SECURITY NOTE:** While the `#[tool]` macro and `serde` automatically handle parameter *format* validation (JSON schema, type conversion, required fields), your tool implementation is still responsible for all *business logic* validation and security checks.
583
584#### Critical Business Logic Validations:
585
586**Financial Operations:**
587```rust,ignore
588#[tool]
589async fn transfer_tokens(
590    to_address: String,
591    amount: f64,
592    slippage_percent: f64,
593) -> Result<String, ToolError> {
594    // ✅ Business logic validation (your responsibility)
595    if amount <= 0.0 {
596        return Err(ToolError::invalid_input_string(
597            "Transfer amount must be positive"
598        ));
599    }
600
601    if slippage_percent >= 5.0 {
602        return Err(ToolError::invalid_input_string(
603            "Slippage tolerance too high (max 5%). Consider if this is intentional"
604        ));
605    }
606
607    // ✅ Address validation
608    if !is_valid_address(&to_address) {
609        return Err(ToolError::invalid_input_string(
610            "Invalid recipient address format"
611        ));
612    }
613
614    // ✅ Balance check before executing
615    let balance = get_current_balance().await?;
616    if balance < amount {
617        return Err(ToolError::permanent_string(
618            format!("Insufficient balance: {} < {}", balance, amount)
619        ));
620    }
621
622    // Proceed with transfer...
623}
624```
625
626**Smart Contract Interactions:**
627```rust,ignore
628#[tool]
629async fn execute_contract_call(
630    contract_address: String,
631    function_name: String,
632    gas_limit: u64,
633) -> Result<String, ToolError> {
634    // ✅ Contract address validation
635    if !is_trusted_contract(&contract_address) {
636        return Err(ToolError::permanent_string(
637            "Contract not in approved whitelist"
638        ));
639    }
640
641    // ✅ Re-entrancy protection
642    if is_contract_execution_in_progress(&contract_address) {
643        return Err(ToolError::retriable_string(
644            "Contract execution already in progress, avoiding re-entrancy"
645        ));
646    }
647
648    // ✅ Gas limit safety check
649    if gas_limit > MAX_SAFE_GAS_LIMIT {
650        return Err(ToolError::invalid_input_string(
651            "Gas limit exceeds safety threshold"
652        ));
653    }
654
655    // Proceed with contract call...
656}
657```
658
659**Data Integrity Checks:**
660```rust,ignore
661#[tool]
662async fn process_transaction_data(
663    tx_hash: String,
664    expected_amount: f64,
665) -> Result<TransactionResult, ToolError> {
666    // ✅ Transaction hash format validation
667    if tx_hash.len() != 64 || !tx_hash.chars().all(|c| c.is_ascii_hexdigit()) {
668        return Err(ToolError::invalid_input_string(
669            "Invalid transaction hash format"
670        ));
671    }
672
673    // ✅ Cross-reference with external data
674    let actual_amount = fetch_transaction_amount(&tx_hash).await?;
675    if (actual_amount - expected_amount).abs() > 0.001 {
676        return Err(ToolError::permanent_string(
677            "Transaction amount mismatch detected"
678        ));
679    }
680
681    // Proceed with processing...
682}
683```
684
685#### Remember: The Macro Handles Format, You Handle Business Logic
686- **Macro + Serde**: Validates JSON structure, types, required fields
687- **Your Code**: Validates ranges, business rules, security constraints, data relationships
688
689## Macro Limitations
690
691### Current Limitations
692
6931. **Generic Functions**: Limited support for complex generic constraints
6942. **Lifetime Parameters**: Not currently supported in tool functions
6953. **Associated Types**: Cannot use associated types in parameters
6964. **Const Generics**: No support for const generic parameters
697
698### Workarounds
699
700For complex generic scenarios, consider using trait objects or type erasure:
701
702```rust,ignore
703// Instead of:
704// #[tool]
705// async fn complex_generic<T: ComplexTrait>(data: T) -> Result<(), Error> { ... }
706
707// Use:
708#[tool]
709async fn process_complex_data(
710    /// JSON representation of the data to process
711    data: serde_json::Value,
712) -> Result<ProcessedResult, ProcessError> {
713    // Deserialize to specific types inside the function
714    let typed_data: MyType = serde_json::from_value(data)?;
715    // ... process typed_data
716}
717```
718
719## Integration with External Crates
720
721The macro is designed to work seamlessly with the broader Rust ecosystem:
722
723### Serde Integration
724- Automatic `#[serde(rename_all = "camelCase")]` for JavaScript compatibility
725- Support for all serde attributes on parameters
726- Custom serialization/deserialization via serde derives
727
728### JSON Schema Generation
729- Automatic schema generation via `schemars` crate
730- Support for complex nested types and enums
731- Custom schema attributes for fine-tuned control
732
733### Async Runtime Compatibility
734- Works with any async runtime (tokio, async-std, etc.)
735- Proper handling of async trait implementations
736- Support for async error handling patterns
737
738The `#[tool]` macro transforms riglr from a collection of utilities into a cohesive,
739developer-friendly framework for building sophisticated blockchain AI agents.
740*/
741
742use heck::ToPascalCase;
743use proc_macro::TokenStream;
744use quote::quote;
745use syn::{
746    parse::Parse, parse::ParseStream, parse_macro_input, Attribute, DeriveInput, FnArg, ItemFn,
747    ItemStruct, LitStr, PatType, Token,
748};
749
750/// The `#[tool]` procedural macro that converts functions and structs into Tool implementations.
751///
752/// This macro supports:
753/// - Async functions with arbitrary parameters and Result return types
754/// - Structs that have an `execute` method
755/// - Automatic JSON schema generation using `schemars`
756/// - Documentation extraction from doc comments
757/// - Parameter descriptions from doc comments on function arguments
758///
759/// Attributes supported:
760/// - description = "..."
761#[proc_macro_attribute]
762pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
763    let input = item.clone();
764
765    let tool_attrs = match syn::parse::<ToolAttr>(attr) {
766        Ok(attrs) => attrs,
767        Err(_) => ToolAttr { description: None },
768    };
769
770    // Try to parse as function first, then as struct
771    if let Ok(function) = syn::parse::<ItemFn>(input.clone()) {
772        handle_function(function, tool_attrs).into()
773    } else if let Ok(structure) = syn::parse::<ItemStruct>(input) {
774        handle_struct(structure, tool_attrs).into()
775    } else {
776        syn::Error::new_spanned(
777            proc_macro2::TokenStream::from(item),
778            "#[tool] can only be applied to async functions or structs.\n\
779            For functions: Must be async and return Result<T, E> where E: Into<ToolError>\n\
780            For structs: Must implement Clone, Serialize, Deserialize, JsonSchema and have an async execute(&self) method",
781        )
782        .to_compile_error()
783        .into()
784    }
785}
786
787#[derive(Default, Debug)]
788struct ToolAttr {
789    description: Option<String>,
790}
791
792impl Parse for ToolAttr {
793    fn parse(input: ParseStream) -> syn::Result<Self> {
794        if input.is_empty() {
795            return Ok(Self::default());
796        }
797
798        let lookahead = input.lookahead1();
799        if lookahead.peek(syn::Ident) {
800            let ident: syn::Ident = input.parse()?;
801            if ident == "description" {
802                input.parse::<Token![=]>()?;
803                let lit: LitStr = input.parse()?;
804                return Ok(Self {
805                    description: Some(lit.value()),
806                });
807            } else {
808                return Err(syn::Error::new_spanned(
809                    ident,
810                    "Unknown attribute key. Supported: description",
811                ));
812            }
813        }
814
815        Err(syn::Error::new(
816            input.span(),
817            "Expected attribute key like: description = \"...\"",
818        ))
819    }
820}
821
822/// Helper function to check if a parameter is a context parameter (by type)
823fn is_context_param(param_type: &syn::Type) -> bool {
824    // Check if the type is &ApplicationContext or &riglr_core::provider::ApplicationContext
825    if let syn::Type::Reference(type_ref) = param_type {
826        if let syn::Type::Path(type_path) = &*type_ref.elem {
827            let path_str = type_path
828                .path
829                .segments
830                .iter()
831                .map(|segment| segment.ident.to_string())
832                .collect::<Vec<_>>()
833                .join("::");
834
835            return path_str == "ApplicationContext"
836                || path_str == "riglr_core::provider::ApplicationContext"
837                || path_str.ends_with("::ApplicationContext");
838        }
839    }
840    false
841}
842
843/// Check if a type is Result<T, E>
844fn is_result_type(ty: &syn::Type) -> bool {
845    if let syn::Type::Path(type_path) = ty {
846        if let Some(segment) = type_path.path.segments.last() {
847            let segment_name = segment.ident.to_string();
848            return segment_name == "Result"
849                && matches!(segment.arguments, syn::PathArguments::AngleBracketed(_));
850        }
851    }
852    false
853}
854
855/// Check if a type is likely serializable
856fn is_serializable_type(ty: &syn::Type) -> bool {
857    match ty {
858        // Basic serializable types
859        syn::Type::Path(type_path) => {
860            if let Some(segment) = type_path.path.segments.last() {
861                let segment_name = segment.ident.to_string();
862                match segment_name.as_str() {
863                    // Primitive types
864                    "String" | "str" | "bool" | "i8" | "i16" | "i32" | "i64" | "i128" | "isize"
865                    | "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "f32" | "f64" | "char" => {
866                        true
867                    }
868
869                    // Common generic types that are serializable
870                    "Vec" | "Option" | "HashMap" | "BTreeMap" | "HashSet" | "BTreeSet"
871                    | "VecDeque" => true,
872
873                    // Common time types
874                    "SystemTime" | "Duration" => true,
875
876                    // Assume custom types are serializable (user responsibility)
877                    _ => true,
878                }
879            } else {
880                false
881            }
882        }
883
884        // References to serializable types
885        syn::Type::Reference(type_ref) => is_serializable_type(&type_ref.elem),
886
887        // Arrays are serializable if their element type is
888        syn::Type::Array(type_array) => is_serializable_type(&type_array.elem),
889
890        // Slices are serializable if their element type is
891        syn::Type::Slice(type_slice) => is_serializable_type(&type_slice.elem),
892
893        // Tuples are serializable if all elements are
894        syn::Type::Tuple(type_tuple) => type_tuple.elems.iter().all(is_serializable_type),
895
896        // Other types - be conservative and reject
897        _ => false,
898    }
899}
900
901/// Extract the error type from a Result<T, E> return type
902fn extract_error_type(return_type: &syn::ReturnType) -> Option<syn::Type> {
903    if let syn::ReturnType::Type(_, ty) = return_type {
904        if let syn::Type::Path(type_path) = ty.as_ref() {
905            if let Some(segment) = type_path.path.segments.last() {
906                if segment.ident == "Result" {
907                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
908                        // Result<T, E> - get the second type argument (E)
909                        if args.args.len() == 2 {
910                            if let syn::GenericArgument::Type(error_type) = &args.args[1] {
911                                return Some(error_type.clone());
912                            }
913                        }
914                    }
915                }
916            }
917        }
918    }
919    None
920}
921
922/// Generate error conversion code based on the error type
923fn generate_error_conversion(error_type: &Option<syn::Type>) -> proc_macro2::TokenStream {
924    let Some(_err_type) = error_type else {
925        // No error type specified - use standard Into conversion
926        // This relies on the user's type implementing Into<ToolError>
927        return quote! { error.into() };
928    };
929
930    // REFACTORED: Strict error handling - only use Into<ToolError> trait
931    //
932    // ERROR CONVERSION LOGIC:
933    // All error types must implement Into<ToolError> to be used with the #[tool] macro.
934    // This enforces a consistent error handling pattern and encourages users to:
935    //
936    // 1. Define custom error enums with #[derive(IntoToolError)] for automatic conversion
937    // 2. Manually implement From<MyError> for ToolError for fine-grained control
938    // 3. Wrap standard library errors in custom types that provide better context
939    //
940    // This removes the special-case handling for std::io::Error and reqwest::Error
941    // that previously existed as fallbacks. Users must now explicitly handle these
942    // error types by wrapping them in custom error enums.
943    //
944    // IMPORTANT: If your function returns Result<T, std::io::Error> or similar,
945    // the compilation will now fail with a clear error message directing you to
946    // implement Into<ToolError> for your error type.
947
948    // Use Into<ToolError> conversion for all error types
949    // If the error type doesn't implement Into<ToolError>, this will produce
950    // a compile error with a clear message about the missing trait implementation
951    quote! { error.into() }
952}
953
954fn handle_function(function: ItemFn, tool_attrs: ToolAttr) -> proc_macro2::TokenStream {
955    let fn_name = &function.sig.ident;
956    let fn_vis = &function.vis;
957
958    // Extract documentation from function
959    let description = extract_doc_comments(&function.attrs);
960    let selected_description = match tool_attrs.description {
961        Some(desc) => desc,
962        None => description,
963    };
964
965    // Partition parameters into user_params and context_params
966    let mut user_params = Vec::new();
967    let mut context_params = Vec::new();
968
969    for input in function.sig.inputs.iter() {
970        if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = input {
971            if is_context_param(ty) {
972                context_params.push((pat, ty, attrs));
973            } else {
974                user_params.push((pat, ty, attrs));
975            }
976        }
977    }
978
979    // Validate exactly one context parameter
980    if context_params.len() != 1 {
981        return syn::Error::new_spanned(
982            &function.sig,
983            "`#[tool]` functions must have exactly one parameter of type `&ApplicationContext`",
984        )
985        .to_compile_error();
986    }
987
988    // Validate function signature requirements
989    if function.sig.asyncness.is_none() {
990        return syn::Error::new_spanned(&function.sig, "`#[tool]` functions must be async")
991            .to_compile_error();
992    }
993
994    // Validate return type is Result
995    if let syn::ReturnType::Type(_, ty) = &function.sig.output {
996        if !is_result_type(ty) {
997            return syn::Error::new_spanned(
998                ty,
999                "`#[tool]` functions must return a Result<T, E> where T is serializable and E implements Into<ToolError>"
1000            ).to_compile_error();
1001        }
1002    } else {
1003        return syn::Error::new_spanned(
1004            &function.sig,
1005            "`#[tool]` functions must return a Result<T, E>",
1006        )
1007        .to_compile_error();
1008    }
1009
1010    // Validate parameter types are serializable
1011    for (pat, ty, _) in user_params.iter() {
1012        if let syn::Pat::Ident(ident) = pat.as_ref() {
1013            let param_name = &ident.ident;
1014            if !is_serializable_type(ty) {
1015                return syn::Error::new_spanned(
1016                    ty,
1017                    format!(
1018                        "Parameter '{}' must be a serializable type. Consider using String, numbers, bool, Vec<T>, Option<T>, or custom types that implement Serialize/Deserialize",
1019                        param_name
1020                    )
1021                ).to_compile_error();
1022            }
1023        }
1024    }
1025
1026    // Build Args struct from user_params only
1027    let mut param_fields = Vec::new();
1028    let mut param_names = Vec::new();
1029    let mut param_docs = Vec::new();
1030
1031    for (pat, ty, attrs) in user_params.iter() {
1032        if let syn::Pat::Ident(ident) = pat.as_ref() {
1033            let param_name = &ident.ident;
1034            let param_type = ty.as_ref();
1035            let param_doc = extract_doc_comments(attrs);
1036
1037            param_names.push(param_name.clone());
1038            param_docs.push(param_doc.clone());
1039
1040            // Add documentation for the field
1041            let doc_attr = if param_doc.is_empty() {
1042                quote! { #[doc = "Parameter"] }
1043            } else {
1044                quote! { #[doc = #param_doc] }
1045            };
1046
1047            // Filter out any attributes that might cause issues
1048            // Only keep serde-related attributes
1049            let filtered_attrs: Vec<_> = attrs
1050                .iter()
1051                .filter(|attr| {
1052                    if let Some(ident) = attr.path().get_ident() {
1053                        let name = ident.to_string();
1054                        name == "serde" || name == "schemars"
1055                    } else {
1056                        false
1057                    }
1058                })
1059                .collect();
1060
1061            param_fields.push(quote! {
1062                #doc_attr
1063                #(#filtered_attrs)*
1064                pub #param_name: #param_type
1065            });
1066        }
1067    }
1068
1069    // Generate the struct names
1070    let tool_struct_name = syn::Ident::new(
1071        &format!("{}Tool", fn_name.to_string().to_pascal_case()),
1072        fn_name.span(),
1073    );
1074    let _args_struct_name = syn::Ident::new(&format!("{}Args", tool_struct_name), fn_name.span());
1075    let tool_fn_name = syn::Ident::new(&format!("{}_tool", fn_name), fn_name.span());
1076
1077    // Check if function is async
1078    let is_async = function.sig.asyncness.is_some();
1079    let await_token = if is_async {
1080        quote! { .await }
1081    } else {
1082        quote! {}
1083    };
1084
1085    // Build the call arguments list for the function call
1086    let mut call_args = quote! {};
1087    for input in function.sig.inputs.iter() {
1088        if let FnArg::Typed(PatType { pat, ty, .. }) = input {
1089            if is_context_param(ty) {
1090                // If it's the context param, pass the context from the execute signature
1091                call_args.extend(quote! { context, });
1092            } else if let syn::Pat::Ident(ident) = pat.as_ref() {
1093                // If it's a user param, pass it from the deserialized 'args' struct
1094                let param_name = &ident.ident;
1095                call_args.extend(quote! { args.#param_name.clone(), });
1096            }
1097        }
1098    }
1099
1100    // Generate a unique module name to avoid namespace collisions
1101    // Prefix with __riglr_tool_ to make it highly unlikely to collide with user code
1102    let module_name = syn::Ident::new(&format!("__riglr_tool_{}", fn_name), fn_name.span());
1103
1104    // Extract the error type from the function's Result<T, E>
1105    let error_type = extract_error_type(&function.sig.output);
1106
1107    // Generate the error conversion code based on the error type
1108    let error_conversion = generate_error_conversion(&error_type);
1109
1110    // Generate the error handling match arms
1111    let error_match_arms = generate_tool_error_match_arms();
1112
1113    // Generate the tool implementation with namespace
1114    quote! {
1115        // Keep the original function
1116        #function
1117
1118        // Generate all tool-related code in a module namespace
1119        #[doc = "Generated tool module containing implementation details"]
1120        #fn_vis mod #module_name {
1121            use super::*;
1122
1123            // Generate the args struct if there are parameters
1124            #[doc = "Arguments structure for the tool"]
1125            #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema, Debug, Clone)]
1126            pub struct Args {
1127                #(#param_fields),*
1128            }
1129
1130            // Generate the tool struct
1131            #[doc = "Tool implementation structure"]
1132            #[derive(Clone)]
1133            pub struct Tool;
1134
1135            impl Tool {
1136                /// Create a new instance of this tool
1137                pub fn new() -> Self {
1138                    Self
1139                }
1140            }
1141
1142            impl Default for Tool {
1143                fn default() -> Self {
1144                    Self::new()
1145                }
1146            }
1147
1148            // Implement the riglr_core::Tool trait
1149            #[async_trait::async_trait]
1150            impl riglr_core::Tool for Tool {
1151                /// Execute the tool with the provided parameters
1152                async fn execute(&self, params: serde_json::Value, context: &riglr_core::provider::ApplicationContext) -> Result<riglr_core::JobResult, riglr_core::ToolError> {
1153                    // Parse the parameters; convert parse errors to ToolError::InvalidInput
1154                    let args: Args = match serde_json::from_value(params) {
1155                        Ok(v) => v,
1156                        Err(e) => {
1157                            // Convert parameter parsing error to ToolError and use standard error handling
1158                            let tool_error = riglr_core::ToolError::invalid_input_with_source(
1159                                e,
1160                                "Failed to parse tool parameters"
1161                            );
1162                            return match tool_error {
1163                                #error_match_arms
1164                            };
1165                        }
1166                    };
1167
1168                    // Call the original function with reconstructed arguments
1169                    let result = super::#fn_name(#call_args)#await_token;
1170
1171                    // Convert the result to JobResult
1172                    match result {
1173                        Ok(value) => {
1174                            let json_value = serde_json::to_value(value)
1175                                .map_err(|e| riglr_core::ToolError::permanent_with_source(e, "Failed to serialize result"))?;
1176                            Ok(riglr_core::JobResult::Success {
1177                                value: json_value,
1178                                tx_hash: None,
1179                            })
1180                        }
1181                        Err(error) => {
1182                            // Convert error to ToolError using automatic mapping for known types
1183                            let tool_error: riglr_core::ToolError = #error_conversion;
1184                            match tool_error {
1185                                #error_match_arms
1186                            }
1187                        }
1188                    }
1189                }
1190
1191                fn name(&self) -> &str {
1192                    stringify!(#fn_name)
1193                }
1194
1195                fn description(&self) -> &str {
1196                    #selected_description
1197                }
1198            }
1199
1200            impl Tool {
1201                /// Get the JSON schema for this tool's parameters
1202                fn schema(&self) -> serde_json::Value {
1203                    // Generate the schema for the Args struct
1204                    let schema = schemars::schema_for!(Args);
1205                    serde_json::to_value(schema).unwrap_or_else(|_| {
1206                        // Fallback to a generic object schema if serialization fails
1207                        serde_json::json!({
1208                            "type": "object",
1209                            "additionalProperties": true
1210                        })
1211                    })
1212                }
1213            }
1214
1215            // NOTE: rig::tool::Tool compatibility is handled by RigToolAdapter in riglr-agents
1216            // The adapter pattern allows us to bridge the incompatible interfaces
1217        }
1218
1219        // Create a convenience function to create an Arc<dyn Tool> using the namespaced type
1220        /// Factory function to create a new instance of the tool
1221        #fn_vis fn #tool_fn_name() -> std::sync::Arc<dyn riglr_core::Tool> {
1222            std::sync::Arc::new(#module_name::Tool::new())
1223        }
1224    }
1225}
1226
1227fn handle_struct(structure: ItemStruct, tool_attrs: ToolAttr) -> proc_macro2::TokenStream {
1228    let struct_name = &structure.ident;
1229    let struct_vis = &structure.vis;
1230
1231    // Validate that the struct meets requirements for #[tool]
1232    // Note: We can't easily validate that the struct has an execute() method at macro time
1233    // because the impl block might be defined elsewhere. Instead, we'll generate a
1234    // compile-time assertion that will fail if the method doesn't exist.
1235
1236    // Extract documentation from struct
1237    let description = extract_doc_comments(&structure.attrs);
1238    let selected_description = match tool_attrs.description {
1239        Some(desc) => desc,
1240        None => description,
1241    };
1242
1243    // Generate the error handling match arms
1244    let error_match_arms = generate_tool_error_match_arms();
1245
1246    // Generate a compile-time check for required traits
1247    let compile_time_checks = quote! {
1248        // This constant will fail to compile if the struct doesn't have the required traits
1249        const _: () = {
1250            fn assert_has_required_traits<T>()
1251            where
1252                T: Clone + serde::Serialize + serde::de::DeserializeOwned + schemars::JsonSchema,
1253            {}
1254
1255            // This will be checked when the Tool trait is implemented
1256            fn _check() {
1257                assert_has_required_traits::<#struct_name>();
1258            }
1259        };
1260    };
1261
1262    quote! {
1263        // Keep the original struct
1264        #structure
1265
1266        // Compile-time validation
1267        #compile_time_checks
1268
1269        // Implement the Tool trait
1270        #[async_trait::async_trait]
1271        impl riglr_core::Tool for #struct_name {
1272            async fn execute(&self, params: serde_json::Value, context: &riglr_core::provider::ApplicationContext) -> Result<riglr_core::JobResult, riglr_core::ToolError> {
1273                // Parse parameters into the struct; convert parse errors to ToolError::InvalidInput
1274                let args: Self = match serde_json::from_value(params) {
1275                    Ok(v) => v,
1276                    Err(e) => {
1277                        // Convert parameter parsing error to ToolError and use standard error handling
1278                        let tool_error = riglr_core::ToolError::invalid_input_with_source(
1279                            e,
1280                            "Failed to parse tool parameters"
1281                        );
1282                        return match tool_error {
1283                            #error_match_arms
1284                        };
1285                    }
1286                };
1287
1288                // Call the execute method (expecting Result<T, ToolError>)
1289                // IMPORTANT: This will fail at compile time if the struct doesn't have an execute() method
1290                // The struct must have: pub async fn execute(&self) -> Result<T, E>
1291                // where T: Serialize and E: Into<ToolError>
1292                let result = args.execute().await;
1293
1294                // Convert the result to JobResult
1295                match result {
1296                    Ok(value) => {
1297                        let json_value = serde_json::to_value(value)
1298                            .map_err(|e| riglr_core::ToolError::permanent_with_source(e, "Failed to serialize result"))?;
1299                        Ok(riglr_core::JobResult::Success {
1300                            value: json_value,
1301                            tx_hash: None,
1302                        })
1303                    }
1304                    Err(tool_error) => {
1305                        // Convert any error to ToolError, then match on it
1306                        let tool_error: riglr_core::ToolError = tool_error.into();
1307                        match tool_error {
1308                            #error_match_arms
1309                        }
1310                    }
1311                }
1312            }
1313
1314            fn name(&self) -> &str {
1315                stringify!(#struct_name)
1316            }
1317
1318            fn description(&self) -> &str {
1319                #selected_description
1320            }
1321        }
1322
1323        /// Get the JSON schema for this tool's parameters
1324        fn schema(&self) -> serde_json::Value {
1325            // Generate the schema for the struct itself
1326            let schema = schemars::schema_for!(#struct_name);
1327            serde_json::to_value(schema).unwrap_or_else(|_| {
1328                // Fallback to a generic object schema if serialization fails
1329                serde_json::json!({
1330                    "type": "object",
1331                    "additionalProperties": true
1332                })
1333            })
1334        }
1335
1336        // NOTE: rig::tool::Tool compatibility is handled by RigToolAdapter in riglr-agents
1337        // The adapter pattern allows us to bridge the incompatible interfaces
1338
1339        // Convenience function to create the tool
1340        impl #struct_name {
1341            #struct_vis fn as_tool(self) -> std::sync::Arc<dyn riglr_core::Tool> {
1342                std::sync::Arc::new(self)
1343            }
1344        }
1345
1346    }
1347}
1348
1349fn extract_doc_comments(attrs: &[Attribute]) -> String {
1350    let mut docs = Vec::new();
1351
1352    for attr in attrs {
1353        if attr.path().is_ident("doc") {
1354            if let syn::Meta::NameValue(meta) = &attr.meta {
1355                if let syn::Expr::Lit(syn::ExprLit {
1356                    lit: syn::Lit::Str(lit_str),
1357                    ..
1358                }) = &meta.value
1359                {
1360                    let line = lit_str.value();
1361                    // Remove leading space if present (rustdoc convention)
1362                    let line = line.strip_prefix(' ').unwrap_or(&line);
1363                    docs.push(line.to_string());
1364                }
1365            }
1366        }
1367    }
1368
1369    docs.join("\n").trim().to_string()
1370}
1371
1372/// Generates the common error handling match arms for ToolError to JobResult conversion
1373fn generate_tool_error_match_arms() -> proc_macro2::TokenStream {
1374    quote! {
1375        // With the new structure, we just wrap the ToolError directly
1376        // The JobResult::Failure variant now contains the full ToolError
1377        _ => Ok(riglr_core::JobResult::Failure {
1378            error: tool_error,
1379        })
1380    }
1381}
1382
1383/// Derives automatic conversion from an error enum to ToolError.
1384///
1385/// This macro generates a `From<YourError> for ToolError` implementation
1386/// that automatically classifies errors as retriable or permanent based on
1387/// naming conventions in variant names.
1388///
1389/// # Classification Rules
1390///
1391/// Errors are classified as **retriable** if their variant names contain:
1392/// - `Rpc`, `Network`, `Connection`, `Timeout`, `TooManyRequests`, `RateLimit`
1393/// - `Api` (for external API errors)
1394/// - `Http` (for HTTP-related errors)
1395///
1396/// Errors are classified as **permanent** if their variant names contain:
1397/// - `Invalid`, `Parse`, `Serialization`, `NotFound`, `Unauthorized`
1398/// - `InsufficientBalance`, `InsufficientFunds`
1399/// - All other unmatched variants (conservative default)
1400///
1401/// # Best Practices
1402///
1403/// **This derive macro is the recommended way to handle custom errors for tools.** It provides:
1404/// - Automatic error classification based on variant names
1405/// - Override capabilities for fine-grained control
1406/// - Type-safe error handling
1407/// - Consistent error conversion across your codebase
1408///
1409/// Using this macro instead of string-based error handling ensures that your errors are properly
1410/// structured and can be downcast by upstream consumers for specific error handling logic.
1411///
1412/// # Custom Classification
1413///
1414/// You can override the automatic classification using attributes:
1415///
1416/// ```rust,ignore
1417/// #[derive(IntoToolError)]
1418/// enum MyError {
1419///     #[tool_error(retriable)]
1420///     CustomError(String),
1421///
1422///     #[tool_error(permanent)]
1423///     NetworkError(String), // Override default retriable classification
1424///
1425///     #[tool_error(rate_limited)]
1426///     ApiQuotaExceeded,
1427/// }
1428/// ```
1429///
1430/// # Examples
1431///
1432/// ## Recommended Usage with thiserror
1433///
1434/// ```rust,ignore
1435/// use riglr_macros::IntoToolError;
1436/// use thiserror::Error;
1437///
1438/// #[derive(Error, Debug, IntoToolError)]
1439/// enum SolanaError {
1440///     #[error("RPC error: {0}")]
1441///     RpcError(String),  // Automatically retriable
1442///
1443///     #[error("Invalid address: {0}")]
1444///     InvalidAddress(String),  // Automatically permanent
1445///
1446///     #[error("Network timeout")]
1447///     NetworkTimeout,  // Automatically retriable
1448///
1449///     #[error("Insufficient balance")]
1450///     InsufficientBalance,  // Automatically permanent
1451///
1452///     #[tool_error(retriable)]
1453///     #[error("Custom error: {0}")]
1454///     CustomError(String),  // Explicitly retriable
1455/// }
1456/// ```
1457#[proc_macro_derive(IntoToolError, attributes(tool_error))]
1458pub fn derive_into_tool_error(input: TokenStream) -> TokenStream {
1459    let input = parse_macro_input!(input as DeriveInput);
1460
1461    let name = input.ident;
1462    let variants = match input.data {
1463        syn::Data::Enum(ref data) => &data.variants,
1464        _ => {
1465            return TokenStream::from(quote! {
1466                compile_error!("IntoToolError can only be derived for enums");
1467            });
1468        }
1469    };
1470
1471    let match_arms = variants.iter().map(|variant| {
1472        let variant_name = &variant.ident;
1473        let variant_name_str = variant_name.to_string();
1474
1475        // Check for explicit classification attribute
1476        let classification = variant.attrs.iter().find_map(|attr| {
1477            if attr.path().is_ident("tool_error") {
1478                attr.parse_args::<syn::Ident>().ok()
1479            } else {
1480                None
1481            }
1482        });
1483
1484        let pattern = match &variant.fields {
1485            syn::Fields::Named(_) => quote! { #name::#variant_name { .. } },
1486            syn::Fields::Unnamed(_) => quote! { #name::#variant_name(..) },
1487            syn::Fields::Unit => quote! { #name::#variant_name },
1488        };
1489
1490        let conversion = if let Some(class) = classification {
1491            match class.to_string().as_str() {
1492                "retriable" => quote! {
1493                    riglr_core::ToolError::retriable_string(err.to_string())
1494                },
1495                "permanent" => quote! {
1496                    riglr_core::ToolError::permanent_string(err.to_string())
1497                },
1498                "rate_limited" => quote! {
1499                    riglr_core::ToolError::rate_limited_string(err.to_string())
1500                },
1501                _ => quote! {
1502                    riglr_core::ToolError::permanent_string(err.to_string())
1503                },
1504            }
1505        } else {
1506            // Automatic classification based on naming conventions
1507            let retriable_patterns = [
1508                "Rpc",
1509                "Network",
1510                "Connection",
1511                "Timeout",
1512                "TooManyRequests",
1513                "RateLimit",
1514                "Api",
1515                "Http",
1516            ];
1517
1518            let is_retriable = retriable_patterns
1519                .iter()
1520                .any(|pattern| variant_name_str.contains(pattern));
1521
1522            if is_retriable {
1523                quote! { riglr_core::ToolError::retriable_string(err.to_string()) }
1524            } else {
1525                quote! { riglr_core::ToolError::permanent_string(err.to_string()) }
1526            }
1527        };
1528
1529        quote! {
1530            #pattern => #conversion
1531        }
1532    });
1533
1534    let expanded = quote! {
1535        impl From<#name> for riglr_core::ToolError {
1536            fn from(err: #name) -> Self {
1537                match err {
1538                    #(#match_arms),*
1539                }
1540            }
1541        }
1542    };
1543
1544    TokenStream::from(expanded)
1545}
1546
1547#[cfg(test)]
1548mod tests {
1549    use super::*;
1550
1551    #[test]
1552    fn test_extract_doc_comments_empty() {
1553        let attrs = vec![];
1554        let result = extract_doc_comments(&attrs);
1555        assert_eq!(result, "");
1556    }
1557
1558    #[test]
1559    fn test_extract_doc_comments_with_content() {
1560        // This is a unit test for the doc comment extraction function
1561        // In a real scenario, we would need to parse actual syn::Attribute instances
1562        // For now, we test that the function handles empty attributes correctly
1563        let attrs = vec![];
1564        let result = extract_doc_comments(&attrs);
1565        assert_eq!(result, "");
1566    }
1567
1568    #[test]
1569    fn test_to_pascal_case_conversion() {
1570        // Test the heck crate functionality we use
1571        assert_eq!("test_function".to_pascal_case(), "TestFunction");
1572        assert_eq!("get_balance".to_pascal_case(), "GetBalance");
1573        assert_eq!("simple".to_pascal_case(), "Simple");
1574    }
1575
1576    // Note: Testing procedural macros typically requires integration tests
1577    // with the `trybuild` crate or similar, as unit testing proc macros
1578    // directly is challenging due to their compile-time nature.
1579    //
1580    // For comprehensive testing, we would create test files in tests/
1581    // directory that use the macro and verify compilation and behavior.
1582
1583    #[test]
1584    fn test_macro_module_exists() {
1585        // Basic test to ensure the module compiles
1586        // Compilation success is the test
1587    }
1588
1589    #[test]
1590    fn test_extract_doc_comments_single_line() {
1591        // Create a mock attribute for a single line doc comment
1592        let attr = syn::parse_quote! { #[doc = " This is a single line comment"] };
1593        let attrs = vec![attr];
1594        let result = extract_doc_comments(&attrs);
1595        assert_eq!(result, "This is a single line comment");
1596    }
1597
1598    #[test]
1599    fn test_extract_doc_comments_multiple_lines() {
1600        // Create mock attributes for multiple line doc comments
1601        let attr1 = syn::parse_quote! { #[doc = " First line"] };
1602        let attr2 = syn::parse_quote! { #[doc = " Second line"] };
1603        let attr3 = syn::parse_quote! { #[doc = " Third line"] };
1604        let attrs = vec![attr1, attr2, attr3];
1605        let result = extract_doc_comments(&attrs);
1606        assert_eq!(result, "First line\nSecond line\nThird line");
1607    }
1608
1609    #[test]
1610    fn test_extract_doc_comments_no_leading_space() {
1611        let attr = syn::parse_quote! { #[doc = "No leading space"] };
1612        let attrs = vec![attr];
1613        let result = extract_doc_comments(&attrs);
1614        assert_eq!(result, "No leading space");
1615    }
1616
1617    #[test]
1618    fn test_extract_doc_comments_mixed_with_other_attrs() {
1619        let doc_attr = syn::parse_quote! { #[doc = " Documentation comment"] };
1620        let other_attr = syn::parse_quote! { #[allow(unused)] };
1621        let attrs = vec![other_attr, doc_attr];
1622        let result = extract_doc_comments(&attrs);
1623        assert_eq!(result, "Documentation comment");
1624    }
1625
1626    #[test]
1627    fn test_extract_doc_comments_empty_doc() {
1628        let attr = syn::parse_quote! { #[doc = ""] };
1629        let attrs = vec![attr];
1630        let result = extract_doc_comments(&attrs);
1631        assert_eq!(result, "");
1632    }
1633
1634    #[test]
1635    fn test_extract_doc_comments_whitespace_only() {
1636        let attr = syn::parse_quote! { #[doc = "   "] };
1637        let attrs = vec![attr];
1638        let result = extract_doc_comments(&attrs);
1639        assert_eq!(result, "");
1640    }
1641
1642    #[test]
1643    fn test_generate_tool_error_match_arms_compilation() {
1644        // Test that the generated match arms compile by checking their structure
1645        let match_arms = generate_tool_error_match_arms();
1646        let generated_string = match_arms.to_string();
1647
1648        // Check that the new simplified structure is used
1649        assert!(generated_string.contains("JobResult :: Failure"));
1650        assert!(generated_string.contains("error : tool_error"));
1651        // Verify it uses wildcard matching for simplified error handling
1652        assert!(generated_string.contains("_ =>"));
1653    }
1654
1655    #[test]
1656    fn test_tool_attr_default() {
1657        let default_attr = ToolAttr::default();
1658        assert!(default_attr.description.is_none());
1659    }
1660
1661    #[test]
1662    fn test_tool_attr_parse_empty() {
1663        let input = "";
1664        let result: Result<ToolAttr, _> = syn::parse_str(input);
1665        assert!(result.is_ok());
1666        let attr = result.unwrap();
1667        assert!(attr.description.is_none());
1668    }
1669
1670    #[test]
1671    fn test_tool_attr_parse_description() {
1672        let input = r#"description = "Test description""#;
1673        let result: Result<ToolAttr, _> = syn::parse_str(input);
1674        assert!(result.is_ok());
1675        let attr = result.unwrap();
1676        assert_eq!(attr.description, Some("Test description".to_string()));
1677    }
1678
1679    #[test]
1680    fn test_tool_attr_parse_invalid_key() {
1681        let input = r#"invalid_key = "value""#;
1682        let result: Result<ToolAttr, _> = syn::parse_str(input);
1683        assert!(result.is_err());
1684        let err = result.unwrap_err();
1685        assert!(err.to_string().contains("Unknown attribute key"));
1686    }
1687
1688    #[test]
1689    fn test_tool_attr_parse_missing_equals() {
1690        let input = "description";
1691        let result: Result<ToolAttr, _> = syn::parse_str(input);
1692        assert!(result.is_err());
1693    }
1694
1695    #[test]
1696    fn test_tool_attr_parse_wrong_value_type() {
1697        let input = "description = 123";
1698        let result: Result<ToolAttr, _> = syn::parse_str(input);
1699        assert!(result.is_err());
1700    }
1701
1702    #[test]
1703    fn test_heck_pascal_case_edge_cases() {
1704        assert_eq!("".to_pascal_case(), "");
1705        assert_eq!("a".to_pascal_case(), "A");
1706        assert_eq!("_test_".to_pascal_case(), "Test");
1707        assert_eq!("test__function".to_pascal_case(), "TestFunction");
1708        assert_eq!("UPPERCASE".to_pascal_case(), "Uppercase");
1709        assert_eq!("mixedCase".to_pascal_case(), "MixedCase");
1710        assert_eq!("123_numeric".to_pascal_case(), "123Numeric");
1711    }
1712
1713    // Test the pattern matching logic for derive_into_tool_error
1714    #[test]
1715    fn test_retriable_error_patterns() {
1716        let retriable_patterns = [
1717            "Rpc",
1718            "Network",
1719            "Connection",
1720            "Timeout",
1721            "TooManyRequests",
1722            "RateLimit",
1723            "Api",
1724            "Http",
1725        ];
1726
1727        // Test each pattern is correctly identified
1728        for pattern in &retriable_patterns {
1729            let test_variant = format!("Test{}Error", pattern);
1730            assert!(retriable_patterns.iter().any(|p| test_variant.contains(p)));
1731        }
1732    }
1733
1734    #[test]
1735    fn test_permanent_error_patterns() {
1736        let permanent_variants = [
1737            "InvalidInput",
1738            "ParseError",
1739            "SerializationFailed",
1740            "NotFound",
1741            "Unauthorized",
1742            "InsufficientBalance",
1743            "InsufficientFunds",
1744            "CustomError",
1745            "UnknownError",
1746        ];
1747
1748        let retriable_patterns = [
1749            "Rpc",
1750            "Network",
1751            "Connection",
1752            "Timeout",
1753            "TooManyRequests",
1754            "RateLimit",
1755            "Api",
1756            "Http",
1757        ];
1758
1759        // Test that permanent patterns don't match retriable patterns
1760        for variant in &permanent_variants {
1761            let is_retriable = retriable_patterns
1762                .iter()
1763                .any(|pattern| variant.contains(pattern));
1764            assert!(!is_retriable, "Variant {} should not be retriable", variant);
1765        }
1766    }
1767
1768    #[test]
1769    fn test_error_match_arms_structure() {
1770        let match_arms = generate_tool_error_match_arms();
1771        let generated = match_arms.to_string();
1772
1773        // Verify the new simplified structure
1774        assert!(generated.contains("JobResult :: Failure"));
1775        assert!(generated.contains("error : tool_error"));
1776        // Check that it uses wildcard matching
1777        assert!(generated.contains("_ =>"));
1778    }
1779
1780    // Test compilation of procedural macro output (basic structure validation)
1781    #[test]
1782    fn test_proc_macro_token_stream_generation() {
1783        // Test that we can create basic token streams without panicking
1784        use quote::quote;
1785
1786        let test_tokens = quote! {
1787            #[derive(Clone)]
1788            pub struct TestTool;
1789
1790            impl TestTool {
1791                pub fn new() -> Self { Self }
1792            }
1793        };
1794
1795        assert!(!test_tokens.is_empty());
1796    }
1797
1798    #[test]
1799    fn test_doc_comment_extraction_with_complex_content() {
1800        let attr1 = syn::parse_quote! { #[doc = " Complex content with \"quotes\""] };
1801        let attr2 = syn::parse_quote! { #[doc = " And special chars: &<>"] };
1802        let attr3 = syn::parse_quote! { #[doc = " Numbers: 123 and symbols: $%^"] };
1803        let attrs = vec![attr1, attr2, attr3];
1804        let result = extract_doc_comments(&attrs);
1805        assert_eq!(result, "Complex content with \"quotes\"\nAnd special chars: &<>\nNumbers: 123 and symbols: $%^");
1806    }
1807
1808    #[test]
1809    fn test_doc_comment_trimming() {
1810        let attr1 = syn::parse_quote! { #[doc = "  Leading spaces"] };
1811        let attr2 = syn::parse_quote! { #[doc = ""] };
1812        let attr3 = syn::parse_quote! { #[doc = "Trailing spaces  "] };
1813        let attrs = vec![attr1, attr2, attr3];
1814        let result = extract_doc_comments(&attrs);
1815        // The function strips the first space but preserves other leading spaces
1816        assert_eq!(result, "Leading spaces\n\nTrailing spaces");
1817    }
1818
1819    #[test]
1820    fn test_tool_attr_parse_description_with_quotes() {
1821        let input = r#"description = "Description with \"escaped quotes\"""#;
1822        let result: Result<ToolAttr, _> = syn::parse_str(input);
1823        assert!(result.is_ok());
1824        let attr = result.unwrap();
1825        assert_eq!(
1826            attr.description,
1827            Some("Description with \"escaped quotes\"".to_string())
1828        );
1829    }
1830
1831    #[test]
1832    fn test_tool_attr_parse_description_empty_string() {
1833        let input = r#"description = """#;
1834        let result: Result<ToolAttr, _> = syn::parse_str(input);
1835        assert!(result.is_ok());
1836        let attr = result.unwrap();
1837        assert_eq!(attr.description, Some("".to_string()));
1838    }
1839
1840    #[test]
1841    fn test_extract_doc_comments_only_doc_attrs() {
1842        // Test that only doc attributes are processed, others are ignored
1843        let doc_attr = syn::parse_quote! { #[doc = " Valid doc comment"] };
1844        let cfg_attr = syn::parse_quote! { #[cfg(test)] };
1845        let allow_attr = syn::parse_quote! { #[allow(dead_code)] };
1846        let derive_attr = syn::parse_quote! { #[derive(Clone)] };
1847
1848        let attrs = vec![cfg_attr, doc_attr, allow_attr, derive_attr];
1849        let result = extract_doc_comments(&attrs);
1850        assert_eq!(result, "Valid doc comment");
1851    }
1852
1853    #[test]
1854    fn test_proc_macro_attr_integration() {
1855        // Test the integration between attribute parsing and tool generation
1856        let empty_attr = ToolAttr::default();
1857        let with_desc = ToolAttr {
1858            description: Some("Custom description".to_string()),
1859        };
1860
1861        // Test that attributes are properly structured
1862        assert!(empty_attr.description.is_none());
1863        assert_eq!(
1864            with_desc.description,
1865            Some("Custom description".to_string())
1866        );
1867    }
1868
1869    #[test]
1870    fn test_complex_pascal_case_scenarios() {
1871        // Test edge cases for function name to struct name conversion
1872        assert_eq!("get_user_profile".to_pascal_case(), "GetUserProfile");
1873        assert_eq!("fetch_api_data".to_pascal_case(), "FetchApiData");
1874        assert_eq!(
1875            "handle_websocket_connection".to_pascal_case(),
1876            "HandleWebsocketConnection"
1877        );
1878        assert_eq!(
1879            "process_json_response".to_pascal_case(),
1880            "ProcessJsonResponse"
1881        );
1882        assert_eq!(
1883            "validate_eth_address".to_pascal_case(),
1884            "ValidateEthAddress"
1885        );
1886    }
1887
1888    #[test]
1889    fn test_error_classification_comprehensive() {
1890        let test_cases = vec![
1891            ("RpcConnectionError", true),        // Should be retriable
1892            ("NetworkTimeoutError", true),       // Should be retriable
1893            ("ApiRateLimitError", true),         // Should be retriable
1894            ("HttpRequestError", true),          // Should be retriable
1895            ("InvalidInputError", false),        // Should be permanent
1896            ("ParseError", false),               // Should be permanent
1897            ("NotFoundError", false),            // Should be permanent
1898            ("UnauthorizedError", false),        // Should be permanent
1899            ("InsufficientBalanceError", false), // Should be permanent
1900            ("CustomBusinessError", false),      // Should be permanent (default)
1901            ("DatabaseConnectionError", true),   // Contains "Connection"
1902            ("TooManyRequestsError", true),      // Contains "TooManyRequests"
1903        ];
1904
1905        let retriable_patterns = [
1906            "Rpc",
1907            "Network",
1908            "Connection",
1909            "Timeout",
1910            "TooManyRequests",
1911            "RateLimit",
1912            "Api",
1913            "Http",
1914        ];
1915
1916        for (variant_name, expected_retriable) in test_cases {
1917            let is_retriable = retriable_patterns
1918                .iter()
1919                .any(|pattern| variant_name.contains(pattern));
1920
1921            assert_eq!(
1922                is_retriable,
1923                expected_retriable,
1924                "Variant '{}' should be {} but was classified as {}",
1925                variant_name,
1926                if expected_retriable {
1927                    "retriable"
1928                } else {
1929                    "permanent"
1930                },
1931                if is_retriable {
1932                    "retriable"
1933                } else {
1934                    "permanent"
1935                }
1936            );
1937        }
1938    }
1939
1940    #[test]
1941    fn test_tool_attr_parse_malformed_syntax() {
1942        // Test various malformed syntax cases
1943        let test_cases = vec![
1944            "description =",            // Missing value
1945            "= \"value\"",              // Missing key
1946            "description \"value\"",    // Missing equals
1947            "description = value",      // Unquoted value
1948            "description == \"value\"", // Double equals
1949        ];
1950
1951        for input in test_cases {
1952            let result: Result<ToolAttr, _> = syn::parse_str(input);
1953            assert!(result.is_err(), "Input '{}' should fail to parse", input);
1954        }
1955    }
1956
1957    #[test]
1958    fn test_extract_doc_comments_with_non_string_meta() {
1959        // Test with attributes that have doc but aren't string literals
1960        // This tests the filtering logic in extract_doc_comments
1961        let valid_doc = syn::parse_quote! { #[doc = "Valid comment"] };
1962        let attrs = vec![valid_doc];
1963        let result = extract_doc_comments(&attrs);
1964        assert_eq!(result, "Valid comment");
1965    }
1966
1967    #[test]
1968    fn test_doc_comment_joining_edge_cases() {
1969        // Test doc comment joining with various whitespace scenarios
1970        let attr1 = syn::parse_quote! { #[doc = "Line1"] };
1971        let attr2 = syn::parse_quote! { #[doc = " "] }; // Just a space
1972        let attr3 = syn::parse_quote! { #[doc = "Line3"] };
1973        let attrs = vec![attr1, attr2, attr3];
1974        let result = extract_doc_comments(&attrs);
1975        assert_eq!(result, "Line1\n\nLine3");
1976    }
1977
1978    #[test]
1979    fn test_pascal_case_with_unicode() {
1980        // Test pascal case conversion with unicode characters
1981        assert_eq!("café_function".to_pascal_case(), "CaféFunction");
1982        assert_eq!("测试_function".to_pascal_case(), "测试Function");
1983    }
1984
1985    #[test]
1986    fn test_tool_attr_description_priority() {
1987        // Test that explicit description takes priority over doc comments
1988        let explicit_desc = ToolAttr {
1989            description: Some("Explicit description".to_string()),
1990        };
1991        assert_eq!(
1992            explicit_desc.description,
1993            Some("Explicit description".to_string())
1994        );
1995
1996        let no_desc = ToolAttr { description: None };
1997        assert!(no_desc.description.is_none());
1998    }
1999
2000    #[test]
2001    fn test_generate_match_arms_output_consistency() {
2002        // Test that generate_tool_error_match_arms produces consistent output
2003        let match_arms1 = generate_tool_error_match_arms();
2004        let match_arms2 = generate_tool_error_match_arms();
2005
2006        // Convert to strings and compare
2007        let output1 = match_arms1.to_string();
2008        let output2 = match_arms2.to_string();
2009        assert_eq!(
2010            output1, output2,
2011            "Match arms generation should be deterministic"
2012        );
2013    }
2014
2015    #[test]
2016    fn test_doc_comment_extract_path_verification() {
2017        // Test that extract_doc_comments properly checks path identity
2018        let doc_attr = syn::parse_quote! { #[doc = "Test"] };
2019        let not_doc_attr = syn::parse_quote! { #[deprecated] };
2020
2021        let attrs = vec![not_doc_attr, doc_attr];
2022        let result = extract_doc_comments(&attrs);
2023        assert_eq!(result, "Test");
2024    }
2025
2026    #[test]
2027    fn test_error_pattern_case_sensitivity() {
2028        // Test that error pattern matching is case-sensitive
2029        let case_sensitive_tests = vec![
2030            ("rpc_error", false),     // lowercase 'rpc' should not match 'Rpc'
2031            ("RpcError", true),       // Uppercase 'Rpc' should match
2032            ("network_issue", false), // lowercase 'network' should not match 'Network'
2033            ("NetworkIssue", true),   // Uppercase 'Network' should match
2034        ];
2035
2036        let retriable_patterns = [
2037            "Rpc",
2038            "Network",
2039            "Connection",
2040            "Timeout",
2041            "TooManyRequests",
2042            "RateLimit",
2043            "Api",
2044            "Http",
2045        ];
2046
2047        for (variant_name, expected_match) in case_sensitive_tests {
2048            let matches = retriable_patterns
2049                .iter()
2050                .any(|pattern| variant_name.contains(pattern));
2051            assert_eq!(
2052                matches, expected_match,
2053                "Case sensitivity test failed for '{}'",
2054                variant_name
2055            );
2056        }
2057    }
2058
2059    #[test]
2060    fn test_tool_attr_parse_lookahead_logic() {
2061        // Test the lookahead logic in ToolAttr::parse
2062        let valid_input = "description = \"test\"";
2063        let result: Result<ToolAttr, _> = syn::parse_str(valid_input);
2064        assert!(result.is_ok());
2065
2066        // Test with invalid identifier that triggers lookahead error
2067        let invalid_input = "123invalid = \"test\"";
2068        let result: Result<ToolAttr, _> = syn::parse_str(invalid_input);
2069        assert!(result.is_err());
2070    }
2071
2072    #[test]
2073    fn test_comprehensive_error_variant_naming() {
2074        // Comprehensive test of error variant naming patterns
2075        let comprehensive_tests = vec![
2076            // Retriable patterns
2077            ("SolanaRpcError", true),
2078            ("EthereumNetworkTimeout", true),
2079            ("DatabaseConnectionLost", true),
2080            ("ApiRateLimitExceeded", true),
2081            ("HttpRequestFailed", true),
2082            ("WebSocketConnectionDropped", true),
2083            ("RedisConnectionTimeout", true),
2084            ("TooManyRequestsReceived", true),
2085            // Permanent patterns
2086            ("InvalidAddressFormat", false),
2087            ("ParseJsonError", false),
2088            ("SerializationFailure", false),
2089            ("UserNotFound", false),
2090            ("UnauthorizedAccess", false),
2091            ("InsufficientTokenBalance", false),
2092            ("InsufficientGasFunds", false),
2093            ("MalformedInput", false),
2094            ("ConfigurationError", false),
2095            ("BusinessLogicViolation", false),
2096        ];
2097
2098        let retriable_patterns = [
2099            "Rpc",
2100            "Network",
2101            "Connection",
2102            "Timeout",
2103            "TooManyRequests",
2104            "RateLimit",
2105            "Api",
2106            "Http",
2107        ];
2108
2109        for (variant_name, expected_retriable) in comprehensive_tests {
2110            let is_retriable = retriable_patterns
2111                .iter()
2112                .any(|pattern| variant_name.contains(pattern));
2113
2114            assert_eq!(
2115                is_retriable,
2116                expected_retriable,
2117                "Comprehensive error classification failed for '{}' - expected {}, got {}",
2118                variant_name,
2119                if expected_retriable {
2120                    "retriable"
2121                } else {
2122                    "permanent"
2123                },
2124                if is_retriable {
2125                    "retriable"
2126                } else {
2127                    "permanent"
2128                }
2129            );
2130        }
2131    }
2132
2133    #[test]
2134    fn test_empty_and_whitespace_edge_cases() {
2135        // Test various empty and whitespace scenarios
2136        let empty_attrs: Vec<syn::Attribute> = vec![];
2137        assert_eq!(extract_doc_comments(&empty_attrs), "");
2138
2139        // Test with only whitespace doc
2140        let whitespace_attr = syn::parse_quote! { #[doc = "   \t\n  "] };
2141        let result = extract_doc_comments(&vec![whitespace_attr]);
2142        assert_eq!(result.trim(), "");
2143
2144        // Test pascal case with empty string
2145        assert_eq!("".to_pascal_case(), "");
2146    }
2147
2148    #[test]
2149    fn test_parameter_parsing_error_handling() {
2150        // Test that parameter parsing errors are converted to ToolError::InvalidInput
2151        // and use the standard error matching logic
2152
2153        // Create a mock serde_json::Error by attempting to parse invalid JSON
2154        let invalid_json = "{ invalid json }";
2155        let parse_result: Result<serde_json::Value, serde_json::Error> =
2156            serde_json::from_str(invalid_json);
2157        assert!(parse_result.is_err());
2158
2159        let error = parse_result.unwrap_err();
2160
2161        // Verify that we can create a ToolError::InvalidInput from the serde error
2162        use riglr_core::ToolError;
2163        let tool_error =
2164            ToolError::invalid_input_with_source(error, "Failed to parse tool parameters");
2165
2166        // Verify properties of the error
2167        assert!(!tool_error.is_retriable());
2168        assert!(!tool_error.is_rate_limited());
2169        assert_eq!(tool_error.retry_after(), None);
2170
2171        // Verify the error message contains expected content
2172        let error_str = tool_error.to_string();
2173        assert!(error_str.contains("Invalid input"));
2174        assert!(error_str.contains("Failed to parse tool parameters"));
2175    }
2176
2177    #[test]
2178    fn test_tool_error_match_arms_invalid_input_handling() {
2179        // Test that the generated match arms handle all errors with simplified structure
2180        let match_arms = generate_tool_error_match_arms();
2181        let generated = match_arms.to_string();
2182
2183        // Verify the simplified structure handles all errors uniformly
2184        assert!(generated.contains("JobResult :: Failure"));
2185        assert!(generated.contains("error : tool_error"));
2186        // Verify it uses wildcard matching for all error types
2187        assert!(generated.contains("_ =>"));
2188    }
2189}