use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
use crate::action::{extract_action, ToolAction};
use crate::external::TokenBucket;
pub fn default_allowed_action_types() -> Vec<String> {
vec![
"remote.session.connect".to_string(),
"remote.session.disconnect".to_string(),
"remote.session.reconnect".to_string(),
"input.inject".to_string(),
"remote.clipboard".to_string(),
"remote.file_transfer".to_string(),
"remote.audio".to_string(),
"remote.drive_mapping".to_string(),
"remote.printing".to_string(),
"remote.session_share".to_string(),
]
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EnforcementMode {
Observe,
#[default]
Guardrail,
FailClosed,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct ComputerUseConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_allowed_action_types")]
pub allowed_action_types: Vec<String>,
#[serde(default)]
pub mode: EnforcementMode,
#[serde(default)]
pub blocked_domains: Vec<String>,
#[serde(default)]
pub allowed_domains: Vec<String>,
#[serde(default)]
pub screenshot_rate_per_second: Option<f64>,
#[serde(default)]
pub screenshot_burst: Option<u32>,
}
fn default_true() -> bool {
true
}
impl Default for ComputerUseConfig {
fn default() -> Self {
Self {
enabled: true,
allowed_action_types: default_allowed_action_types(),
mode: EnforcementMode::Guardrail,
blocked_domains: Vec::new(),
allowed_domains: Vec::new(),
screenshot_rate_per_second: None,
screenshot_burst: None,
}
}
}
pub struct ComputerUseGuard {
enabled: bool,
mode: EnforcementMode,
allowed_actions: HashSet<String>,
blocked_domains: Vec<String>,
allowed_domains: Vec<String>,
screenshot_bucket: Option<TokenBucket>,
}
impl ComputerUseGuard {
pub fn new() -> Self {
Self::with_config(ComputerUseConfig::default())
}
pub fn with_config(config: ComputerUseConfig) -> Self {
let allowed_actions: HashSet<String> = config.allowed_action_types.into_iter().collect();
let screenshot_bucket = match config.screenshot_rate_per_second {
Some(rate) if rate > 0.0 && rate.is_finite() => {
let burst = config.screenshot_burst.unwrap_or(5).max(1);
Some(TokenBucket::new(rate, burst))
}
_ => None,
};
Self {
enabled: config.enabled,
mode: config.mode,
allowed_actions,
blocked_domains: config.blocked_domains,
allowed_domains: config.allowed_domains,
screenshot_bucket,
}
}
fn is_screenshot_verb(verb: &str) -> bool {
let v = verb.to_ascii_lowercase();
matches!(
v.as_str(),
"screenshot"
| "screen_capture"
| "screen_shot"
| "capture"
| "capture_screen"
| "browser_screenshot"
)
}
fn extract_cua_action_type<'a>(
tool_name: &'a str,
arguments: &'a serde_json::Value,
) -> Option<String> {
if tool_name.starts_with("remote.") || tool_name.starts_with("input.") {
return Some(tool_name.to_string());
}
for key in ["action_type", "actionType", "custom_type", "customType"] {
if let Some(value) = arguments.get(key).and_then(|v| v.as_str()) {
if value.starts_with("remote.") || value.starts_with("input.") {
return Some(value.to_string());
}
}
}
None
}
fn apply_mode(&self, in_allowlist: bool) -> Verdict {
match (self.mode, in_allowlist) {
(EnforcementMode::Observe, _) => Verdict::Allow,
(EnforcementMode::Guardrail, _) => Verdict::Allow,
(EnforcementMode::FailClosed, true) => Verdict::Allow,
(EnforcementMode::FailClosed, false) => Verdict::Deny,
}
}
fn check_navigation(&self, target: &str) -> Verdict {
if self.blocked_domains.is_empty() && self.allowed_domains.is_empty() {
return Verdict::Allow;
}
let host = match extract_host(target) {
Some(host) => host,
None => {
return Verdict::Allow;
}
};
let blocked = self
.blocked_domains
.iter()
.any(|pat| matches_domain(pat, &host));
if blocked {
return match self.mode {
EnforcementMode::Observe => Verdict::Allow,
EnforcementMode::Guardrail | EnforcementMode::FailClosed => Verdict::Deny,
};
}
if !self.allowed_domains.is_empty() {
let allowed = self
.allowed_domains
.iter()
.any(|pat| matches_domain(pat, &host));
if !allowed {
return match self.mode {
EnforcementMode::Observe | EnforcementMode::Guardrail => Verdict::Allow,
EnforcementMode::FailClosed => Verdict::Deny,
};
}
}
Verdict::Allow
}
}
impl Default for ComputerUseGuard {
fn default() -> Self {
Self::new()
}
}
impl Guard for ComputerUseGuard {
fn name(&self) -> &str {
"computer-use"
}
fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
if !self.enabled {
return Ok(Verdict::Allow);
}
if let Some(action_type) =
Self::extract_cua_action_type(&ctx.request.tool_name, &ctx.request.arguments)
{
let in_allowlist = self.allowed_actions.contains(&action_type);
return Ok(self.apply_mode(in_allowlist));
}
let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
if let ToolAction::BrowserAction { verb, target } = &action {
if Self::is_screenshot_verb(verb) {
if let Some(bucket) = &self.screenshot_bucket {
if !bucket.try_acquire() {
return Ok(match self.mode {
EnforcementMode::Observe => Verdict::Allow,
EnforcementMode::Guardrail | EnforcementMode::FailClosed => {
Verdict::Deny
}
});
}
}
return Ok(Verdict::Allow);
}
if matches!(
verb.to_ascii_lowercase().as_str(),
"navigate" | "goto" | "open"
) {
if let Some(url) = target {
return Ok(self.check_navigation(url));
}
}
}
Ok(Verdict::Allow)
}
}
fn matches_domain(pattern: &str, host: &str) -> bool {
let pattern = pattern.trim().to_ascii_lowercase();
let host = host.trim().to_ascii_lowercase();
if pattern.is_empty() || host.is_empty() {
return false;
}
if let Some(suffix) = pattern.strip_prefix("*.") {
return host == suffix || host.ends_with(&format!(".{suffix}"));
}
pattern == host
}
fn extract_host(url: &str) -> Option<String> {
let url = url.trim();
if url.is_empty() {
return None;
}
if url.starts_with('#') || url.starts_with('.') || url.starts_with('[') {
return None;
}
let lowered = url.to_ascii_lowercase();
if lowered.starts_with("data:")
|| lowered.starts_with("javascript:")
|| lowered.starts_with("about:")
|| lowered.starts_with("file:")
{
return None;
}
let rest = if lowered.starts_with("https://") {
&url["https://".len()..]
} else if lowered.starts_with("http://") {
&url["http://".len()..]
} else if let Some(rest) = url.strip_prefix("//") {
rest
} else {
url
};
let host_with_port = rest.split(['/', '?', '#']).next().unwrap_or(rest);
let host_without_userinfo = host_with_port
.rsplit_once('@')
.map(|(_, host)| host)
.unwrap_or(host_with_port);
let host = if let Some(bracketed) = host_without_userinfo.strip_prefix('[') {
let (host, remainder) = bracketed.split_once(']')?;
if !remainder.is_empty() && !remainder.starts_with(':') {
return None;
}
host
} else {
host_without_userinfo
.rsplit_once(':')
.map(|(h, _)| h)
.unwrap_or(host_without_userinfo)
}
.trim_matches(|c: char| c == '/' || c == '.');
if host.is_empty() {
return None;
}
Some(host.to_ascii_lowercase())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn matches_domain_exact_and_wildcard() {
assert!(matches_domain("example.com", "example.com"));
assert!(!matches_domain("example.com", "evil.com"));
assert!(matches_domain("*.example.com", "api.example.com"));
assert!(matches_domain("*.example.com", "example.com"));
assert!(!matches_domain("*.example.com", "example.org"));
}
#[test]
fn extract_host_handles_common_urls() {
assert_eq!(
extract_host("https://example.com/x"),
Some("example.com".into())
);
assert_eq!(
extract_host("HTTPS://169.254.169.254/latest"),
Some("169.254.169.254".into())
);
assert_eq!(
extract_host("https://user:pass@example.com:8443/x"),
Some("example.com".into())
);
assert_eq!(
extract_host("https://user@[fd00:ec2::254]:8443/x"),
Some("fd00:ec2::254".into())
);
assert_eq!(
extract_host("http://localhost:8080"),
Some("localhost".into())
);
assert_eq!(
extract_host("example.com:443/y"),
Some("example.com".into())
);
assert_eq!(
extract_host("//169.254.169.254/latest"),
Some("169.254.169.254".into())
);
assert_eq!(
extract_host("https://blocked.example?redir=1"),
Some("blocked.example".into())
);
assert_eq!(
extract_host("https://blocked.example#anchor"),
Some("blocked.example".into())
);
assert_eq!(extract_host("#submit"), None);
assert_eq!(extract_host("data:text/plain,hi"), None);
}
#[test]
fn check_navigation_blocks_scheme_relative_urls() {
let guard = ComputerUseGuard::with_config(ComputerUseConfig {
mode: EnforcementMode::FailClosed,
blocked_domains: vec!["169.254.169.254".into()],
..ComputerUseConfig::default()
});
assert_eq!(
guard.check_navigation("//169.254.169.254/latest"),
Verdict::Deny
);
}
#[test]
fn check_navigation_blocks_urls_with_userinfo() {
let guard = ComputerUseGuard::with_config(ComputerUseConfig {
mode: EnforcementMode::FailClosed,
blocked_domains: vec!["blocked.example".into()],
..ComputerUseConfig::default()
});
assert_eq!(
guard.check_navigation("https://user@blocked.example/path"),
Verdict::Deny
);
}
#[test]
fn check_navigation_blocks_bracketed_ipv6_hosts() {
let guard = ComputerUseGuard::with_config(ComputerUseConfig {
mode: EnforcementMode::FailClosed,
blocked_domains: vec!["fd00:ec2::254".into()],
..ComputerUseConfig::default()
});
assert_eq!(
guard.check_navigation("https://[fd00:ec2::254]/latest"),
Verdict::Deny
);
}
#[test]
fn check_navigation_blocks_query_and_fragment_only_urls() {
let guard = ComputerUseGuard::with_config(ComputerUseConfig {
mode: EnforcementMode::FailClosed,
blocked_domains: vec!["blocked.example".into()],
..ComputerUseConfig::default()
});
assert_eq!(
guard.check_navigation("https://blocked.example?redir=1"),
Verdict::Deny
);
assert_eq!(
guard.check_navigation("https://blocked.example#anchor"),
Verdict::Deny
);
}
#[test]
fn check_navigation_blocks_mixed_case_scheme_urls() {
let guard = ComputerUseGuard::with_config(ComputerUseConfig {
mode: EnforcementMode::FailClosed,
blocked_domains: vec!["169.254.169.254".into()],
..ComputerUseConfig::default()
});
assert_eq!(
guard.check_navigation("HTTPS://169.254.169.254/latest"),
Verdict::Deny
);
}
#[test]
fn is_screenshot_verb_matches_common_names() {
assert!(ComputerUseGuard::is_screenshot_verb("screenshot"));
assert!(ComputerUseGuard::is_screenshot_verb("capture_screen"));
assert!(!ComputerUseGuard::is_screenshot_verb("click"));
}
#[test]
fn extract_cua_action_type_reads_args() {
let args = serde_json::json!({"action_type": "remote.clipboard"});
assert_eq!(
ComputerUseGuard::extract_cua_action_type("unknown", &args),
Some("remote.clipboard".to_string())
);
}
}