use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;
use tempfile::NamedTempFile;
const READ_CHUNK_SIZE: usize = 64 * 1024;
pub enum Buffer {
InMemory(Vec<u8>),
Spilled {
writer: BufWriter<NamedTempFile>,
len: u64,
},
}
impl Buffer {
pub fn new() -> Self {
Buffer::InMemory(Vec::new())
}
pub fn len(&self) -> u64 {
match self {
Buffer::InMemory(v) => v.len() as u64,
Buffer::Spilled { len, .. } => *len,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn drain_reader<R: Read>(
&mut self,
mut reader: R,
threshold: usize,
spill_dir: &Path,
) -> io::Result<()> {
let mut chunk = vec![0u8; READ_CHUNK_SIZE];
loop {
#[cfg(feature = "cli")]
if crate::signal::is_cancelled() {
return Err(io::Error::new(
io::ErrorKind::Interrupted,
"rusty-sponge: cancelled by signal",
));
}
let n = reader.read(&mut chunk)?;
if n == 0 {
break;
}
self.append(&chunk[..n], threshold, spill_dir)?;
}
Ok(())
}
pub fn append(&mut self, bytes: &[u8], threshold: usize, spill_dir: &Path) -> io::Result<()> {
let threshold_u64 = threshold as u64;
let projected_len = self.len() + bytes.len() as u64;
if matches!(self, Buffer::InMemory(_)) && projected_len > threshold_u64 {
self.transition_to_spilled(spill_dir)?;
}
match self {
Buffer::InMemory(v) => v.extend_from_slice(bytes),
Buffer::Spilled { writer, len } => {
writer.write_all(bytes)?;
*len += bytes.len() as u64;
}
}
Ok(())
}
pub fn transition_to_spilled(&mut self, spill_dir: &Path) -> io::Result<()> {
if let Buffer::InMemory(bytes) = std::mem::replace(self, Buffer::InMemory(Vec::new())) {
let tempfile = tempfile::Builder::new()
.prefix(".rusty-sponge-spill-")
.tempfile_in(spill_dir)?;
let mut writer = BufWriter::with_capacity(READ_CHUNK_SIZE, tempfile);
writer.write_all(&bytes)?;
let len = bytes.len() as u64;
*self = Buffer::Spilled { writer, len };
}
Ok(())
}
pub fn write_to<W: Write>(self, mut out: W) -> io::Result<()> {
match self {
Buffer::InMemory(v) => out.write_all(&v),
Buffer::Spilled { writer, .. } => {
let mut tempfile = writer
.into_inner()
.map_err(|e| io::Error::other(format!("BufWriter flush failed: {e}")))?;
tempfile.as_file_mut().seek(SeekFrom::Start(0))?;
let mut chunk = vec![0u8; READ_CHUNK_SIZE];
let mut reader = tempfile.as_file();
loop {
let n = reader.read(&mut chunk)?;
if n == 0 {
break;
}
out.write_all(&chunk[..n])?;
}
Ok(())
}
}
}
}
impl Default for Buffer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn empty_buffer_has_len_zero() {
let buf = Buffer::new();
assert_eq!(buf.len(), 0);
}
#[test]
fn drain_small_input_stays_in_memory() {
let tmpdir = tempfile::tempdir().unwrap();
let mut buf = Buffer::new();
let input = Cursor::new(b"hello world\n");
buf.drain_reader(input, 1024 * 1024, tmpdir.path()).unwrap();
assert!(matches!(buf, Buffer::InMemory(_)));
assert_eq!(buf.len(), 12);
}
#[test]
fn drain_large_input_transitions_to_spilled() {
let tmpdir = tempfile::tempdir().unwrap();
let mut buf = Buffer::new();
let big = vec![0xAAu8; 256 * 1024];
buf.drain_reader(Cursor::new(&big), 64 * 1024, tmpdir.path())
.unwrap();
assert!(matches!(buf, Buffer::Spilled { .. }));
assert_eq!(buf.len(), 256 * 1024);
}
#[test]
fn write_to_roundtrips_in_memory() {
let tmpdir = tempfile::tempdir().unwrap();
let mut buf = Buffer::new();
buf.drain_reader(Cursor::new(b"abc\n"), 1024 * 1024, tmpdir.path())
.unwrap();
let mut out = Vec::new();
buf.write_to(&mut out).unwrap();
assert_eq!(out, b"abc\n");
}
#[test]
fn write_to_roundtrips_spilled() {
let tmpdir = tempfile::tempdir().unwrap();
let mut buf = Buffer::new();
let big = (0u8..=255u8).cycle().take(256 * 1024).collect::<Vec<_>>();
buf.drain_reader(Cursor::new(&big), 1024, tmpdir.path())
.unwrap();
assert!(matches!(buf, Buffer::Spilled { .. }));
let mut out = Vec::new();
buf.write_to(&mut out).unwrap();
assert_eq!(out, big);
}
#[test]
fn binary_bytes_pass_through_unchanged() {
let tmpdir = tempfile::tempdir().unwrap();
let mut buf = Buffer::new();
let bytes: &[u8] = &[0x00, 0xFE, 0xFF, 0xC3, 0x28, 0xA0, 0xA1];
buf.drain_reader(Cursor::new(bytes), 1024 * 1024, tmpdir.path())
.unwrap();
let mut out = Vec::new();
buf.write_to(&mut out).unwrap();
assert_eq!(out, bytes);
}
#[test]
fn empty_input_writes_zero_bytes() {
let tmpdir = tempfile::tempdir().unwrap();
let mut buf = Buffer::new();
buf.drain_reader(Cursor::new(&[][..]), 1024 * 1024, tmpdir.path())
.unwrap();
let mut out = Vec::new();
buf.write_to(&mut out).unwrap();
assert_eq!(out, Vec::<u8>::new());
}
}