1use crate::error::CliError;
2use crate::system_prompt::SYSTEM_PROMPT;
3use crate::tools::tldr_tool_definition;
4use crate::tools::{
5 AstGrepTool, BashTool, BrowserTool, FileEditTool, FileReadTool, FileWriteTool, GitAddTool,
6 GitCloneTool, GitCommitTool, GitDiffTool, GitLogTool, GitPullTool, GitPushTool, GitStatusTool,
7 GrepTool, LspTool, TldrTool, WebFetchTool, WebSearchTool,
8};
9use chrono::Datelike;
10use futures::StreamExt;
11use limit_agent::executor::{ToolCall, ToolExecutor};
12use limit_agent::registry::ToolRegistry;
13use limit_llm::providers::LlmProvider;
14use limit_llm::types::{Message, Role, Tool as LlmTool, ToolCall as LlmToolCall};
15use limit_llm::ProviderFactory;
16use limit_llm::ProviderResponseChunk;
17use limit_llm::TrackingDb;
18use serde_json::json;
19use tokio::sync::mpsc;
20use tokio_util::sync::CancellationToken;
21use tracing::{debug, instrument, trace};
22
23#[derive(Debug, Clone)]
25#[allow(dead_code)]
26pub enum AgentEvent {
27 Thinking {
28 operation_id: u64,
29 },
30 ToolStart {
31 operation_id: u64,
32 name: String,
33 args: serde_json::Value,
34 },
35 ToolComplete {
36 operation_id: u64,
37 name: String,
38 result: String,
39 },
40 ContentChunk {
41 operation_id: u64,
42 chunk: String,
43 },
44 Done {
45 operation_id: u64,
46 },
47 Cancelled {
48 operation_id: u64,
49 },
50 Error {
51 operation_id: u64,
52 message: String,
53 },
54 TokenUsage {
55 operation_id: u64,
56 input_tokens: u64,
57 output_tokens: u64,
58 },
59}
60
61pub struct AgentBridge {
63 llm_client: Box<dyn LlmProvider>,
65 executor: ToolExecutor,
67 tool_names: Vec<&'static str>,
69 config: limit_llm::Config,
71 event_tx: Option<mpsc::UnboundedSender<AgentEvent>>,
73 tracking_db: TrackingDb,
75 cancellation_token: Option<CancellationToken>,
77 operation_id: u64,
79}
80
81impl AgentBridge {
82 pub fn new(config: limit_llm::Config) -> Result<Self, CliError> {
90 let tracking_db = TrackingDb::new().map_err(|e| CliError::ConfigError(e.to_string()))?;
91 Self::with_tracking_db(config, tracking_db)
92 }
93
94 #[cfg(test)]
96 pub fn new_for_test(config: limit_llm::Config) -> Result<Self, CliError> {
97 let tracking_db =
98 TrackingDb::new_in_memory().map_err(|e| CliError::ConfigError(e.to_string()))?;
99 Self::with_tracking_db(config, tracking_db)
100 }
101
102 pub fn with_tracking_db(
104 config: limit_llm::Config,
105 tracking_db: TrackingDb,
106 ) -> Result<Self, CliError> {
107 let llm_client = ProviderFactory::create_provider(&config)
108 .map_err(|e| CliError::ConfigError(e.to_string()))?;
109
110 let mut tool_registry = ToolRegistry::new();
111 Self::register_tools(&mut tool_registry, &config);
112
113 let executor = ToolExecutor::new(tool_registry);
115
116 let tool_names = vec![
118 "file_read",
119 "file_write",
120 "file_edit",
121 "bash",
122 "git_status",
123 "git_diff",
124 "git_log",
125 "git_add",
126 "git_commit",
127 "git_push",
128 "git_pull",
129 "git_clone",
130 "grep",
131 "ast_grep",
132 "lsp",
133 "web_search",
134 "web_fetch",
135 "browser",
136 "tldr_analyze",
137 ];
138
139 Ok(Self {
140 llm_client,
141 executor,
142 tool_names,
143 config,
144 event_tx: None,
145 tracking_db,
146 cancellation_token: None,
147 operation_id: 0,
148 })
149 }
150
151 pub fn set_event_tx(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
153 self.event_tx = Some(tx);
154 }
155
156 pub fn set_cancellation_token(&mut self, token: CancellationToken, operation_id: u64) {
158 debug!("set_cancellation_token: operation_id={}", operation_id);
159 self.cancellation_token = Some(token);
160 self.operation_id = operation_id;
161 }
162
163 pub fn clear_cancellation_token(&mut self) {
165 self.cancellation_token = None;
166 }
167
168 fn register_tools(registry: &mut ToolRegistry, config: &limit_llm::Config) {
170 registry
172 .register(FileReadTool::new())
173 .expect("Failed to register file_read");
174 registry
175 .register(FileWriteTool::new())
176 .expect("Failed to register file_write");
177 registry
178 .register(FileEditTool::new())
179 .expect("Failed to register file_edit");
180
181 registry
183 .register(BashTool::new())
184 .expect("Failed to register bash");
185
186 registry
188 .register(GitStatusTool::new())
189 .expect("Failed to register git_status");
190 registry
191 .register(GitDiffTool::new())
192 .expect("Failed to register git_diff");
193 registry
194 .register(GitLogTool::new())
195 .expect("Failed to register git_log");
196 registry
197 .register(GitAddTool::new())
198 .expect("Failed to register git_add");
199 registry
200 .register(GitCommitTool::new())
201 .expect("Failed to register git_commit");
202 registry
203 .register(GitPushTool::new())
204 .expect("Failed to register git_push");
205 registry
206 .register(GitPullTool::new())
207 .expect("Failed to register git_pull");
208 registry
209 .register(GitCloneTool::new())
210 .expect("Failed to register git_clone");
211
212 registry
214 .register(GrepTool::new())
215 .expect("Failed to register grep");
216 registry
217 .register(AstGrepTool::new())
218 .expect("Failed to register ast_grep");
219 registry
220 .register(LspTool::new())
221 .expect("Failed to register lsp");
222
223 registry
225 .register(WebSearchTool::new())
226 .expect("Failed to register web_search");
227 registry
228 .register(WebFetchTool::new())
229 .expect("Failed to register web_fetch");
230
231 let browser_config = crate::tools::browser::BrowserConfig::from(&config.browser);
233 registry
234 .register(BrowserTool::with_config(browser_config))
235 .expect("Failed to register browser");
236
237 registry
239 .register(TldrTool::new())
240 .expect("Failed to register tldr_analyze");
241 }
242
243 #[instrument(skip(self, _messages, user_input))]
252 pub async fn process_message(
253 &mut self,
254 user_input: &str,
255 _messages: &mut Vec<Message>,
256 ) -> Result<String, CliError> {
257 if _messages.is_empty() {
260 let system_message = Message {
261 role: Role::System,
262 content: Some(SYSTEM_PROMPT.to_string()),
263 tool_calls: None,
264 tool_call_id: None,
265 };
266 _messages.push(system_message);
267 }
268
269 let user_message = Message {
271 role: Role::User,
272 content: Some(user_input.to_string()),
273 tool_calls: None,
274 tool_call_id: None,
275 };
276 _messages.push(user_message);
277
278 let tool_definitions = self.get_tool_definitions();
280
281 let mut full_response = String::new();
283 let mut tool_calls: Vec<LlmToolCall> = Vec::new();
284 let max_iterations = self
285 .config
286 .providers
287 .get(&self.config.provider)
288 .map(|p| p.max_iterations)
289 .unwrap_or(100); let mut iteration = 0;
291
292 while max_iterations == 0 || iteration < max_iterations {
293 iteration += 1;
294 debug!("Agent loop iteration {}", iteration);
295
296 debug!(
298 "Sending Thinking event with operation_id={}",
299 self.operation_id
300 );
301 self.send_event(AgentEvent::Thinking {
302 operation_id: self.operation_id,
303 });
304
305 let request_start = std::time::Instant::now();
307
308 let mut stream = self
310 .llm_client
311 .send(_messages.clone(), tool_definitions.clone())
312 .await
313 .map_err(|e| CliError::ConfigError(e.to_string()))?;
314
315 tool_calls.clear();
316 let mut current_content = String::new();
317 let mut accumulated_calls: std::collections::HashMap<
319 String,
320 (String, serde_json::Value),
321 > = std::collections::HashMap::new();
322
323 loop {
325 if let Some(ref token) = self.cancellation_token {
327 if token.is_cancelled() {
328 debug!("Operation cancelled by user (pre-stream check)");
329 self.send_event(AgentEvent::Cancelled {
330 operation_id: self.operation_id,
331 });
332 return Err(CliError::ConfigError(
333 "Operation cancelled by user".to_string(),
334 ));
335 }
336 }
337
338 let chunk_result = if let Some(ref token) = self.cancellation_token {
341 tokio::select! {
342 chunk = stream.next() => chunk,
343 _ = token.cancelled() => {
344 debug!("Operation cancelled via token while waiting for stream");
345 self.send_event(AgentEvent::Cancelled {
346 operation_id: self.operation_id,
347 });
348 return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
349 }
350 }
351 } else {
352 stream.next().await
353 };
354
355 let Some(chunk_result) = chunk_result else {
356 break;
358 };
359
360 match chunk_result {
361 Ok(ProviderResponseChunk::ContentDelta(text)) => {
362 current_content.push_str(&text);
363 trace!(
364 "ContentDelta: {} chars (total: {})",
365 text.len(),
366 current_content.len()
367 );
368 self.send_event(AgentEvent::ContentChunk {
369 operation_id: self.operation_id,
370 chunk: text,
371 });
372 }
373 Ok(ProviderResponseChunk::ReasoningDelta(_)) => {
374 }
376 Ok(ProviderResponseChunk::ToolCallDelta {
377 id,
378 name,
379 arguments,
380 }) => {
381 trace!(
382 "ToolCallDelta: id={}, name={}, args_len={}",
383 id,
384 name,
385 arguments.to_string().len()
386 );
387 accumulated_calls.insert(id.clone(), (name.clone(), arguments.clone()));
389 }
390 Ok(ProviderResponseChunk::Done(usage)) => {
391 let duration_ms = request_start.elapsed().as_millis() as u64;
393 let cost =
394 calculate_cost(self.model(), usage.input_tokens, usage.output_tokens);
395 let _ = self.tracking_db.track_request(
396 self.model(),
397 usage.input_tokens,
398 usage.output_tokens,
399 cost,
400 duration_ms,
401 );
402 self.send_event(AgentEvent::TokenUsage {
404 operation_id: self.operation_id,
405 input_tokens: usage.input_tokens,
406 output_tokens: usage.output_tokens,
407 });
408 break;
409 }
410 Err(e) => {
411 let error_msg = format!("LLM error: {}", e);
412 self.send_event(AgentEvent::Error {
413 operation_id: self.operation_id,
414 message: error_msg.clone(),
415 });
416 return Err(CliError::ConfigError(error_msg));
417 }
418 }
419 }
420
421 tool_calls = accumulated_calls
423 .into_iter()
424 .map(|(id, (name, args))| LlmToolCall {
425 id,
426 tool_type: "function".to_string(),
427 function: limit_llm::types::FunctionCall {
428 name,
429 arguments: args.to_string(),
430 },
431 })
432 .collect();
433
434 full_response = current_content.clone();
439
440 trace!(
441 "After iter {}: content.len()={}, tool_calls={}, response.len()={}",
442 iteration,
443 current_content.len(),
444 tool_calls.len(),
445 full_response.len()
446 );
447
448 if tool_calls.is_empty() {
450 debug!("No tool calls, breaking loop after iteration {}", iteration);
451 break;
452 }
453
454 trace!(
455 "Tool calls found (count={}), continuing to iteration {}",
456 tool_calls.len(),
457 iteration + 1
458 );
459
460 let assistant_message = Message {
463 role: Role::Assistant,
464 content: None, tool_calls: Some(tool_calls.clone()),
466 tool_call_id: None,
467 };
468 _messages.push(assistant_message);
469
470 let executor_calls: Vec<ToolCall> = tool_calls
472 .iter()
473 .map(|tc| {
474 let args: serde_json::Value =
475 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
476 ToolCall::new(&tc.id, &tc.function.name, args)
477 })
478 .collect();
479
480 for tc in &tool_calls {
482 let args: serde_json::Value =
483 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
484 self.send_event(AgentEvent::ToolStart {
485 operation_id: self.operation_id,
486 name: tc.function.name.clone(),
487 args,
488 });
489 }
490 let results = self.executor.execute_tools(executor_calls).await;
492
493 for result in results {
495 let tool_call = tool_calls.iter().find(|tc| tc.id == result.call_id);
496 if let Some(tool_call) = tool_call {
497 let output_json = match &result.output {
498 Ok(value) => {
499 serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string())
500 }
501 Err(e) => json!({ "error": e.to_string() }).to_string(),
502 };
503
504 self.send_event(AgentEvent::ToolComplete {
505 operation_id: self.operation_id,
506 name: tool_call.function.name.clone(),
507 result: output_json.clone(),
508 });
509
510 let tool_result_message = Message {
512 role: Role::Tool,
513 content: Some(output_json),
514 tool_calls: None,
515 tool_call_id: Some(result.call_id),
516 };
517 _messages.push(tool_result_message);
518 }
519 }
520 }
521
522 if max_iterations > 0 && iteration >= max_iterations && !_messages.is_empty() {
525 debug!("Making final LLM call after hitting max iterations (forcing text response)");
526
527 let constraint_message = Message {
529 role: Role::User,
530 content: Some(
531 "We've reached the iteration limit. Please provide a summary of:\n\
532 1. What you've completed so far\n\
533 2. What remains to be done\n\
534 3. Recommended next steps for the user to continue"
535 .to_string(),
536 ),
537 tool_calls: None,
538 tool_call_id: None,
539 };
540 _messages.push(constraint_message);
541
542 let no_tools: Vec<LlmTool> = vec![];
544 let mut stream = self
545 .llm_client
546 .send(_messages.clone(), no_tools)
547 .await
548 .map_err(|e| CliError::ConfigError(e.to_string()))?;
549
550 full_response.clear();
552 loop {
553 if let Some(ref token) = self.cancellation_token {
555 if token.is_cancelled() {
556 debug!("Operation cancelled by user in final loop (pre-stream check)");
557 self.send_event(AgentEvent::Cancelled {
558 operation_id: self.operation_id,
559 });
560 return Err(CliError::ConfigError(
561 "Operation cancelled by user".to_string(),
562 ));
563 }
564 }
565
566 let chunk_result = if let Some(ref token) = self.cancellation_token {
569 tokio::select! {
570 chunk = stream.next() => chunk,
571 _ = token.cancelled() => {
572 debug!("Operation cancelled via token while waiting for stream");
573 self.send_event(AgentEvent::Cancelled {
574 operation_id: self.operation_id,
575 });
576 return Err(CliError::ConfigError("Operation cancelled by user".to_string()));
577 }
578 }
579 } else {
580 stream.next().await
581 };
582
583 let Some(chunk_result) = chunk_result else {
584 break;
586 };
587
588 match chunk_result {
589 Ok(ProviderResponseChunk::ContentDelta(text)) => {
590 full_response.push_str(&text);
591 self.send_event(AgentEvent::ContentChunk {
592 operation_id: self.operation_id,
593 chunk: text,
594 });
595 }
596 Ok(ProviderResponseChunk::Done(_)) => {
597 break;
598 }
599 Err(e) => {
600 debug!("Error in final LLM call: {}", e);
601 break;
602 }
603 _ => {}
604 }
605 }
606 }
607
608 if !full_response.is_empty() {
612 let last_assistant_idx = _messages.iter().rposition(|m| m.role == Role::Assistant);
616
617 if let Some(idx) = last_assistant_idx {
618 let last_assistant = &mut _messages[idx];
619
620 if last_assistant.content.is_none()
622 || last_assistant
623 .content
624 .as_ref()
625 .map(|c| c.is_empty())
626 .unwrap_or(true)
627 {
628 last_assistant.content = Some(full_response.clone());
629 debug!("Updated last assistant message with final response content");
630 } else {
631 debug!("Last assistant already has content, adding new message");
634 let final_assistant_message = Message {
635 role: Role::Assistant,
636 content: Some(full_response.clone()),
637 tool_calls: None,
638 tool_call_id: None,
639 };
640 _messages.push(final_assistant_message);
641 }
642 } else {
643 debug!("No assistant message found, adding new message");
645 let final_assistant_message = Message {
646 role: Role::Assistant,
647 content: Some(full_response.clone()),
648 tool_calls: None,
649 tool_call_id: None,
650 };
651 _messages.push(final_assistant_message);
652 }
653 }
654
655 self.send_event(AgentEvent::Done {
656 operation_id: self.operation_id,
657 });
658 Ok(full_response)
659 }
660
661 pub fn get_tool_definitions(&self) -> Vec<LlmTool> {
663 self.tool_names
664 .iter()
665 .map(|name| {
666 let (description, parameters) = Self::get_tool_schema(name);
667 LlmTool {
668 tool_type: "function".to_string(),
669 function: limit_llm::types::ToolFunction {
670 name: name.to_string(),
671 description,
672 parameters,
673 },
674 }
675 })
676 .collect()
677 }
678
679 fn get_tool_schema(name: &str) -> (String, serde_json::Value) {
681 match name {
682 "file_read" => (
683 "Read the contents of a file".to_string(),
684 json!({
685 "type": "object",
686 "properties": {
687 "path": {
688 "type": "string",
689 "description": "Path to the file to read"
690 }
691 },
692 "required": ["path"]
693 }),
694 ),
695 "file_write" => (
696 "Write content to a file, creating parent directories if needed".to_string(),
697 json!({
698 "type": "object",
699 "properties": {
700 "path": {
701 "type": "string",
702 "description": "Path to the file to write"
703 },
704 "content": {
705 "type": "string",
706 "description": "Content to write to the file"
707 }
708 },
709 "required": ["path", "content"]
710 }),
711 ),
712 "file_edit" => (
713 "Replace text in a file with new text".to_string(),
714 json!({
715 "type": "object",
716 "properties": {
717 "path": {
718 "type": "string",
719 "description": "Path to the file to edit"
720 },
721 "old_text": {
722 "type": "string",
723 "description": "Text to find and replace"
724 },
725 "new_text": {
726 "type": "string",
727 "description": "New text to replace with"
728 }
729 },
730 "required": ["path", "old_text", "new_text"]
731 }),
732 ),
733 "bash" => (
734 "Execute a bash command in a shell".to_string(),
735 json!({
736 "type": "object",
737 "properties": {
738 "command": {
739 "type": "string",
740 "description": "Bash command to execute"
741 },
742 "workdir": {
743 "type": "string",
744 "description": "Working directory (default: current directory)"
745 },
746 "timeout": {
747 "type": "integer",
748 "description": "Timeout in seconds (default: 60)"
749 }
750 },
751 "required": ["command"]
752 }),
753 ),
754 "git_status" => (
755 "Get git repository status".to_string(),
756 json!({
757 "type": "object",
758 "properties": {},
759 "required": []
760 }),
761 ),
762 "git_diff" => (
763 "Get git diff".to_string(),
764 json!({
765 "type": "object",
766 "properties": {},
767 "required": []
768 }),
769 ),
770 "git_log" => (
771 "Get git commit log".to_string(),
772 json!({
773 "type": "object",
774 "properties": {
775 "count": {
776 "type": "integer",
777 "description": "Number of commits to show (default: 10)"
778 }
779 },
780 "required": []
781 }),
782 ),
783 "git_add" => (
784 "Add files to git staging area".to_string(),
785 json!({
786 "type": "object",
787 "properties": {
788 "files": {
789 "type": "array",
790 "items": {"type": "string"},
791 "description": "List of file paths to add"
792 }
793 },
794 "required": ["files"]
795 }),
796 ),
797 "git_commit" => (
798 "Create a git commit".to_string(),
799 json!({
800 "type": "object",
801 "properties": {
802 "message": {
803 "type": "string",
804 "description": "Commit message"
805 }
806 },
807 "required": ["message"]
808 }),
809 ),
810 "git_push" => (
811 "Push commits to remote repository".to_string(),
812 json!({
813 "type": "object",
814 "properties": {
815 "remote": {
816 "type": "string",
817 "description": "Remote name (default: origin)"
818 },
819 "branch": {
820 "type": "string",
821 "description": "Branch name (default: current branch)"
822 }
823 },
824 "required": []
825 }),
826 ),
827 "git_pull" => (
828 "Pull changes from remote repository".to_string(),
829 json!({
830 "type": "object",
831 "properties": {
832 "remote": {
833 "type": "string",
834 "description": "Remote name (default: origin)"
835 },
836 "branch": {
837 "type": "string",
838 "description": "Branch name (default: current branch)"
839 }
840 },
841 "required": []
842 }),
843 ),
844 "git_clone" => (
845 "Clone a git repository".to_string(),
846 json!({
847 "type": "object",
848 "properties": {
849 "url": {
850 "type": "string",
851 "description": "Repository URL to clone"
852 },
853 "directory": {
854 "type": "string",
855 "description": "Directory to clone into (optional)"
856 }
857 },
858 "required": ["url"]
859 }),
860 ),
861 "grep" => (
862 "Search for text patterns in files using regex".to_string(),
863 json!({
864 "type": "object",
865 "properties": {
866 "pattern": {
867 "type": "string",
868 "description": "Regex pattern to search for"
869 },
870 "path": {
871 "type": "string",
872 "description": "Path to search in (default: current directory)"
873 }
874 },
875 "required": ["pattern"]
876 }),
877 ),
878 "ast_grep" => (
879 "Search code using AST patterns (structural code matching)".to_string(),
880 json!({
881 "type": "object",
882 "properties": {
883 "pattern": {
884 "type": "string",
885 "description": "AST pattern to match"
886 },
887 "language": {
888 "type": "string",
889 "description": "Programming language (rust, typescript, python)"
890 },
891 "path": {
892 "type": "string",
893 "description": "Path to search in (default: current directory)"
894 }
895 },
896 "required": ["pattern", "language"]
897 }),
898 ),
899 "lsp" => (
900 "Perform Language Server Protocol operations (goto_definition, find_references)"
901 .to_string(),
902 json!({
903 "type": "object",
904 "properties": {
905 "command": {
906 "type": "string",
907 "description": "LSP command: goto_definition or find_references"
908 },
909 "file_path": {
910 "type": "string",
911 "description": "Path to the file"
912 },
913 "position": {
914 "type": "object",
915 "description": "Position in the file (line, character)",
916 "properties": {
917 "line": {"type": "integer"},
918 "character": {"type": "integer"}
919 },
920 "required": ["line", "character"]
921 }
922 },
923 "required": ["command", "file_path", "position"]
924 }),
925 ),
926 "web_search" => (
927 format!("Search the web using Exa AI. Returns results with titles, URLs, and content snippets. Use for current information beyond knowledge cutoff. The current year is {} - use this year when searching for recent information.", chrono::Local::now().year()),
928 json!({
929 "type": "object",
930 "properties": {
931 "query": {
932 "type": "string",
933 "description": format!("Search query. Be specific for better results (e.g., 'Rust async tutorial {}' rather than 'Rust')", chrono::Local::now().year())
934 },
935 "numResults": {
936 "type": "integer",
937 "description": "Number of results to return (default: 8, max: 20)",
938 "default": 8
939 }
940 },
941 "required": ["query"]
942 }),
943 ),
944 "web_fetch" => (
945 "Fetch content from a URL. Converts HTML to markdown format by default. Use when user provides a URL or after web_search to read full content of a specific result.".to_string(),
946 json!({
947 "type": "object",
948 "properties": {
949 "url": {
950 "type": "string",
951 "description": "URL to fetch (must start with http:// or https://)"
952 },
953 "format": {
954 "type": "string",
955 "enum": ["markdown", "text", "html"],
956 "default": "markdown",
957 "description": "Output format (default: markdown)"
958 }
959 },
960 "required": ["url"]
961 }),
962 ),
963 "browser" => (
964 "Browser automation for testing, scraping, and web interaction. Use snapshot-ref workflow: open URL, take snapshot, use refs from snapshot for interactions. Supports Chrome and Lightpanda engines.".to_string(),
965 json!({
966 "type": "object",
967 "properties": {
968 "action": {
969 "type": "string",
970 "enum": [
971 "open", "close", "snapshot",
973 "click", "dblclick", "fill", "type", "press", "hover", "select",
975 "focus", "check", "uncheck", "scrollintoview", "drag", "upload",
976 "back", "forward", "reload",
978 "screenshot", "pdf", "eval", "get", "get_attr", "get_count", "get_box", "get_styles",
980 "find", "is", "download",
981 "wait", "wait_for_text", "wait_for_url", "wait_for_load", "wait_for_download", "wait_for_fn", "wait_for_state",
983 "tab_list", "tab_new", "tab_close", "tab_select", "dialog_accept", "dialog_dismiss",
985 "cookies", "cookies_set", "storage_get", "storage_set", "network_requests",
987 "set_viewport", "set_device", "set_geo",
989 "scroll"
991 ],
992 "description": "Browser action to perform"
993 },
994 "url": {
996 "type": "string",
997 "description": "URL to open (required for 'open' action)"
998 },
999 "selector": {
1001 "type": "string",
1002 "description": "Element selector or ref (for click, fill, type, hover, select, focus, check, uncheck, scrollintoview, get_attr, get_count, get_box, get_styles, is, download, upload)"
1003 },
1004 "text": {
1005 "type": "string",
1006 "description": "Text to input (for fill, type actions)"
1007 },
1008 "key": {
1009 "type": "string",
1010 "description": "Key to press (required for 'press' action)"
1011 },
1012 "value": {
1013 "type": "string",
1014 "description": "Value (for select, cookies_set, storage_set)"
1015 },
1016 "target": {
1017 "type": "string",
1018 "description": "Target selector (for drag action)"
1019 },
1020 "files": {
1021 "type": "array",
1022 "items": {"type": "string"},
1023 "description": "File paths to upload (for upload action)"
1024 },
1025 "path": {
1027 "type": "string",
1028 "description": "File path (for screenshot, pdf, download actions)"
1029 },
1030 "script": {
1031 "type": "string",
1032 "description": "JavaScript to evaluate (required for 'eval' and 'wait_for_fn' actions)"
1033 },
1034 "get_what": {
1035 "type": "string",
1036 "enum": ["text", "html", "value", "url", "title"],
1037 "description": "What to get (required for 'get' action)"
1038 },
1039 "attr": {
1040 "type": "string",
1041 "description": "Attribute name (for get_attr action)"
1042 },
1043 "locator_type": {
1045 "type": "string",
1046 "enum": ["role", "text", "label", "placeholder", "alt", "title", "testid", "css", "xpath"],
1047 "description": "Locator strategy (for find action)"
1048 },
1049 "locator_value": {
1050 "type": "string",
1051 "description": "Locator value (for find action)"
1052 },
1053 "find_action": {
1054 "type": "string",
1055 "enum": ["click", "fill", "text", "count", "first", "last", "nth", "hover", "focus", "check", "uncheck"],
1056 "description": "Action to perform on found element (for find action)"
1057 },
1058 "action_value": {
1059 "type": "string",
1060 "description": "Value for find action (optional)"
1061 },
1062 "wait_for": {
1064 "type": "string",
1065 "description": "Wait condition (for wait action)"
1066 },
1067 "state": {
1068 "type": "string",
1069 "enum": ["visible", "hidden", "attached", "detached", "enabled", "disabled", "networkidle", "domcontentloaded", "load"],
1070 "description": "State to wait for (for wait_for_state, wait_for_load actions)"
1071 },
1072 "what": {
1074 "type": "string",
1075 "enum": ["visible", "hidden", "enabled", "disabled", "editable"],
1076 "description": "State to check (required for 'is' action)"
1077 },
1078 "direction": {
1080 "type": "string",
1081 "enum": ["up", "down", "left", "right"],
1082 "description": "Scroll direction (for scroll action)"
1083 },
1084 "pixels": {
1085 "type": "integer",
1086 "description": "Pixels to scroll (optional for scroll action)"
1087 },
1088 "index": {
1090 "type": "integer",
1091 "description": "Tab index (for tab_close, tab_select actions)"
1092 },
1093 "dialog_text": {
1095 "type": "string",
1096 "description": "Text for prompt dialog (for dialog_accept action)"
1097 },
1098 "storage_type": {
1100 "type": "string",
1101 "enum": ["local", "session"],
1102 "description": "Storage type (for storage_get, storage_set actions)"
1103 },
1104 "key_name": {
1105 "type": "string",
1106 "description": "Storage key name (for storage_get, storage_set actions)"
1107 },
1108 "filter": {
1110 "type": "string",
1111 "description": "Network request filter (optional for network_requests action)"
1112 },
1113 "width": {
1115 "type": "integer",
1116 "description": "Viewport width (for set_viewport action)"
1117 },
1118 "height": {
1119 "type": "integer",
1120 "description": "Viewport height (for set_viewport action)"
1121 },
1122 "scale": {
1123 "type": "number",
1124 "description": "Device scale factor (optional for set_viewport action)"
1125 },
1126 "device_name": {
1127 "type": "string",
1128 "description": "Device name to emulate (for set_device action)"
1129 },
1130 "latitude": {
1131 "type": "number",
1132 "description": "Latitude (for set_geo action)"
1133 },
1134 "longitude": {
1135 "type": "number",
1136 "description": "Longitude (for set_geo action)"
1137 },
1138 "name": {
1140 "type": "string",
1141 "description": "Cookie name (for cookies_set action)"
1142 },
1143 "engine": {
1145 "type": "string",
1146 "enum": ["chrome", "lightpanda"],
1147 "default": "chrome",
1148 "description": "Browser engine to use"
1149 }
1150 },
1151 "required": ["action"]
1152 }),
1153 ),
1154 "tldr_analyze" => {
1155 let tool_def = tldr_tool_definition();
1156 (
1157 tool_def["description"].as_str().unwrap_or("").to_string(),
1158 tool_def["parameters"].clone()
1159 )
1160 },
1161 _ => (
1162 format!("Tool: {}", name),
1163 json!({
1164 "type": "object",
1165 "properties": {},
1166 "required": []
1167 }),
1168 ),
1169 }
1170 }
1171
1172 fn send_event(&self, event: AgentEvent) {
1174 if let Some(ref tx) = self.event_tx {
1175 let _ = tx.send(event);
1176 }
1177 }
1178
1179 #[allow(dead_code)]
1181 pub fn is_ready(&self) -> bool {
1182 self.config
1183 .providers
1184 .get(&self.config.provider)
1185 .map(|p| p.api_key_or_env(&self.config.provider).is_some())
1186 .unwrap_or(false)
1187 }
1188
1189 pub fn model(&self) -> &str {
1191 self.config
1192 .providers
1193 .get(&self.config.provider)
1194 .map(|p| p.model.as_str())
1195 .unwrap_or("")
1196 }
1197
1198 pub fn max_tokens(&self) -> u32 {
1200 self.config
1201 .providers
1202 .get(&self.config.provider)
1203 .map(|p| p.max_tokens)
1204 .unwrap_or(4096)
1205 }
1206
1207 pub fn timeout(&self) -> u64 {
1209 self.config
1210 .providers
1211 .get(&self.config.provider)
1212 .map(|p| p.timeout)
1213 .unwrap_or(60)
1214 }
1215}
1216fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
1218 let (input_price, output_price) = match model {
1219 "claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => (3.0, 15.0),
1221 "gpt-4" => (30.0, 60.0),
1223 "gpt-4-turbo" | "gpt-4-turbo-preview" => (10.0, 30.0),
1225 _ => (0.0, 0.0),
1227 };
1228 (input_tokens as f64 * input_price / 1_000_000.0)
1229 + (output_tokens as f64 * output_price / 1_000_000.0)
1230}
1231
1232#[cfg(test)]
1233mod tests {
1234 use super::*;
1235 use limit_llm::{BrowserConfigSection, Config as LlmConfig, ProviderConfig};
1236 use std::collections::HashMap;
1237
1238 #[tokio::test]
1239 async fn test_agent_bridge_new() {
1240 let mut providers = HashMap::new();
1241 providers.insert(
1242 "anthropic".to_string(),
1243 ProviderConfig {
1244 api_key: Some("test-key".to_string()),
1245 model: "claude-3-5-sonnet-20241022".to_string(),
1246 base_url: None,
1247 max_tokens: 4096,
1248 timeout: 60,
1249 max_iterations: 100,
1250 thinking_enabled: false,
1251 clear_thinking: true,
1252 },
1253 );
1254 let config = LlmConfig {
1255 provider: "anthropic".to_string(),
1256 providers,
1257 browser: BrowserConfigSection::default(),
1258 };
1259
1260 let bridge = AgentBridge::new(config).unwrap();
1261 assert!(bridge.is_ready());
1262 }
1263
1264 #[tokio::test]
1265 async fn test_agent_bridge_new_no_api_key() {
1266 let mut providers = HashMap::new();
1267 providers.insert(
1268 "anthropic".to_string(),
1269 ProviderConfig {
1270 api_key: None,
1271 model: "claude-3-5-sonnet-20241022".to_string(),
1272 base_url: None,
1273 max_tokens: 4096,
1274 timeout: 60,
1275 max_iterations: 100,
1276 thinking_enabled: false,
1277 clear_thinking: true,
1278 },
1279 );
1280 let config = LlmConfig {
1281 provider: "anthropic".to_string(),
1282 providers,
1283 browser: BrowserConfigSection::default(),
1284 };
1285
1286 let result = AgentBridge::new(config);
1287 assert!(result.is_err());
1288 }
1289
1290 #[tokio::test]
1291 async fn test_get_tool_definitions() {
1292 let mut providers = HashMap::new();
1293 providers.insert(
1294 "anthropic".to_string(),
1295 ProviderConfig {
1296 api_key: Some("test-key".to_string()),
1297 model: "claude-3-5-sonnet-20241022".to_string(),
1298 base_url: None,
1299 max_tokens: 4096,
1300 timeout: 60,
1301 max_iterations: 100,
1302 thinking_enabled: false,
1303 clear_thinking: true,
1304 },
1305 );
1306 let config = LlmConfig {
1307 provider: "anthropic".to_string(),
1308 providers,
1309 browser: BrowserConfigSection::default(),
1310 };
1311
1312 let bridge = AgentBridge::new(config).unwrap();
1313 let definitions = bridge.get_tool_definitions();
1314
1315 assert_eq!(definitions.len(), 19);
1316
1317 let file_read = definitions
1319 .iter()
1320 .find(|d| d.function.name == "file_read")
1321 .unwrap();
1322 assert_eq!(file_read.tool_type, "function");
1323 assert_eq!(file_read.function.name, "file_read");
1324 assert!(file_read.function.description.contains("Read"));
1325
1326 let bash = definitions
1328 .iter()
1329 .find(|d| d.function.name == "bash")
1330 .unwrap();
1331 assert_eq!(bash.function.name, "bash");
1332 assert!(bash.function.parameters["required"]
1333 .as_array()
1334 .unwrap()
1335 .contains(&"command".into()));
1336 }
1337
1338 #[test]
1339 fn test_get_tool_schema() {
1340 let (desc, params) = AgentBridge::get_tool_schema("file_read");
1341 assert!(desc.contains("Read"));
1342 assert_eq!(params["properties"]["path"]["type"], "string");
1343 assert!(params["required"]
1344 .as_array()
1345 .unwrap()
1346 .contains(&"path".into()));
1347
1348 let (desc, params) = AgentBridge::get_tool_schema("bash");
1349 assert!(desc.contains("bash"));
1350 assert_eq!(params["properties"]["command"]["type"], "string");
1351
1352 let (desc, _params) = AgentBridge::get_tool_schema("unknown_tool");
1353 assert!(desc.contains("unknown_tool"));
1354 }
1355
1356 #[test]
1357 fn test_is_ready() {
1358 let mut providers = HashMap::new();
1359 providers.insert(
1360 "anthropic".to_string(),
1361 ProviderConfig {
1362 api_key: Some("test-key".to_string()),
1363 model: "claude-3-5-sonnet-20241022".to_string(),
1364 base_url: None,
1365 max_tokens: 4096,
1366 timeout: 60,
1367 max_iterations: 100,
1368 thinking_enabled: false,
1369 clear_thinking: true,
1370 },
1371 );
1372 let config_with_key = LlmConfig {
1373 provider: "anthropic".to_string(),
1374 providers,
1375 browser: BrowserConfigSection::default(),
1376 };
1377
1378 let bridge = AgentBridge::new(config_with_key).unwrap();
1379 assert!(bridge.is_ready());
1380 }
1381}