use crate::Value;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
pub type CustomOpFn = Arc<dyn Fn(&Value, &Value) -> bool + Send + Sync>;
#[derive(Clone, Default)]
pub struct OpRegistry {
ops: HashMap<Box<str>, CustomOpFn>,
novalue_ops: HashSet<Box<str>>,
}
impl OpRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<F>(&mut self, name: impl Into<Box<str>>, novalue: bool, f: F) -> &mut Self
where
F: Fn(&Value, &Value) -> bool + Send + Sync + 'static,
{
let name = name.into();
if novalue {
self.novalue_ops.insert(name.clone());
} else {
self.novalue_ops.remove(&name);
}
self.ops.insert(name, Arc::new(f));
self
}
pub fn evaluate(&self, name: &str, field_value: &Value, query_value: &Value) -> Option<bool> {
self.ops.get(name).map(|f| f(field_value, query_value))
}
pub fn contains(&self, name: &str) -> bool {
self.ops.contains_key(name)
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn operator_names(&self) -> impl Iterator<Item = &str> {
self.ops.keys().map(|s| s.as_ref())
}
pub fn is_novalue(&self, name: &str) -> bool {
self.novalue_ops.contains(name)
}
pub fn novalue_ops(&self) -> impl Iterator<Item = &str> {
self.novalue_ops.iter().map(|s| s.as_ref())
}
pub fn merge(&mut self, other: Self) -> &mut Self {
self.ops.extend(other.ops);
self.novalue_ops.extend(other.novalue_ops);
self
}
}
impl std::fmt::Debug for OpRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpRegistry")
.field("operators", &self.ops.keys().collect::<Vec<_>>())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_and_evaluate() {
let mut registry = OpRegistry::new();
registry.register(
"IS_POSITIVE",
true,
|field, _| matches!(field, Value::Int(n) if *n > 0),
);
assert!(registry.contains("IS_POSITIVE"));
assert!(!registry.contains("UNKNOWN"));
let result = registry.evaluate("IS_POSITIVE", &Value::Int(42), &Value::None);
assert_eq!(result, Some(true));
let result = registry.evaluate("IS_POSITIVE", &Value::Int(-5), &Value::None);
assert_eq!(result, Some(false));
let result = registry.evaluate("UNKNOWN", &Value::Int(42), &Value::None);
assert_eq!(result, None);
}
#[test]
fn test_between_operator() {
let mut registry = OpRegistry::new();
registry.register("BETWEEN", false, |field, query| {
let Value::FloatArray(range) = query else {
return false;
};
if range.len() < 2 {
return false;
}
match field {
Value::Int(n) => (*n as f64) >= range[0] && (*n as f64) <= range[1],
Value::Float(n) => *n >= range[0] && *n <= range[1],
_ => false,
}
});
let range = Value::from(vec![10.0, 100.0]);
assert_eq!(
registry.evaluate("BETWEEN", &Value::Int(50), &range),
Some(true)
);
assert_eq!(
registry.evaluate("BETWEEN", &Value::Int(5), &range),
Some(false)
);
assert_eq!(
registry.evaluate("BETWEEN", &Value::Float(50.5), &range),
Some(true)
);
}
}