Skip to main content

reddb_file/
bloom_segment.rs

1//! Persisted bloom-segment frame shared by storage/index owners.
2
3use std::fmt;
4
5use crate::BLOOM_SEGMENT_MAGIC;
6
7pub const BLOOM_SEGMENT_HEADER_LEN: usize = 1 + 1 + 4 + 4;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct BloomSegmentFrame {
11    pub num_hashes: u8,
12    pub bit_size: u32,
13    pub inserted: u32,
14    pub bits: Vec<u8>,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum BloomSegmentFrameError {
19    TooShort,
20    BadMagic,
21    LengthMismatch,
22}
23
24impl fmt::Display for BloomSegmentFrameError {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            Self::TooShort => write!(f, "bloom header too short"),
28            Self::BadMagic => write!(f, "bloom header magic mismatch"),
29            Self::LengthMismatch => write!(f, "bloom header length mismatch"),
30        }
31    }
32}
33
34impl std::error::Error for BloomSegmentFrameError {}
35
36pub fn encode_bloom_segment_frame(frame: &BloomSegmentFrame) -> Vec<u8> {
37    let mut out = Vec::with_capacity(BLOOM_SEGMENT_HEADER_LEN + frame.bits.len());
38    out.push(BLOOM_SEGMENT_MAGIC);
39    out.push(frame.num_hashes);
40    out.extend_from_slice(&frame.bit_size.to_be_bytes());
41    out.extend_from_slice(&frame.inserted.to_be_bytes());
42    out.extend_from_slice(&frame.bits);
43    out
44}
45
46pub fn decode_bloom_segment_frame(
47    bytes: &[u8],
48) -> Result<(BloomSegmentFrame, usize), BloomSegmentFrameError> {
49    if bytes.len() < BLOOM_SEGMENT_HEADER_LEN {
50        return Err(BloomSegmentFrameError::TooShort);
51    }
52    if bytes[0] != BLOOM_SEGMENT_MAGIC {
53        return Err(BloomSegmentFrameError::BadMagic);
54    }
55
56    let num_hashes = bytes[1];
57    let bit_size = u32::from_be_bytes(bytes[2..6].try_into().expect("len checked"));
58    let inserted = u32::from_be_bytes(bytes[6..10].try_into().expect("len checked"));
59    let byte_len = (bit_size as usize).div_ceil(8);
60    let total = BLOOM_SEGMENT_HEADER_LEN + byte_len;
61    if bytes.len() < total {
62        return Err(BloomSegmentFrameError::LengthMismatch);
63    }
64
65    Ok((
66        BloomSegmentFrame {
67            num_hashes,
68            bit_size,
69            inserted,
70            bits: bytes[BLOOM_SEGMENT_HEADER_LEN..total].to_vec(),
71        },
72        total,
73    ))
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn bloom_segment_frame_round_trips_big_endian_header() {
82        let frame = BloomSegmentFrame {
83            num_hashes: 7,
84            bit_size: 17,
85            inserted: 42,
86            bits: vec![0b1010_1010, 0b0101_0101, 0xFF],
87        };
88
89        let encoded = encode_bloom_segment_frame(&frame);
90
91        assert_eq!(encoded[0], BLOOM_SEGMENT_MAGIC);
92        assert_eq!(encoded[1], 7);
93        assert_eq!(&encoded[2..6], &17u32.to_be_bytes());
94        assert_eq!(&encoded[6..10], &42u32.to_be_bytes());
95
96        let (decoded, consumed) = decode_bloom_segment_frame(&encoded).expect("decode frame");
97        assert_eq!(decoded, frame);
98        assert_eq!(consumed, encoded.len());
99    }
100
101    #[test]
102    fn bloom_segment_frame_rejects_bad_inputs() {
103        assert_eq!(
104            decode_bloom_segment_frame(&[]).unwrap_err(),
105            BloomSegmentFrameError::TooShort
106        );
107
108        let mut bad_magic = encode_bloom_segment_frame(&BloomSegmentFrame {
109            num_hashes: 3,
110            bit_size: 8,
111            inserted: 1,
112            bits: vec![1],
113        });
114        bad_magic[0] = 0;
115        assert_eq!(
116            decode_bloom_segment_frame(&bad_magic).unwrap_err(),
117            BloomSegmentFrameError::BadMagic
118        );
119
120        let mut truncated = encode_bloom_segment_frame(&BloomSegmentFrame {
121            num_hashes: 3,
122            bit_size: 16,
123            inserted: 1,
124            bits: vec![1, 2],
125        });
126        truncated.pop();
127        assert_eq!(
128            decode_bloom_segment_frame(&truncated).unwrap_err(),
129            BloomSegmentFrameError::LengthMismatch
130        );
131    }
132}