grebedb 1.0.0

Lightweight embedded key-value store/database backed by files in a virtual file system interface
Documentation
use std::io::{Cursor, Read, Write};

use relative_path::RelativePath;
use rmp_serde::{Deserializer, Serializer};
use serde::{Deserialize, Serialize};

use crate::{
    error::Error,
    lru::LruVec,
    vfs::{Vfs, VfsSyncOption},
};

const MAGIC_BYTES: [u8; 8] = [0xFE, b'G', b'r', b'e', b'b', b'e', 0x00, 0x00];

pub struct Format {
    file_buffer: Vec<u8>,
    page_buffer: Vec<u8>,
    payload_buffer: Vec<u8>,
    compression_level: Option<i32>,
    dir_create_cache: LruVec<String>,
}

impl Default for Format {
    fn default() -> Self {
        Self {
            file_buffer: Vec::new(),
            page_buffer: Vec::new(),
            payload_buffer: Vec::new(),
            compression_level: if cfg!(feature = "zstd") {
                Some(0)
            } else {
                None
            },
            dir_create_cache: LruVec::new(8),
        }
    }
}

impl Format {
    pub fn set_compression_level(&mut self, value: Option<i32>) {
        self.compression_level = value;
    }

    pub fn read_file<'de, T>(&mut self, vfs: &mut dyn Vfs, path: &str) -> Result<T, Error>
    where
        T: Deserialize<'de>,
    {
        let mut file = Cursor::new(vfs.read(path)?);

        let mut magic_bytes: [u8; 8] = [0u8; 8];
        file.read_exact(&mut magic_bytes)?;

        if MAGIC_BYTES != magic_bytes {
            return Err(Error::InvalidFileFormat {
                path: path.to_string(),
                message: "not a database",
            });
        }

        let mut compression_flag: [u8; 1] = [0u8; 1];
        file.read_exact(&mut compression_flag)?;

        if compression_flag[0] == 0x01 {
            self.decompress_to_page_buffer(&mut file)?;
        } else {
            self.page_buffer.clear();
            file.read_to_end(&mut self.page_buffer)?;
        }

        self.deserialize_page(path)
    }

    pub fn write_file<T>(
        &mut self,
        vfs: &mut dyn Vfs,
        path: &str,
        payload: T,
        sync_option: VfsSyncOption,
    ) -> Result<(), Error>
    where
        T: Serialize,
    {
        self.file_buffer.clear();
        self.page_buffer.clear();
        self.payload_buffer.clear();

        self.file_buffer.write_all(&MAGIC_BYTES)?;

        if self.compression_level.is_some() {
            self.file_buffer.write_all(&[0x01])?;
            self.serialize_page(payload)?;
            self.write_compressed_page_to_file_buffer()?;
        } else {
            self.file_buffer.write_all(&[0x00])?;
            self.serialize_page(payload)?;
            self.file_buffer.write_all(&self.page_buffer)?;
        }

        let rel_path = RelativePath::new(path);
        let dir_path = rel_path.parent().unwrap();

        if !self.is_in_dir_cache(dir_path) {
            vfs.create_dir_all(dir_path.as_str())?;
        }

        vfs.write(path, &self.file_buffer, sync_option)?;

        Ok(())
    }

    fn serialize_page<T>(&mut self, object: T) -> Result<(), Error>
    where
        T: Serialize,
    {
        serialize_payload(object, &mut self.payload_buffer)?;

        let size_bytes = self.payload_buffer.len().to_be_bytes();

        self.page_buffer.write_all(&size_bytes)?;
        self.page_buffer.write_all(&self.payload_buffer)?;

        let crc = crc32c::crc32c(&self.payload_buffer);
        let crc_bytes = crc.to_be_bytes();

        self.page_buffer.write_all(&crc_bytes)?;

        Ok(())
    }

    fn write_compressed_page_to_file_buffer(&mut self) -> Result<(), Error> {
        #[cfg(feature = "zstd")]
        {
            let mut temp_buffer = Vec::with_capacity(0);
            std::mem::swap(&mut self.file_buffer, &mut temp_buffer);

            let compression_level = self.compression_level.unwrap();
            let mut compressor = zstd::Encoder::new(temp_buffer, compression_level)?;
            compressor.write_all(&self.page_buffer)?;
            let mut old_writer = compressor.finish()?;

            std::mem::swap(&mut self.file_buffer, &mut old_writer);

            Ok(())
        }
        #[cfg(not(feature = "zstd"))]
        {
            Err(Error::CompressionUnavailable)
        }
    }

    fn is_in_dir_cache(&mut self, dir_path: &RelativePath) -> bool {
        let dir_path = dir_path.to_string();

        if !self.dir_create_cache.touch(&dir_path) {
            self.dir_create_cache.insert(dir_path);
            false
        } else {
            true
        }
    }

    fn decompress_to_page_buffer(&mut self, source: &mut dyn Read) -> Result<(), Error> {
        self.page_buffer.clear();

        #[cfg(feature = "zstd")]
        {
            let mut decompressor = zstd::Decoder::new(source)?;
            decompressor.read_to_end(&mut self.page_buffer)?;
            Ok(())
        }
        #[cfg(not(feature = "zstd"))]
        {
            let _ = source;
            Err(Error::CompressionUnavailable)
        }
    }

    fn deserialize_page<'de, T>(&mut self, path: &str) -> Result<T, Error>
    where
        T: Deserialize<'de>,
    {
        let mut size_bytes: [u8; 8] = [0u8; 8];
        let mut data = Cursor::new(&mut self.page_buffer);

        data.read_exact(&mut size_bytes)?;
        let size = u64::from_be_bytes(size_bytes) as usize;

        let payload = deserialize_payload(&mut data)?;

        let mut crc_bytes: [u8; 4] = [0; 4];
        data.read_exact(&mut crc_bytes)?;
        let crc = u32::from_be_bytes(crc_bytes);

        let test_crc = crc32c::crc32c(&self.page_buffer[8..8 + size]);

        if crc != test_crc {
            Err(Error::BadChecksum {
                path: path.to_string(),
            })
        } else {
            Ok(payload)
        }
    }
}

fn serialize_payload<T, W>(object: T, destination: W) -> Result<(), Error>
where
    T: Serialize,
    W: Write,
{
    let mut serializer = Serializer::new(destination)
        .with_binary()
        .with_string_variants()
        .with_struct_map();

    match object.serialize(&mut serializer) {
        Ok(_) => Ok(()),
        Err(error) => Err(Error::Other(Box::new(error))),
    }
}

fn deserialize_payload<'de, T, R>(source: R) -> Result<T, Error>
where
    T: Deserialize<'de>,
    R: Read,
{
    let mut deserializer = Deserializer::new(source).with_binary();

    match Deserialize::deserialize(&mut deserializer) {
        Ok(value) => Ok(value),
        Err(error) => Err(Error::Other(Box::new(error))),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::vfs::MemoryVfs;

    #[test]
    fn test_format() -> Result<(), Error> {
        let mut format = Format::default();
        let mut vfs = MemoryVfs::new();

        format.write_file(&mut vfs, "my_file", "hello world", VfsSyncOption::None)?;

        let payload: String = format.read_file(&mut vfs, "my_file")?;

        assert_eq!(&payload, "hello world");

        Ok(())
    }
}