browser_use/tools/
mod.rs

1//! Browser automation tools module
2//!
3//! This module provides a framework for browser automation tools and
4//! includes implementations of common browser operations.
5
6pub mod click;
7pub mod evaluate;
8pub mod extract;
9pub mod get_clickable_elements;
10pub mod input;
11pub mod markdown;
12pub mod navigate;
13pub mod read_links;
14pub mod screenshot;
15pub mod wait;
16
17// Re-export Params types for use by MCP layer
18pub use click::ClickParams;
19pub use evaluate::EvaluateParams;
20pub use extract::ExtractParams;
21pub use get_clickable_elements::GetClickableElementsParams;
22pub use input::InputParams;
23pub use markdown::GetMarkdownParams;
24pub use navigate::NavigateParams;
25pub use read_links::ReadLinksParams;
26pub use screenshot::ScreenshotParams;
27pub use wait::WaitParams;
28
29use crate::browser::BrowserSession;
30use crate::dom::DomTree;
31use crate::error::Result;
32use serde_json::Value;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36/// Tool execution context
37pub struct ToolContext<'a> {
38    /// Browser session
39    pub session: &'a BrowserSession,
40
41    /// Optional DOM tree (extracted on demand)
42    pub dom_tree: Option<DomTree>,
43}
44
45impl<'a> ToolContext<'a> {
46    /// Create a new tool context
47    pub fn new(session: &'a BrowserSession) -> Self {
48        Self {
49            session,
50            dom_tree: None,
51        }
52    }
53
54    /// Create a context with a pre-extracted DOM tree
55    pub fn with_dom(session: &'a BrowserSession, dom_tree: DomTree) -> Self {
56        Self {
57            session,
58            dom_tree: Some(dom_tree),
59        }
60    }
61
62    /// Get or extract the DOM tree
63    pub fn get_dom(&mut self) -> Result<&DomTree> {
64        if self.dom_tree.is_none() {
65            self.dom_tree = Some(self.session.extract_dom()?);
66        }
67        Ok(self.dom_tree.as_ref().unwrap())
68    }
69}
70
71/// Result of tool execution
72#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
73pub struct ToolResult {
74    /// Whether the tool execution was successful
75    pub success: bool,
76
77    /// Result data (JSON value)
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub data: Option<Value>,
80
81    /// Error message if execution failed
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub error: Option<String>,
84
85    /// Additional metadata
86    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
87    pub metadata: HashMap<String, Value>,
88}
89
90impl ToolResult {
91    /// Create a successful result
92    pub fn success(data: Option<Value>) -> Self {
93        Self {
94            success: true,
95            data,
96            error: None,
97            metadata: HashMap::new(),
98        }
99    }
100
101    /// Create a successful result with data
102    pub fn success_with<T: serde::Serialize>(data: T) -> Self {
103        Self {
104            success: true,
105            data: serde_json::to_value(data).ok(),
106            error: None,
107            metadata: HashMap::new(),
108        }
109    }
110
111    /// Create a failure result
112    pub fn failure(error: impl Into<String>) -> Self {
113        Self {
114            success: false,
115            data: None,
116            error: Some(error.into()),
117            metadata: HashMap::new(),
118        }
119    }
120
121    /// Add metadata to the result
122    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
123        self.metadata.insert(key.into(), value);
124        self
125    }
126}
127
128/// Trait for browser automation tools with associated parameter types
129pub trait Tool: Send + Sync + Default {
130    /// Associated parameter type for this tool
131    type Params: serde::Serialize + for<'de> serde::Deserialize<'de> + schemars::JsonSchema;
132
133    /// Get tool name
134    fn name(&self) -> &str;
135
136    /// Get tool parameter schema (JSON Schema)
137    fn parameters_schema(&self) -> Value {
138        serde_json::to_value(schemars::schema_for!(Self::Params)).unwrap_or_default()
139    }
140
141    /// Execute the tool with strongly-typed parameters
142    fn execute_typed(&self, params: Self::Params, context: &mut ToolContext) -> Result<ToolResult>;
143
144    /// Execute the tool with JSON parameters (default implementation)
145    fn execute(&self, params: Value, context: &mut ToolContext) -> Result<ToolResult> {
146        let typed_params: Self::Params = serde_json::from_value(params).map_err(|e| {
147            crate::error::BrowserError::InvalidArgument(format!("Invalid parameters: {}", e))
148        })?;
149        self.execute_typed(typed_params, context)
150    }
151}
152
153/// Type-erased tool trait for dynamic dispatch
154pub trait DynTool: Send + Sync {
155    fn name(&self) -> &str;
156    fn parameters_schema(&self) -> Value;
157    fn execute(&self, params: Value, context: &mut ToolContext) -> Result<ToolResult>;
158}
159
160/// Blanket implementation to convert any Tool into DynTool
161impl<T: Tool> DynTool for T {
162    fn name(&self) -> &str {
163        Tool::name(self)
164    }
165
166    fn parameters_schema(&self) -> Value {
167        Tool::parameters_schema(self)
168    }
169
170    fn execute(&self, params: Value, context: &mut ToolContext) -> Result<ToolResult> {
171        Tool::execute(self, params, context)
172    }
173}
174
175/// Tool registry for managing and accessing tools
176pub struct ToolRegistry {
177    tools: HashMap<String, Arc<dyn DynTool>>,
178}
179
180impl ToolRegistry {
181    /// Create a new empty tool registry
182    pub fn new() -> Self {
183        Self {
184            tools: HashMap::new(),
185        }
186    }
187
188    /// Create a registry with default tools
189    pub fn with_defaults() -> Self {
190        let mut registry = Self::new();
191
192        // Register default tools
193        registry.register(navigate::NavigateTool);
194        registry.register(click::ClickTool);
195        registry.register(input::InputTool);
196        registry.register(extract::ExtractContentTool);
197        registry.register(screenshot::ScreenshotTool);
198        registry.register(evaluate::EvaluateTool);
199        registry.register(wait::WaitTool);
200        registry.register(markdown::GetMarkdownTool);
201        registry.register(read_links::ReadLinksTool);
202        registry.register(get_clickable_elements::GetClickableElementsTool);
203
204        registry
205    }
206
207    /// Register a tool
208    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
209        let name = tool.name().to_string();
210        self.tools.insert(name, Arc::new(tool));
211    }
212
213    /// Get a tool by name
214    pub fn get(&self, name: &str) -> Option<&Arc<dyn DynTool>> {
215        self.tools.get(name)
216    }
217
218    /// Check if a tool exists
219    pub fn has(&self, name: &str) -> bool {
220        self.tools.contains_key(name)
221    }
222
223    /// List all tool names
224    pub fn list_names(&self) -> Vec<String> {
225        self.tools.keys().cloned().collect()
226    }
227
228    /// Get all tools
229    pub fn all_tools(&self) -> Vec<Arc<dyn DynTool>> {
230        self.tools.values().cloned().collect()
231    }
232
233    /// Execute a tool by name
234    pub fn execute(
235        &self,
236        name: &str,
237        params: Value,
238        context: &mut ToolContext,
239    ) -> Result<ToolResult> {
240        match self.get(name) {
241            Some(tool) => tool.execute(params, context),
242            None => Ok(ToolResult::failure(format!("Tool '{}' not found", name))),
243        }
244    }
245
246    /// Get the number of registered tools
247    pub fn count(&self) -> usize {
248        self.tools.len()
249    }
250}
251
252impl Default for ToolRegistry {
253    fn default() -> Self {
254        Self::with_defaults()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_tool_result_success() {
264        let result = ToolResult::success(Some(serde_json::json!({"url": "https://example.com"})));
265        assert!(result.success);
266        assert!(result.data.is_some());
267        assert!(result.error.is_none());
268    }
269
270    #[test]
271    fn test_tool_result_failure() {
272        let result = ToolResult::failure("Test error");
273        assert!(!result.success);
274        assert!(result.data.is_none());
275        assert_eq!(result.error, Some("Test error".to_string()));
276    }
277
278    #[test]
279    fn test_tool_result_with_metadata() {
280        let result = ToolResult::success(None).with_metadata("duration_ms", serde_json::json!(100));
281
282        assert!(result.metadata.contains_key("duration_ms"));
283    }
284
285    #[test]
286    fn test_tool_registry() {
287        let registry = ToolRegistry::with_defaults();
288
289        assert!(registry.has("navigate"));
290        assert!(registry.has("click"));
291        assert!(registry.has("input"));
292        assert!(!registry.has("nonexistent"));
293
294        assert!(registry.count() >= 10); // At least 10 default tools
295    }
296
297    #[test]
298    fn test_tool_registry_list() {
299        let registry = ToolRegistry::with_defaults();
300        let names = registry.list_names();
301
302        assert!(names.contains(&"navigate".to_string()));
303        assert!(names.contains(&"click".to_string()));
304        assert!(names.contains(&"get_clickable_elements".to_string()));
305    }
306}