use std::io::{
Error,
ErrorKind,
Result,
Seek,
SeekFrom,
Write,
};
use crate::buffered::DEFAULT_BUFFER_CAPACITY;
use crate::{
Buffer,
Output,
Seekable,
SeekableOutput,
};
#[derive(Debug)]
pub struct BufferedOutput<O>
where
O: Output,
O::Item: Copy + Default,
{
inner: O,
buffer: Buffer<O::Item>,
}
impl<O> BufferedOutput<O>
where
O: Output,
O::Item: Copy + Default,
{
#[inline(always)]
#[must_use]
pub fn new(inner: O) -> Self {
Self::with_capacity(inner, DEFAULT_BUFFER_CAPACITY)
}
#[inline(always)]
#[must_use]
pub fn with_capacity(inner: O, capacity: usize) -> Self {
Self {
inner,
buffer: Buffer::with_capacity(capacity),
}
}
#[inline(always)]
pub const fn inner(&self) -> &O {
&self.inner
}
#[inline(always)]
pub fn inner_mut(&mut self) -> &mut O {
&mut self.inner
}
#[inline(always)]
#[must_use]
pub fn into_parts(self) -> (O, Vec<O::Item>) {
let pending = self.buffer.available_slice().to_vec();
(self.inner, pending)
}
#[inline(always)]
#[must_use]
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
#[inline(always)]
#[must_use]
pub fn spare_capacity(&self) -> usize {
self.buffer.spare_capacity()
}
#[inline(always)]
#[must_use]
pub fn spare_slice_mut(&mut self) -> &mut [O::Item] {
self.buffer.spare_slice_mut()
}
#[inline(always)]
#[must_use]
pub fn spare_raw_parts_mut(&mut self) -> (&mut [O::Item], usize, usize) {
self.buffer.spare_raw_parts_mut()
}
#[inline(always)]
pub fn advance(&mut self, count: usize) {
assert!(
count <= self.spare_capacity(),
"cannot advance beyond spare output buffer"
);
unsafe {
self.buffer.advance_unchecked(count);
}
}
#[inline(always)]
pub unsafe fn advance_unchecked(&mut self, count: usize) {
unsafe {
self.buffer.advance_unchecked(count);
}
}
pub fn ensure_spare_capacity(&mut self, count: usize) -> Result<()> {
if count > self.buffer.capacity() {
return Err(Error::new(
ErrorKind::InvalidInput,
"requested spare capacity exceeds buffered output capacity",
));
}
if self.spare_capacity() < count {
self.flush_buffer()?;
}
Ok(())
}
#[inline]
pub unsafe fn write_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<usize> {
debug_assert!(
input_index
.checked_add(count)
.is_some_and(|end| end <= input.len()),
"unchecked write range exceeds input buffer"
);
if count < self.spare_capacity() {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(count)
} else {
unsafe { self.write_cold(input, input_index, count) }
}
}
#[inline]
pub unsafe fn write_all_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<()> {
debug_assert!(
input_index
.checked_add(count)
.is_some_and(|end| end <= input.len()),
"unchecked write range exceeds input buffer"
);
if count < self.spare_capacity() {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(())
} else {
unsafe { self.write_all_cold(input, input_index, count) }
}
}
pub fn flush_buffer(&mut self) -> Result<()> {
while !self.buffer.is_empty() {
let position = self.buffer.position();
let available = self.buffer.available();
match unsafe {
self.inner.write_unchecked(
self.buffer.data(),
position,
available,
)
} {
Ok(0) => {
self.buffer.compact();
return Err(Error::new(
ErrorKind::WriteZero,
"failed to write buffered data",
));
}
Ok(written) => {
if let Err(error) = validate_write_count(written, available)
{
self.buffer.compact();
return Err(error);
}
unsafe {
self.buffer.consume_unchecked(written);
}
}
Err(error) if error.kind() == ErrorKind::Interrupted => {}
Err(error) => {
self.buffer.compact();
return Err(error);
}
}
}
self.buffer.clear();
Ok(())
}
#[inline(always)]
pub fn flush(&mut self) -> Result<()> {
self.flush_buffer()
.and_then(|()| Output::flush(&mut self.inner))
}
#[inline(always)]
pub fn seek(&mut self, position: SeekFrom) -> Result<u64>
where
O: SeekableOutput,
{
self.flush_buffer()
.and_then(|()| Seekable::seek(&mut self.inner, position))
}
#[inline(always)]
unsafe fn write_to_buffer_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) {
unsafe {
self.buffer.copy_from_unchecked(input, input_index, count);
}
}
#[inline(always)]
unsafe fn write_inner_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<usize> {
let written =
unsafe { self.inner.write_unchecked(input, input_index, count) }?;
validate_write_count(written, count)?;
Ok(written)
}
unsafe fn write_all_inner_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<()> {
let mut written = 0;
while written < count {
let remaining = count - written;
match unsafe {
self.write_inner_unchecked(
input,
input_index + written,
remaining,
)
} {
Ok(0) => {
return Err(Error::new(
ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
Ok(count) => written += count,
Err(error) if error.kind() == ErrorKind::Interrupted => {}
Err(error) => return Err(error),
}
}
Ok(())
}
#[cold]
#[inline(never)]
unsafe fn write_all_cold(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<()> {
if count > self.spare_capacity() {
self.flush_buffer()?;
}
if count >= self.buffer.capacity() {
unsafe { self.write_all_inner_unchecked(input, input_index, count) }
} else {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(())
}
}
#[cold]
#[inline(never)]
unsafe fn write_cold(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<usize> {
if count > self.spare_capacity() {
self.flush_buffer()?;
}
if count >= self.buffer.capacity() {
unsafe { self.write_inner_unchecked(input, input_index, count) }
} else {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(count)
}
}
}
impl<O> Write for BufferedOutput<O>
where
O: Output<Item = u8>,
{
#[inline(always)]
fn write(&mut self, buffer: &[u8]) -> Result<usize> {
unsafe { self.write_unchecked(buffer, 0, buffer.len()) }
}
#[inline(always)]
fn write_all(&mut self, buffer: &[u8]) -> Result<()> {
unsafe { self.write_all_unchecked(buffer, 0, buffer.len()) }
}
#[inline(always)]
fn flush(&mut self) -> Result<()> {
BufferedOutput::flush(self)
}
}
impl<O> Seek for BufferedOutput<O>
where
O: Output<Item = u8> + Seekable<Item = u8>,
{
#[inline(always)]
fn seek(&mut self, position: SeekFrom) -> Result<u64> {
BufferedOutput::seek(self, position)
}
}
#[inline(always)]
fn validate_write_count(written: usize, requested: usize) -> Result<()> {
if written > requested {
return Err(Error::new(
ErrorKind::InvalidData,
format!(
"writer reported {written} bytes for a {requested}-byte buffer"
),
));
}
Ok(())
}