use crate::error::WSError;
use serde::{Deserialize, Serialize};
use wsc_attestation::TransformationAttestation;
pub struct RegoEngine {
engine: regorus::Engine,
policy_loaded: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegoInput {
pub attestation: serde_json::Value,
pub slsa_level: u8,
pub current_time: String,
#[serde(default)]
pub context: serde_json::Value,
}
#[derive(Debug, Clone, Default)]
pub struct RegoResult {
pub allowed: bool,
pub violations: Vec<String>,
pub warnings: Vec<String>,
pub raw: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub enum RegoError {
ParseError(String),
EvalError(String),
SerdeError(String),
IoError(String),
NoPolicyLoaded,
}
impl std::fmt::Display for RegoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RegoError::ParseError(msg) => write!(f, "Rego parse error: {}", msg),
RegoError::EvalError(msg) => write!(f, "Rego evaluation error: {}", msg),
RegoError::SerdeError(msg) => write!(f, "Rego serialization error: {}", msg),
RegoError::IoError(msg) => write!(f, "Rego I/O error: {}", msg),
RegoError::NoPolicyLoaded => write!(f, "No Rego policy loaded"),
}
}
}
impl std::error::Error for RegoError {}
impl From<RegoError> for WSError {
fn from(e: RegoError) -> Self {
WSError::InternalError(e.to_string())
}
}
impl RegoEngine {
pub fn new() -> Result<Self, RegoError> {
let engine = regorus::Engine::new();
Ok(Self {
engine,
policy_loaded: false,
})
}
pub fn add_policy(&mut self, name: &str, policy: &str) -> Result<(), RegoError> {
self.engine
.add_policy(name.to_string(), policy.to_string())
.map_err(|e| RegoError::ParseError(e.to_string()))?;
self.policy_loaded = true;
Ok(())
}
pub fn add_policy_file(&mut self, path: &str) -> Result<(), RegoError> {
let content = std::fs::read_to_string(path)
.map_err(|e| RegoError::IoError(format!("{}: {}", path, e)))?;
self.add_policy(path, &content)
}
pub fn set_data(&mut self, data: serde_json::Value) -> Result<(), RegoError> {
let regorus_value = json_to_regorus(&data)?;
self.engine
.add_data(regorus_value)
.map_err(|e| RegoError::EvalError(e.to_string()))?;
Ok(())
}
pub fn set_data_file(&mut self, path: &str) -> Result<(), RegoError> {
let content = std::fs::read_to_string(path)
.map_err(|e| RegoError::IoError(format!("{}: {}", path, e)))?;
let data: serde_json::Value = serde_json::from_str(&content)
.map_err(|e| RegoError::SerdeError(e.to_string()))?;
self.set_data(data)
}
pub fn evaluate(&mut self, input: &RegoInput) -> Result<RegoResult, RegoError> {
if !self.policy_loaded {
return Err(RegoError::NoPolicyLoaded);
}
let input_json = serde_json::to_value(input)
.map_err(|e| RegoError::SerdeError(e.to_string()))?;
let input_regorus = json_to_regorus(&input_json)?;
self.engine.set_input(input_regorus);
let mut result = RegoResult::default();
match self.engine.eval_rule("data.wsc.policy.allow".to_string()) {
Ok(value) => {
result.allowed = regorus_to_bool(&value);
result.raw = Some(regorus_to_json(&value)?);
}
Err(_) => {
match self.engine.eval_rule("data.policy.allow".to_string()) {
Ok(value) => {
result.allowed = regorus_to_bool(&value);
result.raw = Some(regorus_to_json(&value)?);
}
Err(_) => {
result.allowed = false;
}
}
}
}
if let Ok(value) = self.engine.eval_rule("data.wsc.policy.violations".to_string()) {
result.violations = regorus_to_string_set(&value);
} else if let Ok(value) = self.engine.eval_rule("data.policy.violations".to_string()) {
result.violations = regorus_to_string_set(&value);
}
if let Ok(value) = self.engine.eval_rule("data.wsc.policy.warnings".to_string()) {
result.warnings = regorus_to_string_set(&value);
} else if let Ok(value) = self.engine.eval_rule("data.policy.warnings".to_string()) {
result.warnings = regorus_to_string_set(&value);
}
Ok(result)
}
pub fn eval_rule(&mut self, rule: &str) -> Result<serde_json::Value, RegoError> {
if !self.policy_loaded {
return Err(RegoError::NoPolicyLoaded);
}
let value = self
.engine
.eval_rule(rule.to_string())
.map_err(|e| RegoError::EvalError(e.to_string()))?;
regorus_to_json(&value)
}
}
impl Default for RegoEngine {
fn default() -> Self {
Self::new().expect("Failed to create Rego engine")
}
}
impl RegoInput {
pub fn from_attestation(attestation: &TransformationAttestation, slsa_level: u8) -> Self {
let attestation_json = serde_json::to_value(attestation)
.unwrap_or(serde_json::Value::Null);
Self {
attestation: attestation_json,
slsa_level,
current_time: chrono::Utc::now().to_rfc3339(),
context: serde_json::Value::Null,
}
}
pub fn with_context(mut self, context: serde_json::Value) -> Self {
self.context = context;
self
}
}
fn json_to_regorus(json: &serde_json::Value) -> Result<regorus::Value, RegoError> {
match json {
serde_json::Value::Null => Ok(regorus::Value::Null),
serde_json::Value::Bool(b) => Ok(regorus::Value::Bool(*b)),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(regorus::Value::from(i))
} else if let Some(f) = n.as_f64() {
Ok(regorus::Value::from(f))
} else {
Err(RegoError::SerdeError("Invalid number".to_string()))
}
}
serde_json::Value::String(s) => Ok(regorus::Value::String(s.clone().into())),
serde_json::Value::Array(arr) => {
let values: Result<Vec<_>, _> = arr.iter().map(json_to_regorus).collect();
Ok(regorus::Value::from(values?))
}
serde_json::Value::Object(obj) => {
let mut map = regorus::Value::new_object();
for (k, v) in obj {
let key = regorus::Value::String(k.clone().into());
let value = json_to_regorus(v)?;
if let Ok(obj_map) = map.as_object_mut() {
obj_map.insert(key, value);
} else {
return Err(RegoError::SerdeError("Failed to create object".to_string()));
}
}
Ok(map)
}
}
}
fn regorus_to_json(value: ®orus::Value) -> Result<serde_json::Value, RegoError> {
match value {
regorus::Value::Null | regorus::Value::Undefined => Ok(serde_json::Value::Null),
regorus::Value::Bool(b) => Ok(serde_json::Value::Bool(*b)),
regorus::Value::String(s) => Ok(serde_json::Value::String(s.to_string())),
regorus::Value::Number(n) => {
if let Some(f) = n.as_f64() {
let json_num = serde_json::Number::from_f64(f)
.ok_or_else(|| RegoError::SerdeError("Invalid number conversion".to_string()))?;
Ok(serde_json::Value::Number(json_num))
} else if let Some(i) = n.as_i64() {
Ok(serde_json::Value::Number(i.into()))
} else {
Err(RegoError::SerdeError("Invalid number".to_string()))
}
}
regorus::Value::Array(arr) => {
let values: Result<Vec<_>, _> = arr.iter().map(regorus_to_json).collect();
Ok(serde_json::Value::Array(values?))
}
regorus::Value::Set(set) => {
let values: Result<Vec<_>, _> = set.iter().map(regorus_to_json).collect();
Ok(serde_json::Value::Array(values?))
}
regorus::Value::Object(obj) => {
let mut map = serde_json::Map::new();
for (k, v) in obj.iter() {
let key = match k {
regorus::Value::String(s) => s.to_string(),
_ => k.to_string(),
};
map.insert(key, regorus_to_json(v)?);
}
Ok(serde_json::Value::Object(map))
}
}
}
fn regorus_to_bool(value: ®orus::Value) -> bool {
match value {
regorus::Value::Bool(b) => *b,
regorus::Value::Set(s) if s.len() == 1 => {
s.iter().next().map(regorus_to_bool).unwrap_or(false)
}
_ => false,
}
}
fn regorus_to_string_set(value: ®orus::Value) -> Vec<String> {
match value {
regorus::Value::Set(set) => {
set.iter()
.filter_map(|v| match v {
regorus::Value::String(s) => Some(s.to_string()),
_ => Some(v.to_string()),
})
.collect()
}
regorus::Value::Array(arr) => {
arr.iter()
.filter_map(|v| match v {
regorus::Value::String(s) => Some(s.to_string()),
_ => Some(v.to_string()),
})
.collect()
}
_ => Vec::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_engine_creation() {
let engine = RegoEngine::new();
assert!(engine.is_ok());
}
#[test]
fn test_simple_policy() {
let mut engine = RegoEngine::new().unwrap();
let policy = r#"
package wsc.policy
default allow := false
allow {
input.slsa_level >= 2
}
violations[msg] {
input.slsa_level < 2
msg := sprintf("SLSA level %d is below minimum 2", [input.slsa_level])
}
"#;
engine.add_policy("test.rego", policy).unwrap();
let input = RegoInput {
attestation: serde_json::json!({}),
slsa_level: 3,
current_time: "2025-01-01T00:00:00Z".to_string(),
context: serde_json::Value::Null,
};
let result = engine.evaluate(&input).unwrap();
assert!(result.allowed);
assert!(result.violations.is_empty());
let input = RegoInput {
attestation: serde_json::json!({}),
slsa_level: 1,
current_time: "2025-01-01T00:00:00Z".to_string(),
context: serde_json::Value::Null,
};
let result = engine.evaluate(&input).unwrap();
assert!(!result.allowed);
assert!(!result.violations.is_empty());
}
#[test]
fn test_policy_with_data() {
let mut engine = RegoEngine::new().unwrap();
let policy = r#"
package wsc.policy
default allow := false
allow {
input.attestation.tool.name == tool_name
data.trusted_tools[tool_name]
}
"#;
engine.add_policy("test.rego", policy).unwrap();
engine.set_data(serde_json::json!({
"trusted_tools": {
"loom": true,
"wac": true
}
})).unwrap();
let input = RegoInput {
attestation: serde_json::json!({
"tool": { "name": "loom" }
}),
slsa_level: 2,
current_time: "2025-01-01T00:00:00Z".to_string(),
context: serde_json::Value::Null,
};
let result = engine.evaluate(&input).unwrap();
assert!(result.allowed);
let input = RegoInput {
attestation: serde_json::json!({
"tool": { "name": "malicious-tool" }
}),
slsa_level: 2,
current_time: "2025-01-01T00:00:00Z".to_string(),
context: serde_json::Value::Null,
};
let result = engine.evaluate(&input).unwrap();
assert!(!result.allowed);
}
#[test]
fn test_no_policy_loaded() {
let mut engine = RegoEngine::new().unwrap();
let input = RegoInput {
attestation: serde_json::json!({}),
slsa_level: 2,
current_time: "2025-01-01T00:00:00Z".to_string(),
context: serde_json::Value::Null,
};
let result = engine.evaluate(&input);
assert!(matches!(result, Err(RegoError::NoPolicyLoaded)));
}
#[test]
fn test_json_conversion() {
let json = serde_json::json!({
"string": "hello",
"number": 42,
"float": 3.14,
"bool": true,
"null": null,
"array": [1, 2, 3],
"object": { "nested": "value" }
});
let regorus = json_to_regorus(&json).unwrap();
let back = regorus_to_json(®orus).unwrap();
assert_eq!(back["string"], json["string"]);
assert_eq!(back["bool"], json["bool"]);
assert!(back["null"].is_null());
}
}