use std::io::BufReader;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant, SystemTime};
use crate::error::{DbError, DbResult};
use crate::shutdown::ShutdownSignal;
use super::apply::{ApplyOutcome, FixedReplicationTarget};
use super::cursor::{CURSOR_SAVE_INTERVAL, FixedReplicationCursor};
use super::protocol::*;
pub const RECONNECT_BASE_MS: u64 = 1000;
pub const RECONNECT_MAX_MS: u64 = 30_000;
pub struct FixedReplicationClient {
stop: ShutdownSignal,
handles: Vec<JoinHandle<()>>,
}
impl FixedReplicationClient {
pub fn start(
leader_addr: SocketAddr,
target: Arc<dyn FixedReplicationTarget>,
signal: ShutdownSignal,
) -> DbResult<Self> {
let shard_count = target.shard_count();
let mut handles = Vec::with_capacity(shard_count);
for shard_id in 0..shard_count {
let target = target.clone();
let stop = signal.clone();
let h = thread::spawn(move || {
run_shard_client(leader_addr, shard_id as u8, target, stop);
});
handles.push(h);
}
Ok(Self {
stop: signal,
handles,
})
}
pub fn stop(&self) {
self.stop.shutdown();
}
}
impl Drop for FixedReplicationClient {
fn drop(&mut self) {
self.stop.shutdown();
for h in self.handles.drain(..) {
let _ = h.join();
}
}
}
fn run_shard_client(
leader_addr: SocketAddr,
shard_id: u8,
target: Arc<dyn FixedReplicationTarget>,
stop: ShutdownSignal,
) {
let mut backoff_ms = RECONNECT_BASE_MS;
loop {
if stop.is_shutdown() {
return;
}
match connect_and_sync(leader_addr, shard_id, &target, &stop) {
Ok(()) => {
tracing::info!(shard_id, "fixed replication stream ended cleanly");
return;
}
Err(e) => {
tracing::warn!(
shard_id,
error = %e,
backoff_ms,
"fixed replication error, reconnecting"
);
metrics::counter!(
"armdb.fixed.reconnect_count",
"shard" => shard_id.to_string()
)
.increment(1);
if stop.wait_timeout(Duration::from_millis(backoff_ms)) {
return;
}
backoff_ms = (backoff_ms * 2).min(RECONNECT_MAX_MS);
}
}
}
}
fn connect_and_sync(
leader_addr: SocketAddr,
shard_id: u8,
target: &Arc<dyn FixedReplicationTarget>,
stop: &ShutdownSignal,
) -> DbResult<()> {
let stream = TcpStream::connect_timeout(&leader_addr, Duration::from_secs(5))?;
let _ = stream.set_nodelay(true);
let mut writer = stream.try_clone()?;
let mut reader = BufReader::new(stream);
let mut flags = 0u8;
if target.shard_occupied_count(shard_id) == 0 {
flags |= FLAG_EMPTY_STATE;
}
let req = SyncRequest {
shard_id,
protocol_version: PROTOCOL_VERSION,
flags,
};
write_frame(&mut writer, &req.encode())?;
let info_frame = read_frame(&mut reader)?;
if info_frame.msg_type != FixedMessageType::ShardInfo {
return Err(DbError::Replication(format!(
"expected ShardInfo, got {:?}",
info_frame.msg_type
)));
}
let info = ShardInfo::decode(&info_frame.payload)?;
if info.shard_count as usize != target.shard_count() {
return Err(DbError::ShardCountMismatch {
leader: info.shard_count as usize,
follower: target.shard_count(),
});
}
if info.key_len as usize != target.key_len() || info.value_len as usize != target.value_len() {
return Err(DbError::Replication(format!(
"K/V size mismatch: leader ({}, {}), follower ({}, {})",
info.key_len,
info.value_len,
target.key_len(),
target.value_len()
)));
}
if info.shard_prefix_bits != target.shard_prefix_bits() {
return Err(DbError::Replication(format!(
"shard_prefix_bits mismatch: leader {}, follower {}",
info.shard_prefix_bits,
target.shard_prefix_bits()
)));
}
target.grow_shard_to(shard_id, info.current_slot_count)?;
tracing::info!(
shard_id,
flags,
current_slot_count = info.current_slot_count,
"fixed follower connected to leader"
);
apply_loop(reader, writer, shard_id, info, target, stop)
}
#[allow(clippy::too_many_arguments)]
fn apply_loop(
mut reader: BufReader<TcpStream>,
mut writer: TcpStream,
shard_id: u8,
info: ShardInfo,
target: &Arc<dyn FixedReplicationTarget>,
stop: &ShutdownSignal,
) -> DbResult<()> {
use crate::fixed::slot::{STATUS_OCCUPIED, status_of, version_of};
let shard_dir = target.shard_dir(shard_id);
let mut cursor = FixedReplicationCursor::load(shard_id, &shard_dir)
.unwrap_or_else(|| FixedReplicationCursor::new(shard_id));
let mut events_since_save: u64 = 0;
let mut applied_total = cursor.applied_total;
let mut max_version_seen = cursor.max_version_seen;
let mut last_heartbeat_sent = Instant::now();
let hb = Duration::from_secs(HEARTBEAT_INTERVAL_SECS);
let _ = reader
.get_mut()
.set_read_timeout(Some(Duration::from_secs(1)));
loop {
if stop.is_shutdown() {
if events_since_save > 0 {
cursor.applied_total = applied_total;
cursor.max_version_seen = max_version_seen;
cursor.last_ack_at = SystemTime::now();
let _ = cursor.save(&shard_dir);
}
return Ok(());
}
let frame = match read_frame(&mut reader) {
Ok(f) => f,
Err(ref e)
if e.kind() == std::io::ErrorKind::TimedOut
|| e.kind() == std::io::ErrorKind::WouldBlock =>
{
if last_heartbeat_sent.elapsed() >= hb {
write_frame(&mut writer, &encode_heartbeat())?;
last_heartbeat_sent = Instant::now();
}
continue;
}
Err(e) => return Err(DbError::from(e)),
};
match frame.msg_type {
FixedMessageType::SlotBatch => {
let decoded = SlotBatchDecoder::new(
&frame.payload,
info.key_len as usize,
info.value_len as usize,
)?;
for event in decoded.iter() {
let event = event?;
let outcome = if status_of(event.meta) == STATUS_OCCUPIED {
target.apply_occupied(
shard_id,
event.slot_id,
event.meta,
event.key,
event.value,
)?
} else {
target.apply_deleted(shard_id, event.slot_id, event.meta, event.key)?
};
if matches!(outcome, ApplyOutcome::Applied) {
applied_total += 1;
let v = version_of(event.meta);
if v > max_version_seen {
max_version_seen = v;
}
}
events_since_save += 1;
}
let ack = Ack {
shard_id,
applied_count: applied_total,
max_version_seen,
};
write_frame(&mut writer, &ack.encode())?;
if events_since_save >= CURSOR_SAVE_INTERVAL as u64 {
cursor.applied_total = applied_total;
cursor.max_version_seen = max_version_seen;
cursor.last_ack_at = SystemTime::now();
cursor.save(&shard_dir)?;
metrics::gauge!(
"armdb.fixed.max_version_seen",
"shard" => shard_id.to_string()
)
.set(max_version_seen as f64);
events_since_save = 0;
}
}
FixedMessageType::CaughtUp => {
let cu = CaughtUp::decode(&frame.payload)?;
tracing::info!(
shard_id,
total_scanned = cu.total_scanned,
"fixed catch-up complete, entering streaming"
);
}
FixedMessageType::Heartbeat => {
write_frame(&mut writer, &encode_heartbeat())?;
}
FixedMessageType::Error => {
let msg = decode_error(&frame.payload);
return Err(DbError::Replication(msg));
}
other => {
return Err(DbError::Replication(format!(
"unexpected message: {other:?}"
)));
}
}
if last_heartbeat_sent.elapsed() >= hb {
write_frame(&mut writer, &encode_heartbeat())?;
last_heartbeat_sent = Instant::now();
}
}
}