1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::{
4 AstGrepTool, BashTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool, GitCloneTool,
5 GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool, GrepTool,
6 LspTool, WebFetchTool, WebSearchTool,
7};
8use chrono::Datelike;
9use futures::StreamExt;
10use limit_agent::executor::{ToolCall, ToolExecutor};
11use limit_agent::registry::ToolRegistry;
12use limit_llm::providers::LlmProvider;
13use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
14use limit_llm::ProviderFactory;
15use limit_llm::ProviderResponseChunk;
16use limit_llm::TrackingDb;
17use serde_json::json;
18use tokio::sync::mpsc;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, instrument};
21
22#[derive(Debug, Clone)]
24#[allow(dead_code)]
25pub enum AgentEvent {
26 Thinking {
27 operation_id: u64,
28 },
29 ToolStart {
30 operation_id: u64,
31 name: String,
32 args: serde_json::Value,
33 },
34 ToolComplete {
35 operation_id: u64,
36 name: String,
37 result: String,
38 },
39 ContentChunk {
40 operation_id: u64,
41 chunk: String,
42 },
43 Done {
44 operation_id: u64,
45 },
46 Cancelled {
47 operation_id: u64,
48 },
49 Error {
50 operation_id: u64,
51 message: String,
52 },
53 TokenUsage {
54 operation_id: u64,
55 input_tokens: u64,
56 output_tokens: u64,
57 },
58}
59
60pub struct AgentBridge {
62 llm_client: Box<dyn LlmProvider>,
64 executor: ToolExecutor,
66 tool_names: Vec<&'static str>,
68 config: limit_llm::Config,
70 event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
72 tracking_db: TrackingDb,
74 cancellation_token: Option<CancellationToken>,
76 operation_id: u64,
78}
79
80impl AgentBridge {
81 pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
89 let llm_client = ProviderFactory::create_provider(&config)
90 .map_err(|e| CliError::ConfigError(e.to_string()))?;
91
92 let mut tool_registry = ToolRegistry::new();
93 Self::register_tools(&mut tool_registry);
94
95 let executor = ToolExecutor::new(tool_registry);
97
98 let tool_names = vec![
100 "file_read",
101 "file_write",
102 "file_edit",
103 "bash",
104 "git_status",
105 "git_diff",
106 "git_log",
107 "git_add",
108 "git_commit",
109 "git_push",
110 "git_pull",
111 "git_clone",
112 "grep",
113 "ast_grep",
114 "lsp",
115 "web_search",
116 "web_fetch",
117 ];
118
119 Ok(Self {
120 llm_client,
121 executor,
122 tool_names,
123 config,
124 event_tx: None,
125 tracking_db: TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?,
126 cancellation_token: None,
127 operation_id: 0,
128 })
129 }
130
131 pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
133 self.event_tx = Some(tx);
134 }
135
136 pub fn set_cancellation_token(&mut self, token: CancellationToken, operation_id: u64) {
138 debug!("set_cancellation_token: operation_id={}", operation_id);
139 self.cancellation_token = Some(token);
140 self.operation_id = operation_id;
141 }
142
143 pub fn clear_cancellation_token(&mut self) {
145 self.cancellation_token = None;
146 }
147
148 fn register_tools(registry: &mut ToolRegistry) {
150 registry
152 .register(FileReadTool::new())
153 .expect("Failed to register file_read");
154 registry
155 .register(FileWriteTool::new())
156 .expect("Failed to register file_write");
157 registry
158 .register(FileEditTool::new())
159 .expect("Failed to register file_edit");
160
161 registry
163 .register(BashTool::new())
164 .expect("Failed to register bash");
165
166 registry
168 .register(GitStatusTool::new())
169 .expect("Failed to register git_status");
170 registry
171 .register(GitDiffTool::new())
172 .expect("Failed to register git_diff");
173 registry
174 .register(GitLogTool::new())
175 .expect("Failed to register git_log");
176 registry
177 .register(GitAddTool::new())
178 .expect("Failed to register git_add");
179 registry
180 .register(GitCommitTool::new())
181 .expect("Failed to register git_commit");
182 registry
183 .register(GitPushTool::new())
184 .expect("Failed to register git_push");
185 registry
186 .register(GitPullTool::new())
187 .expect("Failed to register git_pull");
188 registry
189 .register(GitCloneTool::new())
190 .expect("Failed to register git_clone");
191
192 registry
194 .register(GrepTool::new())
195 .expect("Failed to register grep");
196 registry
197 .register(AstGrepTool::new())
198 .expect("Failed to register ast_grep");
199 registry
200 .register(LspTool::new())
201 .expect("Failed to register lsp");
202
203 registry
205 .register(WebSearchTool::new())
206 .expect("Failed to register web_search");
207 registry
208 .register(WebFetchTool::new())
209 .expect("Failed to register web_fetch");
210 }
211
212 #[instrument(skip(self, _messages))]
221 pub async fn process_message(
222 &mut self,
223 user_input: &str,
224 _messages: &mut Vec<Message>,
225 ) -> Result<String, CliError> {
226 if _messages.is_empty() {
229 let system_message = Message {
230 role: Role::System,
231 content: Some(SYSTEM_PROMPT.to_string()),
232 tool_calls: None,
233 tool_call_id: None,
234 };
235 _messages.push(system_message);
236 }
237
238 let user_message = Message {
240 role: Role::User,
241 content: Some(user_input.to_string()),
242 tool_calls: None,
243 tool_call_id: None,
244 };
245 _messages.push(user_message);
246
247 let tool_definitions = self.get_tool_definitions();
249
250 let mut full_response = String::new();
252 let mut tool_calls: Vec<LlmToolCall> = Vec::new();
253 let max_iterations = self
254 .config
255 .providers
256 .get(&self.config.provider)
257 .map(|p| p.max_iterations)
258 .unwrap_or(100); let mut iteration = 0;
260
261 while max_iterations == 0 || iteration < max_iterations {
262 iteration += 1;
263 debug!("Agent loop iteration {}", iteration);
264
265 debug!("Sending Thinking event with operation_id={}", self.operation_id);
267 self.send_event(AgentEvent::Thinking {
268 operation_id: self.operation_id,
269 });
270
271 let request_start = std::time::Instant::now();
273
274 let mut stream = self
276 .llm_client
277 .send(_messages.clone(), tool_definitions.clone())
278 .await
279 .map_err(|e| CliError::ConfigError(e.to_string()))?;
280
281 tool_calls.clear();
282 let mut current_content = String::new();
283 let mut accumulated_calls: std::collections::HashMap<
285 String,
286 (String, serde_json::Value),
287 > = std::collections::HashMap::new();
288
289 loop {
291 if let Some(ref token) = self.cancellation_token {
293 if token.is_cancelled() {
294 debug!("Operation cancelled by user (pre-stream check)");
295 self.send_event(AgentEvent::Cancelled {
296 operation_id: self.operation_id,
297 });
298 return Err(CliError::ConfigError(
299 "Operation cancelled by user".to_string(),
300 ));
301 }
302 }
303
304 let chunk_result = if let Some(ref token) = self.cancellation_token {
307 tokio::select! {
308 chunk = stream.next() => chunk,
309 _ = token.cancelled() => {
310 debug!("Operation cancelled via token while waiting for stream");
311 self.send_event(AgentEvent::Cancelled {
312 operation_id: self.operation_id,
313 });
314 return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
315 }
316 }
317 } else {
318 stream.next().await
319 };
320
321 let Some(chunk_result) = chunk_result else {
322 break;
324 };
325
326 match chunk_result {
327 Ok(ProviderResponseChunk::ContentDelta(text)) => {
328 current_content.push_str(&text);
329 debug!(
330 "ContentDelta: {} chars (total: {})",
331 text.len(),
332 current_content.len()
333 );
334 self.send_event(AgentEvent::ContentChunk {
335 operation_id: self.operation_id,
336 chunk: text,
337 });
338 }
339 Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
340 }
342 Ok(ProviderResponseChunk::ToolCallDelta {
343 id,
344 name,
345 arguments,
346 }) => {
347 debug!(
348 "ToolCallDelta: id={}, name={}, args_len={}",
349 id,
350 name,
351 arguments.to_string().len()
352 );
353 accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
355 }
356 Ok(ProviderResponseChunk::Done(usage)) => {
357 let duration_ms = request_start.elapsed().as_millis() as u64;
359 let cost =
360 calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
361 let _ = self.tracking_db.track_request(
362 self.model(),
363 usage.input_tokens,
364 usage.output_tokens,
365 cost,
366 duration_ms,
367 );
368 self.send_event(AgentEvent::TokenUsage {
370 operation_id: self.operation_id,
371 input_tokens: usage.input_tokens,
372 output_tokens: usage.output_tokens,
373 });
374 break;
375 }
376 Err(e) => {
377 let error_msg = format!("LLM error: {}", e);
378 self.send_event(AgentEvent::Error {
379 operation_id: self.operation_id,
380 message: error_msg.clone(),
381 });
382 return Err(CliError::ConfigError(error_msg));
383 }
384 }
385 }
386
387 tool_calls = accumulated_calls
389 .into_iter()
390 .map(|(id, (name, args))| LlmToolCall {
391 id,
392 tool_type: "function".to_string(),
393 function: limit_llm::types::FunctionCall {
394 name,
395 arguments: args.to_string(),
396 },
397 })
398 .collect();
399
400 full_response = current_content.clone();
405
406 debug!(
407 "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
408 iteration,
409 current_content.len(),
410 tool_calls.len(),
411 full_response.len()
412 );
413
414 if tool_calls.is_empty() {
416 debug!("No tool calls, breaking loop after iteration {}", iteration);
417 break;
418 }
419
420 debug!(
421 "Tool calls found (count={}), continuing to iteration {}",
422 tool_calls.len(),
423 iteration + 1
424 );
425
426 let assistant_message = Message {
429 role: Role::Assistant,
430 content: None, tool_calls: Some(tool_calls.clone()),
432 tool_call_id: None,
433 };
434 _messages.push(assistant_message);
435
436 let executor_calls: Vec<ToolCall> = tool_calls
438 .iter()
439 .map(|tc| {
440 let args: serde_json::Value =
441 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
442 ToolCall::new(&tc.id, &tc.function.name, args)
443 })
444 .collect();
445
446 for tc in &tool_calls {
448 let args: serde_json::Value =
449 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
450 self.send_event(AgentEvent::ToolStart {
451 operation_id: self.operation_id,
452 name: tc.function.name.clone(),
453 args,
454 });
455 }
456 let results = self.executor.execute_tools(executor_calls).await;
458
459 for result in results {
461 let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
462 if let Some(tool_call) = tool_call {
463 let output_json = match &result.output {
464 Ok(value) => {
465 serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
466 }
467 Err(e) => json!({ "error": e.to_string() }).to_string(),
468 };
469
470 self.send_event(AgentEvent::ToolComplete {
471 operation_id: self.operation_id,
472 name: tool_call.function.name.clone(),
473 result: output_json.clone(),
474 });
475
476 let tool_result_message = Message {
478 role: Role::Tool,
479 content: Some(output_json),
480 tool_calls: None,
481 tool_call_id: Some(result.call_id),
482 };
483 _messages.push(tool_result_message);
484 }
485 }
486 }
487
488 if max_iterations > 0 && iteration >= max_iterations && !_messages.is_empty() {
491 debug!("Making final LLM call after hitting max iterations (forcing text response)");
492
493 let constraint_message = Message {
495 role: Role::User,
496 content: Some(
497 "We've reached the iteration limit. Please provide a summary of:\n\
498 1. What you've completed so far\n\
499 2. What remains to be done\n\
500 3. Recommended next steps for the user to continue"
501 .to_string(),
502 ),
503 tool_calls: None,
504 tool_call_id: None,
505 };
506 _messages.push(constraint_message);
507
508 let no_tools: Vec<LlmTool> = vec![];
510 let mut stream = self
511 .llm_client
512 .send(_messages.clone(), no_tools)
513 .await
514 .map_err(|e| CliError::ConfigError(e.to_string()))?;
515
516 full_response.clear();
518 loop {
519 if let Some(ref token) = self.cancellation_token {
521 if token.is_cancelled() {
522 debug!("Operation cancelled by user in final loop (pre-stream check)");
523 self.send_event(AgentEvent::Cancelled {
524 operation_id: self.operation_id,
525 });
526 return Err(CliError::ConfigError(
527 "Operation cancelled by user".to_string(),
528 ));
529 }
530 }
531
532 let chunk_result = if let Some(ref token) = self.cancellation_token {
535 tokio::select! {
536 chunk = stream.next() => chunk,
537 _ = token.cancelled() => {
538 debug!("Operation cancelled via token while waiting for stream");
539 self.send_event(AgentEvent::Cancelled {
540 operation_id: self.operation_id,
541 });
542 return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
543 }
544 }
545 } else {
546 stream.next().await
547 };
548
549 let Some(chunk_result) = chunk_result else {
550 break;
552 };
553
554 match chunk_result {
555 Ok(ProviderResponseChunk::ContentDelta(text)) => {
556 full_response.push_str(&text);
557 self.send_event(AgentEvent::ContentChunk {
558 operation_id: self.operation_id,
559 chunk: text,
560 });
561 }
562 Ok(ProviderResponseChunk::Done(_)) => {
563 break;
564 }
565 Err(e) => {
566 debug!("Error in final LLM call: {}", e);
567 break;
568 }
569 _ => {}
570 }
571 }
572 }
573
574 if !full_response.is_empty() {
578 let last_assistant_idx = _messages.iter().rposition(|m| m.role == Role::Assistant);
582
583 if let Some(idx) = last_assistant_idx {
584 let last_assistant = &mut _messages[idx];
585
586 if last_assistant.content.is_none()
588 || last_assistant
589 .content
590 .as_ref()
591 .map(|c| c.is_empty())
592 .unwrap_or(true)
593 {
594 last_assistant.content = Some(full_response.clone());
595 debug!("Updated last assistant message with final response content");
596 } else {
597 debug!("Last assistant already has content, adding new message");
600 let final_assistant_message = Message {
601 role: Role::Assistant,
602 content: Some(full_response.clone()),
603 tool_calls: None,
604 tool_call_id: None,
605 };
606 _messages.push(final_assistant_message);
607 }
608 } else {
609 debug!("No assistant message found, adding new message");
611 let final_assistant_message = Message {
612 role: Role::Assistant,
613 content: Some(full_response.clone()),
614 tool_calls: None,
615 tool_call_id: None,
616 };
617 _messages.push(final_assistant_message);
618 }
619 }
620
621 self.send_event(AgentEvent::Done {
622 operation_id: self.operation_id,
623 });
624 Ok(full_response)
625 }
626
627 pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
629 self.tool_names
630 .iter()
631 .map(|name| {
632 let (description, parameters) = Self::get_tool_schema(name);
633 LlmTool {
634 tool_type: "function".to_string(),
635 function: limit_llm::types::ToolFunction {
636 name: name.to_string(),
637 description,
638 parameters,
639 },
640 }
641 })
642 .collect()
643 }
644
645 fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
647 match name {
648 "file_read" => (
649 "Read the contents of a file".to_string(),
650 json!({
651 "type": "object",
652 "properties": {
653 "path": {
654 "type": "string",
655 "description": "Path to the file to read"
656 }
657 },
658 "required": ["path"]
659 }),
660 ),
661 "file_write" => (
662 "Write content to a file, creating parent directories if needed".to_string(),
663 json!({
664 "type": "object",
665 "properties": {
666 "path": {
667 "type": "string",
668 "description": "Path to the file to write"
669 },
670 "content": {
671 "type": "string",
672 "description": "Content to write to the file"
673 }
674 },
675 "required": ["path", "content"]
676 }),
677 ),
678 "file_edit" => (
679 "Replace text in a file with new text".to_string(),
680 json!({
681 "type": "object",
682 "properties": {
683 "path": {
684 "type": "string",
685 "description": "Path to the file to edit"
686 },
687 "old_text": {
688 "type": "string",
689 "description": "Text to find and replace"
690 },
691 "new_text": {
692 "type": "string",
693 "description": "New text to replace with"
694 }
695 },
696 "required": ["path", "old_text", "new_text"]
697 }),
698 ),
699 "bash" => (
700 "Execute a bash command in a shell".to_string(),
701 json!({
702 "type": "object",
703 "properties": {
704 "command": {
705 "type": "string",
706 "description": "Bash command to execute"
707 },
708 "workdir": {
709 "type": "string",
710 "description": "Working directory (default: current directory)"
711 },
712 "timeout": {
713 "type": "integer",
714 "description": "Timeout in seconds (default: 60)"
715 }
716 },
717 "required": ["command"]
718 }),
719 ),
720 "git_status" => (
721 "Get git repository status".to_string(),
722 json!({
723 "type": "object",
724 "properties": {},
725 "required": []
726 }),
727 ),
728 "git_diff" => (
729 "Get git diff".to_string(),
730 json!({
731 "type": "object",
732 "properties": {},
733 "required": []
734 }),
735 ),
736 "git_log" => (
737 "Get git commit log".to_string(),
738 json!({
739 "type": "object",
740 "properties": {
741 "count": {
742 "type": "integer",
743 "description": "Number of commits to show (default: 10)"
744 }
745 },
746 "required": []
747 }),
748 ),
749 "git_add" => (
750 "Add files to git staging area".to_string(),
751 json!({
752 "type": "object",
753 "properties": {
754 "files": {
755 "type": "array",
756 "items": {"type": "string"},
757 "description": "List of file paths to add"
758 }
759 },
760 "required": ["files"]
761 }),
762 ),
763 "git_commit" => (
764 "Create a git commit".to_string(),
765 json!({
766 "type": "object",
767 "properties": {
768 "message": {
769 "type": "string",
770 "description": "Commit message"
771 }
772 },
773 "required": ["message"]
774 }),
775 ),
776 "git_push" => (
777 "Push commits to remote repository".to_string(),
778 json!({
779 "type": "object",
780 "properties": {
781 "remote": {
782 "type": "string",
783 "description": "Remote name (default: origin)"
784 },
785 "branch": {
786 "type": "string",
787 "description": "Branch name (default: current branch)"
788 }
789 },
790 "required": []
791 }),
792 ),
793 "git_pull" => (
794 "Pull changes from remote repository".to_string(),
795 json!({
796 "type": "object",
797 "properties": {
798 "remote": {
799 "type": "string",
800 "description": "Remote name (default: origin)"
801 },
802 "branch": {
803 "type": "string",
804 "description": "Branch name (default: current branch)"
805 }
806 },
807 "required": []
808 }),
809 ),
810 "git_clone" => (
811 "Clone a git repository".to_string(),
812 json!({
813 "type": "object",
814 "properties": {
815 "url": {
816 "type": "string",
817 "description": "Repository URL to clone"
818 },
819 "directory": {
820 "type": "string",
821 "description": "Directory to clone into (optional)"
822 }
823 },
824 "required": ["url"]
825 }),
826 ),
827 "grep" => (
828 "Search for text patterns in files using regex".to_string(),
829 json!({
830 "type": "object",
831 "properties": {
832 "pattern": {
833 "type": "string",
834 "description": "Regex pattern to search for"
835 },
836 "path": {
837 "type": "string",
838 "description": "Path to search in (default: current directory)"
839 }
840 },
841 "required": ["pattern"]
842 }),
843 ),
844 "ast_grep" => (
845 "Search code using AST patterns (structural code matching)".to_string(),
846 json!({
847 "type": "object",
848 "properties": {
849 "pattern": {
850 "type": "string",
851 "description": "AST pattern to match"
852 },
853 "language": {
854 "type": "string",
855 "description": "Programming language (rust, typescript, python)"
856 },
857 "path": {
858 "type": "string",
859 "description": "Path to search in (default: current directory)"
860 }
861 },
862 "required": ["pattern", "language"]
863 }),
864 ),
865 "lsp" => (
866 "Perform Language Server Protocol operations (goto_definition, find_references)"
867 .to_string(),
868 json!({
869 "type": "object",
870 "properties": {
871 "command": {
872 "type": "string",
873 "description": "LSP command: goto_definition or find_references"
874 },
875 "file_path": {
876 "type": "string",
877 "description": "Path to the file"
878 },
879 "position": {
880 "type": "object",
881 "description": "Position in the file (line, character)",
882 "properties": {
883 "line": {"type": "integer"},
884 "character": {"type": "integer"}
885 },
886 "required": ["line", "character"]
887 }
888 },
889 "required": ["command", "file_path", "position"]
890 }),
891 ),
892 "web_search" => (
893 format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
894 json!({
895 "type": "object",
896 "properties": {
897 "query": {
898 "type": "string",
899 "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
900 },
901 "numResults": {
902 "type": "integer",
903 "description": "Number of results to return (default: 8, max: 20)",
904 "default": 8
905 }
906 },
907 "required": ["query"]
908 }),
909 ),
910 "web_fetch" => (
911 "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
912 json!({
913 "type": "object",
914 "properties": {
915 "url": {
916 "type": "string",
917 "description": "URL to fetch (must start with http:// or https://)"
918 },
919 "format": {
920 "type": "string",
921 "enum": ["markdown", "text", "html"],
922 "default": "markdown",
923 "description": "Output format (default: markdown)"
924 }
925 },
926 "required": ["url"]
927 }),
928 ),
929 _ => (
930 format!("Tool: {}", name),
931 json!({
932 "type": "object",
933 "properties": {},
934 "required": []
935 }),
936 ),
937 }
938 }
939
940 fn send_event(&self, event: AgentEvent) {
942 if let Some(ref tx) = self.event_tx {
943 let _ = tx.send(event);
944 }
945 }
946
947 #[allow(dead_code)]
949 pub fn is_ready(&self) -> bool {
950 self.config
951 .providers
952 .get(&self.config.provider)
953 .map(|p| p.api_key_or_env(&self.config.provider).is_some())
954 .unwrap_or(false)
955 }
956
957 pub fn model(&self) -> &str {
959 self.config
960 .providers
961 .get(&self.config.provider)
962 .map(|p| p.model.as_str())
963 .unwrap_or("")
964 }
965
966 pub fn max_tokens(&self) -> u32 {
968 self.config
969 .providers
970 .get(&self.config.provider)
971 .map(|p| p.max_tokens)
972 .unwrap_or(4096)
973 }
974
975 pub fn timeout(&self) -> u64 {
977 self.config
978 .providers
979 .get(&self.config.provider)
980 .map(|p| p.timeout)
981 .unwrap_or(60)
982 }
983}
984fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
986 let (input_price, output_price) = match model {
987 "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
989 "gpt-4" => (30.0, 60.0),
991 "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
993 _ => (0.0, 0.0),
995 };
996 (input_tokens as f64 * input_price / 1_000_000.0)
997 + (output_tokens as f64 * output_price / 1_000_000.0)
998}
999
1000#[cfg(test)]
1001mod tests {
1002 use super::*;
1003 use limit_llm::{Config as LlmConfig, ProviderConfig};
1004 use std::collections::HashMap;
1005
1006 #[tokio::test]
1007 async fn test_agent_bridge_new() {
1008 let mut providers = HashMap::new();
1009 providers.insert(
1010 "anthropic".to_string(),
1011 ProviderConfig {
1012 api_key: Some("test-key".to_string()),
1013 model: "claude-3-5-sonnet-20241022".to_string(),
1014 base_url: None,
1015 max_tokens: 4096,
1016 timeout: 60,
1017 max_iterations: 100,
1018 thinking_enabled: false,
1019 clear_thinking: true,
1020 },
1021 );
1022 let config = LlmConfig {
1023 provider: "anthropic".to_string(),
1024 providers,
1025 };
1026
1027 let bridge = AgentBridge::new(config).unwrap();
1028 assert!(bridge.is_ready());
1029 }
1030
1031 #[tokio::test]
1032 async fn test_agent_bridge_new_no_api_key() {
1033 let mut providers = HashMap::new();
1034 providers.insert(
1035 "anthropic".to_string(),
1036 ProviderConfig {
1037 api_key: None,
1038 model: "claude-3-5-sonnet-20241022".to_string(),
1039 base_url: None,
1040 max_tokens: 4096,
1041 timeout: 60,
1042 max_iterations: 100,
1043 thinking_enabled: false,
1044 clear_thinking: true,
1045 },
1046 );
1047 let config = LlmConfig {
1048 provider: "anthropic".to_string(),
1049 providers,
1050 };
1051
1052 let result = AgentBridge::new(config);
1053 assert!(result.is_err());
1054 }
1055
1056 #[tokio::test]
1057 async fn test_get_tool_definitions() {
1058 let mut providers = HashMap::new();
1059 providers.insert(
1060 "anthropic".to_string(),
1061 ProviderConfig {
1062 api_key: Some("test-key".to_string()),
1063 model: "claude-3-5-sonnet-20241022".to_string(),
1064 base_url: None,
1065 max_tokens: 4096,
1066 timeout: 60,
1067 max_iterations: 100,
1068 thinking_enabled: false,
1069 clear_thinking: true,
1070 },
1071 );
1072 let config = LlmConfig {
1073 provider: "anthropic".to_string(),
1074 providers,
1075 };
1076
1077 let bridge = AgentBridge::new(config).unwrap();
1078 let definitions = bridge.get_tool_definitions();
1079
1080 assert_eq!(definitions.len(), 17);
1081
1082 let file_read = definitions
1084 .iter()
1085 .find(|d| d.function.name == "file_read")
1086 .unwrap();
1087 assert_eq!(file_read.tool_type, "function");
1088 assert_eq!(file_read.function.name, "file_read");
1089 assert!(file_read.function.description.contains("Read"));
1090
1091 let bash = definitions
1093 .iter()
1094 .find(|d| d.function.name == "bash")
1095 .unwrap();
1096 assert_eq!(bash.function.name, "bash");
1097 assert!(bash.function.parameters["required"]
1098 .as_array()
1099 .unwrap()
1100 .contains(&"command".into()));
1101 }
1102
1103 #[test]
1104 fn test_get_tool_schema() {
1105 let (desc, params) = AgentBridge::get_tool_schema("file_read");
1106 assert!(desc.contains("Read"));
1107 assert_eq!(params["properties"]["path"]["type"], "string");
1108 assert!(params["required"]
1109 .as_array()
1110 .unwrap()
1111 .contains(&"path".into()));
1112
1113 let (desc, params) = AgentBridge::get_tool_schema("bash");
1114 assert!(desc.contains("bash"));
1115 assert_eq!(params["properties"]["command"]["type"], "string");
1116
1117 let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
1118 assert!(desc.contains("unknown_tool"));
1119 }
1120
1121 #[test]
1122 fn test_is_ready() {
1123 let mut providers = HashMap::new();
1124 providers.insert(
1125 "anthropic".to_string(),
1126 ProviderConfig {
1127 api_key: Some("test-key".to_string()),
1128 model: "claude-3-5-sonnet-20241022".to_string(),
1129 base_url: None,
1130 max_tokens: 4096,
1131 timeout: 60,
1132 max_iterations: 100,
1133 thinking_enabled: false,
1134 clear_thinking: true,
1135 },
1136 );
1137 let config_with_key = LlmConfig {
1138 provider: "anthropic".to_string(),
1139 providers,
1140 };
1141
1142 let bridge = AgentBridge::new(config_with_key).unwrap();
1143 assert!(bridge.is_ready());
1144 }
1145}