use std::collections::HashSet;
use std::sync::Arc;
use serde_json::Value;
use super::BoxFuture;
use crate::state::State;
pub type PredicateFn = Arc<dyn Fn(&Value, &Value) -> bool + Send + Sync>;
pub enum WatchPredicate {
Changed,
ChangedTo(Value),
ChangedFrom(Value),
CrossedAbove(f64),
CrossedBelow(f64),
BecameTrue,
BecameFalse,
Custom(PredicateFn),
}
impl std::fmt::Debug for WatchPredicate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Changed => write!(f, "Changed"),
Self::ChangedTo(v) => write!(f, "ChangedTo({v})"),
Self::ChangedFrom(v) => write!(f, "ChangedFrom({v})"),
Self::CrossedAbove(t) => write!(f, "CrossedAbove({t})"),
Self::CrossedBelow(t) => write!(f, "CrossedBelow({t})"),
Self::BecameTrue => write!(f, "BecameTrue"),
Self::BecameFalse => write!(f, "BecameFalse"),
Self::Custom(_) => write!(f, "Custom(<fn>)"),
}
}
}
impl WatchPredicate {
fn matches(&self, old: &Value, new: &Value) -> bool {
match self {
WatchPredicate::Changed => true,
WatchPredicate::ChangedTo(val) => new == val,
WatchPredicate::ChangedFrom(val) => old == val,
WatchPredicate::CrossedAbove(threshold) => match (as_f64(old), as_f64(new)) {
(Some(o), Some(n)) => o < *threshold && n >= *threshold,
_ => false,
},
WatchPredicate::CrossedBelow(threshold) => match (as_f64(old), as_f64(new)) {
(Some(o), Some(n)) => o >= *threshold && n < *threshold,
_ => false,
},
WatchPredicate::BecameTrue => old != &Value::Bool(true) && new == &Value::Bool(true),
WatchPredicate::BecameFalse => old == &Value::Bool(true) && new != &Value::Bool(true),
WatchPredicate::Custom(f) => f(old, new),
}
}
}
pub struct Watcher {
pub key: String,
pub predicate: WatchPredicate,
pub action: Arc<dyn Fn(Value, Value, State) -> BoxFuture<()> + Send + Sync>,
pub blocking: bool,
}
impl std::fmt::Debug for Watcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Watcher")
.field("key", &self.key)
.field("predicate", &self.predicate)
.field("blocking", &self.blocking)
.finish_non_exhaustive()
}
}
pub struct WatcherRegistry {
watchers: Vec<Watcher>,
observed_keys: HashSet<String>,
}
impl Default for WatcherRegistry {
fn default() -> Self {
Self::new()
}
}
impl WatcherRegistry {
pub fn new() -> Self {
Self {
watchers: Vec::new(),
observed_keys: HashSet::new(),
}
}
pub fn add(&mut self, watcher: Watcher) {
self.observed_keys.insert(watcher.key.clone());
self.watchers.push(watcher);
}
pub fn observed_keys(&self) -> &HashSet<String> {
&self.observed_keys
}
pub fn evaluate(
&self,
diffs: &[(String, Value, Value)],
state: &State,
) -> (Vec<BoxFuture<()>>, Vec<BoxFuture<()>>) {
let mut blocking = Vec::new();
let mut concurrent = Vec::new();
for (key, old, new) in diffs {
for watcher in &self.watchers {
if watcher.key == *key && watcher.predicate.matches(old, new) {
let fut = (watcher.action)(old.clone(), new.clone(), state.clone());
if watcher.blocking {
blocking.push(fut);
} else {
concurrent.push(fut);
}
}
}
}
(blocking, concurrent)
}
}
fn as_f64(v: &Value) -> Option<f64> {
match v {
Value::Number(n) => n.as_f64(),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::atomic::{AtomicU32, Ordering};
fn counting_watcher(
key: &str,
predicate: WatchPredicate,
counter: Arc<AtomicU32>,
blocking: bool,
) -> Watcher {
Watcher {
key: key.to_string(),
predicate,
action: Arc::new(move |_old, _new, _state| {
let c = counter.clone();
Box::pin(async move {
c.fetch_add(1, Ordering::SeqCst);
})
}),
blocking,
}
}
fn recording_watcher(key: &str, predicate: WatchPredicate, blocking: bool) -> Watcher {
Watcher {
key: key.to_string(),
predicate,
action: Arc::new(|old, new, state| {
Box::pin(async move {
state.set("recorded_old", old);
state.set("recorded_new", new);
})
}),
blocking,
}
}
#[tokio::test]
async fn changed_fires_on_any_diff() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"x",
WatchPredicate::Changed,
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("x".to_string(), json!(1), json!(2))];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert!(blocking.is_empty());
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn changed_to_fires_when_new_value_matches() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"status",
WatchPredicate::ChangedTo(json!("active")),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("status".to_string(), json!("inactive"), json!("active"))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn changed_to_does_not_fire_when_new_value_differs() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"status",
WatchPredicate::ChangedTo(json!("active")),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("status".to_string(), json!("inactive"), json!("pending"))];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert!(blocking.is_empty());
assert!(concurrent.is_empty());
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn changed_from_fires_when_old_value_matches() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"mode",
WatchPredicate::ChangedFrom(json!("draft")),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("mode".to_string(), json!("draft"), json!("published"))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
let diffs2 = vec![("mode".to_string(), json!("published"), json!("archived"))];
let (b, c) = registry.evaluate(&diffs2, &state);
assert!(b.is_empty());
assert!(c.is_empty());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn crossed_above_fires_on_upward_crossing() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"temp",
WatchPredicate::CrossedAbove(100.0),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("temp".to_string(), json!(95.0), json!(105.0))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn crossed_above_does_not_fire_when_both_above() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"temp",
WatchPredicate::CrossedAbove(100.0),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("temp".to_string(), json!(110.0), json!(120.0))];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert!(blocking.is_empty());
assert!(concurrent.is_empty());
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn crossed_below_fires_on_downward_crossing() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"battery",
WatchPredicate::CrossedBelow(20.0),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("battery".to_string(), json!(25.0), json!(15.0))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn became_true_fires_on_false_to_true() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"flag",
WatchPredicate::BecameTrue,
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("flag".to_string(), json!(false), json!(true))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn became_false_fires_on_true_to_false() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"flag",
WatchPredicate::BecameFalse,
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("flag".to_string(), json!(true), json!(false))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn custom_predicate_fires_when_fn_returns_true() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"score",
WatchPredicate::Custom(Arc::new(|old, new| {
match (as_f64(old), as_f64(new)) {
(Some(o), Some(n)) => (n - o * 2.0).abs() < f64::EPSILON,
_ => false,
}
})),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("score".to_string(), json!(5.0), json!(10.0))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
let diffs2 = vec![("score".to_string(), json!(5.0), json!(11.0))];
let (b, c) = registry.evaluate(&diffs2, &state);
assert!(b.is_empty());
assert!(c.is_empty());
}
#[tokio::test]
async fn evaluate_separates_blocking_and_concurrent() {
let blocking_counter = Arc::new(AtomicU32::new(0));
let concurrent_counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"x",
WatchPredicate::Changed,
blocking_counter.clone(),
true,
));
registry.add(counting_watcher(
"x",
WatchPredicate::Changed,
concurrent_counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("x".to_string(), json!(1), json!(2))];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(blocking.len(), 1);
assert_eq!(concurrent.len(), 1);
for fut in blocking {
fut.await;
}
for fut in concurrent {
fut.await;
}
assert_eq!(blocking_counter.load(Ordering::SeqCst), 1);
assert_eq!(concurrent_counter.load(Ordering::SeqCst), 1);
}
#[test]
fn evaluate_with_no_matching_diffs_returns_empty() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"x",
WatchPredicate::Changed,
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("y".to_string(), json!(1), json!(2))];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert!(blocking.is_empty());
assert!(concurrent.is_empty());
}
#[test]
fn observed_keys_tracks_added_watcher_keys() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
assert!(registry.observed_keys().is_empty());
registry.add(counting_watcher(
"alpha",
WatchPredicate::Changed,
counter.clone(),
false,
));
registry.add(counting_watcher(
"beta",
WatchPredicate::Changed,
counter.clone(),
false,
));
registry.add(counting_watcher(
"alpha",
WatchPredicate::BecameTrue,
counter.clone(),
true,
));
let keys = registry.observed_keys();
assert_eq!(keys.len(), 2);
assert!(keys.contains("alpha"));
assert!(keys.contains("beta"));
}
#[tokio::test]
async fn multiple_watchers_on_same_key() {
let counter_a = Arc::new(AtomicU32::new(0));
let counter_b = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"x",
WatchPredicate::Changed,
counter_a.clone(),
false,
));
registry.add(counting_watcher(
"x",
WatchPredicate::ChangedTo(json!(42)),
counter_b.clone(),
false,
));
let state = State::new();
let diffs = vec![("x".to_string(), json!(1), json!(42))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 2);
for fut in concurrent {
fut.await;
}
assert_eq!(counter_a.load(Ordering::SeqCst), 1);
assert_eq!(counter_b.load(Ordering::SeqCst), 1);
let diffs2 = vec![("x".to_string(), json!(42), json!(99))];
let (_, concurrent2) = registry.evaluate(&diffs2, &state);
assert_eq!(concurrent2.len(), 1);
for fut in concurrent2 {
fut.await;
}
assert_eq!(counter_a.load(Ordering::SeqCst), 2);
assert_eq!(counter_b.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn action_receives_old_new_and_state() {
let mut registry = WatcherRegistry::new();
registry.add(recording_watcher("val", WatchPredicate::Changed, false));
let state = State::new();
let diffs = vec![("val".to_string(), json!("before"), json!("after"))];
let (_, concurrent) = registry.evaluate(&diffs, &state);
assert_eq!(concurrent.len(), 1);
for fut in concurrent {
fut.await;
}
assert_eq!(state.get_raw("recorded_old"), Some(json!("before")));
assert_eq!(state.get_raw("recorded_new"), Some(json!("after")));
}
#[test]
fn crossed_above_with_non_numeric_values_does_not_fire() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"x",
WatchPredicate::CrossedAbove(10.0),
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("x".to_string(), json!("low"), json!("high"))];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert!(blocking.is_empty());
assert!(concurrent.is_empty());
}
#[test]
fn became_true_does_not_fire_on_non_bool() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"x",
WatchPredicate::BecameTrue,
counter.clone(),
false,
));
let state = State::new();
let diffs = vec![("x".to_string(), json!(0), json!("true"))];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert!(blocking.is_empty());
assert!(concurrent.is_empty());
}
#[test]
fn empty_diffs_produce_no_futures() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = WatcherRegistry::new();
registry.add(counting_watcher(
"x",
WatchPredicate::Changed,
counter.clone(),
false,
));
let state = State::new();
let diffs: Vec<(String, Value, Value)> = vec![];
let (blocking, concurrent) = registry.evaluate(&diffs, &state);
assert!(blocking.is_empty());
assert!(concurrent.is_empty());
}
#[test]
fn default_creates_empty_registry() {
let registry = WatcherRegistry::default();
assert!(registry.observed_keys().is_empty());
}
}