use std::time::Duration;
use reddb_wire::topology::{
self as wire, decode_topology, Endpoint as WireEndpoint, ReplicaInfo, Topology,
};
pub const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UriSeed {
pub primary: String,
pub replicas: Vec<String>,
}
impl UriSeed {
pub fn single(primary: impl Into<String>) -> Self {
Self {
primary: primary.into(),
replicas: Vec::new(),
}
}
pub fn cluster(primary: impl Into<String>, replicas: Vec<String>) -> Self {
Self {
primary: primary.into(),
replicas,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClusterMembership {
pub primary: WireEndpoint,
pub replicas: Vec<ReplicaInfo>,
pub epoch: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConsumeError {
Truncated,
BodyLengthMismatch { declared: u32, available: usize },
InvalidUtf8,
StringTooLong { declared: u32, remaining: usize },
UnknownVersion,
MalformedEnvelope,
}
impl std::fmt::Display for ConsumeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Truncated => write!(f, "topology blob truncated"),
Self::BodyLengthMismatch {
declared,
available,
} => write!(
f,
"topology body length mismatch (declared {declared}, available {available})"
),
Self::InvalidUtf8 => write!(f, "topology string field is not valid UTF-8"),
Self::StringTooLong {
declared,
remaining,
} => write!(
f,
"topology string length {declared} exceeds remaining body {remaining}"
),
Self::UnknownVersion => write!(
f,
"topology wire version tag past MAX_KNOWN_TOPOLOGY_VERSION; falling back to URI-only routing"
),
Self::MalformedEnvelope => write!(
f,
"topology envelope (HelloAck base64) is malformed; falling back to URI-only routing"
),
}
}
}
impl std::error::Error for ConsumeError {}
impl ConsumeError {
pub fn is_recoverable(&self) -> bool {
matches!(self, Self::UnknownVersion | Self::MalformedEnvelope)
}
}
impl From<wire::TopologyError> for ConsumeError {
fn from(e: wire::TopologyError) -> Self {
match e {
wire::TopologyError::Truncated => Self::Truncated,
wire::TopologyError::BodyLengthMismatch {
declared,
available,
} => Self::BodyLengthMismatch {
declared,
available,
},
wire::TopologyError::InvalidUtf8 => Self::InvalidUtf8,
wire::TopologyError::StringTooLong {
declared,
remaining,
} => Self::StringTooLong {
declared,
remaining,
},
}
}
}
#[derive(Debug, Default)]
pub struct TopologyConsumer;
impl TopologyConsumer {
pub fn consume(payload: Topology, uri_seed: Option<UriSeed>) -> ClusterMembership {
let Topology {
epoch,
primary,
replicas,
} = payload;
let _ = uri_seed;
ClusterMembership {
primary,
replicas,
epoch,
}
}
pub fn consume_bytes(
bytes: &[u8],
uri_seed: Option<UriSeed>,
) -> Result<ClusterMembership, ConsumeError> {
match decode_topology(bytes)? {
Some(t) => Ok(Self::consume(t, uri_seed)),
None => Err(ConsumeError::UnknownVersion),
}
}
pub fn consume_hello_ack(
field: &str,
uri_seed: Option<UriSeed>,
) -> Result<ClusterMembership, ConsumeError> {
match decode_base64(field) {
None => Err(ConsumeError::MalformedEnvelope),
Some(bytes) => Self::consume_bytes(&bytes, uri_seed),
}
}
pub fn should_refresh(current_epoch: u64, observed_epoch: u64) -> bool {
observed_epoch > current_epoch
}
}
pub trait Clock: std::fmt::Debug {
fn now_monotonic_ms(&self) -> u64;
}
#[derive(Debug)]
pub struct SystemClock {
origin: std::time::Instant,
}
impl Default for SystemClock {
fn default() -> Self {
Self {
origin: std::time::Instant::now(),
}
}
}
impl Clock for SystemClock {
fn now_monotonic_ms(&self) -> u64 {
self.origin.elapsed().as_millis() as u64
}
}
#[derive(Debug)]
pub struct RefreshScheduler {
interval: Duration,
clock: Box<dyn Clock + Send + Sync>,
last_refresh_ms: Option<u64>,
force: bool,
}
impl RefreshScheduler {
pub fn new() -> Self {
Self::with_interval(DEFAULT_REFRESH_INTERVAL)
}
pub fn with_interval(interval: Duration) -> Self {
Self::with_interval_and_clock(interval, Box::new(SystemClock::default()))
}
pub fn with_interval_and_clock(
interval: Duration,
clock: Box<dyn Clock + Send + Sync>,
) -> Self {
Self {
interval,
clock,
last_refresh_ms: None,
force: false,
}
}
pub fn should_refresh_now(&mut self) -> bool {
if self.force {
self.force = false;
return true;
}
let now = self.clock.now_monotonic_ms();
let interval_ms = self.interval.as_millis() as u64;
match self.last_refresh_ms {
None => true,
Some(last) => now.saturating_sub(last) >= interval_ms,
}
}
pub fn mark_refreshed(&mut self) {
self.last_refresh_ms = Some(self.clock.now_monotonic_ms());
}
pub fn force_now(&mut self) {
self.force = true;
}
}
impl Default for RefreshScheduler {
fn default() -> Self {
Self::new()
}
}
fn decode_base64(input: &str) -> Option<Vec<u8>> {
let trimmed = input.trim_end_matches('=');
let mut out = Vec::with_capacity(trimmed.len() * 3 / 4);
let mut buf = 0u32;
let mut bits = 0u8;
for ch in trimmed.bytes() {
let v: u32 = match ch {
b'A'..=b'Z' => (ch - b'A') as u32,
b'a'..=b'z' => (ch - b'a' + 26) as u32,
b'0'..=b'9' => (ch - b'0' + 52) as u32,
b'+' => 62,
b'/' => 63,
_ => return None,
};
buf = (buf << 6) | v;
bits += 6;
if bits >= 8 {
bits -= 8;
out.push(((buf >> bits) & 0xFF) as u8);
}
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
use reddb_wire::topology::{
encode_topology, encode_topology_for_hello_ack, Endpoint as WireEndpoint, ReplicaInfo,
Topology, TOPOLOGY_HEADER_SIZE, TOPOLOGY_WIRE_VERSION_V1,
};
fn fixture() -> Topology {
Topology {
epoch: 7,
primary: WireEndpoint {
addr: "primary.example.com:5050".into(),
region: "us-east-1".into(),
},
replicas: vec![
ReplicaInfo {
addr: "replica-a.example.com:5050".into(),
region: "us-east-1".into(),
healthy: true,
lag_ms: 12,
last_applied_lsn: 4242,
},
ReplicaInfo {
addr: "replica-b.example.com:5050".into(),
region: "us-west-2".into(),
healthy: false,
lag_ms: 999,
last_applied_lsn: 4100,
},
],
}
}
#[test]
fn parse_round_trip_grpc_bytes() {
let t = fixture();
let bytes = encode_topology(&t);
let m = TopologyConsumer::consume_bytes(&bytes, None).expect("consume");
assert_eq!(m.epoch, 7);
assert_eq!(m.primary, t.primary);
assert_eq!(m.replicas, t.replicas);
}
#[test]
fn parse_round_trip_hello_ack_field() {
let t = fixture();
let field = encode_topology_for_hello_ack(&t);
let m = TopologyConsumer::consume_hello_ack(&field, None).expect("consume");
assert_eq!(m.epoch, 7);
assert_eq!(m.primary, t.primary);
assert_eq!(m.replicas, t.replicas);
}
#[test]
fn fixture_byte_stable_across_runs() {
let t = fixture();
let a = encode_topology(&t);
let b = encode_topology(&t);
assert_eq!(a, b);
let field = encode_topology_for_hello_ack(&t);
let recovered = decode_base64(&field).expect("base64");
assert_eq!(recovered, a);
}
#[test]
fn merge_uri_only_replicas_dropped() {
let t = fixture();
let seed = UriSeed::cluster(
"primary.example.com:5050".to_string(),
vec![
"replica-a.example.com:5050".into(),
"replica-b.example.com:5050".into(),
"replica-stale.example.com:5050".into(),
],
);
let m = TopologyConsumer::consume(t.clone(), Some(seed));
assert_eq!(m.replicas.len(), 2, "URI-only replica must be dropped");
assert!(
m.replicas
.iter()
.all(|r| r.addr != "replica-stale.example.com:5050"),
"stale URI replica must not appear in membership"
);
}
#[test]
fn merge_advertised_only_replicas_added() {
let t = fixture();
let seed = UriSeed::single("primary.example.com:5050");
let m = TopologyConsumer::consume(t.clone(), Some(seed));
assert_eq!(m.replicas.len(), 2);
assert_eq!(m.replicas, t.replicas);
}
#[test]
fn merge_intersection_uses_advertised_metadata() {
let t = fixture();
let seed = UriSeed::cluster(
"primary.example.com:5050".to_string(),
vec!["replica-a.example.com:5050".into()],
);
let m = TopologyConsumer::consume(t.clone(), Some(seed));
let merged_a = m
.replicas
.iter()
.find(|r| r.addr == "replica-a.example.com:5050")
.expect("advertised replica must be present");
assert_eq!(merged_a.region, "us-east-1");
assert!(merged_a.healthy);
assert_eq!(merged_a.lag_ms, 12);
assert_eq!(merged_a.last_applied_lsn, 4242);
}
#[test]
fn merge_with_no_seed_keeps_full_advertisement() {
let t = fixture();
let m = TopologyConsumer::consume(t.clone(), None);
assert_eq!(m.primary, t.primary);
assert_eq!(m.replicas, t.replicas);
assert_eq!(m.epoch, t.epoch);
}
#[test]
fn should_refresh_skips_same_epoch() {
assert!(!TopologyConsumer::should_refresh(7, 7));
}
#[test]
fn should_refresh_advances_on_higher_epoch() {
assert!(TopologyConsumer::should_refresh(7, 8));
}
#[test]
fn should_refresh_treats_lower_epoch_as_stale() {
assert!(!TopologyConsumer::should_refresh(7, 6));
}
#[test]
fn malformed_truncated_blob_returns_typed_error() {
let err = TopologyConsumer::consume_bytes(&[0x01, 0x00], None).unwrap_err();
assert!(matches!(err, ConsumeError::Truncated));
assert!(!err.is_recoverable());
}
#[test]
fn malformed_body_length_mismatch_returns_typed_error() {
let bytes = vec![0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x00];
let err = TopologyConsumer::consume_bytes(&bytes, None).unwrap_err();
assert!(matches!(err, ConsumeError::BodyLengthMismatch { .. }));
assert!(!err.is_recoverable());
}
#[test]
fn unknown_version_tag_is_recoverable() {
let mut bytes = encode_topology(&fixture());
bytes[0] = 0xFE; let err = TopologyConsumer::consume_bytes(&bytes, None).unwrap_err();
assert!(matches!(err, ConsumeError::UnknownVersion));
assert!(err.is_recoverable());
}
#[test]
fn malformed_envelope_base64_is_recoverable() {
let err = TopologyConsumer::consume_hello_ack("@not base64@", None).unwrap_err();
assert!(matches!(err, ConsumeError::MalformedEnvelope));
assert!(err.is_recoverable());
}
#[test]
fn oversize_string_field_returns_typed_error() {
let mut body = Vec::new();
body.extend_from_slice(&0u64.to_le_bytes()); body.extend_from_slice(&u32::MAX.to_le_bytes());
let mut bytes = Vec::new();
bytes.push(TOPOLOGY_WIRE_VERSION_V1);
bytes.extend_from_slice(&(body.len() as u32).to_le_bytes());
bytes.extend_from_slice(&body);
assert_eq!(bytes.len(), TOPOLOGY_HEADER_SIZE + body.len());
let err = TopologyConsumer::consume_bytes(&bytes, None).unwrap_err();
assert!(
matches!(err, ConsumeError::StringTooLong { .. }),
"expected StringTooLong, got {err:?}"
);
assert!(!err.is_recoverable());
}
#[test]
fn invalid_utf8_string_returns_typed_error() {
let mut body = Vec::new();
body.extend_from_slice(&0u64.to_le_bytes()); body.extend_from_slice(&2u32.to_le_bytes()); body.extend_from_slice(&[0xFF, 0xFE]); let mut bytes = Vec::new();
bytes.push(TOPOLOGY_WIRE_VERSION_V1);
bytes.extend_from_slice(&(body.len() as u32).to_le_bytes());
bytes.extend_from_slice(&body);
let err = TopologyConsumer::consume_bytes(&bytes, None).unwrap_err();
assert!(
matches!(err, ConsumeError::InvalidUtf8),
"expected InvalidUtf8, got {err:?}"
);
}
#[test]
fn consume_does_not_panic_on_any_random_short_buffer() {
for n in 0..10usize {
let bytes = vec![0xAAu8; n];
let _ = TopologyConsumer::consume_bytes(&bytes, None);
}
}
#[derive(Debug)]
struct FakeClock {
ms: std::sync::Mutex<u64>,
}
impl FakeClock {
fn new() -> Self {
Self {
ms: std::sync::Mutex::new(0),
}
}
fn advance(&self, by: Duration) {
*self.ms.lock().unwrap() += by.as_millis() as u64;
}
}
impl Clock for FakeClock {
fn now_monotonic_ms(&self) -> u64 {
*self.ms.lock().unwrap()
}
}
fn scheduler_with(clock: std::sync::Arc<FakeClock>, interval: Duration) -> RefreshScheduler {
#[derive(Debug)]
struct ArcClock(std::sync::Arc<FakeClock>);
impl Clock for ArcClock {
fn now_monotonic_ms(&self) -> u64 {
self.0.now_monotonic_ms()
}
}
RefreshScheduler::with_interval_and_clock(interval, Box::new(ArcClock(clock)))
}
#[test]
fn fake_clock_first_call_refreshes_immediately() {
let clock = std::sync::Arc::new(FakeClock::new());
let mut s = scheduler_with(clock.clone(), Duration::from_secs(30));
assert!(s.should_refresh_now(), "first call must refresh");
}
#[test]
fn fake_clock_thirty_second_interval_holds_without_real_waits() {
let clock = std::sync::Arc::new(FakeClock::new());
let mut s = scheduler_with(clock.clone(), Duration::from_secs(30));
assert!(s.should_refresh_now());
s.mark_refreshed();
clock.advance(Duration::from_millis(29_999));
assert!(
!s.should_refresh_now(),
"must not refresh before interval elapsed"
);
clock.advance(Duration::from_millis(2));
assert!(
s.should_refresh_now(),
"must refresh once interval has elapsed"
);
}
#[test]
fn fake_clock_force_now_overrides_interval() {
let clock = std::sync::Arc::new(FakeClock::new());
let mut s = scheduler_with(clock.clone(), Duration::from_secs(30));
assert!(s.should_refresh_now());
s.mark_refreshed();
clock.advance(Duration::from_millis(100));
assert!(!s.should_refresh_now());
s.force_now();
assert!(s.should_refresh_now(), "force_now must override the timer");
s.mark_refreshed();
clock.advance(Duration::from_millis(100));
assert!(!s.should_refresh_now());
}
#[test]
fn default_interval_is_thirty_seconds() {
assert_eq!(DEFAULT_REFRESH_INTERVAL, Duration::from_secs(30));
}
}