use std::io::{Read, Write};
pub trait CodecWriter<T> {
type Error: std::error::Error + Send + Sync + 'static;
fn write(&mut self, item: &T) -> Result<(), Self::Error>;
fn finish(self) -> Result<(), Self::Error>;
}
pub trait CodecReader<T> {
type Error: std::error::Error + Send + Sync + 'static;
fn read(&mut self) -> Result<Option<T>, Self::Error>;
}
pub trait Codec<T>: Copy {
type Error: std::error::Error + Send + Sync + 'static;
type Writer<W: Write>: CodecWriter<T, Error = Self::Error>;
type Reader<R: Read>: CodecReader<T, Error = Self::Error>;
fn writer<W: Write>(&self, dest: W) -> Self::Writer<W>;
fn reader<R: Read>(&self, source: R) -> Self::Reader<R>;
}
pub trait KeyedCodecWriter<T, K> {
type Error: std::error::Error + Send + Sync + 'static;
fn write_keyed(&mut self, item: &T, key: &K) -> Result<(), Self::Error>;
fn finish(self) -> Result<(), Self::Error>;
}
pub trait KeyedCodecReader<T, K> {
type Error: std::error::Error + Send + Sync + 'static;
fn next_key(&mut self) -> Result<Option<K>, Self::Error>;
fn current_record(&mut self) -> Result<T, Self::Error>;
}
pub trait KeyedCodec<T>: Codec<T> {
type Key: Clone;
type KeyedWriter<W: Write>: KeyedCodecWriter<T, Self::Key, Error = Self::Error>;
type KeyedReader<R: Read>: KeyedCodecReader<T, Self::Key, Error = Self::Error>;
fn derive_key(&self, item: &T) -> Self::Key;
fn keyed_writer<W: Write>(&self, dest: W) -> Self::KeyedWriter<W>;
fn keyed_reader<R: Read>(&self, source: R) -> Self::KeyedReader<R>;
}
#[cfg(test)]
mod tests {
use std::io::BufWriter;
use super::*;
#[derive(Clone, Copy)]
struct U64Codec;
struct U64Writer<W: Write> {
inner: BufWriter<W>,
}
impl<W: Write> CodecWriter<u64> for U64Writer<W> {
type Error = std::io::Error;
fn write(&mut self, item: &u64) -> Result<(), Self::Error> {
use std::io::Write as _;
self.inner.write_all(&item.to_le_bytes())
}
fn finish(mut self) -> Result<(), Self::Error> {
use std::io::Write as _;
self.inner.flush()
}
}
struct U64Reader<R: Read> {
inner: R,
}
impl<R: Read> CodecReader<u64> for U64Reader<R> {
type Error = std::io::Error;
fn read(&mut self) -> Result<Option<u64>, Self::Error> {
let mut buf = [0u8; 8];
match self.inner.read(&mut buf[..1]) {
Ok(0) => Ok(None),
Ok(_) => {
self.inner.read_exact(&mut buf[1..])?;
Ok(Some(u64::from_le_bytes(buf)))
}
Err(e) => Err(e),
}
}
}
impl Codec<u64> for U64Codec {
type Error = std::io::Error;
type Writer<W: Write> = U64Writer<W>;
type Reader<R: Read> = U64Reader<R>;
fn writer<W: Write>(&self, dest: W) -> U64Writer<W> {
U64Writer {
inner: BufWriter::new(dest),
}
}
fn reader<R: Read>(&self, source: R) -> U64Reader<R> {
U64Reader { inner: source }
}
}
#[test]
fn codec_round_trips_single_item() {
let mut buf = Vec::new();
let mut writer = U64Codec.writer(&mut buf);
writer.write(&42u64).expect("write should succeed");
writer.finish().expect("finish should succeed");
assert_eq!(buf.len(), 8, "u64 should write exactly 8 bytes");
let mut reader = U64Codec.reader(std::io::Cursor::new(&buf));
let item = reader
.read()
.expect("read should succeed")
.expect("should find one item");
assert_eq!(item, 42, "round-tripped value should match");
}
#[test]
fn codec_round_trips_multiple_items() {
let values = vec![1u64, 2, 3, u64::MAX, 0];
let mut buf = Vec::new();
let mut writer = U64Codec.writer(&mut buf);
for v in &values {
writer.write(v).expect("write should succeed");
}
writer.finish().expect("finish should succeed");
let mut reader = U64Codec.reader(std::io::Cursor::new(&buf));
let mut recovered = Vec::new();
while let Some(v) = reader.read().expect("read should succeed") {
recovered.push(v);
}
assert_eq!(
recovered, values,
"all round-tripped values should match in order"
);
}
#[test]
fn codec_read_empty_returns_none() {
let buf: Vec<u8> = Vec::new();
let mut reader = U64Codec.reader(std::io::Cursor::new(&buf));
let result = reader.read().expect("reading empty should not error");
assert!(
result.is_none(),
"reading from an empty source should return None"
);
}
#[test]
fn codec_read_truncated_returns_error() {
let buf = vec![0u8; 3]; let mut reader = U64Codec.reader(std::io::Cursor::new(&buf));
let result = reader.read();
assert!(
result.is_err(),
"reading a partial record should return an error, not None"
);
}
}