Skip to main content

limit_cli/
agent_bridge.rs

1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4    AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5    GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6    LspTool,
7};
8use futures::StreamExt;
9use limit_agent::executor::{ToolCall, ToolExecutor};
10use limit_agent::registry::ToolRegistry;
11use limit_llm::providers::LlmProvider;
12use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
13use limit_llm::ProviderFactory;
14use limit_llm::ProviderResponseChunk;
15use limit_llm::TrackingDb;
16use serde_json::json;
17use tokio::sync::mpsc;
18use tracing::{debug, instrument};
19
20/// Event types for streaming from agent to REPL
21#[derive(Debug, Clone)]
22#[allow(dead_code)]
23pub enum AgentEvent {
24    Thinking,
25    ToolStart {
26        name: String,
27        args: serde_json::Value,
28    },
29    ToolComplete {
30        name: String,
31        result: String,
32    },
33    ContentChunk(String),
34    Done,
35    Error(String),
36    TokenUsage {
37        input_tokens: u64,
38        output_tokens: u64,
39    },
40}
41
42/// Bridge connecting limit-cli REPL to limit-agent executor and limit-llm client
43pub struct AgentBridge {
44    /// LLM client for communicating with LLM providers
45    llm_client: Box<dyn LlmProvider>,
46    /// Tool executor for running tool calls
47    executor: ToolExecutor,
48    /// List of registered tool names
49    tool_names: Vec<&'static str>,
50    /// Configuration loaded from ~/.limit/config.toml
51    config: limit_llm::Config,
52    /// Event sender for streaming events to REPL
53    event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
54    /// Token usage tracking database
55    tracking_db: TrackingDb,
56}
57
58impl AgentBridge {
59    /// Create a new AgentBridge with the given configuration
60    ///
61    /// # Arguments
62    /// * `config` - LLM configuration (API key, model, etc.)
63    ///
64    /// # Returns
65    /// A new AgentBridge instance or an error if initialization fails
66    pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
67        let llm_client = ProviderFactory::create_provider(&config)
68            .map_err(|e| CliError::ConfigError(e.to_string()))?;
69
70        let mut tool_registry = ToolRegistry::new();
71        Self::register_tools(&mut tool_registry);
72
73        // Create executor (which takes ownership of registry as Arc)
74        let executor = ToolExecutor::new(tool_registry);
75
76        // Generate tool definitions before giving ownership to executor
77        let tool_names = vec![
78            "file_read",
79            "file_write",
80            "file_edit",
81            "bash",
82            "git_status",
83            "git_diff",
84            "git_log",
85            "git_add",
86            "git_commit",
87            "git_push",
88            "git_pull",
89            "git_clone",
90            "grep",
91            "ast_grep",
92            "lsp",
93        ];
94
95        Ok(Self {
96            llm_client,
97            executor,
98            tool_names,
99            config,
100            event_tx: None,
101            tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
102        })
103    }
104
105    /// Set the event channel sender for streaming events
106    pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
107        self.event_tx = Some(tx);
108    }
109
110    /// Register all CLI tools into the tool registry
111    fn register_tools(registry: &mut ToolRegistry) {
112        // File tools
113        registry
114            .register(FileReadTool::new())
115            .expect("Failed to register file_read");
116        registry
117            .register(FileWriteTool::new())
118            .expect("Failed to register file_write");
119        registry
120            .register(FileEditTool::new())
121            .expect("Failed to register file_edit");
122
123        // Bash tool
124        registry
125            .register(BashTool::new())
126            .expect("Failed to register bash");
127
128        // Git tools
129        registry
130            .register(GitStatusTool::new())
131            .expect("Failed to register git_status");
132        registry
133            .register(GitDiffTool::new())
134            .expect("Failed to register git_diff");
135        registry
136            .register(GitLogTool::new())
137            .expect("Failed to register git_log");
138        registry
139            .register(GitAddTool::new())
140            .expect("Failed to register git_add");
141        registry
142            .register(GitCommitTool::new())
143            .expect("Failed to register git_commit");
144        registry
145            .register(GitPushTool::new())
146            .expect("Failed to register git_push");
147        registry
148            .register(GitPullTool::new())
149            .expect("Failed to register git_pull");
150        registry
151            .register(GitCloneTool::new())
152            .expect("Failed to register git_clone");
153
154        // Analysis tools
155        registry
156            .register(GrepTool::new())
157            .expect("Failed to register grep");
158        registry
159            .register(AstGrepTool::new())
160            .expect("Failed to register ast_grep");
161        registry
162            .register(LspTool::new())
163            .expect("Failed to register lsp");
164    }
165
166    /// Process a user message through the LLM and execute any tool calls
167    ///
168    /// # Arguments
169    /// * `user_input` - The user's message to process
170    /// * `messages` - The conversation history (will be updated in place)
171    ///
172    /// # Returns
173    /// The final response from the LLM or an error
174    #[instrument(skip(self, _messages))]
175    pub async fn process_message(
176        &mut self,
177        user_input: &str,
178        _messages: &mut Vec<Message>,
179    ) -> Result<String, CliError> {
180        // Add system message if this is the first message in the conversation
181        // Note: Some providers (z.ai) don't support system role, but OpenAI-compatible APIs generally do
182        if _messages.is_empty() {
183            let system_message = Message {
184                role: Role::System,
185                content: Some(SYSTEM_PROMPT.to_string()),
186                tool_calls: None,
187                tool_call_id: None,
188            };
189            _messages.push(system_message);
190        }
191
192        // Add user message to history
193        let user_message = Message {
194            role: Role::User,
195            content: Some(user_input.to_string()),
196            tool_calls: None,
197            tool_call_id: None,
198        };
199        _messages.push(user_message);
200
201        // Get tool definitions
202        let tool_definitions = self.get_tool_definitions();
203
204        // Main processing loop
205        let mut full_response = String::new();
206        let mut tool_calls: Vec<LlmToolCall> = Vec::new();
207        let max_iterations = 30; // Allow enough iterations for complex tasks
208        let mut iteration = 0;
209
210        while iteration < max_iterations {
211            iteration += 1;
212            debug!("Agent loop iteration {}", iteration);
213
214            // Send thinking event
215            self.send_event(AgentEvent::Thinking);
216
217            // Track timing for token usage
218            let request_start = std::time::Instant::now();
219
220            // Call LLM
221            let mut stream = self
222                .llm_client
223                .send(_messages.clone(), tool_definitions.clone())
224                .await
225                .map_err(|e| CliError::ConfigError(e.to_string()))?;
226
227            tool_calls.clear();
228            let mut current_content = String::new();
229            // Track tool calls: (id) -> (name, args)
230            let mut accumulated_calls: std::collections::HashMap<
231                String,
232                (String, serde_json::Value),
233            > = std::collections::HashMap::new();
234
235            // Process stream chunks
236            while let Some(chunk_result) = stream.next().await {
237                match chunk_result {
238                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
239                        current_content.push_str(&text);
240                        self.send_event(AgentEvent::ContentChunk(text));
241                    }
242                    Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
243                        // Ignore reasoning chunks for now
244                    }
245                    Ok(ProviderResponseChunk::ToolCallDelta {
246                        id,
247                        name,
248                        arguments,
249                    }) => {
250                        debug!("ToolCallDelta: id={}, name={}", id, name);
251                        // Store/merge tool call arguments
252                        accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
253                    }
254                    Ok(ProviderResponseChunk::Done(usage)) => {
255                        // Track token usage
256                        let duration_ms = request_start.elapsed().as_millis() as u64;
257                        let cost =
258                            calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
259                        let _ = self.tracking_db.track_request(
260                            self.model(),
261                            usage.input_tokens,
262                            usage.output_tokens,
263                            cost,
264                            duration_ms,
265                        );
266                        // Emit token usage event for TUI display
267                        self.send_event(AgentEvent::TokenUsage {
268                            input_tokens: usage.input_tokens,
269                            output_tokens: usage.output_tokens,
270                        });
271                        break;
272                    }
273                    Err(e) => {
274                        let error_msg = format!("LLM error: {}", e);
275                        self.send_event(AgentEvent::Error(error_msg.clone()));
276                        return Err(CliError::ConfigError(error_msg));
277                    }
278                }
279            }
280
281            // Convert accumulated calls to Vec<ToolCall>
282            tool_calls = accumulated_calls
283                .into_iter()
284                .map(|(id, (name, args))| LlmToolCall {
285                    id,
286                    tool_type: "function".to_string(),
287                    function: limit_llm::types::FunctionCall {
288                        name,
289                        arguments: args.to_string(),
290                    },
291                })
292                .collect();
293            full_response.push_str(&current_content);
294
295            debug!(
296                "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
297                iteration,
298                current_content.len(),
299                tool_calls.len(),
300                full_response.len()
301            );
302
303            // If no tool calls, we're done
304            if tool_calls.is_empty() {
305                break;
306            }
307
308            // Execute tool calls - add assistant message with tool_calls
309            // Note: Per OpenAI API spec, when tool_calls are present, content should be null
310            let assistant_message = Message {
311                role: Role::Assistant,
312                content: None, // Don't include content when tool_calls are present
313                tool_calls: Some(tool_calls.clone()),
314                tool_call_id: None,
315            };
316            _messages.push(assistant_message);
317
318            // Convert LLM tool calls to executor tool calls
319            let executor_calls: Vec<ToolCall> = tool_calls
320                .iter()
321                .map(|tc| {
322                    let args: serde_json::Value =
323                        serde_json::from_str(&tc.function.arguments).unwrap_or_default();
324                    ToolCall::new(&tc.id, &tc.function.name, args)
325                })
326                .collect();
327
328            // Send ToolStart event for each tool BEFORE execution
329            for tc in &tool_calls {
330                let args: serde_json::Value =
331                    serde_json::from_str(&tc.function.arguments).unwrap_or_default();
332                self.send_event(AgentEvent::ToolStart {
333                    name: tc.function.name.clone(),
334                    args,
335                });
336            }
337            // Execute tools
338            let results = self.executor.execute_tools(executor_calls).await;
339
340            // Add tool results to messages (OpenAI format: role=tool, tool_call_id, content)
341            for result in results {
342                let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
343                if let Some(tool_call) = tool_call {
344                    let output_json = match &result.output {
345                        Ok(value) => {
346                            serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
347                        }
348                        Err(e) => json!({ "error": e.to_string() }).to_string(),
349                    };
350
351                    self.send_event(AgentEvent::ToolComplete {
352                        name: tool_call.function.name.clone(),
353                        result: output_json.clone(),
354                    });
355
356                    // OpenAI tool result format
357                    let tool_result_message = Message {
358                        role: Role::Tool,
359                        content: Some(output_json),
360                        tool_calls: None,
361                        tool_call_id: Some(result.call_id),
362                    };
363                    _messages.push(tool_result_message);
364                }
365            }
366        }
367
368        // If we hit max iterations, make one final request to get a response (no tools = forced text)
369        if iteration >= max_iterations && !_messages.is_empty() {
370            debug!("Making final LLM call after hitting max iterations (forcing text response)");
371
372            // Add constraint message to force text response
373            let constraint_message = Message {
374                role: Role::User,
375                content: Some(
376                    "We've reached the iteration limit. Please provide a summary of:\n\
377                    1. What you've completed so far\n\
378                    2. What remains to be done\n\
379                    3. Recommended next steps for the user to continue"
380                        .to_string(),
381                ),
382                tool_calls: None,
383                tool_call_id: None,
384            };
385            _messages.push(constraint_message);
386
387            // Send with NO tools to force text response
388            let no_tools: Vec<LlmTool> = vec![];
389            let mut stream = self
390                .llm_client
391                .send(_messages.clone(), no_tools)
392                .await
393                .map_err(|e| CliError::ConfigError(e.to_string()))?;
394
395            while let Some(chunk_result) = stream.next().await {
396                match chunk_result {
397                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
398                        full_response.push_str(&text);
399                        self.send_event(AgentEvent::ContentChunk(text));
400                    }
401                    Ok(ProviderResponseChunk::Done(_)) => {
402                        break;
403                    }
404                    Err(e) => {
405                        debug!("Error in final LLM call: {}", e);
406                        break;
407                    }
408                    _ => {}
409                }
410            }
411        }
412
413        self.send_event(AgentEvent::Done);
414        Ok(full_response)
415    }
416
417    /// Get tool definitions formatted for the LLM
418    pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
419        self.tool_names
420            .iter()
421            .map(|name| {
422                let (description, parameters) = Self::get_tool_schema(name);
423                LlmTool {
424                    tool_type: "function".to_string(),
425                    function: limit_llm::types::ToolFunction {
426                        name: name.to_string(),
427                        description,
428                        parameters,
429                    },
430                }
431            })
432            .collect()
433    }
434
435    /// Get the schema (description and parameters) for a tool
436    fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
437        match name {
438            "file_read" => (
439                "Read the contents of a file".to_string(),
440                json!({
441                    "type": "object",
442                    "properties": {
443                        "path": {
444                            "type": "string",
445                            "description": "Path to the file to read"
446                        }
447                    },
448                    "required": ["path"]
449                }),
450            ),
451            "file_write" => (
452                "Write content to a file, creating parent directories if needed".to_string(),
453                json!({
454                    "type": "object",
455                    "properties": {
456                        "path": {
457                            "type": "string",
458                            "description": "Path to the file to write"
459                        },
460                        "content": {
461                            "type": "string",
462                            "description": "Content to write to the file"
463                        }
464                    },
465                    "required": ["path", "content"]
466                }),
467            ),
468            "file_edit" => (
469                "Replace text in a file with new text".to_string(),
470                json!({
471                    "type": "object",
472                    "properties": {
473                        "path": {
474                            "type": "string",
475                            "description": "Path to the file to edit"
476                        },
477                        "old_text": {
478                            "type": "string",
479                            "description": "Text to find and replace"
480                        },
481                        "new_text": {
482                            "type": "string",
483                            "description": "New text to replace with"
484                        }
485                    },
486                    "required": ["path", "old_text", "new_text"]
487                }),
488            ),
489            "bash" => (
490                "Execute a bash command in a shell".to_string(),
491                json!({
492                    "type": "object",
493                    "properties": {
494                        "command": {
495                            "type": "string",
496                            "description": "Bash command to execute"
497                        },
498                        "workdir": {
499                            "type": "string",
500                            "description": "Working directory (default: current directory)"
501                        },
502                        "timeout": {
503                            "type": "integer",
504                            "description": "Timeout in seconds (default: 60)"
505                        }
506                    },
507                    "required": ["command"]
508                }),
509            ),
510            "git_status" => (
511                "Get git repository status".to_string(),
512                json!({
513                    "type": "object",
514                    "properties": {},
515                    "required": []
516                }),
517            ),
518            "git_diff" => (
519                "Get git diff".to_string(),
520                json!({
521                    "type": "object",
522                    "properties": {},
523                    "required": []
524                }),
525            ),
526            "git_log" => (
527                "Get git commit log".to_string(),
528                json!({
529                    "type": "object",
530                    "properties": {
531                        "count": {
532                            "type": "integer",
533                            "description": "Number of commits to show (default: 10)"
534                        }
535                    },
536                    "required": []
537                }),
538            ),
539            "git_add" => (
540                "Add files to git staging area".to_string(),
541                json!({
542                    "type": "object",
543                    "properties": {
544                        "files": {
545                            "type": "array",
546                            "items": {"type": "string"},
547                            "description": "List of file paths to add"
548                        }
549                    },
550                    "required": ["files"]
551                }),
552            ),
553            "git_commit" => (
554                "Create a git commit".to_string(),
555                json!({
556                    "type": "object",
557                    "properties": {
558                        "message": {
559                            "type": "string",
560                            "description": "Commit message"
561                        }
562                    },
563                    "required": ["message"]
564                }),
565            ),
566            "git_push" => (
567                "Push commits to remote repository".to_string(),
568                json!({
569                    "type": "object",
570                    "properties": {
571                        "remote": {
572                            "type": "string",
573                            "description": "Remote name (default: origin)"
574                        },
575                        "branch": {
576                            "type": "string",
577                            "description": "Branch name (default: current branch)"
578                        }
579                    },
580                    "required": []
581                }),
582            ),
583            "git_pull" => (
584                "Pull changes from remote repository".to_string(),
585                json!({
586                    "type": "object",
587                    "properties": {
588                        "remote": {
589                            "type": "string",
590                            "description": "Remote name (default: origin)"
591                        },
592                        "branch": {
593                            "type": "string",
594                            "description": "Branch name (default: current branch)"
595                        }
596                    },
597                    "required": []
598                }),
599            ),
600            "git_clone" => (
601                "Clone a git repository".to_string(),
602                json!({
603                    "type": "object",
604                    "properties": {
605                        "url": {
606                            "type": "string",
607                            "description": "Repository URL to clone"
608                        },
609                        "directory": {
610                            "type": "string",
611                            "description": "Directory to clone into (optional)"
612                        }
613                    },
614                    "required": ["url"]
615                }),
616            ),
617            "grep" => (
618                "Search for text patterns in files using regex".to_string(),
619                json!({
620                    "type": "object",
621                    "properties": {
622                        "pattern": {
623                            "type": "string",
624                            "description": "Regex pattern to search for"
625                        },
626                        "path": {
627                            "type": "string",
628                            "description": "Path to search in (default: current directory)"
629                        }
630                    },
631                    "required": ["pattern"]
632                }),
633            ),
634            "ast_grep" => (
635                "Search code using AST patterns (structural code matching)".to_string(),
636                json!({
637                    "type": "object",
638                    "properties": {
639                        "pattern": {
640                            "type": "string",
641                            "description": "AST pattern to match"
642                        },
643                        "language": {
644                            "type": "string",
645                            "description": "Programming language (rust, typescript, python)"
646                        },
647                        "path": {
648                            "type": "string",
649                            "description": "Path to search in (default: current directory)"
650                        }
651                    },
652                    "required": ["pattern", "language"]
653                }),
654            ),
655            "lsp" => (
656                "Perform Language Server Protocol operations (goto_definition, find_references)"
657                    .to_string(),
658                json!({
659                    "type": "object",
660                    "properties": {
661                        "command": {
662                            "type": "string",
663                            "description": "LSP command: goto_definition or find_references"
664                        },
665                        "file_path": {
666                            "type": "string",
667                            "description": "Path to the file"
668                        },
669                        "position": {
670                            "type": "object",
671                            "description": "Position in the file (line, character)",
672                            "properties": {
673                                "line": {"type": "integer"},
674                                "character": {"type": "integer"}
675                            },
676                            "required": ["line", "character"]
677                        }
678                    },
679                    "required": ["command", "file_path", "position"]
680                }),
681            ),
682            _ => (
683                format!("Tool: {}", name),
684                json!({
685                    "type": "object",
686                    "properties": {},
687                    "required": []
688                }),
689            ),
690        }
691    }
692
693    /// Send an event through the event channel
694    fn send_event(&self, event: AgentEvent) {
695        if let Some(ref tx) = self.event_tx {
696            let _ = tx.send(event);
697        }
698    }
699
700    /// Check if the bridge is ready to process messages
701    #[allow(dead_code)]
702    pub fn is_ready(&self) -> bool {
703        self.config
704            .providers
705            .get(&self.config.provider)
706            .map(|p| p.api_key_or_env(&self.config.provider).is_some())
707            .unwrap_or(false)
708    }
709
710    /// Get the current model name
711    pub fn model(&self) -> &str {
712        self.config
713            .providers
714            .get(&self.config.provider)
715            .map(|p| p.model.as_str())
716            .unwrap_or("")
717    }
718
719    /// Get the max tokens setting
720    pub fn max_tokens(&self) -> u32 {
721        self.config
722            .providers
723            .get(&self.config.provider)
724            .map(|p| p.max_tokens)
725            .unwrap_or(4096)
726    }
727
728    /// Get the timeout setting
729    pub fn timeout(&self) -> u64 {
730        self.config
731            .providers
732            .get(&self.config.provider)
733            .map(|p| p.timeout)
734            .unwrap_or(60)
735    }
736}
737/// Calculate cost based on model pricing (per 1M tokens)
738fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
739    let (input_price, output_price) = match model {
740        // Claude 3.5 Sonnet: $3/1M input, $15/1M output
741        "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
742        // GPT-4: $30/1M input, $60/1M output
743        "gpt-4" => (30.0, 60.0),
744        // GPT-4 Turbo: $10/1M input, $30/1M output
745        "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
746        // Default: no cost tracking
747        _ => (0.0, 0.0),
748    };
749    (input_tokens as f64 * input_price / 1_000_000.0)
750        + (output_tokens as f64 * output_price / 1_000_000.0)
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756    use limit_llm::{Config as LlmConfig, ProviderConfig};
757    use std::collections::HashMap;
758
759    #[tokio::test]
760    async fn test_agent_bridge_new() {
761        let mut providers = HashMap::new();
762        providers.insert(
763            "anthropic".to_string(),
764            ProviderConfig {
765                api_key: Some("test-key".to_string()),
766                model: "claude-3-5-sonnet-20241022".to_string(),
767                base_url: None,
768                max_tokens: 4096,
769                timeout: 60,
770            },
771        );
772        let config = LlmConfig {
773            provider: "anthropic".to_string(),
774            providers,
775        };
776
777        let bridge = AgentBridge::new(config).unwrap();
778        assert!(bridge.is_ready());
779    }
780
781    #[tokio::test]
782    async fn test_agent_bridge_new_no_api_key() {
783        let mut providers = HashMap::new();
784        providers.insert(
785            "anthropic".to_string(),
786            ProviderConfig {
787                api_key: None,
788                model: "claude-3-5-sonnet-20241022".to_string(),
789                base_url: None,
790                max_tokens: 4096,
791                timeout: 60,
792            },
793        );
794        let config = LlmConfig {
795            provider: "anthropic".to_string(),
796            providers,
797        };
798
799        let result = AgentBridge::new(config);
800        assert!(result.is_err());
801    }
802
803    #[tokio::test]
804    async fn test_get_tool_definitions() {
805        let mut providers = HashMap::new();
806        providers.insert(
807            "anthropic".to_string(),
808            ProviderConfig {
809                api_key: Some("test-key".to_string()),
810                model: "claude-3-5-sonnet-20241022".to_string(),
811                base_url: None,
812                max_tokens: 4096,
813                timeout: 60,
814            },
815        );
816        let config = LlmConfig {
817            provider: "anthropic".to_string(),
818            providers,
819        };
820
821        let bridge = AgentBridge::new(config).unwrap();
822        let definitions = bridge.get_tool_definitions();
823
824        assert_eq!(definitions.len(), 15);
825
826        // Check file_read tool definition
827        let file_read = definitions
828            .iter()
829            .find(|d| d.function.name == "file_read")
830            .unwrap();
831        assert_eq!(file_read.tool_type, "function");
832        assert_eq!(file_read.function.name, "file_read");
833        assert!(file_read.function.description.contains("Read"));
834
835        // Check bash tool definition
836        let bash = definitions
837            .iter()
838            .find(|d| d.function.name == "bash")
839            .unwrap();
840        assert_eq!(bash.function.name, "bash");
841        assert!(bash.function.parameters["required"]
842            .as_array()
843            .unwrap()
844            .contains(&"command".into()));
845    }
846
847    #[test]
848    fn test_get_tool_schema() {
849        let (desc, params) = AgentBridge::get_tool_schema("file_read");
850        assert!(desc.contains("Read"));
851        assert_eq!(params["properties"]["path"]["type"], "string");
852        assert!(params["required"]
853            .as_array()
854            .unwrap()
855            .contains(&"path".into()));
856
857        let (desc, params) = AgentBridge::get_tool_schema("bash");
858        assert!(desc.contains("bash"));
859        assert_eq!(params["properties"]["command"]["type"], "string");
860
861        let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
862        assert!(desc.contains("unknown_tool"));
863    }
864
865    #[test]
866    fn test_is_ready() {
867        let mut providers = HashMap::new();
868        providers.insert(
869            "anthropic".to_string(),
870            ProviderConfig {
871                api_key: Some("test-key".to_string()),
872                model: "claude-3-5-sonnet-20241022".to_string(),
873                base_url: None,
874                max_tokens: 4096,
875                timeout: 60,
876            },
877        );
878        let config_with_key = LlmConfig {
879            provider: "anthropic".to_string(),
880            providers,
881        };
882
883        let bridge = AgentBridge::new(config_with_key).unwrap();
884        assert!(bridge.is_ready());
885    }
886}