Skip to main content

limit_cli/
agent_bridge.rs

1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4    AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5    GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6    LspTool,
7};
8use futures::StreamExt;
9use limit_agent::executor::{ToolCall, ToolExecutor};
10use limit_agent::registry::ToolRegistry;
11use limit_llm::providers::LlmProvider;
12use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
13use limit_llm::ProviderFactory;
14use limit_llm::ProviderResponseChunk;
15use limit_llm::TrackingDb;
16use serde_json::json;
17use tokio::sync::mpsc;
18use tracing::{debug, instrument};
19
20/// Event types for streaming from agent to REPL
21#[derive(Debug, Clone)]
22#[allow(dead_code)]
23pub enum AgentEvent {
24    Thinking,
25    ToolStart {
26        name: String,
27        args: serde_json::Value,
28    },
29    ToolComplete {
30        name: String,
31        result: String,
32    },
33    ContentChunk(String),
34    Done,
35    Error(String),
36    TokenUsage {
37        input_tokens: u64,
38        output_tokens: u64,
39    },
40}
41
42/// Bridge connecting limit-cli REPL to limit-agent executor and limit-llm client
43pub struct AgentBridge {
44    /// LLM client for communicating with LLM providers
45    llm_client: Box<dyn LlmProvider>,
46    /// Tool executor for running tool calls
47    executor: ToolExecutor,
48    /// List of registered tool names
49    tool_names: Vec<&'static str>,
50    /// Configuration loaded from ~/.limit/config.toml
51    config: limit_llm::Config,
52    /// Event sender for streaming events to REPL
53    event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
54    /// Token usage tracking database
55    tracking_db: TrackingDb,
56}
57
58impl AgentBridge {
59    /// Create a new AgentBridge with the given configuration
60    ///
61    /// # Arguments
62    /// * `config` - LLM configuration (API key, model, etc.)
63    ///
64    /// # Returns
65    /// A new AgentBridge instance or an error if initialization fails
66    pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
67        let llm_client = ProviderFactory::create_provider(&config)
68            .map_err(|e| CliError::ConfigError(e.to_string()))?;
69
70        let mut tool_registry = ToolRegistry::new();
71        Self::register_tools(&mut tool_registry);
72
73        // Create executor (which takes ownership of registry as Arc)
74        let executor = ToolExecutor::new(tool_registry);
75
76        // Generate tool definitions before giving ownership to executor
77        let tool_names = vec![
78            "file_read",
79            "file_write",
80            "file_edit",
81            "bash",
82            "git_status",
83            "git_diff",
84            "git_log",
85            "git_add",
86            "git_commit",
87            "git_push",
88            "git_pull",
89            "git_clone",
90            "grep",
91            "ast_grep",
92            "lsp",
93        ];
94
95        Ok(Self {
96            llm_client,
97            executor,
98            tool_names,
99            config,
100            event_tx: None,
101            tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
102        })
103    }
104
105    /// Set the event channel sender for streaming events
106    pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
107        self.event_tx = Some(tx);
108    }
109
110    /// Register all CLI tools into the tool registry
111    fn register_tools(registry: &mut ToolRegistry) {
112        // File tools
113        registry
114            .register(FileReadTool::new())
115            .expect("Failed to register file_read");
116        registry
117            .register(FileWriteTool::new())
118            .expect("Failed to register file_write");
119        registry
120            .register(FileEditTool::new())
121            .expect("Failed to register file_edit");
122
123        // Bash tool
124        registry
125            .register(BashTool::new())
126            .expect("Failed to register bash");
127
128        // Git tools
129        registry
130            .register(GitStatusTool::new())
131            .expect("Failed to register git_status");
132        registry
133            .register(GitDiffTool::new())
134            .expect("Failed to register git_diff");
135        registry
136            .register(GitLogTool::new())
137            .expect("Failed to register git_log");
138        registry
139            .register(GitAddTool::new())
140            .expect("Failed to register git_add");
141        registry
142            .register(GitCommitTool::new())
143            .expect("Failed to register git_commit");
144        registry
145            .register(GitPushTool::new())
146            .expect("Failed to register git_push");
147        registry
148            .register(GitPullTool::new())
149            .expect("Failed to register git_pull");
150        registry
151            .register(GitCloneTool::new())
152            .expect("Failed to register git_clone");
153
154        // Analysis tools
155        registry
156            .register(GrepTool::new())
157            .expect("Failed to register grep");
158        registry
159            .register(AstGrepTool::new())
160            .expect("Failed to register ast_grep");
161        registry
162            .register(LspTool::new())
163            .expect("Failed to register lsp");
164    }
165
166    /// Process a user message through the LLM and execute any tool calls
167    ///
168    /// # Arguments
169    /// * `user_input` - The user's message to process
170    /// * `messages` - The conversation history (will be updated in place)
171    ///
172    /// # Returns
173    /// The final response from the LLM or an error
174    #[instrument(skip(self, _messages))]
175    pub async fn process_message(
176        &mut self,
177        user_input: &str,
178        _messages: &mut Vec<Message>,
179    ) -> Result<String, CliError> {
180        // Add system message if this is the first message in the conversation
181        // Note: Some providers (z.ai) don't support system role, but OpenAI-compatible APIs generally do
182        if _messages.is_empty() {
183            let system_message = Message {
184                role: Role::System,
185                content: Some(SYSTEM_PROMPT.to_string()),
186                tool_calls: None,
187                tool_call_id: None,
188            };
189            _messages.push(system_message);
190        }
191
192        // Add user message to history
193        let user_message = Message {
194            role: Role::User,
195            content: Some(user_input.to_string()),
196            tool_calls: None,
197            tool_call_id: None,
198        };
199        _messages.push(user_message);
200
201        // Get tool definitions
202        let tool_definitions = self.get_tool_definitions();
203
204        // Main processing loop
205        let mut full_response = String::new();
206        let mut tool_calls: Vec<LlmToolCall> = Vec::new();
207        let max_iterations = self
208            .config
209            .providers
210            .get(&self.config.provider)
211            .map(|p| p.max_iterations)
212            .unwrap_or(100); // Allow enough iterations for complex tasks
213        let mut iteration = 0;
214
215        while max_iterations == 0 || iteration < max_iterations {
216            iteration += 1;
217            debug!("Agent loop iteration {}", iteration);
218
219            // Send thinking event
220            self.send_event(AgentEvent::Thinking);
221
222            // Track timing for token usage
223            let request_start = std::time::Instant::now();
224
225            // Call LLM
226            let mut stream = self
227                .llm_client
228                .send(_messages.clone(), tool_definitions.clone())
229                .await
230                .map_err(|e| CliError::ConfigError(e.to_string()))?;
231
232            tool_calls.clear();
233            let mut current_content = String::new();
234            // Track tool calls: (id) -> (name, args)
235            let mut accumulated_calls: std::collections::HashMap<
236                String,
237                (String, serde_json::Value),
238            > = std::collections::HashMap::new();
239
240            // Process stream chunks
241            while let Some(chunk_result) = stream.next().await {
242                match chunk_result {
243                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
244                        current_content.push_str(&text);
245                        self.send_event(AgentEvent::ContentChunk(text));
246                    }
247                    Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
248                        // Ignore reasoning chunks for now
249                    }
250                    Ok(ProviderResponseChunk::ToolCallDelta {
251                        id,
252                        name,
253                        arguments,
254                    }) => {
255                        debug!("ToolCallDelta: id={}, name={}", id, name);
256                        // Store/merge tool call arguments
257                        accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
258                    }
259                    Ok(ProviderResponseChunk::Done(usage)) => {
260                        // Track token usage
261                        let duration_ms = request_start.elapsed().as_millis() as u64;
262                        let cost =
263                            calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
264                        let _ = self.tracking_db.track_request(
265                            self.model(),
266                            usage.input_tokens,
267                            usage.output_tokens,
268                            cost,
269                            duration_ms,
270                        );
271                        // Emit token usage event for TUI display
272                        self.send_event(AgentEvent::TokenUsage {
273                            input_tokens: usage.input_tokens,
274                            output_tokens: usage.output_tokens,
275                        });
276                        break;
277                    }
278                    Err(e) => {
279                        let error_msg = format!("LLM error: {}", e);
280                        self.send_event(AgentEvent::Error(error_msg.clone()));
281                        return Err(CliError::ConfigError(error_msg));
282                    }
283                }
284            }
285
286            // Convert accumulated calls to Vec<ToolCall>
287            tool_calls = accumulated_calls
288                .into_iter()
289                .map(|(id, (name, args))| LlmToolCall {
290                    id,
291                    tool_type: "function".to_string(),
292                    function: limit_llm::types::FunctionCall {
293                        name,
294                        arguments: args.to_string(),
295                    },
296                })
297                .collect();
298            full_response.push_str(&current_content);
299
300            debug!(
301                "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
302                iteration,
303                current_content.len(),
304                tool_calls.len(),
305                full_response.len()
306            );
307
308            // If no tool calls, we're done
309            if tool_calls.is_empty() {
310                break;
311            }
312
313            // Execute tool calls - add assistant message with tool_calls
314            // Note: Per OpenAI API spec, when tool_calls are present, content should be null
315            let assistant_message = Message {
316                role: Role::Assistant,
317                content: None, // Don't include content when tool_calls are present
318                tool_calls: Some(tool_calls.clone()),
319                tool_call_id: None,
320            };
321            _messages.push(assistant_message);
322
323            // Convert LLM tool calls to executor tool calls
324            let executor_calls: Vec<ToolCall> = tool_calls
325                .iter()
326                .map(|tc| {
327                    let args: serde_json::Value =
328                        serde_json::from_str(&tc.function.arguments).unwrap_or_default();
329                    ToolCall::new(&tc.id, &tc.function.name, args)
330                })
331                .collect();
332
333            // Send ToolStart event for each tool BEFORE execution
334            for tc in &tool_calls {
335                let args: serde_json::Value =
336                    serde_json::from_str(&tc.function.arguments).unwrap_or_default();
337                self.send_event(AgentEvent::ToolStart {
338                    name: tc.function.name.clone(),
339                    args,
340                });
341            }
342            // Execute tools
343            let results = self.executor.execute_tools(executor_calls).await;
344
345            // Add tool results to messages (OpenAI format: role=tool, tool_call_id, content)
346            for result in results {
347                let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
348                if let Some(tool_call) = tool_call {
349                    let output_json = match &result.output {
350                        Ok(value) => {
351                            serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
352                        }
353                        Err(e) => json!({ "error": e.to_string() }).to_string(),
354                    };
355
356                    self.send_event(AgentEvent::ToolComplete {
357                        name: tool_call.function.name.clone(),
358                        result: output_json.clone(),
359                    });
360
361                    // OpenAI tool result format
362                    let tool_result_message = Message {
363                        role: Role::Tool,
364                        content: Some(output_json),
365                        tool_calls: None,
366                        tool_call_id: Some(result.call_id),
367                    };
368                    _messages.push(tool_result_message);
369                }
370            }
371        }
372
373        // If we hit max iterations, make one final request to get a response (no tools = forced text)
374        if iteration >= max_iterations && !_messages.is_empty() {
375            debug!("Making final LLM call after hitting max iterations (forcing text response)");
376
377            // Add constraint message to force text response
378            let constraint_message = Message {
379                role: Role::User,
380                content: Some(
381                    "We've reached the iteration limit. Please provide a summary of:\n\
382                    1. What you've completed so far\n\
383                    2. What remains to be done\n\
384                    3. Recommended next steps for the user to continue"
385                        .to_string(),
386                ),
387                tool_calls: None,
388                tool_call_id: None,
389            };
390            _messages.push(constraint_message);
391
392            // Send with NO tools to force text response
393            let no_tools: Vec<LlmTool> = vec![];
394            let mut stream = self
395                .llm_client
396                .send(_messages.clone(), no_tools)
397                .await
398                .map_err(|e| CliError::ConfigError(e.to_string()))?;
399
400            while let Some(chunk_result) = stream.next().await {
401                match chunk_result {
402                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
403                        full_response.push_str(&text);
404                        self.send_event(AgentEvent::ContentChunk(text));
405                    }
406                    Ok(ProviderResponseChunk::Done(_)) => {
407                        break;
408                    }
409                    Err(e) => {
410                        debug!("Error in final LLM call: {}", e);
411                        break;
412                    }
413                    _ => {}
414                }
415            }
416        }
417
418        self.send_event(AgentEvent::Done);
419        Ok(full_response)
420    }
421
422    /// Get tool definitions formatted for the LLM
423    pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
424        self.tool_names
425            .iter()
426            .map(|name| {
427                let (description, parameters) = Self::get_tool_schema(name);
428                LlmTool {
429                    tool_type: "function".to_string(),
430                    function: limit_llm::types::ToolFunction {
431                        name: name.to_string(),
432                        description,
433                        parameters,
434                    },
435                }
436            })
437            .collect()
438    }
439
440    /// Get the schema (description and parameters) for a tool
441    fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
442        match name {
443            "file_read" => (
444                "Read the contents of a file".to_string(),
445                json!({
446                    "type": "object",
447                    "properties": {
448                        "path": {
449                            "type": "string",
450                            "description": "Path to the file to read"
451                        }
452                    },
453                    "required": ["path"]
454                }),
455            ),
456            "file_write" => (
457                "Write content to a file, creating parent directories if needed".to_string(),
458                json!({
459                    "type": "object",
460                    "properties": {
461                        "path": {
462                            "type": "string",
463                            "description": "Path to the file to write"
464                        },
465                        "content": {
466                            "type": "string",
467                            "description": "Content to write to the file"
468                        }
469                    },
470                    "required": ["path", "content"]
471                }),
472            ),
473            "file_edit" => (
474                "Replace text in a file with new text".to_string(),
475                json!({
476                    "type": "object",
477                    "properties": {
478                        "path": {
479                            "type": "string",
480                            "description": "Path to the file to edit"
481                        },
482                        "old_text": {
483                            "type": "string",
484                            "description": "Text to find and replace"
485                        },
486                        "new_text": {
487                            "type": "string",
488                            "description": "New text to replace with"
489                        }
490                    },
491                    "required": ["path", "old_text", "new_text"]
492                }),
493            ),
494            "bash" => (
495                "Execute a bash command in a shell".to_string(),
496                json!({
497                    "type": "object",
498                    "properties": {
499                        "command": {
500                            "type": "string",
501                            "description": "Bash command to execute"
502                        },
503                        "workdir": {
504                            "type": "string",
505                            "description": "Working directory (default: current directory)"
506                        },
507                        "timeout": {
508                            "type": "integer",
509                            "description": "Timeout in seconds (default: 60)"
510                        }
511                    },
512                    "required": ["command"]
513                }),
514            ),
515            "git_status" => (
516                "Get git repository status".to_string(),
517                json!({
518                    "type": "object",
519                    "properties": {},
520                    "required": []
521                }),
522            ),
523            "git_diff" => (
524                "Get git diff".to_string(),
525                json!({
526                    "type": "object",
527                    "properties": {},
528                    "required": []
529                }),
530            ),
531            "git_log" => (
532                "Get git commit log".to_string(),
533                json!({
534                    "type": "object",
535                    "properties": {
536                        "count": {
537                            "type": "integer",
538                            "description": "Number of commits to show (default: 10)"
539                        }
540                    },
541                    "required": []
542                }),
543            ),
544            "git_add" => (
545                "Add files to git staging area".to_string(),
546                json!({
547                    "type": "object",
548                    "properties": {
549                        "files": {
550                            "type": "array",
551                            "items": {"type": "string"},
552                            "description": "List of file paths to add"
553                        }
554                    },
555                    "required": ["files"]
556                }),
557            ),
558            "git_commit" => (
559                "Create a git commit".to_string(),
560                json!({
561                    "type": "object",
562                    "properties": {
563                        "message": {
564                            "type": "string",
565                            "description": "Commit message"
566                        }
567                    },
568                    "required": ["message"]
569                }),
570            ),
571            "git_push" => (
572                "Push commits to remote repository".to_string(),
573                json!({
574                    "type": "object",
575                    "properties": {
576                        "remote": {
577                            "type": "string",
578                            "description": "Remote name (default: origin)"
579                        },
580                        "branch": {
581                            "type": "string",
582                            "description": "Branch name (default: current branch)"
583                        }
584                    },
585                    "required": []
586                }),
587            ),
588            "git_pull" => (
589                "Pull changes from remote repository".to_string(),
590                json!({
591                    "type": "object",
592                    "properties": {
593                        "remote": {
594                            "type": "string",
595                            "description": "Remote name (default: origin)"
596                        },
597                        "branch": {
598                            "type": "string",
599                            "description": "Branch name (default: current branch)"
600                        }
601                    },
602                    "required": []
603                }),
604            ),
605            "git_clone" => (
606                "Clone a git repository".to_string(),
607                json!({
608                    "type": "object",
609                    "properties": {
610                        "url": {
611                            "type": "string",
612                            "description": "Repository URL to clone"
613                        },
614                        "directory": {
615                            "type": "string",
616                            "description": "Directory to clone into (optional)"
617                        }
618                    },
619                    "required": ["url"]
620                }),
621            ),
622            "grep" => (
623                "Search for text patterns in files using regex".to_string(),
624                json!({
625                    "type": "object",
626                    "properties": {
627                        "pattern": {
628                            "type": "string",
629                            "description": "Regex pattern to search for"
630                        },
631                        "path": {
632                            "type": "string",
633                            "description": "Path to search in (default: current directory)"
634                        }
635                    },
636                    "required": ["pattern"]
637                }),
638            ),
639            "ast_grep" => (
640                "Search code using AST patterns (structural code matching)".to_string(),
641                json!({
642                    "type": "object",
643                    "properties": {
644                        "pattern": {
645                            "type": "string",
646                            "description": "AST pattern to match"
647                        },
648                        "language": {
649                            "type": "string",
650                            "description": "Programming language (rust, typescript, python)"
651                        },
652                        "path": {
653                            "type": "string",
654                            "description": "Path to search in (default: current directory)"
655                        }
656                    },
657                    "required": ["pattern", "language"]
658                }),
659            ),
660            "lsp" => (
661                "Perform Language Server Protocol operations (goto_definition, find_references)"
662                    .to_string(),
663                json!({
664                    "type": "object",
665                    "properties": {
666                        "command": {
667                            "type": "string",
668                            "description": "LSP command: goto_definition or find_references"
669                        },
670                        "file_path": {
671                            "type": "string",
672                            "description": "Path to the file"
673                        },
674                        "position": {
675                            "type": "object",
676                            "description": "Position in the file (line, character)",
677                            "properties": {
678                                "line": {"type": "integer"},
679                                "character": {"type": "integer"}
680                            },
681                            "required": ["line", "character"]
682                        }
683                    },
684                    "required": ["command", "file_path", "position"]
685                }),
686            ),
687            _ => (
688                format!("Tool: {}", name),
689                json!({
690                    "type": "object",
691                    "properties": {},
692                    "required": []
693                }),
694            ),
695        }
696    }
697
698    /// Send an event through the event channel
699    fn send_event(&self, event: AgentEvent) {
700        if let Some(ref tx) = self.event_tx {
701            let _ = tx.send(event);
702        }
703    }
704
705    /// Check if the bridge is ready to process messages
706    #[allow(dead_code)]
707    pub fn is_ready(&self) -> bool {
708        self.config
709            .providers
710            .get(&self.config.provider)
711            .map(|p| p.api_key_or_env(&self.config.provider).is_some())
712            .unwrap_or(false)
713    }
714
715    /// Get the current model name
716    pub fn model(&self) -> &str {
717        self.config
718            .providers
719            .get(&self.config.provider)
720            .map(|p| p.model.as_str())
721            .unwrap_or("")
722    }
723
724    /// Get the max tokens setting
725    pub fn max_tokens(&self) -> u32 {
726        self.config
727            .providers
728            .get(&self.config.provider)
729            .map(|p| p.max_tokens)
730            .unwrap_or(4096)
731    }
732
733    /// Get the timeout setting
734    pub fn timeout(&self) -> u64 {
735        self.config
736            .providers
737            .get(&self.config.provider)
738            .map(|p| p.timeout)
739            .unwrap_or(60)
740    }
741}
742/// Calculate cost based on model pricing (per 1M tokens)
743fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
744    let (input_price, output_price) = match model {
745        // Claude 3.5 Sonnet: $3/1M input, $15/1M output
746        "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
747        // GPT-4: $30/1M input, $60/1M output
748        "gpt-4" => (30.0, 60.0),
749        // GPT-4 Turbo: $10/1M input, $30/1M output
750        "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
751        // Default: no cost tracking
752        _ => (0.0, 0.0),
753    };
754    (input_tokens as f64 * input_price / 1_000_000.0)
755        + (output_tokens as f64 * output_price / 1_000_000.0)
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761    use limit_llm::{Config as LlmConfig, ProviderConfig};
762    use std::collections::HashMap;
763
764    #[tokio::test]
765    async fn test_agent_bridge_new() {
766        let mut providers = HashMap::new();
767        providers.insert(
768            "anthropic".to_string(),
769            ProviderConfig {
770                api_key: Some("test-key".to_string()),
771                model: "claude-3-5-sonnet-20241022".to_string(),
772                base_url: None,
773                max_tokens: 4096,
774                timeout: 60,
775                max_iterations: 100,
776                thinking_enabled: false,
777                clear_thinking: true,
778            },
779        );
780        let config = LlmConfig {
781            provider: "anthropic".to_string(),
782            providers,
783        };
784
785        let bridge = AgentBridge::new(config).unwrap();
786        assert!(bridge.is_ready());
787    }
788
789    #[tokio::test]
790    async fn test_agent_bridge_new_no_api_key() {
791        let mut providers = HashMap::new();
792        providers.insert(
793            "anthropic".to_string(),
794            ProviderConfig {
795                api_key: None,
796                model: "claude-3-5-sonnet-20241022".to_string(),
797                base_url: None,
798                max_tokens: 4096,
799                timeout: 60,
800                max_iterations: 100,
801                thinking_enabled: false,
802                clear_thinking: true,
803            },
804        );
805        let config = LlmConfig {
806            provider: "anthropic".to_string(),
807            providers,
808        };
809
810        let result = AgentBridge::new(config);
811        assert!(result.is_err());
812    }
813
814    #[tokio::test]
815    async fn test_get_tool_definitions() {
816        let mut providers = HashMap::new();
817        providers.insert(
818            "anthropic".to_string(),
819            ProviderConfig {
820                api_key: None,
821                model: "claude-3-5-sonnet-20241022".to_string(),
822                base_url: None,
823                max_tokens: 4096,
824                timeout: 60,
825                max_iterations: 100,
826                thinking_enabled: false,
827                clear_thinking: true,
828            },
829        );
830        let config = LlmConfig {
831            provider: "anthropic".to_string(),
832            providers,
833        };
834
835        let bridge = AgentBridge::new(config).unwrap();
836        let definitions = bridge.get_tool_definitions();
837
838        assert_eq!(definitions.len(), 15);
839
840        // Check file_read tool definition
841        let file_read = definitions
842            .iter()
843            .find(|d| d.function.name == "file_read")
844            .unwrap();
845        assert_eq!(file_read.tool_type, "function");
846        assert_eq!(file_read.function.name, "file_read");
847        assert!(file_read.function.description.contains("Read"));
848
849        // Check bash tool definition
850        let bash = definitions
851            .iter()
852            .find(|d| d.function.name == "bash")
853            .unwrap();
854        assert_eq!(bash.function.name, "bash");
855        assert!(bash.function.parameters["required"]
856            .as_array()
857            .unwrap()
858            .contains(&"command".into()));
859    }
860
861    #[test]
862    fn test_get_tool_schema() {
863        let (desc, params) = AgentBridge::get_tool_schema("file_read");
864        assert!(desc.contains("Read"));
865        assert_eq!(params["properties"]["path"]["type"], "string");
866        assert!(params["required"]
867            .as_array()
868            .unwrap()
869            .contains(&"path".into()));
870
871        let (desc, params) = AgentBridge::get_tool_schema("bash");
872        assert!(desc.contains("bash"));
873        assert_eq!(params["properties"]["command"]["type"], "string");
874
875        let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
876        assert!(desc.contains("unknown_tool"));
877    }
878
879    #[test]
880    fn test_is_ready() {
881        let mut providers = HashMap::new();
882        providers.insert(
883            "anthropic".to_string(),
884            ProviderConfig {
885                api_key: Some("test-key".to_string()),
886                model: "claude-3-5-sonnet-20241022".to_string(),
887                base_url: None,
888                max_tokens: 4096,
889                timeout: 60,
890                max_iterations: 100,
891                thinking_enabled: false,
892                clear_thinking: true,
893            },
894        );
895        let config_with_key = LlmConfig {
896            provider: "anthropic".to_string(),
897            providers,
898        };
899
900        let bridge = AgentBridge::new(config_with_key).unwrap();
901        assert!(bridge.is_ready());
902    }
903}