rdb 0.3.0

Fast and efficient RDB parsing utility
Documentation
use super::common::utils::{
    read_blob, read_length, read_length_with_encoding, verify_magic, verify_version,
};
use super::{hash, list, set, sorted_set};
use byteorder::{BigEndian, LittleEndian, ReadBytesExt};
use std::io::Read;

use crate::constants::{encoding, encoding_type, op_code};
use crate::filter::Filter;
use crate::types::{RdbError, RdbResult, RdbValue};

#[derive(Default)]
pub(crate) struct DecoderState {
    pub last_expiretime: Option<u64>,
    pub current_database: u32,
    pub reached_eof: bool,
}

pub(crate) fn verify_header<R: Read>(input: &mut R) -> RdbResult<()> {
    verify_magic(input)?;
    verify_version(input)
}

pub(crate) fn read_type<R: Read>(
    input: &mut R,
    key: &[u8],
    value_type: u8,
    expiry: Option<u64>,
) -> RdbResult<RdbValue> {
    let result = match value_type {
        encoding_type::STRING => {
            let value = read_blob(input)?;
            RdbValue::String {
                key: key.to_vec(),
                value,
                expiry,
            }
        }
        encoding_type::LIST => list::read_linked_list(input, key, expiry)?,
        encoding_type::SET => set::read_set(input, key, expiry)?,
        encoding_type::ZSET => sorted_set::read_sorted_set(input, key, expiry, false)?,
        encoding_type::HASH => hash::read_hash(input, key, expiry)?,
        encoding_type::HASH_ZIPMAP => hash::read_hash_zipmap(input, key, expiry)?,
        encoding_type::LIST_ZIPLIST => list::read_list_ziplist(input, key, expiry)?,
        encoding_type::SET_INTSET => set::read_set_intset(input, key, expiry)?,
        encoding_type::ZSET_ZIPLIST => sorted_set::read_sorted_set_ziplist(input, key, expiry)?,
        encoding_type::HASH_ZIPLIST => hash::read_hash_ziplist(input, key, expiry)?,
        encoding_type::LIST_QUICKLIST => list::read_quicklist(input, key, expiry)?,
        encoding_type::HASH_LIST_PACK => hash::read_hash_list_pack(input, key, expiry)?,
        encoding_type::ZSET_2 => sorted_set::read_sorted_set(input, key, expiry, true)?,
        encoding_type::LIST_QUICKLIST_2 => list::read_quicklist_2(input, key, expiry)?,
        encoding_type::STREAM_LIST_PACKS => {
            todo!("read_stream_list_packs v1 not implemented");
        }
        encoding_type::STREAM_LIST_PACKS_2 => {
            todo!("read_stream_list_packs v2 not implemented");
        }
        encoding_type::STREAM_LIST_PACKS_3 => {
            todo!("read_stream_list_packs v3 not implemented");
        }
        encoding_type::ZSET_LIST_PACK => sorted_set::read_sorted_set_listpack(input, key, expiry)?,
        encoding_type::SET_LIST_PACK => set::read_set_list_pack(input, key, expiry)?,
        unknown_type => {
            skip_object(input, unknown_type)?;
            return Err(RdbError::MissingValue("skip"));
        }
    };
    Ok(result)
}

pub(crate) fn skip<R: Read>(input: &mut R, skip_bytes: usize) -> RdbResult<()> {
    let mut buf = vec![0; skip_bytes];
    input.read_exact(&mut buf).map_err(RdbError::Io)?;
    Ok(())
}

pub(crate) fn skip_blob<R: Read>(input: &mut R) -> RdbResult<()> {
    let (len, is_encoded) = read_length_with_encoding(input)?;

    let skip_bytes = if is_encoded {
        match len {
            encoding::INT8 => 1,
            encoding::INT16 => 2,
            encoding::INT32 => 4,
            encoding::LZF => {
                let compressed_length = read_length(input)?;
                let _real_length = read_length(input)?;
                compressed_length
            }
            _ => {
                return Err(RdbError::ParsingError {
                    context: "skip_blob",
                    message: format!("Unknown encoding value: {}", len),
                });
            }
        }
    } else {
        len
    };

    skip(input, skip_bytes as usize)
}

pub(crate) fn skip_object<R: Read>(input: &mut R, enc_type: u8) -> RdbResult<()> {
    let blobs_count = match enc_type {
        encoding_type::STRING
        | encoding_type::HASH_ZIPMAP
        | encoding_type::LIST_ZIPLIST
        | encoding_type::SET_INTSET
        | encoding_type::ZSET_ZIPLIST
        | encoding_type::HASH_ZIPLIST
        | encoding_type::HASH_LIST_PACK => 1,
        encoding_type::LIST | encoding_type::SET | encoding_type::LIST_QUICKLIST => {
            read_length(input)?
        }
        encoding_type::ZSET | encoding_type::HASH => read_length(input)? * 2,
        _ => return Err(RdbError::UnknownEncoding(enc_type)),
    };

    for _ in 0..blobs_count {
        skip_blob(input)?;
    }
    Ok(())
}

pub(crate) fn skip_key_and_object<R: Read>(input: &mut R, enc_type: u8) -> RdbResult<()> {
    skip_blob(input)?;
    skip_object(input, enc_type)?;
    Ok(())
}

pub(crate) fn process_next_operation<R: Read, F: Filter>(
    input: &mut R,
    filter: &F,
    state: &mut DecoderState,
) -> RdbResult<RdbValue> {
    let next_op = match input.read_u8() {
        Ok(op) => op,
        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
            return Ok(RdbValue::Checksum(vec![]))
        }
        Err(e) => return Err(e.into()),
    };

    match next_op {
        op_code::SELECTDB => {
            state.current_database = read_length(input)?;
            Ok(RdbValue::SelectDb(state.current_database))
        }
        op_code::EOF => {
            let mut checksum = Vec::new();
            input.read_to_end(&mut checksum)?;
            state.reached_eof = true;
            Ok(RdbValue::Checksum(checksum))
        }
        op_code::EXPIRETIME_MS => {
            state.last_expiretime = Some(input.read_u64::<LittleEndian>()?);
            process_next_operation(input, filter, state)
        }
        op_code::EXPIRETIME => {
            state.last_expiretime = Some(input.read_u32::<BigEndian>()? as u64 * 1000);
            process_next_operation(input, filter, state)
        }
        op_code::RESIZEDB => {
            let db_size = read_length(input)?;
            let expires_size = read_length(input)?;
            Ok(RdbValue::ResizeDb {
                db_size,
                expires_size,
            })
        }
        op_code::AUX => {
            let key = read_blob(input)?;
            let value = read_blob(input)?;
            Ok(RdbValue::AuxField { key, value })
        }
        op_code::MODULE_AUX => {
            skip_blob(input)?;
            process_next_operation(input, filter, state)
        }
        op_code::IDLE => {
            let _idle_time = read_length(input)?;
            process_next_operation(input, filter, state)
        }
        op_code::FREQ => {
            let _freq = input.read_u8()?;
            process_next_operation(input, filter, state)
        }
        value_type => {
            if !filter.matches_db(state.current_database) {
                skip_key_and_object(input, value_type)?;
                return Ok(RdbValue::SelectDb(state.current_database));
            }

            let key = read_blob(input)?;
            if !filter.matches_type(value_type) || !filter.matches_key(&key) {
                skip_object(input, value_type)?;
                return Ok(RdbValue::SelectDb(state.current_database));
            }

            read_type(input, &key, value_type, state.last_expiretime)
        }
    }
}