use crate::constraints::ConstraintValue;
use crate::error::{Error, Result};
use std::collections::HashMap;
#[cfg(feature = "cel")]
use cel_interpreter::{Context, Program, Value};
#[cfg(feature = "cel")]
use chrono::{DateTime, Utc};
#[cfg(feature = "cel")]
use ipnetwork::IpNetwork;
#[cfg(feature = "cel")]
use moka::sync::Cache;
#[cfg(feature = "cel")]
use std::net::IpAddr;
#[cfg(feature = "cel")]
use std::sync::Arc;
#[cfg(feature = "cel")]
static CEL_CACHE: std::sync::LazyLock<Cache<String, Arc<Program>>> =
std::sync::LazyLock::new(|| Cache::builder().max_capacity(1000).build());
#[cfg(feature = "cel")]
pub fn compile(expression: &str) -> Result<Arc<Program>> {
if let Some(program) = CEL_CACHE.get(expression) {
return Ok(program);
}
let program = Program::compile(expression)
.map_err(|e| Error::CelError(format!("compilation failed: {}", e)))?;
let program = Arc::new(program);
CEL_CACHE.insert(expression.to_string(), program.clone());
Ok(program)
}
#[cfg(not(feature = "cel"))]
pub fn compile(_expression: &str) -> Result<()> {
Err(Error::FeatureNotEnabled { feature: "cel" })
}
#[cfg(feature = "cel")]
pub fn evaluate(
expression: &str,
value: &ConstraintValue,
vars: &HashMap<String, ConstraintValue>,
) -> Result<bool> {
let program = compile(expression)?;
let mut context = create_context();
context
.add_variable("value", constraint_value_to_cel(value)?)
.map_err(|e| Error::CelError(format!("failed to add variable: {}", e)))?;
for (name, val) in vars {
context
.add_variable(name, constraint_value_to_cel(val)?)
.map_err(|e| Error::CelError(format!("failed to add variable '{}': {}", name, e)))?;
}
let result = program
.execute(&context)
.map_err(|e| Error::CelError(format!("execution failed: {}", e)))?;
match result {
Value::Bool(b) => Ok(b),
other => Err(Error::CelError(format!(
"expression must return bool, got {:?}",
other
))),
}
}
#[cfg(not(feature = "cel"))]
pub fn evaluate(
_expression: &str,
_value: &ConstraintValue,
_vars: &HashMap<String, ConstraintValue>,
) -> Result<bool> {
Err(Error::FeatureNotEnabled { feature: "cel" })
}
#[cfg(feature = "cel")]
pub fn evaluate_with_value_context(expression: &str, value: &ConstraintValue) -> Result<bool> {
let program = compile(expression)?;
let mut context = create_context();
match value {
ConstraintValue::Object(map) => {
for (key, val) in map {
context
.add_variable(key, constraint_value_to_cel(val)?)
.map_err(|e| {
Error::CelError(format!("failed to add variable '{}': {}", key, e))
})?;
}
}
other => {
context
.add_variable("value", constraint_value_to_cel(other)?)
.map_err(|e| Error::CelError(format!("failed to add variable: {}", e)))?;
}
}
let result = program
.execute(&context)
.map_err(|e| Error::CelError(format!("execution failed: {}", e)))?;
match result {
Value::Bool(b) => Ok(b),
other => Err(Error::CelError(format!(
"expression must return bool, got {:?}",
other
))),
}
}
#[cfg(not(feature = "cel"))]
pub fn evaluate_with_value_context(_expression: &str, _value: &ConstraintValue) -> Result<bool> {
Err(Error::FeatureNotEnabled { feature: "cel" })
}
#[cfg(feature = "cel")]
fn constraint_value_to_cel(cv: &ConstraintValue) -> Result<Value> {
match cv {
ConstraintValue::String(s) => Ok(Value::String(s.clone().into())),
ConstraintValue::Integer(i) => Ok(Value::Int(*i)),
ConstraintValue::Float(f) => Ok(Value::Float(*f)),
ConstraintValue::Boolean(b) => Ok(Value::Bool(*b)),
ConstraintValue::Null => Ok(Value::Null),
ConstraintValue::List(list) => {
let cel_list: std::result::Result<Vec<Value>, _> =
list.iter().map(constraint_value_to_cel).collect();
Ok(Value::List(cel_list?.into()))
}
ConstraintValue::Object(map) => {
let cel_map: std::result::Result<HashMap<String, Value>, _> = map
.iter()
.map(|(k, v)| constraint_value_to_cel(v).map(|cv| (k.clone(), cv)))
.collect();
Ok(Value::Map(cel_map?.into()))
}
}
}
#[cfg(feature = "cel")]
pub fn clear_cache() {
CEL_CACHE.invalidate_all();
}
#[cfg(not(feature = "cel"))]
pub fn clear_cache() {
}
#[cfg(feature = "cel")]
pub fn cache_size() -> u64 {
CEL_CACHE.entry_count()
}
#[cfg(not(feature = "cel"))]
pub fn cache_size() -> u64 {
0
}
#[cfg(feature = "cel")]
pub fn create_context() -> Context<'static> {
let mut context = Context::default();
context.add_function("time_now", |_unused: Value| -> String {
Utc::now().to_rfc3339()
});
context.add_function("time_is_expired", |timestamp: Value| -> bool {
let ts_str = match timestamp {
Value::String(s) => s,
_ => return false,
};
match DateTime::parse_from_rfc3339(&ts_str) {
Ok(dt) => dt < Utc::now(),
Err(_) => false,
}
});
context.add_function("time_since", |timestamp: Value| -> i64 {
let ts_str = match timestamp {
Value::String(s) => s,
_ => return 0,
};
match DateTime::parse_from_rfc3339(&ts_str) {
Ok(dt) => (Utc::now() - dt.with_timezone(&Utc)).num_seconds(),
Err(_) => 0,
}
});
context.add_function("net_in_cidr", |ip: Value, cidr: Value| -> bool {
let ip_str = match ip {
Value::String(s) => s,
_ => return false,
};
let cidr_str = match cidr {
Value::String(s) => s,
_ => return false,
};
let ip_addr: IpAddr = match ip_str.parse() {
Ok(addr) => addr,
Err(_) => return false,
};
let network: IpNetwork = match cidr_str.parse() {
Ok(net) => net,
Err(_) => return false,
};
network.contains(ip_addr)
});
context.add_function("net_is_private", |ip: Value| -> bool {
let ip_str = match ip {
Value::String(s) => s,
_ => return false,
};
let ip_addr: IpAddr = match ip_str.parse() {
Ok(addr) => addr,
Err(_) => return false,
};
match ip_addr {
IpAddr::V4(addr) => addr.is_private(),
IpAddr::V6(addr) => (addr.segments()[0] & 0xfe00) == 0xfc00,
}
});
context
}
#[cfg(all(test, feature = "cel"))]
mod tests {
use super::*;
#[test]
fn test_simple_comparison() {
let value = ConstraintValue::Integer(5000);
assert!(evaluate("value < 10000", &value, &HashMap::new()).unwrap());
assert!(!evaluate("value > 10000", &value, &HashMap::new()).unwrap());
}
#[test]
fn test_string_operations() {
let value = ConstraintValue::String("staging-web".to_string());
assert!(evaluate("value.startsWith('staging')", &value, &HashMap::new()).unwrap());
assert!(!evaluate("value.startsWith('prod')", &value, &HashMap::new()).unwrap());
}
#[test]
fn test_boolean_logic() {
let value = ConstraintValue::Integer(7500);
assert!(evaluate("value > 5000 && value < 10000", &value, &HashMap::new()).unwrap());
assert!(evaluate("value < 1000 || value > 5000", &value, &HashMap::new()).unwrap());
}
#[test]
fn test_list_operations() {
let value = ConstraintValue::List(vec![
ConstraintValue::String("admin".to_string()),
ConstraintValue::String("user".to_string()),
]);
assert!(evaluate("'admin' in value", &value, &HashMap::new()).unwrap());
assert!(!evaluate("'superuser' in value", &value, &HashMap::new()).unwrap());
}
#[test]
fn test_object_context() {
let value = ConstraintValue::Object(
[
("amount".to_string(), ConstraintValue::Integer(5000)),
(
"currency".to_string(),
ConstraintValue::String("USD".to_string()),
),
]
.into_iter()
.collect(),
);
assert!(evaluate_with_value_context("amount < 10000", &value).unwrap());
assert!(evaluate_with_value_context("currency == 'USD'", &value).unwrap());
assert!(
evaluate_with_value_context("amount < 10000 && currency == 'USD'", &value).unwrap()
);
}
#[test]
fn test_complex_expression() {
let value = ConstraintValue::Object(
[
("amount".to_string(), ConstraintValue::Integer(75000)),
(
"approver".to_string(),
ConstraintValue::String("cfo@company.com".to_string()),
),
]
.into_iter()
.collect(),
);
let expr = "amount < 10000 || (amount < 100000 && approver != '')";
assert!(evaluate_with_value_context(expr, &value).unwrap());
}
#[test]
fn test_cache_works() {
clear_cache();
let value = ConstraintValue::Integer(42);
evaluate("value == 42", &value, &HashMap::new()).unwrap();
evaluate("value == 42", &value, &HashMap::new()).unwrap();
evaluate("value > 0", &value, &HashMap::new()).unwrap();
let p1 = compile("value == 100").unwrap();
let p2 = compile("value == 100").unwrap();
assert!(
std::sync::Arc::ptr_eq(&p1, &p2),
"same expression should return same Arc"
);
}
#[test]
fn test_invalid_expression() {
let value = ConstraintValue::Integer(42);
let result = evaluate("this is not valid CEL !!!", &value, &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_non_bool_result_error() {
let value = ConstraintValue::Integer(42);
let result = evaluate("value + 1", &value, &HashMap::new());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("must return bool"));
}
}