dyn-encoding 0.0.1

Wire-format codec abstraction (protobuf, JSON, CBOR; flatbuffers/capnp/bebop/BSON to follow) for the Riak protocol layer
Documentation
//! `application/x-protobuf` codec backed by `prost`.
//!
//! Unlike the JSON and CBOR codecs, the protobuf codec does not lean
//! on `serde`. Each registered message type provides its own
//! `prost::Message` impl (typically generated by `prost-build` from a
//! `.proto` file, or hand-derived for small fixtures). The encoder
//! calls `Message::encode_to_vec`; the decoder calls
//! `Message::decode`.

use std::collections::HashMap;

use prost::Message;

use crate::error::CodecError;
use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};

type EncodeFn =
    Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
type DecodeFn =
    Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;

/// Codec that serialises [`WireValue`] types using protobuf's
/// length-undelimited wire format via `prost`.
#[derive(Default)]
pub struct ProtobufCodec {
    encoders: HashMap<WireTypeId, EncodeFn>,
    decoders: HashMap<WireTypeId, DecodeFn>,
}

impl ProtobufCodec {
    /// Construct an empty protobuf codec with no registered types.
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    /// Register a [`WireValue`] type with the codec. The type must
    /// implement `prost::Message + Default` (the standard bound for
    /// `prost`-generated message types).
    pub fn register<T>(&mut self) -> &mut Self
    where
        T: WireValue + Message + Default,
    {
        let id = T::wire_type_id();
        self.encoders.insert(
            id,
            Box::new(
                move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
                    let concrete = v
                        .as_any()
                        .downcast_ref::<T>()
                        .ok_or(CodecError::TypeMismatch { expected: id })?;
                    Ok(concrete.encode_to_vec())
                },
            ),
        );
        self.decoders.insert(
            id,
            Box::new(
                move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
                    let value = T::decode(bytes).map_err(CodecError::decode_failure)?;
                    Ok(Box::new(value))
                },
            ),
        );
        self
    }

    /// Number of message types registered with this codec.
    #[must_use]
    pub fn registered_type_count(&self) -> usize {
        self.encoders.len()
    }
}

impl WireCodec for ProtobufCodec {
    fn content_type(&self) -> &'static str {
        "application/x-protobuf"
    }

    fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
        let id = value.type_id();
        let encoder = self
            .encoders
            .get(&id)
            .ok_or(CodecError::UnknownTypeId(id))?;
        encoder(value)
    }

    fn decode(
        &self,
        type_id: WireTypeId,
        bytes: &[u8],
    ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
        let decoder = self
            .decoders
            .get(&type_id)
            .ok_or(CodecError::UnknownTypeId(type_id))?;
        decoder(bytes)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use prost::Message;

    /// Hand-derived protobuf message used purely as a test fixture.
    /// Field tags are stable so the binary layout is deterministic
    /// across test runs. The shape is deliberately tiny: a string
    /// (tag 1), an integer (tag 2), and a Vec<u8> (tag 3) -- the
    /// three field flavours the brief calls for.
    ///
    /// `Default` is generated by `prost::Message`; deriving it again
    /// would conflict.
    #[derive(Clone, Eq, PartialEq, Message)]
    struct Sample {
        #[prost(string, tag = "1")]
        name: String,
        #[prost(uint32, tag = "2")]
        seq: u32,
        #[prost(bytes = "vec", tag = "3")]
        payload: Vec<u8>,
    }

    impl WireValue for Sample {
        fn wire_type_id() -> WireTypeId {
            WireTypeId::new("test.protobuf.Sample")
        }
    }

    fn fixture() -> Sample {
        Sample {
            name: "gamma".into(),
            seq: 99,
            payload: vec![0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef],
        }
    }

    #[test]
    fn round_trip_recovers_value() {
        let mut codec = ProtobufCodec::new();
        codec.register::<Sample>();
        let v = fixture();
        let bytes = codec.encode(&v).expect("encode");
        let back = codec
            .decode(Sample::wire_type_id(), &bytes)
            .expect("decode");
        let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
        assert_eq!(back, &v);
    }

    #[test]
    fn idempotent_encode_is_byte_equal() {
        let mut codec = ProtobufCodec::new();
        codec.register::<Sample>();
        let v = fixture();
        let a = codec.encode(&v).expect("encode 1");
        let b = codec.encode(&v).expect("encode 2");
        assert_eq!(a, b);
        let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
        let c = codec.encode(back.as_ref()).expect("encode 3");
        assert_eq!(a, c);
    }

    #[test]
    fn produced_bytes_match_prost_native_encoding() {
        let mut codec = ProtobufCodec::new();
        codec.register::<Sample>();
        let v = fixture();
        let through_codec = codec.encode(&v).expect("encode");
        let direct = v.encode_to_vec();
        assert_eq!(through_codec, direct);
    }

    #[test]
    fn unregistered_type_returns_unknown_type_id() {
        let codec = ProtobufCodec::new();
        let v = fixture();
        let err = codec.encode(&v).expect_err("expected unknown type");
        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
    }

    #[test]
    fn malformed_bytes_yield_decode_failure() {
        let mut codec = ProtobufCodec::new();
        codec.register::<Sample>();
        // Field tag 1 (wire type 2 = length-delimited) followed by a
        // length that runs past the end of the buffer.
        let err = codec
            .decode(Sample::wire_type_id(), &[0x0a, 0xff])
            .expect_err("expected decode failure");
        assert!(matches!(err, CodecError::Decode(_)));
    }
}