1pub mod click;
7pub mod evaluate;
8pub mod extract;
9pub mod get_clickable_elements;
10pub mod input;
11pub mod markdown;
12pub mod navigate;
13pub mod read_links;
14pub mod screenshot;
15pub mod wait;
16
17pub use click::ClickParams;
19pub use evaluate::EvaluateParams;
20pub use extract::ExtractParams;
21pub use get_clickable_elements::GetClickableElementsParams;
22pub use input::InputParams;
23pub use markdown::GetMarkdownParams;
24pub use navigate::NavigateParams;
25pub use read_links::ReadLinksParams;
26pub use screenshot::ScreenshotParams;
27pub use wait::WaitParams;
28
29use crate::browser::BrowserSession;
30use crate::dom::DomTree;
31use crate::error::Result;
32use serde_json::Value;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36pub struct ToolContext<'a> {
38 pub session: &'a BrowserSession,
40
41 pub dom_tree: Option<DomTree>,
43}
44
45impl<'a> ToolContext<'a> {
46 pub fn new(session: &'a BrowserSession) -> Self {
48 Self {
49 session,
50 dom_tree: None,
51 }
52 }
53
54 pub fn with_dom(session: &'a BrowserSession, dom_tree: DomTree) -> Self {
56 Self {
57 session,
58 dom_tree: Some(dom_tree),
59 }
60 }
61
62 pub fn get_dom(&mut self) -> Result<&DomTree> {
64 if self.dom_tree.is_none() {
65 self.dom_tree = Some(self.session.extract_dom()?);
66 }
67 Ok(self.dom_tree.as_ref().unwrap())
68 }
69}
70
71#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
73pub struct ToolResult {
74 pub success: bool,
76
77 #[serde(skip_serializing_if = "Option::is_none")]
79 pub data: Option<Value>,
80
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub error: Option<String>,
84
85 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
87 pub metadata: HashMap<String, Value>,
88}
89
90impl ToolResult {
91 pub fn success(data: Option<Value>) -> Self {
93 Self {
94 success: true,
95 data,
96 error: None,
97 metadata: HashMap::new(),
98 }
99 }
100
101 pub fn success_with<T: serde::Serialize>(data: T) -> Self {
103 Self {
104 success: true,
105 data: serde_json::to_value(data).ok(),
106 error: None,
107 metadata: HashMap::new(),
108 }
109 }
110
111 pub fn failure(error: impl Into<String>) -> Self {
113 Self {
114 success: false,
115 data: None,
116 error: Some(error.into()),
117 metadata: HashMap::new(),
118 }
119 }
120
121 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
123 self.metadata.insert(key.into(), value);
124 self
125 }
126}
127
128pub trait Tool: Send + Sync + Default {
130 type Params: serde::Serialize + for<'de> serde::Deserialize<'de> + schemars::JsonSchema;
132
133 fn name(&self) -> &str;
135
136 fn parameters_schema(&self) -> Value {
138 serde_json::to_value(schemars::schema_for!(Self::Params)).unwrap_or_default()
139 }
140
141 fn execute_typed(&self, params: Self::Params, context: &mut ToolContext) -> Result<ToolResult>;
143
144 fn execute(&self, params: Value, context: &mut ToolContext) -> Result<ToolResult> {
146 let typed_params: Self::Params = serde_json::from_value(params).map_err(|e| {
147 crate::error::BrowserError::InvalidArgument(format!("Invalid parameters: {}", e))
148 })?;
149 self.execute_typed(typed_params, context)
150 }
151}
152
153pub trait DynTool: Send + Sync {
155 fn name(&self) -> &str;
156 fn parameters_schema(&self) -> Value;
157 fn execute(&self, params: Value, context: &mut ToolContext) -> Result<ToolResult>;
158}
159
160impl<T: Tool> DynTool for T {
162 fn name(&self) -> &str {
163 Tool::name(self)
164 }
165
166 fn parameters_schema(&self) -> Value {
167 Tool::parameters_schema(self)
168 }
169
170 fn execute(&self, params: Value, context: &mut ToolContext) -> Result<ToolResult> {
171 Tool::execute(self, params, context)
172 }
173}
174
175pub struct ToolRegistry {
177 tools: HashMap<String, Arc<dyn DynTool>>,
178}
179
180impl ToolRegistry {
181 pub fn new() -> Self {
183 Self {
184 tools: HashMap::new(),
185 }
186 }
187
188 pub fn with_defaults() -> Self {
190 let mut registry = Self::new();
191
192 registry.register(navigate::NavigateTool);
194 registry.register(click::ClickTool);
195 registry.register(input::InputTool);
196 registry.register(extract::ExtractContentTool);
197 registry.register(screenshot::ScreenshotTool);
198 registry.register(evaluate::EvaluateTool);
199 registry.register(wait::WaitTool);
200 registry.register(markdown::GetMarkdownTool);
201 registry.register(read_links::ReadLinksTool);
202 registry.register(get_clickable_elements::GetClickableElementsTool);
203
204 registry
205 }
206
207 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
209 let name = tool.name().to_string();
210 self.tools.insert(name, Arc::new(tool));
211 }
212
213 pub fn get(&self, name: &str) -> Option<&Arc<dyn DynTool>> {
215 self.tools.get(name)
216 }
217
218 pub fn has(&self, name: &str) -> bool {
220 self.tools.contains_key(name)
221 }
222
223 pub fn list_names(&self) -> Vec<String> {
225 self.tools.keys().cloned().collect()
226 }
227
228 pub fn all_tools(&self) -> Vec<Arc<dyn DynTool>> {
230 self.tools.values().cloned().collect()
231 }
232
233 pub fn execute(
235 &self,
236 name: &str,
237 params: Value,
238 context: &mut ToolContext,
239 ) -> Result<ToolResult> {
240 match self.get(name) {
241 Some(tool) => tool.execute(params, context),
242 None => Ok(ToolResult::failure(format!("Tool '{}' not found", name))),
243 }
244 }
245
246 pub fn count(&self) -> usize {
248 self.tools.len()
249 }
250}
251
252impl Default for ToolRegistry {
253 fn default() -> Self {
254 Self::with_defaults()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_tool_result_success() {
264 let result = ToolResult::success(Some(serde_json::json!({"url": "https://example.com"})));
265 assert!(result.success);
266 assert!(result.data.is_some());
267 assert!(result.error.is_none());
268 }
269
270 #[test]
271 fn test_tool_result_failure() {
272 let result = ToolResult::failure("Test error");
273 assert!(!result.success);
274 assert!(result.data.is_none());
275 assert_eq!(result.error, Some("Test error".to_string()));
276 }
277
278 #[test]
279 fn test_tool_result_with_metadata() {
280 let result = ToolResult::success(None).with_metadata("duration_ms", serde_json::json!(100));
281
282 assert!(result.metadata.contains_key("duration_ms"));
283 }
284
285 #[test]
286 fn test_tool_registry() {
287 let registry = ToolRegistry::with_defaults();
288
289 assert!(registry.has("navigate"));
290 assert!(registry.has("click"));
291 assert!(registry.has("input"));
292 assert!(!registry.has("nonexistent"));
293
294 assert!(registry.count() >= 10); }
296
297 #[test]
298 fn test_tool_registry_list() {
299 let registry = ToolRegistry::with_defaults();
300 let names = registry.list_names();
301
302 assert!(names.contains(&"navigate".to_string()));
303 assert!(names.contains(&"click".to_string()));
304 assert!(names.contains(&"get_clickable_elements".to_string()));
305 }
306}