use std::{fs::File, io::Write, path::Path};
use crate::{Header, Record, RECORD_SIZE};
const DEFAULT_BUFFER_SIZE: usize = 48 * 1024 * RECORD_SIZE;
pub type BoxedWriter = Box<dyn Write + Send>;
#[derive(Clone)]
pub struct Writer<W: Write> {
inner: W,
buffer: Vec<u8>,
pos: usize,
records_written: u64,
}
impl<W: Write> Writer<W> {
pub fn new(mut inner: W, header: Header) -> crate::Result<Self> {
let header_bytes: &[u8] = bytemuck::bytes_of(&header);
inner.write_all(header_bytes)?;
let buffer = vec![0u8; DEFAULT_BUFFER_SIZE];
Ok(Self {
inner,
buffer,
pos: 0,
records_written: 0,
})
}
pub fn new_headless(inner: W) -> Self {
let buffer = vec![0u8; DEFAULT_BUFFER_SIZE];
Self {
inner,
buffer,
pos: 0,
records_written: 0,
}
}
pub fn records_written(&self) -> u64 {
self.records_written
}
fn flush_buffer(&mut self) -> crate::Result<()> {
if self.pos > 0 {
self.inner.write_all(&self.buffer[..self.pos])?;
self.pos = 0;
}
Ok(())
}
pub fn write_record(&mut self, record: &Record) -> crate::Result<()> {
if self.pos + RECORD_SIZE > self.buffer.len() {
self.flush_buffer()?;
}
let record_bytes: &[u8] = bytemuck::bytes_of(record);
self.buffer[self.pos..self.pos + RECORD_SIZE].copy_from_slice(record_bytes);
self.pos += RECORD_SIZE;
self.records_written += 1;
Ok(())
}
pub fn write_batch(&mut self, records: &[Record]) -> crate::Result<()> {
let records_bytes: &[u8] = bytemuck::cast_slice(records);
self.write_slice(records_bytes)
}
fn write_slice(&mut self, buffer: &[u8]) -> crate::Result<()> {
let num_records = buffer.len() / RECORD_SIZE;
if buffer.len() > self.buffer.len() {
self.flush_buffer()?;
self.inner.write_all(buffer)?;
self.records_written += num_records as u64;
return Ok(());
}
let mut remaining = buffer;
while !remaining.is_empty() {
let available = self.buffer.len() - self.pos;
let to_write = remaining.len().min(available);
self.buffer[self.pos..self.pos + to_write].copy_from_slice(&remaining[..to_write]);
self.pos += to_write;
remaining = &remaining[to_write..];
if self.pos >= self.buffer.len() {
self.flush_buffer()?;
}
}
self.records_written += num_records as u64;
Ok(())
}
pub fn write_iter<I>(&mut self, records: I) -> crate::Result<()>
where
I: Iterator<Item = Record>,
{
for record in records {
self.write_record(&record)?;
}
Ok(())
}
pub fn finish(&mut self) -> crate::Result<()> {
self.flush_buffer()?;
self.inner.flush()?;
Ok(())
}
pub fn ingest(&mut self, other: &mut Writer<Vec<u8>>) -> crate::Result<()> {
other.flush_buffer()?;
self.write_slice(&other.inner)?;
other.inner.clear();
Ok(())
}
pub fn into_inner(self) -> W {
use std::mem::ManuallyDrop;
let manual = ManuallyDrop::new(self);
unsafe { std::ptr::read(&manual.inner) }
}
}
impl<W: Write> Drop for Writer<W> {
fn drop(&mut self) {
self.finish().ok();
}
}
impl Writer<BoxedWriter> {
pub fn from_path<P: AsRef<Path>>(path: P, header: Header) -> crate::Result<Self> {
let file = File::create(path)?;
Self::new(Box::new(file), header)
}
pub fn from_stdout(header: Header) -> crate::Result<Self> {
Self::new(Box::new(std::io::stdout()), header)
}
pub fn from_optional_path<P: AsRef<Path>>(
path: Option<P>,
header: Header,
) -> crate::Result<Self> {
match path {
Some(path) => Self::from_path(path, header),
None => Self::from_stdout(header),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Header, Reader, Record};
use std::io::Cursor;
#[test]
fn test_writer_creation() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let writer = Writer::new(buffer, header).unwrap();
assert_eq!(writer.records_written(), 0);
let buffer = writer.into_inner();
assert_eq!(buffer.len(), 32); }
#[test]
fn test_writer_headless() {
let buffer = Vec::new();
let writer = Writer::new_headless(buffer);
assert_eq!(writer.records_written(), 0);
let buffer = writer.into_inner();
assert_eq!(buffer.len(), 0);
}
#[test]
fn test_single_record_write() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
let record = Record::new(0x1234, 0x5678, 42);
writer.write_record(&record).unwrap();
assert_eq!(writer.records_written(), 1);
writer.finish().unwrap();
let buffer = writer.into_inner();
assert_eq!(buffer.len(), 32 + 24); }
#[test]
fn test_batch_write() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
let records = vec![
Record::new(1, 2, 3),
Record::new(4, 5, 6),
Record::new(7, 8, 9),
];
writer.write_batch(&records).unwrap();
assert_eq!(writer.records_written(), 3);
writer.finish().unwrap();
let buffer = writer.into_inner();
assert_eq!(buffer.len(), 32 + 3 * 24); }
#[test]
fn test_iterator_write() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
let records = (0..100).map(|i| Record::new(i, i * 2, i * 3));
writer.write_iter(records).unwrap();
assert_eq!(writer.records_written(), 100);
}
#[test]
fn test_large_batch_direct_write() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
let large_batch: Vec<Record> = (0..100_000).map(|i| Record::new(i, i * 2, i * 3)).collect();
writer.write_batch(&large_batch).unwrap();
assert_eq!(writer.records_written(), 100_000);
}
#[test]
fn test_writer_ingest() {
let header = Header::new(16, 12);
let main_buffer = Vec::new();
let mut main_writer = Writer::new(main_buffer, header).unwrap();
let aux_buffer = Vec::new();
let mut aux_writer = Writer::new_headless(aux_buffer);
aux_writer.write_record(&Record::new(1, 2, 3)).unwrap();
aux_writer.write_record(&Record::new(4, 5, 6)).unwrap();
main_writer.ingest(&mut aux_writer).unwrap();
assert_eq!(main_writer.records_written(), 2);
assert!(aux_writer.inner.is_empty());
}
#[test]
fn test_writer_roundtrip() {
let header = Header::new(20, 10);
let original_records = vec![
Record::new(0x12345, 0x67890, 100),
Record::new(0xABCDE, 0xF0123, 200),
];
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
writer.write_batch(&original_records).unwrap();
writer.finish().unwrap();
let buffer = writer.into_inner();
let cursor = Cursor::new(buffer);
let reader = Reader::new(cursor).unwrap();
let read_records: Vec<Record> = reader.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(original_records, read_records);
}
#[test]
fn test_buffer_flushing() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
let records_to_fill = DEFAULT_BUFFER_SIZE / RECORD_SIZE;
for i in 0..records_to_fill {
writer.write_record(&Record::new(i as u64, 0, 0)).unwrap();
}
let buffer_len_before = writer.inner.len();
writer.write_record(&Record::new(999, 0, 0)).unwrap();
let buffer_len_after = writer.inner.len();
assert!(buffer_len_after > buffer_len_before);
}
#[test]
fn test_records_written_counter() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
assert_eq!(writer.records_written(), 0);
writer.write_record(&Record::new(1, 2, 3)).unwrap();
assert_eq!(writer.records_written(), 1);
let batch = vec![Record::new(4, 5, 6), Record::new(7, 8, 9)];
writer.write_batch(&batch).unwrap();
assert_eq!(writer.records_written(), 3);
let iter_records = (10..15).map(|i| Record::new(i, i, i));
writer.write_iter(iter_records).unwrap();
assert_eq!(writer.records_written(), 8);
}
#[test]
fn test_drop_behavior() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
writer.write_record(&Record::new(1, 2, 3)).unwrap();
drop(writer);
}
#[test]
fn test_empty_batch() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
let empty_batch: Vec<Record> = vec![];
writer.write_batch(&empty_batch).unwrap();
assert_eq!(writer.records_written(), 0);
}
#[test]
fn test_mixed_write_methods() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
writer.write_record(&Record::new(1, 2, 3)).unwrap();
let batch = vec![Record::new(4, 5, 6), Record::new(7, 8, 9)];
writer.write_batch(&batch).unwrap();
let iter_records = (10..13).map(|i| Record::new(i, i * 2, i * 3));
writer.write_iter(iter_records).unwrap();
assert_eq!(writer.records_written(), 6);
writer.finish().unwrap();
let buffer = writer.into_inner();
let cursor = Cursor::new(buffer);
let reader = Reader::new(cursor).unwrap();
let read_records: Vec<Record> = reader.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(read_records.len(), 6);
assert_eq!(read_records[0], Record::new(1, 2, 3));
assert_eq!(read_records[1], Record::new(4, 5, 6));
assert_eq!(read_records[5], Record::new(12, 24, 36));
}
#[test]
fn test_writer_clone() {
let header = Header::new(16, 12);
let buffer = Vec::new();
let writer = Writer::new(buffer, header).unwrap();
let writer_clone = writer.clone();
assert_eq!(writer.records_written(), writer_clone.records_written());
}
}