use std::io::BufReader;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::thread::{self, JoinHandle};
use std::time::Duration;
use crate::error::{DbError, DbResult};
use crate::shard::Shard;
use crate::shutdown::ShutdownSignal;
use super::cursor::ReplicationCursor;
use super::protocol::*;
use super::{ReplicationEntry, ReplicationRegistry};
const ACK_INTERVAL: usize = 1000;
const RECONNECT_BASE_MS: u64 = 1000;
const RECONNECT_MAX_MS: u64 = 30_000;
const CURSOR_SAVE_INTERVAL: usize = 1000;
pub struct ReplicationClient {
stop: ShutdownSignal,
handles: Vec<JoinHandle<()>>,
}
impl ReplicationClient {
pub fn start(
leader_addr: SocketAddr,
shards: Arc<Vec<Shard>>,
registry: Arc<ReplicationRegistry>,
key_lens: Vec<u16>,
signal: ShutdownSignal,
) -> DbResult<Self> {
let mut handles = Vec::with_capacity(shards.len());
for shard_id in 0..shards.len() {
let shards = shards.clone();
let registry = registry.clone();
let stop = signal.clone();
let key_lens = key_lens.clone();
let handle = thread::spawn(move || {
run_shard_client(leader_addr, &shards, shard_id, ®istry, &key_lens, &stop);
});
handles.push(handle);
}
Ok(Self {
stop: signal,
handles,
})
}
pub fn stop(&self) {
self.stop.shutdown();
}
}
impl Drop for ReplicationClient {
fn drop(&mut self) {
self.stop.shutdown();
for h in self.handles.drain(..) {
let _ = h.join();
}
}
}
fn run_shard_client(
leader_addr: SocketAddr,
shards: &[Shard],
shard_id: usize,
registry: &ReplicationRegistry,
key_lens: &[u16],
stop: &ShutdownSignal,
) {
let shard = &shards[shard_id];
let shard_dir = shard.dir().to_path_buf();
let mut backoff_ms = RECONNECT_BASE_MS;
loop {
if stop.is_shutdown() {
return;
}
let cursor = ReplicationCursor::load(shard_id as u8, &shard_dir)
.ok()
.flatten()
.unwrap_or_else(|| ReplicationCursor::new(shard_id as u8));
match connect_and_stream(
leader_addr,
shards,
shard_id,
&cursor,
registry,
key_lens,
stop.as_flag(),
) {
Ok(()) => {
tracing::info!(shard_id, "replication stream ended cleanly");
return;
}
Err(e) => {
tracing::error!(shard_id, error = %e, backoff_ms, "replication error, reconnecting");
if stop.wait_timeout(Duration::from_millis(backoff_ms)) {
return;
}
backoff_ms = (backoff_ms * 2).min(RECONNECT_MAX_MS);
}
}
}
}
fn connect_and_stream(
leader_addr: SocketAddr,
shards: &[Shard],
shard_id: usize,
cursor: &ReplicationCursor,
registry: &ReplicationRegistry,
key_lens: &[u16],
stop: &std::sync::atomic::AtomicBool,
) -> DbResult<()> {
let stream = TcpStream::connect_timeout(&leader_addr, Duration::from_secs(5))?;
let _ = stream.set_nodelay(true);
let mut writer = stream.try_clone()?;
let mut reader = BufReader::new(stream);
let req = SyncRequest {
shard_id: shard_id as u8,
from_gsn: cursor.last_gsn,
key_lens: key_lens.to_vec(),
};
write_frame(&mut writer, &req.encode())?;
let info_frame = read_frame(&mut reader)?;
if info_frame.msg_type != MessageType::ShardInfo {
return Err(DbError::Replication(format!(
"expected ShardInfo, got {:?}",
info_frame.msg_type
)));
}
let info = ShardInfo::decode(&info_frame.payload)?;
if info.shard_count as usize != shards.len() {
return Err(DbError::ShardCountMismatch {
leader: info.shard_count as usize,
follower: shards.len(),
});
}
tracing::info!(shard_id, from_gsn = cursor.last_gsn, "connected to leader");
let shard = &shards[shard_id];
let shard_dir = shard.dir().to_path_buf();
let mut last_gsn = cursor.last_gsn;
let mut entries_since_ack = 0usize;
let mut entries_since_save = 0usize;
loop {
if stop.load(Ordering::Relaxed) {
if entries_since_save > 0 {
let mut cursor = ReplicationCursor::new(shard_id as u8);
cursor.advance(last_gsn, 0, 0);
let _ = cursor.save(&shard_dir);
}
return Ok(());
}
let frame = read_frame(&mut reader)?;
match frame.msg_type {
MessageType::EntryBatch => {
let batch = EntryBatch::decode(&frame.payload)?;
for wire_entry in &batch.entries {
let repl_entry = ReplicationEntry {
data: wire_entry.data.clone(),
key_len: wire_entry.key_len,
};
registry.apply_streaming(shard, &repl_entry)?;
if wire_entry.gsn > last_gsn {
last_gsn = wire_entry.gsn;
}
entries_since_ack += 1;
entries_since_save += 1;
}
if entries_since_ack >= ACK_INTERVAL {
let ack = AckMessage {
shard_id: shard_id as u8,
last_gsn,
};
write_frame(&mut writer, &ack.encode())?;
entries_since_ack = 0;
}
if entries_since_save >= CURSOR_SAVE_INTERVAL {
let mut cursor = ReplicationCursor::new(shard_id as u8);
cursor.advance(last_gsn, 0, 0);
cursor.save(&shard_dir)?;
entries_since_save = 0;
}
}
MessageType::CaughtUp => {
let caught_up = CaughtUp::decode(&frame.payload)?;
tracing::info!(shard_id, leader_gsn = caught_up.leader_gsn, "caught up");
}
MessageType::Heartbeat => {
write_frame(&mut writer, &encode_heartbeat())?;
}
MessageType::Error => {
let msg = decode_error(&frame.payload);
return Err(DbError::Replication(msg));
}
other => {
return Err(DbError::Replication(format!(
"unexpected message: {other:?}"
)));
}
}
}
}