Skip to main content

fast_cache/replication/
backlog.rs

1use std::collections::VecDeque;
2
3use smallvec::SmallVec;
4
5use crate::{FastCacheError, Result};
6
7use super::ReplicationFrameBytes;
8use super::protocol::{ReplicationFrame, ReplicationMutation, ShardWatermarks, decode_frame};
9
10#[derive(Debug, Clone)]
11pub enum BacklogCatchUp {
12    Available(Vec<ReplicationFrameBytes>),
13    NeedsSnapshot,
14}
15
16#[derive(Debug)]
17pub struct ReplicationBacklog {
18    max_bytes: usize,
19    current_bytes: usize,
20    shard_count: usize,
21    entries: VecDeque<BacklogEntry>,
22    /// High watermark of mutations ever pushed, even if since trimmed. Used to
23    /// short-circuit catch-up when a caller is already ahead of everything we
24    /// ever held.
25    high_watermarks: ShardWatermarks,
26    trimmed: bool,
27}
28
29#[derive(Debug, Clone)]
30struct BacklogEntry {
31    frame: ReplicationFrameBytes,
32    spans: SmallVec<[BacklogShardSpan; 1]>,
33    byte_len: usize,
34}
35
36#[derive(Debug, Clone, Copy)]
37struct BacklogShardSpan {
38    shard_id: usize,
39    min_sequence: u64,
40    max_sequence: u64,
41}
42
43impl ReplicationBacklog {
44    pub fn new(max_bytes: usize, shard_count: usize) -> Self {
45        let shard_count = shard_count.max(1);
46        Self {
47            max_bytes: max_bytes.max(1),
48            current_bytes: 0,
49            shard_count,
50            entries: VecDeque::new(),
51            high_watermarks: ShardWatermarks::new(shard_count),
52            trimmed: false,
53        }
54    }
55
56    pub fn push(&mut self, frame: ReplicationFrameBytes, mutations: &[ReplicationMutation]) {
57        self.push_encoded(frame, mutations);
58    }
59
60    pub fn push_encoded(
61        &mut self,
62        frame: ReplicationFrameBytes,
63        mutations: &[ReplicationMutation],
64    ) {
65        let mut spans = SmallVec::<[BacklogShardSpan; 1]>::new();
66        for mutation in mutations {
67            self.ensure_shard_capacity(mutation.shard_id);
68            self.high_watermarks
69                .observe(mutation.shard_id, mutation.sequence);
70            match spans
71                .iter_mut()
72                .find(|span| span.shard_id == mutation.shard_id)
73            {
74                Some(span) => {
75                    span.min_sequence = span.min_sequence.min(mutation.sequence);
76                    span.max_sequence = span.max_sequence.max(mutation.sequence);
77                }
78                None => spans.push(BacklogShardSpan {
79                    shard_id: mutation.shard_id,
80                    min_sequence: mutation.sequence,
81                    max_sequence: mutation.sequence,
82                }),
83            }
84        }
85
86        let byte_len = frame.len();
87        self.current_bytes = self.current_bytes.saturating_add(byte_len);
88        self.entries.push_back(BacklogEntry {
89            frame,
90            spans,
91            byte_len,
92        });
93        self.trim();
94    }
95
96    pub(crate) fn push_encoded_span(
97        &mut self,
98        frame: ReplicationFrameBytes,
99        shard_id: usize,
100        min_sequence: u64,
101        max_sequence: u64,
102    ) {
103        self.ensure_shard_capacity(shard_id);
104        self.high_watermarks.observe(shard_id, max_sequence);
105
106        let byte_len = frame.len();
107        self.current_bytes = self.current_bytes.saturating_add(byte_len);
108        self.entries.push_back(BacklogEntry {
109            frame,
110            spans: SmallVec::from_buf([BacklogShardSpan {
111                shard_id,
112                min_sequence,
113                max_sequence,
114            }]),
115            byte_len,
116        });
117        self.trim();
118    }
119
120    fn ensure_shard_capacity(&mut self, shard_id: usize) {
121        if shard_id >= self.shard_count {
122            self.shard_count = shard_id + 1;
123        }
124    }
125
126    pub fn catch_up_since(&self, watermarks: &ShardWatermarks) -> Result<BacklogCatchUp> {
127        if self.caller_is_caught_up(watermarks) {
128            return Ok(BacklogCatchUp::Available(Vec::new()));
129        }
130        if self.entries.is_empty() {
131            return Ok(if self.trimmed {
132                BacklogCatchUp::NeedsSnapshot
133            } else {
134                BacklogCatchUp::Available(Vec::new())
135            });
136        }
137        let earliest_retained = self.earliest_retained_sequences();
138        for (shard_id, high) in self.high_watermarks.as_slice().iter().enumerate() {
139            if *high <= watermarks.get(shard_id) {
140                continue;
141            }
142            let earliest = earliest_retained.get(shard_id).copied().unwrap_or(0);
143            if earliest == 0 || watermarks.get(shard_id).saturating_add(1) < earliest {
144                return Ok(BacklogCatchUp::NeedsSnapshot);
145            }
146        }
147
148        let mut frames = Vec::new();
149        for entry in &self.entries {
150            let needed = entry
151                .spans
152                .iter()
153                .any(|span| span.max_sequence > watermarks.get(span.shard_id));
154            if !needed {
155                continue;
156            }
157            // Validate the frame can be decoded; a corrupt backlog entry is a
158            // hard error so the caller can fall back to a snapshot.
159            decode_frame(entry.frame.as_ref()).map_err(|error| {
160                FastCacheError::Protocol(format!("backlog frame is corrupt: {error}"))
161            })?;
162            frames.push(entry.frame.clone());
163        }
164        Ok(BacklogCatchUp::Available(frames))
165    }
166
167    pub fn decode_frames(frames: &[ReplicationFrameBytes]) -> Result<Vec<ReplicationFrame>> {
168        frames
169            .iter()
170            .map(|frame| decode_frame(frame.as_ref()))
171            .collect()
172    }
173
174    pub fn latest_watermarks(&self) -> ShardWatermarks {
175        self.high_watermarks.clone()
176    }
177
178    fn caller_is_caught_up(&self, watermarks: &ShardWatermarks) -> bool {
179        let highs = self.high_watermarks.as_slice();
180        for (shard_id, high) in highs.iter().enumerate() {
181            if *high > watermarks.get(shard_id) {
182                return false;
183            }
184        }
185        true
186    }
187
188    pub fn retained_bytes(&self) -> usize {
189        self.current_bytes
190    }
191
192    fn earliest_retained_sequences(&self) -> Vec<u64> {
193        let mut earliest = vec![0_u64; self.shard_count];
194        for entry in &self.entries {
195            for span in &entry.spans {
196                if earliest.get(span.shard_id).copied().unwrap_or(0) == 0 {
197                    if earliest.len() <= span.shard_id {
198                        earliest.resize(span.shard_id + 1, 0);
199                    }
200                    earliest[span.shard_id] = span.min_sequence;
201                }
202            }
203        }
204        earliest
205    }
206
207    fn trim(&mut self) {
208        while self.current_bytes > self.max_bytes {
209            let Some(entry) = self.entries.pop_front() else {
210                break;
211            };
212            self.trimmed = true;
213            self.current_bytes = self.current_bytes.saturating_sub(entry.byte_len);
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use crate::replication::protocol::{
221        FrameKind, ReplicationCompressionMode, ReplicationMutation, ReplicationMutationOp,
222        encode_frame, encode_mutation_batch,
223    };
224    use bytes::Bytes as SharedBytes;
225
226    use crate::storage::{hash_key, hash_key_tag};
227
228    use super::*;
229
230    fn mutation(shard_id: usize, sequence: u64) -> ReplicationMutation {
231        let key = format!("key-{shard_id}-{sequence}").into_bytes();
232        ReplicationMutation {
233            shard_id,
234            sequence,
235            timestamp_ms: 1,
236            op: ReplicationMutationOp::Set,
237            key_hash: hash_key(&key),
238            key_tag: hash_key_tag(&key),
239            key: SharedBytes::from(key),
240            value: SharedBytes::from_static(b"v"),
241            expire_at_ms: None,
242        }
243    }
244
245    fn frame_for(mutations: &[ReplicationMutation]) -> ReplicationFrameBytes {
246        let payload = encode_mutation_batch(mutations);
247        ReplicationFrameBytes::from(
248            encode_frame(
249                FrameKind::MutationBatch,
250                ReplicationCompressionMode::None,
251                0,
252                &payload,
253            )
254            .expect("frame"),
255        )
256    }
257
258    #[test]
259    fn catches_up_when_watermark_is_retained() {
260        let mut backlog = ReplicationBacklog::new(1024 * 1024, 1);
261        for seq in 1..=3 {
262            let mutations = vec![mutation(0, seq)];
263            backlog.push(frame_for(&mutations), &mutations);
264        }
265        match backlog
266            .catch_up_since(&ShardWatermarks::from_vec(vec![1]))
267            .expect("catch_up_since")
268        {
269            BacklogCatchUp::Available(frames) => assert_eq!(frames.len(), 2),
270            BacklogCatchUp::NeedsSnapshot => panic!("expected backlog catch-up"),
271        }
272    }
273
274    #[test]
275    fn empty_backlog_after_trimming_requests_snapshot() {
276        let mut backlog = ReplicationBacklog::new(1, 1);
277        let mutations = vec![mutation(0, 1)];
278        backlog.push(frame_for(&mutations), &mutations);
279        let mutations = vec![mutation(0, 2)];
280        backlog.push(frame_for(&mutations), &mutations);
281        match backlog
282            .catch_up_since(&ShardWatermarks::from_vec(vec![0]))
283            .expect("catch_up_since")
284        {
285            BacklogCatchUp::NeedsSnapshot => {}
286            BacklogCatchUp::Available(_) => panic!("expected NeedsSnapshot after trimming"),
287        }
288    }
289
290    #[test]
291    fn multi_shard_catch_up_returns_relevant_frames() {
292        let mut backlog = ReplicationBacklog::new(1024 * 1024, 2);
293        let m1 = vec![mutation(0, 1)];
294        backlog.push(frame_for(&m1), &m1);
295        let m2 = vec![mutation(1, 1)];
296        backlog.push(frame_for(&m2), &m2);
297        let m3 = vec![mutation(0, 2)];
298        backlog.push(frame_for(&m3), &m3);
299
300        // Replica is caught up on shard 0 through seq 1, but knows nothing
301        // about shard 1.
302        match backlog
303            .catch_up_since(&ShardWatermarks::from_vec(vec![1, 0]))
304            .expect("catch_up_since")
305        {
306            BacklogCatchUp::Available(frames) => assert_eq!(frames.len(), 2),
307            BacklogCatchUp::NeedsSnapshot => panic!("expected backlog catch-up"),
308        }
309    }
310
311    #[test]
312    fn catch_up_resends_partially_needed_batch() {
313        let mut backlog = ReplicationBacklog::new(1024 * 1024, 1);
314        let mutations = vec![mutation(0, 1), mutation(0, 2)];
315        backlog.push(frame_for(&mutations), &mutations);
316
317        match backlog
318            .catch_up_since(&ShardWatermarks::from_vec(vec![1]))
319            .expect("catch_up_since")
320        {
321            BacklogCatchUp::Available(frames) => assert_eq!(frames.len(), 1),
322            BacklogCatchUp::NeedsSnapshot => panic!("expected backlog catch-up"),
323        }
324    }
325}