Skip to main content

armdb/fixed_replication/
server.rs

1//! Leader-side replication server for FixedStore.
2//!
3//! - Spawns a TCP acceptor thread.
4//! - On the first accepted follower connection, installs SPSC producers
5//!   into every shard (lazy install — Mitigation B of spec §5/§7).
6//! - Each accepted connection spawns a per-shard serve thread that does
7//!   Phase 1 (full scan catch-up, honoring FLAG_EMPTY_STATE to skip DELETED
8//!   slots) then Phase 2 (SPSC streaming) if the per-shard consumer is
9//!   still available.
10//!
11//! Single-follower-streaming-per-shard constraint: the SPSC consumer for
12//! each shard is held by at most one connection at a time.  If a concurrent
13//! follower tries to claim a consumer that is already taken the server sends
14//! an Error frame and returns, letting the client reconnect later.  When a
15//! connection finishes (cleanly or on error), `ShardConsumerGuard::drop`
16//! returns the consumer to the pending slot so the next connection can
17//! claim it and enter Phase-2 streaming.
18
19use std::io::BufWriter;
20use std::io::Write as _;
21use std::net::{SocketAddr, TcpListener, TcpStream};
22use std::sync::Arc;
23use std::sync::atomic::{AtomicBool, Ordering};
24use std::thread::{self, JoinHandle};
25use std::time::{Duration, Instant};
26
27use rtrb::{Consumer, Producer, RingBuffer};
28
29use crate::error::{DbError, DbResult};
30use crate::shutdown::ShutdownSignal;
31
32use super::engine_access::ArcEngine;
33use super::event::FixedReplicationEvent;
34use super::protocol::*;
35
36pub const SPSC_CAPACITY: usize = 8192;
37const SCAN_CHUNK_BYTES: usize = 64 * 1024;
38
39// Per-shard slot: holds either (Some producer, Some consumer) before first install,
40// (None, Some consumer) between install and first follower accept,
41// (None, None) after consumer handed off, or (None, None) forever after.
42type PendingSlot = crate::sync::Mutex<(
43    Option<Producer<FixedReplicationEvent>>,
44    Option<Consumer<FixedReplicationEvent>>,
45)>;
46
47pub struct FixedReplicationServer {
48    stop: ShutdownSignal,
49    acceptor_handle: Option<JoinHandle<()>>,
50    handler_handles: Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>,
51    #[allow(dead_code)]
52    producers_installed: Arc<AtomicBool>,
53}
54
55impl FixedReplicationServer {
56    pub fn start(
57        bind_addr: SocketAddr,
58        engine: ArcEngine,
59        signal: ShutdownSignal,
60    ) -> DbResult<Self> {
61        let shard_count = engine.shard_count();
62        let mut pending: Vec<PendingSlot> = Vec::with_capacity(shard_count);
63        for _ in 0..shard_count {
64            let (p, c) = RingBuffer::new(SPSC_CAPACITY);
65            pending.push(crate::sync::Mutex::new((Some(p), Some(c))));
66        }
67        let pending: Arc<Vec<PendingSlot>> = Arc::new(pending);
68        let producers_installed = Arc::new(AtomicBool::new(false));
69        let handler_handles = Arc::new(crate::sync::Mutex::new(Vec::new()));
70
71        let listener = TcpListener::bind(bind_addr).map_err(DbError::from)?;
72        listener.set_nonblocking(true).ok();
73
74        let acceptor_handle = {
75            let engine = engine.clone();
76            let pending = pending.clone();
77            let producers_installed = producers_installed.clone();
78            let stop = signal.clone();
79            let hh = handler_handles.clone();
80            thread::spawn(move || {
81                acceptor_loop(listener, engine, pending, producers_installed, hh, stop);
82            })
83        };
84
85        Ok(Self {
86            stop: signal,
87            acceptor_handle: Some(acceptor_handle),
88            handler_handles,
89            producers_installed,
90        })
91    }
92
93    pub fn stop(&self) {
94        self.stop.shutdown();
95    }
96}
97
98impl Drop for FixedReplicationServer {
99    fn drop(&mut self) {
100        self.stop.shutdown();
101        if let Some(h) = self.acceptor_handle.take() {
102            let _ = h.join();
103        }
104        let mut handles = crate::sync::lock(&self.handler_handles);
105        for h in handles.drain(..) {
106            let _ = h.join();
107        }
108    }
109}
110
111fn acceptor_loop(
112    listener: TcpListener,
113    engine: ArcEngine,
114    pending: Arc<Vec<PendingSlot>>,
115    producers_installed: Arc<AtomicBool>,
116    handler_handles: Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>,
117    stop: ShutdownSignal,
118) {
119    while !stop.is_shutdown() {
120        match listener.accept() {
121            Ok((stream, addr)) => {
122                tracing::info!(%addr, "fixed follower connected");
123                stream.set_nodelay(true).ok();
124                // The listener is non-blocking for the acceptor loop; on
125                // BSD/macOS (and in some Linux configs) accepted streams
126                // inherit that flag.  We rely on blocking reads during
127                // Phase-1 Ack round-trips, so force it back.
128                stream.set_nonblocking(false).ok();
129                // Cap Phase-1 blocking reads so a hung follower can't pin
130                // a handler thread forever.  A healthy follower sends
131                // either an Ack immediately after each batch or a
132                // Heartbeat every HEARTBEAT_INTERVAL_SECS; 2× that gives
133                // a comfortable ceiling before we conclude the peer is
134                // gone.  (Phase 2 switches to non-blocking polling and
135                // overrides this setting locally.)
136                stream
137                    .set_read_timeout(Some(Duration::from_secs(HEARTBEAT_INTERVAL_SECS * 2)))
138                    .ok();
139
140                if !producers_installed.swap(true, Ordering::SeqCst) {
141                    tracing::info!("first fixed follower — installing SPSC producers");
142                    let mut producers = Vec::with_capacity(pending.len());
143                    for slot in pending.iter() {
144                        let mut guard = crate::sync::lock(slot);
145                        producers.push(guard.0.take().expect("producer present on first install"));
146                    }
147                    engine.install_replication_producers(producers);
148                }
149
150                let engine = engine.clone();
151                let pending = pending.clone();
152                let stop = stop.clone();
153                let hh = handler_handles.clone();
154                let handle = thread::spawn(move || {
155                    if let Err(e) = serve_connection(stream, engine, pending, stop) {
156                        tracing::error!(error = %e, "fixed replication connection error");
157                    }
158                });
159                let mut handles = crate::sync::lock(&hh);
160                handles.retain(|h| !h.is_finished());
161                handles.push(handle);
162            }
163            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
164                stop.wait_timeout(Duration::from_millis(50));
165            }
166            Err(e) => {
167                tracing::error!(error = %e, "fixed accept error");
168                stop.wait_timeout(Duration::from_millis(100));
169            }
170        }
171    }
172    tracing::info!("fixed replication acceptor stopped");
173}
174
175/// RAII guard that holds the SPSC consumer while a connection is active.
176///
177/// On drop (whether clean or due to an error) the consumer is returned to the
178/// `pending` slot so the next connecting follower can claim it and enter
179/// Phase-2 streaming.
180struct ShardConsumerGuard {
181    pending: Arc<Vec<PendingSlot>>,
182    shard_id: usize,
183    consumer: Option<Consumer<FixedReplicationEvent>>,
184}
185
186impl ShardConsumerGuard {
187    fn new(
188        pending: Arc<Vec<PendingSlot>>,
189        shard_id: usize,
190        consumer: Consumer<FixedReplicationEvent>,
191    ) -> Self {
192        Self {
193            pending,
194            shard_id,
195            consumer: Some(consumer),
196        }
197    }
198
199    fn take(&mut self) -> Consumer<FixedReplicationEvent> {
200        self.consumer.take().expect("consumer present")
201    }
202
203    fn put_back(&mut self, consumer: Consumer<FixedReplicationEvent>) {
204        self.consumer = Some(consumer);
205    }
206}
207
208impl Drop for ShardConsumerGuard {
209    fn drop(&mut self) {
210        if let Some(consumer) = self.consumer.take() {
211            let mut guard = crate::sync::lock(&self.pending[self.shard_id]);
212            if guard.1.is_none() {
213                guard.1 = Some(consumer);
214            }
215        }
216    }
217}
218
219fn serve_connection(
220    stream: TcpStream,
221    engine: ArcEngine,
222    pending: Arc<Vec<PendingSlot>>,
223    stop: ShutdownSignal,
224) -> DbResult<()> {
225    let mut reader = stream.try_clone().map_err(DbError::from)?;
226    let mut writer = BufWriter::new(stream);
227
228    // Read initial SyncRequest.
229    let frame = read_frame(&mut reader).map_err(DbError::from)?;
230    if frame.msg_type != FixedMessageType::SyncRequest {
231        return Err(DbError::Replication(format!(
232            "expected SyncRequest, got {:?}",
233            frame.msg_type
234        )));
235    }
236    let req = SyncRequest::decode(&frame.payload).map_err(DbError::from)?;
237    if req.protocol_version != PROTOCOL_VERSION {
238        let msg = format!(
239            "protocol version mismatch: leader {PROTOCOL_VERSION}, follower {}",
240            req.protocol_version
241        );
242        write_frame(&mut writer, &encode_error(&msg)).map_err(DbError::from)?;
243        return Err(DbError::Replication(msg));
244    }
245    let shard_id = req.shard_id as usize;
246    if shard_id >= engine.shard_count() {
247        write_frame(&mut writer, &encode_error("invalid shard_id")).map_err(DbError::from)?;
248        return Err(DbError::Replication(format!("invalid shard_id {shard_id}")));
249    }
250
251    // Send ShardInfo.
252    let info = ShardInfo {
253        shard_count: engine.shard_count() as u8,
254        key_len: engine.key_len() as u16,
255        value_len: engine.value_len() as u16,
256        slot_size: engine.slot_size(),
257        current_slot_count: engine.current_slot_count(shard_id),
258        shard_prefix_bits: engine.shard_prefix_bits(),
259    };
260    write_frame(&mut writer, &info.encode()).map_err(DbError::from)?;
261
262    let skip_deleted = (req.flags & FLAG_EMPTY_STATE) != 0;
263    tracing::info!(
264        shard_id,
265        skip_deleted,
266        protocol_version = req.protocol_version,
267        "fixed follower handshake complete"
268    );
269
270    // Take the SPSC consumer for this shard.
271    // If another connection already holds it, reject with an error frame.
272    let consumer = {
273        let mut guard = crate::sync::lock(&pending[shard_id]);
274        match guard.1.take() {
275            Some(c) => c,
276            None => {
277                let msg = "shard already streaming";
278                write_frame(&mut writer, &encode_error(msg)).map_err(DbError::from)?;
279                return Err(DbError::Replication(msg.to_string()));
280            }
281        }
282    };
283
284    // Wrap the consumer in a guard so it is returned to the pending slot on
285    // any exit path (clean shutdown, network error, or panic).
286    let mut consumer_guard = ShardConsumerGuard::new(pending.clone(), shard_id, consumer);
287
288    // Phase 1: full scan — consumer stays inside the guard so that if Phase 1
289    // fails the guard's Drop returns the consumer to the pending slot.
290    let total = phase1_full_scan(
291        &engine,
292        shard_id,
293        &mut writer,
294        &mut reader,
295        skip_deleted,
296        &stop,
297    )?;
298    write_frame(
299        &mut writer,
300        &CaughtUp {
301            shard_id: shard_id as u8,
302            total_scanned: total,
303        }
304        .encode(),
305    )
306    .map_err(DbError::from)?;
307    tracing::info!(shard_id, total, "fixed catch-up complete");
308
309    // Phase 2: streaming — take consumer only now (phase2 always returns it).
310    let consumer = consumer_guard.take();
311    let consumer = phase2_streaming(&engine, shard_id, consumer, &mut writer, &mut reader, &stop)?;
312    consumer_guard.put_back(consumer);
313    Ok(())
314}
315
316fn phase1_full_scan(
317    engine: &ArcEngine,
318    shard_id: usize,
319    writer: &mut BufWriter<TcpStream>,
320    reader: &mut TcpStream,
321    skip_deleted: bool,
322    stop: &ShutdownSignal,
323) -> DbResult<u64> {
324    use crate::fixed::slot::{
325        SLOT_HEADER_SIZE, STATUS_DELETED, STATUS_FREE, STATUS_OCCUPIED, meta_of, pack_meta,
326        read_slot, status_of, version_of,
327    };
328
329    let slot_size = engine.slot_size() as usize;
330    let key_len = engine.key_len();
331    let value_len = engine.value_len();
332    let slot_count = engine.current_slot_count(shard_id);
333    let slots_per_chunk = (SCAN_CHUNK_BYTES / slot_size).max(1);
334
335    let mut total_scanned = 0u64;
336    let mut batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
337
338    let mut slot_id = 0u32;
339    while slot_id < slot_count {
340        if stop.is_shutdown() {
341            return Ok(total_scanned);
342        }
343        let remaining = slot_count - slot_id;
344        let this_chunk = remaining.min(slots_per_chunk as u32) as usize;
345        let chunk = engine.read_shard_chunk(shard_id, slot_id, this_chunk)?;
346
347        for i in 0..this_chunk {
348            let off = i * slot_size;
349            let slot_buf = &chunk[off..off + slot_size];
350            let meta = meta_of(slot_buf);
351            let status = status_of(meta);
352            let current_slot = slot_id + i as u32;
353
354            match status {
355                STATUS_OCCUPIED => {
356                    if let Some((_meta, key, value)) = read_slot(slot_buf, key_len, value_len) {
357                        batch.add_occupied(current_slot, meta, key, value);
358                        total_scanned += 1;
359                    } else {
360                        let reset_meta = pack_meta(STATUS_FREE, version_of(meta));
361                        batch.add_reset(current_slot, reset_meta);
362                        total_scanned += 1;
363                    }
364                }
365                STATUS_DELETED => {
366                    if !skip_deleted {
367                        let key = &slot_buf[SLOT_HEADER_SIZE..SLOT_HEADER_SIZE + key_len];
368                        batch.add_deleted(current_slot, meta, key);
369                        total_scanned += 1;
370                    }
371                }
372                STATUS_FREE => {
373                    if version_of(meta) != 0 {
374                        let reset_meta = pack_meta(STATUS_FREE, version_of(meta));
375                        batch.add_reset(current_slot, reset_meta);
376                        total_scanned += 1;
377                    }
378                }
379                other => {
380                    return Err(DbError::Replication(format!(
381                        "fixed Phase-1 scan found unknown slot status {other} at shard {shard_id} slot {current_slot}"
382                    )));
383                }
384            }
385
386            if !batch.is_empty()
387                && (batch.len() as usize >= BATCH_MAX_ENTRIES || batch.bytes() >= BATCH_MAX_BYTES)
388            {
389                flush_and_wait_ack(writer, reader, batch, engine, shard_id)?;
390                batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
391            }
392        }
393
394        slot_id += this_chunk as u32;
395    }
396
397    if !batch.is_empty() {
398        flush_and_wait_ack(writer, reader, batch, engine, shard_id)?;
399    }
400
401    metrics::counter!(
402        "armdb.fixed.catchup_slots_scanned",
403        "shard" => shard_id.to_string()
404    )
405    .increment(total_scanned);
406
407    Ok(total_scanned)
408}
409
410fn flush_and_wait_ack(
411    writer: &mut BufWriter<TcpStream>,
412    reader: &mut TcpStream,
413    batch: SlotBatchEncoder,
414    engine: &ArcEngine,
415    shard_id: usize,
416) -> DbResult<()> {
417    // A well-behaved follower emits at most one Heartbeat per
418    // HEARTBEAT_INTERVAL_SECS (=5s) while idle.  The per-stream read
419    // timeout (set in `acceptor_loop`) bounds wall-clock time, but a
420    // peer that floods heartbeats at sub-timeout intervals could still
421    // keep the loop hot forever, so cap the count as defence-in-depth.
422    const MAX_HEARTBEATS: u32 = 8;
423
424    let frame = batch.finish();
425    write_frame(writer, &frame).map_err(DbError::from)?;
426    writer.flush().map_err(DbError::from)?;
427    // Loop until we consume an Ack. The follower may have queued a
428    // Heartbeat before our batch arrived (sent during its idle
429    // read-timeout branch); skip those and keep reading until Ack.
430    let mut heartbeats_skipped: u32 = 0;
431    loop {
432        let ack_frame = read_frame(reader).map_err(DbError::from)?;
433        match ack_frame.msg_type {
434            FixedMessageType::Ack => {
435                let ack = Ack::decode(&ack_frame.payload).map_err(DbError::from)?;
436                engine.update_min_replicated_version(shard_id, ack.max_version_seen);
437                return Ok(());
438            }
439            FixedMessageType::Heartbeat => {
440                // Follower-side keepalive while waiting for our batch; ignore.
441                heartbeats_skipped += 1;
442                if heartbeats_skipped > MAX_HEARTBEATS {
443                    return Err(DbError::Replication(format!(
444                        "too many consecutive heartbeats ({heartbeats_skipped}) \
445                         while waiting for Phase-1 Ack"
446                    )));
447                }
448                continue;
449            }
450            other => {
451                return Err(DbError::Replication(format!(
452                    "expected Ack during Phase-1 catch-up, got {other:?}"
453                )));
454            }
455        }
456    }
457}
458
459fn phase2_streaming(
460    engine: &ArcEngine,
461    shard_id: usize,
462    mut consumer: Consumer<FixedReplicationEvent>,
463    writer: &mut BufWriter<TcpStream>,
464    reader: &mut TcpStream,
465    stop: &ShutdownSignal,
466) -> DbResult<Consumer<FixedReplicationEvent>> {
467    use crate::fixed::slot::{SLOT_HEADER_SIZE, meta_of};
468
469    let key_len = engine.key_len();
470    let value_len = engine.value_len();
471    let slot_size = engine.slot_size() as usize;
472    let mut last_heartbeat = Instant::now();
473    let hb_interval = Duration::from_secs(HEARTBEAT_INTERVAL_SECS);
474
475    reader.set_nonblocking(true).ok();
476
477    loop {
478        if stop.is_shutdown() {
479            return Ok(consumer);
480        }
481
482        let mut batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
483        while (batch.len() as usize) < BATCH_MAX_ENTRIES && batch.bytes() < BATCH_MAX_BYTES {
484            match consumer.pop() {
485                Ok(FixedReplicationEvent::Write { slot_id, payload }) => {
486                    debug_assert_eq!(payload.len(), slot_size);
487                    let meta = meta_of(&payload);
488                    let key = &payload[SLOT_HEADER_SIZE..SLOT_HEADER_SIZE + key_len];
489                    let value = &payload
490                        [SLOT_HEADER_SIZE + key_len..SLOT_HEADER_SIZE + key_len + value_len];
491                    batch.add_occupied(slot_id, meta, key, value);
492                }
493                Ok(FixedReplicationEvent::Delete { slot_id, meta, key }) => {
494                    batch.add_deleted(slot_id, meta, &key);
495                }
496                Err(_) => break,
497            }
498        }
499
500        if !batch.is_empty() {
501            let frame_events = batch.len() as u64;
502            let frame = batch.finish();
503            if write_frame(writer, &frame)
504                .and_then(|_| writer.flush())
505                .is_err()
506            {
507                // Peer disconnected during batch write.
508                return Ok(consumer);
509            }
510            metrics::counter!(
511                "armdb.fixed.streaming_events_sent",
512                "shard" => shard_id.to_string()
513            )
514            .increment(frame_events);
515
516            // Non-blocking Ack check.
517            match read_frame(reader) {
518                Ok(f) if f.msg_type == FixedMessageType::Ack => {
519                    if let Ok(ack) = Ack::decode(&f.payload) {
520                        engine.update_min_replicated_version(shard_id, ack.max_version_seen);
521                    }
522                }
523                Ok(_) => {}
524                Err(ref e)
525                    if e.kind() == std::io::ErrorKind::WouldBlock
526                        || e.kind() == std::io::ErrorKind::TimedOut => {}
527                Err(e) => {
528                    // Peer disconnected or connection reset — treat as a clean
529                    // exit so the consumer can be returned to the pending slot
530                    // for the next connecting follower.
531                    tracing::debug!(shard_id, error = %e, "fixed streaming: peer disconnected");
532                    return Ok(consumer);
533                }
534            }
535        } else {
536            // Idle: no events in consumer ring buffer.
537            // Non-blocking probe to detect peer disconnect before the
538            // heartbeat timeout fires (so the consumer is returned promptly
539            // to the pending slot for the next reconnecting follower).
540            match read_frame(reader) {
541                Err(ref e)
542                    if e.kind() == std::io::ErrorKind::WouldBlock
543                        || e.kind() == std::io::ErrorKind::TimedOut => {}
544                Ok(f) if f.msg_type == FixedMessageType::Heartbeat => {
545                    // Follower keepalive; ignore in Phase 2 idle.
546                }
547                Ok(_) => {}
548                Err(e) => {
549                    // Peer disconnected — return consumer for next follower.
550                    tracing::debug!(shard_id, error = %e, "fixed streaming idle: peer disconnected");
551                    return Ok(consumer);
552                }
553            }
554            if last_heartbeat.elapsed() >= hb_interval {
555                if write_frame(writer, &encode_heartbeat())
556                    .and_then(|_| writer.flush())
557                    .is_err()
558                {
559                    // Peer disconnected during heartbeat write.
560                    return Ok(consumer);
561                }
562                last_heartbeat = Instant::now();
563            }
564            thread::sleep(Duration::from_millis(TAIL_POLL_MS));
565        }
566    }
567}