Skip to main content

agentzero_tools/
browser.rs

1use agentzero_core::common::url_policy::UrlAccessPolicy;
2use agentzero_core::common::util::parse_http_url_with_policy;
3use agentzero_core::{Tool, ToolContext, ToolResult};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use serde::Deserialize;
7use std::process::Stdio;
8use tokio::io::AsyncReadExt;
9use tokio::process::Command;
10
11const MAX_OUTPUT_BYTES: usize = 65536;
12const DEFAULT_TIMEOUT_MS: u64 = 30_000;
13
14#[derive(Debug, Deserialize)]
15#[serde(tag = "action")]
16#[serde(rename_all = "snake_case")]
17#[allow(dead_code)]
18enum BrowserAction {
19    Navigate {
20        url: String,
21    },
22    Snapshot,
23    Click {
24        selector: String,
25    },
26    Fill {
27        selector: String,
28        value: String,
29    },
30    Type {
31        selector: String,
32        text: String,
33    },
34    GetText {
35        selector: String,
36    },
37    GetTitle,
38    GetUrl,
39    Screenshot {
40        #[serde(default)]
41        path: Option<String>,
42    },
43    Wait {
44        #[serde(default)]
45        selector: Option<String>,
46        #[serde(default)]
47        ms: Option<u64>,
48    },
49    Press {
50        key: String,
51    },
52    Hover {
53        selector: String,
54    },
55    Scroll {
56        direction: String,
57    },
58    Close,
59}
60
61#[derive(Debug, Clone)]
62pub struct BrowserConfig {
63    pub agent_browser_command: String,
64    pub agent_browser_extra_args: Vec<String>,
65    pub timeout_ms: u64,
66    pub allowed_domains: Vec<String>,
67}
68
69impl Default for BrowserConfig {
70    fn default() -> Self {
71        Self {
72            agent_browser_command: "agent-browser".to_string(),
73            agent_browser_extra_args: Vec::new(),
74            timeout_ms: DEFAULT_TIMEOUT_MS,
75            allowed_domains: Vec::new(),
76        }
77    }
78}
79
80#[derive(Default)]
81pub struct BrowserTool {
82    config: BrowserConfig,
83    url_policy: UrlAccessPolicy,
84}
85
86impl BrowserTool {
87    pub fn new(config: BrowserConfig) -> Self {
88        Self {
89            config,
90            url_policy: UrlAccessPolicy::default(),
91        }
92    }
93
94    pub fn with_url_policy(mut self, policy: UrlAccessPolicy) -> Self {
95        self.url_policy = policy;
96        self
97    }
98
99    fn validate_selector(selector: &str) -> anyhow::Result<()> {
100        if selector.trim().is_empty() {
101            return Err(anyhow!("selector must not be empty"));
102        }
103        Ok(())
104    }
105
106    fn validate_domain(&self, url: &str) -> anyhow::Result<()> {
107        if self.config.allowed_domains.is_empty() {
108            return Ok(());
109        }
110        let parsed = url::Url::parse(url).context("invalid URL")?;
111        let host = parsed.host_str().unwrap_or("");
112        if !self
113            .config
114            .allowed_domains
115            .iter()
116            .any(|d| host == d || host.ends_with(&format!(".{d}")))
117        {
118            return Err(anyhow!(
119                "domain {} is not in the allowed domains list",
120                host
121            ));
122        }
123        Ok(())
124    }
125
126    async fn send_to_agent_browser(&self, action_json: &str) -> anyhow::Result<String> {
127        let mut cmd = Command::new(&self.config.agent_browser_command);
128        cmd.args(&self.config.agent_browser_extra_args);
129        cmd.arg("--action").arg(action_json);
130        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
131
132        let mut child = cmd.spawn().with_context(|| {
133            format!(
134                "failed to spawn agent-browser command: {}",
135                self.config.agent_browser_command
136            )
137        })?;
138
139        let stdout_handle = child
140            .stdout
141            .take()
142            .context("stdout not piped on spawned child")?;
143        let stderr_handle = child
144            .stderr
145            .take()
146            .context("stderr not piped on spawned child")?;
147
148        let stdout_task = tokio::spawn(read_limited(stdout_handle));
149        let stderr_task = tokio::spawn(read_limited(stderr_handle));
150
151        let timeout = tokio::time::Duration::from_millis(self.config.timeout_ms);
152        let status = tokio::time::timeout(timeout, child.wait())
153            .await
154            .context("agent-browser timed out")?
155            .context("agent-browser command failed")?;
156
157        let stdout = stdout_task.await.context("stdout join")??;
158        let stderr = stderr_task.await.context("stderr join")??;
159
160        let mut output = format!("exit={}\n", status.code().unwrap_or(-1));
161        if !stdout.is_empty() {
162            output.push_str(&stdout);
163        }
164        if !stderr.is_empty() {
165            output.push_str("\nstderr:\n");
166            output.push_str(&stderr);
167        }
168        Ok(output)
169    }
170}
171
172async fn read_limited<R: tokio::io::AsyncRead + Unpin>(mut reader: R) -> anyhow::Result<String> {
173    let mut buf = Vec::new();
174    let mut limited = (&mut reader).take((MAX_OUTPUT_BYTES + 1) as u64);
175    limited.read_to_end(&mut buf).await?;
176    let truncated = buf.len() > MAX_OUTPUT_BYTES;
177    if truncated {
178        buf.truncate(MAX_OUTPUT_BYTES);
179    }
180    let mut s = String::from_utf8_lossy(&buf).to_string();
181    if truncated {
182        s.push_str(&format!("\n<truncated at {} bytes>", MAX_OUTPUT_BYTES));
183    }
184    Ok(s)
185}
186
187#[async_trait]
188impl Tool for BrowserTool {
189    fn name(&self) -> &'static str {
190        "browser"
191    }
192
193    fn description(&self) -> &'static str {
194        "Control a headless browser: navigate to URLs, execute JavaScript, take screenshots, and extract page content."
195    }
196
197    fn input_schema(&self) -> Option<serde_json::Value> {
198        Some(serde_json::json!({
199            "type": "object",
200            "properties": {
201                "action": {
202                    "type": "string",
203                    "description": "Browser action: navigate, execute_js, screenshot, content, close",
204                    "enum": ["navigate", "execute_js", "screenshot", "content", "close"]
205                },
206                "url": { "type": "string", "description": "URL to navigate to (for navigate action)" },
207                "script": { "type": "string", "description": "JavaScript to execute (for execute_js action)" }
208            },
209            "required": ["action"]
210        }))
211    }
212
213    async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
214        let action: BrowserAction =
215            serde_json::from_str(input).context("browser expects JSON with \"action\" field")?;
216
217        match &action {
218            BrowserAction::Navigate { url } => {
219                parse_http_url_with_policy(url, &self.url_policy)?;
220                self.validate_domain(url)?;
221            }
222            BrowserAction::Click { selector }
223            | BrowserAction::Fill { selector, .. }
224            | BrowserAction::Type { selector, .. }
225            | BrowserAction::GetText { selector }
226            | BrowserAction::Hover { selector } => {
227                Self::validate_selector(selector)?;
228            }
229            _ => {}
230        }
231
232        let output = self.send_to_agent_browser(input).await?;
233        Ok(ToolResult { output })
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[tokio::test]
242    async fn browser_rejects_invalid_json() {
243        let tool = BrowserTool::default();
244        let err = tool
245            .execute("not json", &ToolContext::new(".".to_string()))
246            .await
247            .expect_err("invalid JSON should fail");
248        assert!(err.to_string().contains("browser expects JSON"));
249    }
250
251    #[tokio::test]
252    async fn browser_navigate_blocks_private_ip() {
253        let tool = BrowserTool::default();
254        let err = tool
255            .execute(
256                r#"{"action": "navigate", "url": "http://10.0.0.1/internal"}"#,
257                &ToolContext::new(".".to_string()),
258            )
259            .await
260            .expect_err("private IP should be blocked");
261        assert!(err.to_string().contains("URL access denied"));
262    }
263
264    #[tokio::test]
265    async fn browser_navigate_blocks_unapproved_domain() {
266        let tool = BrowserTool::new(BrowserConfig {
267            allowed_domains: vec!["example.com".to_string()],
268            ..Default::default()
269        });
270        let err = tool
271            .execute(
272                r#"{"action": "navigate", "url": "https://evil.example.org/page"}"#,
273                &ToolContext::new(".".to_string()),
274            )
275            .await
276            .expect_err("non-allowed domain should be blocked");
277        assert!(err.to_string().contains("not in the allowed domains"));
278    }
279
280    #[tokio::test]
281    async fn browser_click_rejects_empty_selector() {
282        let tool = BrowserTool::default();
283        let err = tool
284            .execute(
285                r#"{"action": "click", "selector": ""}"#,
286                &ToolContext::new(".".to_string()),
287            )
288            .await
289            .expect_err("empty selector should fail");
290        assert!(err.to_string().contains("selector must not be empty"));
291    }
292
293    #[test]
294    fn validate_domain_allows_any_when_empty() {
295        let tool = BrowserTool::default();
296        assert!(tool.validate_domain("https://anything.com").is_ok());
297    }
298
299    #[test]
300    fn validate_domain_allows_subdomain() {
301        let tool = BrowserTool::new(BrowserConfig {
302            allowed_domains: vec!["example.com".to_string()],
303            ..Default::default()
304        });
305        assert!(tool.validate_domain("https://sub.example.com/page").is_ok());
306        assert!(tool.validate_domain("https://example.com").is_ok());
307    }
308}