1use std::collections::HashSet;
36
37use serde::{Deserialize, Serialize};
38
39use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
40
41use crate::action::{extract_action, ToolAction};
42use crate::external::TokenBucket;
43
44pub fn default_allowed_action_types() -> Vec<String> {
49 vec![
50 "remote.session.connect".to_string(),
51 "remote.session.disconnect".to_string(),
52 "remote.session.reconnect".to_string(),
53 "input.inject".to_string(),
54 "remote.clipboard".to_string(),
55 "remote.file_transfer".to_string(),
56 "remote.audio".to_string(),
57 "remote.drive_mapping".to_string(),
58 "remote.printing".to_string(),
59 "remote.session_share".to_string(),
60 ]
61}
62
63#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
65#[serde(rename_all = "snake_case")]
66pub enum EnforcementMode {
67 Observe,
69 #[default]
71 Guardrail,
72 FailClosed,
74}
75
76#[derive(Clone, Debug, Deserialize, Serialize)]
78#[serde(deny_unknown_fields)]
79pub struct ComputerUseConfig {
80 #[serde(default = "default_true")]
83 pub enabled: bool,
84 #[serde(default = "default_allowed_action_types")]
86 pub allowed_action_types: Vec<String>,
87 #[serde(default)]
89 pub mode: EnforcementMode,
90 #[serde(default)]
93 pub blocked_domains: Vec<String>,
94 #[serde(default)]
99 pub allowed_domains: Vec<String>,
100 #[serde(default)]
103 pub screenshot_rate_per_second: Option<f64>,
104 #[serde(default)]
107 pub screenshot_burst: Option<u32>,
108}
109
110fn default_true() -> bool {
111 true
112}
113
114impl Default for ComputerUseConfig {
115 fn default() -> Self {
116 Self {
117 enabled: true,
118 allowed_action_types: default_allowed_action_types(),
119 mode: EnforcementMode::Guardrail,
120 blocked_domains: Vec::new(),
121 allowed_domains: Vec::new(),
122 screenshot_rate_per_second: None,
123 screenshot_burst: None,
124 }
125 }
126}
127
128pub struct ComputerUseGuard {
132 enabled: bool,
133 mode: EnforcementMode,
134 allowed_actions: HashSet<String>,
135 blocked_domains: Vec<String>,
136 allowed_domains: Vec<String>,
137 screenshot_bucket: Option<TokenBucket>,
138}
139
140impl ComputerUseGuard {
141 pub fn new() -> Self {
143 Self::with_config(ComputerUseConfig::default())
144 }
145
146 pub fn with_config(config: ComputerUseConfig) -> Self {
148 let allowed_actions: HashSet<String> = config.allowed_action_types.into_iter().collect();
149 let screenshot_bucket = match config.screenshot_rate_per_second {
150 Some(rate) if rate > 0.0 && rate.is_finite() => {
151 let burst = config.screenshot_burst.unwrap_or(5).max(1);
152 Some(TokenBucket::new(rate, burst))
153 }
154 _ => None,
155 };
156 Self {
157 enabled: config.enabled,
158 mode: config.mode,
159 allowed_actions,
160 blocked_domains: config.blocked_domains,
161 allowed_domains: config.allowed_domains,
162 screenshot_bucket,
163 }
164 }
165
166 fn is_screenshot_verb(verb: &str) -> bool {
169 let v = verb.to_ascii_lowercase();
170 matches!(
171 v.as_str(),
172 "screenshot"
173 | "screen_capture"
174 | "screen_shot"
175 | "capture"
176 | "capture_screen"
177 | "browser_screenshot"
178 )
179 }
180
181 fn extract_cua_action_type<'a>(
188 tool_name: &'a str,
189 arguments: &'a serde_json::Value,
190 ) -> Option<String> {
191 if tool_name.starts_with("remote.") || tool_name.starts_with("input.") {
192 return Some(tool_name.to_string());
193 }
194 for key in ["action_type", "actionType", "custom_type", "customType"] {
195 if let Some(value) = arguments.get(key).and_then(|v| v.as_str()) {
196 if value.starts_with("remote.") || value.starts_with("input.") {
197 return Some(value.to_string());
198 }
199 }
200 }
201 None
202 }
203
204 fn apply_mode(&self, in_allowlist: bool) -> Verdict {
206 match (self.mode, in_allowlist) {
207 (EnforcementMode::Observe, _) => Verdict::Allow,
208 (EnforcementMode::Guardrail, _) => Verdict::Allow,
209 (EnforcementMode::FailClosed, true) => Verdict::Allow,
210 (EnforcementMode::FailClosed, false) => Verdict::Deny,
211 }
212 }
213
214 fn check_navigation(&self, target: &str) -> Verdict {
216 if self.blocked_domains.is_empty() && self.allowed_domains.is_empty() {
219 return Verdict::Allow;
220 }
221 let host = match extract_host(target) {
222 Some(host) => host,
223 None => {
224 return Verdict::Allow;
228 }
229 };
230 let blocked = self
231 .blocked_domains
232 .iter()
233 .any(|pat| matches_domain(pat, &host));
234 if blocked {
235 return match self.mode {
236 EnforcementMode::Observe => Verdict::Allow,
237 EnforcementMode::Guardrail | EnforcementMode::FailClosed => Verdict::Deny,
238 };
239 }
240 if !self.allowed_domains.is_empty() {
241 let allowed = self
242 .allowed_domains
243 .iter()
244 .any(|pat| matches_domain(pat, &host));
245 if !allowed {
246 return match self.mode {
247 EnforcementMode::Observe | EnforcementMode::Guardrail => Verdict::Allow,
248 EnforcementMode::FailClosed => Verdict::Deny,
249 };
250 }
251 }
252 Verdict::Allow
253 }
254}
255
256impl Default for ComputerUseGuard {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262impl Guard for ComputerUseGuard {
263 fn name(&self) -> &str {
264 "computer-use"
265 }
266
267 fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
268 if !self.enabled {
269 return Ok(Verdict::Allow);
270 }
271
272 if let Some(action_type) =
274 Self::extract_cua_action_type(&ctx.request.tool_name, &ctx.request.arguments)
275 {
276 let in_allowlist = self.allowed_actions.contains(&action_type);
277 return Ok(self.apply_mode(in_allowlist));
278 }
279
280 let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
282 if let ToolAction::BrowserAction { verb, target } = &action {
283 if Self::is_screenshot_verb(verb) {
285 if let Some(bucket) = &self.screenshot_bucket {
286 if !bucket.try_acquire() {
287 return Ok(match self.mode {
288 EnforcementMode::Observe => Verdict::Allow,
289 EnforcementMode::Guardrail | EnforcementMode::FailClosed => {
290 Verdict::Deny
291 }
292 });
293 }
294 }
295 return Ok(Verdict::Allow);
296 }
297
298 if matches!(
300 verb.to_ascii_lowercase().as_str(),
301 "navigate" | "goto" | "open"
302 ) {
303 if let Some(url) = target {
304 return Ok(self.check_navigation(url));
305 }
306 }
307 }
308
309 Ok(Verdict::Allow)
311 }
312}
313
314fn matches_domain(pattern: &str, host: &str) -> bool {
317 let pattern = pattern.trim().to_ascii_lowercase();
318 let host = host.trim().to_ascii_lowercase();
319 if pattern.is_empty() || host.is_empty() {
320 return false;
321 }
322 if let Some(suffix) = pattern.strip_prefix("*.") {
323 return host == suffix || host.ends_with(&format!(".{suffix}"));
324 }
325 pattern == host
326}
327
328fn extract_host(url: &str) -> Option<String> {
331 let url = url.trim();
332 if url.is_empty() {
333 return None;
334 }
335 if url.starts_with('#') || url.starts_with('.') || url.starts_with('[') {
337 return None;
338 }
339 let lowered = url.to_ascii_lowercase();
341 if lowered.starts_with("data:")
342 || lowered.starts_with("javascript:")
343 || lowered.starts_with("about:")
344 || lowered.starts_with("file:")
345 {
346 return None;
347 }
348 let rest = if lowered.starts_with("https://") {
349 &url["https://".len()..]
350 } else if lowered.starts_with("http://") {
351 &url["http://".len()..]
352 } else if let Some(rest) = url.strip_prefix("//") {
353 rest
354 } else {
355 url
356 };
357 let host_with_port = rest.split(['/', '?', '#']).next().unwrap_or(rest);
358 let host_without_userinfo = host_with_port
359 .rsplit_once('@')
360 .map(|(_, host)| host)
361 .unwrap_or(host_with_port);
362 let host = if let Some(bracketed) = host_without_userinfo.strip_prefix('[') {
363 let (host, remainder) = bracketed.split_once(']')?;
364 if !remainder.is_empty() && !remainder.starts_with(':') {
365 return None;
366 }
367 host
368 } else {
369 host_without_userinfo
370 .rsplit_once(':')
371 .map(|(h, _)| h)
372 .unwrap_or(host_without_userinfo)
373 }
374 .trim_matches(|c: char| c == '/' || c == '.');
375 if host.is_empty() {
376 return None;
377 }
378 Some(host.to_ascii_lowercase())
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn matches_domain_exact_and_wildcard() {
387 assert!(matches_domain("example.com", "example.com"));
388 assert!(!matches_domain("example.com", "evil.com"));
389 assert!(matches_domain("*.example.com", "api.example.com"));
390 assert!(matches_domain("*.example.com", "example.com"));
391 assert!(!matches_domain("*.example.com", "example.org"));
392 }
393
394 #[test]
395 fn extract_host_handles_common_urls() {
396 assert_eq!(
397 extract_host("https://example.com/x"),
398 Some("example.com".into())
399 );
400 assert_eq!(
401 extract_host("HTTPS://169.254.169.254/latest"),
402 Some("169.254.169.254".into())
403 );
404 assert_eq!(
405 extract_host("https://user:pass@example.com:8443/x"),
406 Some("example.com".into())
407 );
408 assert_eq!(
409 extract_host("https://user@[fd00:ec2::254]:8443/x"),
410 Some("fd00:ec2::254".into())
411 );
412 assert_eq!(
413 extract_host("http://localhost:8080"),
414 Some("localhost".into())
415 );
416 assert_eq!(
417 extract_host("example.com:443/y"),
418 Some("example.com".into())
419 );
420 assert_eq!(
421 extract_host("//169.254.169.254/latest"),
422 Some("169.254.169.254".into())
423 );
424 assert_eq!(
425 extract_host("https://blocked.example?redir=1"),
426 Some("blocked.example".into())
427 );
428 assert_eq!(
429 extract_host("https://blocked.example#anchor"),
430 Some("blocked.example".into())
431 );
432 assert_eq!(extract_host("#submit"), None);
433 assert_eq!(extract_host("data:text/plain,hi"), None);
434 }
435
436 #[test]
437 fn check_navigation_blocks_scheme_relative_urls() {
438 let guard = ComputerUseGuard::with_config(ComputerUseConfig {
439 mode: EnforcementMode::FailClosed,
440 blocked_domains: vec!["169.254.169.254".into()],
441 ..ComputerUseConfig::default()
442 });
443
444 assert_eq!(
445 guard.check_navigation("//169.254.169.254/latest"),
446 Verdict::Deny
447 );
448 }
449
450 #[test]
451 fn check_navigation_blocks_urls_with_userinfo() {
452 let guard = ComputerUseGuard::with_config(ComputerUseConfig {
453 mode: EnforcementMode::FailClosed,
454 blocked_domains: vec!["blocked.example".into()],
455 ..ComputerUseConfig::default()
456 });
457
458 assert_eq!(
459 guard.check_navigation("https://user@blocked.example/path"),
460 Verdict::Deny
461 );
462 }
463
464 #[test]
465 fn check_navigation_blocks_bracketed_ipv6_hosts() {
466 let guard = ComputerUseGuard::with_config(ComputerUseConfig {
467 mode: EnforcementMode::FailClosed,
468 blocked_domains: vec!["fd00:ec2::254".into()],
469 ..ComputerUseConfig::default()
470 });
471
472 assert_eq!(
473 guard.check_navigation("https://[fd00:ec2::254]/latest"),
474 Verdict::Deny
475 );
476 }
477
478 #[test]
479 fn check_navigation_blocks_query_and_fragment_only_urls() {
480 let guard = ComputerUseGuard::with_config(ComputerUseConfig {
481 mode: EnforcementMode::FailClosed,
482 blocked_domains: vec!["blocked.example".into()],
483 ..ComputerUseConfig::default()
484 });
485
486 assert_eq!(
487 guard.check_navigation("https://blocked.example?redir=1"),
488 Verdict::Deny
489 );
490 assert_eq!(
491 guard.check_navigation("https://blocked.example#anchor"),
492 Verdict::Deny
493 );
494 }
495
496 #[test]
497 fn check_navigation_blocks_mixed_case_scheme_urls() {
498 let guard = ComputerUseGuard::with_config(ComputerUseConfig {
499 mode: EnforcementMode::FailClosed,
500 blocked_domains: vec!["169.254.169.254".into()],
501 ..ComputerUseConfig::default()
502 });
503
504 assert_eq!(
505 guard.check_navigation("HTTPS://169.254.169.254/latest"),
506 Verdict::Deny
507 );
508 }
509
510 #[test]
511 fn is_screenshot_verb_matches_common_names() {
512 assert!(ComputerUseGuard::is_screenshot_verb("screenshot"));
513 assert!(ComputerUseGuard::is_screenshot_verb("capture_screen"));
514 assert!(!ComputerUseGuard::is_screenshot_verb("click"));
515 }
516
517 #[test]
518 fn extract_cua_action_type_reads_args() {
519 let args = serde_json::json!({"action_type": "remote.clipboard"});
520 assert_eq!(
521 ComputerUseGuard::extract_cua_action_type("unknown", &args),
522 Some("remote.clipboard".to_string())
523 );
524 }
525}