1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4 AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5 GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6 LspTool,
7};
8use futures::StreamExt;
9use limit_agent::executor::{ToolCall, ToolExecutor};
10use limit_agent::registry::ToolRegistry;
11use limit_llm::providers::LlmProvider;
12use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
13use limit_llm::ProviderFactory;
14use limit_llm::ProviderResponseChunk;
15use limit_llm::TrackingDb;
16use serde_json::json;
17use tokio::sync::mpsc;
18use tracing::{debug, instrument};
19
20#[derive(Debug, Clone)]
22#[allow(dead_code)]
23pub enum AgentEvent {
24 Thinking,
25 ToolStart {
26 name: String,
27 args: serde_json::Value,
28 },
29 ToolComplete {
30 name: String,
31 result: String,
32 },
33 ContentChunk(String),
34 Done,
35 Error(String),
36 TokenUsage {
37 input_tokens: u64,
38 output_tokens: u64,
39 },
40}
41
42pub struct AgentBridge {
44 llm_client: Box<dyn LlmProvider>,
46 executor: ToolExecutor,
48 tool_names: Vec<&'static str>,
50 config: limit_llm::Config,
52 event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
54 tracking_db: TrackingDb,
56}
57
58impl AgentBridge {
59 pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
67 let llm_client = ProviderFactory::create_provider(&config)
68 .map_err(|e| CliError::ConfigError(e.to_string()))?;
69
70 let mut tool_registry = ToolRegistry::new();
71 Self::register_tools(&mut tool_registry);
72
73 let executor = ToolExecutor::new(tool_registry);
75
76 let tool_names = vec![
78 "file_read",
79 "file_write",
80 "file_edit",
81 "bash",
82 "git_status",
83 "git_diff",
84 "git_log",
85 "git_add",
86 "git_commit",
87 "git_push",
88 "git_pull",
89 "git_clone",
90 "grep",
91 "ast_grep",
92 "lsp",
93 ];
94
95 Ok(Self {
96 llm_client,
97 executor,
98 tool_names,
99 config,
100 event_tx: None,
101 tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
102 })
103 }
104
105 pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
107 self.event_tx = Some(tx);
108 }
109
110 fn register_tools(registry: &mut ToolRegistry) {
112 registry
114 .register(FileReadTool::new())
115 .expect("Failed to register file_read");
116 registry
117 .register(FileWriteTool::new())
118 .expect("Failed to register file_write");
119 registry
120 .register(FileEditTool::new())
121 .expect("Failed to register file_edit");
122
123 registry
125 .register(BashTool::new())
126 .expect("Failed to register bash");
127
128 registry
130 .register(GitStatusTool::new())
131 .expect("Failed to register git_status");
132 registry
133 .register(GitDiffTool::new())
134 .expect("Failed to register git_diff");
135 registry
136 .register(GitLogTool::new())
137 .expect("Failed to register git_log");
138 registry
139 .register(GitAddTool::new())
140 .expect("Failed to register git_add");
141 registry
142 .register(GitCommitTool::new())
143 .expect("Failed to register git_commit");
144 registry
145 .register(GitPushTool::new())
146 .expect("Failed to register git_push");
147 registry
148 .register(GitPullTool::new())
149 .expect("Failed to register git_pull");
150 registry
151 .register(GitCloneTool::new())
152 .expect("Failed to register git_clone");
153
154 registry
156 .register(GrepTool::new())
157 .expect("Failed to register grep");
158 registry
159 .register(AstGrepTool::new())
160 .expect("Failed to register ast_grep");
161 registry
162 .register(LspTool::new())
163 .expect("Failed to register lsp");
164 }
165
166 #[instrument(skip(self, _messages))]
175 pub async fn process_message(
176 &mut self,
177 user_input: &str,
178 _messages: &mut Vec<Message>,
179 ) -> Result<String, CliError> {
180 if _messages.is_empty() {
183 let system_message = Message {
184 role: Role::System,
185 content: Some(SYSTEM_PROMPT.to_string()),
186 tool_calls: None,
187 tool_call_id: None,
188 };
189 _messages.push(system_message);
190 }
191
192 let user_message = Message {
194 role: Role::User,
195 content: Some(user_input.to_string()),
196 tool_calls: None,
197 tool_call_id: None,
198 };
199 _messages.push(user_message);
200
201 let tool_definitions = self.get_tool_definitions();
203
204 let mut full_response = String::new();
206 let mut tool_calls: Vec<LlmToolCall> = Vec::new();
207 let max_iterations = self
208 .config
209 .providers
210 .get(&self.config.provider)
211 .map(|p| p.max_iterations)
212 .unwrap_or(100); let mut iteration = 0;
214
215 while max_iterations == 0 || iteration < max_iterations {
216 iteration += 1;
217 debug!("Agent loop iteration {}", iteration);
218
219 self.send_event(AgentEvent::Thinking);
221
222 let request_start = std::time::Instant::now();
224
225 let mut stream = self
227 .llm_client
228 .send(_messages.clone(), tool_definitions.clone())
229 .await
230 .map_err(|e| CliError::ConfigError(e.to_string()))?;
231
232 tool_calls.clear();
233 let mut current_content = String::new();
234 let mut accumulated_calls: std::collections::HashMap<
236 String,
237 (String, serde_json::Value),
238 > = std::collections::HashMap::new();
239
240 while let Some(chunk_result) = stream.next().await {
242 match chunk_result {
243 Ok(ProviderResponseChunk::ContentDelta(text)) => {
244 current_content.push_str(&text);
245 self.send_event(AgentEvent::ContentChunk(text));
246 }
247 Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
248 }
250 Ok(ProviderResponseChunk::ToolCallDelta {
251 id,
252 name,
253 arguments,
254 }) => {
255 debug!("ToolCallDelta: id={}, name={}", id, name);
256 accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
258 }
259 Ok(ProviderResponseChunk::Done(usage)) => {
260 let duration_ms = request_start.elapsed().as_millis() as u64;
262 let cost =
263 calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
264 let _ = self.tracking_db.track_request(
265 self.model(),
266 usage.input_tokens,
267 usage.output_tokens,
268 cost,
269 duration_ms,
270 );
271 self.send_event(AgentEvent::TokenUsage {
273 input_tokens: usage.input_tokens,
274 output_tokens: usage.output_tokens,
275 });
276 break;
277 }
278 Err(e) => {
279 let error_msg = format!("LLM error: {}", e);
280 self.send_event(AgentEvent::Error(error_msg.clone()));
281 return Err(CliError::ConfigError(error_msg));
282 }
283 }
284 }
285
286 tool_calls = accumulated_calls
288 .into_iter()
289 .map(|(id, (name, args))| LlmToolCall {
290 id,
291 tool_type: "function".to_string(),
292 function: limit_llm::types::FunctionCall {
293 name,
294 arguments: args.to_string(),
295 },
296 })
297 .collect();
298 full_response.push_str(¤t_content);
299
300 debug!(
301 "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
302 iteration,
303 current_content.len(),
304 tool_calls.len(),
305 full_response.len()
306 );
307
308 if tool_calls.is_empty() {
310 break;
311 }
312
313 let assistant_message = Message {
316 role: Role::Assistant,
317 content: None, tool_calls: Some(tool_calls.clone()),
319 tool_call_id: None,
320 };
321 _messages.push(assistant_message);
322
323 let executor_calls: Vec<ToolCall> = tool_calls
325 .iter()
326 .map(|tc| {
327 let args: serde_json::Value =
328 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
329 ToolCall::new(&tc.id, &tc.function.name, args)
330 })
331 .collect();
332
333 for tc in &tool_calls {
335 let args: serde_json::Value =
336 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
337 self.send_event(AgentEvent::ToolStart {
338 name: tc.function.name.clone(),
339 args,
340 });
341 }
342 let results = self.executor.execute_tools(executor_calls).await;
344
345 for result in results {
347 let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
348 if let Some(tool_call) = tool_call {
349 let output_json = match &result.output {
350 Ok(value) => {
351 serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
352 }
353 Err(e) => json!({ "error": e.to_string() }).to_string(),
354 };
355
356 self.send_event(AgentEvent::ToolComplete {
357 name: tool_call.function.name.clone(),
358 result: output_json.clone(),
359 });
360
361 let tool_result_message = Message {
363 role: Role::Tool,
364 content: Some(output_json),
365 tool_calls: None,
366 tool_call_id: Some(result.call_id),
367 };
368 _messages.push(tool_result_message);
369 }
370 }
371 }
372
373 if iteration >= max_iterations && !_messages.is_empty() {
375 debug!("Making final LLM call after hitting max iterations (forcing text response)");
376
377 let constraint_message = Message {
379 role: Role::User,
380 content: Some(
381 "We've reached the iteration limit. Please provide a summary of:\n\
382 1. What you've completed so far\n\
383 2. What remains to be done\n\
384 3. Recommended next steps for the user to continue"
385 .to_string(),
386 ),
387 tool_calls: None,
388 tool_call_id: None,
389 };
390 _messages.push(constraint_message);
391
392 let no_tools: Vec<LlmTool> = vec![];
394 let mut stream = self
395 .llm_client
396 .send(_messages.clone(), no_tools)
397 .await
398 .map_err(|e| CliError::ConfigError(e.to_string()))?;
399
400 while let Some(chunk_result) = stream.next().await {
401 match chunk_result {
402 Ok(ProviderResponseChunk::ContentDelta(text)) => {
403 full_response.push_str(&text);
404 self.send_event(AgentEvent::ContentChunk(text));
405 }
406 Ok(ProviderResponseChunk::Done(_)) => {
407 break;
408 }
409 Err(e) => {
410 debug!("Error in final LLM call: {}", e);
411 break;
412 }
413 _ => {}
414 }
415 }
416 }
417
418 self.send_event(AgentEvent::Done);
419 Ok(full_response)
420 }
421
422 pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
424 self.tool_names
425 .iter()
426 .map(|name| {
427 let (description, parameters) = Self::get_tool_schema(name);
428 LlmTool {
429 tool_type: "function".to_string(),
430 function: limit_llm::types::ToolFunction {
431 name: name.to_string(),
432 description,
433 parameters,
434 },
435 }
436 })
437 .collect()
438 }
439
440 fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
442 match name {
443 "file_read" => (
444 "Read the contents of a file".to_string(),
445 json!({
446 "type": "object",
447 "properties": {
448 "path": {
449 "type": "string",
450 "description": "Path to the file to read"
451 }
452 },
453 "required": ["path"]
454 }),
455 ),
456 "file_write" => (
457 "Write content to a file, creating parent directories if needed".to_string(),
458 json!({
459 "type": "object",
460 "properties": {
461 "path": {
462 "type": "string",
463 "description": "Path to the file to write"
464 },
465 "content": {
466 "type": "string",
467 "description": "Content to write to the file"
468 }
469 },
470 "required": ["path", "content"]
471 }),
472 ),
473 "file_edit" => (
474 "Replace text in a file with new text".to_string(),
475 json!({
476 "type": "object",
477 "properties": {
478 "path": {
479 "type": "string",
480 "description": "Path to the file to edit"
481 },
482 "old_text": {
483 "type": "string",
484 "description": "Text to find and replace"
485 },
486 "new_text": {
487 "type": "string",
488 "description": "New text to replace with"
489 }
490 },
491 "required": ["path", "old_text", "new_text"]
492 }),
493 ),
494 "bash" => (
495 "Execute a bash command in a shell".to_string(),
496 json!({
497 "type": "object",
498 "properties": {
499 "command": {
500 "type": "string",
501 "description": "Bash command to execute"
502 },
503 "workdir": {
504 "type": "string",
505 "description": "Working directory (default: current directory)"
506 },
507 "timeout": {
508 "type": "integer",
509 "description": "Timeout in seconds (default: 60)"
510 }
511 },
512 "required": ["command"]
513 }),
514 ),
515 "git_status" => (
516 "Get git repository status".to_string(),
517 json!({
518 "type": "object",
519 "properties": {},
520 "required": []
521 }),
522 ),
523 "git_diff" => (
524 "Get git diff".to_string(),
525 json!({
526 "type": "object",
527 "properties": {},
528 "required": []
529 }),
530 ),
531 "git_log" => (
532 "Get git commit log".to_string(),
533 json!({
534 "type": "object",
535 "properties": {
536 "count": {
537 "type": "integer",
538 "description": "Number of commits to show (default: 10)"
539 }
540 },
541 "required": []
542 }),
543 ),
544 "git_add" => (
545 "Add files to git staging area".to_string(),
546 json!({
547 "type": "object",
548 "properties": {
549 "files": {
550 "type": "array",
551 "items": {"type": "string"},
552 "description": "List of file paths to add"
553 }
554 },
555 "required": ["files"]
556 }),
557 ),
558 "git_commit" => (
559 "Create a git commit".to_string(),
560 json!({
561 "type": "object",
562 "properties": {
563 "message": {
564 "type": "string",
565 "description": "Commit message"
566 }
567 },
568 "required": ["message"]
569 }),
570 ),
571 "git_push" => (
572 "Push commits to remote repository".to_string(),
573 json!({
574 "type": "object",
575 "properties": {
576 "remote": {
577 "type": "string",
578 "description": "Remote name (default: origin)"
579 },
580 "branch": {
581 "type": "string",
582 "description": "Branch name (default: current branch)"
583 }
584 },
585 "required": []
586 }),
587 ),
588 "git_pull" => (
589 "Pull changes from remote repository".to_string(),
590 json!({
591 "type": "object",
592 "properties": {
593 "remote": {
594 "type": "string",
595 "description": "Remote name (default: origin)"
596 },
597 "branch": {
598 "type": "string",
599 "description": "Branch name (default: current branch)"
600 }
601 },
602 "required": []
603 }),
604 ),
605 "git_clone" => (
606 "Clone a git repository".to_string(),
607 json!({
608 "type": "object",
609 "properties": {
610 "url": {
611 "type": "string",
612 "description": "Repository URL to clone"
613 },
614 "directory": {
615 "type": "string",
616 "description": "Directory to clone into (optional)"
617 }
618 },
619 "required": ["url"]
620 }),
621 ),
622 "grep" => (
623 "Search for text patterns in files using regex".to_string(),
624 json!({
625 "type": "object",
626 "properties": {
627 "pattern": {
628 "type": "string",
629 "description": "Regex pattern to search for"
630 },
631 "path": {
632 "type": "string",
633 "description": "Path to search in (default: current directory)"
634 }
635 },
636 "required": ["pattern"]
637 }),
638 ),
639 "ast_grep" => (
640 "Search code using AST patterns (structural code matching)".to_string(),
641 json!({
642 "type": "object",
643 "properties": {
644 "pattern": {
645 "type": "string",
646 "description": "AST pattern to match"
647 },
648 "language": {
649 "type": "string",
650 "description": "Programming language (rust, typescript, python)"
651 },
652 "path": {
653 "type": "string",
654 "description": "Path to search in (default: current directory)"
655 }
656 },
657 "required": ["pattern", "language"]
658 }),
659 ),
660 "lsp" => (
661 "Perform Language Server Protocol operations (goto_definition, find_references)"
662 .to_string(),
663 json!({
664 "type": "object",
665 "properties": {
666 "command": {
667 "type": "string",
668 "description": "LSP command: goto_definition or find_references"
669 },
670 "file_path": {
671 "type": "string",
672 "description": "Path to the file"
673 },
674 "position": {
675 "type": "object",
676 "description": "Position in the file (line, character)",
677 "properties": {
678 "line": {"type": "integer"},
679 "character": {"type": "integer"}
680 },
681 "required": ["line", "character"]
682 }
683 },
684 "required": ["command", "file_path", "position"]
685 }),
686 ),
687 _ => (
688 format!("Tool: {}", name),
689 json!({
690 "type": "object",
691 "properties": {},
692 "required": []
693 }),
694 ),
695 }
696 }
697
698 fn send_event(&self, event: AgentEvent) {
700 if let Some(ref tx) = self.event_tx {
701 let _ = tx.send(event);
702 }
703 }
704
705 #[allow(dead_code)]
707 pub fn is_ready(&self) -> bool {
708 self.config
709 .providers
710 .get(&self.config.provider)
711 .map(|p| p.api_key_or_env(&self.config.provider).is_some())
712 .unwrap_or(false)
713 }
714
715 pub fn model(&self) -> &str {
717 self.config
718 .providers
719 .get(&self.config.provider)
720 .map(|p| p.model.as_str())
721 .unwrap_or("")
722 }
723
724 pub fn max_tokens(&self) -> u32 {
726 self.config
727 .providers
728 .get(&self.config.provider)
729 .map(|p| p.max_tokens)
730 .unwrap_or(4096)
731 }
732
733 pub fn timeout(&self) -> u64 {
735 self.config
736 .providers
737 .get(&self.config.provider)
738 .map(|p| p.timeout)
739 .unwrap_or(60)
740 }
741}
742fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
744 let (input_price, output_price) = match model {
745 "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
747 "gpt-4" => (30.0, 60.0),
749 "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
751 _ => (0.0, 0.0),
753 };
754 (input_tokens as f64 * input_price / 1_000_000.0)
755 + (output_tokens as f64 * output_price / 1_000_000.0)
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761 use limit_llm::{Config as LlmConfig, ProviderConfig};
762 use std::collections::HashMap;
763
764 #[tokio::test]
765 async fn test_agent_bridge_new() {
766 let mut providers = HashMap::new();
767 providers.insert(
768 "anthropic".to_string(),
769 ProviderConfig {
770 api_key: Some("test-key".to_string()),
771 model: "claude-3-5-sonnet-20241022".to_string(),
772 base_url: None,
773 max_tokens: 4096,
774 timeout: 60,
775 max_iterations: 100,
776 thinking_enabled: false,
777 clear_thinking: true,
778 },
779 );
780 let config = LlmConfig {
781 provider: "anthropic".to_string(),
782 providers,
783 };
784
785 let bridge = AgentBridge::new(config).unwrap();
786 assert!(bridge.is_ready());
787 }
788
789 #[tokio::test]
790 async fn test_agent_bridge_new_no_api_key() {
791 let mut providers = HashMap::new();
792 providers.insert(
793 "anthropic".to_string(),
794 ProviderConfig {
795 api_key: None,
796 model: "claude-3-5-sonnet-20241022".to_string(),
797 base_url: None,
798 max_tokens: 4096,
799 timeout: 60,
800 max_iterations: 100,
801 thinking_enabled: false,
802 clear_thinking: true,
803 },
804 );
805 let config = LlmConfig {
806 provider: "anthropic".to_string(),
807 providers,
808 };
809
810 let result = AgentBridge::new(config);
811 assert!(result.is_err());
812 }
813
814 #[tokio::test]
815 async fn test_get_tool_definitions() {
816 let mut providers = HashMap::new();
817 providers.insert(
818 "anthropic".to_string(),
819 ProviderConfig {
820 api_key: None,
821 model: "claude-3-5-sonnet-20241022".to_string(),
822 base_url: None,
823 max_tokens: 4096,
824 timeout: 60,
825 max_iterations: 100,
826 thinking_enabled: false,
827 clear_thinking: true,
828 },
829 );
830 let config = LlmConfig {
831 provider: "anthropic".to_string(),
832 providers,
833 };
834
835 let bridge = AgentBridge::new(config).unwrap();
836 let definitions = bridge.get_tool_definitions();
837
838 assert_eq!(definitions.len(), 15);
839
840 let file_read = definitions
842 .iter()
843 .find(|d| d.function.name == "file_read")
844 .unwrap();
845 assert_eq!(file_read.tool_type, "function");
846 assert_eq!(file_read.function.name, "file_read");
847 assert!(file_read.function.description.contains("Read"));
848
849 let bash = definitions
851 .iter()
852 .find(|d| d.function.name == "bash")
853 .unwrap();
854 assert_eq!(bash.function.name, "bash");
855 assert!(bash.function.parameters["required"]
856 .as_array()
857 .unwrap()
858 .contains(&"command".into()));
859 }
860
861 #[test]
862 fn test_get_tool_schema() {
863 let (desc, params) = AgentBridge::get_tool_schema("file_read");
864 assert!(desc.contains("Read"));
865 assert_eq!(params["properties"]["path"]["type"], "string");
866 assert!(params["required"]
867 .as_array()
868 .unwrap()
869 .contains(&"path".into()));
870
871 let (desc, params) = AgentBridge::get_tool_schema("bash");
872 assert!(desc.contains("bash"));
873 assert_eq!(params["properties"]["command"]["type"], "string");
874
875 let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
876 assert!(desc.contains("unknown_tool"));
877 }
878
879 #[test]
880 fn test_is_ready() {
881 let mut providers = HashMap::new();
882 providers.insert(
883 "anthropic".to_string(),
884 ProviderConfig {
885 api_key: Some("test-key".to_string()),
886 model: "claude-3-5-sonnet-20241022".to_string(),
887 base_url: None,
888 max_tokens: 4096,
889 timeout: 60,
890 max_iterations: 100,
891 thinking_enabled: false,
892 clear_thinking: true,
893 },
894 );
895 let config_with_key = LlmConfig {
896 provider: "anthropic".to_string(),
897 providers,
898 };
899
900 let bridge = AgentBridge::new(config_with_key).unwrap();
901 assert!(bridge.is_ready());
902 }
903}