1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4 AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5 GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6 LspTool, WebFetchTool, WebSearchTool,
7};
8use chrono::Datelike;
9use futures::StreamExt;
10use limit_agent::executor::{ToolCall, ToolExecutor};
11use limit_agent::registry::ToolRegistry;
12use limit_llm::providers::LlmProvider;
13use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
14use limit_llm::ProviderFactory;
15use limit_llm::ProviderResponseChunk;
16use limit_llm::TrackingDb;
17use serde_json::json;
18use tokio::sync::mpsc;
19use tracing::{debug, instrument};
20
21#[derive(Debug, Clone)]
23#[allow(dead_code)]
24pub enum AgentEvent {
25 Thinking,
26 ToolStart {
27 name: String,
28 args: serde_json::Value,
29 },
30 ToolComplete {
31 name: String,
32 result: String,
33 },
34 ContentChunk(String),
35 Done,
36 Error(String),
37 TokenUsage {
38 input_tokens: u64,
39 output_tokens: u64,
40 },
41}
42
43pub struct AgentBridge {
45 llm_client: Box<dyn LlmProvider>,
47 executor: ToolExecutor,
49 tool_names: Vec<&'static str>,
51 config: limit_llm::Config,
53 event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
55 tracking_db: TrackingDb,
57}
58
59impl AgentBridge {
60 pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
68 let llm_client = ProviderFactory::create_provider(&config)
69 .map_err(|e| CliError::ConfigError(e.to_string()))?;
70
71 let mut tool_registry = ToolRegistry::new();
72 Self::register_tools(&mut tool_registry);
73
74 let executor = ToolExecutor::new(tool_registry);
76
77 let tool_names = vec![
79 "file_read",
80 "file_write",
81 "file_edit",
82 "bash",
83 "git_status",
84 "git_diff",
85 "git_log",
86 "git_add",
87 "git_commit",
88 "git_push",
89 "git_pull",
90 "git_clone",
91 "grep",
92 "ast_grep",
93 "lsp",
94 "web_search",
95 "web_fetch",
96 ];
97
98 Ok(Self {
99 llm_client,
100 executor,
101 tool_names,
102 config,
103 event_tx: None,
104 tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
105 })
106 }
107
108 pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
110 self.event_tx = Some(tx);
111 }
112
113 fn register_tools(registry: &mut ToolRegistry) {
115 registry
117 .register(FileReadTool::new())
118 .expect("Failed to register file_read");
119 registry
120 .register(FileWriteTool::new())
121 .expect("Failed to register file_write");
122 registry
123 .register(FileEditTool::new())
124 .expect("Failed to register file_edit");
125
126 registry
128 .register(BashTool::new())
129 .expect("Failed to register bash");
130
131 registry
133 .register(GitStatusTool::new())
134 .expect("Failed to register git_status");
135 registry
136 .register(GitDiffTool::new())
137 .expect("Failed to register git_diff");
138 registry
139 .register(GitLogTool::new())
140 .expect("Failed to register git_log");
141 registry
142 .register(GitAddTool::new())
143 .expect("Failed to register git_add");
144 registry
145 .register(GitCommitTool::new())
146 .expect("Failed to register git_commit");
147 registry
148 .register(GitPushTool::new())
149 .expect("Failed to register git_push");
150 registry
151 .register(GitPullTool::new())
152 .expect("Failed to register git_pull");
153 registry
154 .register(GitCloneTool::new())
155 .expect("Failed to register git_clone");
156
157 registry
159 .register(GrepTool::new())
160 .expect("Failed to register grep");
161 registry
162 .register(AstGrepTool::new())
163 .expect("Failed to register ast_grep");
164 registry
165 .register(LspTool::new())
166 .expect("Failed to register lsp");
167
168 registry
170 .register(WebSearchTool::new())
171 .expect("Failed to register web_search");
172 registry
173 .register(WebFetchTool::new())
174 .expect("Failed to register web_fetch");
175 }
176
177 #[instrument(skip(self, _messages))]
186 pub async fn process_message(
187 &mut self,
188 user_input: &str,
189 _messages: &mut Vec<Message>,
190 ) -> Result<String, CliError> {
191 if _messages.is_empty() {
194 let system_message = Message {
195 role: Role::System,
196 content: Some(SYSTEM_PROMPT.to_string()),
197 tool_calls: None,
198 tool_call_id: None,
199 };
200 _messages.push(system_message);
201 }
202
203 let user_message = Message {
205 role: Role::User,
206 content: Some(user_input.to_string()),
207 tool_calls: None,
208 tool_call_id: None,
209 };
210 _messages.push(user_message);
211
212 let tool_definitions = self.get_tool_definitions();
214
215 let mut full_response = String::new();
217 let mut tool_calls: Vec<LlmToolCall> = Vec::new();
218 let max_iterations = self
219 .config
220 .providers
221 .get(&self.config.provider)
222 .map(|p| p.max_iterations)
223 .unwrap_or(100); let mut iteration = 0;
225
226 while max_iterations == 0 || iteration < max_iterations {
227 iteration += 1;
228 debug!("Agent loop iteration {}", iteration);
229
230 self.send_event(AgentEvent::Thinking);
232
233 let request_start = std::time::Instant::now();
235
236 let mut stream = self
238 .llm_client
239 .send(_messages.clone(), tool_definitions.clone())
240 .await
241 .map_err(|e| CliError::ConfigError(e.to_string()))?;
242
243 tool_calls.clear();
244 let mut current_content = String::new();
245 let mut accumulated_calls: std::collections::HashMap<
247 String,
248 (String, serde_json::Value),
249 > = std::collections::HashMap::new();
250
251 while let Some(chunk_result) = stream.next().await {
253 match chunk_result {
254 Ok(ProviderResponseChunk::ContentDelta(text)) => {
255 current_content.push_str(&text);
256 self.send_event(AgentEvent::ContentChunk(text));
257 }
258 Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
259 }
261 Ok(ProviderResponseChunk::ToolCallDelta {
262 id,
263 name,
264 arguments,
265 }) => {
266 debug!("ToolCallDelta: id={}, name={}", id, name);
267 accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
269 }
270 Ok(ProviderResponseChunk::Done(usage)) => {
271 let duration_ms = request_start.elapsed().as_millis() as u64;
273 let cost =
274 calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
275 let _ = self.tracking_db.track_request(
276 self.model(),
277 usage.input_tokens,
278 usage.output_tokens,
279 cost,
280 duration_ms,
281 );
282 self.send_event(AgentEvent::TokenUsage {
284 input_tokens: usage.input_tokens,
285 output_tokens: usage.output_tokens,
286 });
287 break;
288 }
289 Err(e) => {
290 let error_msg = format!("LLM error: {}", e);
291 self.send_event(AgentEvent::Error(error_msg.clone()));
292 return Err(CliError::ConfigError(error_msg));
293 }
294 }
295 }
296
297 tool_calls = accumulated_calls
299 .into_iter()
300 .map(|(id, (name, args))| LlmToolCall {
301 id,
302 tool_type: "function".to_string(),
303 function: limit_llm::types::FunctionCall {
304 name,
305 arguments: args.to_string(),
306 },
307 })
308 .collect();
309 full_response.push_str(¤t_content);
310
311 debug!(
312 "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
313 iteration,
314 current_content.len(),
315 tool_calls.len(),
316 full_response.len()
317 );
318
319 if tool_calls.is_empty() {
321 break;
322 }
323
324 let assistant_message = Message {
327 role: Role::Assistant,
328 content: None, tool_calls: Some(tool_calls.clone()),
330 tool_call_id: None,
331 };
332 _messages.push(assistant_message);
333
334 let executor_calls: Vec<ToolCall> = tool_calls
336 .iter()
337 .map(|tc| {
338 let args: serde_json::Value =
339 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
340 ToolCall::new(&tc.id, &tc.function.name, args)
341 })
342 .collect();
343
344 for tc in &tool_calls {
346 let args: serde_json::Value =
347 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
348 self.send_event(AgentEvent::ToolStart {
349 name: tc.function.name.clone(),
350 args,
351 });
352 }
353 let results = self.executor.execute_tools(executor_calls).await;
355
356 for result in results {
358 let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
359 if let Some(tool_call) = tool_call {
360 let output_json = match &result.output {
361 Ok(value) => {
362 serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
363 }
364 Err(e) => json!({ "error": e.to_string() }).to_string(),
365 };
366
367 self.send_event(AgentEvent::ToolComplete {
368 name: tool_call.function.name.clone(),
369 result: output_json.clone(),
370 });
371
372 let tool_result_message = Message {
374 role: Role::Tool,
375 content: Some(output_json),
376 tool_calls: None,
377 tool_call_id: Some(result.call_id),
378 };
379 _messages.push(tool_result_message);
380 }
381 }
382 }
383
384 if iteration >= max_iterations && !_messages.is_empty() {
386 debug!("Making final LLM call after hitting max iterations (forcing text response)");
387
388 let constraint_message = Message {
390 role: Role::User,
391 content: Some(
392 "We've reached the iteration limit. Please provide a summary of:\n\
393 1. What you've completed so far\n\
394 2. What remains to be done\n\
395 3. Recommended next steps for the user to continue"
396 .to_string(),
397 ),
398 tool_calls: None,
399 tool_call_id: None,
400 };
401 _messages.push(constraint_message);
402
403 let no_tools: Vec<LlmTool> = vec![];
405 let mut stream = self
406 .llm_client
407 .send(_messages.clone(), no_tools)
408 .await
409 .map_err(|e| CliError::ConfigError(e.to_string()))?;
410
411 while let Some(chunk_result) = stream.next().await {
412 match chunk_result {
413 Ok(ProviderResponseChunk::ContentDelta(text)) => {
414 full_response.push_str(&text);
415 self.send_event(AgentEvent::ContentChunk(text));
416 }
417 Ok(ProviderResponseChunk::Done(_)) => {
418 break;
419 }
420 Err(e) => {
421 debug!("Error in final LLM call: {}", e);
422 break;
423 }
424 _ => {}
425 }
426 }
427 }
428
429 self.send_event(AgentEvent::Done);
430 Ok(full_response)
431 }
432
433 pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
435 self.tool_names
436 .iter()
437 .map(|name| {
438 let (description, parameters) = Self::get_tool_schema(name);
439 LlmTool {
440 tool_type: "function".to_string(),
441 function: limit_llm::types::ToolFunction {
442 name: name.to_string(),
443 description,
444 parameters,
445 },
446 }
447 })
448 .collect()
449 }
450
451 fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
453 match name {
454 "file_read" => (
455 "Read the contents of a file".to_string(),
456 json!({
457 "type": "object",
458 "properties": {
459 "path": {
460 "type": "string",
461 "description": "Path to the file to read"
462 }
463 },
464 "required": ["path"]
465 }),
466 ),
467 "file_write" => (
468 "Write content to a file, creating parent directories if needed".to_string(),
469 json!({
470 "type": "object",
471 "properties": {
472 "path": {
473 "type": "string",
474 "description": "Path to the file to write"
475 },
476 "content": {
477 "type": "string",
478 "description": "Content to write to the file"
479 }
480 },
481 "required": ["path", "content"]
482 }),
483 ),
484 "file_edit" => (
485 "Replace text in a file with new text".to_string(),
486 json!({
487 "type": "object",
488 "properties": {
489 "path": {
490 "type": "string",
491 "description": "Path to the file to edit"
492 },
493 "old_text": {
494 "type": "string",
495 "description": "Text to find and replace"
496 },
497 "new_text": {
498 "type": "string",
499 "description": "New text to replace with"
500 }
501 },
502 "required": ["path", "old_text", "new_text"]
503 }),
504 ),
505 "bash" => (
506 "Execute a bash command in a shell".to_string(),
507 json!({
508 "type": "object",
509 "properties": {
510 "command": {
511 "type": "string",
512 "description": "Bash command to execute"
513 },
514 "workdir": {
515 "type": "string",
516 "description": "Working directory (default: current directory)"
517 },
518 "timeout": {
519 "type": "integer",
520 "description": "Timeout in seconds (default: 60)"
521 }
522 },
523 "required": ["command"]
524 }),
525 ),
526 "git_status" => (
527 "Get git repository status".to_string(),
528 json!({
529 "type": "object",
530 "properties": {},
531 "required": []
532 }),
533 ),
534 "git_diff" => (
535 "Get git diff".to_string(),
536 json!({
537 "type": "object",
538 "properties": {},
539 "required": []
540 }),
541 ),
542 "git_log" => (
543 "Get git commit log".to_string(),
544 json!({
545 "type": "object",
546 "properties": {
547 "count": {
548 "type": "integer",
549 "description": "Number of commits to show (default: 10)"
550 }
551 },
552 "required": []
553 }),
554 ),
555 "git_add" => (
556 "Add files to git staging area".to_string(),
557 json!({
558 "type": "object",
559 "properties": {
560 "files": {
561 "type": "array",
562 "items": {"type": "string"},
563 "description": "List of file paths to add"
564 }
565 },
566 "required": ["files"]
567 }),
568 ),
569 "git_commit" => (
570 "Create a git commit".to_string(),
571 json!({
572 "type": "object",
573 "properties": {
574 "message": {
575 "type": "string",
576 "description": "Commit message"
577 }
578 },
579 "required": ["message"]
580 }),
581 ),
582 "git_push" => (
583 "Push commits to remote repository".to_string(),
584 json!({
585 "type": "object",
586 "properties": {
587 "remote": {
588 "type": "string",
589 "description": "Remote name (default: origin)"
590 },
591 "branch": {
592 "type": "string",
593 "description": "Branch name (default: current branch)"
594 }
595 },
596 "required": []
597 }),
598 ),
599 "git_pull" => (
600 "Pull changes from remote repository".to_string(),
601 json!({
602 "type": "object",
603 "properties": {
604 "remote": {
605 "type": "string",
606 "description": "Remote name (default: origin)"
607 },
608 "branch": {
609 "type": "string",
610 "description": "Branch name (default: current branch)"
611 }
612 },
613 "required": []
614 }),
615 ),
616 "git_clone" => (
617 "Clone a git repository".to_string(),
618 json!({
619 "type": "object",
620 "properties": {
621 "url": {
622 "type": "string",
623 "description": "Repository URL to clone"
624 },
625 "directory": {
626 "type": "string",
627 "description": "Directory to clone into (optional)"
628 }
629 },
630 "required": ["url"]
631 }),
632 ),
633 "grep" => (
634 "Search for text patterns in files using regex".to_string(),
635 json!({
636 "type": "object",
637 "properties": {
638 "pattern": {
639 "type": "string",
640 "description": "Regex pattern to search for"
641 },
642 "path": {
643 "type": "string",
644 "description": "Path to search in (default: current directory)"
645 }
646 },
647 "required": ["pattern"]
648 }),
649 ),
650 "ast_grep" => (
651 "Search code using AST patterns (structural code matching)".to_string(),
652 json!({
653 "type": "object",
654 "properties": {
655 "pattern": {
656 "type": "string",
657 "description": "AST pattern to match"
658 },
659 "language": {
660 "type": "string",
661 "description": "Programming language (rust, typescript, python)"
662 },
663 "path": {
664 "type": "string",
665 "description": "Path to search in (default: current directory)"
666 }
667 },
668 "required": ["pattern", "language"]
669 }),
670 ),
671 "lsp" => (
672 "Perform Language Server Protocol operations (goto_definition, find_references)"
673 .to_string(),
674 json!({
675 "type": "object",
676 "properties": {
677 "command": {
678 "type": "string",
679 "description": "LSP command: goto_definition or find_references"
680 },
681 "file_path": {
682 "type": "string",
683 "description": "Path to the file"
684 },
685 "position": {
686 "type": "object",
687 "description": "Position in the file (line, character)",
688 "properties": {
689 "line": {"type": "integer"},
690 "character": {"type": "integer"}
691 },
692 "required": ["line", "character"]
693 }
694 },
695 "required": ["command", "file_path", "position"]
696 }),
697 ),
698 "web_search" => (
699 format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
700 json!({
701 "type": "object",
702 "properties": {
703 "query": {
704 "type": "string",
705 "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
706 },
707 "numResults": {
708 "type": "integer",
709 "description": "Number of results to return (default: 8, max: 20)",
710 "default": 8
711 }
712 },
713 "required": ["query"]
714 }),
715 ),
716 "web_fetch" => (
717 "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
718 json!({
719 "type": "object",
720 "properties": {
721 "url": {
722 "type": "string",
723 "description": "URL to fetch (must start with http:// or https://)"
724 },
725 "format": {
726 "type": "string",
727 "enum": ["markdown", "text", "html"],
728 "default": "markdown",
729 "description": "Output format (default: markdown)"
730 }
731 },
732 "required": ["url"]
733 }),
734 ),
735 _ => (
736 format!("Tool: {}", name),
737 json!({
738 "type": "object",
739 "properties": {},
740 "required": []
741 }),
742 ),
743 }
744 }
745
746 fn send_event(&self, event: AgentEvent) {
748 if let Some(ref tx) = self.event_tx {
749 let _ = tx.send(event);
750 }
751 }
752
753 #[allow(dead_code)]
755 pub fn is_ready(&self) -> bool {
756 self.config
757 .providers
758 .get(&self.config.provider)
759 .map(|p| p.api_key_or_env(&self.config.provider).is_some())
760 .unwrap_or(false)
761 }
762
763 pub fn model(&self) -> &str {
765 self.config
766 .providers
767 .get(&self.config.provider)
768 .map(|p| p.model.as_str())
769 .unwrap_or("")
770 }
771
772 pub fn max_tokens(&self) -> u32 {
774 self.config
775 .providers
776 .get(&self.config.provider)
777 .map(|p| p.max_tokens)
778 .unwrap_or(4096)
779 }
780
781 pub fn timeout(&self) -> u64 {
783 self.config
784 .providers
785 .get(&self.config.provider)
786 .map(|p| p.timeout)
787 .unwrap_or(60)
788 }
789}
790fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
792 let (input_price, output_price) = match model {
793 "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
795 "gpt-4" => (30.0, 60.0),
797 "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
799 _ => (0.0, 0.0),
801 };
802 (input_tokens as f64 * input_price / 1_000_000.0)
803 + (output_tokens as f64 * output_price / 1_000_000.0)
804}
805
806#[cfg(test)]
807mod tests {
808 use super::*;
809 use limit_llm::{Config as LlmConfig, ProviderConfig};
810 use std::collections::HashMap;
811
812 #[tokio::test]
813 async fn test_agent_bridge_new() {
814 let mut providers = HashMap::new();
815 providers.insert(
816 "anthropic".to_string(),
817 ProviderConfig {
818 api_key: Some("test-key".to_string()),
819 model: "claude-3-5-sonnet-20241022".to_string(),
820 base_url: None,
821 max_tokens: 4096,
822 timeout: 60,
823 max_iterations: 100,
824 thinking_enabled: false,
825 clear_thinking: true,
826 },
827 );
828 let config = LlmConfig {
829 provider: "anthropic".to_string(),
830 providers,
831 };
832
833 let bridge = AgentBridge::new(config).unwrap();
834 assert!(bridge.is_ready());
835 }
836
837 #[tokio::test]
838 async fn test_agent_bridge_new_no_api_key() {
839 let mut providers = HashMap::new();
840 providers.insert(
841 "anthropic".to_string(),
842 ProviderConfig {
843 api_key: None,
844 model: "claude-3-5-sonnet-20241022".to_string(),
845 base_url: None,
846 max_tokens: 4096,
847 timeout: 60,
848 max_iterations: 100,
849 thinking_enabled: false,
850 clear_thinking: true,
851 },
852 );
853 let config = LlmConfig {
854 provider: "anthropic".to_string(),
855 providers,
856 };
857
858 let result = AgentBridge::new(config);
859 assert!(result.is_err());
860 }
861
862 #[tokio::test]
863 async fn test_get_tool_definitions() {
864 let mut providers = HashMap::new();
865 providers.insert(
866 "anthropic".to_string(),
867 ProviderConfig {
868 api_key: Some("test-key".to_string()),
869 model: "claude-3-5-sonnet-20241022".to_string(),
870 base_url: None,
871 max_tokens: 4096,
872 timeout: 60,
873 max_iterations: 100,
874 thinking_enabled: false,
875 clear_thinking: true,
876 },
877 );
878 let config = LlmConfig {
879 provider: "anthropic".to_string(),
880 providers,
881 };
882
883 let bridge = AgentBridge::new(config).unwrap();
884 let definitions = bridge.get_tool_definitions();
885
886 assert_eq!(definitions.len(), 17);
887
888 let file_read = definitions
890 .iter()
891 .find(|d| d.function.name == "file_read")
892 .unwrap();
893 assert_eq!(file_read.tool_type, "function");
894 assert_eq!(file_read.function.name, "file_read");
895 assert!(file_read.function.description.contains("Read"));
896
897 let bash = definitions
899 .iter()
900 .find(|d| d.function.name == "bash")
901 .unwrap();
902 assert_eq!(bash.function.name, "bash");
903 assert!(bash.function.parameters["required"]
904 .as_array()
905 .unwrap()
906 .contains(&"command".into()));
907 }
908
909 #[test]
910 fn test_get_tool_schema() {
911 let (desc, params) = AgentBridge::get_tool_schema("file_read");
912 assert!(desc.contains("Read"));
913 assert_eq!(params["properties"]["path"]["type"], "string");
914 assert!(params["required"]
915 .as_array()
916 .unwrap()
917 .contains(&"path".into()));
918
919 let (desc, params) = AgentBridge::get_tool_schema("bash");
920 assert!(desc.contains("bash"));
921 assert_eq!(params["properties"]["command"]["type"], "string");
922
923 let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
924 assert!(desc.contains("unknown_tool"));
925 }
926
927 #[test]
928 fn test_is_ready() {
929 let mut providers = HashMap::new();
930 providers.insert(
931 "anthropic".to_string(),
932 ProviderConfig {
933 api_key: Some("test-key".to_string()),
934 model: "claude-3-5-sonnet-20241022".to_string(),
935 base_url: None,
936 max_tokens: 4096,
937 timeout: 60,
938 max_iterations: 100,
939 thinking_enabled: false,
940 clear_thinking: true,
941 },
942 );
943 let config_with_key = LlmConfig {
944 provider: "anthropic".to_string(),
945 providers,
946 };
947
948 let bridge = AgentBridge::new(config_with_key).unwrap();
949 assert!(bridge.is_ready());
950 }
951}