s2-common 0.31.0

Common stuff for client and servers for S2, the durable streams API
Documentation
use std::{fmt, str::Utf8Error};

use bytes::{BufMut, Bytes};
use compact_str::CompactString;
use enum_ordinalize::Ordinalize;

use super::{
    Encodable, FencingTokenTooLongError, MeteredSize, RecordDecodeError, fencing::FencingToken,
};
use crate::{deep_size::DeepSize, record::SeqNum};

pub const COMMAND_ID_FENCE: &[u8] = b"fence";
pub const COMMAND_ID_TRIM: &[u8] = b"trim";

#[derive(Debug, PartialEq, Eq, Clone, Copy, Ordinalize)]
#[repr(u8)]
pub enum CommandOp {
    Fence,
    Trim,
}

impl CommandOp {
    pub fn to_id(self) -> &'static [u8] {
        match self {
            Self::Fence => COMMAND_ID_FENCE,
            Self::Trim => COMMAND_ID_TRIM,
        }
    }

    pub fn from_id(name: &[u8]) -> Option<Self> {
        match name {
            COMMAND_ID_FENCE => Some(Self::Fence),
            COMMAND_ID_TRIM => Some(Self::Trim),
            _ => None,
        }
    }
}

impl fmt::Display for CommandOp {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let name = std::str::from_utf8(self.to_id()).map_err(|_| fmt::Error)?;
        f.write_str(name)
    }
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum CommandRecord {
    Fence(FencingToken),
    Trim(SeqNum),
}

impl DeepSize for CommandRecord {
    fn deep_size(&self) -> usize {
        match self {
            Self::Fence(token) => token.deep_size(),
            Self::Trim(seq_num) => seq_num.deep_size(),
        }
    }
}

impl MeteredSize for CommandRecord {
    fn metered_size(&self) -> usize {
        8 + 2
            + self.op().to_id().len()
            + match self {
                Self::Fence(token) => token.len(),
                Self::Trim(trim_point) => size_of_val(trim_point),
            }
    }
}

impl CommandRecord {
    pub fn op(&self) -> CommandOp {
        match self {
            CommandRecord::Fence(_) => CommandOp::Fence,
            CommandRecord::Trim(_) => CommandOp::Trim,
        }
    }

    pub fn payload(&self) -> Bytes {
        match self {
            Self::Fence(token) => Bytes::copy_from_slice(token.as_bytes()),
            Self::Trim(trim_point) => Bytes::copy_from_slice(&trim_point.to_be_bytes()),
        }
    }

    pub fn try_from_parts(op: CommandOp, payload: &[u8]) -> Result<Self, CommandPayloadError> {
        match op {
            CommandOp::Fence => {
                let token = CompactString::from_utf8(payload)
                    .map_err(CommandPayloadError::InvalidUtf8)?
                    .try_into()?;
                Ok(Self::Fence(token))
            }
            CommandOp::Trim => {
                let trim_point = SeqNum::from_be_bytes(
                    payload
                        .try_into()
                        .map_err(|_| CommandPayloadError::TrimPointSize(payload.len()))?,
                );
                Ok(Self::Trim(trim_point))
            }
        }
    }
}

impl TryFrom<&[u8]> for CommandRecord {
    type Error = RecordDecodeError;

    fn try_from(record: &[u8]) -> Result<Self, Self::Error> {
        if record.is_empty() {
            return Err(RecordDecodeError::Truncated("CommandOrdinal"));
        }
        let op = CommandOp::from_ordinal(record[0])
            .ok_or(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))?;
        Self::try_from_parts(op, &record[1..]).map_err(Into::into)
    }
}

impl Encodable for CommandRecord {
    fn encoded_size(&self) -> usize {
        1 + match self {
            CommandRecord::Fence(token) => token.len(),
            CommandRecord::Trim(trim_point) => size_of_val(trim_point),
        }
    }

    fn encode_into(&self, buf: &mut impl BufMut) {
        buf.put_u8(self.op().ordinal());
        match self {
            CommandRecord::Fence(token) => {
                buf.put_slice(token.as_bytes());
            }
            CommandRecord::Trim(trim_point) => {
                buf.put_u64(*trim_point);
            }
        }
    }
}

#[derive(Debug, PartialEq, thiserror::Error)]
pub enum CommandPayloadError {
    #[error("invalid UTF-8")]
    InvalidUtf8(Utf8Error),
    #[error(transparent)]
    FencingTokenTooLong(#[from] FencingTokenTooLongError),
    #[error("earliest sequence number to trim to was {0} bytes, must be 8")]
    TrimPointSize(usize),
}

impl From<CommandPayloadError> for RecordDecodeError {
    fn from(e: CommandPayloadError) -> Self {
        match e {
            CommandPayloadError::InvalidUtf8(_) => {
                RecordDecodeError::InvalidValue("CommandPayload", "fencing token not valid utf8")
            }
            CommandPayloadError::FencingTokenTooLong(_) => {
                RecordDecodeError::InvalidValue("CommandPayload", "fencing token too long")
            }
            CommandPayloadError::TrimPointSize(_) => {
                RecordDecodeError::InvalidValue("CommandPayload", "trim point size")
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use compact_str::ToCompactString;
    use enum_ordinalize::Ordinalize;
    use proptest::prelude::*;
    use rstest::rstest;

    use super::*;

    fn roundtrip(cmd: CommandRecord, expected_len: usize) {
        assert_eq!(cmd.encoded_size(), expected_len);
        let encoded = cmd.to_bytes();
        assert_eq!(encoded.len(), expected_len);
        assert_eq!(CommandRecord::try_from(encoded.as_ref()), Ok(cmd));
    }

    #[test]
    fn command_op_names() {
        for cmd in CommandOp::VARIANTS {
            let name = cmd.to_id();
            assert_eq!(CommandOp::from_id(name), Some(*cmd));
        }
        assert_eq!(CommandOp::from_id(b""), None);
        assert_eq!(CommandOp::from_id(b"invalid"), None);
    }

    #[test]
    fn fencing_token_invalid_utf8() {
        assert!(matches!(
            CommandRecord::try_from_parts(CommandOp::Fence, &[0xff]),
            Err(CommandPayloadError::InvalidUtf8(_))
        ));
    }

    #[test]
    fn fencing_token_too_long() {
        assert_eq!(
            CommandRecord::try_from_parts(
                CommandOp::Fence,
                b"0123456789012345678901234567890123456789"
            ),
            Err(CommandPayloadError::FencingTokenTooLong(
                FencingTokenTooLongError(40)
            ))
        );
    }

    #[rstest]
    #[case::empty("")]
    #[case::arbit("arbitrary")]
    #[case::full("0123456789012345")]
    fn fence_roundtrip(#[case] token: &str) {
        let cmd = CommandRecord::Fence(FencingToken::try_from(token.to_compact_string()).unwrap());
        assert_eq!(
            CommandRecord::try_from_parts(CommandOp::Fence, token.as_bytes()),
            Ok(cmd.clone())
        );
        roundtrip(cmd, 1 + token.len());
    }

    #[rstest]
    #[case::empty(b"")]
    #[case::too_small(b"0123")]
    #[case::too_big(b"0123456789")]
    fn trim_point_size(#[case] payload: &[u8]) {
        assert_eq!(
            CommandRecord::try_from_parts(CommandOp::Trim, payload),
            Err(CommandPayloadError::TrimPointSize(payload.len()))
        );
    }

    #[test]
    fn metered_size_is_computed_without_materializing_payload() {
        let fence =
            CommandRecord::Fence(FencingToken::try_from("fence-me".to_compact_string()).unwrap());
        assert_eq!(
            fence.metered_size(),
            8 + 2 + CommandOp::Fence.to_id().len() + "fence-me".len()
        );

        let trim = CommandRecord::Trim(42);
        assert_eq!(
            trim.metered_size(),
            8 + 2 + CommandOp::Trim.to_id().len() + size_of_val(&42u64)
        );
    }

    proptest! {
        #[test]
        fn trim_roundtrip(trim_point in any::<SeqNum>()) {
            let cmd = CommandRecord::Trim(trim_point);
            assert_eq!(CommandRecord::try_from_parts(CommandOp::Trim, trim_point.to_be_bytes().as_slice()), Ok(cmd.clone()));
            roundtrip(cmd, 9);
        }
    }

    #[test]
    fn decode_invalid_command() {
        let try_convert = |raw: &[u8]| CommandRecord::try_from(raw);
        assert_eq!(
            try_convert(&[]),
            Err(RecordDecodeError::Truncated("CommandOrdinal"))
        );
        assert_eq!(
            try_convert(&[0xff]),
            Err(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))
        );
        assert_eq!(
            try_convert(&[CommandOp::Fence.ordinal(), 0xff, 0xff]),
            Err(RecordDecodeError::InvalidValue(
                "CommandPayload",
                "fencing token not valid utf8"
            ))
        );
        assert_eq!(
            try_convert(&[
                CommandOp::Fence.ordinal(),
                b'0',
                b'1',
                b'2',
                b'3',
                b'4',
                b'5',
                b'6',
                b'7',
                b'8',
                b'9',
                b'0',
                b'1',
                b'2',
                b'3',
                b'4',
                b'5',
                b'6',
                b'7',
                b'8',
                b'9',
                b'0',
                b'1',
                b'2',
                b'3',
                b'4',
                b'5',
                b'6',
                b'7',
                b'8',
                b'9',
                b'0',
                b'1',
                b'2',
                b'3',
                b'4',
                b'5',
                b'6',
                b'7',
                b'8',
                b'9',
            ]),
            Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
        );
        assert_eq!(
            try_convert(&[CommandOp::Trim.ordinal(), 0xff]),
            Err(CommandPayloadError::TrimPointSize(1).into())
        );
    }
}