Skip to main content

limit_cli/
agent_bridge.rs

1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4    AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5    GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6    LspTool, WebFetchTool, WebSearchTool,
7};
8use chrono::Datelike;
9use futures::StreamExt;
10use limit_agent::executor::{ToolCall, ToolExecutor};
11use limit_agent::registry::ToolRegistry;
12use limit_llm::providers::LlmProvider;
13use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
14use limit_llm::ProviderFactory;
15use limit_llm::ProviderResponseChunk;
16use limit_llm::TrackingDb;
17use serde_json::json;
18use tokio::sync::mpsc;
19use tracing::{debug, instrument};
20
21/// Event types for streaming from agent to REPL
22#[derive(Debug, Clone)]
23#[allow(dead_code)]
24pub enum AgentEvent {
25    Thinking,
26    ToolStart {
27        name: String,
28        args: serde_json::Value,
29    },
30    ToolComplete {
31        name: String,
32        result: String,
33    },
34    ContentChunk(String),
35    Done,
36    Error(String),
37    TokenUsage {
38        input_tokens: u64,
39        output_tokens: u64,
40    },
41}
42
43/// Bridge connecting limit-cli REPL to limit-agent executor and limit-llm client
44pub struct AgentBridge {
45    /// LLM client for communicating with LLM providers
46    llm_client: Box<dyn LlmProvider>,
47    /// Tool executor for running tool calls
48    executor: ToolExecutor,
49    /// List of registered tool names
50    tool_names: Vec<&'static str>,
51    /// Configuration loaded from ~/.limit/config.toml
52    config: limit_llm::Config,
53    /// Event sender for streaming events to REPL
54    event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
55    /// Token usage tracking database
56    tracking_db: TrackingDb,
57}
58
59impl AgentBridge {
60    /// Create a new AgentBridge with the given configuration
61    ///
62    /// # Arguments
63    /// * `config` - LLM configuration (API key, model, etc.)
64    ///
65    /// # Returns
66    /// A new AgentBridge instance or an error if initialization fails
67    pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
68        let llm_client = ProviderFactory::create_provider(&config)
69            .map_err(|e| CliError::ConfigError(e.to_string()))?;
70
71        let mut tool_registry = ToolRegistry::new();
72        Self::register_tools(&mut tool_registry);
73
74        // Create executor (which takes ownership of registry as Arc)
75        let executor = ToolExecutor::new(tool_registry);
76
77        // Generate tool definitions before giving ownership to executor
78        let tool_names = vec![
79            "file_read",
80            "file_write",
81            "file_edit",
82            "bash",
83            "git_status",
84            "git_diff",
85            "git_log",
86            "git_add",
87            "git_commit",
88            "git_push",
89            "git_pull",
90            "git_clone",
91            "grep",
92            "ast_grep",
93            "lsp",
94            "web_search",
95            "web_fetch",
96        ];
97
98        Ok(Self {
99            llm_client,
100            executor,
101            tool_names,
102            config,
103            event_tx: None,
104            tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
105        })
106    }
107
108    /// Set the event channel sender for streaming events
109    pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
110        self.event_tx = Some(tx);
111    }
112
113    /// Register all CLI tools into the tool registry
114    fn register_tools(registry: &mut ToolRegistry) {
115        // File tools
116        registry
117            .register(FileReadTool::new())
118            .expect("Failed to register file_read");
119        registry
120            .register(FileWriteTool::new())
121            .expect("Failed to register file_write");
122        registry
123            .register(FileEditTool::new())
124            .expect("Failed to register file_edit");
125
126        // Bash tool
127        registry
128            .register(BashTool::new())
129            .expect("Failed to register bash");
130
131        // Git tools
132        registry
133            .register(GitStatusTool::new())
134            .expect("Failed to register git_status");
135        registry
136            .register(GitDiffTool::new())
137            .expect("Failed to register git_diff");
138        registry
139            .register(GitLogTool::new())
140            .expect("Failed to register git_log");
141        registry
142            .register(GitAddTool::new())
143            .expect("Failed to register git_add");
144        registry
145            .register(GitCommitTool::new())
146            .expect("Failed to register git_commit");
147        registry
148            .register(GitPushTool::new())
149            .expect("Failed to register git_push");
150        registry
151            .register(GitPullTool::new())
152            .expect("Failed to register git_pull");
153        registry
154            .register(GitCloneTool::new())
155            .expect("Failed to register git_clone");
156
157        // Analysis tools
158        registry
159            .register(GrepTool::new())
160            .expect("Failed to register grep");
161        registry
162            .register(AstGrepTool::new())
163            .expect("Failed to register ast_grep");
164        registry
165            .register(LspTool::new())
166            .expect("Failed to register lsp");
167
168        // Web tools
169        registry
170            .register(WebSearchTool::new())
171            .expect("Failed to register web_search");
172        registry
173            .register(WebFetchTool::new())
174            .expect("Failed to register web_fetch");
175    }
176
177    /// Process a user message through the LLM and execute any tool calls
178    ///
179    /// # Arguments
180    /// * `user_input` - The user's message to process
181    /// * `messages` - The conversation history (will be updated in place)
182    ///
183    /// # Returns
184    /// The final response from the LLM or an error
185    #[instrument(skip(self, _messages))]
186    pub async fn process_message(
187        &mut self,
188        user_input: &str,
189        _messages: &mut Vec<Message>,
190    ) -> Result<String, CliError> {
191        // Add system message if this is the first message in the conversation
192        // Note: Some providers (z.ai) don't support system role, but OpenAI-compatible APIs generally do
193        if _messages.is_empty() {
194            let system_message = Message {
195                role: Role::System,
196                content: Some(SYSTEM_PROMPT.to_string()),
197                tool_calls: None,
198                tool_call_id: None,
199            };
200            _messages.push(system_message);
201        }
202
203        // Add user message to history
204        let user_message = Message {
205            role: Role::User,
206            content: Some(user_input.to_string()),
207            tool_calls: None,
208            tool_call_id: None,
209        };
210        _messages.push(user_message);
211
212        // Get tool definitions
213        let tool_definitions = self.get_tool_definitions();
214
215        // Main processing loop
216        let mut full_response = String::new();
217        let mut tool_calls: Vec<LlmToolCall> = Vec::new();
218        let max_iterations = self
219            .config
220            .providers
221            .get(&self.config.provider)
222            .map(|p| p.max_iterations)
223            .unwrap_or(100); // Allow enough iterations for complex tasks
224        let mut iteration = 0;
225
226        while max_iterations == 0 || iteration < max_iterations {
227            iteration += 1;
228            debug!("Agent loop iteration {}", iteration);
229
230            // Send thinking event
231            self.send_event(AgentEvent::Thinking);
232
233            // Track timing for token usage
234            let request_start = std::time::Instant::now();
235
236            // Call LLM
237            let mut stream = self
238                .llm_client
239                .send(_messages.clone(), tool_definitions.clone())
240                .await
241                .map_err(|e| CliError::ConfigError(e.to_string()))?;
242
243            tool_calls.clear();
244            let mut current_content = String::new();
245            // Track tool calls: (id) -> (name, args)
246            let mut accumulated_calls: std::collections::HashMap<
247                String,
248                (String, serde_json::Value),
249            > = std::collections::HashMap::new();
250
251            // Process stream chunks
252            while let Some(chunk_result) = stream.next().await {
253                match chunk_result {
254                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
255                        current_content.push_str(&text);
256                        debug!(
257                            "ContentDelta: {} chars (total: {})",
258                            text.len(),
259                            current_content.len()
260                        );
261                        self.send_event(AgentEvent::ContentChunk(text));
262                    }
263                    Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
264                        // Ignore reasoning chunks for now
265                    }
266                    Ok(ProviderResponseChunk::ToolCallDelta {
267                        id,
268                        name,
269                        arguments,
270                    }) => {
271                        debug!(
272                            "ToolCallDelta: id={}, name={}, args_len={}",
273                            id,
274                            name,
275                            arguments.to_string().len()
276                        );
277                        // Store/merge tool call arguments
278                        accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
279                    }
280                    Ok(ProviderResponseChunk::Done(usage)) => {
281                        // Track token usage
282                        let duration_ms = request_start.elapsed().as_millis() as u64;
283                        let cost =
284                            calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
285                        let _ = self.tracking_db.track_request(
286                            self.model(),
287                            usage.input_tokens,
288                            usage.output_tokens,
289                            cost,
290                            duration_ms,
291                        );
292                        // Emit token usage event for TUI display
293                        self.send_event(AgentEvent::TokenUsage {
294                            input_tokens: usage.input_tokens,
295                            output_tokens: usage.output_tokens,
296                        });
297                        break;
298                    }
299                    Err(e) => {
300                        let error_msg = format!("LLM error: {}", e);
301                        self.send_event(AgentEvent::Error(error_msg.clone()));
302                        return Err(CliError::ConfigError(error_msg));
303                    }
304                }
305            }
306
307            // Convert accumulated calls to Vec<ToolCall>
308            tool_calls = accumulated_calls
309                .into_iter()
310                .map(|(id, (name, args))| LlmToolCall {
311                    id,
312                    tool_type: "function".to_string(),
313                    function: limit_llm::types::FunctionCall {
314                        name,
315                        arguments: args.to_string(),
316                    },
317                })
318                .collect();
319
320            // BUG FIX: Don't accumulate content across iterations
321            // Only store content from the current iteration
322            // If there are tool calls, we'll continue the loop and the LLM will see the tool results
323            // If there are NO tool calls, this is the final response
324            full_response = current_content.clone();
325
326            debug!(
327                "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
328                iteration,
329                current_content.len(),
330                tool_calls.len(),
331                full_response.len()
332            );
333
334            // If no tool calls, we're done
335            if tool_calls.is_empty() {
336                debug!("No tool calls, breaking loop after iteration {}", iteration);
337                break;
338            }
339
340            debug!(
341                "Tool calls found (count={}), continuing to iteration {}",
342                tool_calls.len(),
343                iteration + 1
344            );
345
346            // Execute tool calls - add assistant message with tool_calls
347            // Note: Per OpenAI API spec, when tool_calls are present, content should be null
348            let assistant_message = Message {
349                role: Role::Assistant,
350                content: None, // Don't include content when tool_calls are present
351                tool_calls: Some(tool_calls.clone()),
352                tool_call_id: None,
353            };
354            _messages.push(assistant_message);
355
356            // Convert LLM tool calls to executor tool calls
357            let executor_calls: Vec<ToolCall> = tool_calls
358                .iter()
359                .map(|tc| {
360                    let args: serde_json::Value =
361                        serde_json::from_str(&tc.function.arguments).unwrap_or_default();
362                    ToolCall::new(&tc.id, &tc.function.name, args)
363                })
364                .collect();
365
366            // Send ToolStart event for each tool BEFORE execution
367            for tc in &tool_calls {
368                let args: serde_json::Value =
369                    serde_json::from_str(&tc.function.arguments).unwrap_or_default();
370                self.send_event(AgentEvent::ToolStart {
371                    name: tc.function.name.clone(),
372                    args,
373                });
374            }
375            // Execute tools
376            let results = self.executor.execute_tools(executor_calls).await;
377
378            // Add tool results to messages (OpenAI format: role=tool, tool_call_id, content)
379            for result in results {
380                let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
381                if let Some(tool_call) = tool_call {
382                    let output_json = match &result.output {
383                        Ok(value) => {
384                            serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
385                        }
386                        Err(e) => json!({ "error": e.to_string() }).to_string(),
387                    };
388
389                    self.send_event(AgentEvent::ToolComplete {
390                        name: tool_call.function.name.clone(),
391                        result: output_json.clone(),
392                    });
393
394                    // OpenAI tool result format
395                    let tool_result_message = Message {
396                        role: Role::Tool,
397                        content: Some(output_json),
398                        tool_calls: None,
399                        tool_call_id: Some(result.call_id),
400                    };
401                    _messages.push(tool_result_message);
402                }
403            }
404        }
405
406        // If we hit max iterations, make one final request to get a response (no tools = forced text)
407        // IMPORTANT: Only do this if max_iterations > 0 (0 means unlimited, so we never "hit" the limit)
408        if max_iterations > 0 && iteration >= max_iterations && !_messages.is_empty() {
409            debug!("Making final LLM call after hitting max iterations (forcing text response)");
410
411            // Add constraint message to force text response
412            let constraint_message = Message {
413                role: Role::User,
414                content: Some(
415                    "We've reached the iteration limit. Please provide a summary of:\n\
416                    1. What you've completed so far\n\
417                    2. What remains to be done\n\
418                    3. Recommended next steps for the user to continue"
419                        .to_string(),
420                ),
421                tool_calls: None,
422                tool_call_id: None,
423            };
424            _messages.push(constraint_message);
425
426            // Send with NO tools to force text response
427            let no_tools: Vec<LlmTool> = vec![];
428            let mut stream = self
429                .llm_client
430                .send(_messages.clone(), no_tools)
431                .await
432                .map_err(|e| CliError::ConfigError(e.to_string()))?;
433
434            // BUG FIX: Replace full_response instead of appending
435            full_response.clear();
436            while let Some(chunk_result) = stream.next().await {
437                match chunk_result {
438                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
439                        full_response.push_str(&text);
440                        self.send_event(AgentEvent::ContentChunk(text));
441                    }
442                    Ok(ProviderResponseChunk::Done(_)) => {
443                        break;
444                    }
445                    Err(e) => {
446                        debug!("Error in final LLM call: {}", e);
447                        break;
448                    }
449                    _ => {}
450                }
451            }
452        }
453
454        // IMPORTANT: Add final assistant response to message history for session persistence
455        // This is crucial for session export/share to work correctly
456        // Only add if we have content AND we haven't already added this response
457        if !full_response.is_empty() {
458            // Find the last assistant message and check if it has content
459            // If it has tool_calls but no content, UPDATE it instead of adding a new one
460            // This prevents accumulation of empty assistant messages in the history
461            let last_assistant_idx = _messages.iter().rposition(|m| m.role == Role::Assistant);
462
463            if let Some(idx) = last_assistant_idx {
464                let last_assistant = &mut _messages[idx];
465
466                // If the last assistant message has no content (tool_calls only), update it
467                if last_assistant.content.is_none()
468                    || last_assistant
469                        .content
470                        .as_ref()
471                        .map(|c| c.is_empty())
472                        .unwrap_or(true)
473                {
474                    last_assistant.content = Some(full_response.clone());
475                    debug!("Updated last assistant message with final response content");
476                } else {
477                    // Last assistant already has content, this shouldn't happen normally
478                    // but we add a new message to be safe
479                    debug!("Last assistant already has content, adding new message");
480                    let final_assistant_message = Message {
481                        role: Role::Assistant,
482                        content: Some(full_response.clone()),
483                        tool_calls: None,
484                        tool_call_id: None,
485                    };
486                    _messages.push(final_assistant_message);
487                }
488            } else {
489                // No assistant message found, add a new one
490                debug!("No assistant message found, adding new message");
491                let final_assistant_message = Message {
492                    role: Role::Assistant,
493                    content: Some(full_response.clone()),
494                    tool_calls: None,
495                    tool_call_id: None,
496                };
497                _messages.push(final_assistant_message);
498            }
499        }
500
501        self.send_event(AgentEvent::Done);
502        Ok(full_response)
503    }
504
505    /// Get tool definitions formatted for the LLM
506    pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
507        self.tool_names
508            .iter()
509            .map(|name| {
510                let (description, parameters) = Self::get_tool_schema(name);
511                LlmTool {
512                    tool_type: "function".to_string(),
513                    function: limit_llm::types::ToolFunction {
514                        name: name.to_string(),
515                        description,
516                        parameters,
517                    },
518                }
519            })
520            .collect()
521    }
522
523    /// Get the schema (description and parameters) for a tool
524    fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
525        match name {
526            "file_read" => (
527                "Read the contents of a file".to_string(),
528                json!({
529                    "type": "object",
530                    "properties": {
531                        "path": {
532                            "type": "string",
533                            "description": "Path to the file to read"
534                        }
535                    },
536                    "required": ["path"]
537                }),
538            ),
539            "file_write" => (
540                "Write content to a file, creating parent directories if needed".to_string(),
541                json!({
542                    "type": "object",
543                    "properties": {
544                        "path": {
545                            "type": "string",
546                            "description": "Path to the file to write"
547                        },
548                        "content": {
549                            "type": "string",
550                            "description": "Content to write to the file"
551                        }
552                    },
553                    "required": ["path", "content"]
554                }),
555            ),
556            "file_edit" => (
557                "Replace text in a file with new text".to_string(),
558                json!({
559                    "type": "object",
560                    "properties": {
561                        "path": {
562                            "type": "string",
563                            "description": "Path to the file to edit"
564                        },
565                        "old_text": {
566                            "type": "string",
567                            "description": "Text to find and replace"
568                        },
569                        "new_text": {
570                            "type": "string",
571                            "description": "New text to replace with"
572                        }
573                    },
574                    "required": ["path", "old_text", "new_text"]
575                }),
576            ),
577            "bash" => (
578                "Execute a bash command in a shell".to_string(),
579                json!({
580                    "type": "object",
581                    "properties": {
582                        "command": {
583                            "type": "string",
584                            "description": "Bash command to execute"
585                        },
586                        "workdir": {
587                            "type": "string",
588                            "description": "Working directory (default: current directory)"
589                        },
590                        "timeout": {
591                            "type": "integer",
592                            "description": "Timeout in seconds (default: 60)"
593                        }
594                    },
595                    "required": ["command"]
596                }),
597            ),
598            "git_status" => (
599                "Get git repository status".to_string(),
600                json!({
601                    "type": "object",
602                    "properties": {},
603                    "required": []
604                }),
605            ),
606            "git_diff" => (
607                "Get git diff".to_string(),
608                json!({
609                    "type": "object",
610                    "properties": {},
611                    "required": []
612                }),
613            ),
614            "git_log" => (
615                "Get git commit log".to_string(),
616                json!({
617                    "type": "object",
618                    "properties": {
619                        "count": {
620                            "type": "integer",
621                            "description": "Number of commits to show (default: 10)"
622                        }
623                    },
624                    "required": []
625                }),
626            ),
627            "git_add" => (
628                "Add files to git staging area".to_string(),
629                json!({
630                    "type": "object",
631                    "properties": {
632                        "files": {
633                            "type": "array",
634                            "items": {"type": "string"},
635                            "description": "List of file paths to add"
636                        }
637                    },
638                    "required": ["files"]
639                }),
640            ),
641            "git_commit" => (
642                "Create a git commit".to_string(),
643                json!({
644                    "type": "object",
645                    "properties": {
646                        "message": {
647                            "type": "string",
648                            "description": "Commit message"
649                        }
650                    },
651                    "required": ["message"]
652                }),
653            ),
654            "git_push" => (
655                "Push commits to remote repository".to_string(),
656                json!({
657                    "type": "object",
658                    "properties": {
659                        "remote": {
660                            "type": "string",
661                            "description": "Remote name (default: origin)"
662                        },
663                        "branch": {
664                            "type": "string",
665                            "description": "Branch name (default: current branch)"
666                        }
667                    },
668                    "required": []
669                }),
670            ),
671            "git_pull" => (
672                "Pull changes from remote repository".to_string(),
673                json!({
674                    "type": "object",
675                    "properties": {
676                        "remote": {
677                            "type": "string",
678                            "description": "Remote name (default: origin)"
679                        },
680                        "branch": {
681                            "type": "string",
682                            "description": "Branch name (default: current branch)"
683                        }
684                    },
685                    "required": []
686                }),
687            ),
688            "git_clone" => (
689                "Clone a git repository".to_string(),
690                json!({
691                    "type": "object",
692                    "properties": {
693                        "url": {
694                            "type": "string",
695                            "description": "Repository URL to clone"
696                        },
697                        "directory": {
698                            "type": "string",
699                            "description": "Directory to clone into (optional)"
700                        }
701                    },
702                    "required": ["url"]
703                }),
704            ),
705            "grep" => (
706                "Search for text patterns in files using regex".to_string(),
707                json!({
708                    "type": "object",
709                    "properties": {
710                        "pattern": {
711                            "type": "string",
712                            "description": "Regex pattern to search for"
713                        },
714                        "path": {
715                            "type": "string",
716                            "description": "Path to search in (default: current directory)"
717                        }
718                    },
719                    "required": ["pattern"]
720                }),
721            ),
722            "ast_grep" => (
723                "Search code using AST patterns (structural code matching)".to_string(),
724                json!({
725                    "type": "object",
726                    "properties": {
727                        "pattern": {
728                            "type": "string",
729                            "description": "AST pattern to match"
730                        },
731                        "language": {
732                            "type": "string",
733                            "description": "Programming language (rust, typescript, python)"
734                        },
735                        "path": {
736                            "type": "string",
737                            "description": "Path to search in (default: current directory)"
738                        }
739                    },
740                    "required": ["pattern", "language"]
741                }),
742            ),
743            "lsp" => (
744                "Perform Language Server Protocol operations (goto_definition, find_references)"
745                    .to_string(),
746                json!({
747                    "type": "object",
748                    "properties": {
749                        "command": {
750                            "type": "string",
751                            "description": "LSP command: goto_definition or find_references"
752                        },
753                        "file_path": {
754                            "type": "string",
755                            "description": "Path to the file"
756                        },
757                        "position": {
758                            "type": "object",
759                            "description": "Position in the file (line, character)",
760                            "properties": {
761                                "line": {"type": "integer"},
762                                "character": {"type": "integer"}
763                            },
764                            "required": ["line", "character"]
765                        }
766                    },
767                    "required": ["command", "file_path", "position"]
768                }),
769            ),
770            "web_search" => (
771                format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
772                json!({
773                    "type": "object",
774                    "properties": {
775                        "query": {
776                            "type": "string",
777                            "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
778                        },
779                        "numResults": {
780                            "type": "integer",
781                            "description": "Number of results to return (default: 8, max: 20)",
782                            "default": 8
783                        }
784                    },
785                    "required": ["query"]
786                }),
787            ),
788            "web_fetch" => (
789                "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
790                json!({
791                    "type": "object",
792                    "properties": {
793                        "url": {
794                            "type": "string",
795                            "description": "URL to fetch (must start with http:// or https://)"
796                        },
797                        "format": {
798                            "type": "string",
799                            "enum": ["markdown", "text", "html"],
800                            "default": "markdown",
801                            "description": "Output format (default: markdown)"
802                        }
803                    },
804                    "required": ["url"]
805                }),
806            ),
807            _ => (
808                format!("Tool: {}", name),
809                json!({
810                    "type": "object",
811                    "properties": {},
812                    "required": []
813                }),
814            ),
815        }
816    }
817
818    /// Send an event through the event channel
819    fn send_event(&self, event: AgentEvent) {
820        if let Some(ref tx) = self.event_tx {
821            let _ = tx.send(event);
822        }
823    }
824
825    /// Check if the bridge is ready to process messages
826    #[allow(dead_code)]
827    pub fn is_ready(&self) -> bool {
828        self.config
829            .providers
830            .get(&self.config.provider)
831            .map(|p| p.api_key_or_env(&self.config.provider).is_some())
832            .unwrap_or(false)
833    }
834
835    /// Get the current model name
836    pub fn model(&self) -> &str {
837        self.config
838            .providers
839            .get(&self.config.provider)
840            .map(|p| p.model.as_str())
841            .unwrap_or("")
842    }
843
844    /// Get the max tokens setting
845    pub fn max_tokens(&self) -> u32 {
846        self.config
847            .providers
848            .get(&self.config.provider)
849            .map(|p| p.max_tokens)
850            .unwrap_or(4096)
851    }
852
853    /// Get the timeout setting
854    pub fn timeout(&self) -> u64 {
855        self.config
856            .providers
857            .get(&self.config.provider)
858            .map(|p| p.timeout)
859            .unwrap_or(60)
860    }
861}
862/// Calculate cost based on model pricing (per 1M tokens)
863fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
864    let (input_price, output_price) = match model {
865        // Claude 3.5 Sonnet: $3/1M input, $15/1M output
866        "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
867        // GPT-4: $30/1M input, $60/1M output
868        "gpt-4" => (30.0, 60.0),
869        // GPT-4 Turbo: $10/1M input, $30/1M output
870        "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
871        // Default: no cost tracking
872        _ => (0.0, 0.0),
873    };
874    (input_tokens as f64 * input_price / 1_000_000.0)
875        + (output_tokens as f64 * output_price / 1_000_000.0)
876}
877
878#[cfg(test)]
879mod tests {
880    use super::*;
881    use limit_llm::{Config as LlmConfig, ProviderConfig};
882    use std::collections::HashMap;
883
884    #[tokio::test]
885    async fn test_agent_bridge_new() {
886        let mut providers = HashMap::new();
887        providers.insert(
888            "anthropic".to_string(),
889            ProviderConfig {
890                api_key: Some("test-key".to_string()),
891                model: "claude-3-5-sonnet-20241022".to_string(),
892                base_url: None,
893                max_tokens: 4096,
894                timeout: 60,
895                max_iterations: 100,
896                thinking_enabled: false,
897                clear_thinking: true,
898            },
899        );
900        let config = LlmConfig {
901            provider: "anthropic".to_string(),
902            providers,
903        };
904
905        let bridge = AgentBridge::new(config).unwrap();
906        assert!(bridge.is_ready());
907    }
908
909    #[tokio::test]
910    async fn test_agent_bridge_new_no_api_key() {
911        let mut providers = HashMap::new();
912        providers.insert(
913            "anthropic".to_string(),
914            ProviderConfig {
915                api_key: None,
916                model: "claude-3-5-sonnet-20241022".to_string(),
917                base_url: None,
918                max_tokens: 4096,
919                timeout: 60,
920                max_iterations: 100,
921                thinking_enabled: false,
922                clear_thinking: true,
923            },
924        );
925        let config = LlmConfig {
926            provider: "anthropic".to_string(),
927            providers,
928        };
929
930        let result = AgentBridge::new(config);
931        assert!(result.is_err());
932    }
933
934    #[tokio::test]
935    async fn test_get_tool_definitions() {
936        let mut providers = HashMap::new();
937        providers.insert(
938            "anthropic".to_string(),
939            ProviderConfig {
940                api_key: Some("test-key".to_string()),
941                model: "claude-3-5-sonnet-20241022".to_string(),
942                base_url: None,
943                max_tokens: 4096,
944                timeout: 60,
945                max_iterations: 100,
946                thinking_enabled: false,
947                clear_thinking: true,
948            },
949        );
950        let config = LlmConfig {
951            provider: "anthropic".to_string(),
952            providers,
953        };
954
955        let bridge = AgentBridge::new(config).unwrap();
956        let definitions = bridge.get_tool_definitions();
957
958        assert_eq!(definitions.len(), 17);
959
960        // Check file_read tool definition
961        let file_read = definitions
962            .iter()
963            .find(|d| d.function.name == "file_read")
964            .unwrap();
965        assert_eq!(file_read.tool_type, "function");
966        assert_eq!(file_read.function.name, "file_read");
967        assert!(file_read.function.description.contains("Read"));
968
969        // Check bash tool definition
970        let bash = definitions
971            .iter()
972            .find(|d| d.function.name == "bash")
973            .unwrap();
974        assert_eq!(bash.function.name, "bash");
975        assert!(bash.function.parameters["required"]
976            .as_array()
977            .unwrap()
978            .contains(&"command".into()));
979    }
980
981    #[test]
982    fn test_get_tool_schema() {
983        let (desc, params) = AgentBridge::get_tool_schema("file_read");
984        assert!(desc.contains("Read"));
985        assert_eq!(params["properties"]["path"]["type"], "string");
986        assert!(params["required"]
987            .as_array()
988            .unwrap()
989            .contains(&"path".into()));
990
991        let (desc, params) = AgentBridge::get_tool_schema("bash");
992        assert!(desc.contains("bash"));
993        assert_eq!(params["properties"]["command"]["type"], "string");
994
995        let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
996        assert!(desc.contains("unknown_tool"));
997    }
998
999    #[test]
1000    fn test_is_ready() {
1001        let mut providers = HashMap::new();
1002        providers.insert(
1003            "anthropic".to_string(),
1004            ProviderConfig {
1005                api_key: Some("test-key".to_string()),
1006                model: "claude-3-5-sonnet-20241022".to_string(),
1007                base_url: None,
1008                max_tokens: 4096,
1009                timeout: 60,
1010                max_iterations: 100,
1011                thinking_enabled: false,
1012                clear_thinking: true,
1013            },
1014        );
1015        let config_with_key = LlmConfig {
1016            provider: "anthropic".to_string(),
1017            providers,
1018        };
1019
1020        let bridge = AgentBridge::new(config_with_key).unwrap();
1021        assert!(bridge.is_ready());
1022    }
1023}