use std::fmt;
use std::fmt::Formatter;
use std::io;
use std::io::Cursor;
use std::io::Read;
use byteorder::BigEndian;
use byteorder::ReadBytesExt;
use byteorder::WriteBytesExt;
use codeq::Encode;
use codeq::config::CodeqConfig;
use crate::WalTypes;
use crate::types::Checksum;
pub const CHECKPOINT_RECORD_TYPE: u32 = 5;
#[derive(Clone, PartialEq, Eq)]
pub enum WALRecord<W>
where W: WalTypes
{
Action(W::Action),
Checkpoint(W::Checkpoint),
}
impl<W> fmt::Debug for WALRecord<W>
where W: WalTypes
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
WALRecord::Action(action) => fmt::Debug::fmt(action, f),
WALRecord::Checkpoint(checkpoint) => {
f.debug_tuple("State").field(checkpoint).finish()
}
}
}
}
impl<W> fmt::Display for WALRecord<W>
where
W: WalTypes,
W::Action: fmt::Display,
W::Checkpoint: fmt::Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
WALRecord::Action(action) => fmt::Display::fmt(action, f),
WALRecord::Checkpoint(checkpoint) => {
write!(f, "Checkpoint({})", checkpoint)
}
}
}
}
impl<W> codeq::Encode for WALRecord<W>
where W: WalTypes
{
fn encode<Wt: io::Write>(&self, mut w: Wt) -> Result<usize, io::Error> {
match self {
WALRecord::Action(action) => {
let type_id = action.type_id().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"action encoding does not provide a leading type id",
)
})?;
if type_id == CHECKPOINT_RECORD_TYPE {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"action type id {} conflicts with checkpoint",
CHECKPOINT_RECORD_TYPE
),
));
}
action.encode(&mut w)
}
WALRecord::Checkpoint(checkpoint) => {
let mut n = 0;
let mut cw = Checksum::new_writer(&mut w);
cw.write_u32::<BigEndian>(CHECKPOINT_RECORD_TYPE)?;
n += 4;
n += checkpoint.encode(&mut cw)?;
n += cw.write_checksum()?;
Ok(n)
}
}
}
}
impl<W> codeq::Decode for WALRecord<W>
where W: WalTypes
{
fn decode<R: io::Read>(mut r: R) -> Result<Self, io::Error> {
let mut type_bytes = [0; 4];
r.read_exact(&mut type_bytes)?;
let type_id = u32::from_be_bytes(type_bytes);
if type_id != CHECKPOINT_RECORD_TYPE {
let mut r = Cursor::new(type_bytes).chain(r);
let action = W::Action::decode(&mut r)?;
let decoded_type_id = action.type_id().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"decoded action does not provide a leading type id",
)
})?;
if decoded_type_id != type_id {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"action type id mismatch: encoded {}, decoded {}",
type_id, decoded_type_id
),
));
}
return Ok(Self::Action(action));
}
let mut cr = Checksum::new_reader(Cursor::new(type_bytes).chain(r));
cr.read_u32::<BigEndian>()?;
let rec = Self::Checkpoint(W::Checkpoint::decode(&mut cr)?);
cr.verify_checksum(|| "Record::decode()")?;
Ok(rec)
}
}
#[cfg(test)]
mod tests {
use std::fmt;
use std::io;
use std::sync::mpsc::SyncSender;
use codeq::Decode;
use codeq::Encode;
use crate::WalTypes;
use crate::wal::wal_record::CHECKPOINT_RECORD_TYPE;
use crate::wal::wal_record::WALRecord;
const TEST_ACTION_TYPE: u32 = 1;
#[derive(Clone, PartialEq, Eq)]
struct TestAction(String);
impl fmt::Debug for TestAction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.0, f)
}
}
impl fmt::Display for TestAction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl Encode for TestAction {
fn encode<Wt: io::Write>(&self, mut w: Wt) -> Result<usize, io::Error> {
let mut n = TEST_ACTION_TYPE.encode(&mut w)?;
n += self.0.encode(&mut w)?;
Ok(n)
}
fn type_id(&self) -> Option<u32> {
Some(TEST_ACTION_TYPE)
}
}
impl Decode for TestAction {
fn decode<R: io::Read>(mut r: R) -> Result<Self, io::Error> {
let type_id = u32::decode(&mut r)?;
if type_id != TEST_ACTION_TYPE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected action type id {}", type_id),
));
}
Ok(Self(String::decode(&mut r)?))
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct TestWal;
impl WalTypes for TestWal {
type Action = TestAction;
type Checkpoint = String;
type Callback = SyncSender<Result<(), io::Error>>;
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct NoTypeWal;
impl WalTypes for NoTypeWal {
type Action = String;
type Checkpoint = String;
type Callback = SyncSender<Result<(), io::Error>>;
}
fn action(v: &str) -> WALRecord<TestWal> {
WALRecord::Action(TestAction(v.to_string()))
}
fn checkpoint(v: &str) -> WALRecord<TestWal> {
WALRecord::Checkpoint(v.to_string())
}
fn checkpoint_state_bytes() -> Vec<u8> {
vec![
0, 0, 0, 5, 0, 0, 0, 5, 115, 116, 97, 116, 101, 0, 0, 0, 0, 220, 33, 57, 147, ]
}
#[test]
fn test_action_debug_display_and_clone() {
let rec = action("vote");
assert_eq!("\"vote\"", format!("{:?}", rec));
assert_eq!("vote", format!("{}", rec));
assert_eq!(rec, rec.clone());
}
#[test]
fn test_checkpoint_debug_display_and_clone() {
let rec = checkpoint("state");
assert_eq!("State(\"state\")", format!("{:?}", rec));
assert_eq!("Checkpoint(state)", format!("{}", rec));
assert_eq!(rec, rec.clone());
}
#[test]
fn test_encode_action_delegates_to_action_codec() -> Result<(), io::Error> {
let mut got = Vec::new();
let n = action("vote").encode(&mut got)?;
assert_eq!(got.len(), n);
assert_eq!(vec![0, 0, 0, 1, 0, 0, 0, 4, 118, 111, 116, 101], got);
Ok(())
}
#[test]
fn test_encode_action_requires_type_id() {
let mut got = Vec::new();
let rec = WALRecord::<NoTypeWal>::Action("vote".to_string());
let err = rec.encode(&mut got).unwrap_err();
assert_eq!(io::ErrorKind::InvalidInput, err.kind());
assert!(err.to_string().contains("does not provide"));
}
#[test]
fn test_encode_checkpoint_adds_type_and_checksum() -> Result<(), io::Error>
{
let mut got = Vec::new();
let n = checkpoint("state").encode(&mut got)?;
assert_eq!(CHECKPOINT_RECORD_TYPE, 5);
assert_eq!(got.len(), n);
assert_eq!(checkpoint_state_bytes(), got);
Ok(())
}
#[test]
fn test_decode_action_replays_record_type_bytes() -> Result<(), io::Error> {
let mut bytes = Vec::new();
action("vote").encode(&mut bytes)?;
action("log").encode(&mut bytes)?;
let mut r = &bytes[..];
assert_eq!(action("vote"), WALRecord::<TestWal>::decode(&mut r)?);
assert_eq!(action("log"), WALRecord::<TestWal>::decode(&mut r)?);
assert_eq!(&[] as &[u8], r);
Ok(())
}
#[test]
fn test_decode_checkpoint_verifies_checksum() -> Result<(), io::Error> {
let bytes = checkpoint_state_bytes();
let got = WALRecord::<TestWal>::decode(&mut bytes.as_slice())?;
assert_eq!(checkpoint("state"), got);
Ok(())
}
#[test]
fn test_decode_checkpoint_rejects_bad_checksum() {
let mut bytes = checkpoint_state_bytes();
*bytes.last_mut().unwrap() ^= 1;
let err = WALRecord::<TestWal>::decode(&mut bytes.as_slice())
.expect_err("corrupted checkpoint checksum must fail");
assert_eq!(io::ErrorKind::InvalidData, err.kind());
assert!(err.to_string().contains("Record::decode()"));
}
#[test]
fn test_decode_rejects_short_record_type() {
let err = WALRecord::<TestWal>::decode(&mut [0, 0, 0].as_slice())
.expect_err("short record type must fail");
assert_eq!(io::ErrorKind::UnexpectedEof, err.kind());
}
}