riglr_macros/lib.rs
1/*!
2# riglr-macros
3
4Procedural macros for riglr - dramatically reducing boilerplate when creating blockchain tools.
5
6The `#[tool]` macro is the cornerstone of riglr's developer experience, transforming simple async
7functions, synchronous functions, and structs into full-featured blockchain tools with automatic error handling, JSON
8schema generation, and seamless `rig` framework integration.
9
10## Overview
11
12The `#[tool]` macro automatically implements the `Tool` trait for both async and sync functions, as well as structs,
13eliminating the need to write ~30 lines of boilerplate code per tool. It generates:
14
151. **Parameter struct** with proper JSON schema and serde annotations
162. **Tool trait implementation** with error handling and type conversion
173. **Documentation extraction** from doc comments for AI model consumption
184. **SignerContext integration** for secure blockchain operations
195. **Convenience constructors** for easy instantiation
20
21## Code Generation Process
22
23When you apply `#[tool]` to a function, the macro performs the following transformations:
24
25### 1. Parameter Extraction and Struct Generation
26
27```rust,ignore
28// Your function:
29#[tool]
30async fn swap_tokens(
31 /// Source token mint address
32 from_mint: String,
33 /// Destination token mint address
34 to_mint: String,
35 /// Amount to swap in base units
36 amount: u64,
37 /// Optional slippage tolerance (default: 0.5%)
38 #[serde(default = "default_slippage")]
39 slippage_bps: Option<u16>,
40) -> Result<String, SwapError> { ... }
41
42// Generated args struct:
43#[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema, Debug, Clone)]
44#[serde(rename_all = "camelCase")]
45pub struct SwapTokensArgs {
46 /// Source token mint address
47 pub from_mint: String,
48 /// Destination token mint address
49 pub to_mint: String,
50 /// Amount to swap in base units
51 pub amount: u64,
52 /// Optional slippage tolerance (default: 0.5%)
53 #[serde(default = "default_slippage")]
54 pub slippage_bps: Option<u16>,
55}
56```
57
58### 2. Tool Struct and Trait Implementation Generation
59
60```rust,ignore
61// Generated tool struct:
62#[derive(Clone)]
63pub struct SwapTokensTool;
64
65impl SwapTokensTool {
66 pub fn new() -> Self { Self }
67}
68
69#[async_trait::async_trait]
70impl riglr_core::Tool for SwapTokensTool {
71 async fn execute(&self, params: serde_json::Value, context: &riglr_core::provider::ApplicationContext) -> Result<riglr_core::JobResult, riglr_core::ToolError> {
72 // 1. Parse parameters with detailed error messages
73 let args: SwapTokensArgs = serde_json::from_value(params)
74 .map_err(|e| format!("Failed to parse parameters: {}", e))?;
75
76 // 2. Call your original function
77 let result = swap_tokens(args.from_mint, args.to_mint, args.amount, args.slippage_bps).await;
78
79 // 3. Convert results to standardized JobResult format
80 match result {
81 Ok(value) => Ok(riglr_core::JobResult::Success {
82 value: serde_json::to_value(value)?,
83 tx_hash: None,
84 }),
85 Err(error) => {
86 // 4. Structured error handling with retry logic
87 let tool_error: riglr_core::ToolError = error.into();
88 match tool_error {
89 riglr_core::ToolError::Retriable(msg) => Ok(riglr_core::JobResult::Failure {
90 error: msg,
91 retriable: true,
92 }),
93 riglr_core::ToolError::Permanent(msg) => Ok(riglr_core::JobResult::Failure {
94 error: msg,
95 retriable: false,
96 }),
97 riglr_core::ToolError::RateLimited(msg) => Ok(riglr_core::JobResult::Failure {
98 error: format!("Rate limited: {}", msg),
99 retriable: true,
100 }),
101 riglr_core::ToolError::InvalidInput(msg) => Ok(riglr_core::JobResult::Failure {
102 error: format!("Invalid input: {}", msg),
103 retriable: false,
104 }),
105 riglr_core::ToolError::SignerContext(err) => Ok(riglr_core::JobResult::Failure {
106 error: format!("Signer error: {}", err),
107 retriable: false,
108 }),
109 }
110 }
111 }
112 }
113
114 fn name(&self) -> &str {
115 "swap_tokens"
116 }
117}
118
119// Convenience constructor
120pub fn swap_tokens_tool() -> std::sync::Arc<dyn riglr_core::Tool> {
121 std::sync::Arc::new(SwapTokensTool::new())
122}
123```
124
125### 3. Documentation Processing and Description Attribute
126
127The macro extracts documentation from three sources and wires them into the Tool implementation:
128
129- **Function docstrings** → Tool descriptions for AI models
130- **Parameter docstrings** → JSON schema field descriptions
131- **Type annotations** → JSON schema type information
132
133You can also provide an explicit AI-facing description using the attribute:
134
135```rust,ignore
136#[tool(description = "Fetches the URL and returns the body as text.")]
137async fn fetch(url: String) -> Result<String, Error> { ... }
138```
139
140Priority logic for the generated `Tool::description()` method:
141- If `description = "..."` attribute is present, that string is used
142- Else, the item's rustdoc comments are used
143- Else, an empty string is returned
144
145This enables AI models to understand exactly what each tool does and how to use it properly.
146
147## Constraints and Requirements
148
149### Function Requirements
150
1511. **Return Type**: Must be `Result<T, E>` where `E: Into<riglr_core::ToolError>`
152 ```rust,ignore
153 // ✅ Valid - custom error type with derive
154 #[derive(Error, Debug, IntoToolError)]
155 enum MyError { NetworkError(String), InvalidInput(String) }
156 async fn valid_tool() -> Result<String, MyError> { ... }
157
158 // ❌ Invalid - not a Result
159 async fn invalid_tool() -> String { ... }
160
161 // ❌ Invalid - std::io::Error doesn't implement Into<ToolError>
162 async fn bad_error() -> Result<String, std::io::Error> { ... }
163
164 // ✅ Valid - wrap std library errors in custom types
165 #[derive(Error, Debug, IntoToolError)]
166 enum FileError {
167 #[error("IO error: {0}")]
168 Io(#[from] std::io::Error)
169 }
170 async fn good_file_tool() -> Result<String, FileError> { ... }
171 ```
172
1732. **Parameters**: All parameters must implement `serde::Deserialize + schemars::JsonSchema`
174 ```rust,ignore
175 // ✅ Valid - standard types implement these automatically
176 async fn good_params(address: String, amount: u64) -> Result<(), ToolError> { ... }
177
178 // ❌ Invalid - custom types need derives
179 struct CustomType { field: String }
180 async fn bad_params(custom: CustomType) -> Result<(), ToolError> { ... }
181 ```
182
1833. **Function Type**: The macro supports both async and synchronous functions
184 ```rust,ignore
185 // ✅ Valid - async function
186 #[tool]
187 async fn async_tool() -> Result<String, ToolError> { ... }
188
189 // ✅ Valid - sync function (executed within async context)
190 #[tool]
191 fn sync_tool() -> Result<String, ToolError> { ... }
192 ```
193
194 Synchronous functions are automatically wrapped to work within the async Tool trait.
195 They execute synchronously within the async `execute` method.
196
1974. **Documentation**: Function and parameters should have doc comments for AI consumption
198 ```rust,ignore
199 /// This description helps AI models understand the tool's purpose
200 #[tool]
201 async fn documented_tool(
202 /// This helps the AI understand this parameter
203 param: String,
204 ) -> Result<String, ToolError> { ... }
205 ```
206
207### Struct Requirements
208
209For struct-based tools, additional requirements apply:
210
2111. **Execute Method**: Must have an async `execute` method returning `Result<T, E>`
2122. **Serde Traits**: Must derive `Serialize`, `Deserialize`, and `JsonSchema`
2133. **Clone**: Must be `Clone` for multi-use scenarios
214
215```rust,ignore
216#[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema, Clone)]
217#[tool]
218struct MyStructTool {
219 config: String,
220}
221
222impl MyStructTool {
223 pub async fn execute(&self) -> Result<String, ToolError> {
224 // Implementation
225 Ok(format!("Processed: {}", self.config))
226 }
227}
228```
229
230## Complex Usage Examples
231
232### Synchronous Function Example
233
234The macro supports both async and sync functions. Sync functions are useful for
235computational tools that don't require I/O operations:
236
237```rust,ignore
238use riglr_core::ToolError;
239
240/// Calculate compound interest for a given principal, rate, and time
241///
242/// This is a computational tool that doesn't require async operations,
243/// so it's implemented as a synchronous function that runs efficiently
244/// within the async Tool framework.
245#[tool]
246fn calculate_compound_interest(
247 /// Principal amount in dollars
248 principal: f64,
249 /// Annual interest rate as a decimal (e.g., 0.05 for 5%)
250 annual_rate: f64,
251 /// Time period in years
252 years: f64,
253 /// Number of times interest is compounded per year
254 compounds_per_year: u32,
255) -> Result<f64, ToolError> {
256 if principal <= 0.0 {
257 return Err(ToolError::invalid_input_string("Principal must be positive"));
258 }
259 if annual_rate < 0.0 {
260 return Err(ToolError::invalid_input_string("Interest rate cannot be negative"));
261 }
262 if years < 0.0 {
263 return Err(ToolError::invalid_input_string("Time period cannot be negative"));
264 }
265 if compounds_per_year == 0 {
266 return Err(ToolError::invalid_input_string("Compounds per year must be at least 1"));
267 }
268
269 let rate_per_compound = annual_rate / compounds_per_year as f64;
270 let total_compounds = compounds_per_year as f64 * years;
271 let final_amount = principal * (1.0 + rate_per_compound).powf(total_compounds);
272
273 Ok(final_amount)
274}
275```
276
277#### Important Note on CPU-Intensive Sync Functions
278
279The `#[tool]` macro executes synchronous functions directly within the async executor's thread.
280This is fine for quick computations, but **CPU-intensive operations can block the async runtime**.
281
282For CPU-intensive work, wrap your function in `tokio::task::spawn_blocking` **before** applying
283the `#[tool]` macro:
284
285```rust,ignore
286use riglr_core::ToolError;
287
288/// CPU-intensive cryptographic operation
289///
290/// This uses spawn_blocking to avoid blocking the async runtime
291#[tool]
292async fn compute_hash(
293 /// Data to hash
294 data: Vec<u8>,
295 /// Number of iterations
296 iterations: u32,
297) -> Result<String, ToolError> {
298 // Move CPU-intensive work to a blocking thread pool
299 tokio::task::spawn_blocking(move || {
300 // Simulate expensive computation
301 let mut hash = data;
302 for _ in 0..iterations {
303 hash = sha256::digest(&hash).into_bytes();
304 }
305 Ok(hex::encode(hash))
306 })
307 .await
308 .map_err(|e| ToolError::permanent_string(format!("Task failed: {}", e)))?
309}
310```
311
312**Guidelines for choosing between sync and async with spawn_blocking:**
313- **Use sync functions** for quick calculations (< 1ms), simple data transformations, or validation
314- **Use async + spawn_blocking** for CPU-intensive work like cryptography, complex parsing, or heavy computation
315- **Use regular async** for I/O operations like network requests or database queries
316
317### Generic Parameters and Type Constraints
318
319```rust,ignore
320use serde::{Serialize, Deserialize};
321use schemars::JsonSchema;
322
323/// Generic tool that can process any serializable data
324#[tool]
325async fn process_data<T>(
326 /// The data to process (must be JSON-serializable)
327 data: T,
328 /// Processing options
329 options: ProcessingOptions,
330) -> Result<ProcessedData, ProcessingError>
331where
332 T: Serialize + Deserialize + JsonSchema + Send + Sync,
333{
334 // The macro handles generic constraints properly
335 let serialized = serde_json::to_string(&data)?;
336 // ... processing logic
337 Ok(ProcessedData::new(serialized))
338}
339```
340
341### SignerContext Integration
342
343Tools automatically have access to the current blockchain signer:
344
345```rust,ignore
346use riglr_core::signer::SignerContext;
347
348/// Swap tokens on Solana using Jupiter aggregator
349///
350/// This tool automatically accesses the current signer from the context,
351/// eliminating the need to pass signing credentials explicitly.
352#[tool]
353async fn jupiter_swap(
354 /// Input token mint address
355 input_mint: String,
356 /// Output token mint address
357 output_mint: String,
358 /// Amount to swap in base units
359 amount: u64,
360 /// Maximum slippage in basis points
361 max_slippage_bps: u16,
362) -> Result<String, SwapError> {
363 // Access the current signer automatically
364 let signer = SignerContext::current().await?;
365
366 // Derive RPC client from signer
367 let rpc_client = signer.rpc_client();
368
369 // Get quote from Jupiter
370 let quote = get_jupiter_quote(&input_mint, &output_mint, amount, max_slippage_bps).await?;
371
372 // Build and sign transaction
373 let tx = build_swap_transaction(quote, &signer.pubkey()).await?;
374 let signed_tx = signer.sign_transaction(tx).await?;
375
376 // Send transaction
377 let signature = rpc_client.send_and_confirm_transaction(&signed_tx).await?;
378
379 Ok(signature.to_string())
380}
381```
382
383### Multi-Chain Tool with Dynamic Signer Selection
384
385```rust,ignore
386use riglr_core::signer::{SignerContext, ChainType};
387
388/// Bridge tokens between different blockchains
389///
390/// Automatically detects the source chain from the current signer
391/// and handles cross-chain bridging operations.
392#[tool]
393async fn bridge_tokens(
394 /// Source token address
395 source_token: String,
396 /// Destination chain identifier
397 dest_chain: String,
398 /// Destination token address
399 dest_token: String,
400 /// Amount to bridge in base units
401 amount: u64,
402 /// Recipient address on destination chain
403 recipient: String,
404) -> Result<BridgeResult, BridgeError> {
405 let signer = SignerContext::current().await?;
406
407 // Dynamic chain detection
408 let bridge_operation = match signer.chain_type() {
409 ChainType::Solana => {
410 SolanaBridge::new(signer).bridge_to_evm(
411 source_token, dest_chain, dest_token, amount, recipient
412 ).await?
413 },
414 ChainType::Ethereum => {
415 EthereumBridge::new(signer).bridge_to_solana(
416 source_token, dest_token, amount, recipient
417 ).await?
418 },
419 ChainType::Polygon => {
420 PolygonBridge::new(signer).bridge_cross_chain(
421 source_token, dest_chain, dest_token, amount, recipient
422 ).await?
423 },
424 _ => return Err(BridgeError::UnsupportedChain),
425 };
426
427 Ok(bridge_operation)
428}
429```
430
431### Error Handling and Retry Logic
432
433The macro automatically integrates with riglr's structured error handling.
434
435**IMPORTANT REQUIREMENT:** The `#[tool]` macro requires that all error types implement `Into<ToolError>`.
436There is no automatic conversion for standard library error types like `std::io::Error` or `reqwest::Error`.
437You must define custom error types that provide proper classification and context.
438
439#### Recommended Pattern: Custom Error Types with `#[derive(IntoToolError)]`
440
441The required practice is to use the `IntoToolError` derive macro for automatic error handling:
442
443```rust,ignore
444use riglr_macros::IntoToolError;
445use thiserror::Error;
446
447#[derive(Error, Debug, IntoToolError)]
448enum SwapError {
449 #[error("Insufficient balance: need {required}, have {available}")]
450 InsufficientBalance { required: u64, available: u64 },
451
452 #[error("Network congestion, retry in {retry_after_seconds}s")]
453 #[tool_error(retriable)] // Override default classification
454 NetworkCongestion { retry_after_seconds: u64 },
455
456 #[error("Slippage too high: expected {expected}%, got {actual}%")]
457 SlippageTooHigh { expected: f64, actual: f64 },
458
459 #[error("Invalid token mint: {mint}")]
460 InvalidToken { mint: String },
461}
462
463// The IntoToolError derive macro automatically generates the From<SwapError> for ToolError impl
464```
465
466See the `trybuild` tests in `riglr-macros/tests/trybuild/` for examples:
467- `pass/custom_error_into.rs` - Correct usage with custom error types
468- `fail/unconvertible_error.rs` - What happens when error types don't implement Into<ToolError>
469
470#### Alternative: Manual Implementation
471
472If you need more control, you can manually implement the conversion:
473
474```rust,ignore
475use riglr_core::ToolError;
476
477impl From<SwapError> for ToolError {
478 fn from(error: SwapError) -> Self {
479 match error {
480 SwapError::NetworkCongestion { .. } => ToolError::Retriable(error.to_string()),
481 SwapError::InsufficientBalance { .. } => ToolError::Permanent(error.to_string()),
482 SwapError::SlippageTooHigh { .. } => ToolError::Permanent(error.to_string()),
483 SwapError::InvalidToken { .. } => ToolError::Permanent(error.to_string()),
484 }
485 }
486}
487
488/// Advanced token swap with detailed error handling
489#[tool]
490async fn advanced_swap(
491 input_mint: String,
492 output_mint: String,
493 amount: u64,
494) -> Result<SwapResult, SwapError> {
495 let signer = SignerContext::current().await?;
496
497 // Check balance first
498 let balance = get_token_balance(&signer, &input_mint).await?;
499 if balance < amount {
500 return Err(SwapError::InsufficientBalance {
501 required: amount,
502 available: balance,
503 });
504 }
505
506 // Attempt swap with retries for transient failures
507 match attempt_swap(&signer, &input_mint, &output_mint, amount).await {
508 Err(SwapError::NetworkCongestion { .. }) => {
509 // The macro will automatically mark this as retriable
510 Err(SwapError::NetworkCongestion { retry_after_seconds: 10 })
511 },
512 result => result,
513 }
514}
515```
516
517### Testing Tool Implementations
518
519The macro-generated code is designed to be easily testable:
520
521```rust,ignore
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use riglr_core::signer::{MockSigner, SignerContext};
526 use serde_json::json;
527
528 #[tokio::test]
529 async fn test_swap_tool_execution() {
530 // Create mock signer with expected behavior
531 let mock_signer = MockSigner::new()
532 .with_token_balance("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", 1000000) // USDC
533 .expect_transaction("swap")
534 .returns_signature("5j7s2Hz2UnknownTxHash");
535
536 // Test the generated tool
537 let tool = SwapTokensTool::new();
538
539 let result = SignerContext::new(&mock_signer).execute(async {
540 tool.execute(json!({
541 "fromMint": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
542 "toMint": "So11111111111111111111111111111111111111112",
543 "amount": 1000000,
544 "slippageBps": 50
545 })).await
546 }).await;
547
548 assert!(result.is_ok());
549 mock_signer.verify_all_expectations();
550 }
551}
552```
553
554## Best Practices
555
556### 1. Parameter Design
557- Use descriptive parameter names that clearly indicate their purpose
558- Provide comprehensive doc comments for each parameter
559- Use appropriate default values with `#[serde(default)]` where applicable
560- Group related parameters into structs for complex operations
561
562### 2. Error Handling
563- Define custom error types that implement `Into<ToolError>`
564- Use structured errors that provide actionable information
565- Distinguish between retriable and permanent errors appropriately
566- Include relevant context in error messages
567
568### 3. Documentation
569- Write clear, concise function descriptions that explain the tool's purpose
570- Document any side effects or state changes
571- Include examples in doc comments where helpful
572- Explain any complex parameters or return values
573
574### 4. Performance Considerations
575- Use `Arc<dyn Tool>` for tools that will be shared across threads
576- Implement `Clone` efficiently for struct-based tools
577- Consider caching for expensive operations that don't change frequently
578- Use appropriate timeouts for network operations
579
580### 5. Security and Business Logic Validation
581
582**⚠️ IMPORTANT SECURITY NOTE:** While the `#[tool]` macro and `serde` automatically handle parameter *format* validation (JSON schema, type conversion, required fields), your tool implementation is still responsible for all *business logic* validation and security checks.
583
584#### Critical Business Logic Validations:
585
586**Financial Operations:**
587```rust,ignore
588#[tool]
589async fn transfer_tokens(
590 to_address: String,
591 amount: f64,
592 slippage_percent: f64,
593) -> Result<String, ToolError> {
594 // ✅ Business logic validation (your responsibility)
595 if amount <= 0.0 {
596 return Err(ToolError::invalid_input_string(
597 "Transfer amount must be positive"
598 ));
599 }
600
601 if slippage_percent >= 5.0 {
602 return Err(ToolError::invalid_input_string(
603 "Slippage tolerance too high (max 5%). Consider if this is intentional"
604 ));
605 }
606
607 // ✅ Address validation
608 if !is_valid_address(&to_address) {
609 return Err(ToolError::invalid_input_string(
610 "Invalid recipient address format"
611 ));
612 }
613
614 // ✅ Balance check before executing
615 let balance = get_current_balance().await?;
616 if balance < amount {
617 return Err(ToolError::permanent_string(
618 format!("Insufficient balance: {} < {}", balance, amount)
619 ));
620 }
621
622 // Proceed with transfer...
623}
624```
625
626**Smart Contract Interactions:**
627```rust,ignore
628#[tool]
629async fn execute_contract_call(
630 contract_address: String,
631 function_name: String,
632 gas_limit: u64,
633) -> Result<String, ToolError> {
634 // ✅ Contract address validation
635 if !is_trusted_contract(&contract_address) {
636 return Err(ToolError::permanent_string(
637 "Contract not in approved whitelist"
638 ));
639 }
640
641 // ✅ Re-entrancy protection
642 if is_contract_execution_in_progress(&contract_address) {
643 return Err(ToolError::retriable_string(
644 "Contract execution already in progress, avoiding re-entrancy"
645 ));
646 }
647
648 // ✅ Gas limit safety check
649 if gas_limit > MAX_SAFE_GAS_LIMIT {
650 return Err(ToolError::invalid_input_string(
651 "Gas limit exceeds safety threshold"
652 ));
653 }
654
655 // Proceed with contract call...
656}
657```
658
659**Data Integrity Checks:**
660```rust,ignore
661#[tool]
662async fn process_transaction_data(
663 tx_hash: String,
664 expected_amount: f64,
665) -> Result<TransactionResult, ToolError> {
666 // ✅ Transaction hash format validation
667 if tx_hash.len() != 64 || !tx_hash.chars().all(|c| c.is_ascii_hexdigit()) {
668 return Err(ToolError::invalid_input_string(
669 "Invalid transaction hash format"
670 ));
671 }
672
673 // ✅ Cross-reference with external data
674 let actual_amount = fetch_transaction_amount(&tx_hash).await?;
675 if (actual_amount - expected_amount).abs() > 0.001 {
676 return Err(ToolError::permanent_string(
677 "Transaction amount mismatch detected"
678 ));
679 }
680
681 // Proceed with processing...
682}
683```
684
685#### Remember: The Macro Handles Format, You Handle Business Logic
686- **Macro + Serde**: Validates JSON structure, types, required fields
687- **Your Code**: Validates ranges, business rules, security constraints, data relationships
688
689## Macro Limitations
690
691### Current Limitations
692
6931. **Generic Functions**: Limited support for complex generic constraints
6942. **Lifetime Parameters**: Not currently supported in tool functions
6953. **Associated Types**: Cannot use associated types in parameters
6964. **Const Generics**: No support for const generic parameters
697
698### Workarounds
699
700For complex generic scenarios, consider using trait objects or type erasure:
701
702```rust,ignore
703// Instead of:
704// #[tool]
705// async fn complex_generic<T: ComplexTrait>(data: T) -> Result<(), Error> { ... }
706
707// Use:
708#[tool]
709async fn process_complex_data(
710 /// JSON representation of the data to process
711 data: serde_json::Value,
712) -> Result<ProcessedResult, ProcessError> {
713 // Deserialize to specific types inside the function
714 let typed_data: MyType = serde_json::from_value(data)?;
715 // ... process typed_data
716}
717```
718
719## Integration with External Crates
720
721The macro is designed to work seamlessly with the broader Rust ecosystem:
722
723### Serde Integration
724- Automatic `#[serde(rename_all = "camelCase")]` for JavaScript compatibility
725- Support for all serde attributes on parameters
726- Custom serialization/deserialization via serde derives
727
728### JSON Schema Generation
729- Automatic schema generation via `schemars` crate
730- Support for complex nested types and enums
731- Custom schema attributes for fine-tuned control
732
733### Async Runtime Compatibility
734- Works with any async runtime (tokio, async-std, etc.)
735- Proper handling of async trait implementations
736- Support for async error handling patterns
737
738The `#[tool]` macro transforms riglr from a collection of utilities into a cohesive,
739developer-friendly framework for building sophisticated blockchain AI agents.
740*/
741
742use heck::ToPascalCase;
743use proc_macro::TokenStream;
744use quote::quote;
745use syn::{
746 parse::Parse, parse::ParseStream, parse_macro_input, Attribute, DeriveInput, FnArg, ItemFn,
747 ItemStruct, LitStr, PatType, Token,
748};
749
750/// The `#[tool]` procedural macro that converts functions and structs into Tool implementations.
751///
752/// This macro supports:
753/// - Async functions with arbitrary parameters and Result return types
754/// - Structs that have an `execute` method
755/// - Automatic JSON schema generation using `schemars`
756/// - Documentation extraction from doc comments
757/// - Parameter descriptions from doc comments on function arguments
758///
759/// Attributes supported:
760/// - description = "..."
761#[proc_macro_attribute]
762pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
763 let input = item.clone();
764
765 let tool_attrs = match syn::parse::<ToolAttr>(attr) {
766 Ok(attrs) => attrs,
767 Err(_) => ToolAttr { description: None },
768 };
769
770 // Try to parse as function first, then as struct
771 if let Ok(function) = syn::parse::<ItemFn>(input.clone()) {
772 handle_function(function, tool_attrs).into()
773 } else if let Ok(structure) = syn::parse::<ItemStruct>(input) {
774 handle_struct(structure, tool_attrs).into()
775 } else {
776 syn::Error::new_spanned(
777 proc_macro2::TokenStream::from(item),
778 "#[tool] can only be applied to async functions or structs.\n\
779 For functions: Must be async and return Result<T, E> where E: Into<ToolError>\n\
780 For structs: Must implement Clone, Serialize, Deserialize, JsonSchema and have an async execute(&self) method",
781 )
782 .to_compile_error()
783 .into()
784 }
785}
786
787#[derive(Default, Debug)]
788struct ToolAttr {
789 description: Option<String>,
790}
791
792impl Parse for ToolAttr {
793 fn parse(input: ParseStream) -> syn::Result<Self> {
794 if input.is_empty() {
795 return Ok(Self::default());
796 }
797
798 let lookahead = input.lookahead1();
799 if lookahead.peek(syn::Ident) {
800 let ident: syn::Ident = input.parse()?;
801 if ident == "description" {
802 input.parse::<Token![=]>()?;
803 let lit: LitStr = input.parse()?;
804 return Ok(Self {
805 description: Some(lit.value()),
806 });
807 } else {
808 return Err(syn::Error::new_spanned(
809 ident,
810 "Unknown attribute key. Supported: description",
811 ));
812 }
813 }
814
815 Err(syn::Error::new(
816 input.span(),
817 "Expected attribute key like: description = \"...\"",
818 ))
819 }
820}
821
822/// Helper function to check if a parameter is a context parameter (by type)
823fn is_context_param(param_type: &syn::Type) -> bool {
824 // Check if the type is &ApplicationContext or &riglr_core::provider::ApplicationContext
825 if let syn::Type::Reference(type_ref) = param_type {
826 if let syn::Type::Path(type_path) = &*type_ref.elem {
827 let path_str = type_path
828 .path
829 .segments
830 .iter()
831 .map(|segment| segment.ident.to_string())
832 .collect::<Vec<_>>()
833 .join("::");
834
835 return path_str == "ApplicationContext"
836 || path_str == "riglr_core::provider::ApplicationContext"
837 || path_str.ends_with("::ApplicationContext");
838 }
839 }
840 false
841}
842
843/// Check if a type is Result<T, E>
844fn is_result_type(ty: &syn::Type) -> bool {
845 if let syn::Type::Path(type_path) = ty {
846 if let Some(segment) = type_path.path.segments.last() {
847 let segment_name = segment.ident.to_string();
848 return segment_name == "Result"
849 && matches!(segment.arguments, syn::PathArguments::AngleBracketed(_));
850 }
851 }
852 false
853}
854
855/// Check if a type is likely serializable
856fn is_serializable_type(ty: &syn::Type) -> bool {
857 match ty {
858 // Basic serializable types
859 syn::Type::Path(type_path) => {
860 if let Some(segment) = type_path.path.segments.last() {
861 let segment_name = segment.ident.to_string();
862 match segment_name.as_str() {
863 // Primitive types
864 "String" | "str" | "bool" | "i8" | "i16" | "i32" | "i64" | "i128" | "isize"
865 | "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "f32" | "f64" | "char" => {
866 true
867 }
868
869 // Common generic types that are serializable
870 "Vec" | "Option" | "HashMap" | "BTreeMap" | "HashSet" | "BTreeSet"
871 | "VecDeque" => true,
872
873 // Common time types
874 "SystemTime" | "Duration" => true,
875
876 // Assume custom types are serializable (user responsibility)
877 _ => true,
878 }
879 } else {
880 false
881 }
882 }
883
884 // References to serializable types
885 syn::Type::Reference(type_ref) => is_serializable_type(&type_ref.elem),
886
887 // Arrays are serializable if their element type is
888 syn::Type::Array(type_array) => is_serializable_type(&type_array.elem),
889
890 // Slices are serializable if their element type is
891 syn::Type::Slice(type_slice) => is_serializable_type(&type_slice.elem),
892
893 // Tuples are serializable if all elements are
894 syn::Type::Tuple(type_tuple) => type_tuple.elems.iter().all(is_serializable_type),
895
896 // Other types - be conservative and reject
897 _ => false,
898 }
899}
900
901/// Extract the error type from a Result<T, E> return type
902fn extract_error_type(return_type: &syn::ReturnType) -> Option<syn::Type> {
903 if let syn::ReturnType::Type(_, ty) = return_type {
904 if let syn::Type::Path(type_path) = ty.as_ref() {
905 if let Some(segment) = type_path.path.segments.last() {
906 if segment.ident == "Result" {
907 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
908 // Result<T, E> - get the second type argument (E)
909 if args.args.len() == 2 {
910 if let syn::GenericArgument::Type(error_type) = &args.args[1] {
911 return Some(error_type.clone());
912 }
913 }
914 }
915 }
916 }
917 }
918 }
919 None
920}
921
922/// Generate error conversion code based on the error type
923fn generate_error_conversion(error_type: &Option<syn::Type>) -> proc_macro2::TokenStream {
924 let Some(_err_type) = error_type else {
925 // No error type specified - use standard Into conversion
926 // This relies on the user's type implementing Into<ToolError>
927 return quote! { error.into() };
928 };
929
930 // REFACTORED: Strict error handling - only use Into<ToolError> trait
931 //
932 // ERROR CONVERSION LOGIC:
933 // All error types must implement Into<ToolError> to be used with the #[tool] macro.
934 // This enforces a consistent error handling pattern and encourages users to:
935 //
936 // 1. Define custom error enums with #[derive(IntoToolError)] for automatic conversion
937 // 2. Manually implement From<MyError> for ToolError for fine-grained control
938 // 3. Wrap standard library errors in custom types that provide better context
939 //
940 // This removes the special-case handling for std::io::Error and reqwest::Error
941 // that previously existed as fallbacks. Users must now explicitly handle these
942 // error types by wrapping them in custom error enums.
943 //
944 // IMPORTANT: If your function returns Result<T, std::io::Error> or similar,
945 // the compilation will now fail with a clear error message directing you to
946 // implement Into<ToolError> for your error type.
947
948 // Use Into<ToolError> conversion for all error types
949 // If the error type doesn't implement Into<ToolError>, this will produce
950 // a compile error with a clear message about the missing trait implementation
951 quote! { error.into() }
952}
953
954fn handle_function(function: ItemFn, tool_attrs: ToolAttr) -> proc_macro2::TokenStream {
955 let fn_name = &function.sig.ident;
956 let fn_vis = &function.vis;
957
958 // Extract documentation from function
959 let description = extract_doc_comments(&function.attrs);
960 let selected_description = match tool_attrs.description {
961 Some(desc) => desc,
962 None => description,
963 };
964
965 // Partition parameters into user_params and context_params
966 let mut user_params = Vec::new();
967 let mut context_params = Vec::new();
968
969 for input in function.sig.inputs.iter() {
970 if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = input {
971 if is_context_param(ty) {
972 context_params.push((pat, ty, attrs));
973 } else {
974 user_params.push((pat, ty, attrs));
975 }
976 }
977 }
978
979 // Validate exactly one context parameter
980 if context_params.len() != 1 {
981 return syn::Error::new_spanned(
982 &function.sig,
983 "`#[tool]` functions must have exactly one parameter of type `&ApplicationContext`",
984 )
985 .to_compile_error();
986 }
987
988 // Validate function signature requirements
989 if function.sig.asyncness.is_none() {
990 return syn::Error::new_spanned(&function.sig, "`#[tool]` functions must be async")
991 .to_compile_error();
992 }
993
994 // Validate return type is Result
995 if let syn::ReturnType::Type(_, ty) = &function.sig.output {
996 if !is_result_type(ty) {
997 return syn::Error::new_spanned(
998 ty,
999 "`#[tool]` functions must return a Result<T, E> where T is serializable and E implements Into<ToolError>"
1000 ).to_compile_error();
1001 }
1002 } else {
1003 return syn::Error::new_spanned(
1004 &function.sig,
1005 "`#[tool]` functions must return a Result<T, E>",
1006 )
1007 .to_compile_error();
1008 }
1009
1010 // Validate parameter types are serializable
1011 for (pat, ty, _) in user_params.iter() {
1012 if let syn::Pat::Ident(ident) = pat.as_ref() {
1013 let param_name = &ident.ident;
1014 if !is_serializable_type(ty) {
1015 return syn::Error::new_spanned(
1016 ty,
1017 format!(
1018 "Parameter '{}' must be a serializable type. Consider using String, numbers, bool, Vec<T>, Option<T>, or custom types that implement Serialize/Deserialize",
1019 param_name
1020 )
1021 ).to_compile_error();
1022 }
1023 }
1024 }
1025
1026 // Build Args struct from user_params only
1027 let mut param_fields = Vec::new();
1028 let mut param_names = Vec::new();
1029 let mut param_docs = Vec::new();
1030
1031 for (pat, ty, attrs) in user_params.iter() {
1032 if let syn::Pat::Ident(ident) = pat.as_ref() {
1033 let param_name = &ident.ident;
1034 let param_type = ty.as_ref();
1035 let param_doc = extract_doc_comments(attrs);
1036
1037 param_names.push(param_name.clone());
1038 param_docs.push(param_doc.clone());
1039
1040 // Add documentation for the field
1041 let doc_attr = if param_doc.is_empty() {
1042 quote! { #[doc = "Parameter"] }
1043 } else {
1044 quote! { #[doc = #param_doc] }
1045 };
1046
1047 // Filter out any attributes that might cause issues
1048 // Only keep serde-related attributes
1049 let filtered_attrs: Vec<_> = attrs
1050 .iter()
1051 .filter(|attr| {
1052 if let Some(ident) = attr.path().get_ident() {
1053 let name = ident.to_string();
1054 name == "serde" || name == "schemars"
1055 } else {
1056 false
1057 }
1058 })
1059 .collect();
1060
1061 param_fields.push(quote! {
1062 #doc_attr
1063 #(#filtered_attrs)*
1064 pub #param_name: #param_type
1065 });
1066 }
1067 }
1068
1069 // Generate the struct names
1070 let tool_struct_name = syn::Ident::new(
1071 &format!("{}Tool", fn_name.to_string().to_pascal_case()),
1072 fn_name.span(),
1073 );
1074 let _args_struct_name = syn::Ident::new(&format!("{}Args", tool_struct_name), fn_name.span());
1075 let tool_fn_name = syn::Ident::new(&format!("{}_tool", fn_name), fn_name.span());
1076
1077 // Check if function is async
1078 let is_async = function.sig.asyncness.is_some();
1079 let await_token = if is_async {
1080 quote! { .await }
1081 } else {
1082 quote! {}
1083 };
1084
1085 // Build the call arguments list for the function call
1086 let mut call_args = quote! {};
1087 for input in function.sig.inputs.iter() {
1088 if let FnArg::Typed(PatType { pat, ty, .. }) = input {
1089 if is_context_param(ty) {
1090 // If it's the context param, pass the context from the execute signature
1091 call_args.extend(quote! { context, });
1092 } else if let syn::Pat::Ident(ident) = pat.as_ref() {
1093 // If it's a user param, pass it from the deserialized 'args' struct
1094 let param_name = &ident.ident;
1095 call_args.extend(quote! { args.#param_name.clone(), });
1096 }
1097 }
1098 }
1099
1100 // Generate a unique module name to avoid namespace collisions
1101 // Prefix with __riglr_tool_ to make it highly unlikely to collide with user code
1102 let module_name = syn::Ident::new(&format!("__riglr_tool_{}", fn_name), fn_name.span());
1103
1104 // Extract the error type from the function's Result<T, E>
1105 let error_type = extract_error_type(&function.sig.output);
1106
1107 // Generate the error conversion code based on the error type
1108 let error_conversion = generate_error_conversion(&error_type);
1109
1110 // Generate the error handling match arms
1111 let error_match_arms = generate_tool_error_match_arms();
1112
1113 // Generate the tool implementation with namespace
1114 quote! {
1115 // Keep the original function
1116 #function
1117
1118 // Generate all tool-related code in a module namespace
1119 #[doc = "Generated tool module containing implementation details"]
1120 #fn_vis mod #module_name {
1121 use super::*;
1122
1123 // Generate the args struct if there are parameters
1124 #[doc = "Arguments structure for the tool"]
1125 #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema, Debug, Clone)]
1126 pub struct Args {
1127 #(#param_fields),*
1128 }
1129
1130 // Generate the tool struct
1131 #[doc = "Tool implementation structure"]
1132 #[derive(Clone)]
1133 pub struct Tool;
1134
1135 impl Tool {
1136 /// Create a new instance of this tool
1137 pub fn new() -> Self {
1138 Self
1139 }
1140 }
1141
1142 impl Default for Tool {
1143 fn default() -> Self {
1144 Self::new()
1145 }
1146 }
1147
1148 // Implement the riglr_core::Tool trait
1149 #[async_trait::async_trait]
1150 impl riglr_core::Tool for Tool {
1151 /// Execute the tool with the provided parameters
1152 async fn execute(&self, params: serde_json::Value, context: &riglr_core::provider::ApplicationContext) -> Result<riglr_core::JobResult, riglr_core::ToolError> {
1153 // Parse the parameters; convert parse errors to ToolError::InvalidInput
1154 let args: Args = match serde_json::from_value(params) {
1155 Ok(v) => v,
1156 Err(e) => {
1157 // Convert parameter parsing error to ToolError and use standard error handling
1158 let tool_error = riglr_core::ToolError::invalid_input_with_source(
1159 e,
1160 "Failed to parse tool parameters"
1161 );
1162 return match tool_error {
1163 #error_match_arms
1164 };
1165 }
1166 };
1167
1168 // Call the original function with reconstructed arguments
1169 let result = super::#fn_name(#call_args)#await_token;
1170
1171 // Convert the result to JobResult
1172 match result {
1173 Ok(value) => {
1174 let json_value = serde_json::to_value(value)
1175 .map_err(|e| riglr_core::ToolError::permanent_with_source(e, "Failed to serialize result"))?;
1176 Ok(riglr_core::JobResult::Success {
1177 value: json_value,
1178 tx_hash: None,
1179 })
1180 }
1181 Err(error) => {
1182 // Convert error to ToolError using automatic mapping for known types
1183 let tool_error: riglr_core::ToolError = #error_conversion;
1184 match tool_error {
1185 #error_match_arms
1186 }
1187 }
1188 }
1189 }
1190
1191 fn name(&self) -> &str {
1192 stringify!(#fn_name)
1193 }
1194
1195 fn description(&self) -> &str {
1196 #selected_description
1197 }
1198 }
1199
1200 impl Tool {
1201 /// Get the JSON schema for this tool's parameters
1202 fn schema(&self) -> serde_json::Value {
1203 // Generate the schema for the Args struct
1204 let schema = schemars::schema_for!(Args);
1205 serde_json::to_value(schema).unwrap_or_else(|_| {
1206 // Fallback to a generic object schema if serialization fails
1207 serde_json::json!({
1208 "type": "object",
1209 "additionalProperties": true
1210 })
1211 })
1212 }
1213 }
1214
1215 // NOTE: rig::tool::Tool compatibility is handled by RigToolAdapter in riglr-agents
1216 // The adapter pattern allows us to bridge the incompatible interfaces
1217 }
1218
1219 // Create a convenience function to create an Arc<dyn Tool> using the namespaced type
1220 /// Factory function to create a new instance of the tool
1221 #fn_vis fn #tool_fn_name() -> std::sync::Arc<dyn riglr_core::Tool> {
1222 std::sync::Arc::new(#module_name::Tool::new())
1223 }
1224 }
1225}
1226
1227fn handle_struct(structure: ItemStruct, tool_attrs: ToolAttr) -> proc_macro2::TokenStream {
1228 let struct_name = &structure.ident;
1229 let struct_vis = &structure.vis;
1230
1231 // Validate that the struct meets requirements for #[tool]
1232 // Note: We can't easily validate that the struct has an execute() method at macro time
1233 // because the impl block might be defined elsewhere. Instead, we'll generate a
1234 // compile-time assertion that will fail if the method doesn't exist.
1235
1236 // Extract documentation from struct
1237 let description = extract_doc_comments(&structure.attrs);
1238 let selected_description = match tool_attrs.description {
1239 Some(desc) => desc,
1240 None => description,
1241 };
1242
1243 // Generate the error handling match arms
1244 let error_match_arms = generate_tool_error_match_arms();
1245
1246 // Generate a compile-time check for required traits
1247 let compile_time_checks = quote! {
1248 // This constant will fail to compile if the struct doesn't have the required traits
1249 const _: () = {
1250 fn assert_has_required_traits<T>()
1251 where
1252 T: Clone + serde::Serialize + serde::de::DeserializeOwned + schemars::JsonSchema,
1253 {}
1254
1255 // This will be checked when the Tool trait is implemented
1256 fn _check() {
1257 assert_has_required_traits::<#struct_name>();
1258 }
1259 };
1260 };
1261
1262 quote! {
1263 // Keep the original struct
1264 #structure
1265
1266 // Compile-time validation
1267 #compile_time_checks
1268
1269 // Implement the Tool trait
1270 #[async_trait::async_trait]
1271 impl riglr_core::Tool for #struct_name {
1272 async fn execute(&self, params: serde_json::Value, context: &riglr_core::provider::ApplicationContext) -> Result<riglr_core::JobResult, riglr_core::ToolError> {
1273 // Parse parameters into the struct; convert parse errors to ToolError::InvalidInput
1274 let args: Self = match serde_json::from_value(params) {
1275 Ok(v) => v,
1276 Err(e) => {
1277 // Convert parameter parsing error to ToolError and use standard error handling
1278 let tool_error = riglr_core::ToolError::invalid_input_with_source(
1279 e,
1280 "Failed to parse tool parameters"
1281 );
1282 return match tool_error {
1283 #error_match_arms
1284 };
1285 }
1286 };
1287
1288 // Call the execute method (expecting Result<T, ToolError>)
1289 // IMPORTANT: This will fail at compile time if the struct doesn't have an execute() method
1290 // The struct must have: pub async fn execute(&self) -> Result<T, E>
1291 // where T: Serialize and E: Into<ToolError>
1292 let result = args.execute().await;
1293
1294 // Convert the result to JobResult
1295 match result {
1296 Ok(value) => {
1297 let json_value = serde_json::to_value(value)
1298 .map_err(|e| riglr_core::ToolError::permanent_with_source(e, "Failed to serialize result"))?;
1299 Ok(riglr_core::JobResult::Success {
1300 value: json_value,
1301 tx_hash: None,
1302 })
1303 }
1304 Err(tool_error) => {
1305 // Convert any error to ToolError, then match on it
1306 let tool_error: riglr_core::ToolError = tool_error.into();
1307 match tool_error {
1308 #error_match_arms
1309 }
1310 }
1311 }
1312 }
1313
1314 fn name(&self) -> &str {
1315 stringify!(#struct_name)
1316 }
1317
1318 fn description(&self) -> &str {
1319 #selected_description
1320 }
1321 }
1322
1323 /// Get the JSON schema for this tool's parameters
1324 fn schema(&self) -> serde_json::Value {
1325 // Generate the schema for the struct itself
1326 let schema = schemars::schema_for!(#struct_name);
1327 serde_json::to_value(schema).unwrap_or_else(|_| {
1328 // Fallback to a generic object schema if serialization fails
1329 serde_json::json!({
1330 "type": "object",
1331 "additionalProperties": true
1332 })
1333 })
1334 }
1335
1336 // NOTE: rig::tool::Tool compatibility is handled by RigToolAdapter in riglr-agents
1337 // The adapter pattern allows us to bridge the incompatible interfaces
1338
1339 // Convenience function to create the tool
1340 impl #struct_name {
1341 #struct_vis fn as_tool(self) -> std::sync::Arc<dyn riglr_core::Tool> {
1342 std::sync::Arc::new(self)
1343 }
1344 }
1345
1346 }
1347}
1348
1349fn extract_doc_comments(attrs: &[Attribute]) -> String {
1350 let mut docs = Vec::new();
1351
1352 for attr in attrs {
1353 if attr.path().is_ident("doc") {
1354 if let syn::Meta::NameValue(meta) = &attr.meta {
1355 if let syn::Expr::Lit(syn::ExprLit {
1356 lit: syn::Lit::Str(lit_str),
1357 ..
1358 }) = &meta.value
1359 {
1360 let line = lit_str.value();
1361 // Remove leading space if present (rustdoc convention)
1362 let line = line.strip_prefix(' ').unwrap_or(&line);
1363 docs.push(line.to_string());
1364 }
1365 }
1366 }
1367 }
1368
1369 docs.join("\n").trim().to_string()
1370}
1371
1372/// Generates the common error handling match arms for ToolError to JobResult conversion
1373fn generate_tool_error_match_arms() -> proc_macro2::TokenStream {
1374 quote! {
1375 // With the new structure, we just wrap the ToolError directly
1376 // The JobResult::Failure variant now contains the full ToolError
1377 _ => Ok(riglr_core::JobResult::Failure {
1378 error: tool_error,
1379 })
1380 }
1381}
1382
1383/// Derives automatic conversion from an error enum to ToolError.
1384///
1385/// This macro generates a `From<YourError> for ToolError` implementation
1386/// that automatically classifies errors as retriable or permanent based on
1387/// naming conventions in variant names.
1388///
1389/// # Classification Rules
1390///
1391/// Errors are classified as **retriable** if their variant names contain:
1392/// - `Rpc`, `Network`, `Connection`, `Timeout`, `TooManyRequests`, `RateLimit`
1393/// - `Api` (for external API errors)
1394/// - `Http` (for HTTP-related errors)
1395///
1396/// Errors are classified as **permanent** if their variant names contain:
1397/// - `Invalid`, `Parse`, `Serialization`, `NotFound`, `Unauthorized`
1398/// - `InsufficientBalance`, `InsufficientFunds`
1399/// - All other unmatched variants (conservative default)
1400///
1401/// # Best Practices
1402///
1403/// **This derive macro is the recommended way to handle custom errors for tools.** It provides:
1404/// - Automatic error classification based on variant names
1405/// - Override capabilities for fine-grained control
1406/// - Type-safe error handling
1407/// - Consistent error conversion across your codebase
1408///
1409/// Using this macro instead of string-based error handling ensures that your errors are properly
1410/// structured and can be downcast by upstream consumers for specific error handling logic.
1411///
1412/// # Custom Classification
1413///
1414/// You can override the automatic classification using attributes:
1415///
1416/// ```rust,ignore
1417/// #[derive(IntoToolError)]
1418/// enum MyError {
1419/// #[tool_error(retriable)]
1420/// CustomError(String),
1421///
1422/// #[tool_error(permanent)]
1423/// NetworkError(String), // Override default retriable classification
1424///
1425/// #[tool_error(rate_limited)]
1426/// ApiQuotaExceeded,
1427/// }
1428/// ```
1429///
1430/// # Examples
1431///
1432/// ## Recommended Usage with thiserror
1433///
1434/// ```rust,ignore
1435/// use riglr_macros::IntoToolError;
1436/// use thiserror::Error;
1437///
1438/// #[derive(Error, Debug, IntoToolError)]
1439/// enum SolanaError {
1440/// #[error("RPC error: {0}")]
1441/// RpcError(String), // Automatically retriable
1442///
1443/// #[error("Invalid address: {0}")]
1444/// InvalidAddress(String), // Automatically permanent
1445///
1446/// #[error("Network timeout")]
1447/// NetworkTimeout, // Automatically retriable
1448///
1449/// #[error("Insufficient balance")]
1450/// InsufficientBalance, // Automatically permanent
1451///
1452/// #[tool_error(retriable)]
1453/// #[error("Custom error: {0}")]
1454/// CustomError(String), // Explicitly retriable
1455/// }
1456/// ```
1457#[proc_macro_derive(IntoToolError, attributes(tool_error))]
1458pub fn derive_into_tool_error(input: TokenStream) -> TokenStream {
1459 let input = parse_macro_input!(input as DeriveInput);
1460
1461 let name = input.ident;
1462 let variants = match input.data {
1463 syn::Data::Enum(ref data) => &data.variants,
1464 _ => {
1465 return TokenStream::from(quote! {
1466 compile_error!("IntoToolError can only be derived for enums");
1467 });
1468 }
1469 };
1470
1471 let match_arms = variants.iter().map(|variant| {
1472 let variant_name = &variant.ident;
1473 let variant_name_str = variant_name.to_string();
1474
1475 // Check for explicit classification attribute
1476 let classification = variant.attrs.iter().find_map(|attr| {
1477 if attr.path().is_ident("tool_error") {
1478 attr.parse_args::<syn::Ident>().ok()
1479 } else {
1480 None
1481 }
1482 });
1483
1484 let pattern = match &variant.fields {
1485 syn::Fields::Named(_) => quote! { #name::#variant_name { .. } },
1486 syn::Fields::Unnamed(_) => quote! { #name::#variant_name(..) },
1487 syn::Fields::Unit => quote! { #name::#variant_name },
1488 };
1489
1490 let conversion = if let Some(class) = classification {
1491 match class.to_string().as_str() {
1492 "retriable" => quote! {
1493 riglr_core::ToolError::retriable_string(err.to_string())
1494 },
1495 "permanent" => quote! {
1496 riglr_core::ToolError::permanent_string(err.to_string())
1497 },
1498 "rate_limited" => quote! {
1499 riglr_core::ToolError::rate_limited_string(err.to_string())
1500 },
1501 _ => quote! {
1502 riglr_core::ToolError::permanent_string(err.to_string())
1503 },
1504 }
1505 } else {
1506 // Automatic classification based on naming conventions
1507 let retriable_patterns = [
1508 "Rpc",
1509 "Network",
1510 "Connection",
1511 "Timeout",
1512 "TooManyRequests",
1513 "RateLimit",
1514 "Api",
1515 "Http",
1516 ];
1517
1518 let is_retriable = retriable_patterns
1519 .iter()
1520 .any(|pattern| variant_name_str.contains(pattern));
1521
1522 if is_retriable {
1523 quote! { riglr_core::ToolError::retriable_string(err.to_string()) }
1524 } else {
1525 quote! { riglr_core::ToolError::permanent_string(err.to_string()) }
1526 }
1527 };
1528
1529 quote! {
1530 #pattern => #conversion
1531 }
1532 });
1533
1534 let expanded = quote! {
1535 impl From<#name> for riglr_core::ToolError {
1536 fn from(err: #name) -> Self {
1537 match err {
1538 #(#match_arms),*
1539 }
1540 }
1541 }
1542 };
1543
1544 TokenStream::from(expanded)
1545}
1546
1547#[cfg(test)]
1548mod tests {
1549 use super::*;
1550
1551 #[test]
1552 fn test_extract_doc_comments_empty() {
1553 let attrs = vec![];
1554 let result = extract_doc_comments(&attrs);
1555 assert_eq!(result, "");
1556 }
1557
1558 #[test]
1559 fn test_extract_doc_comments_with_content() {
1560 // This is a unit test for the doc comment extraction function
1561 // In a real scenario, we would need to parse actual syn::Attribute instances
1562 // For now, we test that the function handles empty attributes correctly
1563 let attrs = vec![];
1564 let result = extract_doc_comments(&attrs);
1565 assert_eq!(result, "");
1566 }
1567
1568 #[test]
1569 fn test_to_pascal_case_conversion() {
1570 // Test the heck crate functionality we use
1571 assert_eq!("test_function".to_pascal_case(), "TestFunction");
1572 assert_eq!("get_balance".to_pascal_case(), "GetBalance");
1573 assert_eq!("simple".to_pascal_case(), "Simple");
1574 }
1575
1576 // Note: Testing procedural macros typically requires integration tests
1577 // with the `trybuild` crate or similar, as unit testing proc macros
1578 // directly is challenging due to their compile-time nature.
1579 //
1580 // For comprehensive testing, we would create test files in tests/
1581 // directory that use the macro and verify compilation and behavior.
1582
1583 #[test]
1584 fn test_macro_module_exists() {
1585 // Basic test to ensure the module compiles
1586 // Compilation success is the test
1587 }
1588
1589 #[test]
1590 fn test_extract_doc_comments_single_line() {
1591 // Create a mock attribute for a single line doc comment
1592 let attr = syn::parse_quote! { #[doc = " This is a single line comment"] };
1593 let attrs = vec![attr];
1594 let result = extract_doc_comments(&attrs);
1595 assert_eq!(result, "This is a single line comment");
1596 }
1597
1598 #[test]
1599 fn test_extract_doc_comments_multiple_lines() {
1600 // Create mock attributes for multiple line doc comments
1601 let attr1 = syn::parse_quote! { #[doc = " First line"] };
1602 let attr2 = syn::parse_quote! { #[doc = " Second line"] };
1603 let attr3 = syn::parse_quote! { #[doc = " Third line"] };
1604 let attrs = vec![attr1, attr2, attr3];
1605 let result = extract_doc_comments(&attrs);
1606 assert_eq!(result, "First line\nSecond line\nThird line");
1607 }
1608
1609 #[test]
1610 fn test_extract_doc_comments_no_leading_space() {
1611 let attr = syn::parse_quote! { #[doc = "No leading space"] };
1612 let attrs = vec![attr];
1613 let result = extract_doc_comments(&attrs);
1614 assert_eq!(result, "No leading space");
1615 }
1616
1617 #[test]
1618 fn test_extract_doc_comments_mixed_with_other_attrs() {
1619 let doc_attr = syn::parse_quote! { #[doc = " Documentation comment"] };
1620 let other_attr = syn::parse_quote! { #[allow(unused)] };
1621 let attrs = vec![other_attr, doc_attr];
1622 let result = extract_doc_comments(&attrs);
1623 assert_eq!(result, "Documentation comment");
1624 }
1625
1626 #[test]
1627 fn test_extract_doc_comments_empty_doc() {
1628 let attr = syn::parse_quote! { #[doc = ""] };
1629 let attrs = vec![attr];
1630 let result = extract_doc_comments(&attrs);
1631 assert_eq!(result, "");
1632 }
1633
1634 #[test]
1635 fn test_extract_doc_comments_whitespace_only() {
1636 let attr = syn::parse_quote! { #[doc = " "] };
1637 let attrs = vec![attr];
1638 let result = extract_doc_comments(&attrs);
1639 assert_eq!(result, "");
1640 }
1641
1642 #[test]
1643 fn test_generate_tool_error_match_arms_compilation() {
1644 // Test that the generated match arms compile by checking their structure
1645 let match_arms = generate_tool_error_match_arms();
1646 let generated_string = match_arms.to_string();
1647
1648 // Check that the new simplified structure is used
1649 assert!(generated_string.contains("JobResult :: Failure"));
1650 assert!(generated_string.contains("error : tool_error"));
1651 // Verify it uses wildcard matching for simplified error handling
1652 assert!(generated_string.contains("_ =>"));
1653 }
1654
1655 #[test]
1656 fn test_tool_attr_default() {
1657 let default_attr = ToolAttr::default();
1658 assert!(default_attr.description.is_none());
1659 }
1660
1661 #[test]
1662 fn test_tool_attr_parse_empty() {
1663 let input = "";
1664 let result: Result<ToolAttr, _> = syn::parse_str(input);
1665 assert!(result.is_ok());
1666 let attr = result.unwrap();
1667 assert!(attr.description.is_none());
1668 }
1669
1670 #[test]
1671 fn test_tool_attr_parse_description() {
1672 let input = r#"description = "Test description""#;
1673 let result: Result<ToolAttr, _> = syn::parse_str(input);
1674 assert!(result.is_ok());
1675 let attr = result.unwrap();
1676 assert_eq!(attr.description, Some("Test description".to_string()));
1677 }
1678
1679 #[test]
1680 fn test_tool_attr_parse_invalid_key() {
1681 let input = r#"invalid_key = "value""#;
1682 let result: Result<ToolAttr, _> = syn::parse_str(input);
1683 assert!(result.is_err());
1684 let err = result.unwrap_err();
1685 assert!(err.to_string().contains("Unknown attribute key"));
1686 }
1687
1688 #[test]
1689 fn test_tool_attr_parse_missing_equals() {
1690 let input = "description";
1691 let result: Result<ToolAttr, _> = syn::parse_str(input);
1692 assert!(result.is_err());
1693 }
1694
1695 #[test]
1696 fn test_tool_attr_parse_wrong_value_type() {
1697 let input = "description = 123";
1698 let result: Result<ToolAttr, _> = syn::parse_str(input);
1699 assert!(result.is_err());
1700 }
1701
1702 #[test]
1703 fn test_heck_pascal_case_edge_cases() {
1704 assert_eq!("".to_pascal_case(), "");
1705 assert_eq!("a".to_pascal_case(), "A");
1706 assert_eq!("_test_".to_pascal_case(), "Test");
1707 assert_eq!("test__function".to_pascal_case(), "TestFunction");
1708 assert_eq!("UPPERCASE".to_pascal_case(), "Uppercase");
1709 assert_eq!("mixedCase".to_pascal_case(), "MixedCase");
1710 assert_eq!("123_numeric".to_pascal_case(), "123Numeric");
1711 }
1712
1713 // Test the pattern matching logic for derive_into_tool_error
1714 #[test]
1715 fn test_retriable_error_patterns() {
1716 let retriable_patterns = [
1717 "Rpc",
1718 "Network",
1719 "Connection",
1720 "Timeout",
1721 "TooManyRequests",
1722 "RateLimit",
1723 "Api",
1724 "Http",
1725 ];
1726
1727 // Test each pattern is correctly identified
1728 for pattern in &retriable_patterns {
1729 let test_variant = format!("Test{}Error", pattern);
1730 assert!(retriable_patterns.iter().any(|p| test_variant.contains(p)));
1731 }
1732 }
1733
1734 #[test]
1735 fn test_permanent_error_patterns() {
1736 let permanent_variants = [
1737 "InvalidInput",
1738 "ParseError",
1739 "SerializationFailed",
1740 "NotFound",
1741 "Unauthorized",
1742 "InsufficientBalance",
1743 "InsufficientFunds",
1744 "CustomError",
1745 "UnknownError",
1746 ];
1747
1748 let retriable_patterns = [
1749 "Rpc",
1750 "Network",
1751 "Connection",
1752 "Timeout",
1753 "TooManyRequests",
1754 "RateLimit",
1755 "Api",
1756 "Http",
1757 ];
1758
1759 // Test that permanent patterns don't match retriable patterns
1760 for variant in &permanent_variants {
1761 let is_retriable = retriable_patterns
1762 .iter()
1763 .any(|pattern| variant.contains(pattern));
1764 assert!(!is_retriable, "Variant {} should not be retriable", variant);
1765 }
1766 }
1767
1768 #[test]
1769 fn test_error_match_arms_structure() {
1770 let match_arms = generate_tool_error_match_arms();
1771 let generated = match_arms.to_string();
1772
1773 // Verify the new simplified structure
1774 assert!(generated.contains("JobResult :: Failure"));
1775 assert!(generated.contains("error : tool_error"));
1776 // Check that it uses wildcard matching
1777 assert!(generated.contains("_ =>"));
1778 }
1779
1780 // Test compilation of procedural macro output (basic structure validation)
1781 #[test]
1782 fn test_proc_macro_token_stream_generation() {
1783 // Test that we can create basic token streams without panicking
1784 use quote::quote;
1785
1786 let test_tokens = quote! {
1787 #[derive(Clone)]
1788 pub struct TestTool;
1789
1790 impl TestTool {
1791 pub fn new() -> Self { Self }
1792 }
1793 };
1794
1795 assert!(!test_tokens.is_empty());
1796 }
1797
1798 #[test]
1799 fn test_doc_comment_extraction_with_complex_content() {
1800 let attr1 = syn::parse_quote! { #[doc = " Complex content with \"quotes\""] };
1801 let attr2 = syn::parse_quote! { #[doc = " And special chars: &<>"] };
1802 let attr3 = syn::parse_quote! { #[doc = " Numbers: 123 and symbols: $%^"] };
1803 let attrs = vec![attr1, attr2, attr3];
1804 let result = extract_doc_comments(&attrs);
1805 assert_eq!(result, "Complex content with \"quotes\"\nAnd special chars: &<>\nNumbers: 123 and symbols: $%^");
1806 }
1807
1808 #[test]
1809 fn test_doc_comment_trimming() {
1810 let attr1 = syn::parse_quote! { #[doc = " Leading spaces"] };
1811 let attr2 = syn::parse_quote! { #[doc = ""] };
1812 let attr3 = syn::parse_quote! { #[doc = "Trailing spaces "] };
1813 let attrs = vec![attr1, attr2, attr3];
1814 let result = extract_doc_comments(&attrs);
1815 // The function strips the first space but preserves other leading spaces
1816 assert_eq!(result, "Leading spaces\n\nTrailing spaces");
1817 }
1818
1819 #[test]
1820 fn test_tool_attr_parse_description_with_quotes() {
1821 let input = r#"description = "Description with \"escaped quotes\"""#;
1822 let result: Result<ToolAttr, _> = syn::parse_str(input);
1823 assert!(result.is_ok());
1824 let attr = result.unwrap();
1825 assert_eq!(
1826 attr.description,
1827 Some("Description with \"escaped quotes\"".to_string())
1828 );
1829 }
1830
1831 #[test]
1832 fn test_tool_attr_parse_description_empty_string() {
1833 let input = r#"description = """#;
1834 let result: Result<ToolAttr, _> = syn::parse_str(input);
1835 assert!(result.is_ok());
1836 let attr = result.unwrap();
1837 assert_eq!(attr.description, Some("".to_string()));
1838 }
1839
1840 #[test]
1841 fn test_extract_doc_comments_only_doc_attrs() {
1842 // Test that only doc attributes are processed, others are ignored
1843 let doc_attr = syn::parse_quote! { #[doc = " Valid doc comment"] };
1844 let cfg_attr = syn::parse_quote! { #[cfg(test)] };
1845 let allow_attr = syn::parse_quote! { #[allow(dead_code)] };
1846 let derive_attr = syn::parse_quote! { #[derive(Clone)] };
1847
1848 let attrs = vec![cfg_attr, doc_attr, allow_attr, derive_attr];
1849 let result = extract_doc_comments(&attrs);
1850 assert_eq!(result, "Valid doc comment");
1851 }
1852
1853 #[test]
1854 fn test_proc_macro_attr_integration() {
1855 // Test the integration between attribute parsing and tool generation
1856 let empty_attr = ToolAttr::default();
1857 let with_desc = ToolAttr {
1858 description: Some("Custom description".to_string()),
1859 };
1860
1861 // Test that attributes are properly structured
1862 assert!(empty_attr.description.is_none());
1863 assert_eq!(
1864 with_desc.description,
1865 Some("Custom description".to_string())
1866 );
1867 }
1868
1869 #[test]
1870 fn test_complex_pascal_case_scenarios() {
1871 // Test edge cases for function name to struct name conversion
1872 assert_eq!("get_user_profile".to_pascal_case(), "GetUserProfile");
1873 assert_eq!("fetch_api_data".to_pascal_case(), "FetchApiData");
1874 assert_eq!(
1875 "handle_websocket_connection".to_pascal_case(),
1876 "HandleWebsocketConnection"
1877 );
1878 assert_eq!(
1879 "process_json_response".to_pascal_case(),
1880 "ProcessJsonResponse"
1881 );
1882 assert_eq!(
1883 "validate_eth_address".to_pascal_case(),
1884 "ValidateEthAddress"
1885 );
1886 }
1887
1888 #[test]
1889 fn test_error_classification_comprehensive() {
1890 let test_cases = vec![
1891 ("RpcConnectionError", true), // Should be retriable
1892 ("NetworkTimeoutError", true), // Should be retriable
1893 ("ApiRateLimitError", true), // Should be retriable
1894 ("HttpRequestError", true), // Should be retriable
1895 ("InvalidInputError", false), // Should be permanent
1896 ("ParseError", false), // Should be permanent
1897 ("NotFoundError", false), // Should be permanent
1898 ("UnauthorizedError", false), // Should be permanent
1899 ("InsufficientBalanceError", false), // Should be permanent
1900 ("CustomBusinessError", false), // Should be permanent (default)
1901 ("DatabaseConnectionError", true), // Contains "Connection"
1902 ("TooManyRequestsError", true), // Contains "TooManyRequests"
1903 ];
1904
1905 let retriable_patterns = [
1906 "Rpc",
1907 "Network",
1908 "Connection",
1909 "Timeout",
1910 "TooManyRequests",
1911 "RateLimit",
1912 "Api",
1913 "Http",
1914 ];
1915
1916 for (variant_name, expected_retriable) in test_cases {
1917 let is_retriable = retriable_patterns
1918 .iter()
1919 .any(|pattern| variant_name.contains(pattern));
1920
1921 assert_eq!(
1922 is_retriable,
1923 expected_retriable,
1924 "Variant '{}' should be {} but was classified as {}",
1925 variant_name,
1926 if expected_retriable {
1927 "retriable"
1928 } else {
1929 "permanent"
1930 },
1931 if is_retriable {
1932 "retriable"
1933 } else {
1934 "permanent"
1935 }
1936 );
1937 }
1938 }
1939
1940 #[test]
1941 fn test_tool_attr_parse_malformed_syntax() {
1942 // Test various malformed syntax cases
1943 let test_cases = vec![
1944 "description =", // Missing value
1945 "= \"value\"", // Missing key
1946 "description \"value\"", // Missing equals
1947 "description = value", // Unquoted value
1948 "description == \"value\"", // Double equals
1949 ];
1950
1951 for input in test_cases {
1952 let result: Result<ToolAttr, _> = syn::parse_str(input);
1953 assert!(result.is_err(), "Input '{}' should fail to parse", input);
1954 }
1955 }
1956
1957 #[test]
1958 fn test_extract_doc_comments_with_non_string_meta() {
1959 // Test with attributes that have doc but aren't string literals
1960 // This tests the filtering logic in extract_doc_comments
1961 let valid_doc = syn::parse_quote! { #[doc = "Valid comment"] };
1962 let attrs = vec![valid_doc];
1963 let result = extract_doc_comments(&attrs);
1964 assert_eq!(result, "Valid comment");
1965 }
1966
1967 #[test]
1968 fn test_doc_comment_joining_edge_cases() {
1969 // Test doc comment joining with various whitespace scenarios
1970 let attr1 = syn::parse_quote! { #[doc = "Line1"] };
1971 let attr2 = syn::parse_quote! { #[doc = " "] }; // Just a space
1972 let attr3 = syn::parse_quote! { #[doc = "Line3"] };
1973 let attrs = vec![attr1, attr2, attr3];
1974 let result = extract_doc_comments(&attrs);
1975 assert_eq!(result, "Line1\n\nLine3");
1976 }
1977
1978 #[test]
1979 fn test_pascal_case_with_unicode() {
1980 // Test pascal case conversion with unicode characters
1981 assert_eq!("café_function".to_pascal_case(), "CaféFunction");
1982 assert_eq!("测试_function".to_pascal_case(), "测试Function");
1983 }
1984
1985 #[test]
1986 fn test_tool_attr_description_priority() {
1987 // Test that explicit description takes priority over doc comments
1988 let explicit_desc = ToolAttr {
1989 description: Some("Explicit description".to_string()),
1990 };
1991 assert_eq!(
1992 explicit_desc.description,
1993 Some("Explicit description".to_string())
1994 );
1995
1996 let no_desc = ToolAttr { description: None };
1997 assert!(no_desc.description.is_none());
1998 }
1999
2000 #[test]
2001 fn test_generate_match_arms_output_consistency() {
2002 // Test that generate_tool_error_match_arms produces consistent output
2003 let match_arms1 = generate_tool_error_match_arms();
2004 let match_arms2 = generate_tool_error_match_arms();
2005
2006 // Convert to strings and compare
2007 let output1 = match_arms1.to_string();
2008 let output2 = match_arms2.to_string();
2009 assert_eq!(
2010 output1, output2,
2011 "Match arms generation should be deterministic"
2012 );
2013 }
2014
2015 #[test]
2016 fn test_doc_comment_extract_path_verification() {
2017 // Test that extract_doc_comments properly checks path identity
2018 let doc_attr = syn::parse_quote! { #[doc = "Test"] };
2019 let not_doc_attr = syn::parse_quote! { #[deprecated] };
2020
2021 let attrs = vec![not_doc_attr, doc_attr];
2022 let result = extract_doc_comments(&attrs);
2023 assert_eq!(result, "Test");
2024 }
2025
2026 #[test]
2027 fn test_error_pattern_case_sensitivity() {
2028 // Test that error pattern matching is case-sensitive
2029 let case_sensitive_tests = vec![
2030 ("rpc_error", false), // lowercase 'rpc' should not match 'Rpc'
2031 ("RpcError", true), // Uppercase 'Rpc' should match
2032 ("network_issue", false), // lowercase 'network' should not match 'Network'
2033 ("NetworkIssue", true), // Uppercase 'Network' should match
2034 ];
2035
2036 let retriable_patterns = [
2037 "Rpc",
2038 "Network",
2039 "Connection",
2040 "Timeout",
2041 "TooManyRequests",
2042 "RateLimit",
2043 "Api",
2044 "Http",
2045 ];
2046
2047 for (variant_name, expected_match) in case_sensitive_tests {
2048 let matches = retriable_patterns
2049 .iter()
2050 .any(|pattern| variant_name.contains(pattern));
2051 assert_eq!(
2052 matches, expected_match,
2053 "Case sensitivity test failed for '{}'",
2054 variant_name
2055 );
2056 }
2057 }
2058
2059 #[test]
2060 fn test_tool_attr_parse_lookahead_logic() {
2061 // Test the lookahead logic in ToolAttr::parse
2062 let valid_input = "description = \"test\"";
2063 let result: Result<ToolAttr, _> = syn::parse_str(valid_input);
2064 assert!(result.is_ok());
2065
2066 // Test with invalid identifier that triggers lookahead error
2067 let invalid_input = "123invalid = \"test\"";
2068 let result: Result<ToolAttr, _> = syn::parse_str(invalid_input);
2069 assert!(result.is_err());
2070 }
2071
2072 #[test]
2073 fn test_comprehensive_error_variant_naming() {
2074 // Comprehensive test of error variant naming patterns
2075 let comprehensive_tests = vec![
2076 // Retriable patterns
2077 ("SolanaRpcError", true),
2078 ("EthereumNetworkTimeout", true),
2079 ("DatabaseConnectionLost", true),
2080 ("ApiRateLimitExceeded", true),
2081 ("HttpRequestFailed", true),
2082 ("WebSocketConnectionDropped", true),
2083 ("RedisConnectionTimeout", true),
2084 ("TooManyRequestsReceived", true),
2085 // Permanent patterns
2086 ("InvalidAddressFormat", false),
2087 ("ParseJsonError", false),
2088 ("SerializationFailure", false),
2089 ("UserNotFound", false),
2090 ("UnauthorizedAccess", false),
2091 ("InsufficientTokenBalance", false),
2092 ("InsufficientGasFunds", false),
2093 ("MalformedInput", false),
2094 ("ConfigurationError", false),
2095 ("BusinessLogicViolation", false),
2096 ];
2097
2098 let retriable_patterns = [
2099 "Rpc",
2100 "Network",
2101 "Connection",
2102 "Timeout",
2103 "TooManyRequests",
2104 "RateLimit",
2105 "Api",
2106 "Http",
2107 ];
2108
2109 for (variant_name, expected_retriable) in comprehensive_tests {
2110 let is_retriable = retriable_patterns
2111 .iter()
2112 .any(|pattern| variant_name.contains(pattern));
2113
2114 assert_eq!(
2115 is_retriable,
2116 expected_retriable,
2117 "Comprehensive error classification failed for '{}' - expected {}, got {}",
2118 variant_name,
2119 if expected_retriable {
2120 "retriable"
2121 } else {
2122 "permanent"
2123 },
2124 if is_retriable {
2125 "retriable"
2126 } else {
2127 "permanent"
2128 }
2129 );
2130 }
2131 }
2132
2133 #[test]
2134 fn test_empty_and_whitespace_edge_cases() {
2135 // Test various empty and whitespace scenarios
2136 let empty_attrs: Vec<syn::Attribute> = vec![];
2137 assert_eq!(extract_doc_comments(&empty_attrs), "");
2138
2139 // Test with only whitespace doc
2140 let whitespace_attr = syn::parse_quote! { #[doc = " \t\n "] };
2141 let result = extract_doc_comments(&vec![whitespace_attr]);
2142 assert_eq!(result.trim(), "");
2143
2144 // Test pascal case with empty string
2145 assert_eq!("".to_pascal_case(), "");
2146 }
2147
2148 #[test]
2149 fn test_parameter_parsing_error_handling() {
2150 // Test that parameter parsing errors are converted to ToolError::InvalidInput
2151 // and use the standard error matching logic
2152
2153 // Create a mock serde_json::Error by attempting to parse invalid JSON
2154 let invalid_json = "{ invalid json }";
2155 let parse_result: Result<serde_json::Value, serde_json::Error> =
2156 serde_json::from_str(invalid_json);
2157 assert!(parse_result.is_err());
2158
2159 let error = parse_result.unwrap_err();
2160
2161 // Verify that we can create a ToolError::InvalidInput from the serde error
2162 use riglr_core::ToolError;
2163 let tool_error =
2164 ToolError::invalid_input_with_source(error, "Failed to parse tool parameters");
2165
2166 // Verify properties of the error
2167 assert!(!tool_error.is_retriable());
2168 assert!(!tool_error.is_rate_limited());
2169 assert_eq!(tool_error.retry_after(), None);
2170
2171 // Verify the error message contains expected content
2172 let error_str = tool_error.to_string();
2173 assert!(error_str.contains("Invalid input"));
2174 assert!(error_str.contains("Failed to parse tool parameters"));
2175 }
2176
2177 #[test]
2178 fn test_tool_error_match_arms_invalid_input_handling() {
2179 // Test that the generated match arms handle all errors with simplified structure
2180 let match_arms = generate_tool_error_match_arms();
2181 let generated = match_arms.to_string();
2182
2183 // Verify the simplified structure handles all errors uniformly
2184 assert!(generated.contains("JobResult :: Failure"));
2185 assert!(generated.contains("error : tool_error"));
2186 // Verify it uses wildcard matching for all error types
2187 assert!(generated.contains("_ =>"));
2188 }
2189}