Skip to main content

common/
bytes.rs

1//! Byte utilities for key encoding and range queries.
2
3use bytes::{Bytes, BytesMut};
4use std::ops::Bound::{Excluded, Included, Unbounded};
5use std::ops::{Bound, RangeBounds};
6
7/// Computes the lexicographic successor of a byte sequence.
8///
9/// Returns the smallest byte sequence that is strictly greater than the input.
10/// Returns `None` if no such sequence exists (i.e., input is empty or all `0xFF` bytes).
11///
12/// This is useful for computing exclusive upper bounds in range queries.
13/// For example, to query all keys with prefix "foo", use the range `["foo", lex_lex_increment("foo"))`.
14///
15/// # Algorithm
16///
17/// Starting from the rightmost byte:
18/// - If it's less than `0xFF`, increment it and return
19/// - If it's `0xFF`, remove it and try to increment the previous byte
20/// - If all bytes are `0xFF` (or input is empty), return `None`
21///
22/// # Examples
23///
24/// - `[0x61]` ("a") → `Some([0x62])` ("b")
25/// - `[0x61, 0xFF]` → `Some([0x62])`
26/// - `[0xFF]` → `None`
27/// - `[]` → `None`
28pub(crate) fn lex_increment(data: &[u8]) -> Option<Bytes> {
29    if data.is_empty() {
30        return None;
31    }
32
33    let mut result = BytesMut::from(data);
34
35    // Work backwards, looking for a byte we can increment
36    while let Some(last) = result.last_mut() {
37        if *last < 0xFF {
38            *last += 1;
39            return Some(result.freeze());
40        }
41        // Last byte is 0xFF, truncate it and try the previous byte
42        result.truncate(result.len() - 1);
43    }
44
45    // All bytes were 0xFF, no valid increment exists
46    None
47}
48
49/// A range over byte sequences, used for key range queries.
50#[derive(Clone, Debug)]
51pub struct BytesRange {
52    pub start: Bound<Bytes>,
53    pub end: Bound<Bytes>,
54}
55
56impl BytesRange {
57    pub fn new(start: Bound<Bytes>, end: Bound<Bytes>) -> Self {
58        Self { start, end }
59    }
60
61    /// Creates a range that includes all keys with the given prefix.
62    pub fn prefix(prefix: Bytes) -> Self {
63        if prefix.is_empty() {
64            Self::unbounded()
65        } else {
66            match lex_increment(&prefix) {
67                Some(end) => Self {
68                    start: Included(prefix),
69                    end: Excluded(end),
70                },
71                None => Self {
72                    start: Included(prefix),
73                    end: Unbounded,
74                },
75            }
76        }
77    }
78
79    pub fn contains(&self, k: &[u8]) -> bool {
80        (match &self.start {
81            Included(s) => k >= s,
82            Excluded(s) => k > s,
83            Unbounded => true,
84        }) && (match &self.end {
85            Included(e) => k <= e,
86            Excluded(e) => k < e,
87            Unbounded => true,
88        })
89    }
90
91    /// Creates a range that scans everything.
92    pub fn unbounded() -> Self {
93        Self {
94            start: Unbounded,
95            end: Unbounded,
96        }
97    }
98}
99
100impl RangeBounds<Bytes> for BytesRange {
101    fn start_bound(&self) -> Bound<&Bytes> {
102        self.start.as_ref()
103    }
104    fn end_bound(&self) -> Bound<&Bytes> {
105        self.end.as_ref()
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use proptest::prelude::*;
112
113    use super::*;
114
115    // Property tests for increment
116
117    proptest! {
118        #[test]
119        fn should_increment_produce_strictly_greater_result(data: Vec<u8>) {
120            let all_ff = !data.is_empty() && data.iter().all(|&b| b == 0xFF);
121            prop_assume!(!data.is_empty() && !all_ff);
122
123            let incremented = lex_increment(&data).unwrap();
124            prop_assert!(
125                incremented.as_ref() > data.as_slice(),
126                "lex_increment({:?}) = {:?} should be > input",
127                data,
128                incremented
129            );
130        }
131
132        #[test]
133        fn should_increment_produce_immediate_successor(data: Vec<u8>) {
134            let all_ff = !data.is_empty() && data.iter().all(|&b| b == 0xFF);
135            prop_assume!(!data.is_empty() && !all_ff);
136
137            let incremented = lex_increment(&data).unwrap();
138
139            if let Some(&last) = data.last() {
140                if last < 0xFF {
141                    let mut expected = data.clone();
142                    *expected.last_mut().unwrap() += 1;
143                    prop_assert_eq!(incremented.as_ref(), expected.as_slice());
144                } else {
145                    prop_assert!(incremented.len() < data.len());
146                    prop_assert!(data.starts_with(&incremented[..incremented.len() - 1]));
147                }
148            }
149        }
150
151        #[test]
152        fn should_prefix_range_contain_all_prefixed_keys(prefix: Vec<u8>, suffix: Vec<u8>) {
153            prop_assume!(!prefix.is_empty());
154
155            let range = BytesRange::prefix(Bytes::from(prefix.clone()));
156
157            // The prefix itself should be included
158            prop_assert!(range.contains(&prefix));
159
160            // Any key with this prefix should be included
161            let mut extended = prefix.clone();
162            extended.extend(&suffix);
163            prop_assert!(range.contains(&extended));
164        }
165    }
166
167    // Concrete increment tests
168
169    #[test]
170    fn should_increment_simple_byte() {
171        assert_eq!(lex_increment(b"a").unwrap().as_ref(), b"b");
172        assert_eq!(lex_increment(&[0x00]).unwrap().as_ref(), &[0x01]);
173        assert_eq!(lex_increment(&[0xFE]).unwrap().as_ref(), &[0xFF]);
174    }
175
176    #[test]
177    fn should_increment_with_trailing_ff() {
178        assert_eq!(lex_increment(&[0x61, 0xFF]).unwrap().as_ref(), &[0x62]);
179        assert_eq!(
180            lex_increment(&[0x61, 0xFF, 0xFF]).unwrap().as_ref(),
181            &[0x62]
182        );
183        assert_eq!(
184            lex_increment(&[0x00, 0xFF, 0xFF]).unwrap().as_ref(),
185            &[0x01]
186        );
187    }
188
189    #[test]
190    fn should_return_none_for_non_incrementable() {
191        assert!(lex_increment(&[]).is_none());
192        assert!(lex_increment(&[0xFF]).is_none());
193        assert!(lex_increment(&[0xFF, 0xFF]).is_none());
194    }
195
196    // BytesRange tests
197
198    #[test]
199    fn should_create_prefix_range() {
200        let range = BytesRange::prefix(Bytes::from("foo"));
201
202        assert!(range.contains(b"foo"));
203        assert!(range.contains(b"foobar"));
204        assert!(range.contains(b"foo\x00"));
205        assert!(range.contains(b"foo\xFF"));
206
207        assert!(!range.contains(b"fo"));
208        assert!(!range.contains(b"fop"));
209        assert!(!range.contains(b"fop\x00"));
210    }
211
212    #[test]
213    fn should_handle_prefix_with_trailing_ff() {
214        let range = BytesRange::prefix(Bytes::from_static(&[0x61, 0xFF]));
215
216        assert!(range.contains(&[0x61, 0xFF]));
217        assert!(range.contains(&[0x61, 0xFF, 0x00]));
218        assert!(range.contains(&[0x61, 0xFF, 0xFF]));
219
220        assert!(!range.contains(&[0x61]));
221        assert!(!range.contains(&[0x62]));
222    }
223
224    #[test]
225    fn should_handle_all_ff_prefix() {
226        let range = BytesRange::prefix(Bytes::from_static(&[0xFF, 0xFF]));
227
228        // Should be unbounded on the end
229        assert!(range.contains(&[0xFF, 0xFF]));
230        assert!(range.contains(&[0xFF, 0xFF, 0x00]));
231        assert!(range.contains(&[0xFF, 0xFF, 0xFF, 0xFF]));
232
233        assert!(!range.contains(&[0xFF]));
234        assert!(!range.contains(&[0xFE, 0xFF]));
235    }
236
237    #[test]
238    fn should_handle_empty_prefix() {
239        let range = BytesRange::prefix(Bytes::new());
240
241        // Empty prefix = unbounded = matches everything
242        assert!(range.contains(b""));
243        assert!(range.contains(b"anything"));
244        assert!(range.contains(&[0xFF, 0xFF, 0xFF]));
245    }
246}