1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use cobble::{Error, MergeOperator, Result, TimeProvider, ValueType};
3use serde::{Deserialize, Serialize};
4use serde_json::Value as JsonValue;
5use std::collections::VecDeque;
6use std::mem::size_of;
7use std::sync::Arc;
8
9pub(crate) const LIST_OPERATOR_ID: &str = "cobble.list.v1";
10
11#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
12pub struct ListConfig {
13 pub max_elements: Option<usize>,
14 pub retain_mode: ListRetainMode,
15 pub preserve_element_ttl: bool,
16}
17
18impl Default for ListConfig {
19 fn default() -> Self {
20 Self {
21 max_elements: None,
22 retain_mode: ListRetainMode::Last,
23 preserve_element_ttl: false,
24 }
25 }
26}
27
28#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum ListRetainMode {
31 First,
32 #[default]
33 Last,
34}
35
36#[derive(Clone)]
37struct ListMergeOperator {
38 config: ListConfig,
39}
40
41impl ListMergeOperator {
42 fn new(config: ListConfig) -> Self {
43 Self { config }
44 }
45}
46
47impl MergeOperator for ListMergeOperator {
48 fn id(&self) -> String {
49 LIST_OPERATOR_ID.to_string()
50 }
51
52 fn metadata(&self) -> Option<JsonValue> {
53 serde_json::to_value(&self.config).ok()
54 }
55
56 fn merge(
57 &self,
58 existing_value: Bytes,
59 value: Bytes,
60 time_provider: Option<&dyn TimeProvider>,
61 ) -> Result<(Bytes, Option<ValueType>)> {
62 self.merge_batch(existing_value, vec![value], time_provider)
63 }
64
65 fn merge_batch(
66 &self,
67 existing_value: Bytes,
68 operands: Vec<Bytes>,
69 time_provider: Option<&dyn TimeProvider>,
70 ) -> Result<(Bytes, Option<ValueType>)> {
71 if operands.is_empty() {
72 return Ok((existing_value, None));
73 }
74
75 if let Some(merged) = try_fast_append_batch(&existing_value, &operands, &self.config)? {
77 return Ok((merged, None));
78 }
79
80 let now_seconds = time_provider
81 .map(|provider| provider.now_seconds())
82 .unwrap_or(0);
83 if let (ListRetainMode::Last, Some(max_elements)) =
84 (self.config.retain_mode, self.config.max_elements)
85 {
86 let (elements, reached_last_cap) = collect_last_from_newest(
87 &existing_value,
88 &operands,
89 &self.config,
90 now_seconds,
91 max_elements,
92 )?;
93 let output = encode_list_payload(&elements, &self.config)?;
94 let value_type = if reached_last_cap && !self.config.preserve_element_ttl {
95 Some(ValueType::Put)
96 } else {
97 None
98 };
99 return Ok((output, value_type));
100 }
101 let mut accumulator = ListAccumulator::new(&self.config);
102 accumulator.ingest_payload(&existing_value, now_seconds)?;
103 if !accumulator.should_stop() {
104 for operand in &operands {
105 accumulator.ingest_payload(operand, now_seconds)?;
106 if accumulator.should_stop() {
107 break;
108 }
109 }
110 }
111 let (elements, reached_last_cap) = accumulator.into_parts();
112 let output = encode_list_payload(&elements, &self.config)?;
113 let value_type = if reached_last_cap
114 && self.config.retain_mode == ListRetainMode::Last
115 && !self.config.preserve_element_ttl
116 {
117 Some(ValueType::Put)
118 } else {
119 None
120 };
121 Ok((output, value_type))
122 }
123}
124
125pub(crate) fn list_operator(config: ListConfig) -> Arc<dyn MergeOperator> {
126 Arc::new(ListMergeOperator::new(config))
127}
128
129pub(crate) fn list_operator_from_metadata(
130 id: &str,
131 metadata: Option<&JsonValue>,
132) -> Option<Arc<dyn MergeOperator>> {
133 if id != LIST_OPERATOR_ID {
134 return None;
135 }
136 let config = serde_json::from_value::<ListConfig>(metadata?.clone()).ok()?;
137 Some(list_operator(config))
138}
139
140pub(crate) fn encode_list_for_write(
141 elements: Vec<Bytes>,
142 config: &ListConfig,
143 ttl_seconds: Option<u32>,
144 now_seconds: u32,
145) -> Result<Bytes> {
146 let expires_at_secs = if config.preserve_element_ttl {
147 ttl_seconds.map(|ttl| now_seconds.saturating_add(ttl))
148 } else {
149 None
150 };
151 let decoded = elements
152 .into_iter()
153 .map(|value| DecodedListElement {
154 value,
155 expires_at_secs,
156 })
157 .collect::<Vec<_>>();
158 encode_list_payload(&decoded, config)
159}
160
161pub(crate) fn decode_list_for_read(
162 raw: &Bytes,
163 config: &ListConfig,
164 now_seconds: u32,
165) -> Result<Vec<Bytes>> {
166 let mut accumulator = ListAccumulator::new(config);
167 accumulator.ingest_payload(raw, now_seconds)?;
168 let (elements, _) = accumulator.into_parts();
169 Ok(elements.into_iter().map(|element| element.value).collect())
170}
171
172#[derive(Clone)]
173struct DecodedListElement {
174 value: Bytes,
175 expires_at_secs: Option<u32>,
176}
177
178struct ListPayloadCursor {
182 remaining: Bytes,
183 preserve_element_ttl: bool,
184 remaining_elements: usize,
185}
186
187impl ListPayloadCursor {
188 fn new(payload: &Bytes, preserve_element_ttl: bool) -> Result<Self> {
189 if payload.is_empty() {
190 return Ok(Self {
191 remaining: Bytes::new(),
192 preserve_element_ttl,
193 remaining_elements: 0,
194 });
195 }
196 let mut remaining = payload.clone();
197 if remaining.remaining() < size_of::<u32>() {
198 return Err(Error::FileFormatError(
199 "invalid list payload: missing element count".to_string(),
200 ));
201 }
202 let remaining_elements = remaining.get_u32_le() as usize;
203 Ok(Self {
204 remaining,
205 preserve_element_ttl,
206 remaining_elements,
207 })
208 }
209
210 fn next(&mut self) -> Result<Option<DecodedListElement>> {
211 if self.remaining_elements == 0 {
212 if self.remaining.has_remaining() {
213 return Err(Error::InvalidState(
214 "invalid list payload: trailing bytes found".to_string(),
215 ));
216 }
217 return Ok(None);
218 }
219 let expires_at_secs = if self.preserve_element_ttl {
220 if self.remaining.remaining() < size_of::<u32>() {
221 return Err(Error::InvalidState(
222 "invalid list payload: missing element ttl timestamp".to_string(),
223 ));
224 }
225 let expires_at = self.remaining.get_u32_le();
226 if expires_at == 0 {
227 None
228 } else {
229 Some(expires_at)
230 }
231 } else {
232 None
233 };
234 if self.remaining.remaining() < size_of::<u32>() {
235 return Err(Error::InvalidState(
236 "invalid list payload: missing element length".to_string(),
237 ));
238 }
239 let element_len = self.remaining.get_u32_le() as usize;
240 if self.remaining.remaining() < element_len {
241 return Err(Error::InvalidState(format!(
242 "invalid list payload: element length {} exceeds remaining {}",
243 element_len,
244 self.remaining.remaining()
245 )));
246 }
247 self.remaining_elements -= 1;
248 Ok(Some(DecodedListElement {
249 value: self.remaining.split_to(element_len),
250 expires_at_secs,
251 }))
252 }
253}
254
255struct ListAccumulator {
260 config: ListConfig,
261 mode: ListAccumulatorMode,
262 reached_last_cap: bool,
263}
264
265enum ListAccumulatorMode {
266 All(Vec<DecodedListElement>),
267 First {
268 max: usize,
269 kept: Vec<DecodedListElement>,
270 },
271 Last {
272 max: usize,
273 kept: VecDeque<DecodedListElement>,
274 },
275}
276
277impl ListAccumulator {
278 fn new(config: &ListConfig) -> Self {
279 let mode = match (config.max_elements, config.retain_mode) {
280 (Some(max), ListRetainMode::First) => ListAccumulatorMode::First {
281 max,
282 kept: Vec::with_capacity(max),
283 },
284 (Some(max), ListRetainMode::Last) => ListAccumulatorMode::Last {
285 max,
286 kept: VecDeque::with_capacity(max),
287 },
288 (None, _) => ListAccumulatorMode::All(Vec::new()),
289 };
290 Self {
291 config: config.clone(),
292 mode,
293 reached_last_cap: false,
294 }
295 }
296
297 fn ingest_payload(&mut self, payload: &Bytes, now_seconds: u32) -> Result<()> {
298 let mut cursor = ListPayloadCursor::new(payload, self.config.preserve_element_ttl)?;
299 while let Some(element) = cursor.next()? {
300 if self.config.preserve_element_ttl
301 && element
302 .expires_at_secs
303 .is_some_and(|expires_at| expires_at <= now_seconds)
304 {
305 continue;
306 }
307 match &mut self.mode {
308 ListAccumulatorMode::All(values) => {
309 values.push(element);
310 }
311 ListAccumulatorMode::First { max, kept } => {
312 if kept.len() < *max {
313 kept.push(element);
314 }
315 }
316 ListAccumulatorMode::Last { max, kept } => {
317 kept.push_back(element);
318 if kept.len() > *max {
319 let _ = kept.pop_front();
320 }
321 if kept.len() == *max {
322 self.reached_last_cap = true;
323 }
324 }
325 }
326 if self.should_stop() {
327 break;
328 }
329 }
330 Ok(())
331 }
332
333 fn should_stop(&self) -> bool {
334 match &self.mode {
335 ListAccumulatorMode::First { max, kept } => kept.len() >= *max,
336 ListAccumulatorMode::All(_) | ListAccumulatorMode::Last { .. } => false,
337 }
338 }
339
340 fn into_parts(self) -> (Vec<DecodedListElement>, bool) {
341 let elements = match self.mode {
342 ListAccumulatorMode::All(values) => values,
343 ListAccumulatorMode::First { kept, .. } => kept,
344 ListAccumulatorMode::Last { kept, .. } => kept.into_iter().collect(),
345 };
346 (elements, self.reached_last_cap)
347 }
348}
349
350fn encode_list_payload(elements: &[DecodedListElement], config: &ListConfig) -> Result<Bytes> {
351 if elements.len() > u32::MAX as usize {
352 return Err(Error::InputError(format!(
353 "too many list elements to encode: {}",
354 elements.len()
355 )));
356 }
357 let ttl_bytes = if config.preserve_element_ttl {
358 size_of::<u32>()
359 } else {
360 0
361 };
362 let total_size = size_of::<u32>()
363 + elements
364 .iter()
365 .map(|item| ttl_bytes + size_of::<u32>() + item.value.len())
366 .sum::<usize>();
367 let mut out = BytesMut::with_capacity(total_size);
368 out.put_u32_le(elements.len() as u32);
369 for element in elements {
370 if config.preserve_element_ttl {
371 out.put_u32_le(element.expires_at_secs.unwrap_or(0));
372 }
373 out.put_u32_le(element.value.len() as u32);
374 out.extend_from_slice(element.value.as_ref());
375 }
376 Ok(out.freeze())
377}
378
379fn try_fast_append_batch(
380 existing: &Bytes,
381 operands: &[Bytes],
382 config: &ListConfig,
383) -> Result<Option<Bytes>> {
384 if config.preserve_element_ttl {
385 return Ok(None);
386 }
387 let (mut total_count, existing_body) = parse_payload_body(existing)?;
388 let mut operand_bodies = Vec::with_capacity(operands.len());
389 for operand in operands {
390 let (count, body) = parse_payload_body(operand)?;
391 total_count = total_count.checked_add(count).ok_or_else(|| {
392 Error::InputError(format!(
393 "list element count overflow during merge: {} + {}",
394 total_count, count
395 ))
396 })?;
397 operand_bodies.push(body);
398 }
399 if let Some(max_elements) = config.max_elements
400 && total_count > max_elements
401 {
402 return Ok(None);
403 }
404 if total_count > u32::MAX as usize {
405 return Err(Error::InputError(format!(
406 "too many list elements to encode: {}",
407 total_count
408 )));
409 }
410 let total_body_size =
411 existing_body.len() + operand_bodies.iter().map(Bytes::len).sum::<usize>();
412 let mut out = BytesMut::with_capacity(size_of::<u32>() + total_body_size);
413 out.put_u32_le(total_count as u32);
414 out.extend_from_slice(existing_body.as_ref());
415 for body in operand_bodies {
416 out.extend_from_slice(body.as_ref());
417 }
418 Ok(Some(out.freeze()))
419}
420
421fn parse_payload_body(payload: &Bytes) -> Result<(usize, Bytes)> {
424 if payload.is_empty() {
425 return Ok((0, Bytes::new()));
426 }
427 if payload.len() < size_of::<u32>() {
428 return Err(Error::FileFormatError(
429 "invalid list payload: missing element count".to_string(),
430 ));
431 }
432 let mut header = payload.slice(..size_of::<u32>());
433 let element_count = header.get_u32_le() as usize;
434 Ok((element_count, payload.slice(size_of::<u32>()..)))
435}
436
437fn collect_last_from_newest(
441 existing_value: &Bytes,
442 operands: &[Bytes],
443 config: &ListConfig,
444 now_seconds: u32,
445 max_elements: usize,
446) -> Result<(Vec<DecodedListElement>, bool)> {
447 if max_elements == 0 {
448 return Ok((Vec::new(), true));
449 }
450 let mut newest_to_oldest = Vec::with_capacity(max_elements);
451 for payload in operands.iter().rev().chain(std::iter::once(existing_value)) {
452 if newest_to_oldest.len() >= max_elements {
453 break;
454 }
455 let needed = max_elements - newest_to_oldest.len();
456 collect_last_from_single_payload(
457 payload,
458 config,
459 now_seconds,
460 needed,
461 &mut newest_to_oldest,
462 )?;
463 }
464 let reached_last_cap = newest_to_oldest.len() >= max_elements;
465 newest_to_oldest.reverse();
466 Ok((newest_to_oldest, reached_last_cap))
467}
468
469fn collect_last_from_single_payload(
470 payload: &Bytes,
471 config: &ListConfig,
472 now_seconds: u32,
473 needed: usize,
474 out_newest_to_oldest: &mut Vec<DecodedListElement>,
475) -> Result<()> {
476 if needed == 0 {
477 return Ok(());
478 }
479 let mut cursor = ListPayloadCursor::new(payload, config.preserve_element_ttl)?;
480 let mut tail = VecDeque::with_capacity(needed);
481 while let Some(element) = cursor.next()? {
482 if config.preserve_element_ttl
483 && element
484 .expires_at_secs
485 .is_some_and(|expires_at| expires_at <= now_seconds)
486 {
487 continue;
488 }
489 tail.push_back(element);
490 if tail.len() > needed {
491 let _ = tail.pop_front();
492 }
493 }
494 while let Some(element) = tail.pop_back() {
495 out_newest_to_oldest.push(element);
496 }
497 Ok(())
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_list_round_trip() {
506 let config = ListConfig {
507 max_elements: Some(2),
508 retain_mode: ListRetainMode::Last,
509 preserve_element_ttl: false,
510 };
511 let encoded = encode_list_for_write(
512 vec![Bytes::from_static(b"a"), Bytes::from_static(b"b")],
513 &config,
514 None,
515 10,
516 )
517 .unwrap();
518 let decoded = decode_list_for_read(&encoded, &config, 10).unwrap();
519 assert_eq!(
520 decoded,
521 vec![Bytes::from_static(b"a"), Bytes::from_static(b"b")]
522 );
523 }
524
525 #[test]
526 fn test_list_ttl_uses_supplied_time() {
527 let config = ListConfig {
528 max_elements: None,
529 retain_mode: ListRetainMode::Last,
530 preserve_element_ttl: true,
531 };
532 let encoded =
533 encode_list_for_write(vec![Bytes::from_static(b"a")], &config, Some(5), 100).unwrap();
534 assert_eq!(
535 decode_list_for_read(&encoded, &config, 104).unwrap(),
536 vec![Bytes::from_static(b"a")]
537 );
538 assert!(
539 decode_list_for_read(&encoded, &config, 105)
540 .unwrap()
541 .is_empty()
542 );
543 }
544
545 #[test]
546 fn test_merge_batch_fast_append_keeps_valid_payload() {
547 let config = ListConfig {
548 max_elements: Some(4),
549 retain_mode: ListRetainMode::Last,
550 preserve_element_ttl: false,
551 };
552 let operator = ListMergeOperator::new(config.clone());
553 let left = encode_list_for_write(
554 vec![Bytes::from_static(b"a"), Bytes::from_static(b"b")],
555 &config,
556 None,
557 0,
558 )
559 .unwrap();
560 let right =
561 encode_list_for_write(vec![Bytes::from_static(b"c")], &config, None, 0).unwrap();
562 let merged = operator.merge_batch(left, vec![right], None).unwrap();
563 assert_eq!(
564 decode_list_for_read(&merged.0, &config, 0).unwrap(),
565 vec![
566 Bytes::from_static(b"a"),
567 Bytes::from_static(b"b"),
568 Bytes::from_static(b"c"),
569 ]
570 );
571 }
572
573 #[test]
574 fn test_merge_over_cap_falls_back_to_retain_policy() {
575 let config = ListConfig {
576 max_elements: Some(2),
577 retain_mode: ListRetainMode::Last,
578 preserve_element_ttl: false,
579 };
580 let operator = ListMergeOperator::new(config.clone());
581 let left = encode_list_for_write(vec![Bytes::from_static(b"a")], &config, None, 0).unwrap();
582 let right = encode_list_for_write(
583 vec![Bytes::from_static(b"b"), Bytes::from_static(b"c")],
584 &config,
585 None,
586 0,
587 )
588 .unwrap();
589 let merged = operator.merge_batch(left, vec![right], None).unwrap();
590 assert_eq!(
591 decode_list_for_read(&merged.0, &config, 0).unwrap(),
592 vec![Bytes::from_static(b"b"), Bytes::from_static(b"c")]
593 );
594 }
595
596 #[test]
597 fn test_merge_last_cap_returns_put_value_type() {
598 let config = ListConfig {
599 max_elements: Some(2),
600 retain_mode: ListRetainMode::Last,
601 preserve_element_ttl: false,
602 };
603 let operator = ListMergeOperator::new(config.clone());
604 let left = encode_list_for_write(vec![Bytes::from_static(b"a")], &config, None, 0).unwrap();
605 let right = encode_list_for_write(
606 vec![Bytes::from_static(b"b"), Bytes::from_static(b"c")],
607 &config,
608 None,
609 0,
610 )
611 .unwrap();
612 let merged = operator.merge_batch(left, vec![right], None).unwrap();
613 assert_eq!(merged.1, Some(ValueType::Put));
614 }
615
616 #[test]
617 fn test_merge_last_stops_before_older_payloads() {
618 let config = ListConfig {
619 max_elements: Some(2),
620 retain_mode: ListRetainMode::Last,
621 preserve_element_ttl: true,
622 };
623 let operator = ListMergeOperator::new(config.clone());
624 let malformed_existing = Bytes::from_static(b"\x01");
627 let op1 = encode_list_for_write(vec![Bytes::from_static(b"a")], &config, None, 0).unwrap();
628 let op2 = encode_list_for_write(vec![Bytes::from_static(b"b")], &config, None, 0).unwrap();
629 let merged = operator
630 .merge_batch(malformed_existing, vec![op1, op2], None)
631 .unwrap();
632 assert_eq!(
633 decode_list_for_read(&merged.0, &config, 0).unwrap(),
634 vec![Bytes::from_static(b"a"), Bytes::from_static(b"b")]
635 );
636 }
637
638 #[test]
639 fn test_merge_first_cap_stops_and_keeps_merge_type() {
640 let config = ListConfig {
641 max_elements: Some(2),
642 retain_mode: ListRetainMode::First,
643 preserve_element_ttl: false,
644 };
645 let operator = ListMergeOperator::new(config.clone());
646 let left = encode_list_for_write(vec![Bytes::from_static(b"a")], &config, None, 0).unwrap();
647 let right = encode_list_for_write(
648 vec![
649 Bytes::from_static(b"b"),
650 Bytes::from_static(b"c"),
651 Bytes::from_static(b"d"),
652 ],
653 &config,
654 None,
655 0,
656 )
657 .unwrap();
658 let merged = operator.merge_batch(left, vec![right], None).unwrap();
659 assert_eq!(merged.1, None);
660 assert_eq!(
661 decode_list_for_read(&merged.0, &config, 0).unwrap(),
662 vec![Bytes::from_static(b"a"), Bytes::from_static(b"b")]
663 );
664 }
665
666 #[test]
667 fn test_list_operator_from_metadata_requires_metadata() {
668 assert!(list_operator_from_metadata(LIST_OPERATOR_ID, None).is_none());
669 }
670}