use std::io::Write;
use std::sync::Arc;
use bytes::Bytes;
use crate::{Error, Result};
use super::{CacheEntryReader, CacheEntryWriter};
pub const MAGIC: [u8; 4] = *b"LCE1";
pub fn has_cache_envelope(data: &[u8]) -> bool {
data.get(..MAGIC.len()) == Some(&MAGIC[..])
}
const ENVELOPE_VERSION: u8 = 1;
struct ParsedEnvelope<'a> {
type_id: &'a str,
type_version: u32,
body_offset: usize,
}
fn parse_envelope(data: &Bytes) -> Option<ParsedEnvelope<'_>> {
let bytes = data.as_ref();
let mut off = 0usize;
let magic = bytes.get(off..off + 4)?;
if magic != MAGIC {
return None;
}
off += 4;
if *bytes.get(off)? != ENVELOPE_VERSION {
return None;
}
off += 1;
let type_id_len = u16::from_le_bytes(bytes.get(off..off + 2)?.try_into().ok()?) as usize;
off += 2;
let type_id = std::str::from_utf8(bytes.get(off..off + type_id_len)?).ok()?;
off += type_id_len;
let type_version = u32::from_le_bytes(bytes.get(off..off + 4)?.try_into().ok()?);
off += 4;
Some(ParsedEnvelope {
type_id,
type_version,
body_offset: off,
})
}
fn write_envelope(writer: &mut dyn Write, type_id: &str, type_version: u32) -> Result<usize> {
let type_id_len = u16::try_from(type_id.len()).map_err(|_| {
Error::io(format!(
"cache codec type_id too long ({} bytes, max {})",
type_id.len(),
u16::MAX
))
})?;
writer.write_all(&MAGIC)?;
writer.write_all(&[ENVELOPE_VERSION])?;
writer.write_all(&type_id_len.to_le_bytes())?;
writer.write_all(type_id.as_bytes())?;
writer.write_all(&type_version.to_le_bytes())?;
Ok(4 + 1 + 2 + type_id.len() + 4)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheMissReason {
InvalidEnvelope,
TypeMismatch,
VersionTooNew,
BodyError,
}
#[derive(Debug)]
pub enum CacheDecode<T> {
Hit(T),
Miss(CacheMissReason),
}
impl<T> CacheDecode<T> {
pub fn hit(self) -> Option<T> {
match self {
Self::Hit(v) => Some(v),
Self::Miss(_) => None,
}
}
}
pub trait CacheCodecImpl: Send + Sync {
const TYPE_ID: &'static str;
const CURRENT_VERSION: u32;
fn serialize(&self, writer: &mut CacheEntryWriter<'_>) -> Result<()>;
fn deserialize(reader: &mut CacheEntryReader<'_>) -> Result<Self>
where
Self: Sized;
}
pub(crate) type ArcAny = Arc<dyn std::any::Any + Send + Sync>;
#[derive(Copy, Clone)]
pub struct CacheCodec {
type_id: &'static str,
version: u32,
serialize_body: fn(&ArcAny, &mut CacheEntryWriter<'_>) -> Result<()>,
deserialize_body: fn(&mut CacheEntryReader<'_>) -> Result<ArcAny>,
}
impl std::fmt::Debug for CacheCodec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CacheCodec")
.field("type_id", &self.type_id)
.field("version", &self.version)
.finish_non_exhaustive()
}
}
fn serialize_via_impl<T: CacheCodecImpl + 'static>(
any: &ArcAny,
writer: &mut CacheEntryWriter<'_>,
) -> Result<()> {
let val = any
.downcast_ref::<T>()
.expect("CacheCodec::serialize called with wrong type (this is a bug in the cache layer)");
val.serialize(writer)
}
fn deserialize_via_impl<T: CacheCodecImpl + 'static>(
reader: &mut CacheEntryReader<'_>,
) -> Result<ArcAny> {
let val = T::deserialize(reader)?;
Ok(Arc::new(val) as ArcAny)
}
impl CacheCodec {
pub fn new(
type_id: &'static str,
version: u32,
serialize_body: fn(&ArcAny, &mut CacheEntryWriter<'_>) -> Result<()>,
deserialize_body: fn(&mut CacheEntryReader<'_>) -> Result<ArcAny>,
) -> Self {
Self {
type_id,
version,
serialize_body,
deserialize_body,
}
}
pub fn from_impl<T: CacheCodecImpl + 'static>() -> Self {
Self {
type_id: T::TYPE_ID,
version: T::CURRENT_VERSION,
serialize_body: serialize_via_impl::<T>,
deserialize_body: deserialize_via_impl::<T>,
}
}
pub fn serialize(&self, value: &ArcAny, writer: &mut dyn Write) -> Result<()> {
let body_offset = write_envelope(writer, self.type_id, self.version)?;
let mut entry_writer = CacheEntryWriter::with_pos(writer, body_offset);
(self.serialize_body)(value, &mut entry_writer)
}
pub fn deserialize(&self, data: &Bytes) -> CacheDecode<ArcAny> {
let Some(envelope) = parse_envelope(data) else {
log::debug!("cache entry rejected: missing or invalid envelope");
return CacheDecode::Miss(CacheMissReason::InvalidEnvelope);
};
if envelope.type_id != self.type_id {
log::debug!(
"cache entry type_id mismatch: got {:?}, expected {:?}",
envelope.type_id,
self.type_id
);
return CacheDecode::Miss(CacheMissReason::TypeMismatch);
}
if envelope.type_version > self.version {
log::debug!(
"cache entry {:?} has unsupported type_version {} (this build writes {})",
self.type_id,
envelope.type_version,
self.version
);
return CacheDecode::Miss(CacheMissReason::VersionTooNew);
}
let mut reader = CacheEntryReader::new(data, envelope.body_offset, envelope.type_version);
match (self.deserialize_body)(&mut reader) {
Ok(value) => CacheDecode::Hit(value),
Err(e) => {
log::debug!(
"cache entry {:?} v{} failed to decode: {e}",
self.type_id,
envelope.type_version
);
CacheDecode::Miss(CacheMissReason::BodyError)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, PartialEq)]
struct Widget {
n: u32,
}
impl CacheCodecImpl for Widget {
const TYPE_ID: &'static str = "test.Widget";
const CURRENT_VERSION: u32 = 1;
fn serialize(&self, writer: &mut CacheEntryWriter<'_>) -> Result<()> {
writer.write_raw(&self.n.to_le_bytes())
}
fn deserialize(reader: &mut CacheEntryReader<'_>) -> Result<Self> {
let bytes = reader.read_raw()?;
let n = u32::from_le_bytes(
bytes
.as_ref()
.try_into()
.map_err(|_| Error::io("bad widget".to_string()))?,
);
Ok(Self { n })
}
}
fn serialize_widget(widget: &Widget) -> Bytes {
let codec = CacheCodec::from_impl::<Widget>();
let any: ArcAny = Arc::new(Widget { n: widget.n });
let mut buf = Vec::new();
codec.serialize(&any, &mut buf).unwrap();
Bytes::from(buf)
}
fn miss_reason(data: &Bytes) -> Option<CacheMissReason> {
match deserialize_widget(data) {
CacheDecode::Hit(_) => None,
CacheDecode::Miss(reason) => Some(reason),
}
}
fn deserialize_widget(data: &Bytes) -> CacheDecode<Widget> {
let codec = CacheCodec::from_impl::<Widget>();
match codec.deserialize(data) {
CacheDecode::Hit(any) => {
CacheDecode::Hit(Arc::try_unwrap(any.downcast::<Widget>().unwrap()).unwrap())
}
CacheDecode::Miss(reason) => CacheDecode::Miss(reason),
}
}
#[test]
fn envelope_roundtrip_hits() {
let bytes = serialize_widget(&Widget { n: 0xDEADBEEF });
assert_eq!(&bytes[..4], b"LCE1");
let decoded = deserialize_widget(&bytes).hit().unwrap();
assert_eq!(decoded, Widget { n: 0xDEADBEEF });
}
#[test]
fn has_cache_envelope_detects_magic() {
let bytes = serialize_widget(&Widget { n: 1 });
assert!(has_cache_envelope(&bytes));
assert!(has_cache_envelope(&MAGIC)); assert!(!has_cache_envelope(b"LCE")); assert!(!has_cache_envelope(b"JUNK and more"));
assert!(!has_cache_envelope(&[]));
}
#[test]
fn wrong_magic_is_miss() {
let mut bytes = serialize_widget(&Widget { n: 7 }).to_vec();
bytes[0] = b'X';
assert_eq!(
miss_reason(&Bytes::from(bytes)),
Some(CacheMissReason::InvalidEnvelope)
);
}
#[test]
fn pre_stabilization_blob_is_miss() {
let mut blob = Vec::new();
blob.extend_from_slice(&(42u64).to_le_bytes());
blob.extend_from_slice(&[0u8; 42]);
assert_eq!(
miss_reason(&Bytes::from(blob)),
Some(CacheMissReason::InvalidEnvelope)
);
assert_eq!(
miss_reason(&Bytes::from(vec![0u8, 1, 2, 3])),
Some(CacheMissReason::InvalidEnvelope)
);
}
#[test]
fn unknown_envelope_version_is_miss() {
let mut bytes = serialize_widget(&Widget { n: 7 }).to_vec();
bytes[4] = 0xFF; assert_eq!(
miss_reason(&Bytes::from(bytes)),
Some(CacheMissReason::InvalidEnvelope)
);
}
#[test]
fn type_id_mismatch_is_miss() {
let mut buf = Vec::new();
write_envelope(&mut buf, "some.OtherType", 1).unwrap();
buf.extend_from_slice(&(4u64).to_le_bytes());
buf.extend_from_slice(&99u32.to_le_bytes());
assert_eq!(
miss_reason(&Bytes::from(buf)),
Some(CacheMissReason::TypeMismatch)
);
}
#[test]
fn unsupported_future_type_version_is_miss() {
let mut buf = Vec::new();
write_envelope(&mut buf, Widget::TYPE_ID, Widget::CURRENT_VERSION + 1).unwrap();
lance_arrow::ipc::write_len_prefixed_bytes(&mut buf, &9u32.to_le_bytes()).unwrap();
assert_eq!(
miss_reason(&Bytes::from(buf)),
Some(CacheMissReason::VersionTooNew)
);
}
#[test]
fn truncated_envelope_is_miss() {
let bytes = serialize_widget(&Widget { n: 7 });
for cut in [0, 1, 4, 5, 7, 9] {
assert_eq!(
miss_reason(&bytes.slice(..cut.min(bytes.len()))),
Some(CacheMissReason::InvalidEnvelope),
"truncating to {cut} bytes should miss as InvalidEnvelope"
);
}
}
#[test]
fn body_decode_error_is_miss() {
let mut buf = Vec::new();
write_envelope(&mut buf, Widget::TYPE_ID, Widget::CURRENT_VERSION).unwrap();
buf.extend_from_slice(&(1u64).to_le_bytes());
buf.push(0u8);
assert_eq!(
miss_reason(&Bytes::from(buf)),
Some(CacheMissReason::BodyError)
);
}
#[test]
fn reader_exposes_envelope_version() {
let mut buf = Vec::new();
write_envelope(&mut buf, Widget::TYPE_ID, 7).unwrap();
let body_off = buf.len();
lance_arrow::ipc::write_len_prefixed_bytes(&mut buf, &5u32.to_le_bytes()).unwrap();
let data = Bytes::from(buf);
let mut r = CacheEntryReader::new(&data, body_off, 7);
assert_eq!(r.version(), 7);
assert_eq!(r.read_raw().unwrap().as_ref(), 5u32.to_le_bytes());
}
}