use std::error::Error;
use std::fmt;
use std::io;
#[cfg(feature = "cbor")]
use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[derive(Debug)]
pub enum WireError {
#[cfg(feature = "cbor")]
CborEncode(ciborium::ser::Error<io::Error>),
#[cfg(feature = "cbor")]
CborDecode(ciborium::de::Error<io::Error>),
#[cfg(feature = "zstd")]
Zstd(io::Error),
}
impl fmt::Display for WireError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
#[cfg(feature = "cbor")]
Self::CborEncode(_) => formatter.write_str("CBOR encoding failed"),
#[cfg(feature = "cbor")]
Self::CborDecode(_) => formatter.write_str("CBOR decoding failed"),
#[cfg(feature = "zstd")]
Self::Zstd(_) => formatter.write_str("Zstd operation failed"),
}
}
}
impl Error for WireError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
#[cfg(feature = "cbor")]
Self::CborEncode(error) => Some(error),
#[cfg(feature = "cbor")]
Self::CborDecode(error) => Some(error),
#[cfg(feature = "zstd")]
Self::Zstd(error) => Some(error),
}
}
}
#[cfg(feature = "cbor")]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct WireDeltaPoint<K, V> {
pub key: K,
pub end_time: u64,
pub value: V,
}
#[cfg(feature = "cbor")]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct WireFinalizedBucket<K, V> {
pub key: K,
pub start: u64,
pub end: u64,
pub value: V,
}
#[cfg(feature = "cbor")]
pub fn to_cbor<T>(value: &T) -> Result<Vec<u8>, WireError>
where
T: Serialize,
{
let mut output = Vec::new();
ciborium::ser::into_writer(value, &mut output).map_err(WireError::CborEncode)?;
Ok(output)
}
#[cfg(feature = "cbor")]
pub fn from_cbor<T>(bytes: &[u8]) -> Result<T, WireError>
where
T: DeserializeOwned,
{
ciborium::de::from_reader(bytes).map_err(WireError::CborDecode)
}
#[cfg(feature = "zstd")]
pub fn zstd_compress(bytes: &[u8], level: i32) -> Result<Vec<u8>, WireError> {
zstd::stream::encode_all(bytes, level).map_err(WireError::Zstd)
}
#[cfg(feature = "zstd")]
pub fn zstd_decompress(bytes: &[u8]) -> Result<Vec<u8>, WireError> {
zstd::stream::decode_all(bytes).map_err(WireError::Zstd)
}
#[cfg(all(feature = "cbor", feature = "zstd"))]
pub fn to_cbor_zstd<T>(value: &T, level: i32) -> Result<Vec<u8>, WireError>
where
T: Serialize,
{
zstd_compress(&to_cbor(value)?, level)
}
#[cfg(all(feature = "cbor", feature = "zstd"))]
pub fn from_cbor_zstd<T>(bytes: &[u8]) -> Result<T, WireError>
where
T: DeserializeOwned,
{
from_cbor(&zstd_decompress(bytes)?)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "cbor")]
#[test]
fn cbor_round_trips_delta_points() {
let point = WireDeltaPoint {
key: "series-a",
end_time: 60,
value: 42_u64,
};
let encoded = to_cbor(&point).unwrap();
let decoded: WireDeltaPoint<String, u64> = from_cbor(&encoded).unwrap();
assert_eq!(decoded.key, point.key);
assert_eq!(decoded.end_time, point.end_time);
assert_eq!(decoded.value, point.value);
}
#[cfg(all(feature = "cbor", feature = "zstd"))]
#[test]
fn cbor_zstd_round_trips_finalized_buckets() {
let bucket = WireFinalizedBucket {
key: "series-a",
start: 0,
end: 60,
value: 42_u64,
};
let encoded = to_cbor_zstd(&bucket, 1).unwrap();
let decoded: WireFinalizedBucket<String, u64> = from_cbor_zstd(&encoded).unwrap();
assert_eq!(decoded.key, bucket.key);
assert_eq!(decoded.start, bucket.start);
assert_eq!(decoded.end, bucket.end);
assert_eq!(decoded.value, bucket.value);
}
}