use std::io::{BufReader, BufWriter};
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use crate::error::DbResult;
use crate::shard::Shard;
use crate::shutdown::ShutdownSignal;
use super::ReplicationEntry;
use super::log_reader::ShardLogReader;
use super::protocol::*;
const BATCH_MAX_ENTRIES: usize = 256;
const BATCH_MAX_BYTES: usize = 64 * 1024;
const TAIL_POLL_MS: u64 = 1;
pub const HEARTBEAT_INTERVAL_SECS: u64 = 5;
type HandlerHandles = Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>;
pub struct ReplicationServerOptions {
pub heartbeat_interval_secs: u64,
}
impl Default for ReplicationServerOptions {
fn default() -> Self {
Self {
heartbeat_interval_secs: HEARTBEAT_INTERVAL_SECS,
}
}
}
pub struct ReplicationServer {
stop: ShutdownSignal,
acceptor_handle: Option<JoinHandle<()>>,
handler_handles: HandlerHandles,
pub min_replicated_gsn: Arc<Vec<AtomicU64>>,
}
impl ReplicationServer {
pub fn start(
bind_addr: SocketAddr,
shards: Arc<Vec<Shard>>,
consumers: Vec<rtrb::Consumer<ReplicationEntry>>,
max_file_size: u64,
signal: ShutdownSignal,
) -> DbResult<Self> {
Self::start_with_options(
bind_addr,
shards,
consumers,
max_file_size,
signal,
ReplicationServerOptions::default(),
)
}
pub fn start_with_options(
bind_addr: SocketAddr,
shards: Arc<Vec<Shard>>,
consumers: Vec<rtrb::Consumer<ReplicationEntry>>,
max_file_size: u64,
signal: ShutdownSignal,
options: ReplicationServerOptions,
) -> DbResult<Self> {
let shard_count = shards.len();
let heartbeat_secs = options.heartbeat_interval_secs;
let min_replicated_gsn: Arc<Vec<AtomicU64>> =
Arc::new((0..shard_count).map(|_| AtomicU64::new(0)).collect());
let consumers: Arc<Vec<crate::sync::Mutex<Option<rtrb::Consumer<ReplicationEntry>>>>> =
Arc::new(
consumers
.into_iter()
.map(|c| crate::sync::Mutex::new(Some(c)))
.collect(),
);
let listener = TcpListener::bind(bind_addr)?;
listener.set_nonblocking(true)?;
let stop2 = signal.clone();
let shards2 = shards.clone();
let min_gsn2 = min_replicated_gsn.clone();
let consumers2 = consumers.clone();
let handler_handles: HandlerHandles = Arc::new(crate::sync::Mutex::new(Vec::new()));
let hh2 = handler_handles.clone();
let acceptor = thread::spawn(move || {
tracing::info!(%bind_addr, "replication server started");
while !stop2.is_shutdown() {
match listener.accept() {
Ok((stream, addr)) => {
tracing::info!(%addr, "follower connected");
let _ = stream.set_nodelay(true);
let _ =
stream.set_read_timeout(Some(Duration::from_secs(2 * heartbeat_secs)));
let shards = shards2.clone();
let consumers = consumers2.clone();
let stop_handler = stop2.clone();
let min_gsn = min_gsn2.clone();
let hh = hh2.clone();
let handle = thread::spawn(move || {
if let Err(e) = handle_connection_in_thread(
stream,
&shards,
&consumers,
max_file_size,
&stop_handler,
&min_gsn,
heartbeat_secs,
) {
tracing::error!(%addr, error = %e, "handler thread error");
}
});
crate::sync::lock(&hh).push(handle);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
stop2.wait_timeout(Duration::from_millis(50));
}
Err(e) => {
tracing::error!(error = %e, "accept error");
stop2.wait_timeout(Duration::from_millis(100));
}
}
}
tracing::info!("replication server stopped");
});
Ok(Self {
stop: signal,
acceptor_handle: Some(acceptor),
handler_handles,
min_replicated_gsn,
})
}
pub fn stop(&self) {
self.stop.shutdown();
}
}
impl crate::compaction::CompactionGuard for ReplicationServer {
fn min_replicated_gsn(&self, shard_id: u8) -> u64 {
self.min_replicated_gsn
.get(shard_id as usize)
.map(|v| v.load(Ordering::Relaxed))
.unwrap_or(u64::MAX)
}
}
impl Drop for ReplicationServer {
fn drop(&mut self) {
self.stop.shutdown();
if let Some(h) = self.acceptor_handle.take() {
let _ = h.join();
}
let mut handles = crate::sync::lock(&self.handler_handles);
for h in handles.drain(..) {
let _ = h.join();
}
}
}
#[allow(clippy::too_many_arguments)]
fn handle_connection_in_thread(
stream: TcpStream,
shards: &Arc<Vec<Shard>>,
consumers: &Arc<Vec<crate::sync::Mutex<Option<rtrb::Consumer<ReplicationEntry>>>>>,
max_file_size: u64,
stop: &ShutdownSignal,
min_gsn: &Arc<Vec<AtomicU64>>,
heartbeat_secs: u64,
) -> DbResult<()> {
let mut reader = stream.try_clone().map_err(crate::error::DbError::Io)?;
let frame = read_frame(&mut reader)?;
if frame.msg_type != MessageType::SyncRequest {
return Err(crate::error::DbError::Replication(format!(
"expected SyncRequest, got {:?}",
frame.msg_type
)));
}
let req = SyncRequest::decode(&frame.payload)?;
let shard_id = req.shard_id as usize;
if shard_id >= shards.len() {
return Err(crate::error::DbError::Replication(format!(
"invalid shard_id {shard_id}"
)));
}
let info = ShardInfo {
shard_count: shards.len() as u8,
max_file_size,
};
let mut writer = BufWriter::new(stream);
write_frame(&mut writer, &info.encode())?;
let consumer = {
let mut guard = crate::sync::lock(&consumers[shard_id]);
guard.take()
};
if consumer.is_none() {
tracing::warn!(
shard_id,
"shard already streaming, rejecting second connection"
);
let _ = write_frame(&mut writer, &encode_error("shard already streaming"));
return Ok(());
}
let key_len = req.key_len;
let from_gsn = req.from_gsn;
let result = serve_shard(
&mut reader,
&mut writer,
shards,
shard_id,
from_gsn,
consumer,
key_len,
stop,
min_gsn,
heartbeat_secs,
);
let (outcome, returned_consumer) = result;
if let Some(c) = returned_consumer {
let mut guard = crate::sync::lock(&consumers[shard_id]);
*guard = Some(c);
}
outcome
}
#[allow(clippy::too_many_arguments)]
fn serve_shard(
reader: &mut TcpStream,
writer: &mut BufWriter<TcpStream>,
shards: &[Shard],
shard_id: usize,
from_gsn: u64,
consumer: Option<rtrb::Consumer<ReplicationEntry>>,
key_len: u16,
stop: &ShutdownSignal,
min_gsn: &Arc<Vec<AtomicU64>>,
heartbeat_secs: u64,
) -> (DbResult<()>, Option<rtrb::Consumer<ReplicationEntry>>) {
let mut consumer = consumer;
let result = serve_shard_inner(
reader,
writer,
shards,
shard_id,
from_gsn,
&mut consumer,
key_len,
stop,
min_gsn,
heartbeat_secs,
);
(result, consumer)
}
#[allow(clippy::too_many_arguments)]
fn serve_shard_inner(
reader: &mut TcpStream,
writer: &mut BufWriter<TcpStream>,
shards: &[Shard],
shard_id: usize,
from_gsn: u64,
consumer: &mut Option<rtrb::Consumer<ReplicationEntry>>,
key_len: u16,
stop: &ShutdownSignal,
min_gsn: &Arc<Vec<AtomicU64>>,
heartbeat_secs: u64,
) -> DbResult<()> {
let shard = &shards[shard_id];
let shard_dir = shard.dir().to_path_buf();
let current_gsn = shard.gsn().load(Ordering::Relaxed);
let ack_stream = reader.try_clone()?;
ack_stream.set_read_timeout(Some(Duration::from_secs(2 * heartbeat_secs)))?;
let ack_buf_reader = BufReader::new(ack_stream);
let stop_ack = stop.clone();
let min_gsn_ack = min_gsn.clone();
let ack_handle = thread::spawn(move || {
let mut br = ack_buf_reader;
while !stop_ack.is_shutdown() {
match read_frame(&mut br) {
Ok(frame) => match frame.msg_type {
MessageType::Ack => {
if let Ok(ack) = AckMessage::decode(&frame.payload) {
if ack.shard_id == shard_id as u8 {
min_gsn_ack[shard_id].fetch_max(ack.last_gsn, Ordering::Relaxed);
}
}
}
MessageType::Heartbeat => {
tracing::trace!(shard_id, "received heartbeat echo");
}
other => {
tracing::warn!(?other, shard_id, "unexpected frame in ack reader; exiting");
break;
}
},
Err(e) => {
use std::io::ErrorKind;
if matches!(e.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) {
tracing::warn!(shard_id, error = %e, "ack reader timed out — dead peer");
} else {
tracing::info!(shard_id, error = %e, "ack reader stream closed");
}
break;
}
}
}
});
macro_rules! cleanup_ack {
() => {{
let _ = reader.shutdown(Shutdown::Read);
let _ = ack_handle.join();
}};
}
if from_gsn < current_gsn {
tracing::info!(shard_id, from_gsn, current_gsn, "starting catch-up");
if let Err(e) = shard.flush_for_replication_catchup() {
cleanup_ack!();
return Err(e);
}
let mut log_reader = match ShardLogReader::new(
shard_dir,
from_gsn,
key_len,
#[cfg(feature = "encryption")]
shard.cipher(),
) {
Ok(lr) => lr,
Err(e) => {
cleanup_ack!();
return Err(e);
}
};
let mut last_gsn = from_gsn;
'catchup: loop {
if stop.is_shutdown() {
cleanup_ack!();
return Ok(());
}
let mut batch = Vec::new();
let mut batch_bytes = 0;
loop {
if batch.len() >= BATCH_MAX_ENTRIES || batch_bytes >= BATCH_MAX_BYTES {
break;
}
match log_reader.next_entry() {
Ok(Some(entry)) => {
last_gsn = entry.gsn;
batch_bytes += entry.data.len();
batch.push(WireEntry {
entry_len: entry.data.len() as u32,
key_len: entry.key_len,
gsn: entry.gsn,
data: entry.data,
});
}
Ok(None) => break,
Err(e) => {
cleanup_ack!();
return Err(e);
}
}
}
if batch.is_empty() {
break 'catchup; }
let msg = EntryBatch {
shard_id: shard_id as u8,
entries: batch,
};
if let Err(e) = write_frame(writer, &msg.encode()) {
cleanup_ack!();
return Err(e.into());
}
}
let caught_up = CaughtUp {
shard_id: shard_id as u8,
leader_gsn: last_gsn,
};
if let Err(e) = write_frame(writer, &caught_up.encode()) {
cleanup_ack!();
return Err(e.into());
}
tracing::info!(shard_id, last_gsn, "catch-up complete");
}
if let Some(consumer) = consumer.as_mut() {
tracing::info!(shard_id, "entering streaming mode");
let mut last_send = std::time::Instant::now();
loop {
if stop.is_shutdown() {
break;
}
let mut batch = Vec::new();
let mut batch_bytes = 0;
while batch.len() < BATCH_MAX_ENTRIES && batch_bytes < BATCH_MAX_BYTES {
match consumer.pop() {
Ok(entry) => {
batch_bytes += entry.data.len();
let gsn = extract_gsn(&entry.data);
batch.push(WireEntry {
entry_len: entry.data.len() as u32,
key_len: entry.key_len,
gsn,
data: entry.data,
});
}
Err(_) => break, }
}
if batch.is_empty() {
if last_send.elapsed().as_secs() >= heartbeat_secs {
if let Err(e) = write_frame(writer, &encode_heartbeat()) {
cleanup_ack!();
return Err(e.into());
}
last_send = std::time::Instant::now();
}
thread::sleep(Duration::from_millis(TAIL_POLL_MS));
continue;
}
let msg = EntryBatch {
shard_id: shard_id as u8,
entries: batch,
};
if let Err(e) = write_frame(writer, &msg.encode()) {
cleanup_ack!();
return Err(e.into());
}
last_send = std::time::Instant::now();
}
}
cleanup_ack!();
Ok(())
}
fn extract_gsn(data: &[u8]) -> u64 {
if data.len() < 8 {
return 0;
}
let gsn = u64::from_ne_bytes(data[..8].try_into().expect("impossible"));
gsn & !crate::entry::TOMBSTONE_BIT
}