1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4 AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5 GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6 LspTool,
7};
8use futures::StreamExt;
9use limit_agent::executor::{ToolCall, ToolExecutor};
10use limit_agent::registry::ToolRegistry;
11use limit_llm::providers::LlmProvider;
12use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
13use limit_llm::ProviderFactory;
14use limit_llm::ProviderResponseChunk;
15use limit_llm::TrackingDb;
16use serde_json::json;
17use tokio::sync::mpsc;
18use tracing::{debug, instrument};
19
20#[derive(Debug, Clone)]
22#[allow(dead_code)]
23pub enum AgentEvent {
24 Thinking,
25 ToolStart {
26 name: String,
27 args: serde_json::Value,
28 },
29 ToolComplete {
30 name: String,
31 result: String,
32 },
33 ContentChunk(String),
34 Done,
35 Error(String),
36 TokenUsage {
37 input_tokens: u64,
38 output_tokens: u64,
39 },
40}
41
42pub struct AgentBridge {
44 llm_client: Box<dyn LlmProvider>,
46 executor: ToolExecutor,
48 tool_names: Vec<&'static str>,
50 config: limit_llm::Config,
52 event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
54 tracking_db: TrackingDb,
56}
57
58impl AgentBridge {
59 pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
67 let llm_client = ProviderFactory::create_provider(&config)
68 .map_err(|e| CliError::ConfigError(e.to_string()))?;
69
70 let mut tool_registry = ToolRegistry::new();
71 Self::register_tools(&mut tool_registry);
72
73 let executor = ToolExecutor::new(tool_registry);
75
76 let tool_names = vec![
78 "file_read",
79 "file_write",
80 "file_edit",
81 "bash",
82 "git_status",
83 "git_diff",
84 "git_log",
85 "git_add",
86 "git_commit",
87 "git_push",
88 "git_pull",
89 "git_clone",
90 "grep",
91 "ast_grep",
92 "lsp",
93 ];
94
95 Ok(Self {
96 llm_client,
97 executor,
98 tool_names,
99 config,
100 event_tx: None,
101 tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
102 })
103 }
104
105 pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
107 self.event_tx = Some(tx);
108 }
109
110 fn register_tools(registry: &mut ToolRegistry) {
112 registry
114 .register(FileReadTool::new())
115 .expect("Failed to register file_read");
116 registry
117 .register(FileWriteTool::new())
118 .expect("Failed to register file_write");
119 registry
120 .register(FileEditTool::new())
121 .expect("Failed to register file_edit");
122
123 registry
125 .register(BashTool::new())
126 .expect("Failed to register bash");
127
128 registry
130 .register(GitStatusTool::new())
131 .expect("Failed to register git_status");
132 registry
133 .register(GitDiffTool::new())
134 .expect("Failed to register git_diff");
135 registry
136 .register(GitLogTool::new())
137 .expect("Failed to register git_log");
138 registry
139 .register(GitAddTool::new())
140 .expect("Failed to register git_add");
141 registry
142 .register(GitCommitTool::new())
143 .expect("Failed to register git_commit");
144 registry
145 .register(GitPushTool::new())
146 .expect("Failed to register git_push");
147 registry
148 .register(GitPullTool::new())
149 .expect("Failed to register git_pull");
150 registry
151 .register(GitCloneTool::new())
152 .expect("Failed to register git_clone");
153
154 registry
156 .register(GrepTool::new())
157 .expect("Failed to register grep");
158 registry
159 .register(AstGrepTool::new())
160 .expect("Failed to register ast_grep");
161 registry
162 .register(LspTool::new())
163 .expect("Failed to register lsp");
164 }
165
166 #[instrument(skip(self, _messages))]
175 pub async fn process_message(
176 &mut self,
177 user_input: &str,
178 _messages: &mut Vec<Message>,
179 ) -> Result<String, CliError> {
180 if _messages.is_empty() {
183 let system_message = Message {
184 role: Role::System,
185 content: Some(SYSTEM_PROMPT.to_string()),
186 tool_calls: None,
187 tool_call_id: None,
188 };
189 _messages.push(system_message);
190 }
191
192 let user_message = Message {
194 role: Role::User,
195 content: Some(user_input.to_string()),
196 tool_calls: None,
197 tool_call_id: None,
198 };
199 _messages.push(user_message);
200
201 let tool_definitions = self.get_tool_definitions();
203
204 let mut full_response = String::new();
206 let mut tool_calls: Vec<LlmToolCall> = Vec::new();
207 let max_iterations = 30; let mut iteration = 0;
209
210 while iteration < max_iterations {
211 iteration += 1;
212 debug!("Agent loop iteration {}", iteration);
213
214 self.send_event(AgentEvent::Thinking);
216
217 let request_start = std::time::Instant::now();
219
220 let mut stream = self
222 .llm_client
223 .send(_messages.clone(), tool_definitions.clone())
224 .await
225 .map_err(|e| CliError::ConfigError(e.to_string()))?;
226
227 tool_calls.clear();
228 let mut current_content = String::new();
229 let mut accumulated_calls: std::collections::HashMap<
231 String,
232 (String, serde_json::Value),
233 > = std::collections::HashMap::new();
234
235 while let Some(chunk_result) = stream.next().await {
237 match chunk_result {
238 Ok(ProviderResponseChunk::ContentDelta(text)) => {
239 current_content.push_str(&text);
240 self.send_event(AgentEvent::ContentChunk(text));
241 }
242 Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
243 }
245 Ok(ProviderResponseChunk::ToolCallDelta {
246 id,
247 name,
248 arguments,
249 }) => {
250 debug!("ToolCallDelta: id={}, name={}", id, name);
251 accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
253 }
254 Ok(ProviderResponseChunk::Done(usage)) => {
255 let duration_ms = request_start.elapsed().as_millis() as u64;
257 let cost =
258 calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
259 let _ = self.tracking_db.track_request(
260 self.model(),
261 usage.input_tokens,
262 usage.output_tokens,
263 cost,
264 duration_ms,
265 );
266 self.send_event(AgentEvent::TokenUsage {
268 input_tokens: usage.input_tokens,
269 output_tokens: usage.output_tokens,
270 });
271 break;
272 }
273 Err(e) => {
274 let error_msg = format!("LLM error: {}", e);
275 self.send_event(AgentEvent::Error(error_msg.clone()));
276 return Err(CliError::ConfigError(error_msg));
277 }
278 }
279 }
280
281 tool_calls = accumulated_calls
283 .into_iter()
284 .map(|(id, (name, args))| LlmToolCall {
285 id,
286 tool_type: "function".to_string(),
287 function: limit_llm::types::FunctionCall {
288 name,
289 arguments: args.to_string(),
290 },
291 })
292 .collect();
293 full_response.push_str(¤t_content);
294
295 debug!(
296 "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
297 iteration,
298 current_content.len(),
299 tool_calls.len(),
300 full_response.len()
301 );
302
303 if tool_calls.is_empty() {
305 break;
306 }
307
308 let assistant_message = Message {
311 role: Role::Assistant,
312 content: None, tool_calls: Some(tool_calls.clone()),
314 tool_call_id: None,
315 };
316 _messages.push(assistant_message);
317
318 let executor_calls: Vec<ToolCall> = tool_calls
320 .iter()
321 .map(|tc| {
322 let args: serde_json::Value =
323 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
324 ToolCall::new(&tc.id, &tc.function.name, args)
325 })
326 .collect();
327
328 for tc in &tool_calls {
330 let args: serde_json::Value =
331 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
332 self.send_event(AgentEvent::ToolStart {
333 name: tc.function.name.clone(),
334 args,
335 });
336 }
337 let results = self.executor.execute_tools(executor_calls).await;
339
340 for result in results {
342 let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
343 if let Some(tool_call) = tool_call {
344 let output_json = match &result.output {
345 Ok(value) => {
346 serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
347 }
348 Err(e) => json!({ "error": e.to_string() }).to_string(),
349 };
350
351 self.send_event(AgentEvent::ToolComplete {
352 name: tool_call.function.name.clone(),
353 result: output_json.clone(),
354 });
355
356 let tool_result_message = Message {
358 role: Role::Tool,
359 content: Some(output_json),
360 tool_calls: None,
361 tool_call_id: Some(result.call_id),
362 };
363 _messages.push(tool_result_message);
364 }
365 }
366 }
367
368 if iteration >= max_iterations && !_messages.is_empty() {
370 debug!("Making final LLM call after hitting max iterations (forcing text response)");
371
372 let constraint_message = Message {
374 role: Role::User,
375 content: Some(
376 "We've reached the iteration limit. Please provide a summary of:\n\
377 1. What you've completed so far\n\
378 2. What remains to be done\n\
379 3. Recommended next steps for the user to continue"
380 .to_string(),
381 ),
382 tool_calls: None,
383 tool_call_id: None,
384 };
385 _messages.push(constraint_message);
386
387 let no_tools: Vec<LlmTool> = vec![];
389 let mut stream = self
390 .llm_client
391 .send(_messages.clone(), no_tools)
392 .await
393 .map_err(|e| CliError::ConfigError(e.to_string()))?;
394
395 while let Some(chunk_result) = stream.next().await {
396 match chunk_result {
397 Ok(ProviderResponseChunk::ContentDelta(text)) => {
398 full_response.push_str(&text);
399 self.send_event(AgentEvent::ContentChunk(text));
400 }
401 Ok(ProviderResponseChunk::Done(_)) => {
402 break;
403 }
404 Err(e) => {
405 debug!("Error in final LLM call: {}", e);
406 break;
407 }
408 _ => {}
409 }
410 }
411 }
412
413 self.send_event(AgentEvent::Done);
414 Ok(full_response)
415 }
416
417 pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
419 self.tool_names
420 .iter()
421 .map(|name| {
422 let (description, parameters) = Self::get_tool_schema(name);
423 LlmTool {
424 tool_type: "function".to_string(),
425 function: limit_llm::types::ToolFunction {
426 name: name.to_string(),
427 description,
428 parameters,
429 },
430 }
431 })
432 .collect()
433 }
434
435 fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
437 match name {
438 "file_read" => (
439 "Read the contents of a file".to_string(),
440 json!({
441 "type": "object",
442 "properties": {
443 "path": {
444 "type": "string",
445 "description": "Path to the file to read"
446 }
447 },
448 "required": ["path"]
449 }),
450 ),
451 "file_write" => (
452 "Write content to a file, creating parent directories if needed".to_string(),
453 json!({
454 "type": "object",
455 "properties": {
456 "path": {
457 "type": "string",
458 "description": "Path to the file to write"
459 },
460 "content": {
461 "type": "string",
462 "description": "Content to write to the file"
463 }
464 },
465 "required": ["path", "content"]
466 }),
467 ),
468 "file_edit" => (
469 "Replace text in a file with new text".to_string(),
470 json!({
471 "type": "object",
472 "properties": {
473 "path": {
474 "type": "string",
475 "description": "Path to the file to edit"
476 },
477 "old_text": {
478 "type": "string",
479 "description": "Text to find and replace"
480 },
481 "new_text": {
482 "type": "string",
483 "description": "New text to replace with"
484 }
485 },
486 "required": ["path", "old_text", "new_text"]
487 }),
488 ),
489 "bash" => (
490 "Execute a bash command in a shell".to_string(),
491 json!({
492 "type": "object",
493 "properties": {
494 "command": {
495 "type": "string",
496 "description": "Bash command to execute"
497 },
498 "workdir": {
499 "type": "string",
500 "description": "Working directory (default: current directory)"
501 },
502 "timeout": {
503 "type": "integer",
504 "description": "Timeout in seconds (default: 60)"
505 }
506 },
507 "required": ["command"]
508 }),
509 ),
510 "git_status" => (
511 "Get git repository status".to_string(),
512 json!({
513 "type": "object",
514 "properties": {},
515 "required": []
516 }),
517 ),
518 "git_diff" => (
519 "Get git diff".to_string(),
520 json!({
521 "type": "object",
522 "properties": {},
523 "required": []
524 }),
525 ),
526 "git_log" => (
527 "Get git commit log".to_string(),
528 json!({
529 "type": "object",
530 "properties": {
531 "count": {
532 "type": "integer",
533 "description": "Number of commits to show (default: 10)"
534 }
535 },
536 "required": []
537 }),
538 ),
539 "git_add" => (
540 "Add files to git staging area".to_string(),
541 json!({
542 "type": "object",
543 "properties": {
544 "files": {
545 "type": "array",
546 "items": {"type": "string"},
547 "description": "List of file paths to add"
548 }
549 },
550 "required": ["files"]
551 }),
552 ),
553 "git_commit" => (
554 "Create a git commit".to_string(),
555 json!({
556 "type": "object",
557 "properties": {
558 "message": {
559 "type": "string",
560 "description": "Commit message"
561 }
562 },
563 "required": ["message"]
564 }),
565 ),
566 "git_push" => (
567 "Push commits to remote repository".to_string(),
568 json!({
569 "type": "object",
570 "properties": {
571 "remote": {
572 "type": "string",
573 "description": "Remote name (default: origin)"
574 },
575 "branch": {
576 "type": "string",
577 "description": "Branch name (default: current branch)"
578 }
579 },
580 "required": []
581 }),
582 ),
583 "git_pull" => (
584 "Pull changes from remote repository".to_string(),
585 json!({
586 "type": "object",
587 "properties": {
588 "remote": {
589 "type": "string",
590 "description": "Remote name (default: origin)"
591 },
592 "branch": {
593 "type": "string",
594 "description": "Branch name (default: current branch)"
595 }
596 },
597 "required": []
598 }),
599 ),
600 "git_clone" => (
601 "Clone a git repository".to_string(),
602 json!({
603 "type": "object",
604 "properties": {
605 "url": {
606 "type": "string",
607 "description": "Repository URL to clone"
608 },
609 "directory": {
610 "type": "string",
611 "description": "Directory to clone into (optional)"
612 }
613 },
614 "required": ["url"]
615 }),
616 ),
617 "grep" => (
618 "Search for text patterns in files using regex".to_string(),
619 json!({
620 "type": "object",
621 "properties": {
622 "pattern": {
623 "type": "string",
624 "description": "Regex pattern to search for"
625 },
626 "path": {
627 "type": "string",
628 "description": "Path to search in (default: current directory)"
629 }
630 },
631 "required": ["pattern"]
632 }),
633 ),
634 "ast_grep" => (
635 "Search code using AST patterns (structural code matching)".to_string(),
636 json!({
637 "type": "object",
638 "properties": {
639 "pattern": {
640 "type": "string",
641 "description": "AST pattern to match"
642 },
643 "language": {
644 "type": "string",
645 "description": "Programming language (rust, typescript, python)"
646 },
647 "path": {
648 "type": "string",
649 "description": "Path to search in (default: current directory)"
650 }
651 },
652 "required": ["pattern", "language"]
653 }),
654 ),
655 "lsp" => (
656 "Perform Language Server Protocol operations (goto_definition, find_references)"
657 .to_string(),
658 json!({
659 "type": "object",
660 "properties": {
661 "command": {
662 "type": "string",
663 "description": "LSP command: goto_definition or find_references"
664 },
665 "file_path": {
666 "type": "string",
667 "description": "Path to the file"
668 },
669 "position": {
670 "type": "object",
671 "description": "Position in the file (line, character)",
672 "properties": {
673 "line": {"type": "integer"},
674 "character": {"type": "integer"}
675 },
676 "required": ["line", "character"]
677 }
678 },
679 "required": ["command", "file_path", "position"]
680 }),
681 ),
682 _ => (
683 format!("Tool: {}", name),
684 json!({
685 "type": "object",
686 "properties": {},
687 "required": []
688 }),
689 ),
690 }
691 }
692
693 fn send_event(&self, event: AgentEvent) {
695 if let Some(ref tx) = self.event_tx {
696 let _ = tx.send(event);
697 }
698 }
699
700 #[allow(dead_code)]
702 pub fn is_ready(&self) -> bool {
703 self.config
704 .providers
705 .get(&self.config.provider)
706 .map(|p| p.api_key_or_env(&self.config.provider).is_some())
707 .unwrap_or(false)
708 }
709
710 pub fn model(&self) -> &str {
712 self.config
713 .providers
714 .get(&self.config.provider)
715 .map(|p| p.model.as_str())
716 .unwrap_or("")
717 }
718
719 pub fn max_tokens(&self) -> u32 {
721 self.config
722 .providers
723 .get(&self.config.provider)
724 .map(|p| p.max_tokens)
725 .unwrap_or(4096)
726 }
727
728 pub fn timeout(&self) -> u64 {
730 self.config
731 .providers
732 .get(&self.config.provider)
733 .map(|p| p.timeout)
734 .unwrap_or(60)
735 }
736}
737fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
739 let (input_price, output_price) = match model {
740 "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
742 "gpt-4" => (30.0, 60.0),
744 "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
746 _ => (0.0, 0.0),
748 };
749 (input_tokens as f64 * input_price / 1_000_000.0)
750 + (output_tokens as f64 * output_price / 1_000_000.0)
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use limit_llm::{Config as LlmConfig, ProviderConfig};
757 use std::collections::HashMap;
758
759 #[tokio::test]
760 async fn test_agent_bridge_new() {
761 let mut providers = HashMap::new();
762 providers.insert(
763 "anthropic".to_string(),
764 ProviderConfig {
765 api_key: Some("test-key".to_string()),
766 model: "claude-3-5-sonnet-20241022".to_string(),
767 base_url: None,
768 max_tokens: 4096,
769 timeout: 60,
770 },
771 );
772 let config = LlmConfig {
773 provider: "anthropic".to_string(),
774 providers,
775 };
776
777 let bridge = AgentBridge::new(config).unwrap();
778 assert!(bridge.is_ready());
779 }
780
781 #[tokio::test]
782 async fn test_agent_bridge_new_no_api_key() {
783 let mut providers = HashMap::new();
784 providers.insert(
785 "anthropic".to_string(),
786 ProviderConfig {
787 api_key: None,
788 model: "claude-3-5-sonnet-20241022".to_string(),
789 base_url: None,
790 max_tokens: 4096,
791 timeout: 60,
792 },
793 );
794 let config = LlmConfig {
795 provider: "anthropic".to_string(),
796 providers,
797 };
798
799 let result = AgentBridge::new(config);
800 assert!(result.is_err());
801 }
802
803 #[tokio::test]
804 async fn test_get_tool_definitions() {
805 let mut providers = HashMap::new();
806 providers.insert(
807 "anthropic".to_string(),
808 ProviderConfig {
809 api_key: Some("test-key".to_string()),
810 model: "claude-3-5-sonnet-20241022".to_string(),
811 base_url: None,
812 max_tokens: 4096,
813 timeout: 60,
814 },
815 );
816 let config = LlmConfig {
817 provider: "anthropic".to_string(),
818 providers,
819 };
820
821 let bridge = AgentBridge::new(config).unwrap();
822 let definitions = bridge.get_tool_definitions();
823
824 assert_eq!(definitions.len(), 15);
825
826 let file_read = definitions
828 .iter()
829 .find(|d| d.function.name == "file_read")
830 .unwrap();
831 assert_eq!(file_read.tool_type, "function");
832 assert_eq!(file_read.function.name, "file_read");
833 assert!(file_read.function.description.contains("Read"));
834
835 let bash = definitions
837 .iter()
838 .find(|d| d.function.name == "bash")
839 .unwrap();
840 assert_eq!(bash.function.name, "bash");
841 assert!(bash.function.parameters["required"]
842 .as_array()
843 .unwrap()
844 .contains(&"command".into()));
845 }
846
847 #[test]
848 fn test_get_tool_schema() {
849 let (desc, params) = AgentBridge::get_tool_schema("file_read");
850 assert!(desc.contains("Read"));
851 assert_eq!(params["properties"]["path"]["type"], "string");
852 assert!(params["required"]
853 .as_array()
854 .unwrap()
855 .contains(&"path".into()));
856
857 let (desc, params) = AgentBridge::get_tool_schema("bash");
858 assert!(desc.contains("bash"));
859 assert_eq!(params["properties"]["command"]["type"], "string");
860
861 let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
862 assert!(desc.contains("unknown_tool"));
863 }
864
865 #[test]
866 fn test_is_ready() {
867 let mut providers = HashMap::new();
868 providers.insert(
869 "anthropic".to_string(),
870 ProviderConfig {
871 api_key: Some("test-key".to_string()),
872 model: "claude-3-5-sonnet-20241022".to_string(),
873 base_url: None,
874 max_tokens: 4096,
875 timeout: 60,
876 },
877 );
878 let config_with_key = LlmConfig {
879 provider: "anthropic".to_string(),
880 providers,
881 };
882
883 let bridge = AgentBridge::new(config_with_key).unwrap();
884 assert!(bridge.is_ready());
885 }
886}