use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_nats::connection::State;
use tokio::time::Instant;
#[derive(Clone, Debug, Default)]
pub struct Tracker {
inner: Arc<TrackerInner>,
}
#[derive(Debug, Default)]
struct TrackerInner {
last_connected_at: Mutex<Option<Instant>>,
}
impl Tracker {
pub fn new() -> Self {
Self::default()
}
pub fn on_event(
&self,
) -> impl Fn(async_nats::Event) -> std::future::Ready<()> + Send + Sync + 'static + use<> {
let inner = self.inner.clone();
move |event| {
if matches!(event, async_nats::Event::Connected) {
if let Ok(mut g) = inner.last_connected_at.lock() {
*g = Some(Instant::now());
}
}
std::future::ready(())
}
}
pub fn staleness(&self, client: &async_nats::Client) -> Duration {
if client.connection_state() == State::Connected {
return Duration::ZERO;
}
match self.inner.last_connected_at.lock().ok().and_then(|g| *g) {
Some(t) => Instant::now().saturating_duration_since(t),
None => Duration::MAX,
}
}
}
pub fn decide(policy: &kanade_shared::wire::Staleness, staleness: Duration) -> StalenessDecision {
use kanade_shared::wire::Staleness;
match policy {
Staleness::Unchecked | Staleness::Cached => StalenessDecision::Proceed,
Staleness::Strict { max_cache_age } => {
let max = humantime::parse_duration(max_cache_age).unwrap_or(Duration::ZERO);
if staleness <= max {
StalenessDecision::Proceed
} else {
StalenessDecision::Skip {
observed: staleness,
allowed: max,
}
}
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum StalenessDecision {
Proceed,
Skip {
observed: Duration,
allowed: Duration,
},
}
#[cfg(test)]
mod tests {
use super::*;
use kanade_shared::wire::Staleness;
#[test]
fn cached_always_proceeds() {
assert_eq!(
decide(&Staleness::Cached, Duration::MAX),
StalenessDecision::Proceed,
);
assert_eq!(
decide(&Staleness::Cached, Duration::ZERO),
StalenessDecision::Proceed,
);
}
#[test]
fn unchecked_always_proceeds() {
assert_eq!(
decide(&Staleness::Unchecked, Duration::MAX),
StalenessDecision::Proceed,
);
}
#[test]
fn strict_zero_max_proceeds_when_currently_connected() {
let policy = Staleness::Strict {
max_cache_age: "0s".into(),
};
assert_eq!(decide(&policy, Duration::ZERO), StalenessDecision::Proceed);
}
#[test]
fn strict_zero_max_skips_when_any_disconnect() {
let policy = Staleness::Strict {
max_cache_age: "0s".into(),
};
let result = decide(&policy, Duration::from_secs(1));
match result {
StalenessDecision::Skip { observed, allowed } => {
assert_eq!(observed, Duration::from_secs(1));
assert_eq!(allowed, Duration::ZERO);
}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn strict_window_inclusive_boundary() {
let policy = Staleness::Strict {
max_cache_age: "5m".into(),
};
assert_eq!(
decide(&policy, Duration::from_secs(300)),
StalenessDecision::Proceed
);
assert!(matches!(
decide(&policy, Duration::from_secs(301)),
StalenessDecision::Skip { .. }
));
}
#[test]
fn strict_never_connected_skips() {
let policy = Staleness::Strict {
max_cache_age: "1h".into(),
};
assert!(matches!(
decide(&policy, Duration::MAX),
StalenessDecision::Skip { .. }
));
}
#[test]
fn strict_bogus_max_cache_age_fails_closed() {
let policy = Staleness::Strict {
max_cache_age: "notaduration".into(),
};
assert!(matches!(
decide(&policy, Duration::from_secs(1)),
StalenessDecision::Skip { .. }
));
}
}