use std::io::BufWriter;
use std::io::Write as _;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use rtrb::{Consumer, Producer, RingBuffer};
use crate::error::{DbError, DbResult};
use crate::shutdown::ShutdownSignal;
use super::engine_access::ArcEngine;
use super::event::FixedReplicationEvent;
use super::protocol::*;
pub const SPSC_CAPACITY: usize = 8192;
const SCAN_CHUNK_BYTES: usize = 64 * 1024;
type PendingSlot = crate::sync::Mutex<(
Option<Producer<FixedReplicationEvent>>,
Option<Consumer<FixedReplicationEvent>>,
)>;
pub struct FixedReplicationServer {
stop: ShutdownSignal,
acceptor_handle: Option<JoinHandle<()>>,
handler_handles: Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>,
#[allow(dead_code)]
producers_installed: Arc<AtomicBool>,
}
impl FixedReplicationServer {
pub fn start(
bind_addr: SocketAddr,
engine: ArcEngine,
signal: ShutdownSignal,
) -> DbResult<Self> {
let shard_count = engine.shard_count();
let mut pending: Vec<PendingSlot> = Vec::with_capacity(shard_count);
for _ in 0..shard_count {
let (p, c) = RingBuffer::new(SPSC_CAPACITY);
pending.push(crate::sync::Mutex::new((Some(p), Some(c))));
}
let pending: Arc<Vec<PendingSlot>> = Arc::new(pending);
let producers_installed = Arc::new(AtomicBool::new(false));
let handler_handles = Arc::new(crate::sync::Mutex::new(Vec::new()));
let listener = TcpListener::bind(bind_addr).map_err(DbError::from)?;
listener.set_nonblocking(true).ok();
let acceptor_handle = {
let engine = engine.clone();
let pending = pending.clone();
let producers_installed = producers_installed.clone();
let stop = signal.clone();
let hh = handler_handles.clone();
thread::spawn(move || {
acceptor_loop(listener, engine, pending, producers_installed, hh, stop);
})
};
Ok(Self {
stop: signal,
acceptor_handle: Some(acceptor_handle),
handler_handles,
producers_installed,
})
}
pub fn stop(&self) {
self.stop.shutdown();
}
}
impl Drop for FixedReplicationServer {
fn drop(&mut self) {
self.stop.shutdown();
if let Some(h) = self.acceptor_handle.take() {
let _ = h.join();
}
let mut handles = crate::sync::lock(&self.handler_handles);
for h in handles.drain(..) {
let _ = h.join();
}
}
}
fn acceptor_loop(
listener: TcpListener,
engine: ArcEngine,
pending: Arc<Vec<PendingSlot>>,
producers_installed: Arc<AtomicBool>,
handler_handles: Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>,
stop: ShutdownSignal,
) {
while !stop.is_shutdown() {
match listener.accept() {
Ok((stream, addr)) => {
tracing::info!(%addr, "fixed follower connected");
stream.set_nodelay(true).ok();
stream.set_nonblocking(false).ok();
stream
.set_read_timeout(Some(Duration::from_secs(HEARTBEAT_INTERVAL_SECS * 2)))
.ok();
if !producers_installed.swap(true, Ordering::SeqCst) {
tracing::info!("first fixed follower — installing SPSC producers");
let mut producers = Vec::with_capacity(pending.len());
for slot in pending.iter() {
let mut guard = crate::sync::lock(slot);
producers.push(guard.0.take().expect("producer present on first install"));
}
engine.install_replication_producers(producers);
}
let engine = engine.clone();
let pending = pending.clone();
let stop = stop.clone();
let hh = handler_handles.clone();
let handle = thread::spawn(move || {
if let Err(e) = serve_connection(stream, engine, pending, stop) {
tracing::error!(error = %e, "fixed replication connection error");
}
});
crate::sync::lock(&hh).push(handle);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
stop.wait_timeout(Duration::from_millis(50));
}
Err(e) => {
tracing::error!(error = %e, "fixed accept error");
stop.wait_timeout(Duration::from_millis(100));
}
}
}
tracing::info!("fixed replication acceptor stopped");
}
fn serve_connection(
stream: TcpStream,
engine: ArcEngine,
pending: Arc<Vec<PendingSlot>>,
stop: ShutdownSignal,
) -> DbResult<()> {
let mut reader = stream.try_clone().map_err(DbError::from)?;
let mut writer = BufWriter::new(stream);
let frame = read_frame(&mut reader).map_err(DbError::from)?;
if frame.msg_type != FixedMessageType::SyncRequest {
return Err(DbError::Replication(format!(
"expected SyncRequest, got {:?}",
frame.msg_type
)));
}
let req = SyncRequest::decode(&frame.payload).map_err(DbError::from)?;
if req.protocol_version != PROTOCOL_VERSION {
let msg = format!(
"protocol version mismatch: leader {PROTOCOL_VERSION}, follower {}",
req.protocol_version
);
write_frame(&mut writer, &encode_error(&msg)).map_err(DbError::from)?;
return Err(DbError::Replication(msg));
}
let shard_id = req.shard_id as usize;
if shard_id >= engine.shard_count() {
write_frame(&mut writer, &encode_error("invalid shard_id")).map_err(DbError::from)?;
return Err(DbError::Replication(format!("invalid shard_id {shard_id}")));
}
let info = ShardInfo {
shard_count: engine.shard_count() as u8,
key_len: engine.key_len() as u16,
value_len: engine.value_len() as u16,
slot_size: engine.slot_size(),
current_slot_count: engine.current_slot_count(shard_id),
shard_prefix_bits: engine.shard_prefix_bits(),
};
write_frame(&mut writer, &info.encode()).map_err(DbError::from)?;
let skip_deleted = (req.flags & FLAG_EMPTY_STATE) != 0;
tracing::info!(
shard_id,
skip_deleted,
protocol_version = req.protocol_version,
"fixed follower handshake complete"
);
let consumer = {
let mut guard = crate::sync::lock(&pending[shard_id]);
guard.1.take()
};
serve_shard(
reader,
writer,
engine,
shard_id,
consumer,
skip_deleted,
stop,
)
}
#[allow(clippy::too_many_arguments)]
fn serve_shard(
mut reader: TcpStream,
mut writer: BufWriter<TcpStream>,
engine: ArcEngine,
shard_id: usize,
consumer: Option<Consumer<FixedReplicationEvent>>,
skip_deleted: bool,
stop: ShutdownSignal,
) -> DbResult<()> {
let total = phase1_full_scan(
&engine,
shard_id,
&mut writer,
&mut reader,
skip_deleted,
&stop,
)?;
write_frame(
&mut writer,
&CaughtUp {
shard_id: shard_id as u8,
total_scanned: total,
}
.encode(),
)
.map_err(DbError::from)?;
tracing::info!(shard_id, total, "fixed catch-up complete");
if let Some(consumer) = consumer {
phase2_streaming(&engine, shard_id, consumer, &mut writer, &mut reader, &stop)?;
} else {
tracing::warn!(
shard_id,
"fixed SPSC consumer already taken; catch-up only mode"
);
idle_until_disconnect(&mut reader, &mut writer, &stop)?;
}
Ok(())
}
fn phase1_full_scan(
engine: &ArcEngine,
shard_id: usize,
writer: &mut BufWriter<TcpStream>,
reader: &mut TcpStream,
skip_deleted: bool,
stop: &ShutdownSignal,
) -> DbResult<u64> {
use crate::fixed::slot::{
SLOT_HEADER_SIZE, STATUS_DELETED, STATUS_FREE, STATUS_OCCUPIED, meta_of, status_of,
};
let slot_size = engine.slot_size() as usize;
let key_len = engine.key_len();
let value_len = engine.value_len();
let slot_count = engine.current_slot_count(shard_id);
let slots_per_chunk = (SCAN_CHUNK_BYTES / slot_size).max(1);
let mut total_scanned = 0u64;
let mut batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
let mut slot_id = 0u32;
while slot_id < slot_count {
if stop.is_shutdown() {
return Ok(total_scanned);
}
let remaining = slot_count - slot_id;
let this_chunk = remaining.min(slots_per_chunk as u32) as usize;
let chunk = engine.read_shard_chunk(shard_id, slot_id, this_chunk)?;
for i in 0..this_chunk {
let off = i * slot_size;
let slot_buf = &chunk[off..off + slot_size];
let meta = meta_of(slot_buf);
let status = status_of(meta);
if status == STATUS_FREE {
continue;
}
if status == STATUS_DELETED && skip_deleted {
continue;
}
let key = &slot_buf[SLOT_HEADER_SIZE..SLOT_HEADER_SIZE + key_len];
let current_slot = slot_id + i as u32;
if status == STATUS_OCCUPIED {
let value =
&slot_buf[SLOT_HEADER_SIZE + key_len..SLOT_HEADER_SIZE + key_len + value_len];
batch.add_occupied(current_slot, meta, key, value);
} else {
batch.add_deleted(current_slot, meta, key);
}
total_scanned += 1;
if batch.len() as usize >= BATCH_MAX_ENTRIES || batch.bytes() >= BATCH_MAX_BYTES {
flush_and_wait_ack(writer, reader, batch, engine, shard_id)?;
batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
}
}
slot_id += this_chunk as u32;
}
if !batch.is_empty() {
flush_and_wait_ack(writer, reader, batch, engine, shard_id)?;
}
metrics::counter!(
"armdb.fixed.catchup_slots_scanned",
"shard" => shard_id.to_string()
)
.increment(total_scanned);
Ok(total_scanned)
}
fn flush_and_wait_ack(
writer: &mut BufWriter<TcpStream>,
reader: &mut TcpStream,
batch: SlotBatchEncoder,
engine: &ArcEngine,
shard_id: usize,
) -> DbResult<()> {
const MAX_HEARTBEATS: u32 = 8;
let frame = batch.finish();
write_frame(writer, &frame).map_err(DbError::from)?;
writer.flush().map_err(DbError::from)?;
let mut heartbeats_skipped: u32 = 0;
loop {
let ack_frame = read_frame(reader).map_err(DbError::from)?;
match ack_frame.msg_type {
FixedMessageType::Ack => {
let ack = Ack::decode(&ack_frame.payload).map_err(DbError::from)?;
engine.update_min_replicated_version(shard_id, ack.max_version_seen);
return Ok(());
}
FixedMessageType::Heartbeat => {
heartbeats_skipped += 1;
if heartbeats_skipped > MAX_HEARTBEATS {
return Err(DbError::Replication(format!(
"too many consecutive heartbeats ({heartbeats_skipped}) \
while waiting for Phase-1 Ack"
)));
}
continue;
}
other => {
return Err(DbError::Replication(format!(
"expected Ack during Phase-1 catch-up, got {other:?}"
)));
}
}
}
}
fn phase2_streaming(
engine: &ArcEngine,
shard_id: usize,
mut consumer: Consumer<FixedReplicationEvent>,
writer: &mut BufWriter<TcpStream>,
reader: &mut TcpStream,
stop: &ShutdownSignal,
) -> DbResult<()> {
use crate::fixed::slot::{SLOT_HEADER_SIZE, meta_of};
let key_len = engine.key_len();
let value_len = engine.value_len();
let slot_size = engine.slot_size() as usize;
let mut last_heartbeat = Instant::now();
let hb_interval = Duration::from_secs(HEARTBEAT_INTERVAL_SECS);
reader.set_nonblocking(true).ok();
loop {
if stop.is_shutdown() {
return Ok(());
}
let mut batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
while (batch.len() as usize) < BATCH_MAX_ENTRIES && batch.bytes() < BATCH_MAX_BYTES {
match consumer.pop() {
Ok(FixedReplicationEvent::Write { slot_id, payload }) => {
debug_assert_eq!(payload.len(), slot_size);
let meta = meta_of(&payload);
let key = &payload[SLOT_HEADER_SIZE..SLOT_HEADER_SIZE + key_len];
let value = &payload
[SLOT_HEADER_SIZE + key_len..SLOT_HEADER_SIZE + key_len + value_len];
batch.add_occupied(slot_id, meta, key, value);
}
Ok(FixedReplicationEvent::Delete { slot_id, meta, key }) => {
batch.add_deleted(slot_id, meta, &key);
}
Err(_) => break,
}
}
if !batch.is_empty() {
let frame_events = batch.len() as u64;
let frame = batch.finish();
write_frame(writer, &frame).map_err(DbError::from)?;
writer.flush().map_err(DbError::from)?;
metrics::counter!(
"armdb.fixed.streaming_events_sent",
"shard" => shard_id.to_string()
)
.increment(frame_events);
match read_frame(reader) {
Ok(f) if f.msg_type == FixedMessageType::Ack => {
if let Ok(ack) = Ack::decode(&f.payload) {
engine.update_min_replicated_version(shard_id, ack.max_version_seen);
}
}
Ok(_) => {}
Err(ref e)
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut => {}
Err(e) => {
return Err(DbError::Replication(format!("ack read error: {e}")));
}
}
} else {
if last_heartbeat.elapsed() >= hb_interval {
write_frame(writer, &encode_heartbeat()).map_err(DbError::from)?;
writer.flush().map_err(DbError::from)?;
last_heartbeat = Instant::now();
}
thread::sleep(Duration::from_millis(TAIL_POLL_MS));
}
}
}
fn idle_until_disconnect(
reader: &mut TcpStream,
writer: &mut BufWriter<TcpStream>,
stop: &ShutdownSignal,
) -> DbResult<()> {
reader.set_read_timeout(Some(Duration::from_secs(1))).ok();
loop {
if stop.is_shutdown() {
return Ok(());
}
match read_frame(reader) {
Ok(f) if f.msg_type == FixedMessageType::Heartbeat => {
write_frame(writer, &encode_heartbeat()).map_err(DbError::from)?;
writer.flush().map_err(DbError::from)?;
}
Ok(_) => {}
Err(ref e)
if e.kind() == std::io::ErrorKind::TimedOut
|| e.kind() == std::io::ErrorKind::WouldBlock =>
{
continue;
}
Err(_) => return Ok(()),
}
}
}