use integer_encoding::{FixedInt, FixedIntReader, VarInt, VarIntReader};
use std::io::Read;
use std::slice::Iter;
use crate::errors::{RainDBError, RainDBResult};
use crate::key::Operation;
use crate::utils::io::ReadHelpers;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BatchElement {
operation: Operation,
user_key: Vec<u8>,
value: Option<Vec<u8>>,
size: usize,
}
impl BatchElement {
pub fn get_operation(&self) -> Operation {
self.operation
}
pub fn get_key(&self) -> &[u8] {
&self.user_key
}
pub fn get_value(&self) -> Option<&Vec<u8>> {
self.value.as_ref()
}
}
impl BatchElement {
pub(crate) fn new(operation: Operation, user_key: Vec<u8>, value: Option<Vec<u8>>) -> Self {
let value_size = value.as_ref().map_or(0, |val| val.len());
let size = 1 + user_key.len() + value_size + 8;
Self {
operation,
user_key,
value,
size,
}
}
pub(crate) fn size(&self) -> usize {
self.size
}
pub(crate) fn read_element(mut buf: &[u8]) -> RainDBResult<(BatchElement, usize)> {
let starting_len = buf.len();
let mut raw_operation: [u8; 1] = [0; 1];
buf.read_exact(&mut raw_operation)?;
let operation = Operation::try_from(raw_operation[0])?;
let user_key = buf.read_length_prefixed_slice()?;
let value: Option<Vec<u8>> = if operation == Operation::Put {
Some(buf.read_length_prefixed_slice()?)
} else {
None
};
let bytes_read = starting_len - buf.len();
Ok((BatchElement::new(operation, user_key, value), bytes_read))
}
}
impl From<&BatchElement> for Vec<u8> {
fn from(batch_element: &BatchElement) -> Vec<u8> {
let mut buf = Vec::with_capacity(batch_element.size());
buf.extend(&[batch_element.operation as u8]);
buf.extend(u32::encode_var_vec(batch_element.user_key.len() as u32));
buf.extend(&batch_element.user_key);
if batch_element.get_operation() == Operation::Put {
buf.extend(u32::encode_var_vec(
batch_element.value.as_ref().unwrap().len() as u32,
));
buf.extend(batch_element.value.as_ref().unwrap());
}
buf
}
}
impl TryFrom<&[u8]> for BatchElement {
type Error = RainDBError;
fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
Ok(BatchElement::read_element(buf)?.0)
}
}
#[derive(PartialEq, Eq)]
pub struct Batch {
starting_seq_number: Option<u64>,
operations: Vec<BatchElement>,
}
impl Batch {
pub fn new() -> Self {
Self {
starting_seq_number: None,
operations: vec![],
}
}
pub fn add_put(&mut self, key: Vec<u8>, value: Vec<u8>) -> &mut Self {
let batch_element = BatchElement::new(Operation::Put, key, Some(value));
self.add_operation(batch_element);
self
}
pub fn add_delete(&mut self, key: Vec<u8>) -> &mut Self {
let batch_element = BatchElement::new(Operation::Delete, key, None);
self.add_operation(batch_element);
self
}
pub fn clear(&mut self) -> &mut Self {
self.operations.clear();
self
}
pub fn iter(&self) -> Iter<BatchElement> {
self.operations.iter()
}
pub fn len(&self) -> usize {
self.operations.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for Batch {
fn default() -> Self {
Self::new()
}
}
impl Batch {
pub(crate) fn get_approximate_size(&self) -> usize {
self.iter().map(|operation| operation.size()).sum()
}
pub(crate) fn set_starting_seq_number(&mut self, seq_number: u64) {
self.starting_seq_number = Some(seq_number);
}
pub(crate) fn get_starting_seq_number(&self) -> Option<u64> {
self.starting_seq_number
}
pub(crate) fn append_batch(&mut self, batch_to_append: &Batch) {
self.operations
.extend_from_slice(batch_to_append.iter().as_slice());
}
pub(crate) fn add_operation(&mut self, batch_element: BatchElement) {
self.operations.push(batch_element);
}
}
impl From<&Batch> for Vec<u8> {
fn from(batch: &Batch) -> Vec<u8> {
let mut buf = Vec::with_capacity(batch.get_approximate_size());
buf.extend(u64::encode_fixed_vec(batch.starting_seq_number.unwrap()));
buf.extend(u32::encode_var_vec(batch.operations.len() as u32));
for element in batch.operations.iter() {
buf.extend(Vec::<u8>::from(element));
}
buf
}
}
impl TryFrom<&[u8]> for Batch {
type Error = RainDBError;
fn try_from(mut buf: &[u8]) -> Result<Self, Self::Error> {
let mut batch = Batch::new();
let starting_seq_num: u64 = buf.read_fixedint()?;
batch.set_starting_seq_number(starting_seq_num);
let num_operations: u32 = buf.read_varint()?;
for _ in 0..num_operations {
let (element, bytes_read) = BatchElement::read_element(buf)?;
batch.add_operation(element);
buf = &buf[bytes_read..];
}
Ok(batch)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn can_add_and_remove_elements_on_a_batch() {
let mut batch = Batch::new();
assert!(batch.is_empty());
batch.add_put(b"batmann".to_vec(), u32::encode_fixed_vec(55));
batch.add_delete(b"robin".to_vec());
batch.add_put(b"robin".to_vec(), u32::encode_fixed_vec(3));
assert_eq!(batch.len(), 3);
batch.clear();
assert!(batch.is_empty());
}
#[test]
fn can_append_another_batch() {
let mut batch1 = Batch::new();
batch1.set_starting_seq_number(43);
batch1.add_put(b"batmann".to_vec(), u32::encode_fixed_vec(55));
let mut batch2 = Batch::new();
batch2.set_starting_seq_number(117);
batch2.add_delete(b"batmann".to_vec());
batch1.append_batch(&batch2);
assert_eq!(batch1.len(), 2);
assert_eq!(batch1.get_starting_seq_number().unwrap(), 43);
let mut batch1_iter = batch1.iter();
assert_eq!(
*batch1_iter.next().unwrap(),
BatchElement::new(
Operation::Put,
b"batmann".to_vec(),
Some(u32::encode_fixed_vec(55))
)
);
assert_eq!(
*batch1_iter.next().unwrap(),
BatchElement::new(Operation::Delete, b"batmann".to_vec(), None)
);
}
#[test]
fn can_be_serialized_and_deserialized() {
let expected_elements = [
BatchElement::new(
Operation::Put,
b"batmann".to_vec(),
Some(u32::encode_fixed_vec(55)),
),
BatchElement::new(Operation::Delete, b"batmann".to_vec(), None),
BatchElement::new(
Operation::Put,
b"robin".to_vec(),
Some(u32::encode_fixed_vec(3)),
),
BatchElement::new(
Operation::Put,
b"batmann".to_vec(),
Some(u32::encode_fixed_vec(11)),
),
];
let mut batch = Batch::new();
batch.set_starting_seq_number(43);
batch.add_put(b"batmann".to_vec(), u32::encode_fixed_vec(55));
batch.add_delete(b"batmann".to_vec());
batch.add_put(b"robin".to_vec(), u32::encode_fixed_vec(3));
batch.add_put(b"batmann".to_vec(), u32::encode_fixed_vec(11));
let serialized_batch = Vec::<u8>::from(&batch);
let deserialized_batch = Batch::try_from(serialized_batch.as_slice()).unwrap();
assert_eq!(
batch.get_starting_seq_number(),
deserialized_batch.get_starting_seq_number()
);
assert_eq!(
batch.get_approximate_size(),
deserialized_batch.get_approximate_size(),
"The batch approximate size ({batch_size}) should match the deserialized size \
({deserialized_size})",
batch_size = batch.get_approximate_size(),
deserialized_size = deserialized_batch.get_approximate_size()
);
for (deserialized_element, expected_element) in
deserialized_batch.iter().zip(expected_elements.iter())
{
assert_eq!(deserialized_element, expected_element);
}
}
#[test]
fn approximate_sizes_scale_appropriately() {
let mut batch = Batch::new();
assert_eq!(batch.get_approximate_size(), 0);
batch.add_put(b"batmann".to_vec(), u32::encode_fixed_vec(55));
assert!(batch.get_approximate_size() <= 69);
batch.add_delete(b"batmann".to_vec());
assert!(batch.get_approximate_size() <= (69 + 65));
batch.add_put(b"batmann".to_vec(), u32::encode_fixed_vec(11));
assert!(batch.get_approximate_size() <= (69 + 65 + 69));
}
}