use std::io::BufReader;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use crate::error::{DbError, DbResult};
use crate::shard::Shard;
use crate::shutdown::ShutdownSignal;
use super::cursor::ReplicationCursor;
use super::protocol::*;
use super::{ReplicationEntry, ReplicationRegistry};
const ACK_INTERVAL_CATCHUP: usize = 1; const ACK_INTERVAL: usize = 1000; const RECONNECT_BASE_MS: u64 = 1000;
const RECONNECT_MAX_MS: u64 = 30_000;
const CURSOR_SAVE_INTERVAL: usize = 1000;
const CURSOR_FILENAME: &str = "repl.cursor";
pub struct ReplicationClient {
stop: ShutdownSignal,
handles: Vec<JoinHandle<()>>,
}
pub struct ReplicationClientOptions {
pub reconnect_base_ms: u64,
pub reconnect_max_ms: u64,
}
impl Default for ReplicationClientOptions {
fn default() -> Self {
Self {
reconnect_base_ms: RECONNECT_BASE_MS,
reconnect_max_ms: RECONNECT_MAX_MS,
}
}
}
impl ReplicationClient {
pub fn start(
leader_addr: SocketAddr,
shards: Arc<Vec<Shard>>,
registry: Arc<ReplicationRegistry>,
key_len: u16,
signal: ShutdownSignal,
) -> DbResult<Self> {
Self::start_with_options(
leader_addr,
shards,
registry,
key_len,
signal,
ReplicationClientOptions::default(),
)
}
pub fn start_with_options(
leader_addr: SocketAddr,
shards: Arc<Vec<Shard>>,
registry: Arc<ReplicationRegistry>,
key_len: u16,
signal: ShutdownSignal,
options: ReplicationClientOptions,
) -> DbResult<Self> {
let mut handles = Vec::with_capacity(shards.len());
let reconnect_base_ms = options.reconnect_base_ms;
let reconnect_max_ms = options.reconnect_max_ms;
for shard_id in 0..shards.len() {
let shards = shards.clone();
let registry = registry.clone();
let stop = signal.clone();
let handle = thread::spawn(move || {
run_shard_client(
leader_addr,
&shards,
shard_id,
®istry,
key_len,
&stop,
reconnect_base_ms,
reconnect_max_ms,
);
});
handles.push(handle);
}
Ok(Self {
stop: signal,
handles,
})
}
pub fn stop(&self) {
self.stop.shutdown();
}
}
impl Drop for ReplicationClient {
fn drop(&mut self) {
self.stop.shutdown();
for h in self.handles.drain(..) {
let _ = h.join();
}
}
}
#[allow(clippy::too_many_arguments)]
fn run_shard_client(
leader_addr: SocketAddr,
shards: &[Shard],
shard_id: usize,
registry: &ReplicationRegistry,
key_len: u16,
stop: &ShutdownSignal,
reconnect_base_ms: u64,
reconnect_max_ms: u64,
) {
let shard = &shards[shard_id];
let shard_dir = shard.dir().to_path_buf();
let mut backoff_ms = reconnect_base_ms;
loop {
if stop.is_shutdown() {
return;
}
let cursor_path = shard_dir.join(CURSOR_FILENAME);
let cursor = ReplicationCursor::load(&cursor_path)
.ok()
.flatten()
.unwrap_or_default();
match connect_and_stream(
leader_addr,
shards,
shard_id,
&cursor,
&cursor_path,
registry,
key_len,
stop.as_flag(),
) {
Ok(()) => {
tracing::info!(shard_id, "replication stream ended cleanly");
return;
}
Err(e) => {
tracing::error!(shard_id, error = %e, backoff_ms, "replication error, reconnecting");
if stop.wait_timeout(Duration::from_millis(backoff_ms)) {
return;
}
backoff_ms = (backoff_ms * 2).min(reconnect_max_ms);
}
}
}
}
fn send_ack(writer: &mut TcpStream, shard_id: u8, last_gsn: u64) -> DbResult<()> {
let ack = AckMessage { shard_id, last_gsn };
write_frame(writer, &ack.encode())?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn connect_and_stream(
leader_addr: SocketAddr,
shards: &[Shard],
shard_id: usize,
cursor: &ReplicationCursor,
cursor_path: &std::path::Path,
registry: &ReplicationRegistry,
key_len: u16,
stop: &std::sync::atomic::AtomicBool,
) -> DbResult<()> {
let stream = TcpStream::connect_timeout(&leader_addr, Duration::from_secs(5))?;
let _ = stream.set_nodelay(true);
let mut writer = stream.try_clone()?;
let mut reader = BufReader::new(stream);
let req = SyncRequest {
shard_id: shard_id as u8,
from_gsn: cursor.last_gsn.saturating_add(1),
key_len,
};
write_frame(&mut writer, &req.encode())?;
let info_frame = read_frame(&mut reader)?;
if info_frame.msg_type != MessageType::ShardInfo {
return Err(DbError::Replication(format!(
"expected ShardInfo, got {:?}",
info_frame.msg_type
)));
}
let info = ShardInfo::decode(&info_frame.payload)?;
if info.shard_count as usize != shards.len() {
return Err(DbError::ShardCountMismatch {
leader: info.shard_count as usize,
follower: shards.len(),
});
}
tracing::info!(shard_id, from_gsn = cursor.last_gsn, "connected to leader");
let shard = &shards[shard_id];
let last_applied = AtomicU64::new(cursor.last_gsn);
let mut entries_since_ack = 0usize;
let mut entries_since_save = 0usize;
let mut in_catchup = true;
loop {
if stop.load(Ordering::Relaxed) {
if entries_since_save > 0 {
let gsn = last_applied.load(Ordering::Relaxed);
let _ = ReplicationCursor { last_gsn: gsn }.save(cursor_path);
}
return Ok(());
}
let frame = read_frame(&mut reader)?;
match frame.msg_type {
MessageType::EntryBatch => {
let batch = EntryBatch::decode(&frame.payload)?;
if batch.shard_id != shard_id as u8 {
return Err(DbError::Replication("EntryBatch shard_id mismatch".into()));
}
for wire_entry in &batch.entries {
let repl_entry = ReplicationEntry {
data: wire_entry.data.clone(),
key_len: wire_entry.key_len,
};
registry.apply_streaming(shard, &repl_entry, &last_applied)?;
entries_since_ack += 1;
entries_since_save += 1;
}
let ack_threshold = if in_catchup {
ACK_INTERVAL_CATCHUP
} else {
ACK_INTERVAL
};
if entries_since_ack >= ack_threshold {
send_ack(
&mut writer,
shard_id as u8,
last_applied.load(Ordering::Relaxed),
)?;
entries_since_ack = 0;
}
if entries_since_save >= CURSOR_SAVE_INTERVAL {
let last_gsn = last_applied.load(Ordering::Relaxed);
ReplicationCursor { last_gsn }.save(cursor_path)?;
entries_since_save = 0;
}
}
MessageType::CaughtUp => {
let caught_up = CaughtUp::decode(&frame.payload)?;
tracing::info!(shard_id, leader_gsn = caught_up.leader_gsn, "caught up");
in_catchup = false;
}
MessageType::Heartbeat => {
write_frame(&mut writer, &encode_heartbeat())?;
}
MessageType::Error => {
let msg = decode_error(&frame.payload);
return Err(DbError::Replication(msg));
}
other => {
return Err(DbError::Replication(format!(
"unexpected message: {other:?}"
)));
}
}
}
}