use std::io::Read;
use std::io::Write;
use std::io::{self};
use crate::ExtractionError;
const COPY_BUFFER_SIZE: usize = 64 * 1024;
#[derive(Debug)]
pub struct CopyBuffer {
#[allow(clippy::large_stack_arrays)]
buf: [u8; COPY_BUFFER_SIZE],
}
impl CopyBuffer {
#[inline]
#[must_use]
#[allow(clippy::large_stack_arrays)]
pub fn new() -> Self {
Self {
buf: [0u8; COPY_BUFFER_SIZE],
}
}
#[inline]
#[must_use]
pub fn size(&self) -> usize {
COPY_BUFFER_SIZE
}
}
impl Default for CopyBuffer {
fn default() -> Self {
Self::new()
}
}
#[inline]
pub fn copy_with_buffer<R: Read, W: Write>(
reader: &mut R,
writer: &mut W,
buffer: &mut CopyBuffer,
) -> Result<u64, ExtractionError> {
let mut total: u64 = 0;
loop {
let bytes_read = match reader.read(&mut buffer.buf) {
Ok(0) => break,
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(ExtractionError::Io(e)),
};
writer
.write_all(&buffer.buf[..bytes_read])
.map_err(ExtractionError::Io)?;
total = total
.checked_add(bytes_read as u64)
.ok_or(ExtractionError::QuotaExceeded {
resource: crate::QuotaResource::IntegerOverflow,
})?;
}
Ok(total)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_copy_buffer_new() {
let buffer = CopyBuffer::new();
assert_eq!(buffer.size(), 64 * 1024);
}
#[test]
fn test_copy_buffer_default() {
let buffer = CopyBuffer::default();
assert_eq!(buffer.size(), 64 * 1024);
}
#[test]
fn test_copy_empty_source() {
let mut buffer = CopyBuffer::new();
let mut input = Cursor::new(Vec::<u8>::new());
let mut output = Vec::new();
let result = copy_with_buffer(&mut input, &mut output, &mut buffer);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0);
assert_eq!(output.len(), 0);
}
#[test]
fn test_copy_small_data() {
let mut buffer = CopyBuffer::new();
let input_data = b"Hello, World!";
let mut input = Cursor::new(input_data);
let mut output = Vec::new();
let result = copy_with_buffer(&mut input, &mut output, &mut buffer);
assert!(result.is_ok());
assert_eq!(result.unwrap(), input_data.len() as u64);
assert_eq!(output, input_data);
}
#[test]
fn test_copy_large_data() {
let mut buffer = CopyBuffer::new();
let input_data = vec![0x42u8; 1024 * 1024];
let mut input = Cursor::new(&input_data);
let mut output = Vec::new();
let result = copy_with_buffer(&mut input, &mut output, &mut buffer);
assert!(result.is_ok());
assert_eq!(result.unwrap(), input_data.len() as u64);
assert_eq!(output, input_data);
}
#[test]
fn test_copy_exact_buffer_size() {
let mut buffer = CopyBuffer::new();
let input_data = vec![0xAAu8; COPY_BUFFER_SIZE];
let mut input = Cursor::new(&input_data);
let mut output = Vec::new();
let result = copy_with_buffer(&mut input, &mut output, &mut buffer);
assert!(result.is_ok());
assert_eq!(result.unwrap(), COPY_BUFFER_SIZE as u64);
assert_eq!(output, input_data);
}
#[test]
fn test_copy_multiple_chunks() {
let mut buffer = CopyBuffer::new();
let input_data = vec![0x55u8; COPY_BUFFER_SIZE * 3 + 1000];
let mut input = Cursor::new(&input_data);
let mut output = Vec::new();
let result = copy_with_buffer(&mut input, &mut output, &mut buffer);
assert!(result.is_ok());
assert_eq!(result.unwrap(), input_data.len() as u64);
assert_eq!(output, input_data);
}
#[test]
fn test_copy_reusable_buffer() {
let mut buffer = CopyBuffer::new();
let data1 = b"First copy";
let mut input1 = Cursor::new(data1);
let mut output1 = Vec::new();
let result1 = copy_with_buffer(&mut input1, &mut output1, &mut buffer);
assert!(result1.is_ok());
assert_eq!(output1, data1);
let data2 = b"Second copy with different data";
let mut input2 = Cursor::new(data2);
let mut output2 = Vec::new();
let result2 = copy_with_buffer(&mut input2, &mut output2, &mut buffer);
assert!(result2.is_ok());
assert_eq!(output2, data2);
}
#[test]
fn test_copy_byte_for_byte_correctness() {
let mut buffer = CopyBuffer::new();
let mut input_data = Vec::new();
for i in 0..=255u8 {
input_data.extend_from_slice(&[i; 256]);
}
let mut input = Cursor::new(&input_data);
let mut output = Vec::new();
let result = copy_with_buffer(&mut input, &mut output, &mut buffer);
assert!(result.is_ok());
assert_eq!(result.unwrap(), input_data.len() as u64);
assert_eq!(output, input_data);
}
#[test]
fn test_copy_with_interrupted_reads() {
use std::io::Error;
use std::io::ErrorKind;
struct InterruptedReader {
data: Vec<u8>,
position: usize,
interrupt_count: usize,
}
impl Read for InterruptedReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.interrupt_count.is_multiple_of(3) && self.position < self.data.len() {
self.interrupt_count += 1;
return Err(Error::new(ErrorKind::Interrupted, "interrupted"));
}
self.interrupt_count += 1;
if self.position >= self.data.len() {
return Ok(0); }
let remaining = self.data.len() - self.position;
let to_read = remaining.min(buf.len());
buf[..to_read].copy_from_slice(&self.data[self.position..self.position + to_read]);
self.position += to_read;
Ok(to_read)
}
}
let test_data = vec![0x42u8; 1000];
let mut reader = InterruptedReader {
data: test_data.clone(),
position: 0,
interrupt_count: 0,
};
let mut buffer = CopyBuffer::new();
let mut output = Vec::new();
let result = copy_with_buffer(&mut reader, &mut output, &mut buffer);
assert!(result.is_ok(), "copy should handle interrupted reads");
assert_eq!(
output, test_data,
"data should be copied correctly despite interruptions"
);
}
#[test]
fn test_copy_with_write_failure() {
use std::io::Error;
use std::io::ErrorKind;
struct FailingWriter {
written: usize,
fail_after: usize,
}
impl Write for FailingWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if self.written >= self.fail_after {
return Err(Error::other("write failed"));
}
let to_write = (self.fail_after - self.written).min(buf.len());
self.written += to_write;
Ok(to_write)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let test_data = vec![0x42u8; 1000];
let mut input = Cursor::new(test_data);
let mut writer = FailingWriter {
written: 0,
fail_after: 500,
};
let mut buffer = CopyBuffer::new();
let result = copy_with_buffer(&mut input, &mut writer, &mut buffer);
assert!(result.is_err(), "copy should propagate write errors");
match result {
Err(ExtractionError::Io(e)) => {
assert_eq!(e.kind(), ErrorKind::Other);
}
_ => panic!("expected IO error"),
}
}
}