use crate::config::EncodingConfig;
use crate::decoding::{DecodingConfig, DecodingError, DecodingPipeline};
use crate::encoding::{EncodingError, EncodingPipeline};
use crate::security::AuthenticatedSymbol;
use crate::security::tag::AuthenticationTag;
use crate::transport::{SymbolSink, SymbolSinkExt, SymbolStream, SymbolStreamExt};
use crate::types::resource::{PoolConfig, SymbolPool};
use crate::types::symbol_set::SymbolSet;
use crate::types::{ObjectId, ObjectParams, Symbol, SymbolKind};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::any::TypeId;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
pub const TYPED_SYMBOL_MAGIC: [u8; 4] = *b"TSYM";
pub const TYPED_SYMBOL_HEADER_LEN: usize = 27;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SerializationFormat {
MessagePack,
Bincode,
Json,
Custom,
}
impl SerializationFormat {
fn to_byte(self) -> u8 {
match self {
Self::MessagePack => 1,
Self::Bincode => 2,
Self::Json => 3,
Self::Custom => 255,
}
}
fn from_byte(value: u8) -> Result<Self, TypeMismatchError> {
match value {
1 => Ok(Self::MessagePack),
2 => Ok(Self::Bincode),
3 => Ok(Self::Json),
255 => Ok(Self::Custom),
_ => Err(TypeMismatchError::UnsupportedFormatByte { value }),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum SerializationError {
#[error("serialization failed: {reason}")]
SerializationFailed {
reason: String,
},
#[error("value too large: {size} bytes exceeds {max} limit")]
ValueTooLarge {
size: usize,
max: usize,
},
#[error("unsupported type: {type_name}")]
UnsupportedType {
type_name: String,
},
}
#[derive(Debug, thiserror::Error)]
pub enum DeserializationError {
#[error("deserialization failed: {reason}")]
DeserializationFailed {
reason: String,
},
#[error("type mismatch: expected {expected}, got {actual}")]
TypeMismatch {
expected: String,
actual: String,
},
#[error("schema version mismatch: expected {expected}, got {actual}")]
SchemaMismatch {
expected: u32,
actual: u32,
},
#[error("corrupt symbol data")]
CorruptData,
}
#[derive(Debug, thiserror::Error)]
pub enum TypeMismatchError {
#[error("invalid magic number")]
InvalidMagic,
#[error("unknown type id: {type_id}")]
UnknownType {
type_id: u64,
},
#[error("unsupported serialization format byte: {value}")]
UnsupportedFormatByte {
value: u8,
},
#[error("schema hash mismatch: expected {expected}, got {actual}")]
SchemaMismatch {
expected: u64,
actual: u64,
},
}
#[derive(Debug, Clone, Copy)]
pub struct TypeDescriptor {
pub type_id: TypeId,
pub name: &'static str,
pub version: u32,
pub schema_hash: u64,
}
#[derive(Debug, Default)]
pub struct TypeRegistry {
types: HashMap<TypeId, TypeDescriptor>,
}
const DEFAULT_TYPE_REGISTRY_CAPACITY: usize = 16;
impl TypeRegistry {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(DEFAULT_TYPE_REGISTRY_CAPACITY)
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
types: HashMap::with_capacity(capacity),
}
}
pub fn register<T: 'static>(&mut self, name: &'static str, version: u32) {
let type_id = TypeId::of::<T>();
let schema_hash = schema_hash::<T>(version);
self.types.insert(
type_id,
TypeDescriptor {
type_id,
name,
version,
schema_hash,
},
);
}
#[must_use]
pub fn is_registered<T: 'static>(&self) -> bool {
self.types.contains_key(&TypeId::of::<T>())
}
#[must_use]
pub fn get<T: 'static>(&self) -> Option<&TypeDescriptor> {
self.types.get(&TypeId::of::<T>())
}
#[must_use]
pub fn schema_hash<T: 'static>(&self) -> u64 {
self.get::<T>()
.map_or_else(|| schema_hash::<T>(0), |desc| desc.schema_hash)
}
}
pub trait Serializer<T>: Send + Sync {
fn serialize(
&self,
value: &T,
format: SerializationFormat,
) -> Result<Vec<u8>, SerializationError>;
}
pub trait Deserializer<T>: Send + Sync {
fn deserialize(
&self,
bytes: &[u8],
format: SerializationFormat,
) -> Result<T, DeserializationError>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SerdeCodec;
impl<T: Serialize> Serializer<T> for SerdeCodec {
fn serialize(
&self,
value: &T,
format: SerializationFormat,
) -> Result<Vec<u8>, SerializationError> {
match format {
SerializationFormat::MessagePack => {
rmp_serde::to_vec(value).map_err(|err: rmp_serde::encode::Error| {
SerializationError::SerializationFailed {
reason: err.to_string(),
}
})
}
SerializationFormat::Bincode => {
bincode::serde::encode_to_vec(value, bincode::config::legacy()).map_err(
|err: bincode::error::EncodeError| SerializationError::SerializationFailed {
reason: err.to_string(),
},
)
}
SerializationFormat::Json => {
serde_json::to_vec(value).map_err(|err| SerializationError::SerializationFailed {
reason: err.to_string(),
})
}
SerializationFormat::Custom => Err(SerializationError::UnsupportedType {
type_name: std::any::type_name::<T>().to_string(),
}),
}
}
}
impl<T: DeserializeOwned> Deserializer<T> for SerdeCodec {
fn deserialize(
&self,
bytes: &[u8],
format: SerializationFormat,
) -> Result<T, DeserializationError> {
match format {
SerializationFormat::MessagePack => {
rmp_serde::from_slice(bytes).map_err(|err: rmp_serde::decode::Error| {
DeserializationError::DeserializationFailed {
reason: err.to_string(),
}
})
}
SerializationFormat::Bincode => {
bincode::serde::decode_from_slice(bytes, bincode::config::legacy())
.map(|(decoded, _)| decoded)
.map_err(|err: bincode::error::DecodeError| {
DeserializationError::DeserializationFailed {
reason: err.to_string(),
}
})
}
SerializationFormat::Json => serde_json::from_slice(bytes).map_err(|err| {
DeserializationError::DeserializationFailed {
reason: err.to_string(),
}
}),
SerializationFormat::Custom => Err(DeserializationError::DeserializationFailed {
reason: "custom format not supported".to_string(),
}),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct TypedHeader {
version: u16,
type_id: u64,
format: SerializationFormat,
schema_hash: u64,
payload_len: u32,
}
impl TypedHeader {
fn new<T: 'static>(format: SerializationFormat, version: u16, payload_len: u32) -> Self {
Self {
version,
type_id: type_id_hash::<T>(),
format,
schema_hash: schema_hash::<T>(u32::from(version)),
payload_len,
}
}
fn encode(self) -> [u8; TYPED_SYMBOL_HEADER_LEN] {
let mut buf = [0u8; TYPED_SYMBOL_HEADER_LEN];
buf[..4].copy_from_slice(&TYPED_SYMBOL_MAGIC);
buf[4..6].copy_from_slice(&self.version.to_le_bytes());
buf[6..14].copy_from_slice(&self.type_id.to_le_bytes());
buf[14] = self.format.to_byte();
buf[15..23].copy_from_slice(&self.schema_hash.to_le_bytes());
buf[23..27].copy_from_slice(&self.payload_len.to_le_bytes());
buf
}
fn decode(bytes: &[u8]) -> Result<Self, TypeMismatchError> {
if bytes.len() < TYPED_SYMBOL_HEADER_LEN {
return Err(TypeMismatchError::InvalidMagic);
}
if bytes[..4] != TYPED_SYMBOL_MAGIC {
return Err(TypeMismatchError::InvalidMagic);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
let mut type_id_bytes = [0u8; 8];
type_id_bytes.copy_from_slice(&bytes[6..14]);
let type_id = u64::from_le_bytes(type_id_bytes);
let format = SerializationFormat::from_byte(bytes[14])?;
let mut schema_bytes = [0u8; 8];
schema_bytes.copy_from_slice(&bytes[15..23]);
let schema_hash = u64::from_le_bytes(schema_bytes);
let mut payload_bytes = [0u8; 4];
payload_bytes.copy_from_slice(&bytes[23..27]);
let payload_len = u32::from_le_bytes(payload_bytes);
Ok(Self {
version,
type_id,
format,
schema_hash,
payload_len,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypedSymbol<T> {
symbol: Symbol,
header: TypedHeader,
_marker: PhantomData<T>,
}
impl<T> TypedSymbol<T> {
#[must_use]
pub fn from_symbol_unchecked(symbol: Symbol) -> Self {
let header = TypedHeader::decode(symbol.data()).expect("typed symbol header missing");
Self {
symbol,
header,
_marker: PhantomData,
}
}
pub fn try_from_symbol(symbol: Symbol) -> Result<Self, TypeMismatchError>
where
T: 'static,
{
let header = TypedHeader::decode(symbol.data())?;
let expected_type = type_id_hash::<T>();
if header.type_id != expected_type {
return Err(TypeMismatchError::UnknownType {
type_id: header.type_id,
});
}
let expected_schema = schema_hash::<T>(u32::from(header.version));
if header.schema_hash != expected_schema {
return Err(TypeMismatchError::SchemaMismatch {
expected: expected_schema,
actual: header.schema_hash,
});
}
Ok(Self {
symbol,
header,
_marker: PhantomData,
})
}
#[must_use]
#[inline]
pub fn symbol(&self) -> &Symbol {
&self.symbol
}
#[must_use]
pub fn into_symbol(self) -> Symbol {
self.symbol
}
#[must_use]
pub const fn format(&self) -> SerializationFormat {
self.header.format
}
#[must_use]
pub const fn version(&self) -> u16 {
self.header.version
}
#[must_use]
pub const fn payload_len(&self) -> u32 {
self.header.payload_len
}
fn strip_header(&self) -> Result<&[u8], DeserializationError> {
let data = self.symbol.data();
let payload_len = self.header.payload_len as usize;
let end = TYPED_SYMBOL_HEADER_LEN + payload_len;
if data.len() < end {
return Err(DeserializationError::CorruptData);
}
Ok(&data[TYPED_SYMBOL_HEADER_LEN..end])
}
}
impl<T: Serialize + DeserializeOwned + 'static> TypedSymbol<T> {
pub fn from_value(value: &T, format: SerializationFormat) -> Result<Self, SerializationError> {
let codec = SerdeCodec;
let payload = codec.serialize(value, format)?;
let header = TypedHeader::new::<T>(format, 1, payload.len() as u32);
let header_bytes = header.encode();
let symbol_size = crate::types::DEFAULT_SYMBOL_SIZE;
let max_payload = symbol_size.saturating_sub(TYPED_SYMBOL_HEADER_LEN);
if payload.len() > max_payload {
return Err(SerializationError::ValueTooLarge {
size: payload.len(),
max: max_payload,
});
}
let mut data = Vec::with_capacity(TYPED_SYMBOL_HEADER_LEN + payload.len());
data.extend_from_slice(&header_bytes);
data.extend_from_slice(&payload);
let object_id = object_id_from_bytes::<T>(&payload);
let symbol = Symbol::new(
crate::types::SymbolId::new(object_id, 0, 0),
data,
SymbolKind::Source,
);
Ok(Self {
symbol,
header,
_marker: PhantomData,
})
}
pub fn into_value(self) -> Result<T, DeserializationError> {
let codec = SerdeCodec;
let payload = self.strip_header()?;
codec.deserialize(payload, self.header.format)
}
pub fn value(&self) -> Result<T, DeserializationError> {
let codec = SerdeCodec;
let payload = self.strip_header()?;
codec.deserialize(payload, self.header.format)
}
}
pub struct TypedEncoder<T> {
config: EncodingConfig,
format: SerializationFormat,
version: u16,
serializer: Box<dyn Serializer<T>>,
_marker: PhantomData<T>,
}
impl<T: Serialize + 'static> TypedEncoder<T> {
#[must_use]
pub fn new(format: SerializationFormat) -> Self {
Self::with_config(EncodingConfig::default(), format)
}
#[must_use]
pub fn with_config(config: EncodingConfig, format: SerializationFormat) -> Self {
Self::with_serializer(config, format, SerdeCodec)
}
#[must_use]
pub fn with_serializer(
config: EncodingConfig,
format: SerializationFormat,
serializer: impl Serializer<T> + 'static,
) -> Self {
Self {
config,
format,
version: 1,
serializer: Box::new(serializer),
_marker: PhantomData,
}
}
pub fn encode(
&mut self,
object_id: ObjectId,
value: &T,
) -> Result<Vec<TypedSymbol<T>>, EncodingError> {
let payload = self
.serializer
.serialize(value, self.format)
.map_err(|err| EncodingError::ComputationFailed {
details: err.to_string(),
})?;
let payload_len =
u32::try_from(payload.len()).map_err(|_| EncodingError::DataTooLarge {
size: payload.len(),
limit: u32::MAX as usize,
})?;
let header = TypedHeader::new::<T>(self.format, self.version, payload_len);
let header_bytes = header.encode();
let inner_symbol_size = inner_symbol_size(self.config.symbol_size)
.map_err(|reason| EncodingError::InvalidConfig { reason })?;
let mut inner_config = self.config.clone();
inner_config.symbol_size = inner_symbol_size;
let pool = SymbolPool::new(PoolConfig {
symbol_size: inner_symbol_size,
..PoolConfig::default()
});
let mut pipeline = EncodingPipeline::new(inner_config, pool);
let mut symbols = Vec::new();
for result in pipeline.encode(object_id, &payload) {
let symbol = result?.into_symbol();
let typed_symbol = wrap_symbol(&symbol, header, &header_bytes);
symbols.push(typed_symbol);
}
Ok(symbols)
}
pub fn encode_into(
&mut self,
object_id: ObjectId,
value: &T,
set: &mut SymbolSet,
) -> Result<usize, EncodingError> {
let symbols = self.encode(object_id, value)?;
let mut inserted = 0;
for symbol in symbols {
match set.insert(symbol.into_symbol()) {
crate::types::symbol_set::InsertResult::Inserted { .. } => inserted += 1,
crate::types::symbol_set::InsertResult::Duplicate => {}
crate::types::symbol_set::InsertResult::MemoryLimitReached
| crate::types::symbol_set::InsertResult::BlockLimitReached { .. } => {
return Err(EncodingError::ComputationFailed {
details: "symbol set rejected insert".to_string(),
});
}
}
}
Ok(inserted)
}
pub async fn encode_to_sink<S: SymbolSink + Unpin>(
&mut self,
object_id: ObjectId,
value: &T,
sink: &mut S,
) -> Result<usize, EncodingError>
where
T: Send + Sync,
{
let symbols = self.encode(object_id, value)?;
let mut count = 0;
for symbol in symbols {
let auth =
AuthenticatedSymbol::new_verified(symbol.into_symbol(), AuthenticationTag::zero());
sink.send(auth)
.await
.map_err(|err| EncodingError::ComputationFailed {
details: err.to_string(),
})?;
count += 1;
}
Ok(count)
}
}
pub struct TypedDecoder<T> {
config: DecodingConfig,
format: SerializationFormat,
deserializer: Box<dyn Deserializer<T>>,
_marker: PhantomData<T>,
}
impl<T> std::fmt::Debug for TypedEncoder<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TypedEncoder")
.field("config", &self.config)
.field("format", &self.format)
.field("version", &self.version)
.finish_non_exhaustive()
}
}
impl<T> std::fmt::Debug for TypedDecoder<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TypedDecoder")
.field("config", &self.config)
.field("format", &self.format)
.finish_non_exhaustive()
}
}
impl<T: DeserializeOwned> TypedDecoder<T> {
#[must_use]
pub fn new(format: SerializationFormat) -> Self {
Self::with_config(DecodingConfig::default(), format)
}
#[must_use]
pub fn with_config(config: DecodingConfig, format: SerializationFormat) -> Self {
Self::with_deserializer(config, format, SerdeCodec)
}
#[must_use]
pub fn with_deserializer(
config: DecodingConfig,
format: SerializationFormat,
deserializer: impl Deserializer<T> + 'static,
) -> Self {
Self {
config,
format,
deserializer: Box::new(deserializer),
_marker: PhantomData,
}
}
pub fn decode<I>(&mut self, symbols: I) -> Result<T, DecodingError>
where
I: IntoIterator<Item = TypedSymbol<T>>,
T: 'static,
{
let mut iter = symbols.into_iter();
let first = iter
.next()
.ok_or_else(|| DecodingError::InconsistentMetadata {
sbn: 0,
details: "no symbols provided".to_string(),
})?;
let header = validate_header::<T>(&first, self.format)?;
let object_id = first.symbol().object_id();
let inner_size = inner_symbol_size(self.config.symbol_size).map_err(|reason| {
DecodingError::InconsistentMetadata {
sbn: 0,
details: reason,
}
})?;
let mut pipeline = DecodingPipeline::new(inner_config(&self.config, inner_size));
pipeline.set_object_params(object_params_for_payload(
object_id,
u64::from(header.payload_len),
inner_size,
self.config.max_block_size,
))?;
feed_typed_symbol(&mut pipeline, first, inner_size)?;
for symbol in iter {
let current = validate_header::<T>(&symbol, self.format)?;
if current != header {
return Err(DecodingError::InconsistentMetadata {
sbn: symbol.symbol().sbn(),
details: "typed symbol header mismatch".to_string(),
});
}
feed_typed_symbol(&mut pipeline, symbol, inner_size)?;
}
let payload = pipeline.into_data()?;
self.deserializer
.deserialize(&payload, header.format)
.map_err(|err| DecodingError::InconsistentMetadata {
sbn: 0,
details: err.to_string(),
})
}
pub fn decode_from_set(&mut self, set: &SymbolSet) -> Result<T, DecodingError>
where
T: 'static,
{
let symbols = set
.iter()
.map(|(_, symbol)| {
TypedSymbol::try_from_symbol(symbol.clone()).map_err(|err| {
DecodingError::InconsistentMetadata {
sbn: symbol.sbn(),
details: err.to_string(),
}
})
})
.collect::<Result<Vec<_>, _>>()?;
self.decode(symbols)
}
pub async fn decode_from_stream<S: SymbolStream + Unpin>(
&mut self,
stream: &mut S,
) -> Result<T, DecodingError>
where
T: 'static,
{
let mut header = None;
let mut pipeline = None;
let mut inner_size = 0;
while let Some(result) = stream.next().await {
let symbol = result.map_err(|err| DecodingError::InconsistentMetadata {
sbn: 0,
details: err.to_string(),
})?;
let typed = TypedSymbol::try_from_symbol(symbol.into_symbol()).map_err(|err| {
DecodingError::InconsistentMetadata {
sbn: 0,
details: err.to_string(),
}
})?;
let current = validate_header::<T>(&typed, self.format)?;
if let Some(expected) = header {
if current != expected {
return Err(DecodingError::InconsistentMetadata {
sbn: typed.symbol().sbn(),
details: "typed symbol header mismatch".to_string(),
});
}
} else {
inner_size = inner_symbol_size(self.config.symbol_size).map_err(|reason| {
DecodingError::InconsistentMetadata {
sbn: 0,
details: reason,
}
})?;
let mut current_pipeline =
DecodingPipeline::new(inner_config(&self.config, inner_size));
current_pipeline.set_object_params(object_params_for_payload(
typed.symbol().object_id(),
u64::from(current.payload_len),
inner_size,
self.config.max_block_size,
))?;
header = Some(current);
pipeline = Some(current_pipeline);
}
let is_complete = {
let pipeline =
pipeline
.as_mut()
.ok_or_else(|| DecodingError::InconsistentMetadata {
sbn: typed.symbol().sbn(),
details: "typed stream pipeline not initialized".to_string(),
})?;
feed_typed_symbol(pipeline, typed, inner_size)?;
pipeline.is_complete()
};
if is_complete {
let payload = pipeline
.take()
.ok_or_else(|| DecodingError::InconsistentMetadata {
sbn: 0,
details: "typed stream pipeline disappeared before completion".to_string(),
})?
.into_data()?;
let header = header.ok_or_else(|| DecodingError::InconsistentMetadata {
sbn: 0,
details: "typed stream header missing at completion".to_string(),
})?;
return self
.deserializer
.deserialize(&payload, header.format)
.map_err(|err| DecodingError::InconsistentMetadata {
sbn: 0,
details: err.to_string(),
});
}
}
let header = header.ok_or_else(|| DecodingError::InconsistentMetadata {
sbn: 0,
details: "no symbols provided".to_string(),
})?;
let pipeline = pipeline.ok_or_else(|| DecodingError::InconsistentMetadata {
sbn: 0,
details: "typed stream pipeline missing at end of stream".to_string(),
})?;
let payload = pipeline.into_data()?;
self.deserializer
.deserialize(&payload, header.format)
.map_err(|err| DecodingError::InconsistentMetadata {
sbn: 0,
details: err.to_string(),
})
}
}
fn inner_symbol_size(symbol_size: u16) -> Result<u16, String> {
let outer = usize::from(symbol_size);
if outer <= TYPED_SYMBOL_HEADER_LEN {
return Err(format!(
"symbol_size {outer} must exceed typed header length {TYPED_SYMBOL_HEADER_LEN}"
));
}
let inner = outer - TYPED_SYMBOL_HEADER_LEN;
u16::try_from(inner).map_err(|_| "inner symbol size overflow".to_string())
}
fn inner_config(config: &DecodingConfig, inner_symbol_size: u16) -> DecodingConfig {
let mut inner = config.clone();
inner.symbol_size = inner_symbol_size;
inner
}
fn wrap_symbol<T>(
symbol: &Symbol,
header: TypedHeader,
header_bytes: &[u8; TYPED_SYMBOL_HEADER_LEN],
) -> TypedSymbol<T> {
let mut data = Vec::with_capacity(TYPED_SYMBOL_HEADER_LEN + symbol.data().len());
data.extend_from_slice(header_bytes);
data.extend_from_slice(symbol.data());
let typed_symbol = Symbol::new(symbol.id(), data, symbol.kind());
TypedSymbol {
symbol: typed_symbol,
header,
_marker: PhantomData,
}
}
fn feed_typed_symbol<T>(
pipeline: &mut DecodingPipeline,
symbol: TypedSymbol<T>,
inner_size: u16,
) -> Result<(), DecodingError> {
let raw = symbol.into_symbol();
if raw.data().len() < TYPED_SYMBOL_HEADER_LEN {
return Err(DecodingError::SymbolSizeMismatch {
expected: inner_size,
actual: raw.data().len(),
});
}
let payload = raw.data()[TYPED_SYMBOL_HEADER_LEN..].to_vec();
if payload.len() != usize::from(inner_size) {
return Err(DecodingError::SymbolSizeMismatch {
expected: inner_size,
actual: payload.len(),
});
}
let inner_symbol = Symbol::new(raw.id(), payload, raw.kind());
let auth = AuthenticatedSymbol::new_verified(inner_symbol, AuthenticationTag::zero());
let _ = pipeline.feed(auth)?;
Ok(())
}
fn validate_header<T: 'static>(
symbol: &TypedSymbol<T>,
expected_format: SerializationFormat,
) -> Result<TypedHeader, DecodingError> {
if symbol.header.format != expected_format {
return Err(DecodingError::InconsistentMetadata {
sbn: symbol.symbol().sbn(),
details: "serialization format mismatch".to_string(),
});
}
Ok(symbol.header)
}
fn object_params_for_payload(
object_id: ObjectId,
payload_len: u64,
symbol_size: u16,
max_block_size: usize,
) -> ObjectParams {
let symbol_size = usize::from(symbol_size);
let max_block_size = max_block_size.max(symbol_size);
let payload_len = payload_len as usize;
let blocks = payload_len.div_ceil(max_block_size).max(1);
let symbols_per_block = if payload_len == 0 {
0
} else {
payload_len.min(max_block_size).div_ceil(symbol_size)
};
let blocks = blocks.min(u16::MAX as usize) as u16;
let symbols_per_block = symbols_per_block.min(u16::MAX as usize) as u16;
ObjectParams::new(
object_id,
payload_len as u64,
symbol_size as u16,
blocks,
symbols_per_block,
)
}
fn type_id_hash<T: 'static>() -> u64 {
let mut hasher = crate::util::DetHasher::default();
TypeId::of::<T>().hash(&mut hasher);
hasher.finish()
}
fn schema_hash<T: 'static>(version: u32) -> u64 {
let mut hasher = crate::util::DetHasher::default();
std::any::type_name::<T>().hash(&mut hasher);
version.hash(&mut hasher);
hasher.finish()
}
fn object_id_from_bytes<T: 'static>(bytes: &[u8]) -> ObjectId {
let mut hasher = crate::util::DetHasher::default();
bytes.hash(&mut hasher);
std::any::type_name::<T>().hash(&mut hasher);
let hash = hasher.finish();
ObjectId::new(hash, hash.rotate_left(17))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::error::StreamError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
struct Demo {
id: u64,
name: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
struct EmptyStruct;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
enum TestEnum {
Unit,
Tuple(u32, String),
Struct { x: i32, y: i32 },
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Nested {
inner: Box<Option<Self>>,
value: f64,
}
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
struct ReadyThenPendingStream {
items: std::vec::IntoIter<AuthenticatedSymbol>,
}
impl ReadyThenPendingStream {
fn new(items: Vec<AuthenticatedSymbol>) -> Self {
Self {
items: items.into_iter(),
}
}
}
impl SymbolStream for ReadyThenPendingStream {
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<AuthenticatedSymbol, StreamError>>> {
self.items
.next()
.map_or(Poll::Pending, |item| Poll::Ready(Some(Ok(item))))
}
}
#[test]
fn typed_symbol_single_roundtrip() {
let value = Demo {
id: 7,
name: "alpha".to_string(),
};
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).expect("symbol");
let decoded_value = symbol.into_value().expect("decode");
assert_eq!(value, decoded_value);
}
#[test]
fn typed_encoder_decoder_roundtrip() {
let value = Demo {
id: 42,
name: "beta".to_string(),
};
let mut encoder = TypedEncoder::with_config(
EncodingConfig {
symbol_size: 64,
max_block_size: 128,
repair_overhead: 1.05,
encoding_parallelism: 1,
decoding_parallelism: 1,
},
SerializationFormat::Bincode,
);
let object_id = ObjectId::new(1, 2);
let symbols = encoder.encode(object_id, &value).expect("encode");
assert!(!symbols.is_empty());
let mut decoder: TypedDecoder<Demo> = TypedDecoder::with_config(
DecodingConfig {
symbol_size: 64,
max_block_size: 128,
repair_overhead: 1.05,
min_overhead: 0,
max_buffered_symbols: 0,
block_timeout: std::time::Duration::from_secs(1),
verify_auth: false,
},
SerializationFormat::Bincode,
);
let decoded_value = decoder.decode(symbols).expect("decode");
assert_eq!(value, decoded_value);
}
#[test]
fn object_params_for_small_payload_uses_actual_single_block_k() {
let params = object_params_for_payload(ObjectId::new_for_test(7), 8, 37, 512);
assert_eq!(params.source_blocks, 1);
assert_eq!(params.symbols_per_block, 1);
}
#[test]
fn object_params_for_payload_preserves_256_block_boundary() {
let params = object_params_for_payload(ObjectId::new_for_test(8), 256, 1, 1);
assert_eq!(params.source_blocks, 256);
assert_eq!(params.symbols_per_block, 1);
}
#[test]
#[allow(clippy::similar_names)]
fn typed_decoder_accepts_small_single_block_payload_with_large_max_block_size() {
let value: u64 = 42;
let mut encoder = TypedEncoder::with_config(
EncodingConfig {
symbol_size: 64,
max_block_size: 512,
repair_overhead: 1.05,
encoding_parallelism: 1,
decoding_parallelism: 1,
},
SerializationFormat::Bincode,
);
let mut decoder: TypedDecoder<u64> = TypedDecoder::with_config(
DecodingConfig {
symbol_size: 64,
max_block_size: 512,
repair_overhead: 1.05,
min_overhead: 0,
max_buffered_symbols: 0,
block_timeout: std::time::Duration::from_secs(1),
verify_auth: false,
},
SerializationFormat::Bincode,
);
let symbols = encoder
.encode(ObjectId::new_for_test(9), &value)
.expect("encode");
let decoded = decoder.decode(symbols).expect("decode");
assert_eq!(decoded, value);
}
#[test]
fn typed_decoder_stream_returns_once_object_is_complete_before_eof() {
let value: u64 = 42;
let mut encoder = TypedEncoder::with_config(
EncodingConfig {
symbol_size: 64,
max_block_size: 512,
repair_overhead: 1.05,
encoding_parallelism: 1,
decoding_parallelism: 1,
},
SerializationFormat::Bincode,
);
let mut decoder: TypedDecoder<u64> = TypedDecoder::with_config(
DecodingConfig {
symbol_size: 64,
max_block_size: 512,
repair_overhead: 1.05,
min_overhead: 0,
max_buffered_symbols: 0,
block_timeout: std::time::Duration::from_secs(1),
verify_auth: false,
},
SerializationFormat::Bincode,
);
let symbols = encoder
.encode(ObjectId::new_for_test(11), &value)
.expect("encode");
assert!(
!symbols.is_empty(),
"test requires at least one encoded symbol for the object"
);
let auth_symbols = symbols
.into_iter()
.map(|symbol| {
AuthenticatedSymbol::new_verified(symbol.into_symbol(), AuthenticationTag::zero())
})
.collect();
let mut stream = ReadyThenPendingStream::new(auth_symbols);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut future = Box::pin(decoder.decode_from_stream(&mut stream));
match Future::poll(future.as_mut(), &mut cx) {
Poll::Ready(Ok(actual)) => assert_eq!(actual, value),
other => panic!(
"stream decode should finish once the first object is complete, even if the stream stays open: {other:?}"
),
}
}
#[test]
fn test_roundtrip_primitive_types() {
let val_u64: u64 = 12_345_678_901_234;
let sym = TypedSymbol::from_value(&val_u64, SerializationFormat::Bincode).unwrap();
assert_eq!(sym.into_value().unwrap(), val_u64);
let val_int: i32 = -42;
let sym = TypedSymbol::from_value(&val_int, SerializationFormat::Bincode).unwrap();
assert_eq!(sym.into_value().unwrap(), val_int);
let val_bool = true;
let sym = TypedSymbol::from_value(&val_bool, SerializationFormat::Bincode).unwrap();
assert_eq!(sym.into_value().unwrap(), val_bool);
let val_str = "hello world".to_string();
let sym = TypedSymbol::from_value(&val_str, SerializationFormat::Bincode).unwrap();
assert_eq!(sym.into_value().unwrap(), val_str);
}
#[test]
fn test_roundtrip_struct() {
let value = Demo {
id: 999,
name: "struct_test".to_string(),
};
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
assert_eq!(symbol.into_value().unwrap(), value);
}
#[test]
fn test_roundtrip_enum() {
let variants = [
TestEnum::Unit,
TestEnum::Tuple(42, "tuple".to_string()),
TestEnum::Struct { x: -10, y: 20 },
];
for value in variants {
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
assert_eq!(symbol.into_value().unwrap(), value);
}
}
#[test]
fn test_roundtrip_vec() {
let value: Vec<u32> = vec![1, 2, 3, 4, 5];
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
assert_eq!(symbol.into_value().unwrap(), value);
let value: Vec<Demo> = vec![
Demo {
id: 1,
name: "a".into(),
},
Demo {
id: 2,
name: "b".into(),
},
];
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
assert_eq!(symbol.into_value().unwrap(), value);
}
#[test]
fn test_roundtrip_hashmap() {
let mut value: HashMap<String, i32> = HashMap::new();
value.insert("one".to_string(), 1);
value.insert("two".to_string(), 2);
value.insert("negative".to_string(), -100);
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
assert_eq!(symbol.into_value().unwrap(), value);
}
#[test]
fn test_messagepack_format() {
let value = Demo {
id: 1,
name: "msgpack".to_string(),
};
let symbol = TypedSymbol::from_value(&value, SerializationFormat::MessagePack).unwrap();
assert_eq!(symbol.format(), SerializationFormat::MessagePack);
assert_eq!(symbol.into_value().unwrap(), value);
}
#[test]
fn test_bincode_format() {
let value = Demo {
id: 2,
name: "bincode".to_string(),
};
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
assert_eq!(symbol.format(), SerializationFormat::Bincode);
assert_eq!(symbol.into_value().unwrap(), value);
}
#[test]
fn test_json_format() {
let value = Demo {
id: 3,
name: "json".to_string(),
};
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Json).unwrap();
assert_eq!(symbol.format(), SerializationFormat::Json);
assert_eq!(symbol.into_value().unwrap(), value);
}
#[test]
fn test_type_mismatch_detected() {
let value: u64 = 42;
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
let raw = symbol.into_symbol();
let result = TypedSymbol::<String>::try_from_symbol(raw);
assert!(result.is_err());
match result {
Err(TypeMismatchError::UnknownType { .. }) => {}
_ => panic!("Expected UnknownType error"),
}
}
#[test]
fn test_corrupt_header_detected() {
let mut data = vec![0u8; TYPED_SYMBOL_HEADER_LEN + 10];
data[0..4].copy_from_slice(b"XXXX");
let symbol = Symbol::new(
crate::types::SymbolId::new(ObjectId::new(1, 1), 0, 0),
data,
SymbolKind::Source,
);
let result = TypedSymbol::<u64>::try_from_symbol(symbol);
assert!(result.is_err());
match result {
Err(TypeMismatchError::InvalidMagic) => {}
_ => panic!("Expected InvalidMagic error"),
}
}
#[test]
fn test_unsupported_format_byte_is_reported() {
let value: u64 = 42;
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
let mut raw = symbol.into_symbol();
raw.data_mut()[14] = 4;
let result = TypedSymbol::<u64>::try_from_symbol(raw);
match result {
Err(TypeMismatchError::UnsupportedFormatByte { value: 4 }) => {}
other => panic!("expected UnsupportedFormatByte {{ value: 4 }}, got {other:?}"),
}
}
#[test]
fn test_type_registration() {
let mut registry = TypeRegistry::new();
assert!(!registry.is_registered::<Demo>());
registry.register::<Demo>("Demo", 1);
assert!(registry.is_registered::<Demo>());
let desc = registry.get::<Demo>().unwrap();
assert_eq!(desc.name, "Demo");
assert_eq!(desc.version, 1);
}
#[test]
fn test_type_registry_with_capacity_registration() {
let mut registry = TypeRegistry::with_capacity(1);
registry.register::<Demo>("Demo", 2);
let desc = registry.get::<Demo>().unwrap();
assert_eq!(desc.name, "Demo");
assert_eq!(desc.version, 2);
}
#[test]
fn test_schema_hash_stability() {
let hash1 = schema_hash::<Demo>(1);
let hash2 = schema_hash::<Demo>(1);
assert_eq!(hash1, hash2);
let hash3 = schema_hash::<Demo>(2);
assert_ne!(hash1, hash3);
let hash4 = schema_hash::<EmptyStruct>(1);
assert_ne!(hash1, hash4);
}
#[test]
fn test_empty_struct() {
let value = EmptyStruct;
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
let decoded: EmptyStruct = symbol.into_value().unwrap();
assert_eq!(decoded, value);
}
#[test]
#[allow(clippy::approx_constant)]
fn test_deeply_nested_type() {
let value = Nested {
inner: Box::new(Some(Nested {
inner: Box::new(Some(Nested {
inner: Box::new(None),
value: 3.14,
})),
value: 2.71,
})),
value: 1.41,
};
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Bincode).unwrap();
let decoded: Nested = symbol.into_value().unwrap();
assert_eq!(decoded, value);
}
#[test]
fn test_header_encode_decode_roundtrip() {
let header = TypedHeader::new::<Demo>(SerializationFormat::Json, 5, 100);
let encoded = header.encode();
let decoded = TypedHeader::decode(&encoded).unwrap();
assert_eq!(header, decoded);
}
#[test]
fn test_serialization_format_byte_roundtrip() {
for format in [
SerializationFormat::MessagePack,
SerializationFormat::Bincode,
SerializationFormat::Json,
SerializationFormat::Custom,
] {
let byte = format.to_byte();
let recovered = SerializationFormat::from_byte(byte).unwrap();
assert_eq!(format, recovered);
}
}
#[test]
fn serialization_format_debug_clone_copy_eq() {
let fmt = SerializationFormat::Json;
let dbg = format!("{fmt:?}");
assert!(dbg.contains("Json"), "{dbg}");
let copied = fmt;
let cloned = fmt;
assert_eq!(copied, cloned);
assert_ne!(fmt, SerializationFormat::Bincode);
}
#[test]
fn typed_symbol_json_snapshot_scrubbed_ids() {
let value = Demo {
id: 11,
name: "typed".to_string(),
};
let symbol = TypedSymbol::from_value(&value, SerializationFormat::Json)
.expect("create typed symbol");
insta::assert_json_snapshot!(
"typed_symbol_json_scrubbed_ids",
serde_json::json!({
"symbol_id": {
"object_id": "[OBJECT_ID]",
"sbn": symbol.symbol().sbn(),
"esi": symbol.symbol().esi(),
},
"kind": symbol.symbol().kind().to_string(),
"format": format!("{:?}", symbol.format()),
"version": symbol.version(),
"payload_len": symbol.payload_len(),
"value": symbol.value().expect("decode typed symbol"),
})
);
}
}