Skip to main content

fast_cache/replication/
transport.rs

1//! TCP wire transport for native FCRP replication.
2//!
3//! The transport is intentionally minimal and synchronous: one accept thread
4//! per primary, one per-replica worker thread for streaming, and one connect
5//! thread per replica. Frames on the wire are encoded by [`encode_frame`] in
6//! [`super::protocol`], so capture/replay tools that already understand FCRP
7//! frames work unchanged.
8
9use std::io::{self, Read, Write};
10use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::thread::{self, JoinHandle};
14use std::time::Duration;
15
16use crossbeam_channel::{RecvTimeoutError, TryRecvError};
17use parking_lot::Mutex;
18
19use crate::config::ReplicationConfig;
20use crate::storage::StoredEntry;
21use crate::{FastCacheError, Result};
22
23use super::ReplicationFrameBytes;
24use super::backlog::BacklogCatchUp;
25use super::batcher::ReplicationPrimary;
26use super::embedded::{ReplicatedEmbeddedStore, ReplicationReplica};
27use super::protocol::{
28    FCRP_VERSION, FrameKind, HelloRole, ReplicationCompressionMode, ReplicationHello,
29    ReplicationSnapshotChunk, ShardWatermarks, decode_ack, decode_error, decode_frame,
30    decode_frame_payload_bytes, decode_hello, decode_snapshot_chunk, encode_ack, encode_error,
31    encode_frame, encode_hello, encode_snapshot_chunk,
32};
33
34#[cfg(all(target_os = "linux", feature = "monoio"))]
35mod monoio_transport;
36
37const FRAME_HEADER_LEN: usize = 16;
38const MAX_FRAME_BYTES: usize = 256 * 1024 * 1024;
39
40/// Provides consistent snapshots to the replication transport.
41pub trait SnapshotProvider: Send + Sync + 'static {
42    /// Returns a consistent snapshot together with the watermarks captured at
43    /// the same logical point.
44    fn snapshot(&self) -> super::protocol::ReplicationSnapshot;
45}
46
47impl SnapshotProvider for ReplicatedEmbeddedStore {
48    fn snapshot(&self) -> super::protocol::ReplicationSnapshot {
49        ReplicatedEmbeddedStore::snapshot(self)
50    }
51}
52
53/// Handle to a primary's TCP listener thread.
54#[derive(Debug)]
55pub struct ReplicationPrimaryServer {
56    stop: Arc<AtomicBool>,
57    join: Mutex<Option<JoinHandle<()>>>,
58}
59
60impl ReplicationPrimaryServer {
61    /// Binds a TCP listener and serves replicas using `primary` and `snapshots`.
62    pub fn start(
63        config: ReplicationConfig,
64        primary: Arc<ReplicationPrimary>,
65        snapshots: Arc<dyn SnapshotProvider>,
66    ) -> Result<Self> {
67        if !config.enabled {
68            return Err(FastCacheError::Config(
69                "replication primary server requires replication.enabled = true".into(),
70            ));
71        }
72        #[cfg(all(target_os = "linux", feature = "monoio"))]
73        if monoio_transport::should_use() {
74            return monoio_transport::start_primary(config, primary, snapshots);
75        }
76
77        let listener = TcpListener::bind(&config.bind_addr).map_err(|error| {
78            FastCacheError::Config(format!(
79                "replication primary failed to bind {}: {error}",
80                config.bind_addr
81            ))
82        })?;
83        listener.set_nonblocking(true).map_err(|error| {
84            FastCacheError::Config(format!(
85                "replication primary set_nonblocking failed: {error}"
86            ))
87        })?;
88        let stop = Arc::new(AtomicBool::new(false));
89        let stop_clone = Arc::clone(&stop);
90        let cfg = config;
91        let join = thread::Builder::new()
92            .name("fast-cache-replication-listener".into())
93            .spawn(move || run_listener(listener, cfg, primary, snapshots, stop_clone))
94            .map_err(|error| {
95                FastCacheError::Config(format!("failed to start replication listener: {error}"))
96            })?;
97        Ok(Self {
98            stop,
99            join: Mutex::new(Some(join)),
100        })
101    }
102
103    #[cfg(all(target_os = "linux", feature = "monoio"))]
104    fn from_join(stop: Arc<AtomicBool>, join: JoinHandle<()>) -> Self {
105        Self {
106            stop,
107            join: Mutex::new(Some(join)),
108        }
109    }
110
111    pub fn shutdown(&self) -> Result<()> {
112        self.stop.store(true, Ordering::SeqCst);
113        if let Some(join) = self.join.lock().take() {
114            join.join()
115                .map_err(|_| FastCacheError::TaskJoin("replication listener panicked".into()))?;
116        }
117        Ok(())
118    }
119}
120
121impl Drop for ReplicationPrimaryServer {
122    fn drop(&mut self) {
123        let _ = self.shutdown();
124    }
125}
126
127fn run_listener(
128    listener: TcpListener,
129    config: ReplicationConfig,
130    primary: Arc<ReplicationPrimary>,
131    snapshots: Arc<dyn SnapshotProvider>,
132    stop: Arc<AtomicBool>,
133) {
134    let active = Arc::new(parking_lot::Mutex::new(Vec::<JoinHandle<()>>::new()));
135    while !stop.load(Ordering::SeqCst) {
136        match listener.accept() {
137            Ok((stream, peer)) => {
138                // Cap simultaneous replicas. Drain finished workers first.
139                let mut handles = active.lock();
140                handles.retain(|h| !h.is_finished());
141                if handles.len() >= config.max_replicas {
142                    drop(handles);
143                    tracing::warn!(
144                        "rejecting replication client {peer}: max_replicas {} reached",
145                        config.max_replicas
146                    );
147                    let _ = stream.shutdown(std::net::Shutdown::Both);
148                    continue;
149                }
150                let cfg = config.clone();
151                let primary = Arc::clone(&primary);
152                let snapshots = Arc::clone(&snapshots);
153                let stop = Arc::clone(&stop);
154                let handle = thread::Builder::new()
155                    .name(format!("fast-cache-replication-worker-{peer}"))
156                    .spawn(move || {
157                        if let Err(error) =
158                            serve_replica(stream, peer, cfg, primary, snapshots, stop)
159                        {
160                            tracing::warn!("replication worker for {peer} terminated: {error}");
161                        }
162                    });
163                match handle {
164                    Ok(h) => handles.push(h),
165                    Err(error) => tracing::warn!("failed to spawn replication worker: {error}"),
166                }
167            }
168            Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
169                thread::sleep(Duration::from_millis(10));
170            }
171            Err(error) => {
172                tracing::warn!("replication listener accept failed: {error}");
173                thread::sleep(Duration::from_millis(50));
174            }
175        }
176    }
177    // Best-effort join of workers; they observe the stop flag and exit.
178    let mut handles = active.lock();
179    for h in handles.drain(..) {
180        let _ = h.join();
181    }
182}
183
184fn serve_replica(
185    mut stream: TcpStream,
186    peer: SocketAddr,
187    config: ReplicationConfig,
188    primary: Arc<ReplicationPrimary>,
189    snapshots: Arc<dyn SnapshotProvider>,
190    stop: Arc<AtomicBool>,
191) -> Result<()> {
192    stream.set_nodelay(true).ok();
193    stream
194        .set_read_timeout(Some(Duration::from_millis(
195            config.connect_timeout_ms.max(1),
196        )))
197        .ok();
198    stream
199        .set_write_timeout(Some(Duration::from_millis(config.write_timeout_ms.max(1))))
200        .ok();
201
202    let hello_frame = match read_frame_bytes_interruptible(&mut stream, &stop)? {
203        Some(bytes) => bytes,
204        None => return Ok(()),
205    };
206    let frame = decode_frame(&hello_frame)?;
207    if frame.kind != FrameKind::Hello {
208        send_error(&mut stream, "expected Hello frame")?;
209        return Err(FastCacheError::Protocol(format!(
210            "replica {peer} sent {:?} before Hello",
211            frame.kind
212        )));
213    }
214    let hello = decode_hello(&frame.payload)?;
215    if hello.version != FCRP_VERSION {
216        send_error(&mut stream, "unsupported FCRP version")?;
217        return Err(FastCacheError::Protocol(format!(
218            "replica {peer} requested FCRP version {}",
219            hello.version
220        )));
221    }
222    if !auth_ok(config.auth_token.as_deref(), hello.auth_token.as_deref()) {
223        send_error(&mut stream, "invalid auth token")?;
224        return Err(FastCacheError::Protocol(format!(
225            "replica {peer} sent invalid auth token"
226        )));
227    }
228    // After authentication clear the read timeout so the worker can wait
229    // indefinitely on the (silent) reverse channel without spuriously dying.
230    stream.set_read_timeout(None).ok();
231    let ack = ReplicationHello {
232        version: FCRP_VERSION,
233        role: HelloRole::Replica,
234        auth_token: None,
235        since: Some(primary.current_watermarks()),
236    };
237    write_full_frame(
238        &mut stream,
239        FrameKind::Hello,
240        ReplicationCompressionMode::None,
241        0,
242        &encode_hello(&ack),
243    )?;
244
245    // Subscribe BEFORE deciding whether to snapshot, so we don't drop frames
246    // that flush during the snapshot window.
247    let subscription = primary.subscribe(config.subscriber_channel_capacity);
248
249    let since = hello
250        .since
251        .clone()
252        .unwrap_or_else(|| ShardWatermarks::new(primary.shard_count()));
253    let live_start = match primary.catch_up_since(&since)? {
254        BacklogCatchUp::Available(frames) => {
255            for frame in &frames {
256                write_raw(&mut stream, frame.as_ref())?;
257            }
258            primary.current_watermarks()
259        }
260        BacklogCatchUp::NeedsSnapshot => {
261            let snapshot = snapshots.snapshot();
262            stream_snapshot(&mut stream, &snapshot, &config)?;
263            snapshot.watermarks
264        }
265    };
266
267    // Drain any frames that were broadcast while we were sending the
268    // snapshot, then enter the steady-state forwarding loop.
269    drain_buffered(&mut stream, &subscription, &live_start, &primary)?;
270
271    while !stop.load(Ordering::SeqCst) {
272        match subscription.recv_timeout(Duration::from_millis(100)) {
273            Ok(frame) => write_raw(&mut stream, frame.as_ref())?,
274            Err(RecvTimeoutError::Timeout) => {}
275            Err(RecvTimeoutError::Disconnected) => break,
276        }
277    }
278    Ok(())
279}
280
281fn drain_buffered(
282    stream: &mut TcpStream,
283    subscription: &crossbeam_channel::Receiver<ReplicationFrameBytes>,
284    bootstrap_high: &ShardWatermarks,
285    primary: &Arc<ReplicationPrimary>,
286) -> Result<()> {
287    loop {
288        match subscription.try_recv() {
289            Ok(frame) => write_raw(stream, frame.as_ref())?,
290            Err(TryRecvError::Empty) => break,
291            Err(TryRecvError::Disconnected) => break,
292        }
293    }
294    // The subscriber channel may have started filling AFTER we read the
295    // bootstrap watermarks; the inflight backlog covers that gap.
296    if let BacklogCatchUp::Available(frames) = primary.catch_up_since(bootstrap_high)? {
297        for frame in frames {
298            // De-duplication happens on the replica side via per-shard
299            // watermark comparison, so re-sending these frames is safe.
300            write_raw(stream, frame.as_ref())?;
301        }
302    }
303    Ok(())
304}
305
306fn stream_snapshot(
307    stream: &mut TcpStream,
308    snapshot: &super::protocol::ReplicationSnapshot,
309    config: &ReplicationConfig,
310) -> Result<()> {
311    write_full_frame(
312        stream,
313        FrameKind::SnapshotBegin,
314        ReplicationCompressionMode::None,
315        0,
316        &encode_ack(&snapshot.watermarks),
317    )?;
318
319    let target = config.snapshot_chunk_bytes.max(4 * 1024);
320    let mut chunk_index = 0u64;
321    let mut buffer: Vec<crate::storage::StoredEntry> = Vec::new();
322    let mut buffer_bytes = 0usize;
323    let total = snapshot.entries.len();
324    let compression = ReplicationCompressionMode::from(config.compression);
325
326    for (idx, entry) in snapshot.entries.iter().enumerate() {
327        let entry_bytes = entry.key.len() + entry.value.len() + 32;
328        buffer.push(entry.clone());
329        buffer_bytes = buffer_bytes.saturating_add(entry_bytes);
330        let is_last_entry = idx + 1 == total;
331        if buffer_bytes >= target || is_last_entry {
332            let chunk = ReplicationSnapshotChunk {
333                watermarks: snapshot.watermarks.clone(),
334                chunk_index,
335                is_last: is_last_entry,
336                entries: std::mem::take(&mut buffer),
337            };
338            buffer_bytes = 0;
339            chunk_index += 1;
340            let payload = encode_snapshot_chunk(&chunk);
341            write_full_frame(
342                stream,
343                FrameKind::SnapshotChunk,
344                compression,
345                config.zstd_level,
346                &payload,
347            )?;
348        }
349    }
350    if total == 0 {
351        let chunk = ReplicationSnapshotChunk {
352            watermarks: snapshot.watermarks.clone(),
353            chunk_index: 0,
354            is_last: true,
355            entries: Vec::new(),
356        };
357        let payload = encode_snapshot_chunk(&chunk);
358        write_full_frame(
359            stream,
360            FrameKind::SnapshotChunk,
361            ReplicationCompressionMode::None,
362            0,
363            &payload,
364        )?;
365    }
366    write_full_frame(
367        stream,
368        FrameKind::SnapshotEnd,
369        ReplicationCompressionMode::None,
370        0,
371        &encode_ack(&snapshot.watermarks),
372    )?;
373    Ok(())
374}
375
376fn send_error(stream: &mut TcpStream, message: &str) -> Result<()> {
377    write_full_frame(
378        stream,
379        FrameKind::Error,
380        ReplicationCompressionMode::None,
381        0,
382        &encode_error(message),
383    )
384}
385
386fn auth_ok(expected: Option<&str>, presented: Option<&str>) -> bool {
387    match (expected, presented) {
388        (None, _) => true,
389        (Some(want), Some(got)) => want == got,
390        (Some(_), None) => false,
391    }
392}
393
394fn write_full_frame(
395    stream: &mut TcpStream,
396    kind: FrameKind,
397    compression: ReplicationCompressionMode,
398    zstd_level: i32,
399    payload: &[u8],
400) -> Result<()> {
401    let frame = encode_frame(kind, compression, zstd_level, payload)?;
402    write_raw(stream, &frame)
403}
404
405fn write_raw(stream: &mut TcpStream, bytes: &[u8]) -> Result<()> {
406    stream.write_all(bytes).map_err(FastCacheError::Io)
407}
408
409fn read_frame_bytes(stream: &mut TcpStream) -> Result<Vec<u8>> {
410    read_frame_inner(stream, None).and_then(|opt| {
411        opt.ok_or_else(|| {
412            FastCacheError::Io(io::Error::new(
413                io::ErrorKind::UnexpectedEof,
414                "FCRP stream closed before frame completed",
415            ))
416        })
417    })
418}
419
420fn read_frame_bytes_interruptible(
421    stream: &mut TcpStream,
422    stop: &Arc<AtomicBool>,
423) -> Result<Option<Vec<u8>>> {
424    read_frame_inner(stream, Some(stop))
425}
426
427fn read_frame_inner(
428    stream: &mut TcpStream,
429    stop: Option<&Arc<AtomicBool>>,
430) -> Result<Option<Vec<u8>>> {
431    let mut header = [0_u8; FRAME_HEADER_LEN];
432    match read_fully(stream, &mut header, stop)? {
433        ReadResult::Done => {}
434        ReadResult::Stopped => return Ok(None),
435        ReadResult::Eof => {
436            return Err(FastCacheError::Io(io::Error::new(
437                io::ErrorKind::UnexpectedEof,
438                "FCRP stream closed mid-header",
439            )));
440        }
441    }
442    let payload_len = u32::from_le_bytes(header[8..12].try_into().unwrap()) as usize;
443    if payload_len > MAX_FRAME_BYTES {
444        return Err(FastCacheError::Protocol(format!(
445            "FCRP frame payload exceeds limit ({payload_len} bytes)"
446        )));
447    }
448    let mut frame = Vec::with_capacity(FRAME_HEADER_LEN + payload_len);
449    frame.extend_from_slice(&header);
450    frame.resize(FRAME_HEADER_LEN + payload_len, 0);
451    match read_fully(stream, &mut frame[FRAME_HEADER_LEN..], stop)? {
452        ReadResult::Done => Ok(Some(frame)),
453        ReadResult::Stopped => Ok(None),
454        ReadResult::Eof => Err(FastCacheError::Io(io::Error::new(
455            io::ErrorKind::UnexpectedEof,
456            "FCRP stream closed mid-payload",
457        ))),
458    }
459}
460
461enum ReadResult {
462    Done,
463    Stopped,
464    Eof,
465}
466
467fn read_fully(
468    stream: &mut TcpStream,
469    buffer: &mut [u8],
470    stop: Option<&Arc<AtomicBool>>,
471) -> Result<ReadResult> {
472    let mut filled = 0;
473    while filled < buffer.len() {
474        match stream.read(&mut buffer[filled..]) {
475            Ok(0) => return Ok(ReadResult::Eof),
476            Ok(n) => filled += n,
477            Err(error) if is_timeout(&error) => match stop {
478                Some(stop) => {
479                    if stop.load(Ordering::SeqCst) {
480                        return Ok(ReadResult::Stopped);
481                    }
482                    continue;
483                }
484                None => return Err(FastCacheError::Io(error)),
485            },
486            Err(error) => return Err(FastCacheError::Io(error)),
487        }
488    }
489    Ok(ReadResult::Done)
490}
491
492fn is_timeout(error: &io::Error) -> bool {
493    matches!(
494        error.kind(),
495        io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
496    )
497}
498
499/// Handle to a replica's TCP connector thread.
500#[derive(Debug)]
501pub struct ReplicationReplicaClient {
502    stop: Arc<AtomicBool>,
503    join: Mutex<Option<JoinHandle<()>>>,
504    state: Arc<Mutex<ReplicationReplica>>,
505}
506
507impl ReplicationReplicaClient {
508    /// Starts a replica that connects to `config.replica_of`, bootstraps via
509    /// snapshot or backlog, and streams live mutations.
510    pub fn start(config: ReplicationConfig) -> Result<Self> {
511        if !config.enabled {
512            return Err(FastCacheError::Config(
513                "replication replica requires replication.enabled = true".into(),
514            ));
515        }
516        let upstream = config.replica_of.clone().ok_or_else(|| {
517            FastCacheError::Config("replication.replica_of is required for replica role".into())
518        })?;
519        #[cfg(all(target_os = "linux", feature = "monoio"))]
520        if monoio_transport::should_use() {
521            return monoio_transport::start_replica(upstream, config);
522        }
523
524        let stop = Arc::new(AtomicBool::new(false));
525        let state = Arc::new(Mutex::new(ReplicationReplica::new(1)));
526        let cfg = config;
527        let stop_clone = Arc::clone(&stop);
528        let state_clone = Arc::clone(&state);
529        let join = thread::Builder::new()
530            .name("fast-cache-replication-replica".into())
531            .spawn(move || run_replica_client(upstream, cfg, state_clone, stop_clone))
532            .map_err(|error| {
533                FastCacheError::Config(format!("failed to start replica client: {error}"))
534            })?;
535        Ok(Self {
536            stop,
537            join: Mutex::new(Some(join)),
538            state,
539        })
540    }
541
542    #[cfg(all(target_os = "linux", feature = "monoio"))]
543    fn from_join(
544        stop: Arc<AtomicBool>,
545        join: JoinHandle<()>,
546        state: Arc<Mutex<ReplicationReplica>>,
547    ) -> Self {
548        Self {
549            stop,
550            join: Mutex::new(Some(join)),
551            state,
552        }
553    }
554
555    /// Returns the live replica handle. Holds a mutex while in use, so prefer
556    /// short, read-only operations.
557    pub fn replica(&self) -> Arc<Mutex<ReplicationReplica>> {
558        Arc::clone(&self.state)
559    }
560
561    pub fn shutdown(&self) -> Result<()> {
562        self.stop.store(true, Ordering::SeqCst);
563        if let Some(join) = self.join.lock().take() {
564            join.join()
565                .map_err(|_| FastCacheError::TaskJoin("replication replica panicked".into()))?;
566        }
567        Ok(())
568    }
569}
570
571impl Drop for ReplicationReplicaClient {
572    fn drop(&mut self) {
573        let _ = self.shutdown();
574    }
575}
576
577fn run_replica_client(
578    upstream: String,
579    config: ReplicationConfig,
580    state: Arc<Mutex<ReplicationReplica>>,
581    stop: Arc<AtomicBool>,
582) {
583    while !stop.load(Ordering::SeqCst) {
584        match connect_and_stream(&upstream, &config, &state, &stop) {
585            Ok(()) => {}
586            Err(error) => {
587                tracing::warn!("replication replica disconnected: {error}");
588            }
589        }
590        if stop.load(Ordering::SeqCst) {
591            break;
592        }
593        let backoff = Duration::from_millis(config.reconnect_backoff_ms.max(1));
594        let step = Duration::from_millis(25);
595        let mut slept = Duration::ZERO;
596        while slept < backoff && !stop.load(Ordering::SeqCst) {
597            let chunk = step.min(backoff.saturating_sub(slept));
598            thread::sleep(chunk);
599            slept = slept.saturating_add(chunk);
600        }
601    }
602}
603
604fn connect_and_stream(
605    upstream: &str,
606    config: &ReplicationConfig,
607    state: &Arc<Mutex<ReplicationReplica>>,
608    stop: &Arc<AtomicBool>,
609) -> Result<()> {
610    let addr = upstream
611        .to_socket_addrs()
612        .map_err(|error| {
613            FastCacheError::Config(format!("replica address {upstream} unresolvable: {error}"))
614        })?
615        .next()
616        .ok_or_else(|| {
617            FastCacheError::Config(format!("replica address {upstream} had no entries"))
618        })?;
619    let mut stream = TcpStream::connect_timeout(
620        &addr,
621        Duration::from_millis(config.connect_timeout_ms.max(1)),
622    )?;
623    stream.set_nodelay(true).ok();
624    stream
625        .set_read_timeout(Some(Duration::from_millis(
626            config.connect_timeout_ms.max(1),
627        )))
628        .ok();
629    stream
630        .set_write_timeout(Some(Duration::from_millis(config.write_timeout_ms.max(1))))
631        .ok();
632
633    let since = state.lock().watermarks().clone();
634    let hello = ReplicationHello {
635        version: FCRP_VERSION,
636        role: HelloRole::Replica,
637        auth_token: config.auth_token.clone(),
638        since: Some(since),
639    };
640    write_full_frame(
641        &mut stream,
642        FrameKind::Hello,
643        ReplicationCompressionMode::None,
644        0,
645        &encode_hello(&hello),
646    )?;
647
648    // Read Hello-ack.
649    let ack_bytes = read_frame_bytes(&mut stream)?;
650    let ack = decode_frame(&ack_bytes)?;
651    match ack.kind {
652        FrameKind::Hello => {}
653        FrameKind::Error => {
654            let message = decode_error(&ack.payload).unwrap_or_else(|_| "unknown".to_string());
655            return Err(FastCacheError::Protocol(format!(
656                "primary rejected handshake: {message}"
657            )));
658        }
659        other => {
660            return Err(FastCacheError::Protocol(format!(
661                "expected Hello ack, got {other:?}"
662            )));
663        }
664    }
665
666    // Use a short read timeout so the loop polls the stop flag without
667    // blocking indefinitely.
668    stream
669        .set_read_timeout(Some(Duration::from_millis(200)))
670        .ok();
671    let mut pending_snapshot: Option<PendingSnapshot> = None;
672    while !stop.load(Ordering::SeqCst) {
673        let bytes = match read_frame_bytes_interruptible(&mut stream, stop) {
674            Ok(Some(bytes)) => bytes,
675            Ok(None) => return Ok(()),
676            Err(FastCacheError::Io(error))
677                if error.kind() == io::ErrorKind::UnexpectedEof
678                    || error.kind() == io::ErrorKind::ConnectionReset =>
679            {
680                return Ok(());
681            }
682            Err(error) => return Err(error),
683        };
684        let frame = decode_frame_payload_bytes(bytes::Bytes::from(bytes))?;
685        match frame.kind {
686            FrameKind::MutationBatch => {
687                let mut replica = state.lock();
688                replica.apply_frame_bytes_payload(frame)?;
689            }
690            FrameKind::SnapshotBegin => {
691                let watermarks = decode_ack(frame.payload.as_ref())?;
692                pending_snapshot = Some(PendingSnapshot {
693                    watermarks,
694                    entries: Vec::new(),
695                });
696            }
697            FrameKind::SnapshotChunk => {
698                let chunk = decode_snapshot_chunk(frame.payload.as_ref())?;
699                let Some(slot) = pending_snapshot.as_mut() else {
700                    return Err(FastCacheError::Protocol(
701                        "SnapshotChunk arrived without SnapshotBegin".into(),
702                    ));
703                };
704                slot.entries.extend(chunk.entries);
705            }
706            FrameKind::SnapshotEnd => {
707                let Some(snapshot) = pending_snapshot.take() else {
708                    return Err(FastCacheError::Protocol(
709                        "SnapshotEnd arrived without SnapshotBegin".into(),
710                    ));
711                };
712                let mut replica = state.lock();
713                replica.replace_with_snapshot(super::protocol::ReplicationSnapshot {
714                    entries: snapshot.entries,
715                    watermarks: snapshot.watermarks,
716                });
717            }
718            FrameKind::Hello => {
719                // Ignore unexpected mid-stream Hello frames.
720            }
721            FrameKind::Ack => {
722                // Replica doesn't act on Ack frames today.
723            }
724            FrameKind::Error => {
725                let message =
726                    decode_error(frame.payload.as_ref()).unwrap_or_else(|_| "unknown".to_string());
727                return Err(FastCacheError::Protocol(format!(
728                    "primary error frame: {message}"
729                )));
730            }
731        }
732    }
733    Ok(())
734}
735
736struct PendingSnapshot {
737    watermarks: ShardWatermarks,
738    entries: Vec<StoredEntry>,
739}
740
741#[cfg(test)]
742mod tests {
743    use std::net::TcpListener;
744    use std::time::Duration;
745
746    use crate::config::{
747        ReplicationCompression, ReplicationConfig, ReplicationRole, ReplicationSendPolicy,
748    };
749
750    use super::*;
751
752    fn ephemeral_addr() -> String {
753        let listener = TcpListener::bind("127.0.0.1:0").expect("bind ephemeral");
754        let addr = listener.local_addr().expect("local_addr");
755        drop(listener);
756        addr.to_string()
757    }
758
759    fn primary_config(addr: &str, auth_token: Option<&str>) -> ReplicationConfig {
760        ReplicationConfig {
761            enabled: true,
762            role: ReplicationRole::Primary,
763            bind_addr: addr.to_string(),
764            replica_of: None,
765            auth_token: auth_token.map(str::to_string),
766            compression: ReplicationCompression::None,
767            send_policy: ReplicationSendPolicy::Immediate,
768            batch_max_records: 1,
769            batch_max_delay_us: 1_000,
770            snapshot_chunk_bytes: 4 * 1024,
771            ..ReplicationConfig::default()
772        }
773    }
774
775    fn replica_config(upstream: &str, auth_token: Option<&str>) -> ReplicationConfig {
776        ReplicationConfig {
777            enabled: true,
778            role: ReplicationRole::Replica,
779            bind_addr: String::new(),
780            replica_of: Some(upstream.to_string()),
781            auth_token: auth_token.map(str::to_string),
782            compression: ReplicationCompression::None,
783            ..ReplicationConfig::default()
784        }
785    }
786
787    fn await_value(
788        client: &ReplicationReplicaClient,
789        key: &[u8],
790        deadline: Duration,
791    ) -> Option<Vec<u8>> {
792        let start = std::time::Instant::now();
793        while start.elapsed() < deadline {
794            if let Some(value) = client.replica().lock().get(key) {
795                return Some(value);
796            }
797            thread::sleep(Duration::from_millis(10));
798        }
799        None
800    }
801
802    #[test]
803    fn live_streaming_round_trip() {
804        let addr = ephemeral_addr();
805        let primary = Arc::new(
806            ReplicatedEmbeddedStore::new(2, primary_config(&addr, None)).expect("primary"),
807        );
808        let server = ReplicationPrimaryServer::start(
809            primary_config(&addr, None),
810            primary.primary(),
811            Arc::clone(&primary) as Arc<dyn SnapshotProvider>,
812        )
813        .expect("server");
814        let client = ReplicationReplicaClient::start(replica_config(&addr, None)).expect("replica");
815
816        primary.set(b"alpha".to_vec(), b"one".to_vec(), None);
817        primary.set(b"beta".to_vec(), b"two".to_vec(), None);
818        assert_eq!(
819            await_value(&client, b"alpha", Duration::from_secs(3)),
820            Some(b"one".to_vec())
821        );
822        assert_eq!(
823            await_value(&client, b"beta", Duration::from_secs(3)),
824            Some(b"two".to_vec())
825        );
826        client.shutdown().ok();
827        server.shutdown().ok();
828    }
829
830    #[test]
831    fn snapshot_bootstrap_when_backlog_empty() {
832        let addr = ephemeral_addr();
833        let primary = Arc::new(
834            ReplicatedEmbeddedStore::new(2, primary_config(&addr, None)).expect("primary"),
835        );
836        // Populate before the replica connects.
837        for i in 0..32 {
838            primary.set(format!("key-{i}").into_bytes(), b"v".to_vec(), None);
839        }
840        thread::sleep(Duration::from_millis(20));
841        let mut tight_cfg = primary_config(&addr, None);
842        tight_cfg.backlog_bytes = 1; // force snapshot path
843        let primary =
844            Arc::new(ReplicatedEmbeddedStore::new(2, tight_cfg.clone()).expect("primary2"));
845        for i in 0..32 {
846            primary.set(format!("key-{i}").into_bytes(), b"v".to_vec(), None);
847        }
848        let server = ReplicationPrimaryServer::start(
849            tight_cfg,
850            primary.primary(),
851            Arc::clone(&primary) as Arc<dyn SnapshotProvider>,
852        )
853        .expect("server");
854        let client = ReplicationReplicaClient::start(replica_config(&addr, None)).expect("replica");
855
856        for i in 0..32 {
857            let key = format!("key-{i}").into_bytes();
858            assert_eq!(
859                await_value(&client, &key, Duration::from_secs(5)),
860                Some(b"v".to_vec()),
861                "missing {i}"
862            );
863        }
864        client.shutdown().ok();
865        server.shutdown().ok();
866    }
867
868    #[test]
869    fn auth_token_required_when_configured() {
870        let addr = ephemeral_addr();
871        let primary = Arc::new(
872            ReplicatedEmbeddedStore::new(2, primary_config(&addr, Some("secret")))
873                .expect("primary"),
874        );
875        let server = ReplicationPrimaryServer::start(
876            primary_config(&addr, Some("secret")),
877            primary.primary(),
878            Arc::clone(&primary) as Arc<dyn SnapshotProvider>,
879        )
880        .expect("server");
881        // Wrong token — replica connect_and_stream will return Err and the
882        // client retries; we just confirm no data leaks.
883        let client = ReplicationReplicaClient::start(replica_config(&addr, Some("wrong")))
884            .expect("client-start");
885        primary.set(b"alpha".to_vec(), b"one".to_vec(), None);
886        thread::sleep(Duration::from_millis(200));
887        assert!(client.replica().lock().get(b"alpha").is_none());
888        client.shutdown().ok();
889        server.shutdown().ok();
890    }
891}