Skip to main content

fast_cache/replication/
batcher.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
3use std::thread;
4use std::time::{Duration, Instant};
5
6use crossbeam_channel::{Receiver, Select, Sender, TrySendError, bounded};
7use parking_lot::Mutex;
8
9use crate::config::{ReplicationConfig, ReplicationSendPolicy};
10use crate::{FastCacheError, Result};
11
12use super::ReplicationFrameBytes;
13use super::backlog::{BacklogCatchUp, ReplicationBacklog};
14use super::metrics::{ReplicationMetrics, ReplicationMetricsSnapshot};
15use super::protocol::{
16    BorrowedReplicationMutation, FRAME_HEADER_LEN, FrameKind, ReplicationCompressionMode,
17    ReplicationFrame, ReplicationMutation, ShardWatermarks, borrowed_mutation_record_payload_len,
18    decode_frame, encode_frame, encode_mutation_batch_frame_with_payload_len,
19    write_borrowed_mutation_payload_record, write_uncompressed_frame_header_at,
20};
21
22type SubscriberTx = Sender<ReplicationFrameBytes>;
23type SubscriberRef = Arc<ReplicationShardSubscriber>;
24type SubscriberList = Mutex<Vec<SubscriberRef>>;
25
26#[derive(Debug)]
27pub struct ReplicationPrimary {
28    config: ReplicationConfig,
29    shard_count: usize,
30    sequences: Vec<AtomicU64>,
31    metrics: ReplicationMetrics,
32    shards: Vec<ReplicationShardExport>,
33}
34
35#[derive(Debug)]
36struct ReplicationShardExport {
37    backlog: Mutex<ReplicationBacklog>,
38    subscribers: SubscriberList,
39    emitted_watermark: AtomicU64,
40}
41
42#[derive(Debug)]
43struct ReplicationShardSubscriber {
44    tx: SubscriberTx,
45    active: Arc<AtomicBool>,
46}
47
48#[derive(Debug)]
49pub(crate) struct ReplicationBatch {
50    pending: Vec<ReplicationMutation>,
51    pending_bytes: usize,
52}
53
54#[derive(Debug)]
55pub(crate) struct ReplicationBatchBuilder {
56    config: ReplicationConfig,
57    pending: Vec<ReplicationMutation>,
58    pending_bytes: usize,
59    first_pending_at: Option<Instant>,
60    track_delay: bool,
61}
62
63#[derive(Debug)]
64pub(crate) struct EncodedReplicationBatch {
65    shard_id: usize,
66    min_sequence: u64,
67    max_sequence: u64,
68    record_count: usize,
69    uncompressed_len: usize,
70    frame: Vec<u8>,
71}
72
73#[derive(Debug)]
74pub(crate) struct EncodedReplicationBatchBuilder {
75    config: ReplicationConfig,
76    shard_id: usize,
77    frame: Vec<u8>,
78    record_count: usize,
79    pending_bytes: usize,
80    min_sequence: u64,
81    max_sequence: u64,
82    first_pending_at: Option<Instant>,
83    track_delay: bool,
84}
85
86impl ReplicationBatch {
87    fn single(mutation: ReplicationMutation) -> Self {
88        let pending_bytes = mutation.estimated_uncompressed_len();
89        Self {
90            pending: vec![mutation],
91            pending_bytes,
92        }
93    }
94
95    fn shard_id(&self) -> Option<usize> {
96        self.pending.first().map(|mutation| mutation.shard_id)
97    }
98
99    fn payload_len(&self) -> usize {
100        4 + self.pending_bytes
101    }
102
103    fn is_empty(&self) -> bool {
104        self.pending.is_empty()
105    }
106}
107
108impl ReplicationBatchBuilder {
109    pub(crate) fn new(config: ReplicationConfig) -> Self {
110        Self::with_delay_tracking(config, true)
111    }
112
113    pub(crate) fn new_clockless(config: ReplicationConfig) -> Self {
114        Self::with_delay_tracking(config, false)
115    }
116
117    fn with_delay_tracking(config: ReplicationConfig, track_delay: bool) -> Self {
118        let capacity = batch_record_capacity(&config);
119        Self {
120            config,
121            pending: Vec::with_capacity(capacity),
122            pending_bytes: 0,
123            first_pending_at: None,
124            track_delay,
125        }
126    }
127
128    pub(crate) fn push(&mut self, mutation: ReplicationMutation) -> Option<ReplicationBatch> {
129        if self.track_delay && self.pending.is_empty() {
130            self.first_pending_at = Some(Instant::now());
131        }
132        self.pending_bytes = self
133            .pending_bytes
134            .saturating_add(mutation.estimated_uncompressed_len());
135        self.pending.push(mutation);
136        self.should_flush_after_push()
137            .then(|| self.flush())
138            .flatten()
139    }
140
141    pub(crate) fn flush_due(&mut self) -> Option<ReplicationBatch> {
142        self.should_flush_due().then(|| self.flush()).flatten()
143    }
144
145    pub(crate) fn flush(&mut self) -> Option<ReplicationBatch> {
146        if self.pending.is_empty() {
147            return None;
148        }
149        let capacity = batch_record_capacity(&self.config);
150        let pending = std::mem::replace(&mut self.pending, Vec::with_capacity(capacity));
151        let pending_bytes = self.pending_bytes;
152        self.pending_bytes = 0;
153        self.first_pending_at = None;
154        Some(ReplicationBatch {
155            pending,
156            pending_bytes,
157        })
158    }
159
160    pub(crate) fn next_timeout(&self) -> Option<Duration> {
161        match (self.pending.is_empty(), self.first_pending_at) {
162            (true, _) => None,
163            (false, None) => Some(Duration::ZERO),
164            (false, Some(start)) => Some(
165                self.max_delay()
166                    .checked_sub(start.elapsed())
167                    .unwrap_or_default(),
168            ),
169        }
170    }
171
172    fn should_flush_after_push(&self) -> bool {
173        if self.pending.is_empty() {
174            return false;
175        }
176        if self.config.send_policy == ReplicationSendPolicy::Immediate {
177            return true;
178        }
179        self.pending.len() >= self.config.batch_max_records
180            || self.pending_bytes >= self.config.batch_max_bytes
181    }
182
183    fn should_flush_due(&self) -> bool {
184        if self.pending.is_empty() {
185            return false;
186        }
187        if self.config.send_policy == ReplicationSendPolicy::Immediate {
188            return true;
189        }
190        if !self.track_delay {
191            return true;
192        }
193        self.first_pending_at
194            .is_some_and(|start| start.elapsed() >= self.max_delay())
195    }
196
197    fn max_delay(&self) -> Duration {
198        Duration::from_micros(self.config.batch_max_delay_us)
199    }
200}
201
202impl EncodedReplicationBatch {
203    fn shard_id(&self) -> usize {
204        self.shard_id
205    }
206
207    fn min_sequence(&self) -> u64 {
208        self.min_sequence
209    }
210
211    fn max_sequence(&self) -> u64 {
212        self.max_sequence
213    }
214
215    fn record_count(&self) -> usize {
216        self.record_count
217    }
218
219    fn uncompressed_len(&self) -> usize {
220        self.uncompressed_len
221    }
222
223    fn into_frame(
224        self,
225        compression: ReplicationCompressionMode,
226        zstd_level: i32,
227    ) -> Result<Vec<u8>> {
228        match compression {
229            ReplicationCompressionMode::None => Ok(self.frame),
230            ReplicationCompressionMode::Zstd => encode_frame(
231                FrameKind::MutationBatch,
232                compression,
233                zstd_level,
234                &self.frame[FRAME_HEADER_LEN..],
235            ),
236        }
237    }
238}
239
240impl EncodedReplicationBatchBuilder {
241    pub(crate) fn new_clockless(config: ReplicationConfig, shard_id: usize) -> Self {
242        Self::with_delay_tracking(config, shard_id, false)
243    }
244
245    pub(crate) fn shard_id(&self) -> usize {
246        self.shard_id
247    }
248
249    fn with_delay_tracking(config: ReplicationConfig, shard_id: usize, track_delay: bool) -> Self {
250        Self {
251            frame: empty_encoded_frame(&config),
252            config,
253            shard_id,
254            record_count: 0,
255            pending_bytes: 0,
256            min_sequence: u64::MAX,
257            max_sequence: 0,
258            first_pending_at: None,
259            track_delay,
260        }
261    }
262
263    pub(crate) fn push(
264        &mut self,
265        mutation: BorrowedReplicationMutation<'_>,
266    ) -> Option<EncodedReplicationBatch> {
267        if self.track_delay && self.record_count == 0 {
268            self.first_pending_at = Some(Instant::now());
269        }
270        self.pending_bytes = self
271            .pending_bytes
272            .saturating_add(borrowed_mutation_record_payload_len(mutation));
273        self.record_count += 1;
274        self.min_sequence = self.min_sequence.min(mutation.sequence);
275        self.max_sequence = self.max_sequence.max(mutation.sequence);
276        write_borrowed_mutation_payload_record(&mut self.frame, mutation);
277        self.should_flush_after_push()
278            .then(|| self.flush())
279            .flatten()
280    }
281
282    pub(crate) fn flush_due(&mut self) -> Option<EncodedReplicationBatch> {
283        self.should_flush_due().then(|| self.flush()).flatten()
284    }
285
286    pub(crate) fn flush(&mut self) -> Option<EncodedReplicationBatch> {
287        if self.record_count == 0 {
288            return None;
289        }
290
291        let record_count = self.record_count;
292        let pending_bytes = self.pending_bytes;
293        let min_sequence = self.min_sequence;
294        let max_sequence = self.max_sequence;
295        let mut frame = std::mem::replace(&mut self.frame, empty_encoded_frame(&self.config));
296        let uncompressed_len = 4 + pending_bytes;
297        write_uncompressed_frame_header_at(&mut frame, FrameKind::MutationBatch, uncompressed_len);
298        frame[FRAME_HEADER_LEN..FRAME_HEADER_LEN + 4]
299            .copy_from_slice(&(record_count as u32).to_le_bytes());
300
301        self.record_count = 0;
302        self.pending_bytes = 0;
303        self.min_sequence = u64::MAX;
304        self.max_sequence = 0;
305        self.first_pending_at = None;
306        Some(EncodedReplicationBatch {
307            shard_id: self.shard_id,
308            min_sequence,
309            max_sequence,
310            record_count,
311            uncompressed_len,
312            frame,
313        })
314    }
315
316    fn should_flush_after_push(&self) -> bool {
317        if self.record_count == 0 {
318            return false;
319        }
320        if self.config.send_policy == ReplicationSendPolicy::Immediate {
321            return true;
322        }
323        self.record_count >= self.config.batch_max_records
324            || self.pending_bytes >= self.config.batch_max_bytes
325    }
326
327    fn should_flush_due(&self) -> bool {
328        if self.record_count == 0 {
329            return false;
330        }
331        if self.config.send_policy == ReplicationSendPolicy::Immediate {
332            return true;
333        }
334        if !self.track_delay {
335            return true;
336        }
337        self.first_pending_at
338            .is_some_and(|start| start.elapsed() >= self.max_delay())
339    }
340
341    fn max_delay(&self) -> Duration {
342        Duration::from_micros(self.config.batch_max_delay_us)
343    }
344}
345
346fn empty_encoded_frame(config: &ReplicationConfig) -> Vec<u8> {
347    let payload_capacity = config.batch_max_bytes.clamp(4, 64 * 1024);
348    let mut frame = Vec::with_capacity(FRAME_HEADER_LEN + payload_capacity);
349    frame.resize(FRAME_HEADER_LEN, 0);
350    frame.extend_from_slice(&0_u32.to_le_bytes());
351    frame
352}
353
354fn batch_record_capacity(config: &ReplicationConfig) -> usize {
355    match config.send_policy {
356        ReplicationSendPolicy::Immediate => 1,
357        ReplicationSendPolicy::Batch => config.batch_max_records.clamp(1, 1024),
358    }
359}
360
361fn start_subscriber_fan_in(
362    out_tx: SubscriberTx,
363    shard_receivers: Vec<Receiver<ReplicationFrameBytes>>,
364    active: Arc<AtomicBool>,
365) {
366    let thread_active = Arc::clone(&active);
367    match thread::Builder::new()
368        .name("fast-cache-replication-subscriber-fan-in".into())
369        .spawn(move || run_subscriber_fan_in(out_tx, shard_receivers, thread_active))
370    {
371        Ok(_) => {}
372        Err(error) => {
373            active.store(false, Ordering::Relaxed);
374            tracing::error!("failed to start replication subscriber fan-in thread: {error}");
375        }
376    }
377}
378
379fn run_subscriber_fan_in(
380    out_tx: SubscriberTx,
381    mut shard_receivers: Vec<Receiver<ReplicationFrameBytes>>,
382    active: Arc<AtomicBool>,
383) {
384    while active.load(Ordering::Relaxed) && !shard_receivers.is_empty() {
385        let mut select = Select::new();
386        for receiver in &shard_receivers {
387            select.recv(receiver);
388        }
389
390        let operation = select.select();
391        let index = operation.index();
392        match operation.recv(&shard_receivers[index]) {
393            Ok(frame) => match out_tx.send(frame) {
394                Ok(()) => {}
395                Err(_) => {
396                    active.store(false, Ordering::Relaxed);
397                    break;
398                }
399            },
400            Err(_) => {
401                shard_receivers.swap_remove(index);
402            }
403        }
404    }
405    active.store(false, Ordering::Relaxed);
406}
407
408impl ReplicationPrimary {
409    pub fn start(shard_count: usize, config: ReplicationConfig) -> Result<Self> {
410        if !config.enabled {
411            return Err(FastCacheError::Config(
412                "replication primary requires replication.enabled = true".into(),
413            ));
414        }
415        let shard_count = shard_count.max(1);
416        let metrics = ReplicationMetrics::default();
417        let shards = (0..shard_count)
418            .map(|_| ReplicationShardExport {
419                backlog: Mutex::new(ReplicationBacklog::new(config.backlog_bytes, shard_count)),
420                subscribers: Mutex::new(Vec::new()),
421                emitted_watermark: AtomicU64::new(0),
422            })
423            .collect();
424        Ok(Self {
425            config,
426            shard_count,
427            sequences: (0..shard_count).map(|_| AtomicU64::new(0)).collect(),
428            metrics,
429            shards,
430        })
431    }
432
433    pub fn shard_count(&self) -> usize {
434        self.shard_count
435    }
436
437    /// Allocates the next sequence for `shard_id` from the primary-owned
438    /// counter. Callers using an external sequence source (such as the engine
439    /// shard worker, which already mints sequences for the WAL) should pass
440    /// fully-formed mutations to [`Self::emit`] instead.
441    pub fn next_sequence(&self, shard_id: usize) -> u64 {
442        debug_assert!(
443            shard_id < self.sequences.len(),
444            "shard_id {shard_id} out of range for {} primary sequences",
445            self.sequences.len()
446        );
447        let Some(slot) = self.sequences.get(shard_id) else {
448            return 0;
449        };
450        slot.fetch_add(1, Ordering::Relaxed) + 1
451    }
452
453    /// Enqueues a mutation for callers that do not own a shard-local batch
454    /// buffer.
455    ///
456    /// The sharded engine path uses `ReplicationBatchBuilder` so it can
457    /// append mutations to an ordered per-shard `Vec` and hand off ready
458    /// batches. This fallback keeps direct embedded callers correct without
459    /// adding shared batching state.
460    pub fn emit(&self, mutation: ReplicationMutation) {
461        if mutation.shard_id >= self.shards.len() {
462            self.metrics.record_drop();
463            tracing::warn!(
464                "dropping replication mutation for shard {} outside configured shard_count {}",
465                mutation.shard_id,
466                self.shard_count
467            );
468            return;
469        }
470        self.export_batch_direct(ReplicationBatch::single(mutation));
471    }
472
473    /// Encodes and publishes a shard-owned batch on the caller's thread.
474    ///
475    /// Storage shard workers already own ordering and batching, so routing
476    /// those batches through a second exporter thread only adds channel and
477    /// scheduler overhead. Direct export still keeps socket writes out of the
478    /// storage shard; it publishes immutable frames into subscriber queues.
479    pub(crate) fn export_batch_direct(&self, batch: ReplicationBatch) {
480        if batch.is_empty() {
481            return;
482        }
483        let Some(shard_id) = batch.shard_id() else {
484            return;
485        };
486        if shard_id >= self.shards.len() {
487            self.metrics.record_drop();
488            tracing::warn!(
489                "dropping replication batch for shard {} outside configured shard_count {}",
490                shard_id,
491                self.shard_count
492            );
493            return;
494        }
495        let payload_len = batch.payload_len();
496        let record_count = batch.pending.len();
497        let compression = ReplicationCompressionMode::from(self.config.compression);
498        let compression_started = Instant::now();
499        let (frame, uncompressed_len) = match encode_mutation_batch_frame_with_payload_len(
500            &batch.pending,
501            payload_len,
502            compression,
503            self.config.zstd_level,
504        ) {
505            Ok(encoded) => encoded,
506            Err(error) => {
507                tracing::error!("failed to encode replication batch: {error}");
508                return;
509            }
510        };
511        let compression_ns = compression_started.elapsed().as_nanos() as u64;
512        self.metrics.record_emit_count(record_count, 0);
513        self.metrics
514            .record_batch(record_count, uncompressed_len, frame.len(), compression_ns);
515        let frame = ReplicationFrameBytes::from(frame);
516        self.shards[shard_id]
517            .backlog
518            .lock()
519            .push_encoded(frame.clone(), &batch.pending);
520        self.observe_emitted_watermarks(&batch.pending);
521        self.broadcast(shard_id, frame);
522    }
523
524    pub(crate) fn export_encoded_batch_direct(&self, batch: EncodedReplicationBatch) {
525        let shard_id = batch.shard_id();
526        if shard_id >= self.shards.len() {
527            self.metrics.record_drop();
528            tracing::warn!(
529                "dropping encoded replication batch for shard {} outside configured shard_count {}",
530                shard_id,
531                self.shard_count
532            );
533            return;
534        }
535
536        let record_count = batch.record_count();
537        let min_sequence = batch.min_sequence();
538        let max_sequence = batch.max_sequence();
539        let uncompressed_len = batch.uncompressed_len();
540        let compression = ReplicationCompressionMode::from(self.config.compression);
541        let compression_started = Instant::now();
542        let frame = match batch.into_frame(compression, self.config.zstd_level) {
543            Ok(frame) => frame,
544            Err(error) => {
545                tracing::error!("failed to encode direct replication batch: {error}");
546                return;
547            }
548        };
549        let compression_ns = compression_started.elapsed().as_nanos() as u64;
550        self.metrics.record_emit_count(record_count, 0);
551        self.metrics
552            .record_batch(record_count, uncompressed_len, frame.len(), compression_ns);
553        let frame = ReplicationFrameBytes::from(frame);
554        self.shards[shard_id].backlog.lock().push_encoded_span(
555            frame.clone(),
556            shard_id,
557            min_sequence,
558            max_sequence,
559        );
560        self.shards[shard_id]
561            .emitted_watermark
562            .fetch_max(max_sequence, Ordering::Relaxed);
563        self.broadcast(shard_id, frame);
564    }
565
566    pub fn queue_depths(&self) -> Vec<usize> {
567        vec![0; self.shard_count]
568    }
569
570    pub fn max_queue_depth(&self) -> usize {
571        0
572    }
573
574    pub fn total_queue_depth(&self) -> usize {
575        0
576    }
577
578    pub fn per_shard_export_enabled(&self) -> bool {
579        self.shards.len() == self.shard_count
580    }
581
582    pub fn lane_count(&self) -> usize {
583        self.shards.len()
584    }
585
586    pub fn shutdown(&self) -> Result<()> {
587        Ok(())
588    }
589
590    pub fn subscribe(&self, channel_capacity: usize) -> Receiver<ReplicationFrameBytes> {
591        let channel_capacity = channel_capacity.max(1);
592        let (out_tx, out_rx) = bounded(channel_capacity);
593        let active = Arc::new(AtomicBool::new(true));
594        let mut shard_receivers = Vec::with_capacity(self.shards.len());
595        for shard in &self.shards {
596            let (tx, rx) = bounded(channel_capacity);
597            shard
598                .subscribers
599                .lock()
600                .push(Arc::new(ReplicationShardSubscriber {
601                    tx,
602                    active: Arc::clone(&active),
603                }));
604            shard_receivers.push(rx);
605        }
606        start_subscriber_fan_in(out_tx, shard_receivers, active);
607        out_rx
608    }
609
610    pub fn catch_up_since(&self, watermarks: &ShardWatermarks) -> Result<BacklogCatchUp> {
611        let mut frames = Vec::new();
612        for shard in &self.shards {
613            match shard.backlog.lock().catch_up_since(watermarks)? {
614                BacklogCatchUp::Available(mut shard_frames) => frames.append(&mut shard_frames),
615                BacklogCatchUp::NeedsSnapshot => return Ok(BacklogCatchUp::NeedsSnapshot),
616            }
617        }
618        Ok(BacklogCatchUp::Available(frames))
619    }
620
621    /// Returns the watermarks for batches that have been emitted to subscribers
622    /// and the backlog. Mutations that are still pending in shard-local batch
623    /// builders or ready-batch channels are not reflected here.
624    pub fn current_watermarks(&self) -> ShardWatermarks {
625        ShardWatermarks::from_vec(
626            self.shards
627                .iter()
628                .map(|shard| shard.emitted_watermark.load(Ordering::Relaxed))
629                .collect(),
630        )
631    }
632
633    pub fn latest_backlog_watermarks(&self) -> ShardWatermarks {
634        self.current_watermarks()
635    }
636
637    pub fn metrics_snapshot(&self) -> ReplicationMetricsSnapshot {
638        self.metrics.snapshot()
639    }
640
641    pub fn decode_subscriber_frame(bytes: &[u8]) -> Result<ReplicationFrame> {
642        decode_frame(bytes)
643    }
644
645    fn observe_emitted_watermarks(&self, mutations: &[ReplicationMutation]) {
646        for mutation in mutations {
647            if let Some(shard) = self.shards.get(mutation.shard_id) {
648                shard
649                    .emitted_watermark
650                    .fetch_max(mutation.sequence, Ordering::Relaxed);
651            }
652        }
653    }
654
655    fn broadcast(&self, shard_id: usize, frame: ReplicationFrameBytes) {
656        let Some(shard) = self.shards.get(shard_id) else {
657            return;
658        };
659        let mut subscribers = shard.subscribers.lock();
660        subscribers.retain(|subscriber| {
661            subscriber.active.load(Ordering::Relaxed)
662                && match subscriber.tx.try_send(frame.clone()) {
663                    Ok(()) => true,
664                    Err(TrySendError::Full(_)) => {
665                        subscriber.active.store(false, Ordering::Relaxed);
666                        self.metrics.record_backpressure();
667                        self.metrics.record_drop();
668                        false
669                    }
670                    Err(TrySendError::Disconnected(_)) => {
671                        subscriber.active.store(false, Ordering::Relaxed);
672                        self.metrics.record_drop();
673                        false
674                    }
675                }
676        });
677    }
678}
679
680impl Drop for ReplicationPrimary {
681    fn drop(&mut self) {
682        let _ = self.shutdown();
683    }
684}