use std::io::Write;
use arrow_array::RecordBatch;
use bytes::Bytes;
use prost::Message;
use crate::{Error, Result};
pub struct CacheEntryWriter<'a> {
writer: &'a mut dyn Write,
pos: usize,
}
impl<'a> CacheEntryWriter<'a> {
pub fn new(writer: &'a mut dyn Write) -> Self {
Self { writer, pos: 0 }
}
pub(crate) fn with_pos(writer: &'a mut dyn Write, pos: usize) -> Self {
Self { writer, pos }
}
pub fn write_u8(&mut self, value: u8) -> Result<()> {
self.writer.write_all(&[value])?;
self.pos += 1;
Ok(())
}
pub fn write_header<P: Message>(&mut self, header: &P) -> Result<()> {
let bytes = header.encode_to_vec();
let len = u32::try_from(bytes.len())
.map_err(|_| Error::io(format!("cache header too large: {} bytes", bytes.len())))?;
self.writer.write_all(&len.to_le_bytes())?;
self.writer.write_all(&bytes)?;
self.pos += 4 + bytes.len();
Ok(())
}
pub fn write_ipc(&mut self, batch: &RecordBatch) -> Result<()> {
lance_arrow::ipc::write_ipc_section(self.writer, &mut self.pos, batch)
.map_err(|e| Error::io(e.to_string()))
}
pub fn write_ipc_batches<I>(&mut self, batches: I) -> Result<()>
where
I: IntoIterator<Item = RecordBatch>,
{
lance_arrow::ipc::write_ipc_section_batches(self.writer, &mut self.pos, batches)
.map_err(|e| Error::io(e.to_string()))
}
pub fn write_raw(&mut self, bytes: &[u8]) -> Result<()> {
lance_arrow::ipc::write_len_prefixed_bytes(self.writer, bytes)
.map_err(|e| Error::io(e.to_string()))?;
self.pos += 8 + bytes.len();
Ok(())
}
pub fn raw_writer(&mut self) -> &mut dyn Write {
self.writer
}
}
pub struct CacheEntryReader<'a> {
data: &'a Bytes,
offset: usize,
version: u32,
}
impl<'a> CacheEntryReader<'a> {
pub fn new(data: &'a Bytes, offset: usize, version: u32) -> Self {
Self {
data,
offset,
version,
}
}
pub fn version(&self) -> u32 {
self.version
}
pub fn read_u8(&mut self) -> Result<u8> {
let bytes = self.data.as_ref();
let v = *bytes
.get(self.offset)
.ok_or_else(|| Error::io("cache entry: truncated, missing tag byte".to_string()))?;
self.offset += 1;
Ok(v)
}
pub fn read_header<P: Message + Default>(&mut self) -> Result<P> {
let bytes = self.data.as_ref();
let len_end = self
.offset
.checked_add(4)
.filter(|&e| e <= bytes.len())
.ok_or_else(|| Error::io("cache header: truncated length prefix".to_string()))?;
let len = u32::from_le_bytes(bytes[self.offset..len_end].try_into().unwrap()) as usize;
let data_end = len_end
.checked_add(len)
.filter(|&e| e <= bytes.len())
.ok_or_else(|| Error::io("cache header: truncated body".to_string()))?;
let msg = P::decode(&bytes[len_end..data_end])
.map_err(|e| Error::io(format!("cache header decode failed: {e}")))?;
self.offset = data_end;
Ok(msg)
}
pub fn read_ipc(&mut self) -> Result<RecordBatch> {
lance_arrow::ipc::read_ipc_section_at(self.data, &mut self.offset)
.map_err(|e| Error::io(e.to_string()))
}
pub fn read_ipc_batches(&mut self) -> Result<Vec<RecordBatch>> {
lance_arrow::ipc::read_ipc_section_batches_at(self.data, &mut self.offset)
.map_err(|e| Error::io(e.to_string()))
}
pub fn read_raw(&mut self) -> Result<Bytes> {
lance_arrow::ipc::read_len_prefixed_bytes_at(self.data, &mut self.offset)
.map_err(|e| Error::io(e.to_string()))
}
pub fn body(&self) -> Bytes {
self.data.slice(self.offset..)
}
}