use std::io::{self, Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
use memmap2::Mmap;
use tempfile::NamedTempFile;
#[derive(Debug, Clone)]
pub struct LargeBodyBufferConfig {
pub mmap_threshold: usize,
pub max_body_size: usize,
pub temp_dir: Option<PathBuf>,
}
impl Default for LargeBodyBufferConfig {
fn default() -> Self {
Self {
mmap_threshold: 1024 * 1024, max_body_size: 100 * 1024 * 1024, temp_dir: None,
}
}
}
enum BufferStorage {
Memory(Vec<u8>),
Mmap {
file: NamedTempFile,
len: usize,
mmap: Option<Mmap>,
},
}
pub struct LargeBodyBuffer {
config: LargeBodyBufferConfig,
storage: BufferStorage,
total_written: usize,
}
impl LargeBodyBuffer {
pub fn new() -> Self {
Self::with_config(LargeBodyBufferConfig::default())
}
pub fn with_config(config: LargeBodyBufferConfig) -> Self {
Self {
config,
storage: BufferStorage::Memory(Vec::new()),
total_written: 0,
}
}
pub fn config(&self) -> &LargeBodyBufferConfig {
&self.config
}
pub fn len(&self) -> usize {
self.total_written
}
pub fn is_empty(&self) -> bool {
self.total_written == 0
}
pub fn is_mmap(&self) -> bool {
matches!(self.storage, BufferStorage::Mmap { .. })
}
pub fn write_chunk(&mut self, data: &[u8]) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
let new_size = self.total_written + data.len();
if new_size > self.config.max_body_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Body size {} exceeds maximum {}",
new_size, self.config.max_body_size
),
));
}
if new_size >= self.config.mmap_threshold {
self.ensure_mmap()?;
}
match &mut self.storage {
BufferStorage::Memory(vec) => {
vec.extend_from_slice(data);
}
BufferStorage::Mmap { file, len, mmap } => {
*mmap = None;
file.as_file_mut().seek(SeekFrom::End(0))?;
file.as_file_mut().write_all(data)?;
*len = new_size;
}
}
self.total_written = new_size;
Ok(())
}
pub fn as_slice(&mut self) -> io::Result<&[u8]> {
match &mut self.storage {
BufferStorage::Memory(vec) => Ok(vec.as_slice()),
BufferStorage::Mmap { file, len, mmap } => {
if mmap.is_none() {
file.as_file_mut().sync_all()?;
let mapped = unsafe { Mmap::map(file.as_file())? };
*mmap = Some(mapped);
}
Ok(&mmap.as_ref().unwrap()[..*len])
}
}
}
pub fn as_mut_slice(&mut self) -> io::Result<&mut [u8]> {
self.convert_mmap_to_memory()?;
if let BufferStorage::Memory(ref mut vec) = self.storage {
Ok(vec.as_mut_slice())
} else {
unreachable!("convert_mmap_to_memory should have converted to Memory")
}
}
fn convert_mmap_to_memory(&mut self) -> io::Result<()> {
if let BufferStorage::Mmap { file, len, mmap } = &mut self.storage {
*mmap = None;
file.as_file_mut().sync_all()?;
let data_len = *len;
let mut vec = Vec::with_capacity(data_len);
file.as_file_mut().seek(SeekFrom::Start(0))?;
file.as_file_mut().read_to_end(&mut vec)?;
let new_storage = BufferStorage::Memory(vec);
self.storage = new_storage;
}
Ok(())
}
pub fn clear(&mut self) {
self.storage = BufferStorage::Memory(Vec::new());
self.total_written = 0;
}
pub fn into_vec(mut self) -> io::Result<Vec<u8>> {
match &mut self.storage {
BufferStorage::Memory(vec) => Ok(std::mem::take(vec)),
BufferStorage::Mmap { file, len, mmap } => {
*mmap = None;
let mut vec = Vec::with_capacity(*len);
file.as_file_mut().seek(SeekFrom::Start(0))?;
file.as_file_mut().read_to_end(&mut vec)?;
Ok(vec)
}
}
}
fn ensure_mmap(&mut self) -> io::Result<()> {
if matches!(self.storage, BufferStorage::Mmap { .. }) {
return Ok(());
}
let temp_file = match &self.config.temp_dir {
Some(dir) => NamedTempFile::new_in(dir)?,
None => NamedTempFile::new()?,
};
if let BufferStorage::Memory(vec) = &self.storage {
temp_file.as_file().write_all(vec)?;
}
self.storage = BufferStorage::Mmap {
file: temp_file,
len: self.total_written,
mmap: None,
};
Ok(())
}
}
impl Default for LargeBodyBuffer {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for LargeBodyBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LargeBodyBuffer")
.field("len", &self.total_written)
.field("is_mmap", &self.is_mmap())
.field("config", &self.config)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = LargeBodyBufferConfig::default();
assert_eq!(config.mmap_threshold, 1024 * 1024);
assert_eq!(config.max_body_size, 100 * 1024 * 1024);
assert!(config.temp_dir.is_none());
}
#[test]
fn test_new_buffer() {
let buffer = LargeBodyBuffer::new();
assert!(buffer.is_empty());
assert_eq!(buffer.len(), 0);
assert!(!buffer.is_mmap());
}
#[test]
fn test_small_body_stays_in_memory() {
let mut buffer = LargeBodyBuffer::with_config(LargeBodyBufferConfig {
mmap_threshold: 1024,
max_body_size: 10 * 1024,
temp_dir: None,
});
buffer.write_chunk(b"hello world").unwrap();
assert!(!buffer.is_mmap());
assert_eq!(buffer.len(), 11);
let data = buffer.as_slice().unwrap();
assert_eq!(data, b"hello world");
}
#[test]
fn test_large_body_uses_mmap() {
let mut buffer = LargeBodyBuffer::with_config(LargeBodyBufferConfig {
mmap_threshold: 100,
max_body_size: 10 * 1024,
temp_dir: None,
});
let data = vec![0u8; 200];
buffer.write_chunk(&data).unwrap();
assert!(buffer.is_mmap());
assert_eq!(buffer.len(), 200);
let slice = buffer.as_slice().unwrap();
assert_eq!(slice.len(), 200);
}
#[test]
fn test_transition_to_mmap_preserves_data() {
let mut buffer = LargeBodyBuffer::with_config(LargeBodyBufferConfig {
mmap_threshold: 50,
max_body_size: 1024,
temp_dir: None,
});
buffer.write_chunk(b"initial data ").unwrap();
assert!(!buffer.is_mmap());
let padding = vec![b'x'; 50];
buffer.write_chunk(&padding).unwrap();
assert!(buffer.is_mmap());
let slice = buffer.as_slice().unwrap();
assert!(slice.starts_with(b"initial data "));
assert_eq!(slice.len(), 13 + 50);
}
#[test]
fn test_max_body_size_enforced() {
let mut buffer = LargeBodyBuffer::with_config(LargeBodyBufferConfig {
mmap_threshold: 1024,
max_body_size: 100,
temp_dir: None,
});
let data = vec![0u8; 101];
let result = buffer.write_chunk(&data);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[test]
fn test_into_vec_memory() {
let mut buffer = LargeBodyBuffer::new();
buffer.write_chunk(b"test data").unwrap();
let vec = buffer.into_vec().unwrap();
assert_eq!(vec, b"test data");
}
#[test]
fn test_into_vec_mmap() {
let mut buffer = LargeBodyBuffer::with_config(LargeBodyBufferConfig {
mmap_threshold: 10,
max_body_size: 1024,
temp_dir: None,
});
let data = b"this is some larger data that exceeds threshold";
buffer.write_chunk(data).unwrap();
assert!(buffer.is_mmap());
let vec = buffer.into_vec().unwrap();
assert_eq!(vec, data);
}
#[test]
fn test_clear() {
let mut buffer = LargeBodyBuffer::new();
buffer.write_chunk(b"some data").unwrap();
assert!(!buffer.is_empty());
buffer.clear();
assert!(buffer.is_empty());
assert_eq!(buffer.len(), 0);
assert!(!buffer.is_mmap());
}
#[test]
fn test_multiple_chunks() {
let mut buffer = LargeBodyBuffer::new();
buffer.write_chunk(b"chunk1 ").unwrap();
buffer.write_chunk(b"chunk2 ").unwrap();
buffer.write_chunk(b"chunk3").unwrap();
let data = buffer.as_slice().unwrap();
assert_eq!(data, b"chunk1 chunk2 chunk3");
}
#[test]
fn test_empty_chunk() {
let mut buffer = LargeBodyBuffer::new();
buffer.write_chunk(b"").unwrap();
assert!(buffer.is_empty());
buffer.write_chunk(b"data").unwrap();
buffer.write_chunk(b"").unwrap();
assert_eq!(buffer.len(), 4);
}
}