Skip to main content

limit_cli/
agent_bridge.rs

1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4    AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5    GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6    LspTool, WebFetchTool, WebSearchTool,
7};
8use chrono::Datelike;
9use futures::StreamExt;
10use limit_agent::executor::{ToolCall, ToolExecutor};
11use limit_agent::registry::ToolRegistry;
12use limit_llm::providers::LlmProvider;
13use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
14use limit_llm::ProviderFactory;
15use limit_llm::ProviderResponseChunk;
16use limit_llm::TrackingDb;
17use serde_json::json;
18use tokio::sync::mpsc;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, instrument};
21
22/// Event types for streaming from agent to REPL
23#[derive(Debug, Clone)]
24#[allow(dead_code)]
25pub enum AgentEvent {
26    Thinking {
27        operation_id: u64,
28    },
29    ToolStart {
30        operation_id: u64,
31        name: String,
32        args: serde_json::Value,
33    },
34    ToolComplete {
35        operation_id: u64,
36        name: String,
37        result: String,
38    },
39    ContentChunk {
40        operation_id: u64,
41        chunk: String,
42    },
43    Done {
44        operation_id: u64,
45    },
46    Cancelled {
47        operation_id: u64,
48    },
49    Error {
50        operation_id: u64,
51        message: String,
52    },
53    TokenUsage {
54        operation_id: u64,
55        input_tokens: u64,
56        output_tokens: u64,
57    },
58}
59
60/// Bridge connecting limit-cli REPL to limit-agent executor and limit-llm client
61pub struct AgentBridge {
62    /// LLM client for communicating with LLM providers
63    llm_client: Box<dyn LlmProvider>,
64    /// Tool executor for running tool calls
65    executor: ToolExecutor,
66    /// List of registered tool names
67    tool_names: Vec<&'static str>,
68    /// Configuration loaded from ~/.limit/config.toml
69    config: limit_llm::Config,
70    /// Event sender for streaming events to REPL
71    event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
72    /// Token usage tracking database
73    tracking_db: TrackingDb,
74    /// Cancellation token for aborting current operation
75    cancellation_token: Option<CancellationToken>,
76    /// Current operation ID for event tracking
77    operation_id: u64,
78}
79
80impl AgentBridge {
81    /// Create a new AgentBridge with the given configuration
82    ///
83    /// # Arguments
84    /// * `config` - LLM configuration (API key, model, etc.)
85    ///
86    /// # Returns
87    /// A new AgentBridge instance or an error if initialization fails
88    pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
89        let llm_client = ProviderFactory::create_provider(&config)
90            .map_err(|e| CliError::ConfigError(e.to_string()))?;
91
92        let mut tool_registry = ToolRegistry::new();
93        Self::register_tools(&mut tool_registry);
94
95        // Create executor (which takes ownership of registry as Arc)
96        let executor = ToolExecutor::new(tool_registry);
97
98        // Generate tool definitions before giving ownership to executor
99        let tool_names = vec![
100            "file_read",
101            "file_write",
102            "file_edit",
103            "bash",
104            "git_status",
105            "git_diff",
106            "git_log",
107            "git_add",
108            "git_commit",
109            "git_push",
110            "git_pull",
111            "git_clone",
112            "grep",
113            "ast_grep",
114            "lsp",
115            "web_search",
116            "web_fetch",
117        ];
118
119        Ok(Self {
120            llm_client,
121            executor,
122            tool_names,
123            config,
124            event_tx: None,
125            tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
126            cancellation_token: None,
127            operation_id: 0,
128        })
129    }
130
131    /// Set the event channel sender for streaming events
132    pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
133        self.event_tx = Some(tx);
134    }
135
136    /// Set the cancellation token and operation ID for this operation
137    pub fn set_cancellation_token(&mut self, token: CancellationToken, operation_id: u64) {
138        debug!("set_cancellation_token: operation_id={}", operation_id);
139        self.cancellation_token = Some(token);
140        self.operation_id = operation_id;
141    }
142
143    /// Clear the cancellation token
144    pub fn clear_cancellation_token(&mut self) {
145        self.cancellation_token = None;
146    }
147
148    /// Register all CLI tools into the tool registry
149    fn register_tools(registry: &mut ToolRegistry) {
150        // File tools
151        registry
152            .register(FileReadTool::new())
153            .expect("Failed to register file_read");
154        registry
155            .register(FileWriteTool::new())
156            .expect("Failed to register file_write");
157        registry
158            .register(FileEditTool::new())
159            .expect("Failed to register file_edit");
160
161        // Bash tool
162        registry
163            .register(BashTool::new())
164            .expect("Failed to register bash");
165
166        // Git tools
167        registry
168            .register(GitStatusTool::new())
169            .expect("Failed to register git_status");
170        registry
171            .register(GitDiffTool::new())
172            .expect("Failed to register git_diff");
173        registry
174            .register(GitLogTool::new())
175            .expect("Failed to register git_log");
176        registry
177            .register(GitAddTool::new())
178            .expect("Failed to register git_add");
179        registry
180            .register(GitCommitTool::new())
181            .expect("Failed to register git_commit");
182        registry
183            .register(GitPushTool::new())
184            .expect("Failed to register git_push");
185        registry
186            .register(GitPullTool::new())
187            .expect("Failed to register git_pull");
188        registry
189            .register(GitCloneTool::new())
190            .expect("Failed to register git_clone");
191
192        // Analysis tools
193        registry
194            .register(GrepTool::new())
195            .expect("Failed to register grep");
196        registry
197            .register(AstGrepTool::new())
198            .expect("Failed to register ast_grep");
199        registry
200            .register(LspTool::new())
201            .expect("Failed to register lsp");
202
203        // Web tools
204        registry
205            .register(WebSearchTool::new())
206            .expect("Failed to register web_search");
207        registry
208            .register(WebFetchTool::new())
209            .expect("Failed to register web_fetch");
210    }
211
212    /// Process a user message through the LLM and execute any tool calls
213    ///
214    /// # Arguments
215    /// * `user_input` - The user's message to process
216    /// * `messages` - The conversation history (will be updated in place)
217    ///
218    /// # Returns
219    /// The final response from the LLM or an error
220    #[instrument(skip(self, _messages))]
221    pub async fn process_message(
222        &mut self,
223        user_input: &str,
224        _messages: &mut Vec<Message>,
225    ) -> Result<String, CliError> {
226        // Add system message if this is the first message in the conversation
227        // Note: Some providers (z.ai) don't support system role, but OpenAI-compatible APIs generally do
228        if _messages.is_empty() {
229            let system_message = Message {
230                role: Role::System,
231                content: Some(SYSTEM_PROMPT.to_string()),
232                tool_calls: None,
233                tool_call_id: None,
234            };
235            _messages.push(system_message);
236        }
237
238        // Add user message to history
239        let user_message = Message {
240            role: Role::User,
241            content: Some(user_input.to_string()),
242            tool_calls: None,
243            tool_call_id: None,
244        };
245        _messages.push(user_message);
246
247        // Get tool definitions
248        let tool_definitions = self.get_tool_definitions();
249
250        // Main processing loop
251        let mut full_response = String::new();
252        let mut tool_calls: Vec<LlmToolCall> = Vec::new();
253        let max_iterations = self
254            .config
255            .providers
256            .get(&self.config.provider)
257            .map(|p| p.max_iterations)
258            .unwrap_or(100); // Allow enough iterations for complex tasks
259        let mut iteration = 0;
260
261        while max_iterations == 0 || iteration < max_iterations {
262            iteration += 1;
263            debug!("Agent loop iteration {}", iteration);
264
265            // Send thinking event
266            debug!("Sending Thinking event with operation_id={}", self.operation_id);
267            self.send_event(AgentEvent::Thinking {
268                operation_id: self.operation_id,
269            });
270
271            // Track timing for token usage
272            let request_start = std::time::Instant::now();
273
274            // Call LLM
275            let mut stream = self
276                .llm_client
277                .send(_messages.clone(), tool_definitions.clone())
278                .await
279                .map_err(|e| CliError::ConfigError(e.to_string()))?;
280
281            tool_calls.clear();
282            let mut current_content = String::new();
283            // Track tool calls: (id) -> (name, args)
284            let mut accumulated_calls: std::collections::HashMap<
285                String,
286                (String, serde_json::Value),
287            > = std::collections::HashMap::new();
288
289            // Process stream chunks with cancellation support
290            loop {
291                // Check for cancellation FIRST (before waiting for stream)
292                if let Some(ref token) = self.cancellation_token {
293                    if token.is_cancelled() {
294                        debug!("Operation cancelled by user (pre-stream check)");
295                        self.send_event(AgentEvent::Cancelled {
296                            operation_id: self.operation_id,
297                        });
298                        return Err(CliError::ConfigError(
299                            "Operation cancelled by user".to_string(),
300                        ));
301                    }
302                }
303
304                // Use tokio::select! to check cancellation while waiting for stream
305                // Using cancellation_token.cancelled() for immediate cancellation detection
306                let chunk_result = if let Some(ref token) = self.cancellation_token {
307                    tokio::select! {
308                        chunk = stream.next() => chunk,
309                        _ = token.cancelled() => {
310                            debug!("Operation cancelled via token while waiting for stream");
311                            self.send_event(AgentEvent::Cancelled {
312                                operation_id: self.operation_id,
313                            });
314                            return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
315                        }
316                    }
317                } else {
318                    stream.next().await
319                };
320
321                let Some(chunk_result) = chunk_result else {
322                    // Stream ended
323                    break;
324                };
325
326                match chunk_result {
327                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
328                        current_content.push_str(&text);
329                        debug!(
330                            "ContentDelta: {} chars (total: {})",
331                            text.len(),
332                            current_content.len()
333                        );
334                        self.send_event(AgentEvent::ContentChunk {
335                            operation_id: self.operation_id,
336                            chunk: text,
337                        });
338                    }
339                    Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
340                        // Ignore reasoning chunks for now
341                    }
342                    Ok(ProviderResponseChunk::ToolCallDelta {
343                        id,
344                        name,
345                        arguments,
346                    }) => {
347                        debug!(
348                            "ToolCallDelta: id={}, name={}, args_len={}",
349                            id,
350                            name,
351                            arguments.to_string().len()
352                        );
353                        // Store/merge tool call arguments
354                        accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
355                    }
356                    Ok(ProviderResponseChunk::Done(usage)) => {
357                        // Track token usage
358                        let duration_ms = request_start.elapsed().as_millis() as u64;
359                        let cost =
360                            calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
361                        let _ = self.tracking_db.track_request(
362                            self.model(),
363                            usage.input_tokens,
364                            usage.output_tokens,
365                            cost,
366                            duration_ms,
367                        );
368                        // Emit token usage event for TUI display
369                        self.send_event(AgentEvent::TokenUsage {
370                            operation_id: self.operation_id,
371                            input_tokens: usage.input_tokens,
372                            output_tokens: usage.output_tokens,
373                        });
374                        break;
375                    }
376                    Err(e) => {
377                        let error_msg = format!("LLM error: {}", e);
378                        self.send_event(AgentEvent::Error {
379                            operation_id: self.operation_id,
380                            message: error_msg.clone(),
381                        });
382                        return Err(CliError::ConfigError(error_msg));
383                    }
384                }
385            }
386
387            // Convert accumulated calls to Vec<ToolCall>
388            tool_calls = accumulated_calls
389                .into_iter()
390                .map(|(id, (name, args))| LlmToolCall {
391                    id,
392                    tool_type: "function".to_string(),
393                    function: limit_llm::types::FunctionCall {
394                        name,
395                        arguments: args.to_string(),
396                    },
397                })
398                .collect();
399
400            // BUG FIX: Don't accumulate content across iterations
401            // Only store content from the current iteration
402            // If there are tool calls, we'll continue the loop and the LLM will see the tool results
403            // If there are NO tool calls, this is the final response
404            full_response = current_content.clone();
405
406            debug!(
407                "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
408                iteration,
409                current_content.len(),
410                tool_calls.len(),
411                full_response.len()
412            );
413
414            // If no tool calls, we're done
415            if tool_calls.is_empty() {
416                debug!("No tool calls, breaking loop after iteration {}", iteration);
417                break;
418            }
419
420            debug!(
421                "Tool calls found (count={}), continuing to iteration {}",
422                tool_calls.len(),
423                iteration + 1
424            );
425
426            // Execute tool calls - add assistant message with tool_calls
427            // Note: Per OpenAI API spec, when tool_calls are present, content should be null
428            let assistant_message = Message {
429                role: Role::Assistant,
430                content: None, // Don't include content when tool_calls are present
431                tool_calls: Some(tool_calls.clone()),
432                tool_call_id: None,
433            };
434            _messages.push(assistant_message);
435
436            // Convert LLM tool calls to executor tool calls
437            let executor_calls: Vec<ToolCall> = tool_calls
438                .iter()
439                .map(|tc| {
440                    let args: serde_json::Value =
441                        serde_json::from_str(&tc.function.arguments).unwrap_or_default();
442                    ToolCall::new(&tc.id, &tc.function.name, args)
443                })
444                .collect();
445
446            // Send ToolStart event for each tool BEFORE execution
447            for tc in &tool_calls {
448                let args: serde_json::Value =
449                    serde_json::from_str(&tc.function.arguments).unwrap_or_default();
450                self.send_event(AgentEvent::ToolStart {
451                    operation_id: self.operation_id,
452                    name: tc.function.name.clone(),
453                    args,
454                });
455            }
456            // Execute tools
457            let results = self.executor.execute_tools(executor_calls).await;
458
459            // Add tool results to messages (OpenAI format: role=tool, tool_call_id, content)
460            for result in results {
461                let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
462                if let Some(tool_call) = tool_call {
463                    let output_json = match &result.output {
464                        Ok(value) => {
465                            serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
466                        }
467                        Err(e) => json!({ "error": e.to_string() }).to_string(),
468                    };
469
470                    self.send_event(AgentEvent::ToolComplete {
471                        operation_id: self.operation_id,
472                        name: tool_call.function.name.clone(),
473                        result: output_json.clone(),
474                    });
475
476                    // OpenAI tool result format
477                    let tool_result_message = Message {
478                        role: Role::Tool,
479                        content: Some(output_json),
480                        tool_calls: None,
481                        tool_call_id: Some(result.call_id),
482                    };
483                    _messages.push(tool_result_message);
484                }
485            }
486        }
487
488        // If we hit max iterations, make one final request to get a response (no tools = forced text)
489        // IMPORTANT: Only do this if max_iterations > 0 (0 means unlimited, so we never "hit" the limit)
490        if max_iterations > 0 && iteration >= max_iterations && !_messages.is_empty() {
491            debug!("Making final LLM call after hitting max iterations (forcing text response)");
492
493            // Add constraint message to force text response
494            let constraint_message = Message {
495                role: Role::User,
496                content: Some(
497                    "We've reached the iteration limit. Please provide a summary of:\n\
498                    1. What you've completed so far\n\
499                    2. What remains to be done\n\
500                    3. Recommended next steps for the user to continue"
501                        .to_string(),
502                ),
503                tool_calls: None,
504                tool_call_id: None,
505            };
506            _messages.push(constraint_message);
507
508            // Send with NO tools to force text response
509            let no_tools: Vec<LlmTool> = vec![];
510            let mut stream = self
511                .llm_client
512                .send(_messages.clone(), no_tools)
513                .await
514                .map_err(|e| CliError::ConfigError(e.to_string()))?;
515
516            // BUG FIX: Replace full_response instead of appending
517            full_response.clear();
518            loop {
519                // Check for cancellation FIRST (before waiting for stream)
520                if let Some(ref token) = self.cancellation_token {
521                    if token.is_cancelled() {
522                        debug!("Operation cancelled by user in final loop (pre-stream check)");
523                        self.send_event(AgentEvent::Cancelled {
524                            operation_id: self.operation_id,
525                        });
526                        return Err(CliError::ConfigError(
527                            "Operation cancelled by user".to_string(),
528                        ));
529                    }
530                }
531
532                // Use tokio::select! to check cancellation while waiting for stream
533                // Using cancellation_token.cancelled() for immediate cancellation detection
534                let chunk_result = if let Some(ref token) = self.cancellation_token {
535                    tokio::select! {
536                        chunk = stream.next() => chunk,
537                        _ = token.cancelled() => {
538                            debug!("Operation cancelled via token while waiting for stream");
539                            self.send_event(AgentEvent::Cancelled {
540                                operation_id: self.operation_id,
541                            });
542                            return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
543                        }
544                    }
545                } else {
546                    stream.next().await
547                };
548
549                let Some(chunk_result) = chunk_result else {
550                    // Stream ended
551                    break;
552                };
553
554                match chunk_result {
555                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
556                        full_response.push_str(&text);
557                        self.send_event(AgentEvent::ContentChunk {
558                            operation_id: self.operation_id,
559                            chunk: text,
560                        });
561                    }
562                    Ok(ProviderResponseChunk::Done(_)) => {
563                        break;
564                    }
565                    Err(e) => {
566                        debug!("Error in final LLM call: {}", e);
567                        break;
568                    }
569                    _ => {}
570                }
571            }
572        }
573
574        // IMPORTANT: Add final assistant response to message history for session persistence
575        // This is crucial for session export/share to work correctly
576        // Only add if we have content AND we haven't already added this response
577        if !full_response.is_empty() {
578            // Find the last assistant message and check if it has content
579            // If it has tool_calls but no content, UPDATE it instead of adding a new one
580            // This prevents accumulation of empty assistant messages in the history
581            let last_assistant_idx = _messages.iter().rposition(|m| m.role == Role::Assistant);
582
583            if let Some(idx) = last_assistant_idx {
584                let last_assistant = &mut _messages[idx];
585
586                // If the last assistant message has no content (tool_calls only), update it
587                if last_assistant.content.is_none()
588                    || last_assistant
589                        .content
590                        .as_ref()
591                        .map(|c| c.is_empty())
592                        .unwrap_or(true)
593                {
594                    last_assistant.content = Some(full_response.clone());
595                    debug!("Updated last assistant message with final response content");
596                } else {
597                    // Last assistant already has content, this shouldn't happen normally
598                    // but we add a new message to be safe
599                    debug!("Last assistant already has content, adding new message");
600                    let final_assistant_message = Message {
601                        role: Role::Assistant,
602                        content: Some(full_response.clone()),
603                        tool_calls: None,
604                        tool_call_id: None,
605                    };
606                    _messages.push(final_assistant_message);
607                }
608            } else {
609                // No assistant message found, add a new one
610                debug!("No assistant message found, adding new message");
611                let final_assistant_message = Message {
612                    role: Role::Assistant,
613                    content: Some(full_response.clone()),
614                    tool_calls: None,
615                    tool_call_id: None,
616                };
617                _messages.push(final_assistant_message);
618            }
619        }
620
621        self.send_event(AgentEvent::Done {
622            operation_id: self.operation_id,
623        });
624        Ok(full_response)
625    }
626
627    /// Get tool definitions formatted for the LLM
628    pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
629        self.tool_names
630            .iter()
631            .map(|name| {
632                let (description, parameters) = Self::get_tool_schema(name);
633                LlmTool {
634                    tool_type: "function".to_string(),
635                    function: limit_llm::types::ToolFunction {
636                        name: name.to_string(),
637                        description,
638                        parameters,
639                    },
640                }
641            })
642            .collect()
643    }
644
645    /// Get the schema (description and parameters) for a tool
646    fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
647        match name {
648            "file_read" => (
649                "Read the contents of a file".to_string(),
650                json!({
651                    "type": "object",
652                    "properties": {
653                        "path": {
654                            "type": "string",
655                            "description": "Path to the file to read"
656                        }
657                    },
658                    "required": ["path"]
659                }),
660            ),
661            "file_write" => (
662                "Write content to a file, creating parent directories if needed".to_string(),
663                json!({
664                    "type": "object",
665                    "properties": {
666                        "path": {
667                            "type": "string",
668                            "description": "Path to the file to write"
669                        },
670                        "content": {
671                            "type": "string",
672                            "description": "Content to write to the file"
673                        }
674                    },
675                    "required": ["path", "content"]
676                }),
677            ),
678            "file_edit" => (
679                "Replace text in a file with new text".to_string(),
680                json!({
681                    "type": "object",
682                    "properties": {
683                        "path": {
684                            "type": "string",
685                            "description": "Path to the file to edit"
686                        },
687                        "old_text": {
688                            "type": "string",
689                            "description": "Text to find and replace"
690                        },
691                        "new_text": {
692                            "type": "string",
693                            "description": "New text to replace with"
694                        }
695                    },
696                    "required": ["path", "old_text", "new_text"]
697                }),
698            ),
699            "bash" => (
700                "Execute a bash command in a shell".to_string(),
701                json!({
702                    "type": "object",
703                    "properties": {
704                        "command": {
705                            "type": "string",
706                            "description": "Bash command to execute"
707                        },
708                        "workdir": {
709                            "type": "string",
710                            "description": "Working directory (default: current directory)"
711                        },
712                        "timeout": {
713                            "type": "integer",
714                            "description": "Timeout in seconds (default: 60)"
715                        }
716                    },
717                    "required": ["command"]
718                }),
719            ),
720            "git_status" => (
721                "Get git repository status".to_string(),
722                json!({
723                    "type": "object",
724                    "properties": {},
725                    "required": []
726                }),
727            ),
728            "git_diff" => (
729                "Get git diff".to_string(),
730                json!({
731                    "type": "object",
732                    "properties": {},
733                    "required": []
734                }),
735            ),
736            "git_log" => (
737                "Get git commit log".to_string(),
738                json!({
739                    "type": "object",
740                    "properties": {
741                        "count": {
742                            "type": "integer",
743                            "description": "Number of commits to show (default: 10)"
744                        }
745                    },
746                    "required": []
747                }),
748            ),
749            "git_add" => (
750                "Add files to git staging area".to_string(),
751                json!({
752                    "type": "object",
753                    "properties": {
754                        "files": {
755                            "type": "array",
756                            "items": {"type": "string"},
757                            "description": "List of file paths to add"
758                        }
759                    },
760                    "required": ["files"]
761                }),
762            ),
763            "git_commit" => (
764                "Create a git commit".to_string(),
765                json!({
766                    "type": "object",
767                    "properties": {
768                        "message": {
769                            "type": "string",
770                            "description": "Commit message"
771                        }
772                    },
773                    "required": ["message"]
774                }),
775            ),
776            "git_push" => (
777                "Push commits to remote repository".to_string(),
778                json!({
779                    "type": "object",
780                    "properties": {
781                        "remote": {
782                            "type": "string",
783                            "description": "Remote name (default: origin)"
784                        },
785                        "branch": {
786                            "type": "string",
787                            "description": "Branch name (default: current branch)"
788                        }
789                    },
790                    "required": []
791                }),
792            ),
793            "git_pull" => (
794                "Pull changes from remote repository".to_string(),
795                json!({
796                    "type": "object",
797                    "properties": {
798                        "remote": {
799                            "type": "string",
800                            "description": "Remote name (default: origin)"
801                        },
802                        "branch": {
803                            "type": "string",
804                            "description": "Branch name (default: current branch)"
805                        }
806                    },
807                    "required": []
808                }),
809            ),
810            "git_clone" => (
811                "Clone a git repository".to_string(),
812                json!({
813                    "type": "object",
814                    "properties": {
815                        "url": {
816                            "type": "string",
817                            "description": "Repository URL to clone"
818                        },
819                        "directory": {
820                            "type": "string",
821                            "description": "Directory to clone into (optional)"
822                        }
823                    },
824                    "required": ["url"]
825                }),
826            ),
827            "grep" => (
828                "Search for text patterns in files using regex".to_string(),
829                json!({
830                    "type": "object",
831                    "properties": {
832                        "pattern": {
833                            "type": "string",
834                            "description": "Regex pattern to search for"
835                        },
836                        "path": {
837                            "type": "string",
838                            "description": "Path to search in (default: current directory)"
839                        }
840                    },
841                    "required": ["pattern"]
842                }),
843            ),
844            "ast_grep" => (
845                "Search code using AST patterns (structural code matching)".to_string(),
846                json!({
847                    "type": "object",
848                    "properties": {
849                        "pattern": {
850                            "type": "string",
851                            "description": "AST pattern to match"
852                        },
853                        "language": {
854                            "type": "string",
855                            "description": "Programming language (rust, typescript, python)"
856                        },
857                        "path": {
858                            "type": "string",
859                            "description": "Path to search in (default: current directory)"
860                        }
861                    },
862                    "required": ["pattern", "language"]
863                }),
864            ),
865            "lsp" => (
866                "Perform Language Server Protocol operations (goto_definition, find_references)"
867                    .to_string(),
868                json!({
869                    "type": "object",
870                    "properties": {
871                        "command": {
872                            "type": "string",
873                            "description": "LSP command: goto_definition or find_references"
874                        },
875                        "file_path": {
876                            "type": "string",
877                            "description": "Path to the file"
878                        },
879                        "position": {
880                            "type": "object",
881                            "description": "Position in the file (line, character)",
882                            "properties": {
883                                "line": {"type": "integer"},
884                                "character": {"type": "integer"}
885                            },
886                            "required": ["line", "character"]
887                        }
888                    },
889                    "required": ["command", "file_path", "position"]
890                }),
891            ),
892            "web_search" => (
893                format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
894                json!({
895                    "type": "object",
896                    "properties": {
897                        "query": {
898                            "type": "string",
899                            "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
900                        },
901                        "numResults": {
902                            "type": "integer",
903                            "description": "Number of results to return (default: 8, max: 20)",
904                            "default": 8
905                        }
906                    },
907                    "required": ["query"]
908                }),
909            ),
910            "web_fetch" => (
911                "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
912                json!({
913                    "type": "object",
914                    "properties": {
915                        "url": {
916                            "type": "string",
917                            "description": "URL to fetch (must start with http:// or https://)"
918                        },
919                        "format": {
920                            "type": "string",
921                            "enum": ["markdown", "text", "html"],
922                            "default": "markdown",
923                            "description": "Output format (default: markdown)"
924                        }
925                    },
926                    "required": ["url"]
927                }),
928            ),
929            _ => (
930                format!("Tool: {}", name),
931                json!({
932                    "type": "object",
933                    "properties": {},
934                    "required": []
935                }),
936            ),
937        }
938    }
939
940    /// Send an event through the event channel
941    fn send_event(&self, event: AgentEvent) {
942        if let Some(ref tx) = self.event_tx {
943            let _ = tx.send(event);
944        }
945    }
946
947    /// Check if the bridge is ready to process messages
948    #[allow(dead_code)]
949    pub fn is_ready(&self) -> bool {
950        self.config
951            .providers
952            .get(&self.config.provider)
953            .map(|p| p.api_key_or_env(&self.config.provider).is_some())
954            .unwrap_or(false)
955    }
956
957    /// Get the current model name
958    pub fn model(&self) -> &str {
959        self.config
960            .providers
961            .get(&self.config.provider)
962            .map(|p| p.model.as_str())
963            .unwrap_or("")
964    }
965
966    /// Get the max tokens setting
967    pub fn max_tokens(&self) -> u32 {
968        self.config
969            .providers
970            .get(&self.config.provider)
971            .map(|p| p.max_tokens)
972            .unwrap_or(4096)
973    }
974
975    /// Get the timeout setting
976    pub fn timeout(&self) -> u64 {
977        self.config
978            .providers
979            .get(&self.config.provider)
980            .map(|p| p.timeout)
981            .unwrap_or(60)
982    }
983}
984/// Calculate cost based on model pricing (per 1M tokens)
985fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
986    let (input_price, output_price) = match model {
987        // Claude 3.5 Sonnet: $3/1M input, $15/1M output
988        "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
989        // GPT-4: $30/1M input, $60/1M output
990        "gpt-4" => (30.0, 60.0),
991        // GPT-4 Turbo: $10/1M input, $30/1M output
992        "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
993        // Default: no cost tracking
994        _ => (0.0, 0.0),
995    };
996    (input_tokens as f64 * input_price / 1_000_000.0)
997        + (output_tokens as f64 * output_price / 1_000_000.0)
998}
999
1000#[cfg(test)]
1001mod tests {
1002    use super::*;
1003    use limit_llm::{Config as LlmConfig, ProviderConfig};
1004    use std::collections::HashMap;
1005
1006    #[tokio::test]
1007    async fn test_agent_bridge_new() {
1008        let mut providers = HashMap::new();
1009        providers.insert(
1010            "anthropic".to_string(),
1011            ProviderConfig {
1012                api_key: Some("test-key".to_string()),
1013                model: "claude-3-5-sonnet-20241022".to_string(),
1014                base_url: None,
1015                max_tokens: 4096,
1016                timeout: 60,
1017                max_iterations: 100,
1018                thinking_enabled: false,
1019                clear_thinking: true,
1020            },
1021        );
1022        let config = LlmConfig {
1023            provider: "anthropic".to_string(),
1024            providers,
1025        };
1026
1027        let bridge = AgentBridge::new(config).unwrap();
1028        assert!(bridge.is_ready());
1029    }
1030
1031    #[tokio::test]
1032    async fn test_agent_bridge_new_no_api_key() {
1033        let mut providers = HashMap::new();
1034        providers.insert(
1035            "anthropic".to_string(),
1036            ProviderConfig {
1037                api_key: None,
1038                model: "claude-3-5-sonnet-20241022".to_string(),
1039                base_url: None,
1040                max_tokens: 4096,
1041                timeout: 60,
1042                max_iterations: 100,
1043                thinking_enabled: false,
1044                clear_thinking: true,
1045            },
1046        );
1047        let config = LlmConfig {
1048            provider: "anthropic".to_string(),
1049            providers,
1050        };
1051
1052        let result = AgentBridge::new(config);
1053        assert!(result.is_err());
1054    }
1055
1056    #[tokio::test]
1057    async fn test_get_tool_definitions() {
1058        let mut providers = HashMap::new();
1059        providers.insert(
1060            "anthropic".to_string(),
1061            ProviderConfig {
1062                api_key: Some("test-key".to_string()),
1063                model: "claude-3-5-sonnet-20241022".to_string(),
1064                base_url: None,
1065                max_tokens: 4096,
1066                timeout: 60,
1067                max_iterations: 100,
1068                thinking_enabled: false,
1069                clear_thinking: true,
1070            },
1071        );
1072        let config = LlmConfig {
1073            provider: "anthropic".to_string(),
1074            providers,
1075        };
1076
1077        let bridge = AgentBridge::new(config).unwrap();
1078        let definitions = bridge.get_tool_definitions();
1079
1080        assert_eq!(definitions.len(), 17);
1081
1082        // Check file_read tool definition
1083        let file_read = definitions
1084            .iter()
1085            .find(|d| d.function.name == "file_read")
1086            .unwrap();
1087        assert_eq!(file_read.tool_type, "function");
1088        assert_eq!(file_read.function.name, "file_read");
1089        assert!(file_read.function.description.contains("Read"));
1090
1091        // Check bash tool definition
1092        let bash = definitions
1093            .iter()
1094            .find(|d| d.function.name == "bash")
1095            .unwrap();
1096        assert_eq!(bash.function.name, "bash");
1097        assert!(bash.function.parameters["required"]
1098            .as_array()
1099            .unwrap()
1100            .contains(&"command".into()));
1101    }
1102
1103    #[test]
1104    fn test_get_tool_schema() {
1105        let (desc, params) = AgentBridge::get_tool_schema("file_read");
1106        assert!(desc.contains("Read"));
1107        assert_eq!(params["properties"]["path"]["type"], "string");
1108        assert!(params["required"]
1109            .as_array()
1110            .unwrap()
1111            .contains(&"path".into()));
1112
1113        let (desc, params) = AgentBridge::get_tool_schema("bash");
1114        assert!(desc.contains("bash"));
1115        assert_eq!(params["properties"]["command"]["type"], "string");
1116
1117        let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
1118        assert!(desc.contains("unknown_tool"));
1119    }
1120
1121    #[test]
1122    fn test_is_ready() {
1123        let mut providers = HashMap::new();
1124        providers.insert(
1125            "anthropic".to_string(),
1126            ProviderConfig {
1127                api_key: Some("test-key".to_string()),
1128                model: "claude-3-5-sonnet-20241022".to_string(),
1129                base_url: None,
1130                max_tokens: 4096,
1131                timeout: 60,
1132                max_iterations: 100,
1133                thinking_enabled: false,
1134                clear_thinking: true,
1135            },
1136        );
1137        let config_with_key = LlmConfig {
1138            provider: "anthropic".to_string(),
1139            providers,
1140        };
1141
1142        let bridge = AgentBridge::new(config_with_key).unwrap();
1143        assert!(bridge.is_ready());
1144    }
1145}