use std::io::BufWriter;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use crate::error::DbResult;
use crate::shard::Shard;
use crate::shutdown::ShutdownSignal;
use super::ReplicationEntry;
use super::log_reader::ShardLogReader;
use super::protocol::*;
const BATCH_MAX_ENTRIES: usize = 256;
const BATCH_MAX_BYTES: usize = 64 * 1024;
const TAIL_POLL_MS: u64 = 1;
type HandlerHandles = Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>;
pub struct ReplicationServer {
stop: ShutdownSignal,
acceptor_handle: Option<JoinHandle<()>>,
handler_handles: HandlerHandles,
pub min_replicated_gsn: Arc<Vec<AtomicU64>>,
}
impl ReplicationServer {
pub fn start(
bind_addr: SocketAddr,
shards: Arc<Vec<Shard>>,
consumers: Vec<rtrb::Consumer<ReplicationEntry>>,
key_lens: Vec<usize>,
max_file_size: u64,
signal: ShutdownSignal,
) -> DbResult<Self> {
let shard_count = shards.len();
let min_replicated_gsn: Arc<Vec<AtomicU64>> =
Arc::new((0..shard_count).map(|_| AtomicU64::new(0)).collect());
let consumers: Arc<Vec<crate::sync::Mutex<Option<rtrb::Consumer<ReplicationEntry>>>>> =
Arc::new(
consumers
.into_iter()
.map(|c| crate::sync::Mutex::new(Some(c)))
.collect(),
);
let listener = TcpListener::bind(bind_addr)?;
listener.set_nonblocking(true)?;
let stop2 = signal.clone();
let shards2 = shards.clone();
let key_lens2 = key_lens.clone();
let min_gsn2 = min_replicated_gsn.clone();
let consumers2 = consumers.clone();
let handler_handles: HandlerHandles = Arc::new(crate::sync::Mutex::new(Vec::new()));
let hh2 = handler_handles.clone();
let acceptor = thread::spawn(move || {
tracing::info!(%bind_addr, "replication server started");
while !stop2.is_shutdown() {
match listener.accept() {
Ok((stream, addr)) => {
tracing::info!(%addr, "follower connected");
let _ = stream.set_nodelay(true);
handle_connection(
stream,
&shards2,
&consumers2,
&key_lens2,
max_file_size,
&stop2,
&min_gsn2,
&hh2,
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
stop2.wait_timeout(Duration::from_millis(50));
}
Err(e) => {
tracing::error!(error = %e, "accept error");
stop2.wait_timeout(Duration::from_millis(100));
}
}
}
tracing::info!("replication server stopped");
});
Ok(Self {
stop: signal,
acceptor_handle: Some(acceptor),
handler_handles,
min_replicated_gsn,
})
}
pub fn stop(&self) {
self.stop.shutdown();
}
}
impl crate::compaction::CompactionGuard for ReplicationServer {
fn min_replicated_gsn(&self, shard_id: u8) -> u64 {
self.min_replicated_gsn
.get(shard_id as usize)
.map(|v| v.load(Ordering::Relaxed))
.unwrap_or(u64::MAX)
}
}
impl Drop for ReplicationServer {
fn drop(&mut self) {
self.stop.shutdown();
if let Some(h) = self.acceptor_handle.take() {
let _ = h.join();
}
let mut handles = crate::sync::lock(&self.handler_handles);
for h in handles.drain(..) {
let _ = h.join();
}
}
}
#[allow(clippy::too_many_arguments)]
fn handle_connection(
stream: TcpStream,
shards: &Arc<Vec<Shard>>,
consumers: &Arc<Vec<crate::sync::Mutex<Option<rtrb::Consumer<ReplicationEntry>>>>>,
key_lens: &[usize],
max_file_size: u64,
stop: &ShutdownSignal,
min_gsn: &Arc<Vec<AtomicU64>>,
handler_handles: &HandlerHandles,
) {
let mut reader = stream.try_clone().expect("clone tcp stream");
let frame = match read_frame(&mut reader) {
Ok(f) => f,
Err(e) => {
tracing::error!(error = %e, "failed to read SyncRequest");
return;
}
};
if frame.msg_type != MessageType::SyncRequest {
tracing::error!("expected SyncRequest, got {:?}", frame.msg_type);
return;
}
let req = match SyncRequest::decode(&frame.payload) {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "invalid SyncRequest");
return;
}
};
let shard_id = req.shard_id as usize;
if shard_id >= shards.len() {
tracing::error!(shard_id, "invalid shard_id");
return;
}
let info = ShardInfo {
shard_count: shards.len() as u8,
max_file_size,
};
let mut writer = BufWriter::new(stream);
if write_frame(&mut writer, &info.encode()).is_err() {
return;
}
let consumer = {
let mut guard = crate::sync::lock(&consumers[shard_id]);
guard.take()
};
let shards = shards.clone();
let key_lens = key_lens.to_vec();
let stop = stop.clone();
let min_gsn = min_gsn.clone();
let handle = thread::spawn(move || {
if let Err(e) = serve_shard(
&mut reader,
&mut writer,
&shards,
shard_id,
req.from_gsn,
consumer,
&key_lens,
stop.as_flag(),
&min_gsn,
) {
tracing::error!(shard_id, error = %e, "shard stream error");
}
});
crate::sync::lock(handler_handles).push(handle);
}
#[allow(clippy::too_many_arguments)]
fn serve_shard(
reader: &mut TcpStream,
writer: &mut BufWriter<TcpStream>,
shards: &[Shard],
shard_id: usize,
from_gsn: u64,
consumer: Option<rtrb::Consumer<ReplicationEntry>>,
key_lens: &[usize],
stop: &AtomicBool,
min_gsn: &Arc<Vec<AtomicU64>>,
) -> DbResult<()> {
let shard = &shards[shard_id];
let shard_dir = shard.dir().to_path_buf();
let current_gsn = shard.gsn().load(Ordering::Relaxed);
if from_gsn < current_gsn {
tracing::info!(shard_id, from_gsn, current_gsn, "starting catch-up");
shard.flush_buf()?;
let mut log_reader = ShardLogReader::new(shard_dir, from_gsn, key_lens.to_vec())?;
let mut last_gsn = from_gsn;
loop {
if stop.load(Ordering::Relaxed) {
return Ok(());
}
let mut batch = Vec::new();
let mut batch_bytes = 0;
while batch.len() < BATCH_MAX_ENTRIES && batch_bytes < BATCH_MAX_BYTES {
match log_reader.next_entry()? {
Some(entry) => {
last_gsn = entry.gsn;
batch_bytes += entry.data.len();
batch.push(WireEntry {
entry_len: entry.data.len() as u32,
key_len: entry.key_len,
gsn: entry.gsn,
data: entry.data,
});
}
None => break,
}
}
if batch.is_empty() {
break; }
let msg = EntryBatch {
shard_id: shard_id as u8,
entries: batch,
};
write_frame(writer, &msg.encode())?;
let ack_frame = read_frame(reader)?;
if ack_frame.msg_type == MessageType::Ack {
let ack = AckMessage::decode(&ack_frame.payload)?;
min_gsn[shard_id].fetch_max(ack.last_gsn, Ordering::Relaxed);
}
}
let caught_up = CaughtUp {
shard_id: shard_id as u8,
leader_gsn: last_gsn,
};
write_frame(writer, &caught_up.encode())?;
tracing::info!(shard_id, last_gsn, "catch-up complete");
}
if let Some(mut consumer) = consumer {
tracing::info!(shard_id, "entering streaming mode");
loop {
if stop.load(Ordering::Relaxed) {
return Ok(());
}
let mut batch = Vec::new();
let mut batch_bytes = 0;
while batch.len() < BATCH_MAX_ENTRIES && batch_bytes < BATCH_MAX_BYTES {
match consumer.pop() {
Ok(entry) => {
batch_bytes += entry.data.len();
let gsn = extract_gsn(&entry.data);
batch.push(WireEntry {
entry_len: entry.data.len() as u32,
key_len: entry.key_len,
gsn,
data: entry.data,
});
}
Err(_) => break, }
}
if batch.is_empty() {
thread::sleep(Duration::from_millis(TAIL_POLL_MS));
continue;
}
let msg = EntryBatch {
shard_id: shard_id as u8,
entries: batch,
};
write_frame(writer, &msg.encode())?;
let _ = reader.set_nonblocking(true);
if let Ok(ack_frame) = read_frame(reader)
&& ack_frame.msg_type == MessageType::Ack
&& let Ok(ack) = AckMessage::decode(&ack_frame.payload)
{
min_gsn[shard_id].fetch_max(ack.last_gsn, Ordering::Relaxed);
}
let _ = reader.set_nonblocking(false);
}
}
Ok(())
}
fn extract_gsn(data: &[u8]) -> u64 {
if data.len() < 8 {
return 0;
}
let gsn = u64::from_ne_bytes(data[..8].try_into().expect("impossible"));
gsn & !crate::entry::TOMBSTONE_BIT
}