use bytes::{Bytes, BytesMut};
use std::ops::Bound::{Excluded, Included, Unbounded};
use std::ops::{Bound, RangeBounds};
pub(crate) fn lex_increment(data: &[u8]) -> Option<Bytes> {
if data.is_empty() {
return None;
}
let mut result = BytesMut::from(data);
while let Some(last) = result.last_mut() {
if *last < 0xFF {
*last += 1;
return Some(result.freeze());
}
result.truncate(result.len() - 1);
}
None
}
#[derive(Clone, Debug)]
pub struct BytesRange {
pub start: Bound<Bytes>,
pub end: Bound<Bytes>,
}
impl BytesRange {
pub fn new(start: Bound<Bytes>, end: Bound<Bytes>) -> Self {
Self { start, end }
}
pub fn prefix(prefix: Bytes) -> Self {
if prefix.is_empty() {
Self::unbounded()
} else {
match lex_increment(&prefix) {
Some(end) => Self {
start: Included(prefix),
end: Excluded(end),
},
None => Self {
start: Included(prefix),
end: Unbounded,
},
}
}
}
pub fn contains(&self, k: &[u8]) -> bool {
(match &self.start {
Included(s) => k >= s,
Excluded(s) => k > s,
Unbounded => true,
}) && (match &self.end {
Included(e) => k <= e,
Excluded(e) => k < e,
Unbounded => true,
})
}
pub fn unbounded() -> Self {
Self {
start: Unbounded,
end: Unbounded,
}
}
}
impl RangeBounds<Bytes> for BytesRange {
fn start_bound(&self) -> Bound<&Bytes> {
self.start.as_ref()
}
fn end_bound(&self) -> Bound<&Bytes> {
self.end.as_ref()
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
proptest! {
#[test]
fn should_increment_produce_strictly_greater_result(data: Vec<u8>) {
let all_ff = !data.is_empty() && data.iter().all(|&b| b == 0xFF);
prop_assume!(!data.is_empty() && !all_ff);
let incremented = lex_increment(&data).unwrap();
prop_assert!(
incremented.as_ref() > data.as_slice(),
"lex_increment({:?}) = {:?} should be > input",
data,
incremented
);
}
#[test]
fn should_increment_produce_immediate_successor(data: Vec<u8>) {
let all_ff = !data.is_empty() && data.iter().all(|&b| b == 0xFF);
prop_assume!(!data.is_empty() && !all_ff);
let incremented = lex_increment(&data).unwrap();
if let Some(&last) = data.last() {
if last < 0xFF {
let mut expected = data.clone();
*expected.last_mut().unwrap() += 1;
prop_assert_eq!(incremented.as_ref(), expected.as_slice());
} else {
prop_assert!(incremented.len() < data.len());
prop_assert!(data.starts_with(&incremented[..incremented.len() - 1]));
}
}
}
#[test]
fn should_prefix_range_contain_all_prefixed_keys(prefix: Vec<u8>, suffix: Vec<u8>) {
prop_assume!(!prefix.is_empty());
let range = BytesRange::prefix(Bytes::from(prefix.clone()));
prop_assert!(range.contains(&prefix));
let mut extended = prefix.clone();
extended.extend(&suffix);
prop_assert!(range.contains(&extended));
}
}
#[test]
fn should_increment_simple_byte() {
assert_eq!(lex_increment(b"a").unwrap().as_ref(), b"b");
assert_eq!(lex_increment(&[0x00]).unwrap().as_ref(), &[0x01]);
assert_eq!(lex_increment(&[0xFE]).unwrap().as_ref(), &[0xFF]);
}
#[test]
fn should_increment_with_trailing_ff() {
assert_eq!(lex_increment(&[0x61, 0xFF]).unwrap().as_ref(), &[0x62]);
assert_eq!(
lex_increment(&[0x61, 0xFF, 0xFF]).unwrap().as_ref(),
&[0x62]
);
assert_eq!(
lex_increment(&[0x00, 0xFF, 0xFF]).unwrap().as_ref(),
&[0x01]
);
}
#[test]
fn should_return_none_for_non_incrementable() {
assert!(lex_increment(&[]).is_none());
assert!(lex_increment(&[0xFF]).is_none());
assert!(lex_increment(&[0xFF, 0xFF]).is_none());
}
#[test]
fn should_create_prefix_range() {
let range = BytesRange::prefix(Bytes::from("foo"));
assert!(range.contains(b"foo"));
assert!(range.contains(b"foobar"));
assert!(range.contains(b"foo\x00"));
assert!(range.contains(b"foo\xFF"));
assert!(!range.contains(b"fo"));
assert!(!range.contains(b"fop"));
assert!(!range.contains(b"fop\x00"));
}
#[test]
fn should_handle_prefix_with_trailing_ff() {
let range = BytesRange::prefix(Bytes::from_static(&[0x61, 0xFF]));
assert!(range.contains(&[0x61, 0xFF]));
assert!(range.contains(&[0x61, 0xFF, 0x00]));
assert!(range.contains(&[0x61, 0xFF, 0xFF]));
assert!(!range.contains(&[0x61]));
assert!(!range.contains(&[0x62]));
}
#[test]
fn should_handle_all_ff_prefix() {
let range = BytesRange::prefix(Bytes::from_static(&[0xFF, 0xFF]));
assert!(range.contains(&[0xFF, 0xFF]));
assert!(range.contains(&[0xFF, 0xFF, 0x00]));
assert!(range.contains(&[0xFF, 0xFF, 0xFF, 0xFF]));
assert!(!range.contains(&[0xFF]));
assert!(!range.contains(&[0xFE, 0xFF]));
}
#[test]
fn should_handle_empty_prefix() {
let range = BytesRange::prefix(Bytes::new());
assert!(range.contains(b""));
assert!(range.contains(b"anything"));
assert!(range.contains(&[0xFF, 0xFF, 0xFF]));
}
}