use byteorder::{BigEndian, ByteOrder};
use chrono::{DateTime, NaiveDateTime, Utc};
use rust_decimal::Decimal;
use uuid::Uuid;
pub trait BinaryKey: ToOwned {
fn size(&self) -> usize;
fn write(&self, buffer: &mut [u8]) -> usize;
fn read(buffer: &[u8]) -> Self::Owned;
}
impl BinaryKey for () {
fn size(&self) -> usize {
0
}
fn write(&self, _buffer: &mut [u8]) -> usize {
self.size()
}
fn read(_buffer: &[u8]) -> Self::Owned {}
}
impl BinaryKey for u8 {
fn size(&self) -> usize {
1
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer[0] = *self;
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
buffer[0]
}
}
impl BinaryKey for i8 {
fn size(&self) -> usize {
1
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer[0] = self.wrapping_add(Self::min_value()) as u8;
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
buffer[0].wrapping_sub(Self::min_value() as u8) as Self
}
}
macro_rules! storage_key_for_ints {
($utype:ident, $itype:ident, $size:expr, $read_method:ident, $write_method:ident) => {
impl BinaryKey for $utype {
fn size(&self) -> usize {
$size
}
fn write(&self, buffer: &mut [u8]) -> usize {
BigEndian::$write_method(buffer, *self);
self.size()
}
fn read(buffer: &[u8]) -> Self {
BigEndian::$read_method(buffer)
}
}
impl BinaryKey for $itype {
fn size(&self) -> usize {
$size
}
fn write(&self, buffer: &mut [u8]) -> usize {
BigEndian::$write_method(buffer, self.wrapping_add(Self::min_value()) as $utype);
self.size()
}
fn read(buffer: &[u8]) -> Self {
BigEndian::$read_method(buffer).wrapping_sub(Self::min_value() as $utype) as Self
}
}
};
}
storage_key_for_ints! {u16, i16, 2, read_u16, write_u16}
storage_key_for_ints! {u32, i32, 4, read_u32, write_u32}
storage_key_for_ints! {u64, i64, 8, read_u64, write_u64}
storage_key_for_ints! {u128, i128, 16, read_u128, write_u128}
impl BinaryKey for Vec<u8> {
fn size(&self) -> usize {
self.len()
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer[..self.size()].copy_from_slice(self);
self.size()
}
fn read(buffer: &[u8]) -> Self {
buffer.to_vec()
}
}
impl BinaryKey for [u8] {
fn size(&self) -> usize {
self.len()
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer[..self.size()].copy_from_slice(self);
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
Vec::<u8>::read(buffer)
}
}
impl BinaryKey for [u8; 32] {
fn size(&self) -> usize {
self.len()
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer[..self.size()].copy_from_slice(self);
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
let mut value = [0_u8; 32];
value.copy_from_slice(buffer);
value
}
}
impl BinaryKey for String {
fn size(&self) -> usize {
self.len()
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer[..self.size()].copy_from_slice(self.as_bytes());
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
const ERROR_MSG: &str = "Error reading UTF-8 string from the database. \
Probable reason is data schema mismatch; for example, data was written to \
`MapIndex<u64, _>` and is read as `MapIndex<str, _>`";
std::str::from_utf8(buffer).expect(ERROR_MSG).to_string()
}
}
impl BinaryKey for str {
fn size(&self) -> usize {
self.len()
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer[..self.size()].copy_from_slice(self.as_bytes());
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
String::read(buffer)
}
}
impl BinaryKey for DateTime<Utc> {
fn size(&self) -> usize {
12
}
fn write(&self, buffer: &mut [u8]) -> usize {
let secs = self.timestamp();
let nanos = self.timestamp_subsec_nanos();
secs.write(&mut buffer[0..8]);
nanos.write(&mut buffer[8..12]);
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
let secs = i64::read(&buffer[0..8]);
let nanos = u32::read(&buffer[8..12]);
Self::from_utc(NaiveDateTime::from_timestamp(secs, nanos), Utc)
}
}
impl BinaryKey for Uuid {
fn size(&self) -> usize {
16
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer.copy_from_slice(self.as_bytes());
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
Self::from_slice(buffer).unwrap()
}
}
impl BinaryKey for Decimal {
fn size(&self) -> usize {
16
}
fn write(&self, buffer: &mut [u8]) -> usize {
buffer.copy_from_slice(&self.serialize());
self.size()
}
fn read(buffer: &[u8]) -> Self::Owned {
let mut bytes = [0_u8; 16];
bytes.copy_from_slice(buffer);
Self::deserialize(bytes)
}
}
#[cfg(test)]
mod tests {
use super::{BinaryKey, DateTime, Decimal, Utc, Uuid};
use crate::access::CopyAccessExt;
use std::{fmt::Debug, str::FromStr};
use chrono::{Duration, TimeZone};
const FUZZ_SAMPLES: usize = 100_000;
macro_rules! test_storage_key_for_int_type {
(full $type:ident, $size:expr => $test_name:ident) => {
#[test]
#[allow(clippy::replace_consts)]
fn $test_name() {
use std::iter::once;
const MIN: $type = std::$type::MIN;
const MAX: $type = std::$type::MAX;
let mut buffer = [0_u8; $size];
for x in (MIN..MAX).chain(once(MAX)) {
x.write(&mut buffer);
assert_eq!($type::read(&buffer), x);
}
let (mut x_buffer, mut y_buffer) = ([0_u8; $size], [0_u8; $size]);
for x in MIN..MAX {
let y = x + 1;
x.write(&mut x_buffer);
y.write(&mut y_buffer);
assert!(x_buffer < y_buffer);
}
}
};
(fuzz $type:ident, $size:expr => $test_name:ident) => {
#[test]
fn $test_name() {
use rand::{distributions::Standard, thread_rng, Rng};
let rng = thread_rng();
let mut buffer = [0_u8; $size];
let handpicked_vals = vec![$type::min_value(), $type::max_value()];
for x in rng
.sample_iter(&Standard)
.take(FUZZ_SAMPLES)
.chain(handpicked_vals)
{
x.write(&mut buffer);
assert_eq!($type::read(&buffer), x);
}
let rng = thread_rng();
let (mut x_buffer, mut y_buffer) = ([0_u8; $size], [0_u8; $size]);
let mut vals: Vec<$type> = rng.sample_iter(&Standard).take(FUZZ_SAMPLES).collect();
vals.sort_unstable();
for w in vals.windows(2) {
let (x, y) = (w[0], w[1]);
if x == y {
continue;
}
x.write(&mut x_buffer);
y.write(&mut y_buffer);
assert!(x_buffer < y_buffer);
}
}
};
}
test_storage_key_for_int_type! {full u8, 1 => test_storage_key_for_u8}
test_storage_key_for_int_type! {full i8, 1 => test_storage_key_for_i8}
test_storage_key_for_int_type! {full u16, 2 => test_storage_key_for_u16}
test_storage_key_for_int_type! {full i16, 2 => test_storage_key_for_i16}
test_storage_key_for_int_type! {fuzz u32, 4 => test_storage_key_for_u32}
test_storage_key_for_int_type! {fuzz i32, 4 => test_storage_key_for_i32}
test_storage_key_for_int_type! {fuzz u64, 8 => test_storage_key_for_u64}
test_storage_key_for_int_type! {fuzz i64, 8 => test_storage_key_for_i64}
test_storage_key_for_int_type! {fuzz u128, 16 => test_storage_key_for_u128}
test_storage_key_for_int_type! {fuzz i128, 16 => test_storage_key_for_i128}
#[test]
fn test_signed_int_key_in_index() {
use crate::{Database, MapIndex, TemporaryDB};
let db: Box<dyn Database> = Box::new(TemporaryDB::default());
let fork = db.fork();
{
let mut index: MapIndex<_, i32, u64> = fork.get_map("test_index");
index.put(&5, 100);
index.put(&-3, 200);
}
db.merge(fork.into_patch()).unwrap();
let snapshot = db.snapshot();
let index: MapIndex<_, i32, u64> = snapshot.get_map("test_index");
assert_eq!(index.get(&5), Some(100));
assert_eq!(index.get(&-3), Some(200));
assert_eq!(
index.iter_from(&-4).collect::<Vec<_>>(),
vec![(-3, 200), (5, 100)]
);
assert_eq!(index.iter_from(&-2).collect::<Vec<_>>(), vec![(5, 100)]);
assert_eq!(index.iter_from(&1).collect::<Vec<_>>(), vec![(5, 100)]);
assert_eq!(index.iter_from(&6).collect::<Vec<_>>(), vec![]);
assert_eq!(index.values().collect::<Vec<_>>(), vec![200, 100]);
}
#[test]
fn test_storage_key_for_chrono_date_time_round_trip() {
let times = [
Utc.timestamp(0, 0),
Utc.timestamp(13, 23),
Utc::now(),
Utc::now() + Duration::seconds(17) + Duration::nanoseconds(15),
Utc.timestamp(0, 999_999_999),
Utc.timestamp(0, 1_500_000_000), ];
assert_round_trip_eq(×);
}
#[test]
fn test_storage_key_for_system_time_ordering() {
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
let (mut buffer1, mut buffer2) = ([0_u8; 12], [0_u8; 12]);
for _ in 0..FUZZ_SAMPLES {
let time1 = Utc.timestamp(
rng.gen::<i64>() % i64::from(i32::max_value()),
rng.gen::<u32>() % 1_000_000_000,
);
let time2 = Utc.timestamp(
rng.gen::<i64>() % i64::from(i32::max_value()),
rng.gen::<u32>() % 1_000_000_000,
);
time1.write(&mut buffer1);
time2.write(&mut buffer2);
assert_eq!(time1.cmp(&time2), buffer1.cmp(&buffer2));
}
}
#[test]
fn test_system_time_key_in_index() {
use crate::{Database, MapIndex, TemporaryDB};
let db: Box<dyn Database> = Box::new(TemporaryDB::default());
let x1 = Utc.timestamp(80, 0);
let x2 = Utc.timestamp(10, 0);
let y1 = Utc::now();
let y2 = y1 + Duration::seconds(10);
let fork = db.fork();
{
let mut index: MapIndex<_, DateTime<Utc>, DateTime<Utc>> = fork.get_map("test_index");
index.put(&x1, y1);
index.put(&x2, y2);
}
db.merge(fork.into_patch()).unwrap();
let snapshot = db.snapshot();
let index: MapIndex<_, DateTime<Utc>, DateTime<Utc>> = snapshot.get_map("test_index");
assert_eq!(index.get(&x1), Some(y1));
assert_eq!(index.get(&x2), Some(y2));
assert_eq!(
index.iter_from(&Utc.timestamp(0, 0)).collect::<Vec<_>>(),
vec![(x2, y2), (x1, y1)]
);
assert_eq!(
index.iter_from(&Utc.timestamp(20, 0)).collect::<Vec<_>>(),
vec![(x1, y1)]
);
assert_eq!(
index.iter_from(&Utc.timestamp(80, 0)).collect::<Vec<_>>(),
vec![(x1, y1)]
);
assert_eq!(
index.iter_from(&Utc.timestamp(90, 0)).collect::<Vec<_>>(),
vec![]
);
assert_eq!(index.values().collect::<Vec<_>>(), vec![y2, y1]);
}
#[test]
fn test_str_key() {
let values = ["eee", "hello world", ""];
for val in &values {
let mut buffer = get_buffer(*val);
val.write(&mut buffer);
let new_val = str::read(&buffer);
assert_eq!(new_val, *val);
}
}
#[test]
#[should_panic(expected = "Error reading UTF-8 string")]
fn test_str_key_error() {
let buffer = &[0xfe_u8, 0xfd];
str::read(buffer);
}
#[test]
fn test_u8_slice_key() {
let values: &[&[u8]] = &[&[1, 2, 3], &[255], &[]];
for val in values.iter() {
let mut buffer = get_buffer(*val);
val.write(&mut buffer);
let new_val = <[u8] as BinaryKey>::read(&buffer);
assert_eq!(new_val, *val);
}
}
#[test]
fn test_uuid_round_trip() {
let uuids = [
Uuid::nil(),
Uuid::parse_str("936DA01F9ABD4d9d80C702AF85C822A8").unwrap(),
Uuid::parse_str("0000002a-000c-0005-0c03-0938362b0809").unwrap(),
];
assert_round_trip_eq(&uuids);
}
#[test]
fn test_decimal_round_trip() {
let decimals = [
Decimal::from_str("3.14").unwrap(),
Decimal::from_parts(1_102_470_952, 185_874_565, 1_703_060_790, false, 28),
Decimal::new(9_497_628_354_687_268, 12),
Decimal::from_str("0").unwrap(),
Decimal::from_str("-0.000000000000000000019").unwrap(),
];
assert_round_trip_eq(&decimals);
}
fn assert_round_trip_eq<T>(values: &[T])
where
T: BinaryKey + PartialEq<<T as ToOwned>::Owned> + Debug,
<T as ToOwned>::Owned: Debug,
{
for original_value in values.iter() {
let mut buffer = get_buffer(original_value);
original_value.write(&mut buffer);
let new_value = <T as BinaryKey>::read(&buffer);
assert_eq!(*original_value, new_value);
}
}
fn get_buffer<T: BinaryKey + ?Sized>(key: &T) -> Vec<u8> {
vec![0; key.size()]
}
}