titan-api-codec 1.2.7

Helpers for encoding and decoding Titan API messages
Documentation
use bytes::Bytes;
use thiserror::Error;

use titan_api_codec::dec::messagepack::MessagePackDecoder;
use titan_api_codec::enc::messagepack::MessagePackEncoder;
use titan_api_codec::transform;
use titan_api_codec::{
    dec::{DecodeError, Decoder},
    enc::{EncodeError, Encoder},
};
use titan_api_types::ws::v1::ServerMessage;

#[derive(Debug, Error)]
enum TestError {
    #[error("encoding to binary failed")]
    EncodeFailed(#[from] EncodeError),
    #[error("decoding from binary failed")]
    DecodeFailed(#[from] DecodeError),
    #[error("got wrong type, expected `{0}`, got `{1}`")]
    WrongType(&'static str, &'static str),
    #[error("got bad message type, expected Map with key `{0}`, got `{1}`")]
    BadMessageType(&'static str, &'static str),
    #[error("expected map with key `{0}`, but got empty map")]
    NoKeys(&'static str),
    #[error("expected map with key `{0}`, but got map with {1} values")]
    TooManyKeys(&'static str, usize),
    #[error("expected map with key `{0}`, but got map with non-string key {1}")]
    NonKeyString(&'static str, rmpv::Value),
    #[error("expected message type `{0}`, got unknown type `{1}`")]
    UnknownMessageType(&'static str, String),
}

fn error_from_other(value: rmpv::Value, expected: &'static str) -> TestError {
    match value {
        rmpv::Value::Nil => TestError::BadMessageType(expected, "Nil"),
        rmpv::Value::Boolean(_) => TestError::BadMessageType(expected, "Boolean"),
        rmpv::Value::Integer(_) => TestError::BadMessageType(expected, "Integer"),
        rmpv::Value::F32(_) => TestError::BadMessageType(expected, "F32"),
        rmpv::Value::F64(_) => TestError::BadMessageType(expected, "F64"),
        rmpv::Value::String(_) => TestError::BadMessageType(expected, "String"),
        rmpv::Value::Binary(_) => TestError::BadMessageType(expected, "Binary"),
        rmpv::Value::Array(_) => TestError::BadMessageType(expected, "Array"),
        rmpv::Value::Ext(_, _) => TestError::BadMessageType(expected, "Ext"),
        rmpv::Value::Map(values) => {
            if values.is_empty() {
                return TestError::NoKeys(expected);
            }
            if values.len() > 1 {
                return TestError::TooManyKeys(expected, values.len());
            }
            let (key, _) = values.first().unwrap();
            let key_string = match key {
                rmpv::Value::String(str) => str.to_string(),
                other => {
                    return TestError::NonKeyString(expected, other.clone());
                }
            };
            TestError::UnknownMessageType(expected, key_string)
        }
    }
}

#[test]
fn test_roundtrip_streamdata_messagepack_plain() -> Result<(), TestError> {
    let source_data = include_bytes!("data/stream_data.msgpack");
    let source_bytes = Bytes::from_static(source_data);

    let mut encoder = MessagePackEncoder::default();
    let mut decoder = MessagePackDecoder::default();

    let value: ServerMessage = decoder.decode_mut(source_bytes.clone())?;
    match value {
        ServerMessage::StreamData(_) => {}
        ServerMessage::Response(_) => {
            return Err(TestError::WrongType("StreamData", "Response"));
        }
        ServerMessage::Error(_) => {
            return Err(TestError::WrongType("StreamData", "Error"));
        }
        ServerMessage::StreamEnd(_) => {
            return Err(TestError::WrongType("StreamData", "StreamEnd"));
        }
        ServerMessage::Other(value) => {
            return Err(error_from_other(value, "StreamData"));
        }
    }

    let encoded = encoder.encode_mut(&value)?;
    let value2: ServerMessage = decoder.decode_mut(encoded.clone())?;

    // Re-encoded value won't necessarily have fields in the same order, so just make sure that
    // re-encoding and decoding results in the same decoded value.
    //
    // For example, `SwapQuotes` has a HashMap in it, which due to Rust's [`HashMap`] randomly
    // seeding the state for collision resistance, will result in a different iteration order each
    // time.
    assert_eq!(value, value2);

    Ok(())
}

#[test]
fn test_roundtrip_streamdata_messagepack_noop() -> Result<(), TestError> {
    let source_data = include_bytes!("data/stream_data.msgpack");
    let source_bytes = Bytes::from_static(source_data);

    let mut encoder = MessagePackEncoder::default().transform(transform::common::NoOpTransform);
    let mut decoder = MessagePackDecoder::default().transformed(transform::common::NoOpTransform);

    let value: ServerMessage = decoder.decode_mut(source_bytes.clone())?;
    match value {
        ServerMessage::StreamData(_) => {}
        ServerMessage::Response(_) => {
            return Err(TestError::WrongType("StreamData", "Response"));
        }
        ServerMessage::Error(_) => {
            return Err(TestError::WrongType("StreamData", "Error"));
        }
        ServerMessage::StreamEnd(_) => {
            return Err(TestError::WrongType("StreamData", "StreamEnd"));
        }
        ServerMessage::Other(value) => {
            return Err(error_from_other(value, "StreamData"));
        }
    }

    let encoded = encoder.encode_mut(&value)?;
    let value2: ServerMessage = decoder.decode_mut(encoded.clone())?;

    assert_eq!(value, value2);

    Ok(())
}

#[test]
fn test_roundtrip_streamdata_messagepack_zstd() -> Result<(), TestError> {
    // This file was compressed by the `zstd` command line utility.
    let source_data = include_bytes!("data/stream_data.msgpack.zstd");
    let source_bytes = Bytes::from_static(source_data);

    let mut encoder =
        MessagePackEncoder::default().transform(transform::zstd::ZstdCompressor::default());
    let mut decoder =
        MessagePackDecoder::default().transformed(transform::zstd::ZstdDecompressor::default());

    let value: ServerMessage = decoder.decode_mut(source_bytes.clone())?;
    match value {
        ServerMessage::StreamData(_) => {}
        ServerMessage::Response(_) => {
            return Err(TestError::WrongType("StreamData", "Response"));
        }
        ServerMessage::Error(_) => {
            return Err(TestError::WrongType("StreamData", "Error"));
        }
        ServerMessage::StreamEnd(_) => {
            return Err(TestError::WrongType("StreamData", "StreamEnd"));
        }
        ServerMessage::Other(value) => {
            return Err(error_from_other(value, "StreamData"));
        }
    }

    let encoded = encoder.encode_mut(&value)?;
    let value2: ServerMessage = decoder.decode_mut(encoded.clone())?;

    assert_eq!(value, value2);

    Ok(())
}

#[test]
fn test_roundtrip_streamdata_messagepack_brotli() -> Result<(), TestError> {
    // This file was compressed by the `brotli` command line utility.
    let source_data = include_bytes!("data/stream_data.msgpack.brotli");
    let source_bytes = Bytes::from_static(source_data);

    let mut encoder =
        MessagePackEncoder::default().transform(transform::brotli::BrotliCompressor::default());
    let mut decoder =
        MessagePackDecoder::default().transformed(transform::brotli::BrotliDecompressor::default());

    let value: ServerMessage = decoder.decode_mut(source_bytes.clone())?;
    match value {
        ServerMessage::StreamData(_) => {}
        ServerMessage::Response(_) => {
            return Err(TestError::WrongType("StreamData", "Response"));
        }
        ServerMessage::Error(_) => {
            return Err(TestError::WrongType("StreamData", "Error"));
        }
        ServerMessage::StreamEnd(_) => {
            return Err(TestError::WrongType("StreamData", "StreamEnd"));
        }
        ServerMessage::Other(value) => {
            return Err(error_from_other(value, "StreamData"));
        }
    }

    let encoded = encoder.encode_mut(&value)?;
    let value2: ServerMessage = decoder.decode_mut(encoded.clone())?;

    assert_eq!(value, value2);

    Ok(())
}

#[test]
fn test_roundtrip_streamdata_messagepack_gzip() -> Result<(), TestError> {
    // This file was compressed by the `gzip` command line utility.
    let source_data = include_bytes!("data/stream_data.msgpack.gz");
    let source_bytes = Bytes::from_static(source_data);

    let mut encoder =
        MessagePackEncoder::default().transform(transform::gzip::GzipCompressor::default());
    let mut decoder =
        MessagePackDecoder::default().transformed(transform::gzip::GzipDecompressor::default());

    let value: ServerMessage = decoder.decode_mut(source_bytes.clone())?;
    match value {
        ServerMessage::StreamData(_) => {}
        ServerMessage::Response(_) => {
            return Err(TestError::WrongType("StreamData", "Response"));
        }
        ServerMessage::Error(_) => {
            return Err(TestError::WrongType("StreamData", "Error"));
        }
        ServerMessage::StreamEnd(_) => {
            return Err(TestError::WrongType("StreamData", "StreamEnd"));
        }
        ServerMessage::Other(value) => {
            return Err(error_from_other(value, "StreamData"));
        }
    }

    let encoded = encoder.encode_mut(&value)?;
    let value2: ServerMessage = decoder.decode_mut(encoded.clone())?;

    assert_eq!(value, value2);

    Ok(())
}