Skip to main content

limit_cli/
agent_bridge.rs

1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4    AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5    GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6    LspTool, WebFetchTool, WebSearchTool,
7};
8use chrono::Datelike;
9use futures::StreamExt;
10use limit_agent::executor::{ToolCall, ToolExecutor};
11use limit_agent::registry::ToolRegistry;
12use limit_llm::providers::LlmProvider;
13use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
14use limit_llm::ProviderFactory;
15use limit_llm::ProviderResponseChunk;
16use limit_llm::TrackingDb;
17use serde_json::json;
18use tokio::sync::mpsc;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, instrument};
21
22/// Event types for streaming from agent to REPL
23#[derive(Debug, Clone)]
24#[allow(dead_code)]
25pub enum AgentEvent {
26    Thinking {
27        operation_id: u64,
28    },
29    ToolStart {
30        operation_id: u64,
31        name: String,
32        args: serde_json::Value,
33    },
34    ToolComplete {
35        operation_id: u64,
36        name: String,
37        result: String,
38    },
39    ContentChunk {
40        operation_id: u64,
41        chunk: String,
42    },
43    Done {
44        operation_id: u64,
45    },
46    Cancelled {
47        operation_id: u64,
48    },
49    Error {
50        operation_id: u64,
51        message: String,
52    },
53    TokenUsage {
54        operation_id: u64,
55        input_tokens: u64,
56        output_tokens: u64,
57    },
58}
59
60/// Bridge connecting limit-cli REPL to limit-agent executor and limit-llm client
61pub struct AgentBridge {
62    /// LLM client for communicating with LLM providers
63    llm_client: Box<dyn LlmProvider>,
64    /// Tool executor for running tool calls
65    executor: ToolExecutor,
66    /// List of registered tool names
67    tool_names: Vec<&'static str>,
68    /// Configuration loaded from ~/.limit/config.toml
69    config: limit_llm::Config,
70    /// Event sender for streaming events to REPL
71    event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
72    /// Token usage tracking database
73    tracking_db: TrackingDb,
74    /// Cancellation token for aborting current operation
75    cancellation_token: Option<CancellationToken>,
76    /// Current operation ID for event tracking
77    operation_id: u64,
78}
79
80impl AgentBridge {
81    /// Create a new AgentBridge with the given configuration
82    ///
83    /// # Arguments
84    /// * `config` - LLM configuration (API key, model, etc.)
85    ///
86    /// # Returns
87    /// A new AgentBridge instance or an error if initialization fails
88    pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
89        let llm_client = ProviderFactory::create_provider(&config)
90            .map_err(|e| CliError::ConfigError(e.to_string()))?;
91
92        let mut tool_registry = ToolRegistry::new();
93        Self::register_tools(&mut tool_registry);
94
95        // Create executor (which takes ownership of registry as Arc)
96        let executor = ToolExecutor::new(tool_registry);
97
98        // Generate tool definitions before giving ownership to executor
99        let tool_names = vec![
100            "file_read",
101            "file_write",
102            "file_edit",
103            "bash",
104            "git_status",
105            "git_diff",
106            "git_log",
107            "git_add",
108            "git_commit",
109            "git_push",
110            "git_pull",
111            "git_clone",
112            "grep",
113            "ast_grep",
114            "lsp",
115            "web_search",
116            "web_fetch",
117        ];
118
119        Ok(Self {
120            llm_client,
121            executor,
122            tool_names,
123            config,
124            event_tx: None,
125            tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
126            cancellation_token: None,
127            operation_id: 0,
128        })
129    }
130
131    /// Set the event channel sender for streaming events
132    pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
133        self.event_tx = Some(tx);
134    }
135
136    /// Set the cancellation token and operation ID for this operation
137    pub fn set_cancellation_token(&mut self, token: CancellationToken, operation_id: u64) {
138        debug!("set_cancellation_token: operation_id={}", operation_id);
139        self.cancellation_token = Some(token);
140        self.operation_id = operation_id;
141    }
142
143    /// Clear the cancellation token
144    pub fn clear_cancellation_token(&mut self) {
145        self.cancellation_token = None;
146    }
147
148    /// Register all CLI tools into the tool registry
149    fn register_tools(registry: &mut ToolRegistry) {
150        // File tools
151        registry
152            .register(FileReadTool::new())
153            .expect("Failed to register file_read");
154        registry
155            .register(FileWriteTool::new())
156            .expect("Failed to register file_write");
157        registry
158            .register(FileEditTool::new())
159            .expect("Failed to register file_edit");
160
161        // Bash tool
162        registry
163            .register(BashTool::new())
164            .expect("Failed to register bash");
165
166        // Git tools
167        registry
168            .register(GitStatusTool::new())
169            .expect("Failed to register git_status");
170        registry
171            .register(GitDiffTool::new())
172            .expect("Failed to register git_diff");
173        registry
174            .register(GitLogTool::new())
175            .expect("Failed to register git_log");
176        registry
177            .register(GitAddTool::new())
178            .expect("Failed to register git_add");
179        registry
180            .register(GitCommitTool::new())
181            .expect("Failed to register git_commit");
182        registry
183            .register(GitPushTool::new())
184            .expect("Failed to register git_push");
185        registry
186            .register(GitPullTool::new())
187            .expect("Failed to register git_pull");
188        registry
189            .register(GitCloneTool::new())
190            .expect("Failed to register git_clone");
191
192        // Analysis tools
193        registry
194            .register(GrepTool::new())
195            .expect("Failed to register grep");
196        registry
197            .register(AstGrepTool::new())
198            .expect("Failed to register ast_grep");
199        registry
200            .register(LspTool::new())
201            .expect("Failed to register lsp");
202
203        // Web tools
204        registry
205            .register(WebSearchTool::new())
206            .expect("Failed to register web_search");
207        registry
208            .register(WebFetchTool::new())
209            .expect("Failed to register web_fetch");
210    }
211
212    /// Process a user message through the LLM and execute any tool calls
213    ///
214    /// # Arguments
215    /// * `user_input` - The user's message to process
216    /// * `messages` - The conversation history (will be updated in place)
217    ///
218    /// # Returns
219    /// The final response from the LLM or an error
220    #[instrument(skip(self, _messages))]
221    pub async fn process_message(
222        &mut self,
223        user_input: &str,
224        _messages: &mut Vec<Message>,
225    ) -> Result<String, CliError> {
226        // Add system message if this is the first message in the conversation
227        // Note: Some providers (z.ai) don't support system role, but OpenAI-compatible APIs generally do
228        if _messages.is_empty() {
229            let system_message = Message {
230                role: Role::System,
231                content: Some(SYSTEM_PROMPT.to_string()),
232                tool_calls: None,
233                tool_call_id: None,
234            };
235            _messages.push(system_message);
236        }
237
238        // Add user message to history
239        let user_message = Message {
240            role: Role::User,
241            content: Some(user_input.to_string()),
242            tool_calls: None,
243            tool_call_id: None,
244        };
245        _messages.push(user_message);
246
247        // Get tool definitions
248        let tool_definitions = self.get_tool_definitions();
249
250        // Main processing loop
251        let mut full_response = String::new();
252        let mut tool_calls: Vec<LlmToolCall> = Vec::new();
253        let max_iterations = self
254            .config
255            .providers
256            .get(&self.config.provider)
257            .map(|p| p.max_iterations)
258            .unwrap_or(100); // Allow enough iterations for complex tasks
259        let mut iteration = 0;
260
261        while max_iterations == 0 || iteration < max_iterations {
262            iteration += 1;
263            debug!("Agent loop iteration {}", iteration);
264
265            // Send thinking event
266            debug!(
267                "Sending Thinking event with operation_id={}",
268                self.operation_id
269            );
270            self.send_event(AgentEvent::Thinking {
271                operation_id: self.operation_id,
272            });
273
274            // Track timing for token usage
275            let request_start = std::time::Instant::now();
276
277            // Call LLM
278            let mut stream = self
279                .llm_client
280                .send(_messages.clone(), tool_definitions.clone())
281                .await
282                .map_err(|e| CliError::ConfigError(e.to_string()))?;
283
284            tool_calls.clear();
285            let mut current_content = String::new();
286            // Track tool calls: (id) -> (name, args)
287            let mut accumulated_calls: std::collections::HashMap<
288                String,
289                (String, serde_json::Value),
290            > = std::collections::HashMap::new();
291
292            // Process stream chunks with cancellation support
293            loop {
294                // Check for cancellation FIRST (before waiting for stream)
295                if let Some(ref token) = self.cancellation_token {
296                    if token.is_cancelled() {
297                        debug!("Operation cancelled by user (pre-stream check)");
298                        self.send_event(AgentEvent::Cancelled {
299                            operation_id: self.operation_id,
300                        });
301                        return Err(CliError::ConfigError(
302                            "Operation cancelled by user".to_string(),
303                        ));
304                    }
305                }
306
307                // Use tokio::select! to check cancellation while waiting for stream
308                // Using cancellation_token.cancelled() for immediate cancellation detection
309                let chunk_result = if let Some(ref token) = self.cancellation_token {
310                    tokio::select! {
311                        chunk = stream.next() => chunk,
312                        _ = token.cancelled() => {
313                            debug!("Operation cancelled via token while waiting for stream");
314                            self.send_event(AgentEvent::Cancelled {
315                                operation_id: self.operation_id,
316                            });
317                            return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
318                        }
319                    }
320                } else {
321                    stream.next().await
322                };
323
324                let Some(chunk_result) = chunk_result else {
325                    // Stream ended
326                    break;
327                };
328
329                match chunk_result {
330                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
331                        current_content.push_str(&text);
332                        debug!(
333                            "ContentDelta: {} chars (total: {})",
334                            text.len(),
335                            current_content.len()
336                        );
337                        self.send_event(AgentEvent::ContentChunk {
338                            operation_id: self.operation_id,
339                            chunk: text,
340                        });
341                    }
342                    Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
343                        // Ignore reasoning chunks for now
344                    }
345                    Ok(ProviderResponseChunk::ToolCallDelta {
346                        id,
347                        name,
348                        arguments,
349                    }) => {
350                        debug!(
351                            "ToolCallDelta: id={}, name={}, args_len={}",
352                            id,
353                            name,
354                            arguments.to_string().len()
355                        );
356                        // Store/merge tool call arguments
357                        accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
358                    }
359                    Ok(ProviderResponseChunk::Done(usage)) => {
360                        // Track token usage
361                        let duration_ms = request_start.elapsed().as_millis() as u64;
362                        let cost =
363                            calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
364                        let _ = self.tracking_db.track_request(
365                            self.model(),
366                            usage.input_tokens,
367                            usage.output_tokens,
368                            cost,
369                            duration_ms,
370                        );
371                        // Emit token usage event for TUI display
372                        self.send_event(AgentEvent::TokenUsage {
373                            operation_id: self.operation_id,
374                            input_tokens: usage.input_tokens,
375                            output_tokens: usage.output_tokens,
376                        });
377                        break;
378                    }
379                    Err(e) => {
380                        let error_msg = format!("LLM error: {}", e);
381                        self.send_event(AgentEvent::Error {
382                            operation_id: self.operation_id,
383                            message: error_msg.clone(),
384                        });
385                        return Err(CliError::ConfigError(error_msg));
386                    }
387                }
388            }
389
390            // Convert accumulated calls to Vec<ToolCall>
391            tool_calls = accumulated_calls
392                .into_iter()
393                .map(|(id, (name, args))| LlmToolCall {
394                    id,
395                    tool_type: "function".to_string(),
396                    function: limit_llm::types::FunctionCall {
397                        name,
398                        arguments: args.to_string(),
399                    },
400                })
401                .collect();
402
403            // BUG FIX: Don't accumulate content across iterations
404            // Only store content from the current iteration
405            // If there are tool calls, we'll continue the loop and the LLM will see the tool results
406            // If there are NO tool calls, this is the final response
407            full_response = current_content.clone();
408
409            debug!(
410                "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
411                iteration,
412                current_content.len(),
413                tool_calls.len(),
414                full_response.len()
415            );
416
417            // If no tool calls, we're done
418            if tool_calls.is_empty() {
419                debug!("No tool calls, breaking loop after iteration {}", iteration);
420                break;
421            }
422
423            debug!(
424                "Tool calls found (count={}), continuing to iteration {}",
425                tool_calls.len(),
426                iteration + 1
427            );
428
429            // Execute tool calls - add assistant message with tool_calls
430            // Note: Per OpenAI API spec, when tool_calls are present, content should be null
431            let assistant_message = Message {
432                role: Role::Assistant,
433                content: None, // Don't include content when tool_calls are present
434                tool_calls: Some(tool_calls.clone()),
435                tool_call_id: None,
436            };
437            _messages.push(assistant_message);
438
439            // Convert LLM tool calls to executor tool calls
440            let executor_calls: Vec<ToolCall> = tool_calls
441                .iter()
442                .map(|tc| {
443                    let args: serde_json::Value =
444                        serde_json::from_str(&tc.function.arguments).unwrap_or_default();
445                    ToolCall::new(&tc.id, &tc.function.name, args)
446                })
447                .collect();
448
449            // Send ToolStart event for each tool BEFORE execution
450            for tc in &tool_calls {
451                let args: serde_json::Value =
452                    serde_json::from_str(&tc.function.arguments).unwrap_or_default();
453                self.send_event(AgentEvent::ToolStart {
454                    operation_id: self.operation_id,
455                    name: tc.function.name.clone(),
456                    args,
457                });
458            }
459            // Execute tools
460            let results = self.executor.execute_tools(executor_calls).await;
461
462            // Add tool results to messages (OpenAI format: role=tool, tool_call_id, content)
463            for result in results {
464                let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
465                if let Some(tool_call) = tool_call {
466                    let output_json = match &result.output {
467                        Ok(value) => {
468                            serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
469                        }
470                        Err(e) => json!({ "error": e.to_string() }).to_string(),
471                    };
472
473                    self.send_event(AgentEvent::ToolComplete {
474                        operation_id: self.operation_id,
475                        name: tool_call.function.name.clone(),
476                        result: output_json.clone(),
477                    });
478
479                    // OpenAI tool result format
480                    let tool_result_message = Message {
481                        role: Role::Tool,
482                        content: Some(output_json),
483                        tool_calls: None,
484                        tool_call_id: Some(result.call_id),
485                    };
486                    _messages.push(tool_result_message);
487                }
488            }
489        }
490
491        // If we hit max iterations, make one final request to get a response (no tools = forced text)
492        // IMPORTANT: Only do this if max_iterations > 0 (0 means unlimited, so we never "hit" the limit)
493        if max_iterations > 0 && iteration >= max_iterations && !_messages.is_empty() {
494            debug!("Making final LLM call after hitting max iterations (forcing text response)");
495
496            // Add constraint message to force text response
497            let constraint_message = Message {
498                role: Role::User,
499                content: Some(
500                    "We've reached the iteration limit. Please provide a summary of:\n\
501                    1. What you've completed so far\n\
502                    2. What remains to be done\n\
503                    3. Recommended next steps for the user to continue"
504                        .to_string(),
505                ),
506                tool_calls: None,
507                tool_call_id: None,
508            };
509            _messages.push(constraint_message);
510
511            // Send with NO tools to force text response
512            let no_tools: Vec<LlmTool> = vec![];
513            let mut stream = self
514                .llm_client
515                .send(_messages.clone(), no_tools)
516                .await
517                .map_err(|e| CliError::ConfigError(e.to_string()))?;
518
519            // BUG FIX: Replace full_response instead of appending
520            full_response.clear();
521            loop {
522                // Check for cancellation FIRST (before waiting for stream)
523                if let Some(ref token) = self.cancellation_token {
524                    if token.is_cancelled() {
525                        debug!("Operation cancelled by user in final loop (pre-stream check)");
526                        self.send_event(AgentEvent::Cancelled {
527                            operation_id: self.operation_id,
528                        });
529                        return Err(CliError::ConfigError(
530                            "Operation cancelled by user".to_string(),
531                        ));
532                    }
533                }
534
535                // Use tokio::select! to check cancellation while waiting for stream
536                // Using cancellation_token.cancelled() for immediate cancellation detection
537                let chunk_result = if let Some(ref token) = self.cancellation_token {
538                    tokio::select! {
539                        chunk = stream.next() => chunk,
540                        _ = token.cancelled() => {
541                            debug!("Operation cancelled via token while waiting for stream");
542                            self.send_event(AgentEvent::Cancelled {
543                                operation_id: self.operation_id,
544                            });
545                            return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
546                        }
547                    }
548                } else {
549                    stream.next().await
550                };
551
552                let Some(chunk_result) = chunk_result else {
553                    // Stream ended
554                    break;
555                };
556
557                match chunk_result {
558                    Ok(ProviderResponseChunk::ContentDelta(text)) => {
559                        full_response.push_str(&text);
560                        self.send_event(AgentEvent::ContentChunk {
561                            operation_id: self.operation_id,
562                            chunk: text,
563                        });
564                    }
565                    Ok(ProviderResponseChunk::Done(_)) => {
566                        break;
567                    }
568                    Err(e) => {
569                        debug!("Error in final LLM call: {}", e);
570                        break;
571                    }
572                    _ => {}
573                }
574            }
575        }
576
577        // IMPORTANT: Add final assistant response to message history for session persistence
578        // This is crucial for session export/share to work correctly
579        // Only add if we have content AND we haven't already added this response
580        if !full_response.is_empty() {
581            // Find the last assistant message and check if it has content
582            // If it has tool_calls but no content, UPDATE it instead of adding a new one
583            // This prevents accumulation of empty assistant messages in the history
584            let last_assistant_idx = _messages.iter().rposition(|m| m.role == Role::Assistant);
585
586            if let Some(idx) = last_assistant_idx {
587                let last_assistant = &mut _messages[idx];
588
589                // If the last assistant message has no content (tool_calls only), update it
590                if last_assistant.content.is_none()
591                    || last_assistant
592                        .content
593                        .as_ref()
594                        .map(|c| c.is_empty())
595                        .unwrap_or(true)
596                {
597                    last_assistant.content = Some(full_response.clone());
598                    debug!("Updated last assistant message with final response content");
599                } else {
600                    // Last assistant already has content, this shouldn't happen normally
601                    // but we add a new message to be safe
602                    debug!("Last assistant already has content, adding new message");
603                    let final_assistant_message = Message {
604                        role: Role::Assistant,
605                        content: Some(full_response.clone()),
606                        tool_calls: None,
607                        tool_call_id: None,
608                    };
609                    _messages.push(final_assistant_message);
610                }
611            } else {
612                // No assistant message found, add a new one
613                debug!("No assistant message found, adding new message");
614                let final_assistant_message = Message {
615                    role: Role::Assistant,
616                    content: Some(full_response.clone()),
617                    tool_calls: None,
618                    tool_call_id: None,
619                };
620                _messages.push(final_assistant_message);
621            }
622        }
623
624        self.send_event(AgentEvent::Done {
625            operation_id: self.operation_id,
626        });
627        Ok(full_response)
628    }
629
630    /// Get tool definitions formatted for the LLM
631    pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
632        self.tool_names
633            .iter()
634            .map(|name| {
635                let (description, parameters) = Self::get_tool_schema(name);
636                LlmTool {
637                    tool_type: "function".to_string(),
638                    function: limit_llm::types::ToolFunction {
639                        name: name.to_string(),
640                        description,
641                        parameters,
642                    },
643                }
644            })
645            .collect()
646    }
647
648    /// Get the schema (description and parameters) for a tool
649    fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
650        match name {
651            "file_read" => (
652                "Read the contents of a file".to_string(),
653                json!({
654                    "type": "object",
655                    "properties": {
656                        "path": {
657                            "type": "string",
658                            "description": "Path to the file to read"
659                        }
660                    },
661                    "required": ["path"]
662                }),
663            ),
664            "file_write" => (
665                "Write content to a file, creating parent directories if needed".to_string(),
666                json!({
667                    "type": "object",
668                    "properties": {
669                        "path": {
670                            "type": "string",
671                            "description": "Path to the file to write"
672                        },
673                        "content": {
674                            "type": "string",
675                            "description": "Content to write to the file"
676                        }
677                    },
678                    "required": ["path", "content"]
679                }),
680            ),
681            "file_edit" => (
682                "Replace text in a file with new text".to_string(),
683                json!({
684                    "type": "object",
685                    "properties": {
686                        "path": {
687                            "type": "string",
688                            "description": "Path to the file to edit"
689                        },
690                        "old_text": {
691                            "type": "string",
692                            "description": "Text to find and replace"
693                        },
694                        "new_text": {
695                            "type": "string",
696                            "description": "New text to replace with"
697                        }
698                    },
699                    "required": ["path", "old_text", "new_text"]
700                }),
701            ),
702            "bash" => (
703                "Execute a bash command in a shell".to_string(),
704                json!({
705                    "type": "object",
706                    "properties": {
707                        "command": {
708                            "type": "string",
709                            "description": "Bash command to execute"
710                        },
711                        "workdir": {
712                            "type": "string",
713                            "description": "Working directory (default: current directory)"
714                        },
715                        "timeout": {
716                            "type": "integer",
717                            "description": "Timeout in seconds (default: 60)"
718                        }
719                    },
720                    "required": ["command"]
721                }),
722            ),
723            "git_status" => (
724                "Get git repository status".to_string(),
725                json!({
726                    "type": "object",
727                    "properties": {},
728                    "required": []
729                }),
730            ),
731            "git_diff" => (
732                "Get git diff".to_string(),
733                json!({
734                    "type": "object",
735                    "properties": {},
736                    "required": []
737                }),
738            ),
739            "git_log" => (
740                "Get git commit log".to_string(),
741                json!({
742                    "type": "object",
743                    "properties": {
744                        "count": {
745                            "type": "integer",
746                            "description": "Number of commits to show (default: 10)"
747                        }
748                    },
749                    "required": []
750                }),
751            ),
752            "git_add" => (
753                "Add files to git staging area".to_string(),
754                json!({
755                    "type": "object",
756                    "properties": {
757                        "files": {
758                            "type": "array",
759                            "items": {"type": "string"},
760                            "description": "List of file paths to add"
761                        }
762                    },
763                    "required": ["files"]
764                }),
765            ),
766            "git_commit" => (
767                "Create a git commit".to_string(),
768                json!({
769                    "type": "object",
770                    "properties": {
771                        "message": {
772                            "type": "string",
773                            "description": "Commit message"
774                        }
775                    },
776                    "required": ["message"]
777                }),
778            ),
779            "git_push" => (
780                "Push commits to remote repository".to_string(),
781                json!({
782                    "type": "object",
783                    "properties": {
784                        "remote": {
785                            "type": "string",
786                            "description": "Remote name (default: origin)"
787                        },
788                        "branch": {
789                            "type": "string",
790                            "description": "Branch name (default: current branch)"
791                        }
792                    },
793                    "required": []
794                }),
795            ),
796            "git_pull" => (
797                "Pull changes from remote repository".to_string(),
798                json!({
799                    "type": "object",
800                    "properties": {
801                        "remote": {
802                            "type": "string",
803                            "description": "Remote name (default: origin)"
804                        },
805                        "branch": {
806                            "type": "string",
807                            "description": "Branch name (default: current branch)"
808                        }
809                    },
810                    "required": []
811                }),
812            ),
813            "git_clone" => (
814                "Clone a git repository".to_string(),
815                json!({
816                    "type": "object",
817                    "properties": {
818                        "url": {
819                            "type": "string",
820                            "description": "Repository URL to clone"
821                        },
822                        "directory": {
823                            "type": "string",
824                            "description": "Directory to clone into (optional)"
825                        }
826                    },
827                    "required": ["url"]
828                }),
829            ),
830            "grep" => (
831                "Search for text patterns in files using regex".to_string(),
832                json!({
833                    "type": "object",
834                    "properties": {
835                        "pattern": {
836                            "type": "string",
837                            "description": "Regex pattern to search for"
838                        },
839                        "path": {
840                            "type": "string",
841                            "description": "Path to search in (default: current directory)"
842                        }
843                    },
844                    "required": ["pattern"]
845                }),
846            ),
847            "ast_grep" => (
848                "Search code using AST patterns (structural code matching)".to_string(),
849                json!({
850                    "type": "object",
851                    "properties": {
852                        "pattern": {
853                            "type": "string",
854                            "description": "AST pattern to match"
855                        },
856                        "language": {
857                            "type": "string",
858                            "description": "Programming language (rust, typescript, python)"
859                        },
860                        "path": {
861                            "type": "string",
862                            "description": "Path to search in (default: current directory)"
863                        }
864                    },
865                    "required": ["pattern", "language"]
866                }),
867            ),
868            "lsp" => (
869                "Perform Language Server Protocol operations (goto_definition, find_references)"
870                    .to_string(),
871                json!({
872                    "type": "object",
873                    "properties": {
874                        "command": {
875                            "type": "string",
876                            "description": "LSP command: goto_definition or find_references"
877                        },
878                        "file_path": {
879                            "type": "string",
880                            "description": "Path to the file"
881                        },
882                        "position": {
883                            "type": "object",
884                            "description": "Position in the file (line, character)",
885                            "properties": {
886                                "line": {"type": "integer"},
887                                "character": {"type": "integer"}
888                            },
889                            "required": ["line", "character"]
890                        }
891                    },
892                    "required": ["command", "file_path", "position"]
893                }),
894            ),
895            "web_search" => (
896                format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
897                json!({
898                    "type": "object",
899                    "properties": {
900                        "query": {
901                            "type": "string",
902                            "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
903                        },
904                        "numResults": {
905                            "type": "integer",
906                            "description": "Number of results to return (default: 8, max: 20)",
907                            "default": 8
908                        }
909                    },
910                    "required": ["query"]
911                }),
912            ),
913            "web_fetch" => (
914                "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
915                json!({
916                    "type": "object",
917                    "properties": {
918                        "url": {
919                            "type": "string",
920                            "description": "URL to fetch (must start with http:// or https://)"
921                        },
922                        "format": {
923                            "type": "string",
924                            "enum": ["markdown", "text", "html"],
925                            "default": "markdown",
926                            "description": "Output format (default: markdown)"
927                        }
928                    },
929                    "required": ["url"]
930                }),
931            ),
932            _ => (
933                format!("Tool: {}", name),
934                json!({
935                    "type": "object",
936                    "properties": {},
937                    "required": []
938                }),
939            ),
940        }
941    }
942
943    /// Send an event through the event channel
944    fn send_event(&self, event: AgentEvent) {
945        if let Some(ref tx) = self.event_tx {
946            let _ = tx.send(event);
947        }
948    }
949
950    /// Check if the bridge is ready to process messages
951    #[allow(dead_code)]
952    pub fn is_ready(&self) -> bool {
953        self.config
954            .providers
955            .get(&self.config.provider)
956            .map(|p| p.api_key_or_env(&self.config.provider).is_some())
957            .unwrap_or(false)
958    }
959
960    /// Get the current model name
961    pub fn model(&self) -> &str {
962        self.config
963            .providers
964            .get(&self.config.provider)
965            .map(|p| p.model.as_str())
966            .unwrap_or("")
967    }
968
969    /// Get the max tokens setting
970    pub fn max_tokens(&self) -> u32 {
971        self.config
972            .providers
973            .get(&self.config.provider)
974            .map(|p| p.max_tokens)
975            .unwrap_or(4096)
976    }
977
978    /// Get the timeout setting
979    pub fn timeout(&self) -> u64 {
980        self.config
981            .providers
982            .get(&self.config.provider)
983            .map(|p| p.timeout)
984            .unwrap_or(60)
985    }
986}
987/// Calculate cost based on model pricing (per 1M tokens)
988fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
989    let (input_price, output_price) = match model {
990        // Claude 3.5 Sonnet: $3/1M input, $15/1M output
991        "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
992        // GPT-4: $30/1M input, $60/1M output
993        "gpt-4" => (30.0, 60.0),
994        // GPT-4 Turbo: $10/1M input, $30/1M output
995        "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
996        // Default: no cost tracking
997        _ => (0.0, 0.0),
998    };
999    (input_tokens as f64 * input_price / 1_000_000.0)
1000        + (output_tokens as f64 * output_price / 1_000_000.0)
1001}
1002
1003#[cfg(test)]
1004mod tests {
1005    use super::*;
1006    use limit_llm::{Config as LlmConfig, ProviderConfig};
1007    use std::collections::HashMap;
1008
1009    #[tokio::test]
1010    async fn test_agent_bridge_new() {
1011        let mut providers = HashMap::new();
1012        providers.insert(
1013            "anthropic".to_string(),
1014            ProviderConfig {
1015                api_key: Some("test-key".to_string()),
1016                model: "claude-3-5-sonnet-20241022".to_string(),
1017                base_url: None,
1018                max_tokens: 4096,
1019                timeout: 60,
1020                max_iterations: 100,
1021                thinking_enabled: false,
1022                clear_thinking: true,
1023            },
1024        );
1025        let config = LlmConfig {
1026            provider: "anthropic".to_string(),
1027            providers,
1028        };
1029
1030        let bridge = AgentBridge::new(config).unwrap();
1031        assert!(bridge.is_ready());
1032    }
1033
1034    #[tokio::test]
1035    async fn test_agent_bridge_new_no_api_key() {
1036        let mut providers = HashMap::new();
1037        providers.insert(
1038            "anthropic".to_string(),
1039            ProviderConfig {
1040                api_key: None,
1041                model: "claude-3-5-sonnet-20241022".to_string(),
1042                base_url: None,
1043                max_tokens: 4096,
1044                timeout: 60,
1045                max_iterations: 100,
1046                thinking_enabled: false,
1047                clear_thinking: true,
1048            },
1049        );
1050        let config = LlmConfig {
1051            provider: "anthropic".to_string(),
1052            providers,
1053        };
1054
1055        let result = AgentBridge::new(config);
1056        assert!(result.is_err());
1057    }
1058
1059    #[tokio::test]
1060    async fn test_get_tool_definitions() {
1061        let mut providers = HashMap::new();
1062        providers.insert(
1063            "anthropic".to_string(),
1064            ProviderConfig {
1065                api_key: Some("test-key".to_string()),
1066                model: "claude-3-5-sonnet-20241022".to_string(),
1067                base_url: None,
1068                max_tokens: 4096,
1069                timeout: 60,
1070                max_iterations: 100,
1071                thinking_enabled: false,
1072                clear_thinking: true,
1073            },
1074        );
1075        let config = LlmConfig {
1076            provider: "anthropic".to_string(),
1077            providers,
1078        };
1079
1080        let bridge = AgentBridge::new(config).unwrap();
1081        let definitions = bridge.get_tool_definitions();
1082
1083        assert_eq!(definitions.len(), 17);
1084
1085        // Check file_read tool definition
1086        let file_read = definitions
1087            .iter()
1088            .find(|d| d.function.name == "file_read")
1089            .unwrap();
1090        assert_eq!(file_read.tool_type, "function");
1091        assert_eq!(file_read.function.name, "file_read");
1092        assert!(file_read.function.description.contains("Read"));
1093
1094        // Check bash tool definition
1095        let bash = definitions
1096            .iter()
1097            .find(|d| d.function.name == "bash")
1098            .unwrap();
1099        assert_eq!(bash.function.name, "bash");
1100        assert!(bash.function.parameters["required"]
1101            .as_array()
1102            .unwrap()
1103            .contains(&"command".into()));
1104    }
1105
1106    #[test]
1107    fn test_get_tool_schema() {
1108        let (desc, params) = AgentBridge::get_tool_schema("file_read");
1109        assert!(desc.contains("Read"));
1110        assert_eq!(params["properties"]["path"]["type"], "string");
1111        assert!(params["required"]
1112            .as_array()
1113            .unwrap()
1114            .contains(&"path".into()));
1115
1116        let (desc, params) = AgentBridge::get_tool_schema("bash");
1117        assert!(desc.contains("bash"));
1118        assert_eq!(params["properties"]["command"]["type"], "string");
1119
1120        let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
1121        assert!(desc.contains("unknown_tool"));
1122    }
1123
1124    #[test]
1125    fn test_is_ready() {
1126        let mut providers = HashMap::new();
1127        providers.insert(
1128            "anthropic".to_string(),
1129            ProviderConfig {
1130                api_key: Some("test-key".to_string()),
1131                model: "claude-3-5-sonnet-20241022".to_string(),
1132                base_url: None,
1133                max_tokens: 4096,
1134                timeout: 60,
1135                max_iterations: 100,
1136                thinking_enabled: false,
1137                clear_thinking: true,
1138            },
1139        );
1140        let config_with_key = LlmConfig {
1141            provider: "anthropic".to_string(),
1142            providers,
1143        };
1144
1145        let bridge = AgentBridge::new(config_with_key).unwrap();
1146        assert!(bridge.is_ready());
1147    }
1148}