use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
pub fn default_allowed_input_types() -> Vec<String> {
vec![
"keyboard".to_string(),
"mouse".to_string(),
"touch".to_string(),
]
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct InputInjectionCapabilityConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_allowed_input_types")]
pub allowed_input_types: Vec<String>,
#[serde(default)]
pub require_postcondition_probe: bool,
#[serde(default = "default_true")]
pub strict: bool,
}
fn default_true() -> bool {
true
}
impl Default for InputInjectionCapabilityConfig {
fn default() -> Self {
Self {
enabled: true,
allowed_input_types: default_allowed_input_types(),
require_postcondition_probe: false,
strict: true,
}
}
}
pub struct InputInjectionCapabilityGuard {
enabled: bool,
allowed_types: HashSet<String>,
require_postcondition_probe: bool,
strict: bool,
}
impl InputInjectionCapabilityGuard {
pub fn new() -> Self {
Self::with_config(InputInjectionCapabilityConfig::default())
}
pub fn with_config(config: InputInjectionCapabilityConfig) -> Self {
Self {
enabled: config.enabled,
allowed_types: config.allowed_input_types.into_iter().collect(),
require_postcondition_probe: config.require_postcondition_probe,
strict: config.strict,
}
}
fn is_injection(tool_name: &str, arguments: &Value) -> bool {
if tool_name == "input.inject" || tool_name == "input_inject" {
return true;
}
for key in ["action_type", "actionType", "custom_type", "customType"] {
if let Some(v) = arguments.get(key).and_then(|v| v.as_str()) {
if v == "input.inject" {
return true;
}
}
}
arguments
.get("input_type")
.or_else(|| arguments.get("inputType"))
.and_then(|v| v.as_str())
.is_some()
&& (tool_name == "keyboard"
|| tool_name == "mouse"
|| tool_name == "touch"
|| tool_name == "input")
}
fn input_type(arguments: &Value) -> Option<&str> {
arguments
.get("input_type")
.or_else(|| arguments.get("inputType"))
.and_then(|v| v.as_str())
}
fn has_postcondition_probe(arguments: &Value) -> bool {
arguments
.get("postcondition_probe_hash")
.or_else(|| arguments.get("postconditionProbeHash"))
.and_then(|v| v.as_str())
.is_some_and(|s| !s.is_empty())
}
}
impl Default for InputInjectionCapabilityGuard {
fn default() -> Self {
Self::new()
}
}
impl Guard for InputInjectionCapabilityGuard {
fn name(&self) -> &str {
"input-injection-capability"
}
fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
if !self.enabled {
return Ok(Verdict::Allow);
}
if !Self::is_injection(&ctx.request.tool_name, &ctx.request.arguments) {
return Ok(Verdict::Allow);
}
match Self::input_type(&ctx.request.arguments) {
Some(it) => {
if !self.allowed_types.contains(it) {
return Ok(Verdict::Deny);
}
}
None => {
return Ok(if self.strict {
Verdict::Deny
} else {
Verdict::Allow
});
}
}
if self.require_postcondition_probe
&& !Self::has_postcondition_probe(&ctx.request.arguments)
{
return Ok(Verdict::Deny);
}
Ok(Verdict::Allow)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detects_explicit_input_inject_tool() {
let args = serde_json::json!({"input_type": "keyboard"});
assert!(InputInjectionCapabilityGuard::is_injection(
"input.inject",
&args
));
}
#[test]
fn detects_action_type_argument() {
let args = serde_json::json!({"action_type": "input.inject", "input_type": "mouse"});
assert!(InputInjectionCapabilityGuard::is_injection(
"generic", &args
));
}
#[test]
fn ignores_unrelated_tools() {
let args = serde_json::json!({"path": "/tmp/x"});
assert!(!InputInjectionCapabilityGuard::is_injection(
"read_file",
&args
));
}
#[test]
fn input_type_accepts_camel_case() {
let args = serde_json::json!({"inputType": "keyboard"});
assert_eq!(
InputInjectionCapabilityGuard::input_type(&args),
Some("keyboard")
);
}
#[test]
fn postcondition_probe_detected_both_cases() {
let snake = serde_json::json!({"postcondition_probe_hash": "sha256:abc"});
let camel = serde_json::json!({"postconditionProbeHash": "sha256:def"});
assert!(InputInjectionCapabilityGuard::has_postcondition_probe(
&snake
));
assert!(InputInjectionCapabilityGuard::has_postcondition_probe(
&camel
));
}
#[test]
fn postcondition_probe_empty_string_is_missing() {
let empty = serde_json::json!({"postcondition_probe_hash": ""});
assert!(!InputInjectionCapabilityGuard::has_postcondition_probe(
&empty
));
}
}