1use std::collections::VecDeque;
2
3use smallvec::SmallVec;
4
5use crate::{FastCacheError, Result};
6
7use super::ReplicationFrameBytes;
8use super::protocol::{ReplicationFrame, ReplicationMutation, ShardWatermarks, decode_frame};
9
10#[derive(Debug, Clone)]
11pub enum BacklogCatchUp {
12 Available(Vec<ReplicationFrameBytes>),
13 NeedsSnapshot,
14}
15
16#[derive(Debug)]
17pub struct ReplicationBacklog {
18 max_bytes: usize,
19 current_bytes: usize,
20 shard_count: usize,
21 entries: VecDeque<BacklogEntry>,
22 high_watermarks: ShardWatermarks,
26 trimmed: bool,
27}
28
29#[derive(Debug, Clone)]
30struct BacklogEntry {
31 frame: ReplicationFrameBytes,
32 spans: SmallVec<[BacklogShardSpan; 1]>,
33 byte_len: usize,
34}
35
36#[derive(Debug, Clone, Copy)]
37struct BacklogShardSpan {
38 shard_id: usize,
39 min_sequence: u64,
40 max_sequence: u64,
41}
42
43impl ReplicationBacklog {
44 pub fn new(max_bytes: usize, shard_count: usize) -> Self {
45 let shard_count = shard_count.max(1);
46 Self {
47 max_bytes: max_bytes.max(1),
48 current_bytes: 0,
49 shard_count,
50 entries: VecDeque::new(),
51 high_watermarks: ShardWatermarks::new(shard_count),
52 trimmed: false,
53 }
54 }
55
56 pub fn push(&mut self, frame: ReplicationFrameBytes, mutations: &[ReplicationMutation]) {
57 self.push_encoded(frame, mutations);
58 }
59
60 pub fn push_encoded(
61 &mut self,
62 frame: ReplicationFrameBytes,
63 mutations: &[ReplicationMutation],
64 ) {
65 let mut spans = SmallVec::<[BacklogShardSpan; 1]>::new();
66 for mutation in mutations {
67 self.ensure_shard_capacity(mutation.shard_id);
68 self.high_watermarks
69 .observe(mutation.shard_id, mutation.sequence);
70 match spans
71 .iter_mut()
72 .find(|span| span.shard_id == mutation.shard_id)
73 {
74 Some(span) => {
75 span.min_sequence = span.min_sequence.min(mutation.sequence);
76 span.max_sequence = span.max_sequence.max(mutation.sequence);
77 }
78 None => spans.push(BacklogShardSpan {
79 shard_id: mutation.shard_id,
80 min_sequence: mutation.sequence,
81 max_sequence: mutation.sequence,
82 }),
83 }
84 }
85
86 let byte_len = frame.len();
87 self.current_bytes = self.current_bytes.saturating_add(byte_len);
88 self.entries.push_back(BacklogEntry {
89 frame,
90 spans,
91 byte_len,
92 });
93 self.trim();
94 }
95
96 pub(crate) fn push_encoded_span(
97 &mut self,
98 frame: ReplicationFrameBytes,
99 shard_id: usize,
100 min_sequence: u64,
101 max_sequence: u64,
102 ) {
103 self.ensure_shard_capacity(shard_id);
104 self.high_watermarks.observe(shard_id, max_sequence);
105
106 let byte_len = frame.len();
107 self.current_bytes = self.current_bytes.saturating_add(byte_len);
108 self.entries.push_back(BacklogEntry {
109 frame,
110 spans: SmallVec::from_buf([BacklogShardSpan {
111 shard_id,
112 min_sequence,
113 max_sequence,
114 }]),
115 byte_len,
116 });
117 self.trim();
118 }
119
120 fn ensure_shard_capacity(&mut self, shard_id: usize) {
121 if shard_id >= self.shard_count {
122 self.shard_count = shard_id + 1;
123 }
124 }
125
126 pub fn catch_up_since(&self, watermarks: &ShardWatermarks) -> Result<BacklogCatchUp> {
127 if self.caller_is_caught_up(watermarks) {
128 return Ok(BacklogCatchUp::Available(Vec::new()));
129 }
130 if self.entries.is_empty() {
131 return Ok(if self.trimmed {
132 BacklogCatchUp::NeedsSnapshot
133 } else {
134 BacklogCatchUp::Available(Vec::new())
135 });
136 }
137 let earliest_retained = self.earliest_retained_sequences();
138 for (shard_id, high) in self.high_watermarks.as_slice().iter().enumerate() {
139 if *high <= watermarks.get(shard_id) {
140 continue;
141 }
142 let earliest = earliest_retained.get(shard_id).copied().unwrap_or(0);
143 if earliest == 0 || watermarks.get(shard_id).saturating_add(1) < earliest {
144 return Ok(BacklogCatchUp::NeedsSnapshot);
145 }
146 }
147
148 let mut frames = Vec::new();
149 for entry in &self.entries {
150 let needed = entry
151 .spans
152 .iter()
153 .any(|span| span.max_sequence > watermarks.get(span.shard_id));
154 if !needed {
155 continue;
156 }
157 decode_frame(entry.frame.as_ref()).map_err(|error| {
160 FastCacheError::Protocol(format!("backlog frame is corrupt: {error}"))
161 })?;
162 frames.push(entry.frame.clone());
163 }
164 Ok(BacklogCatchUp::Available(frames))
165 }
166
167 pub fn decode_frames(frames: &[ReplicationFrameBytes]) -> Result<Vec<ReplicationFrame>> {
168 frames
169 .iter()
170 .map(|frame| decode_frame(frame.as_ref()))
171 .collect()
172 }
173
174 pub fn latest_watermarks(&self) -> ShardWatermarks {
175 self.high_watermarks.clone()
176 }
177
178 fn caller_is_caught_up(&self, watermarks: &ShardWatermarks) -> bool {
179 let highs = self.high_watermarks.as_slice();
180 for (shard_id, high) in highs.iter().enumerate() {
181 if *high > watermarks.get(shard_id) {
182 return false;
183 }
184 }
185 true
186 }
187
188 pub fn retained_bytes(&self) -> usize {
189 self.current_bytes
190 }
191
192 fn earliest_retained_sequences(&self) -> Vec<u64> {
193 let mut earliest = vec![0_u64; self.shard_count];
194 for entry in &self.entries {
195 for span in &entry.spans {
196 if earliest.get(span.shard_id).copied().unwrap_or(0) == 0 {
197 if earliest.len() <= span.shard_id {
198 earliest.resize(span.shard_id + 1, 0);
199 }
200 earliest[span.shard_id] = span.min_sequence;
201 }
202 }
203 }
204 earliest
205 }
206
207 fn trim(&mut self) {
208 while self.current_bytes > self.max_bytes {
209 let Some(entry) = self.entries.pop_front() else {
210 break;
211 };
212 self.trimmed = true;
213 self.current_bytes = self.current_bytes.saturating_sub(entry.byte_len);
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use crate::replication::protocol::{
221 FrameKind, ReplicationCompressionMode, ReplicationMutation, ReplicationMutationOp,
222 encode_frame, encode_mutation_batch,
223 };
224 use bytes::Bytes as SharedBytes;
225
226 use crate::storage::{hash_key, hash_key_tag};
227
228 use super::*;
229
230 fn mutation(shard_id: usize, sequence: u64) -> ReplicationMutation {
231 let key = format!("key-{shard_id}-{sequence}").into_bytes();
232 ReplicationMutation {
233 shard_id,
234 sequence,
235 timestamp_ms: 1,
236 op: ReplicationMutationOp::Set,
237 key_hash: hash_key(&key),
238 key_tag: hash_key_tag(&key),
239 key: SharedBytes::from(key),
240 value: SharedBytes::from_static(b"v"),
241 expire_at_ms: None,
242 }
243 }
244
245 fn frame_for(mutations: &[ReplicationMutation]) -> ReplicationFrameBytes {
246 let payload = encode_mutation_batch(mutations);
247 ReplicationFrameBytes::from(
248 encode_frame(
249 FrameKind::MutationBatch,
250 ReplicationCompressionMode::None,
251 0,
252 &payload,
253 )
254 .expect("frame"),
255 )
256 }
257
258 #[test]
259 fn catches_up_when_watermark_is_retained() {
260 let mut backlog = ReplicationBacklog::new(1024 * 1024, 1);
261 for seq in 1..=3 {
262 let mutations = vec![mutation(0, seq)];
263 backlog.push(frame_for(&mutations), &mutations);
264 }
265 match backlog
266 .catch_up_since(&ShardWatermarks::from_vec(vec![1]))
267 .expect("catch_up_since")
268 {
269 BacklogCatchUp::Available(frames) => assert_eq!(frames.len(), 2),
270 BacklogCatchUp::NeedsSnapshot => panic!("expected backlog catch-up"),
271 }
272 }
273
274 #[test]
275 fn empty_backlog_after_trimming_requests_snapshot() {
276 let mut backlog = ReplicationBacklog::new(1, 1);
277 let mutations = vec![mutation(0, 1)];
278 backlog.push(frame_for(&mutations), &mutations);
279 let mutations = vec![mutation(0, 2)];
280 backlog.push(frame_for(&mutations), &mutations);
281 match backlog
282 .catch_up_since(&ShardWatermarks::from_vec(vec![0]))
283 .expect("catch_up_since")
284 {
285 BacklogCatchUp::NeedsSnapshot => {}
286 BacklogCatchUp::Available(_) => panic!("expected NeedsSnapshot after trimming"),
287 }
288 }
289
290 #[test]
291 fn multi_shard_catch_up_returns_relevant_frames() {
292 let mut backlog = ReplicationBacklog::new(1024 * 1024, 2);
293 let m1 = vec![mutation(0, 1)];
294 backlog.push(frame_for(&m1), &m1);
295 let m2 = vec![mutation(1, 1)];
296 backlog.push(frame_for(&m2), &m2);
297 let m3 = vec![mutation(0, 2)];
298 backlog.push(frame_for(&m3), &m3);
299
300 match backlog
303 .catch_up_since(&ShardWatermarks::from_vec(vec![1, 0]))
304 .expect("catch_up_since")
305 {
306 BacklogCatchUp::Available(frames) => assert_eq!(frames.len(), 2),
307 BacklogCatchUp::NeedsSnapshot => panic!("expected backlog catch-up"),
308 }
309 }
310
311 #[test]
312 fn catch_up_resends_partially_needed_batch() {
313 let mut backlog = ReplicationBacklog::new(1024 * 1024, 1);
314 let mutations = vec![mutation(0, 1), mutation(0, 2)];
315 backlog.push(frame_for(&mutations), &mutations);
316
317 match backlog
318 .catch_up_since(&ShardWatermarks::from_vec(vec![1]))
319 .expect("catch_up_since")
320 {
321 BacklogCatchUp::Available(frames) => assert_eq!(frames.len(), 1),
322 BacklogCatchUp::NeedsSnapshot => panic!("expected backlog catch-up"),
323 }
324 }
325}