1use std::io::BufWriter;
20use std::io::Write as _;
21use std::net::{SocketAddr, TcpListener, TcpStream};
22use std::sync::Arc;
23use std::sync::atomic::{AtomicBool, Ordering};
24use std::thread::{self, JoinHandle};
25use std::time::{Duration, Instant};
26
27use rtrb::{Consumer, Producer, RingBuffer};
28
29use crate::error::{DbError, DbResult};
30use crate::shutdown::ShutdownSignal;
31
32use super::engine_access::ArcEngine;
33use super::event::FixedReplicationEvent;
34use super::protocol::*;
35
36pub const SPSC_CAPACITY: usize = 8192;
37const SCAN_CHUNK_BYTES: usize = 64 * 1024;
38
39type PendingSlot = crate::sync::Mutex<(
43 Option<Producer<FixedReplicationEvent>>,
44 Option<Consumer<FixedReplicationEvent>>,
45)>;
46
47pub struct FixedReplicationServer {
48 stop: ShutdownSignal,
49 acceptor_handle: Option<JoinHandle<()>>,
50 handler_handles: Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>,
51 #[allow(dead_code)]
52 producers_installed: Arc<AtomicBool>,
53}
54
55impl FixedReplicationServer {
56 pub fn start(
57 bind_addr: SocketAddr,
58 engine: ArcEngine,
59 signal: ShutdownSignal,
60 ) -> DbResult<Self> {
61 let shard_count = engine.shard_count();
62 let mut pending: Vec<PendingSlot> = Vec::with_capacity(shard_count);
63 for _ in 0..shard_count {
64 let (p, c) = RingBuffer::new(SPSC_CAPACITY);
65 pending.push(crate::sync::Mutex::new((Some(p), Some(c))));
66 }
67 let pending: Arc<Vec<PendingSlot>> = Arc::new(pending);
68 let producers_installed = Arc::new(AtomicBool::new(false));
69 let handler_handles = Arc::new(crate::sync::Mutex::new(Vec::new()));
70
71 let listener = TcpListener::bind(bind_addr).map_err(DbError::from)?;
72 listener.set_nonblocking(true).ok();
73
74 let acceptor_handle = {
75 let engine = engine.clone();
76 let pending = pending.clone();
77 let producers_installed = producers_installed.clone();
78 let stop = signal.clone();
79 let hh = handler_handles.clone();
80 thread::spawn(move || {
81 acceptor_loop(listener, engine, pending, producers_installed, hh, stop);
82 })
83 };
84
85 Ok(Self {
86 stop: signal,
87 acceptor_handle: Some(acceptor_handle),
88 handler_handles,
89 producers_installed,
90 })
91 }
92
93 pub fn stop(&self) {
94 self.stop.shutdown();
95 }
96}
97
98impl Drop for FixedReplicationServer {
99 fn drop(&mut self) {
100 self.stop.shutdown();
101 if let Some(h) = self.acceptor_handle.take() {
102 let _ = h.join();
103 }
104 let mut handles = crate::sync::lock(&self.handler_handles);
105 for h in handles.drain(..) {
106 let _ = h.join();
107 }
108 }
109}
110
111fn acceptor_loop(
112 listener: TcpListener,
113 engine: ArcEngine,
114 pending: Arc<Vec<PendingSlot>>,
115 producers_installed: Arc<AtomicBool>,
116 handler_handles: Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>,
117 stop: ShutdownSignal,
118) {
119 while !stop.is_shutdown() {
120 match listener.accept() {
121 Ok((stream, addr)) => {
122 tracing::info!(%addr, "fixed follower connected");
123 stream.set_nodelay(true).ok();
124 stream.set_nonblocking(false).ok();
129 stream
137 .set_read_timeout(Some(Duration::from_secs(HEARTBEAT_INTERVAL_SECS * 2)))
138 .ok();
139
140 if !producers_installed.swap(true, Ordering::SeqCst) {
141 tracing::info!("first fixed follower — installing SPSC producers");
142 let mut producers = Vec::with_capacity(pending.len());
143 for slot in pending.iter() {
144 let mut guard = crate::sync::lock(slot);
145 producers.push(guard.0.take().expect("producer present on first install"));
146 }
147 engine.install_replication_producers(producers);
148 }
149
150 let engine = engine.clone();
151 let pending = pending.clone();
152 let stop = stop.clone();
153 let hh = handler_handles.clone();
154 let handle = thread::spawn(move || {
155 if let Err(e) = serve_connection(stream, engine, pending, stop) {
156 tracing::error!(error = %e, "fixed replication connection error");
157 }
158 });
159 let mut handles = crate::sync::lock(&hh);
160 handles.retain(|h| !h.is_finished());
161 handles.push(handle);
162 }
163 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
164 stop.wait_timeout(Duration::from_millis(50));
165 }
166 Err(e) => {
167 tracing::error!(error = %e, "fixed accept error");
168 stop.wait_timeout(Duration::from_millis(100));
169 }
170 }
171 }
172 tracing::info!("fixed replication acceptor stopped");
173}
174
175struct ShardConsumerGuard {
181 pending: Arc<Vec<PendingSlot>>,
182 shard_id: usize,
183 consumer: Option<Consumer<FixedReplicationEvent>>,
184}
185
186impl ShardConsumerGuard {
187 fn new(
188 pending: Arc<Vec<PendingSlot>>,
189 shard_id: usize,
190 consumer: Consumer<FixedReplicationEvent>,
191 ) -> Self {
192 Self {
193 pending,
194 shard_id,
195 consumer: Some(consumer),
196 }
197 }
198
199 fn take(&mut self) -> Consumer<FixedReplicationEvent> {
200 self.consumer.take().expect("consumer present")
201 }
202
203 fn put_back(&mut self, consumer: Consumer<FixedReplicationEvent>) {
204 self.consumer = Some(consumer);
205 }
206}
207
208impl Drop for ShardConsumerGuard {
209 fn drop(&mut self) {
210 if let Some(consumer) = self.consumer.take() {
211 let mut guard = crate::sync::lock(&self.pending[self.shard_id]);
212 if guard.1.is_none() {
213 guard.1 = Some(consumer);
214 }
215 }
216 }
217}
218
219fn serve_connection(
220 stream: TcpStream,
221 engine: ArcEngine,
222 pending: Arc<Vec<PendingSlot>>,
223 stop: ShutdownSignal,
224) -> DbResult<()> {
225 let mut reader = stream.try_clone().map_err(DbError::from)?;
226 let mut writer = BufWriter::new(stream);
227
228 let frame = read_frame(&mut reader).map_err(DbError::from)?;
230 if frame.msg_type != FixedMessageType::SyncRequest {
231 return Err(DbError::Replication(format!(
232 "expected SyncRequest, got {:?}",
233 frame.msg_type
234 )));
235 }
236 let req = SyncRequest::decode(&frame.payload).map_err(DbError::from)?;
237 if req.protocol_version != PROTOCOL_VERSION {
238 let msg = format!(
239 "protocol version mismatch: leader {PROTOCOL_VERSION}, follower {}",
240 req.protocol_version
241 );
242 write_frame(&mut writer, &encode_error(&msg)).map_err(DbError::from)?;
243 return Err(DbError::Replication(msg));
244 }
245 let shard_id = req.shard_id as usize;
246 if shard_id >= engine.shard_count() {
247 write_frame(&mut writer, &encode_error("invalid shard_id")).map_err(DbError::from)?;
248 return Err(DbError::Replication(format!("invalid shard_id {shard_id}")));
249 }
250
251 let info = ShardInfo {
253 shard_count: engine.shard_count() as u8,
254 key_len: engine.key_len() as u16,
255 value_len: engine.value_len() as u16,
256 slot_size: engine.slot_size(),
257 current_slot_count: engine.current_slot_count(shard_id),
258 shard_prefix_bits: engine.shard_prefix_bits(),
259 };
260 write_frame(&mut writer, &info.encode()).map_err(DbError::from)?;
261
262 let skip_deleted = (req.flags & FLAG_EMPTY_STATE) != 0;
263 tracing::info!(
264 shard_id,
265 skip_deleted,
266 protocol_version = req.protocol_version,
267 "fixed follower handshake complete"
268 );
269
270 let consumer = {
273 let mut guard = crate::sync::lock(&pending[shard_id]);
274 match guard.1.take() {
275 Some(c) => c,
276 None => {
277 let msg = "shard already streaming";
278 write_frame(&mut writer, &encode_error(msg)).map_err(DbError::from)?;
279 return Err(DbError::Replication(msg.to_string()));
280 }
281 }
282 };
283
284 let mut consumer_guard = ShardConsumerGuard::new(pending.clone(), shard_id, consumer);
287
288 let total = phase1_full_scan(
291 &engine,
292 shard_id,
293 &mut writer,
294 &mut reader,
295 skip_deleted,
296 &stop,
297 )?;
298 write_frame(
299 &mut writer,
300 &CaughtUp {
301 shard_id: shard_id as u8,
302 total_scanned: total,
303 }
304 .encode(),
305 )
306 .map_err(DbError::from)?;
307 tracing::info!(shard_id, total, "fixed catch-up complete");
308
309 let consumer = consumer_guard.take();
311 let consumer = phase2_streaming(&engine, shard_id, consumer, &mut writer, &mut reader, &stop)?;
312 consumer_guard.put_back(consumer);
313 Ok(())
314}
315
316fn phase1_full_scan(
317 engine: &ArcEngine,
318 shard_id: usize,
319 writer: &mut BufWriter<TcpStream>,
320 reader: &mut TcpStream,
321 skip_deleted: bool,
322 stop: &ShutdownSignal,
323) -> DbResult<u64> {
324 use crate::fixed::slot::{
325 SLOT_HEADER_SIZE, STATUS_DELETED, STATUS_FREE, STATUS_OCCUPIED, meta_of, pack_meta,
326 read_slot, status_of, version_of,
327 };
328
329 let slot_size = engine.slot_size() as usize;
330 let key_len = engine.key_len();
331 let value_len = engine.value_len();
332 let slot_count = engine.current_slot_count(shard_id);
333 let slots_per_chunk = (SCAN_CHUNK_BYTES / slot_size).max(1);
334
335 let mut total_scanned = 0u64;
336 let mut batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
337
338 let mut slot_id = 0u32;
339 while slot_id < slot_count {
340 if stop.is_shutdown() {
341 return Ok(total_scanned);
342 }
343 let remaining = slot_count - slot_id;
344 let this_chunk = remaining.min(slots_per_chunk as u32) as usize;
345 let chunk = engine.read_shard_chunk(shard_id, slot_id, this_chunk)?;
346
347 for i in 0..this_chunk {
348 let off = i * slot_size;
349 let slot_buf = &chunk[off..off + slot_size];
350 let meta = meta_of(slot_buf);
351 let status = status_of(meta);
352 let current_slot = slot_id + i as u32;
353
354 match status {
355 STATUS_OCCUPIED => {
356 if let Some((_meta, key, value)) = read_slot(slot_buf, key_len, value_len) {
357 batch.add_occupied(current_slot, meta, key, value);
358 total_scanned += 1;
359 } else {
360 let reset_meta = pack_meta(STATUS_FREE, version_of(meta));
361 batch.add_reset(current_slot, reset_meta);
362 total_scanned += 1;
363 }
364 }
365 STATUS_DELETED => {
366 if !skip_deleted {
367 let key = &slot_buf[SLOT_HEADER_SIZE..SLOT_HEADER_SIZE + key_len];
368 batch.add_deleted(current_slot, meta, key);
369 total_scanned += 1;
370 }
371 }
372 STATUS_FREE => {
373 if version_of(meta) != 0 {
374 let reset_meta = pack_meta(STATUS_FREE, version_of(meta));
375 batch.add_reset(current_slot, reset_meta);
376 total_scanned += 1;
377 }
378 }
379 other => {
380 return Err(DbError::Replication(format!(
381 "fixed Phase-1 scan found unknown slot status {other} at shard {shard_id} slot {current_slot}"
382 )));
383 }
384 }
385
386 if !batch.is_empty()
387 && (batch.len() as usize >= BATCH_MAX_ENTRIES || batch.bytes() >= BATCH_MAX_BYTES)
388 {
389 flush_and_wait_ack(writer, reader, batch, engine, shard_id)?;
390 batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
391 }
392 }
393
394 slot_id += this_chunk as u32;
395 }
396
397 if !batch.is_empty() {
398 flush_and_wait_ack(writer, reader, batch, engine, shard_id)?;
399 }
400
401 metrics::counter!(
402 "armdb.fixed.catchup_slots_scanned",
403 "shard" => shard_id.to_string()
404 )
405 .increment(total_scanned);
406
407 Ok(total_scanned)
408}
409
410fn flush_and_wait_ack(
411 writer: &mut BufWriter<TcpStream>,
412 reader: &mut TcpStream,
413 batch: SlotBatchEncoder,
414 engine: &ArcEngine,
415 shard_id: usize,
416) -> DbResult<()> {
417 const MAX_HEARTBEATS: u32 = 8;
423
424 let frame = batch.finish();
425 write_frame(writer, &frame).map_err(DbError::from)?;
426 writer.flush().map_err(DbError::from)?;
427 let mut heartbeats_skipped: u32 = 0;
431 loop {
432 let ack_frame = read_frame(reader).map_err(DbError::from)?;
433 match ack_frame.msg_type {
434 FixedMessageType::Ack => {
435 let ack = Ack::decode(&ack_frame.payload).map_err(DbError::from)?;
436 engine.update_min_replicated_version(shard_id, ack.max_version_seen);
437 return Ok(());
438 }
439 FixedMessageType::Heartbeat => {
440 heartbeats_skipped += 1;
442 if heartbeats_skipped > MAX_HEARTBEATS {
443 return Err(DbError::Replication(format!(
444 "too many consecutive heartbeats ({heartbeats_skipped}) \
445 while waiting for Phase-1 Ack"
446 )));
447 }
448 continue;
449 }
450 other => {
451 return Err(DbError::Replication(format!(
452 "expected Ack during Phase-1 catch-up, got {other:?}"
453 )));
454 }
455 }
456 }
457}
458
459fn phase2_streaming(
460 engine: &ArcEngine,
461 shard_id: usize,
462 mut consumer: Consumer<FixedReplicationEvent>,
463 writer: &mut BufWriter<TcpStream>,
464 reader: &mut TcpStream,
465 stop: &ShutdownSignal,
466) -> DbResult<Consumer<FixedReplicationEvent>> {
467 use crate::fixed::slot::{SLOT_HEADER_SIZE, meta_of};
468
469 let key_len = engine.key_len();
470 let value_len = engine.value_len();
471 let slot_size = engine.slot_size() as usize;
472 let mut last_heartbeat = Instant::now();
473 let hb_interval = Duration::from_secs(HEARTBEAT_INTERVAL_SECS);
474
475 reader.set_nonblocking(true).ok();
476
477 loop {
478 if stop.is_shutdown() {
479 return Ok(consumer);
480 }
481
482 let mut batch = SlotBatchEncoder::new(shard_id as u8, key_len, value_len);
483 while (batch.len() as usize) < BATCH_MAX_ENTRIES && batch.bytes() < BATCH_MAX_BYTES {
484 match consumer.pop() {
485 Ok(FixedReplicationEvent::Write { slot_id, payload }) => {
486 debug_assert_eq!(payload.len(), slot_size);
487 let meta = meta_of(&payload);
488 let key = &payload[SLOT_HEADER_SIZE..SLOT_HEADER_SIZE + key_len];
489 let value = &payload
490 [SLOT_HEADER_SIZE + key_len..SLOT_HEADER_SIZE + key_len + value_len];
491 batch.add_occupied(slot_id, meta, key, value);
492 }
493 Ok(FixedReplicationEvent::Delete { slot_id, meta, key }) => {
494 batch.add_deleted(slot_id, meta, &key);
495 }
496 Err(_) => break,
497 }
498 }
499
500 if !batch.is_empty() {
501 let frame_events = batch.len() as u64;
502 let frame = batch.finish();
503 if write_frame(writer, &frame)
504 .and_then(|_| writer.flush())
505 .is_err()
506 {
507 return Ok(consumer);
509 }
510 metrics::counter!(
511 "armdb.fixed.streaming_events_sent",
512 "shard" => shard_id.to_string()
513 )
514 .increment(frame_events);
515
516 match read_frame(reader) {
518 Ok(f) if f.msg_type == FixedMessageType::Ack => {
519 if let Ok(ack) = Ack::decode(&f.payload) {
520 engine.update_min_replicated_version(shard_id, ack.max_version_seen);
521 }
522 }
523 Ok(_) => {}
524 Err(ref e)
525 if e.kind() == std::io::ErrorKind::WouldBlock
526 || e.kind() == std::io::ErrorKind::TimedOut => {}
527 Err(e) => {
528 tracing::debug!(shard_id, error = %e, "fixed streaming: peer disconnected");
532 return Ok(consumer);
533 }
534 }
535 } else {
536 match read_frame(reader) {
541 Err(ref e)
542 if e.kind() == std::io::ErrorKind::WouldBlock
543 || e.kind() == std::io::ErrorKind::TimedOut => {}
544 Ok(f) if f.msg_type == FixedMessageType::Heartbeat => {
545 }
547 Ok(_) => {}
548 Err(e) => {
549 tracing::debug!(shard_id, error = %e, "fixed streaming idle: peer disconnected");
551 return Ok(consumer);
552 }
553 }
554 if last_heartbeat.elapsed() >= hb_interval {
555 if write_frame(writer, &encode_heartbeat())
556 .and_then(|_| writer.flush())
557 .is_err()
558 {
559 return Ok(consumer);
561 }
562 last_heartbeat = Instant::now();
563 }
564 thread::sleep(Duration::from_millis(TAIL_POLL_MS));
565 }
566 }
567}