use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EvalConditions {
#[serde(default)]
pub success: Vec<Condition>,
#[serde(default)]
pub failure: Vec<Condition>,
#[serde(default)]
pub on_timeout: TimeoutBehavior,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Condition {
pub name: String,
pub metric: String,
pub op: CompareOp,
pub value: ConditionValue,
}
impl Condition {
pub fn new(
name: impl Into<String>,
metric: impl Into<String>,
op: CompareOp,
value: impl Into<ConditionValue>,
) -> Self {
Self {
name: name.into(),
metric: metric.into(),
op,
value: value.into(),
}
}
pub fn evaluate(&self, actual: &ConditionValue) -> bool {
self.op.compare(actual, &self.value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompareOp {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
}
impl CompareOp {
pub fn compare(&self, actual: &ConditionValue, expected: &ConditionValue) -> bool {
match (actual, expected) {
(ConditionValue::Integer(a), ConditionValue::Integer(e)) => self.compare_ord(a, e),
(ConditionValue::Float(a), ConditionValue::Float(e)) => self.compare_float(*a, *e),
(ConditionValue::Integer(a), ConditionValue::Float(e)) => {
self.compare_float(*a as f64, *e)
}
(ConditionValue::Float(a), ConditionValue::Integer(e)) => {
self.compare_float(*a, *e as f64)
}
(ConditionValue::Bool(a), ConditionValue::Bool(e)) => match self {
CompareOp::Eq => a == e,
CompareOp::Ne => a != e,
_ => false,
},
(ConditionValue::String(a), ConditionValue::String(e)) => match self {
CompareOp::Eq => a == e,
CompareOp::Ne => a != e,
_ => false,
},
_ => false,
}
}
fn compare_ord<T: Ord>(&self, a: &T, e: &T) -> bool {
match self {
CompareOp::Eq => a == e,
CompareOp::Ne => a != e,
CompareOp::Gt => a > e,
CompareOp::Gte => a >= e,
CompareOp::Lt => a < e,
CompareOp::Lte => a <= e,
}
}
fn compare_float(&self, a: f64, e: f64) -> bool {
const EPSILON: f64 = 1e-9;
match self {
CompareOp::Eq => (a - e).abs() < EPSILON,
CompareOp::Ne => (a - e).abs() >= EPSILON,
CompareOp::Gt => a > e,
CompareOp::Gte => a >= e - EPSILON,
CompareOp::Lt => a < e,
CompareOp::Lte => a <= e + EPSILON,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ConditionValue {
Integer(i64),
Float(f64),
Bool(bool),
String(String),
}
impl From<i32> for ConditionValue {
fn from(v: i32) -> Self {
Self::Integer(v as i64)
}
}
impl From<i64> for ConditionValue {
fn from(v: i64) -> Self {
Self::Integer(v)
}
}
impl From<u64> for ConditionValue {
fn from(v: u64) -> Self {
Self::Integer(v as i64)
}
}
impl From<f64> for ConditionValue {
fn from(v: f64) -> Self {
Self::Float(v)
}
}
impl From<bool> for ConditionValue {
fn from(v: bool) -> Self {
Self::Bool(v)
}
}
impl From<&str> for ConditionValue {
fn from(v: &str) -> Self {
Self::String(v.to_string())
}
}
impl From<String> for ConditionValue {
fn from(v: String) -> Self {
Self::String(v)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TimeoutBehavior {
#[default]
Fail,
PartialSuccess,
MilestoneScore,
}
#[derive(Debug, Clone)]
pub struct ConditionResult {
pub name: String,
pub passed: bool,
pub actual: Option<ConditionValue>,
pub expected: ConditionValue,
}
impl ConditionResult {
pub fn new(condition: &Condition, actual: Option<ConditionValue>, passed: bool) -> Self {
Self {
name: condition.name.clone(),
passed,
actual,
expected: condition.value.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compare_op_integer() {
let a = ConditionValue::Integer(5);
let e = ConditionValue::Integer(3);
assert!(CompareOp::Gt.compare(&a, &e));
assert!(CompareOp::Gte.compare(&a, &e));
assert!(!CompareOp::Lt.compare(&a, &e));
assert!(!CompareOp::Lte.compare(&a, &e));
assert!(!CompareOp::Eq.compare(&a, &e));
assert!(CompareOp::Ne.compare(&a, &e));
}
#[test]
fn test_compare_op_float() {
let a = ConditionValue::Float(0.8);
let e = ConditionValue::Float(0.7);
assert!(CompareOp::Gt.compare(&a, &e));
assert!(CompareOp::Gte.compare(&a, &e));
}
#[test]
fn test_compare_op_mixed() {
let a = ConditionValue::Integer(5);
let e = ConditionValue::Float(3.0);
assert!(CompareOp::Gt.compare(&a, &e));
}
#[test]
fn test_condition_evaluate() {
let condition = Condition::new("completion_check", "task.completed", CompareOp::Gte, 5);
assert!(condition.evaluate(&ConditionValue::Integer(5)));
assert!(condition.evaluate(&ConditionValue::Integer(10)));
assert!(!condition.evaluate(&ConditionValue::Integer(3)));
}
#[test]
fn test_condition_deserialize() {
let json = r#"{
"name": "success_rate",
"metric": "task.success_rate",
"op": "gte",
"value": 0.8
}"#;
let condition: Condition = serde_json::from_str(json).unwrap();
assert_eq!(condition.name, "success_rate");
assert_eq!(condition.metric, "task.success_rate");
assert_eq!(condition.op, CompareOp::Gte);
}
#[test]
fn test_timeout_behavior_default() {
let behavior = TimeoutBehavior::default();
assert_eq!(behavior, TimeoutBehavior::Fail);
}
}