use std::collections::HashMap;
use std::sync::atomic::{AtomicU16, AtomicU8, Ordering};
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use super::error::VersionError;
use super::wire::{WireVersion, CURRENT_WIRE_VERSION};
pub const FEATURE_FLOORS: &[(&str, WireVersion)] = &[
("baseline_v1_0", WireVersion::new(1, 0)),
("mutation.UpsertMemory", WireVersion::new(1, 0)),
("mutation.UpdateMemoryPatch", WireVersion::new(1, 0)),
("mutation.TombstoneMemory", WireVersion::new(1, 0)),
("mutation.PurgeMemory", WireVersion::new(1, 0)),
("mutation.UpsertEntityEdge", WireVersion::new(1, 0)),
("mutation.DeleteEntityEdge", WireVersion::new(1, 0)),
("mutation.TenantConfigPatch", WireVersion::new(1, 0)),
("mutation.UpsertMemory.materialized", WireVersion::new(1, 1)),
(
"mutation.TombstoneMemory.namespaced",
WireVersion::new(1, 2),
),
(
"mutation.DeleteEntityEdge.namespaced",
WireVersion::new(1, 2),
),
];
pub const PEER_OBSERVATION_TIMEOUT: Duration = Duration::from_secs(30);
struct PeerObservation {
version: WireVersion,
last_seen: Instant,
}
pub struct VersionGate {
local_wire: WireVersion,
cluster_min_major: AtomicU8,
cluster_min_minor: AtomicU16,
cluster_max_major: AtomicU8,
cluster_max_minor: AtomicU16,
peers: RwLock<HashMap<u32, PeerObservation>>,
feature_floors: HashMap<&'static str, WireVersion>,
}
impl VersionGate {
pub fn new(local: WireVersion) -> Self {
let feature_floors = FEATURE_FLOORS
.iter()
.map(|(name, ver)| (*name, *ver))
.collect();
Self {
local_wire: local,
cluster_min_major: AtomicU8::new(local.major),
cluster_min_minor: AtomicU16::new(local.minor),
cluster_max_major: AtomicU8::new(local.major),
cluster_max_minor: AtomicU16::new(local.minor),
peers: RwLock::new(HashMap::new()),
feature_floors,
}
}
pub fn for_local_build() -> Self {
Self::new(CURRENT_WIRE_VERSION)
}
pub fn local_wire(&self) -> WireVersion {
self.local_wire
}
pub fn cluster_min(&self) -> WireVersion {
WireVersion::new(
self.cluster_min_major.load(Ordering::Relaxed),
self.cluster_min_minor.load(Ordering::Relaxed),
)
}
pub fn cluster_max(&self) -> WireVersion {
WireVersion::new(
self.cluster_max_major.load(Ordering::Relaxed),
self.cluster_max_minor.load(Ordering::Relaxed),
)
}
pub fn observe_peer(&self, peer_id: u32, version: WireVersion) {
{
let mut peers = self.peers.write();
peers.insert(
peer_id,
PeerObservation {
version,
last_seen: Instant::now(),
},
);
}
self.recompute_min_max();
}
pub fn evict_stale_peers(&self) -> usize {
let now = Instant::now();
let dropped = {
let mut peers = self.peers.write();
let before = peers.len();
peers.retain(|_, obs| now.duration_since(obs.last_seen) < PEER_OBSERVATION_TIMEOUT);
before - peers.len()
};
if dropped > 0 {
self.recompute_min_max();
}
dropped
}
fn recompute_min_max(&self) {
let peers = self.peers.read();
let mut min_seen = self.local_wire;
let mut max_seen = self.local_wire;
for obs in peers.values() {
if obs.version < min_seen {
min_seen = obs.version;
}
if obs.version > max_seen {
max_seen = obs.version;
}
}
self.cluster_min_major
.store(min_seen.major, Ordering::Relaxed);
self.cluster_min_minor
.store(min_seen.minor, Ordering::Relaxed);
self.cluster_max_major
.store(max_seen.major, Ordering::Relaxed);
self.cluster_max_minor
.store(max_seen.minor, Ordering::Relaxed);
}
pub fn can_use_feature(&self, feature: &'static str) -> Result<(), VersionError> {
let floor = match self.feature_floors.get(feature).copied() {
Some(f) => f,
None => return Ok(()),
};
let current_min = self.cluster_min();
if current_min < floor {
return Err(VersionError::FeatureGated {
feature,
requires: floor,
cluster_min: current_min,
});
}
Ok(())
}
pub fn peer_count(&self) -> usize {
self.peers.read().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_gate_reports_local_as_min_and_max() {
let g = VersionGate::new(WireVersion::new(1, 5));
assert_eq!(g.cluster_min(), WireVersion::new(1, 5));
assert_eq!(g.cluster_max(), WireVersion::new(1, 5));
assert_eq!(g.local_wire(), WireVersion::new(1, 5));
}
#[test]
fn observing_lower_peer_lowers_min() {
let g = VersionGate::new(WireVersion::new(1, 5));
g.observe_peer(2, WireVersion::new(1, 0));
assert_eq!(g.cluster_min(), WireVersion::new(1, 0));
assert_eq!(g.cluster_max(), WireVersion::new(1, 5));
}
#[test]
fn observing_higher_peer_raises_max() {
let g = VersionGate::new(WireVersion::new(1, 5));
g.observe_peer(2, WireVersion::new(1, 9));
assert_eq!(g.cluster_min(), WireVersion::new(1, 5));
assert_eq!(g.cluster_max(), WireVersion::new(1, 9));
}
#[test]
fn evicting_stale_peer_recomputes_min() {
let g = VersionGate::new(WireVersion::new(1, 5));
{
let mut peers = g.peers.write();
peers.insert(
42,
PeerObservation {
version: WireVersion::new(1, 0),
last_seen: Instant::now() - PEER_OBSERVATION_TIMEOUT - Duration::from_secs(1),
},
);
}
g.recompute_min_max();
assert_eq!(g.cluster_min(), WireVersion::new(1, 0));
let dropped = g.evict_stale_peers();
assert_eq!(dropped, 1);
assert_eq!(g.cluster_min(), WireVersion::new(1, 5));
}
#[test]
fn can_use_feature_allows_known_when_floor_satisfied() {
let g = VersionGate::new(WireVersion::new(1, 5));
assert!(g.can_use_feature("baseline_v1_0").is_ok());
}
#[test]
fn can_use_feature_rejects_when_cluster_too_old() {
let g = VersionGate::new(WireVersion::new(1, 5));
g.observe_peer(2, WireVersion::new(1, 0));
assert!(g.can_use_feature("baseline_v1_0").is_ok());
let mut g2 = VersionGate::new(WireVersion::new(1, 5));
g2.feature_floors
.insert("synthetic_test_feature", WireVersion::new(1, 3));
g2.observe_peer(99, WireVersion::new(1, 0));
let err = g2.can_use_feature("synthetic_test_feature").unwrap_err();
match err {
VersionError::FeatureGated {
feature,
requires,
cluster_min,
} => {
assert_eq!(feature, "synthetic_test_feature");
assert_eq!(requires, WireVersion::new(1, 3));
assert_eq!(cluster_min, WireVersion::new(1, 0));
}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn unknown_features_are_allowed() {
let g = VersionGate::new(WireVersion::new(1, 5));
assert!(g.can_use_feature("not_in_registry").is_ok());
}
#[test]
fn rolling_upgrade_simulation() {
let g = VersionGate::new(WireVersion::new(1, 0));
g.observe_peer(2, WireVersion::new(1, 0));
g.observe_peer(3, WireVersion::new(1, 0));
assert_eq!(g.cluster_min(), WireVersion::new(1, 0));
assert_eq!(g.cluster_max(), WireVersion::new(1, 0));
g.observe_peer(2, WireVersion::new(1, 5));
assert_eq!(g.cluster_min(), WireVersion::new(1, 0));
assert_eq!(g.cluster_max(), WireVersion::new(1, 5));
g.observe_peer(3, WireVersion::new(1, 5));
assert_eq!(g.cluster_min(), WireVersion::new(1, 0));
}
#[test]
fn peer_count_tracks_observation_set() {
let g = VersionGate::new(WireVersion::new(1, 0));
assert_eq!(g.peer_count(), 0);
g.observe_peer(2, WireVersion::new(1, 0));
assert_eq!(g.peer_count(), 1);
g.observe_peer(3, WireVersion::new(1, 0));
assert_eq!(g.peer_count(), 2);
g.observe_peer(2, WireVersion::new(1, 1));
assert_eq!(g.peer_count(), 2);
}
}