1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4 AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5 GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6 LspTool, WebFetchTool, WebSearchTool,
7};
8use chrono::Datelike;
9use futures::StreamExt;
10use limit_agent::executor::{ToolCall, ToolExecutor};
11use limit_agent::registry::ToolRegistry;
12use limit_llm::providers::LlmProvider;
13use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
14use limit_llm::ProviderFactory;
15use limit_llm::ProviderResponseChunk;
16use limit_llm::TrackingDb;
17use serde_json::json;
18use tokio::sync::mpsc;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, instrument};
21
22#[derive(Debug, Clone)]
24#[allow(dead_code)]
25pub enum AgentEvent {
26 Thinking {
27 operation_id: u64,
28 },
29 ToolStart {
30 operation_id: u64,
31 name: String,
32 args: serde_json::Value,
33 },
34 ToolComplete {
35 operation_id: u64,
36 name: String,
37 result: String,
38 },
39 ContentChunk {
40 operation_id: u64,
41 chunk: String,
42 },
43 Done {
44 operation_id: u64,
45 },
46 Cancelled {
47 operation_id: u64,
48 },
49 Error {
50 operation_id: u64,
51 message: String,
52 },
53 TokenUsage {
54 operation_id: u64,
55 input_tokens: u64,
56 output_tokens: u64,
57 },
58}
59
60pub struct AgentBridge {
62 llm_client: Box<dyn LlmProvider>,
64 executor: ToolExecutor,
66 tool_names: Vec<&'static str>,
68 config: limit_llm::Config,
70 event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
72 tracking_db: TrackingDb,
74 cancellation_token: Option<CancellationToken>,
76 operation_id: u64,
78}
79
80impl AgentBridge {
81 pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
89 let llm_client = ProviderFactory::create_provider(&config)
90 .map_err(|e| CliError::ConfigError(e.to_string()))?;
91
92 let mut tool_registry = ToolRegistry::new();
93 Self::register_tools(&mut tool_registry);
94
95 let executor = ToolExecutor::new(tool_registry);
97
98 let tool_names = vec![
100 "file_read",
101 "file_write",
102 "file_edit",
103 "bash",
104 "git_status",
105 "git_diff",
106 "git_log",
107 "git_add",
108 "git_commit",
109 "git_push",
110 "git_pull",
111 "git_clone",
112 "grep",
113 "ast_grep",
114 "lsp",
115 "web_search",
116 "web_fetch",
117 ];
118
119 Ok(Self {
120 llm_client,
121 executor,
122 tool_names,
123 config,
124 event_tx: None,
125 tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
126 cancellation_token: None,
127 operation_id: 0,
128 })
129 }
130
131 pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
133 self.event_tx = Some(tx);
134 }
135
136 pub fn set_cancellation_token(&mut self, token: CancellationToken, operation_id: u64) {
138 debug!("set_cancellation_token: operation_id={}", operation_id);
139 self.cancellation_token = Some(token);
140 self.operation_id = operation_id;
141 }
142
143 pub fn clear_cancellation_token(&mut self) {
145 self.cancellation_token = None;
146 }
147
148 fn register_tools(registry: &mut ToolRegistry) {
150 registry
152 .register(FileReadTool::new())
153 .expect("Failed to register file_read");
154 registry
155 .register(FileWriteTool::new())
156 .expect("Failed to register file_write");
157 registry
158 .register(FileEditTool::new())
159 .expect("Failed to register file_edit");
160
161 registry
163 .register(BashTool::new())
164 .expect("Failed to register bash");
165
166 registry
168 .register(GitStatusTool::new())
169 .expect("Failed to register git_status");
170 registry
171 .register(GitDiffTool::new())
172 .expect("Failed to register git_diff");
173 registry
174 .register(GitLogTool::new())
175 .expect("Failed to register git_log");
176 registry
177 .register(GitAddTool::new())
178 .expect("Failed to register git_add");
179 registry
180 .register(GitCommitTool::new())
181 .expect("Failed to register git_commit");
182 registry
183 .register(GitPushTool::new())
184 .expect("Failed to register git_push");
185 registry
186 .register(GitPullTool::new())
187 .expect("Failed to register git_pull");
188 registry
189 .register(GitCloneTool::new())
190 .expect("Failed to register git_clone");
191
192 registry
194 .register(GrepTool::new())
195 .expect("Failed to register grep");
196 registry
197 .register(AstGrepTool::new())
198 .expect("Failed to register ast_grep");
199 registry
200 .register(LspTool::new())
201 .expect("Failed to register lsp");
202
203 registry
205 .register(WebSearchTool::new())
206 .expect("Failed to register web_search");
207 registry
208 .register(WebFetchTool::new())
209 .expect("Failed to register web_fetch");
210 }
211
212 #[instrument(skip(self, _messages))]
221 pub async fn process_message(
222 &mut self,
223 user_input: &str,
224 _messages: &mut Vec<Message>,
225 ) -> Result<String, CliError> {
226 if _messages.is_empty() {
229 let system_message = Message {
230 role: Role::System,
231 content: Some(SYSTEM_PROMPT.to_string()),
232 tool_calls: None,
233 tool_call_id: None,
234 };
235 _messages.push(system_message);
236 }
237
238 let user_message = Message {
240 role: Role::User,
241 content: Some(user_input.to_string()),
242 tool_calls: None,
243 tool_call_id: None,
244 };
245 _messages.push(user_message);
246
247 let tool_definitions = self.get_tool_definitions();
249
250 let mut full_response = String::new();
252 let mut tool_calls: Vec<LlmToolCall> = Vec::new();
253 let max_iterations = self
254 .config
255 .providers
256 .get(&self.config.provider)
257 .map(|p| p.max_iterations)
258 .unwrap_or(100); let mut iteration = 0;
260
261 while max_iterations == 0 || iteration < max_iterations {
262 iteration += 1;
263 debug!("Agent loop iteration {}", iteration);
264
265 debug!(
267 "Sending Thinking event with operation_id={}",
268 self.operation_id
269 );
270 self.send_event(AgentEvent::Thinking {
271 operation_id: self.operation_id,
272 });
273
274 let request_start = std::time::Instant::now();
276
277 let mut stream = self
279 .llm_client
280 .send(_messages.clone(), tool_definitions.clone())
281 .await
282 .map_err(|e| CliError::ConfigError(e.to_string()))?;
283
284 tool_calls.clear();
285 let mut current_content = String::new();
286 let mut accumulated_calls: std::collections::HashMap<
288 String,
289 (String, serde_json::Value),
290 > = std::collections::HashMap::new();
291
292 loop {
294 if let Some(ref token) = self.cancellation_token {
296 if token.is_cancelled() {
297 debug!("Operation cancelled by user (pre-stream check)");
298 self.send_event(AgentEvent::Cancelled {
299 operation_id: self.operation_id,
300 });
301 return Err(CliError::ConfigError(
302 "Operation cancelled by user".to_string(),
303 ));
304 }
305 }
306
307 let chunk_result = if let Some(ref token) = self.cancellation_token {
310 tokio::select! {
311 chunk = stream.next() => chunk,
312 _ = token.cancelled() => {
313 debug!("Operation cancelled via token while waiting for stream");
314 self.send_event(AgentEvent::Cancelled {
315 operation_id: self.operation_id,
316 });
317 return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
318 }
319 }
320 } else {
321 stream.next().await
322 };
323
324 let Some(chunk_result) = chunk_result else {
325 break;
327 };
328
329 match chunk_result {
330 Ok(ProviderResponseChunk::ContentDelta(text)) => {
331 current_content.push_str(&text);
332 debug!(
333 "ContentDelta: {} chars (total: {})",
334 text.len(),
335 current_content.len()
336 );
337 self.send_event(AgentEvent::ContentChunk {
338 operation_id: self.operation_id,
339 chunk: text,
340 });
341 }
342 Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
343 }
345 Ok(ProviderResponseChunk::ToolCallDelta {
346 id,
347 name,
348 arguments,
349 }) => {
350 debug!(
351 "ToolCallDelta: id={}, name={}, args_len={}",
352 id,
353 name,
354 arguments.to_string().len()
355 );
356 accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
358 }
359 Ok(ProviderResponseChunk::Done(usage)) => {
360 let duration_ms = request_start.elapsed().as_millis() as u64;
362 let cost =
363 calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
364 let _ = self.tracking_db.track_request(
365 self.model(),
366 usage.input_tokens,
367 usage.output_tokens,
368 cost,
369 duration_ms,
370 );
371 self.send_event(AgentEvent::TokenUsage {
373 operation_id: self.operation_id,
374 input_tokens: usage.input_tokens,
375 output_tokens: usage.output_tokens,
376 });
377 break;
378 }
379 Err(e) => {
380 let error_msg = format!("LLM error: {}", e);
381 self.send_event(AgentEvent::Error {
382 operation_id: self.operation_id,
383 message: error_msg.clone(),
384 });
385 return Err(CliError::ConfigError(error_msg));
386 }
387 }
388 }
389
390 tool_calls = accumulated_calls
392 .into_iter()
393 .map(|(id, (name, args))| LlmToolCall {
394 id,
395 tool_type: "function".to_string(),
396 function: limit_llm::types::FunctionCall {
397 name,
398 arguments: args.to_string(),
399 },
400 })
401 .collect();
402
403 full_response = current_content.clone();
408
409 debug!(
410 "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
411 iteration,
412 current_content.len(),
413 tool_calls.len(),
414 full_response.len()
415 );
416
417 if tool_calls.is_empty() {
419 debug!("No tool calls, breaking loop after iteration {}", iteration);
420 break;
421 }
422
423 debug!(
424 "Tool calls found (count={}), continuing to iteration {}",
425 tool_calls.len(),
426 iteration + 1
427 );
428
429 let assistant_message = Message {
432 role: Role::Assistant,
433 content: None, tool_calls: Some(tool_calls.clone()),
435 tool_call_id: None,
436 };
437 _messages.push(assistant_message);
438
439 let executor_calls: Vec<ToolCall> = tool_calls
441 .iter()
442 .map(|tc| {
443 let args: serde_json::Value =
444 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
445 ToolCall::new(&tc.id, &tc.function.name, args)
446 })
447 .collect();
448
449 for tc in &tool_calls {
451 let args: serde_json::Value =
452 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
453 self.send_event(AgentEvent::ToolStart {
454 operation_id: self.operation_id,
455 name: tc.function.name.clone(),
456 args,
457 });
458 }
459 let results = self.executor.execute_tools(executor_calls).await;
461
462 for result in results {
464 let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
465 if let Some(tool_call) = tool_call {
466 let output_json = match &result.output {
467 Ok(value) => {
468 serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
469 }
470 Err(e) => json!({ "error": e.to_string() }).to_string(),
471 };
472
473 self.send_event(AgentEvent::ToolComplete {
474 operation_id: self.operation_id,
475 name: tool_call.function.name.clone(),
476 result: output_json.clone(),
477 });
478
479 let tool_result_message = Message {
481 role: Role::Tool,
482 content: Some(output_json),
483 tool_calls: None,
484 tool_call_id: Some(result.call_id),
485 };
486 _messages.push(tool_result_message);
487 }
488 }
489 }
490
491 if max_iterations > 0 && iteration >= max_iterations && !_messages.is_empty() {
494 debug!("Making final LLM call after hitting max iterations (forcing text response)");
495
496 let constraint_message = Message {
498 role: Role::User,
499 content: Some(
500 "We've reached the iteration limit. Please provide a summary of:\n\
501 1. What you've completed so far\n\
502 2. What remains to be done\n\
503 3. Recommended next steps for the user to continue"
504 .to_string(),
505 ),
506 tool_calls: None,
507 tool_call_id: None,
508 };
509 _messages.push(constraint_message);
510
511 let no_tools: Vec<LlmTool> = vec![];
513 let mut stream = self
514 .llm_client
515 .send(_messages.clone(), no_tools)
516 .await
517 .map_err(|e| CliError::ConfigError(e.to_string()))?;
518
519 full_response.clear();
521 loop {
522 if let Some(ref token) = self.cancellation_token {
524 if token.is_cancelled() {
525 debug!("Operation cancelled by user in final loop (pre-stream check)");
526 self.send_event(AgentEvent::Cancelled {
527 operation_id: self.operation_id,
528 });
529 return Err(CliError::ConfigError(
530 "Operation cancelled by user".to_string(),
531 ));
532 }
533 }
534
535 let chunk_result = if let Some(ref token) = self.cancellation_token {
538 tokio::select! {
539 chunk = stream.next() => chunk,
540 _ = token.cancelled() => {
541 debug!("Operation cancelled via token while waiting for stream");
542 self.send_event(AgentEvent::Cancelled {
543 operation_id: self.operation_id,
544 });
545 return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
546 }
547 }
548 } else {
549 stream.next().await
550 };
551
552 let Some(chunk_result) = chunk_result else {
553 break;
555 };
556
557 match chunk_result {
558 Ok(ProviderResponseChunk::ContentDelta(text)) => {
559 full_response.push_str(&text);
560 self.send_event(AgentEvent::ContentChunk {
561 operation_id: self.operation_id,
562 chunk: text,
563 });
564 }
565 Ok(ProviderResponseChunk::Done(_)) => {
566 break;
567 }
568 Err(e) => {
569 debug!("Error in final LLM call: {}", e);
570 break;
571 }
572 _ => {}
573 }
574 }
575 }
576
577 if !full_response.is_empty() {
581 let last_assistant_idx = _messages.iter().rposition(|m| m.role == Role::Assistant);
585
586 if let Some(idx) = last_assistant_idx {
587 let last_assistant = &mut _messages[idx];
588
589 if last_assistant.content.is_none()
591 || last_assistant
592 .content
593 .as_ref()
594 .map(|c| c.is_empty())
595 .unwrap_or(true)
596 {
597 last_assistant.content = Some(full_response.clone());
598 debug!("Updated last assistant message with final response content");
599 } else {
600 debug!("Last assistant already has content, adding new message");
603 let final_assistant_message = Message {
604 role: Role::Assistant,
605 content: Some(full_response.clone()),
606 tool_calls: None,
607 tool_call_id: None,
608 };
609 _messages.push(final_assistant_message);
610 }
611 } else {
612 debug!("No assistant message found, adding new message");
614 let final_assistant_message = Message {
615 role: Role::Assistant,
616 content: Some(full_response.clone()),
617 tool_calls: None,
618 tool_call_id: None,
619 };
620 _messages.push(final_assistant_message);
621 }
622 }
623
624 self.send_event(AgentEvent::Done {
625 operation_id: self.operation_id,
626 });
627 Ok(full_response)
628 }
629
630 pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
632 self.tool_names
633 .iter()
634 .map(|name| {
635 let (description, parameters) = Self::get_tool_schema(name);
636 LlmTool {
637 tool_type: "function".to_string(),
638 function: limit_llm::types::ToolFunction {
639 name: name.to_string(),
640 description,
641 parameters,
642 },
643 }
644 })
645 .collect()
646 }
647
648 fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
650 match name {
651 "file_read" => (
652 "Read the contents of a file".to_string(),
653 json!({
654 "type": "object",
655 "properties": {
656 "path": {
657 "type": "string",
658 "description": "Path to the file to read"
659 }
660 },
661 "required": ["path"]
662 }),
663 ),
664 "file_write" => (
665 "Write content to a file, creating parent directories if needed".to_string(),
666 json!({
667 "type": "object",
668 "properties": {
669 "path": {
670 "type": "string",
671 "description": "Path to the file to write"
672 },
673 "content": {
674 "type": "string",
675 "description": "Content to write to the file"
676 }
677 },
678 "required": ["path", "content"]
679 }),
680 ),
681 "file_edit" => (
682 "Replace text in a file with new text".to_string(),
683 json!({
684 "type": "object",
685 "properties": {
686 "path": {
687 "type": "string",
688 "description": "Path to the file to edit"
689 },
690 "old_text": {
691 "type": "string",
692 "description": "Text to find and replace"
693 },
694 "new_text": {
695 "type": "string",
696 "description": "New text to replace with"
697 }
698 },
699 "required": ["path", "old_text", "new_text"]
700 }),
701 ),
702 "bash" => (
703 "Execute a bash command in a shell".to_string(),
704 json!({
705 "type": "object",
706 "properties": {
707 "command": {
708 "type": "string",
709 "description": "Bash command to execute"
710 },
711 "workdir": {
712 "type": "string",
713 "description": "Working directory (default: current directory)"
714 },
715 "timeout": {
716 "type": "integer",
717 "description": "Timeout in seconds (default: 60)"
718 }
719 },
720 "required": ["command"]
721 }),
722 ),
723 "git_status" => (
724 "Get git repository status".to_string(),
725 json!({
726 "type": "object",
727 "properties": {},
728 "required": []
729 }),
730 ),
731 "git_diff" => (
732 "Get git diff".to_string(),
733 json!({
734 "type": "object",
735 "properties": {},
736 "required": []
737 }),
738 ),
739 "git_log" => (
740 "Get git commit log".to_string(),
741 json!({
742 "type": "object",
743 "properties": {
744 "count": {
745 "type": "integer",
746 "description": "Number of commits to show (default: 10)"
747 }
748 },
749 "required": []
750 }),
751 ),
752 "git_add" => (
753 "Add files to git staging area".to_string(),
754 json!({
755 "type": "object",
756 "properties": {
757 "files": {
758 "type": "array",
759 "items": {"type": "string"},
760 "description": "List of file paths to add"
761 }
762 },
763 "required": ["files"]
764 }),
765 ),
766 "git_commit" => (
767 "Create a git commit".to_string(),
768 json!({
769 "type": "object",
770 "properties": {
771 "message": {
772 "type": "string",
773 "description": "Commit message"
774 }
775 },
776 "required": ["message"]
777 }),
778 ),
779 "git_push" => (
780 "Push commits to remote repository".to_string(),
781 json!({
782 "type": "object",
783 "properties": {
784 "remote": {
785 "type": "string",
786 "description": "Remote name (default: origin)"
787 },
788 "branch": {
789 "type": "string",
790 "description": "Branch name (default: current branch)"
791 }
792 },
793 "required": []
794 }),
795 ),
796 "git_pull" => (
797 "Pull changes from remote repository".to_string(),
798 json!({
799 "type": "object",
800 "properties": {
801 "remote": {
802 "type": "string",
803 "description": "Remote name (default: origin)"
804 },
805 "branch": {
806 "type": "string",
807 "description": "Branch name (default: current branch)"
808 }
809 },
810 "required": []
811 }),
812 ),
813 "git_clone" => (
814 "Clone a git repository".to_string(),
815 json!({
816 "type": "object",
817 "properties": {
818 "url": {
819 "type": "string",
820 "description": "Repository URL to clone"
821 },
822 "directory": {
823 "type": "string",
824 "description": "Directory to clone into (optional)"
825 }
826 },
827 "required": ["url"]
828 }),
829 ),
830 "grep" => (
831 "Search for text patterns in files using regex".to_string(),
832 json!({
833 "type": "object",
834 "properties": {
835 "pattern": {
836 "type": "string",
837 "description": "Regex pattern to search for"
838 },
839 "path": {
840 "type": "string",
841 "description": "Path to search in (default: current directory)"
842 }
843 },
844 "required": ["pattern"]
845 }),
846 ),
847 "ast_grep" => (
848 "Search code using AST patterns (structural code matching)".to_string(),
849 json!({
850 "type": "object",
851 "properties": {
852 "pattern": {
853 "type": "string",
854 "description": "AST pattern to match"
855 },
856 "language": {
857 "type": "string",
858 "description": "Programming language (rust, typescript, python)"
859 },
860 "path": {
861 "type": "string",
862 "description": "Path to search in (default: current directory)"
863 }
864 },
865 "required": ["pattern", "language"]
866 }),
867 ),
868 "lsp" => (
869 "Perform Language Server Protocol operations (goto_definition, find_references)"
870 .to_string(),
871 json!({
872 "type": "object",
873 "properties": {
874 "command": {
875 "type": "string",
876 "description": "LSP command: goto_definition or find_references"
877 },
878 "file_path": {
879 "type": "string",
880 "description": "Path to the file"
881 },
882 "position": {
883 "type": "object",
884 "description": "Position in the file (line, character)",
885 "properties": {
886 "line": {"type": "integer"},
887 "character": {"type": "integer"}
888 },
889 "required": ["line", "character"]
890 }
891 },
892 "required": ["command", "file_path", "position"]
893 }),
894 ),
895 "web_search" => (
896 format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
897 json!({
898 "type": "object",
899 "properties": {
900 "query": {
901 "type": "string",
902 "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
903 },
904 "numResults": {
905 "type": "integer",
906 "description": "Number of results to return (default: 8, max: 20)",
907 "default": 8
908 }
909 },
910 "required": ["query"]
911 }),
912 ),
913 "web_fetch" => (
914 "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
915 json!({
916 "type": "object",
917 "properties": {
918 "url": {
919 "type": "string",
920 "description": "URL to fetch (must start with http:// or https://)"
921 },
922 "format": {
923 "type": "string",
924 "enum": ["markdown", "text", "html"],
925 "default": "markdown",
926 "description": "Output format (default: markdown)"
927 }
928 },
929 "required": ["url"]
930 }),
931 ),
932 _ => (
933 format!("Tool: {}", name),
934 json!({
935 "type": "object",
936 "properties": {},
937 "required": []
938 }),
939 ),
940 }
941 }
942
943 fn send_event(&self, event: AgentEvent) {
945 if let Some(ref tx) = self.event_tx {
946 let _ = tx.send(event);
947 }
948 }
949
950 #[allow(dead_code)]
952 pub fn is_ready(&self) -> bool {
953 self.config
954 .providers
955 .get(&self.config.provider)
956 .map(|p| p.api_key_or_env(&self.config.provider).is_some())
957 .unwrap_or(false)
958 }
959
960 pub fn model(&self) -> &str {
962 self.config
963 .providers
964 .get(&self.config.provider)
965 .map(|p| p.model.as_str())
966 .unwrap_or("")
967 }
968
969 pub fn max_tokens(&self) -> u32 {
971 self.config
972 .providers
973 .get(&self.config.provider)
974 .map(|p| p.max_tokens)
975 .unwrap_or(4096)
976 }
977
978 pub fn timeout(&self) -> u64 {
980 self.config
981 .providers
982 .get(&self.config.provider)
983 .map(|p| p.timeout)
984 .unwrap_or(60)
985 }
986}
987fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
989 let (input_price, output_price) = match model {
990 "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
992 "gpt-4" => (30.0, 60.0),
994 "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
996 _ => (0.0, 0.0),
998 };
999 (input_tokens as f64 * input_price / 1_000_000.0)
1000 + (output_tokens as f64 * output_price / 1_000_000.0)
1001}
1002
1003#[cfg(test)]
1004mod tests {
1005 use super::*;
1006 use limit_llm::{Config as LlmConfig, ProviderConfig};
1007 use std::collections::HashMap;
1008
1009 #[tokio::test]
1010 async fn test_agent_bridge_new() {
1011 let mut providers = HashMap::new();
1012 providers.insert(
1013 "anthropic".to_string(),
1014 ProviderConfig {
1015 api_key: Some("test-key".to_string()),
1016 model: "claude-3-5-sonnet-20241022".to_string(),
1017 base_url: None,
1018 max_tokens: 4096,
1019 timeout: 60,
1020 max_iterations: 100,
1021 thinking_enabled: false,
1022 clear_thinking: true,
1023 },
1024 );
1025 let config = LlmConfig {
1026 provider: "anthropic".to_string(),
1027 providers,
1028 };
1029
1030 let bridge = AgentBridge::new(config).unwrap();
1031 assert!(bridge.is_ready());
1032 }
1033
1034 #[tokio::test]
1035 async fn test_agent_bridge_new_no_api_key() {
1036 let mut providers = HashMap::new();
1037 providers.insert(
1038 "anthropic".to_string(),
1039 ProviderConfig {
1040 api_key: None,
1041 model: "claude-3-5-sonnet-20241022".to_string(),
1042 base_url: None,
1043 max_tokens: 4096,
1044 timeout: 60,
1045 max_iterations: 100,
1046 thinking_enabled: false,
1047 clear_thinking: true,
1048 },
1049 );
1050 let config = LlmConfig {
1051 provider: "anthropic".to_string(),
1052 providers,
1053 };
1054
1055 let result = AgentBridge::new(config);
1056 assert!(result.is_err());
1057 }
1058
1059 #[tokio::test]
1060 async fn test_get_tool_definitions() {
1061 let mut providers = HashMap::new();
1062 providers.insert(
1063 "anthropic".to_string(),
1064 ProviderConfig {
1065 api_key: Some("test-key".to_string()),
1066 model: "claude-3-5-sonnet-20241022".to_string(),
1067 base_url: None,
1068 max_tokens: 4096,
1069 timeout: 60,
1070 max_iterations: 100,
1071 thinking_enabled: false,
1072 clear_thinking: true,
1073 },
1074 );
1075 let config = LlmConfig {
1076 provider: "anthropic".to_string(),
1077 providers,
1078 };
1079
1080 let bridge = AgentBridge::new(config).unwrap();
1081 let definitions = bridge.get_tool_definitions();
1082
1083 assert_eq!(definitions.len(), 17);
1084
1085 let file_read = definitions
1087 .iter()
1088 .find(|d| d.function.name == "file_read")
1089 .unwrap();
1090 assert_eq!(file_read.tool_type, "function");
1091 assert_eq!(file_read.function.name, "file_read");
1092 assert!(file_read.function.description.contains("Read"));
1093
1094 let bash = definitions
1096 .iter()
1097 .find(|d| d.function.name == "bash")
1098 .unwrap();
1099 assert_eq!(bash.function.name, "bash");
1100 assert!(bash.function.parameters["required"]
1101 .as_array()
1102 .unwrap()
1103 .contains(&"command".into()));
1104 }
1105
1106 #[test]
1107 fn test_get_tool_schema() {
1108 let (desc, params) = AgentBridge::get_tool_schema("file_read");
1109 assert!(desc.contains("Read"));
1110 assert_eq!(params["properties"]["path"]["type"], "string");
1111 assert!(params["required"]
1112 .as_array()
1113 .unwrap()
1114 .contains(&"path".into()));
1115
1116 let (desc, params) = AgentBridge::get_tool_schema("bash");
1117 assert!(desc.contains("bash"));
1118 assert_eq!(params["properties"]["command"]["type"], "string");
1119
1120 let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
1121 assert!(desc.contains("unknown_tool"));
1122 }
1123
1124 #[test]
1125 fn test_is_ready() {
1126 let mut providers = HashMap::new();
1127 providers.insert(
1128 "anthropic".to_string(),
1129 ProviderConfig {
1130 api_key: Some("test-key".to_string()),
1131 model: "claude-3-5-sonnet-20241022".to_string(),
1132 base_url: None,
1133 max_tokens: 4096,
1134 timeout: 60,
1135 max_iterations: 100,
1136 thinking_enabled: false,
1137 clear_thinking: true,
1138 },
1139 );
1140 let config_with_key = LlmConfig {
1141 provider: "anthropic".to_string(),
1142 providers,
1143 };
1144
1145 let bridge = AgentBridge::new(config_with_key).unwrap();
1146 assert!(bridge.is_ready());
1147 }
1148}