use std::fmt;
use regorus::{Engine, Value as RegoValue};
use serde_json::Value as JsonValue;
use sha2::{Digest, Sha256};
use uuid::Uuid;
use vti_common::error::AppError;
pub const POLICY_MODULE_PATH: &str = "policy.rego";
pub struct CompiledPolicy {
id: Uuid,
source_sha256: [u8; 32],
engine: Engine,
}
impl fmt::Debug for CompiledPolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CompiledPolicy")
.field("id", &self.id)
.field("source_sha256", &hex::encode(self.source_sha256))
.finish_non_exhaustive()
}
}
impl CompiledPolicy {
pub fn id(&self) -> Uuid {
self.id
}
pub fn source_sha256(&self) -> &[u8; 32] {
&self.source_sha256
}
}
pub fn compile(rego_source: &str, id: Uuid) -> Result<CompiledPolicy, AppError> {
let mut engine = Engine::new();
engine
.add_policy(POLICY_MODULE_PATH.to_string(), rego_source.to_string())
.map_err(|e| AppError::Validation(format!("rego compile failed for policy {id}: {e}")))?;
let source_sha256: [u8; 32] = Sha256::digest(rego_source.as_bytes()).into();
Ok(CompiledPolicy {
id,
source_sha256,
engine,
})
}
pub fn evaluate(
compiled: &CompiledPolicy,
query: &str,
input: JsonValue,
) -> Result<JsonValue, AppError> {
let mut engine = compiled.engine.clone();
engine.set_input(RegoValue::from(input));
let results = engine.eval_query(query.to_string(), false).map_err(|e| {
AppError::Internal(format!(
"rego evaluation failed for policy {}: {e}",
compiled.id
))
})?;
serde_json::to_value(results).map_err(AppError::from)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
const ALLOW_POLICY: &str = "\
package vtc.test
import rego.v1
default allow := false
allow if input.role == \"admin\"
";
const DENY_POLICY: &str = "\
package vtc.test
import rego.v1
default allow := false
allow if {
input.role == \"admin\"
input.context == \"prod\"
}
";
fn test_id() -> Uuid {
Uuid::from_u128(0x0102_0304_0506_0708_0900_0a0b_0c0d_0e0f)
}
#[test]
fn compile_happy_path() {
let id = test_id();
let compiled = compile(ALLOW_POLICY, id).expect("compile should succeed");
assert_eq!(compiled.id(), id);
let expected: [u8; 32] = Sha256::digest(ALLOW_POLICY.as_bytes()).into();
assert_eq!(compiled.source_sha256(), &expected);
}
#[test]
fn compile_surfaces_parse_error() {
let id = test_id();
let err = compile("not valid rego @@@ }}}", id).expect_err("malformed source must fail");
match err {
AppError::Validation(msg) => {
assert!(
msg.contains(&id.to_string()),
"error message should name the policy id: {msg}"
);
assert!(
msg.contains("rego compile failed"),
"error message should be a compile-failure: {msg}"
);
}
other => panic!("expected Validation error, got {other:?}"),
}
}
#[test]
fn evaluate_allow_true() {
let compiled = compile(ALLOW_POLICY, test_id()).unwrap();
let result = evaluate(&compiled, "data.vtc.test.allow", json!({ "role": "admin" }))
.expect("evaluate must succeed");
let value = pluck_expression_value(&result);
assert_eq!(value, &json!(true));
}
#[test]
fn evaluate_allow_false() {
let compiled = compile(DENY_POLICY, test_id()).unwrap();
let result = evaluate(
&compiled,
"data.vtc.test.allow",
json!({ "role": "admin", "context": "staging" }),
)
.expect("evaluate must succeed");
let value = pluck_expression_value(&result);
assert_eq!(value, &json!(false));
}
#[test]
fn evaluate_undefined_returns_empty_and_malformed_query_errors() {
let compiled = compile(ALLOW_POLICY, test_id()).unwrap();
let ok = evaluate(&compiled, "data.vtc.test.does_not_exist", json!({}))
.expect("undefined symbols must not surface as an error");
let value = ok.pointer("/result/0/expressions/0/value");
assert!(
value.is_none() || matches!(value, Some(JsonValue::Object(o)) if o.is_empty()),
"undefined rule should yield no value, got {ok}"
);
let err = evaluate(&compiled, "@@@ not a query @@@", json!({}))
.expect_err("malformed query must fail");
match err {
AppError::Internal(msg) => {
assert!(
msg.contains("rego evaluation failed"),
"error message should be an evaluation failure: {msg}"
);
}
other => panic!("expected Internal error, got {other:?}"),
}
}
#[test]
fn compile_sha_is_deterministic() {
let a = compile(ALLOW_POLICY, Uuid::new_v4()).unwrap();
let b = compile(ALLOW_POLICY, Uuid::new_v4()).unwrap();
assert_eq!(a.source_sha256(), b.source_sha256());
let c = compile(DENY_POLICY, Uuid::new_v4()).unwrap();
assert_ne!(a.source_sha256(), c.source_sha256());
}
fn pluck_expression_value(results: &JsonValue) -> &JsonValue {
results
.pointer("/result/0/expressions/0/value")
.expect("regorus QueryResults must carry result[0].expressions[0].value")
}
}