Skip to main content

limit_cli/
agent_bridge.rs

1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4    AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5    GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6    LspTool, WebFetchTool, WebSearchTool,
7};
8use chrono::Datelike;
9use futures::StreamExt;
10use limit_agent::executor::{ToolCall, ToolExecutor};
11use limit_agent::registry::ToolRegistry;
12use limit_llm::providers::LlmProvider;
13use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
14use limit_llm::ProviderFactory;
15use limit_llm::ProviderResponseChunk;
16use limit_llm::TrackingDb;
17use serde_json::json;
18use tokio::sync::mpsc;
19use tracing::{debug, instrument};
20
21/// Event types for streaming from agent to REPL
22#[derive(Debug, Clone)]
23#[allow(dead_code)]
24pub enum AgentEvent {
25    Thinking,
26    ToolStart {
27        name: String,
28        args: serde_json::Value,
29    },
30    ToolComplete {
31        name: String,
32        result: String,
33    },
34    ContentChunk(String),
35    Done,
36    Error(String),
37    TokenUsage {
38        input_tokens: u64,
39        output_tokens: u64,
40    },
41}
42
43/// Bridge connecting limit-cli REPL to limit-agent executor and limit-llm client
44pub struct AgentBridge {
45    /// LLM client for communicating with LLM providers
46    llm_client: Box<dyn LlmProvider>,
47    /// Tool executor for running tool calls
48    executor: ToolExecutor,
49    /// List of registered tool names
50    tool_names: Vec<&'static str>,
51    /// Configuration loaded from ~/.limit/config.toml
52    config: limit_llm::Config,
53    /// Event sender for streaming events to REPL
54    event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
55    /// Token usage tracking database
56    tracking_db: TrackingDb,
57}
58
59impl AgentBridge {
60    /// Create a new AgentBridge with the given configuration
61    ///
62    /// # Arguments
63    /// * `config` - LLM configuration (API key, model, etc.)
64    ///
65    /// # Returns
66    /// A new AgentBridge instance or an error if initialization fails
67    pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
68        let llm_client = ProviderFactory::create_provider(&config)
69            .map_err(|e| CliError::ConfigError(e.to_string()))?;
70
71        let mut tool_registry = ToolRegistry::new();
72        Self::register_tools(&mut tool_registry);
73
74        // Create executor (which takes ownership of registry as Arc)
75        let executor = ToolExecutor::new(tool_registry);
76
77        // Generate tool definitions before giving ownership to executor
78        let tool_names = vec![
79            "file_read",
80            "file_write",
81            "file_edit",
82            "bash",
83            "git_status",
84            "git_diff",
85            "git_log",
86            "git_add",
87            "git_commit",
88            "git_push",
89            "git_pull",
90            "git_clone",
91            "grep",
92            "ast_grep",
93            "lsp",
94            "web_search",
95            "web_fetch",
96        ];
97
98        Ok(Self {
99            llm_client,
100            executor,
101            tool_names,
102            config,
103            event_tx: None,
104            tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
105        })
106    }
107
108    /// Set the event channel sender for streaming events
109    pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
110        self.event_tx = Some(tx);
111    }
112
113    /// Register all CLI tools into the tool registry
114    fn register_tools(registry: &mut ToolRegistry) {
115        // File tools
116        registry
117            .register(FileReadTool::new())
118            .expect("Failed to register file_read");
119        registry
120            .register(FileWriteTool::new())
121            .expect("Failed to register file_write");
122        registry
123            .register(FileEditTool::new())
124            .expect("Failed to register file_edit");
125
126        // Bash tool
127        registry
128            .register(BashTool::new())
129            .expect("Failed to register bash");
130
131        // Git tools
132        registry
133            .register(GitStatusTool::new())
134            .expect("Failed to register git_status");
135        registry
136            .register(GitDiffTool::new())
137            .expect("Failed to register git_diff");
138        registry
139            .register(GitLogTool::new())
140            .expect("Failed to register git_log");
141        registry
142            .register(GitAddTool::new())
143            .expect("Failed to register git_add");
144        registry
145            .register(GitCommitTool::new())
146            .expect("Failed to register git_commit");
147        registry
148            .register(GitPushTool::new())
149            .expect("Failed to register git_push");
150        registry
151            .register(GitPullTool::new())
152            .expect("Failed to register git_pull");
153        registry
154            .register(GitCloneTool::new())
155            .expect("Failed to register git_clone");
156
157        // Analysis tools
158        registry
159            .register(GrepTool::new())
160            .expect("Failed to register grep");
161        registry
162            .register(AstGrepTool::new())
163            .expect("Failed to register ast_grep");
164        registry
165            .register(LspTool::new())
166            .expect("Failed to register lsp");
167
168        // Web tools
169        registry
170            .register(WebSearchTool::new())
171            .expect("Failed to register web_search");
172        registry
173            .register(WebFetchTool::new())
174            .expect("Failed to register web_fetch");
175    }
176
177    /// Process a user message through the LLM and execute any tool calls
178    ///
179    /// # Arguments
180    /// * `user_input` - The user's message to process
181    /// * `messages` - The conversation history (will be updated in place)
182    ///
183    /// # Returns
184    /// The final response from the LLM or an error
185    #[instrument(skip(self, _messages))]
186    pub async fn process_message(
187        &mut self,
188        user_input: &str,
189        _messages: &mut Vec<Message>,
190    ) -> Result<String, CliError> {
191        // Add system message if this is the first message in the conversation
192        // Note: Some providers (z.ai) don't support system role, but OpenAI-compatible APIs generally do
193        if _messages.is_empty() {
194            let system_message = Message {
195                role: Role::System,
196                content: Some(SYSTEM_PROMPT.to_string()),
197                tool_calls: None,
198                tool_call_id: None,
199            };
200            _messages.push(system_message);
201        }
202
203        // Add user message to history
204        let user_message = Message {
205            role: Role::User,
206            content: Some(user_input.to_string()),
207            tool_calls: None,
208            tool_call_id: None,
209        };
210        _messages.push(user_message);
211
212        // Get tool definitions
213        let tool_definitions = self.get_tool_definitions();
214
215        // Main processing loop
216        let mut full_response = String::new();
217        let mut tool_calls: Vec<LlmToolCall> = Vec::new();
218        let max_iterations = self
219            .config
220            .providers
221            .get(&self.config.provider)
222            .map(|p| p.max_iterations)
223            .unwrap_or(100); // Allow enough iterations for complex tasks
224        let mut iteration = 0;
225
226        while max_iterations == 0 || iteration < max_iterations {
227            iteration += 1;
228            debug!("Agent loop iteration {}", iteration);
229
230            // Send thinking event
231            self.send_event(AgentEvent::Thinking);
232
233            // Track timing for token usage
234            let request_start = std::time::Instant::now();
235
236            // Call LLM
237            let mut stream = self
238                .llm_client
239                .send(_messages.clone(), tool_definitions.clone())
240                .await
241                .map_err(|e| CliError::ConfigError(e.to_string()))?;
242
243            tool_calls.clear();
244            let mut current_content = String::new();
245            // Track tool calls: (id) -> (name, args)
246            let mut accumulated_calls: std::collections::HashMap<
247                String,
248                (String, serde_json::Value),
249            > = std::collections::HashMap::new();
250
251            // Process stream chunks
252            while let Some(chunk_result) = stream.next().await {
253                match chunk_result {
254                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
255                        current_content.push_str(&text);
256                        self.send_event(AgentEvent::ContentChunk(text));
257                    }
258                    Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
259                        // Ignore reasoning chunks for now
260                    }
261                    Ok(ProviderResponseChunk::ToolCallDelta {
262                        id,
263                        name,
264                        arguments,
265                    }) => {
266                        debug!("ToolCallDelta: id={}, name={}", id, name);
267                        // Store/merge tool call arguments
268                        accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
269                    }
270                    Ok(ProviderResponseChunk::Done(usage)) => {
271                        // Track token usage
272                        let duration_ms = request_start.elapsed().as_millis() as u64;
273                        let cost =
274                            calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
275                        let _ = self.tracking_db.track_request(
276                            self.model(),
277                            usage.input_tokens,
278                            usage.output_tokens,
279                            cost,
280                            duration_ms,
281                        );
282                        // Emit token usage event for TUI display
283                        self.send_event(AgentEvent::TokenUsage {
284                            input_tokens: usage.input_tokens,
285                            output_tokens: usage.output_tokens,
286                        });
287                        break;
288                    }
289                    Err(e) => {
290                        let error_msg = format!("LLM error: {}", e);
291                        self.send_event(AgentEvent::Error(error_msg.clone()));
292                        return Err(CliError::ConfigError(error_msg));
293                    }
294                }
295            }
296
297            // Convert accumulated calls to Vec<ToolCall>
298            tool_calls = accumulated_calls
299                .into_iter()
300                .map(|(id, (name, args))| LlmToolCall {
301                    id,
302                    tool_type: "function".to_string(),
303                    function: limit_llm::types::FunctionCall {
304                        name,
305                        arguments: args.to_string(),
306                    },
307                })
308                .collect();
309            full_response.push_str(&current_content);
310
311            debug!(
312                "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
313                iteration,
314                current_content.len(),
315                tool_calls.len(),
316                full_response.len()
317            );
318
319            // If no tool calls, we're done
320            if tool_calls.is_empty() {
321                break;
322            }
323
324            // Execute tool calls - add assistant message with tool_calls
325            // Note: Per OpenAI API spec, when tool_calls are present, content should be null
326            let assistant_message = Message {
327                role: Role::Assistant,
328                content: None, // Don't include content when tool_calls are present
329                tool_calls: Some(tool_calls.clone()),
330                tool_call_id: None,
331            };
332            _messages.push(assistant_message);
333
334            // Convert LLM tool calls to executor tool calls
335            let executor_calls: Vec<ToolCall> = tool_calls
336                .iter()
337                .map(|tc| {
338                    let args: serde_json::Value =
339                        serde_json::from_str(&tc.function.arguments).unwrap_or_default();
340                    ToolCall::new(&tc.id, &tc.function.name, args)
341                })
342                .collect();
343
344            // Send ToolStart event for each tool BEFORE execution
345            for tc in &tool_calls {
346                let args: serde_json::Value =
347                    serde_json::from_str(&tc.function.arguments).unwrap_or_default();
348                self.send_event(AgentEvent::ToolStart {
349                    name: tc.function.name.clone(),
350                    args,
351                });
352            }
353            // Execute tools
354            let results = self.executor.execute_tools(executor_calls).await;
355
356            // Add tool results to messages (OpenAI format: role=tool, tool_call_id, content)
357            for result in results {
358                let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
359                if let Some(tool_call) = tool_call {
360                    let output_json = match &result.output {
361                        Ok(value) => {
362                            serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
363                        }
364                        Err(e) => json!({ "error": e.to_string() }).to_string(),
365                    };
366
367                    self.send_event(AgentEvent::ToolComplete {
368                        name: tool_call.function.name.clone(),
369                        result: output_json.clone(),
370                    });
371
372                    // OpenAI tool result format
373                    let tool_result_message = Message {
374                        role: Role::Tool,
375                        content: Some(output_json),
376                        tool_calls: None,
377                        tool_call_id: Some(result.call_id),
378                    };
379                    _messages.push(tool_result_message);
380                }
381            }
382        }
383
384        // If we hit max iterations, make one final request to get a response (no tools = forced text)
385        if iteration >= max_iterations && !_messages.is_empty() {
386            debug!("Making final LLM call after hitting max iterations (forcing text response)");
387
388            // Add constraint message to force text response
389            let constraint_message = Message {
390                role: Role::User,
391                content: Some(
392                    "We've reached the iteration limit. Please provide a summary of:\n\
393                    1. What you've completed so far\n\
394                    2. What remains to be done\n\
395                    3. Recommended next steps for the user to continue"
396                        .to_string(),
397                ),
398                tool_calls: None,
399                tool_call_id: None,
400            };
401            _messages.push(constraint_message);
402
403            // Send with NO tools to force text response
404            let no_tools: Vec<LlmTool> = vec![];
405            let mut stream = self
406                .llm_client
407                .send(_messages.clone(), no_tools)
408                .await
409                .map_err(|e| CliError::ConfigError(e.to_string()))?;
410
411            while let Some(chunk_result) = stream.next().await {
412                match chunk_result {
413                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
414                        full_response.push_str(&text);
415                        self.send_event(AgentEvent::ContentChunk(text));
416                    }
417                    Ok(ProviderResponseChunk::Done(_)) => {
418                        break;
419                    }
420                    Err(e) => {
421                        debug!("Error in final LLM call: {}", e);
422                        break;
423                    }
424                    _ => {}
425                }
426            }
427        }
428
429        self.send_event(AgentEvent::Done);
430        Ok(full_response)
431    }
432
433    /// Get tool definitions formatted for the LLM
434    pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
435        self.tool_names
436            .iter()
437            .map(|name| {
438                let (description, parameters) = Self::get_tool_schema(name);
439                LlmTool {
440                    tool_type: "function".to_string(),
441                    function: limit_llm::types::ToolFunction {
442                        name: name.to_string(),
443                        description,
444                        parameters,
445                    },
446                }
447            })
448            .collect()
449    }
450
451    /// Get the schema (description and parameters) for a tool
452    fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
453        match name {
454            "file_read" => (
455                "Read the contents of a file".to_string(),
456                json!({
457                    "type": "object",
458                    "properties": {
459                        "path": {
460                            "type": "string",
461                            "description": "Path to the file to read"
462                        }
463                    },
464                    "required": ["path"]
465                }),
466            ),
467            "file_write" => (
468                "Write content to a file, creating parent directories if needed".to_string(),
469                json!({
470                    "type": "object",
471                    "properties": {
472                        "path": {
473                            "type": "string",
474                            "description": "Path to the file to write"
475                        },
476                        "content": {
477                            "type": "string",
478                            "description": "Content to write to the file"
479                        }
480                    },
481                    "required": ["path", "content"]
482                }),
483            ),
484            "file_edit" => (
485                "Replace text in a file with new text".to_string(),
486                json!({
487                    "type": "object",
488                    "properties": {
489                        "path": {
490                            "type": "string",
491                            "description": "Path to the file to edit"
492                        },
493                        "old_text": {
494                            "type": "string",
495                            "description": "Text to find and replace"
496                        },
497                        "new_text": {
498                            "type": "string",
499                            "description": "New text to replace with"
500                        }
501                    },
502                    "required": ["path", "old_text", "new_text"]
503                }),
504            ),
505            "bash" => (
506                "Execute a bash command in a shell".to_string(),
507                json!({
508                    "type": "object",
509                    "properties": {
510                        "command": {
511                            "type": "string",
512                            "description": "Bash command to execute"
513                        },
514                        "workdir": {
515                            "type": "string",
516                            "description": "Working directory (default: current directory)"
517                        },
518                        "timeout": {
519                            "type": "integer",
520                            "description": "Timeout in seconds (default: 60)"
521                        }
522                    },
523                    "required": ["command"]
524                }),
525            ),
526            "git_status" => (
527                "Get git repository status".to_string(),
528                json!({
529                    "type": "object",
530                    "properties": {},
531                    "required": []
532                }),
533            ),
534            "git_diff" => (
535                "Get git diff".to_string(),
536                json!({
537                    "type": "object",
538                    "properties": {},
539                    "required": []
540                }),
541            ),
542            "git_log" => (
543                "Get git commit log".to_string(),
544                json!({
545                    "type": "object",
546                    "properties": {
547                        "count": {
548                            "type": "integer",
549                            "description": "Number of commits to show (default: 10)"
550                        }
551                    },
552                    "required": []
553                }),
554            ),
555            "git_add" => (
556                "Add files to git staging area".to_string(),
557                json!({
558                    "type": "object",
559                    "properties": {
560                        "files": {
561                            "type": "array",
562                            "items": {"type": "string"},
563                            "description": "List of file paths to add"
564                        }
565                    },
566                    "required": ["files"]
567                }),
568            ),
569            "git_commit" => (
570                "Create a git commit".to_string(),
571                json!({
572                    "type": "object",
573                    "properties": {
574                        "message": {
575                            "type": "string",
576                            "description": "Commit message"
577                        }
578                    },
579                    "required": ["message"]
580                }),
581            ),
582            "git_push" => (
583                "Push commits to remote repository".to_string(),
584                json!({
585                    "type": "object",
586                    "properties": {
587                        "remote": {
588                            "type": "string",
589                            "description": "Remote name (default: origin)"
590                        },
591                        "branch": {
592                            "type": "string",
593                            "description": "Branch name (default: current branch)"
594                        }
595                    },
596                    "required": []
597                }),
598            ),
599            "git_pull" => (
600                "Pull changes from remote repository".to_string(),
601                json!({
602                    "type": "object",
603                    "properties": {
604                        "remote": {
605                            "type": "string",
606                            "description": "Remote name (default: origin)"
607                        },
608                        "branch": {
609                            "type": "string",
610                            "description": "Branch name (default: current branch)"
611                        }
612                    },
613                    "required": []
614                }),
615            ),
616            "git_clone" => (
617                "Clone a git repository".to_string(),
618                json!({
619                    "type": "object",
620                    "properties": {
621                        "url": {
622                            "type": "string",
623                            "description": "Repository URL to clone"
624                        },
625                        "directory": {
626                            "type": "string",
627                            "description": "Directory to clone into (optional)"
628                        }
629                    },
630                    "required": ["url"]
631                }),
632            ),
633            "grep" => (
634                "Search for text patterns in files using regex".to_string(),
635                json!({
636                    "type": "object",
637                    "properties": {
638                        "pattern": {
639                            "type": "string",
640                            "description": "Regex pattern to search for"
641                        },
642                        "path": {
643                            "type": "string",
644                            "description": "Path to search in (default: current directory)"
645                        }
646                    },
647                    "required": ["pattern"]
648                }),
649            ),
650            "ast_grep" => (
651                "Search code using AST patterns (structural code matching)".to_string(),
652                json!({
653                    "type": "object",
654                    "properties": {
655                        "pattern": {
656                            "type": "string",
657                            "description": "AST pattern to match"
658                        },
659                        "language": {
660                            "type": "string",
661                            "description": "Programming language (rust, typescript, python)"
662                        },
663                        "path": {
664                            "type": "string",
665                            "description": "Path to search in (default: current directory)"
666                        }
667                    },
668                    "required": ["pattern", "language"]
669                }),
670            ),
671            "lsp" => (
672                "Perform Language Server Protocol operations (goto_definition, find_references)"
673                    .to_string(),
674                json!({
675                    "type": "object",
676                    "properties": {
677                        "command": {
678                            "type": "string",
679                            "description": "LSP command: goto_definition or find_references"
680                        },
681                        "file_path": {
682                            "type": "string",
683                            "description": "Path to the file"
684                        },
685                        "position": {
686                            "type": "object",
687                            "description": "Position in the file (line, character)",
688                            "properties": {
689                                "line": {"type": "integer"},
690                                "character": {"type": "integer"}
691                            },
692                            "required": ["line", "character"]
693                        }
694                    },
695                    "required": ["command", "file_path", "position"]
696                }),
697            ),
698            "web_search" => (
699                format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
700                json!({
701                    "type": "object",
702                    "properties": {
703                        "query": {
704                            "type": "string",
705                            "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
706                        },
707                        "numResults": {
708                            "type": "integer",
709                            "description": "Number of results to return (default: 8, max: 20)",
710                            "default": 8
711                        }
712                    },
713                    "required": ["query"]
714                }),
715            ),
716            "web_fetch" => (
717                "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
718                json!({
719                    "type": "object",
720                    "properties": {
721                        "url": {
722                            "type": "string",
723                            "description": "URL to fetch (must start with http:// or https://)"
724                        },
725                        "format": {
726                            "type": "string",
727                            "enum": ["markdown", "text", "html"],
728                            "default": "markdown",
729                            "description": "Output format (default: markdown)"
730                        }
731                    },
732                    "required": ["url"]
733                }),
734            ),
735            _ => (
736                format!("Tool: {}", name),
737                json!({
738                    "type": "object",
739                    "properties": {},
740                    "required": []
741                }),
742            ),
743        }
744    }
745
746    /// Send an event through the event channel
747    fn send_event(&self, event: AgentEvent) {
748        if let Some(ref tx) = self.event_tx {
749            let _ = tx.send(event);
750        }
751    }
752
753    /// Check if the bridge is ready to process messages
754    #[allow(dead_code)]
755    pub fn is_ready(&self) -> bool {
756        self.config
757            .providers
758            .get(&self.config.provider)
759            .map(|p| p.api_key_or_env(&self.config.provider).is_some())
760            .unwrap_or(false)
761    }
762
763    /// Get the current model name
764    pub fn model(&self) -> &str {
765        self.config
766            .providers
767            .get(&self.config.provider)
768            .map(|p| p.model.as_str())
769            .unwrap_or("")
770    }
771
772    /// Get the max tokens setting
773    pub fn max_tokens(&self) -> u32 {
774        self.config
775            .providers
776            .get(&self.config.provider)
777            .map(|p| p.max_tokens)
778            .unwrap_or(4096)
779    }
780
781    /// Get the timeout setting
782    pub fn timeout(&self) -> u64 {
783        self.config
784            .providers
785            .get(&self.config.provider)
786            .map(|p| p.timeout)
787            .unwrap_or(60)
788    }
789}
790/// Calculate cost based on model pricing (per 1M tokens)
791fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
792    let (input_price, output_price) = match model {
793        // Claude 3.5 Sonnet: $3/1M input, $15/1M output
794        "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
795        // GPT-4: $30/1M input, $60/1M output
796        "gpt-4" => (30.0, 60.0),
797        // GPT-4 Turbo: $10/1M input, $30/1M output
798        "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
799        // Default: no cost tracking
800        _ => (0.0, 0.0),
801    };
802    (input_tokens as f64 * input_price / 1_000_000.0)
803        + (output_tokens as f64 * output_price / 1_000_000.0)
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use limit_llm::{Config as LlmConfig, ProviderConfig};
810    use std::collections::HashMap;
811
812    #[tokio::test]
813    async fn test_agent_bridge_new() {
814        let mut providers = HashMap::new();
815        providers.insert(
816            "anthropic".to_string(),
817            ProviderConfig {
818                api_key: Some("test-key".to_string()),
819                model: "claude-3-5-sonnet-20241022".to_string(),
820                base_url: None,
821                max_tokens: 4096,
822                timeout: 60,
823                max_iterations: 100,
824                thinking_enabled: false,
825                clear_thinking: true,
826            },
827        );
828        let config = LlmConfig {
829            provider: "anthropic".to_string(),
830            providers,
831        };
832
833        let bridge = AgentBridge::new(config).unwrap();
834        assert!(bridge.is_ready());
835    }
836
837    #[tokio::test]
838    async fn test_agent_bridge_new_no_api_key() {
839        let mut providers = HashMap::new();
840        providers.insert(
841            "anthropic".to_string(),
842            ProviderConfig {
843                api_key: None,
844                model: "claude-3-5-sonnet-20241022".to_string(),
845                base_url: None,
846                max_tokens: 4096,
847                timeout: 60,
848                max_iterations: 100,
849                thinking_enabled: false,
850                clear_thinking: true,
851            },
852        );
853        let config = LlmConfig {
854            provider: "anthropic".to_string(),
855            providers,
856        };
857
858        let result = AgentBridge::new(config);
859        assert!(result.is_err());
860    }
861
862    #[tokio::test]
863    async fn test_get_tool_definitions() {
864        let mut providers = HashMap::new();
865        providers.insert(
866            "anthropic".to_string(),
867            ProviderConfig {
868                api_key: Some("test-key".to_string()),
869                model: "claude-3-5-sonnet-20241022".to_string(),
870                base_url: None,
871                max_tokens: 4096,
872                timeout: 60,
873                max_iterations: 100,
874                thinking_enabled: false,
875                clear_thinking: true,
876            },
877        );
878        let config = LlmConfig {
879            provider: "anthropic".to_string(),
880            providers,
881        };
882
883        let bridge = AgentBridge::new(config).unwrap();
884        let definitions = bridge.get_tool_definitions();
885
886        assert_eq!(definitions.len(), 17);
887
888        // Check file_read tool definition
889        let file_read = definitions
890            .iter()
891            .find(|d| d.function.name == "file_read")
892            .unwrap();
893        assert_eq!(file_read.tool_type, "function");
894        assert_eq!(file_read.function.name, "file_read");
895        assert!(file_read.function.description.contains("Read"));
896
897        // Check bash tool definition
898        let bash = definitions
899            .iter()
900            .find(|d| d.function.name == "bash")
901            .unwrap();
902        assert_eq!(bash.function.name, "bash");
903        assert!(bash.function.parameters["required"]
904            .as_array()
905            .unwrap()
906            .contains(&"command".into()));
907    }
908
909    #[test]
910    fn test_get_tool_schema() {
911        let (desc, params) = AgentBridge::get_tool_schema("file_read");
912        assert!(desc.contains("Read"));
913        assert_eq!(params["properties"]["path"]["type"], "string");
914        assert!(params["required"]
915            .as_array()
916            .unwrap()
917            .contains(&"path".into()));
918
919        let (desc, params) = AgentBridge::get_tool_schema("bash");
920        assert!(desc.contains("bash"));
921        assert_eq!(params["properties"]["command"]["type"], "string");
922
923        let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
924        assert!(desc.contains("unknown_tool"));
925    }
926
927    #[test]
928    fn test_is_ready() {
929        let mut providers = HashMap::new();
930        providers.insert(
931            "anthropic".to_string(),
932            ProviderConfig {
933                api_key: Some("test-key".to_string()),
934                model: "claude-3-5-sonnet-20241022".to_string(),
935                base_url: None,
936                max_tokens: 4096,
937                timeout: 60,
938                max_iterations: 100,
939                thinking_enabled: false,
940                clear_thinking: true,
941            },
942        );
943        let config_with_key = LlmConfig {
944            provider: "anthropic".to_string(),
945            providers,
946        };
947
948        let bridge = AgentBridge::new(config_with_key).unwrap();
949        assert!(bridge.is_ready());
950    }
951}