use bytes::{Bytes, BytesMut};
use crate::kraft::types::NodeId;
pub type SnapshotId = (i64, i32);
const MAX_SNAPSHOT_BYTES: usize = 1024 * 1024 * 1024;
#[derive(Debug)]
pub struct SnapshotFetchState {
pub snapshot_id: SnapshotId,
pub leader_id: NodeId,
buf: BytesMut,
size: Option<i64>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum SnapshotFetchStep {
Continue { next_position: i64 },
Complete(Bytes),
Restart,
}
impl SnapshotFetchState {
#[must_use]
pub fn new(snapshot_id: SnapshotId, leader_id: NodeId) -> Self {
Self {
snapshot_id,
leader_id,
buf: BytesMut::new(),
size: None,
}
}
#[must_use]
pub fn next_position(&self) -> i64 {
i64::try_from(self.buf.len()).unwrap_or(i64::MAX)
}
pub fn on_chunk(
&mut self,
id: SnapshotId,
size: i64,
position: i64,
chunk: &[u8],
) -> SnapshotFetchStep {
if id != self.snapshot_id
|| position != self.next_position()
|| size < 0
|| size.cast_unsigned() > MAX_SNAPSHOT_BYTES as u64
{
return SnapshotFetchStep::Restart;
}
match self.size {
Some(s) if s != size => return SnapshotFetchStep::Restart,
_ => self.size = Some(size),
}
if i64::try_from(chunk.len()).unwrap_or(i64::MAX) > size - position {
return SnapshotFetchStep::Restart;
}
self.buf.extend_from_slice(chunk);
if self.next_position() >= size {
SnapshotFetchStep::Complete(self.buf.split().freeze())
} else {
SnapshotFetchStep::Continue {
next_position: self.next_position(),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
#[test]
fn assembles_in_order_chunks_to_complete() {
let mut s = SnapshotFetchState::new((10, 1), 2);
assert!(s.next_position() == 0);
let step = s.on_chunk((10, 1), 6, 0, b"abc");
assert!(step == SnapshotFetchStep::Continue { next_position: 3 });
let step = s.on_chunk((10, 1), 6, 3, b"def");
match step {
SnapshotFetchStep::Complete(b) => assert!(b.as_ref() == b"abcdef"),
other => panic!("expected Complete, got {other:?}"),
}
}
#[test]
fn out_of_order_position_restarts() {
let mut s = SnapshotFetchState::new((10, 1), 2);
let _ = s.on_chunk((10, 1), 6, 0, b"abc");
assert!(s.on_chunk((10, 1), 6, 99, b"def") == SnapshotFetchStep::Restart);
}
#[test]
fn mismatched_id_restarts() {
let mut s = SnapshotFetchState::new((10, 1), 2);
assert!(s.on_chunk((11, 1), 6, 0, b"abc") == SnapshotFetchStep::Restart);
}
#[test]
fn chunk_overshooting_declared_size_restarts() {
let mut s = SnapshotFetchState::new((10, 1), 2);
assert!(s.on_chunk((10, 1), 3, 0, b"abcde") == SnapshotFetchStep::Restart);
assert!(s.next_position() == 0);
}
#[test]
fn declared_size_over_cap_restarts() {
let mut s = SnapshotFetchState::new((10, 1), 2);
let too_big = i64::try_from(MAX_SNAPSHOT_BYTES).unwrap() + 1;
assert!(s.on_chunk((10, 1), too_big, 0, b"abc") == SnapshotFetchStep::Restart);
assert!(s.next_position() == 0);
}
#[test]
fn single_chunk_completes() {
let mut s = SnapshotFetchState::new((5, 0), 1);
match s.on_chunk((5, 0), 3, 0, b"xyz") {
SnapshotFetchStep::Complete(b) => assert!(b.as_ref() == b"xyz"),
other => panic!("expected Complete, got {other:?}"),
}
}
}