use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const VERSION: &str = "0.4.0";
pub struct Client {
api_key: String,
base_url: String,
agent_id: String,
fail_open: bool,
http: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str, base_url: &str) -> Self {
Self {
api_key: api_key.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
agent_id: "rust-agent".to_string(),
fail_open: false,
http: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.unwrap(),
}
}
pub fn with_agent_id(mut self, id: &str) -> Self {
self.agent_id = id.to_string();
self
}
pub fn with_fail_open(mut self, fail_open: bool) -> Self {
self.fail_open = fail_open;
self
}
pub async fn evaluate(&self, action: Action) -> Result<Decision, Box<dyn std::error::Error>> {
let payload = EvalRequest {
agent_id: self.agent_id.clone(),
action: ActionReq {
action_type: action.action_type,
scope: action.scope,
count: action.count,
params: action.params,
},
};
let mut last_err = String::new();
for attempt in 0..3u32 {
if attempt > 0 {
tokio::time::sleep(std::time::Duration::from_millis(100 * 2u64.pow(attempt - 1))).await;
}
match self.http
.post(format!("{}/v1/evaluate", self.base_url))
.header("Content-Type", "application/json")
.header("X-Boundary-Key", &self.api_key)
.json(&payload)
.send()
.await
{
Ok(r) => {
if r.status().is_success() {
let body: EvalResponse = r.json().await?;
return Ok(Decision {
allowed: body.decision == "allow",
blocked: body.decision == "block",
needs_confirm: body.decision == "confirm",
decision: body.decision,
reason: body.reason,
boundary_rule: body.boundary_rule.unwrap_or_default(),
eval_time_ms: body.evaluation_time_ms,
audit_id: body.audit_id.unwrap_or_default(),
});
}
let status = r.status().as_u16();
last_err = format!("HTTP {}", status);
if status < 500 && status != 429 { break; } }
Err(e) => { last_err = format!("{}", e); }
}
}
if self.fail_open {
Ok(Decision {
allowed: true, blocked: false, needs_confirm: false,
decision: "allow".into(),
reason: format!("Engine unreachable (fail-open): {}", last_err),
boundary_rule: String::new(), eval_time_ms: 0.0, audit_id: String::new(),
})
} else {
Ok(Decision {
allowed: false, blocked: true, needs_confirm: false,
decision: "block".into(),
reason: format!("Engine unreachable (fail-closed): {}", last_err),
boundary_rule: String::new(), eval_time_ms: 0.0, audit_id: String::new(),
})
}
}
pub async fn health(&self) -> bool {
self.http.get(format!("{}/health", self.base_url))
.send().await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[derive(Debug, Default)]
pub struct Action {
pub action_type: String,
pub scope: String,
pub count: i64,
pub params: HashMap<String, String>,
}
#[derive(Debug)]
pub struct Decision {
pub decision: String,
pub reason: String,
pub boundary_rule: String,
pub eval_time_ms: f64,
pub audit_id: String,
pub allowed: bool,
pub blocked: bool,
pub needs_confirm: bool,
}
#[derive(Serialize)]
struct EvalRequest {
agent_id: String,
action: ActionReq,
}
#[derive(Serialize)]
struct ActionReq {
#[serde(rename = "type")]
action_type: String,
scope: String,
count: i64,
#[serde(skip_serializing_if = "HashMap::is_empty")]
params: HashMap<String, String>,
}
#[derive(Deserialize)]
struct EvalResponse {
decision: String,
reason: String,
boundary_rule: Option<String>,
evaluation_time_ms: f64,
audit_id: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = Client::new("bai_test_key", "https://engine.boundaryai.ai");
assert_eq!(client.api_key, "bai_test_key");
assert_eq!(client.base_url, "https://engine.boundaryai.ai");
assert_eq!(client.agent_id, "rust-agent");
assert!(!client.fail_open);
}
#[test]
fn test_client_trailing_slash_stripped() {
let client = Client::new("bai_key", "https://engine.boundaryai.ai/");
assert_eq!(client.base_url, "https://engine.boundaryai.ai");
}
#[test]
fn test_client_builder_methods() {
let client = Client::new("bai_key", "http://localhost:8080")
.with_agent_id("custom-agent")
.with_fail_open(true);
assert_eq!(client.agent_id, "custom-agent");
assert!(client.fail_open);
}
#[test]
fn test_action_default() {
let action = Action::default();
assert_eq!(action.action_type, "");
assert_eq!(action.scope, "");
assert_eq!(action.count, 0);
assert!(action.params.is_empty());
}
#[test]
fn test_action_builder() {
let mut params = HashMap::new();
params.insert("cwd".to_string(), "/home/user".to_string());
let action = Action {
action_type: "system.command".into(),
scope: "rm -rf /tmp/test".into(),
count: 1,
params,
};
assert_eq!(action.action_type, "system.command");
assert_eq!(action.scope, "rm -rf /tmp/test");
assert_eq!(action.count, 1);
assert_eq!(action.params.get("cwd").unwrap(), "/home/user");
}
#[test]
fn test_decision_parsing_block() {
let json_str = r#"{
"decision": "block",
"reason": "Destructive command detected",
"boundary_rule": "rule_system_command_block",
"evaluation_time_ms": 2.5,
"audit_id": "aud_123"
}"#;
let resp: EvalResponse = serde_json::from_str(json_str).unwrap();
let decision = Decision {
allowed: resp.decision == "allow",
blocked: resp.decision == "block",
needs_confirm: resp.decision == "confirm",
decision: resp.decision,
reason: resp.reason,
boundary_rule: resp.boundary_rule.unwrap_or_default(),
eval_time_ms: resp.evaluation_time_ms,
audit_id: resp.audit_id.unwrap_or_default(),
};
assert!(decision.blocked);
assert!(!decision.allowed);
assert!(!decision.needs_confirm);
assert_eq!(decision.reason, "Destructive command detected");
assert_eq!(decision.boundary_rule, "rule_system_command_block");
assert!((decision.eval_time_ms - 2.5).abs() < f64::EPSILON);
assert_eq!(decision.audit_id, "aud_123");
}
#[test]
fn test_decision_parsing_allow() {
let json_str = r#"{
"decision": "allow",
"reason": "Safe operation",
"evaluation_time_ms": 0.5
}"#;
let resp: EvalResponse = serde_json::from_str(json_str).unwrap();
let decision = Decision {
allowed: resp.decision == "allow",
blocked: resp.decision == "block",
needs_confirm: resp.decision == "confirm",
decision: resp.decision,
reason: resp.reason,
boundary_rule: resp.boundary_rule.unwrap_or_default(),
eval_time_ms: resp.evaluation_time_ms,
audit_id: resp.audit_id.unwrap_or_default(),
};
assert!(decision.allowed);
assert!(!decision.blocked);
assert!(!decision.needs_confirm);
assert_eq!(decision.reason, "Safe operation");
}
#[test]
fn test_decision_parsing_confirm() {
let json_str = r#"{
"decision": "confirm",
"reason": "Human review required",
"boundary_rule": "rule_sensitive_action",
"evaluation_time_ms": 1.2,
"audit_id": "aud_456"
}"#;
let resp: EvalResponse = serde_json::from_str(json_str).unwrap();
let decision = Decision {
allowed: resp.decision == "allow",
blocked: resp.decision == "block",
needs_confirm: resp.decision == "confirm",
decision: resp.decision,
reason: resp.reason,
boundary_rule: resp.boundary_rule.unwrap_or_default(),
eval_time_ms: resp.evaluation_time_ms,
audit_id: resp.audit_id.unwrap_or_default(),
};
assert!(decision.needs_confirm);
assert!(!decision.allowed);
assert!(!decision.blocked);
}
#[test]
fn test_decision_parsing_missing_optional_fields() {
let json_str = r#"{
"decision": "allow",
"reason": "ok",
"evaluation_time_ms": 0.1
}"#;
let resp: EvalResponse = serde_json::from_str(json_str).unwrap();
assert!(resp.boundary_rule.is_none());
assert!(resp.audit_id.is_none());
}
#[test]
fn test_eval_request_serialization() {
let req = EvalRequest {
agent_id: "test-agent".into(),
action: ActionReq {
action_type: "system.command".into(),
scope: "ls -la".into(),
count: 0,
params: HashMap::new(),
},
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["agent_id"], "test-agent");
assert_eq!(json["action"]["type"], "system.command");
assert_eq!(json["action"]["scope"], "ls -la");
assert!(json["action"].get("params").is_none());
}
#[test]
fn test_eval_request_with_params() {
let mut params = HashMap::new();
params.insert("env".to_string(), "production".to_string());
let req = EvalRequest {
agent_id: "agent-1".into(),
action: ActionReq {
action_type: "api.call".into(),
scope: "https://api.example.com".into(),
count: 5,
params,
},
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["action"]["params"]["env"], "production");
assert_eq!(json["action"]["count"], 5);
}
#[test]
fn test_version() {
assert_eq!(VERSION, "0.4.0");
assert!(!VERSION.is_empty());
}
#[tokio::test]
async fn test_health_unreachable() {
let client = Client::new("bai_test", "http://127.0.0.1:1");
let healthy = client.health().await;
assert!(!healthy);
}
#[tokio::test]
async fn test_evaluate_fail_closed() {
let client = Client::new("bai_test", "http://127.0.0.1:1");
let result = client.evaluate(Action {
action_type: "test".into(),
scope: "test".into(),
..Default::default()
}).await;
let decision = result.unwrap();
assert!(decision.blocked);
assert!(!decision.allowed);
assert!(decision.reason.contains("fail-closed"));
}
#[tokio::test]
async fn test_evaluate_fail_open() {
let client = Client::new("bai_test", "http://127.0.0.1:1")
.with_fail_open(true);
let result = client.evaluate(Action {
action_type: "test".into(),
scope: "test".into(),
..Default::default()
}).await;
let decision = result.unwrap();
assert!(decision.allowed);
assert!(!decision.blocked);
assert!(decision.reason.contains("fail-open"));
}
}