use std::io::{
Error,
ErrorKind,
Result,
Seek,
SeekFrom,
Write,
};
use crate::buffered::DEFAULT_BUFFER_CAPACITY;
use crate::{
Buffer,
Output,
};
#[derive(Debug)]
pub struct BufferedOutput<O>
where
O: Output,
O::Item: Copy + Default,
{
inner: O,
buffer: Buffer<O::Item>,
}
impl<O> BufferedOutput<O>
where
O: Output,
O::Item: Copy + Default,
{
#[inline(always)]
#[must_use]
pub fn new(inner: O) -> Self {
Self::with_capacity(inner, DEFAULT_BUFFER_CAPACITY)
}
#[inline(always)]
#[must_use]
pub fn with_capacity(inner: O, capacity: usize) -> Self {
Self {
inner,
buffer: Buffer::with_capacity(capacity),
}
}
#[inline(always)]
pub const fn inner(&self) -> &O {
&self.inner
}
#[inline(always)]
pub fn inner_mut(&mut self) -> &mut O {
&mut self.inner
}
#[inline(always)]
#[must_use]
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
#[inline(always)]
#[must_use]
pub fn spare_capacity(&self) -> usize {
self.buffer.spare_capacity()
}
#[inline(always)]
#[must_use]
pub fn spare_buffer_mut(&mut self) -> &mut [O::Item] {
let limit = self.buffer.limit();
&mut self.buffer.data_mut()[limit..]
}
#[inline(always)]
#[must_use]
pub fn spare_raw_parts_mut(&mut self) -> (&mut [O::Item], usize, usize) {
let index = self.buffer.limit();
let count = self.buffer.spare_capacity();
(self.buffer.data_mut(), index, count)
}
#[inline(always)]
pub fn advance(&mut self, count: usize) {
assert!(
count <= self.spare_capacity(),
"cannot advance beyond spare output buffer"
);
unsafe {
self.buffer.advance_unchecked(count);
}
}
#[inline(always)]
pub unsafe fn advance_unchecked(&mut self, count: usize) {
unsafe {
self.buffer.advance_unchecked(count);
}
}
#[inline(always)]
unsafe fn write_to_buffer_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) {
unsafe {
self.buffer.copy_from_unchecked(input, input_index, count);
}
}
}
impl<O> BufferedOutput<O>
where
O: Output,
O::Item: Copy + Default,
{
#[inline(always)]
#[must_use]
pub fn into_parts(self) -> (O, Vec<O::Item>) {
let pending = self.pending_slice().to_vec();
(self.inner, pending)
}
pub fn ensure_spare_capacity(&mut self, count: usize) -> Result<()> {
if count > self.buffer.capacity() {
return Err(Error::new(
ErrorKind::InvalidInput,
"requested spare capacity exceeds buffered output capacity",
));
}
if self.spare_capacity() < count {
self.flush_buffer()?;
}
Ok(())
}
#[inline]
pub unsafe fn write_all_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<()> {
debug_assert!(
input_index
.checked_add(count)
.is_some_and(|end| end <= input.len()),
"unchecked write range exceeds input buffer"
);
if count < self.spare_capacity() {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(())
} else {
unsafe { self.write_all_cold(input, input_index, count) }
}
}
#[cold]
#[inline(never)]
unsafe fn write_all_cold(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<()> {
if count > self.spare_capacity() {
self.flush_buffer()?;
}
if count >= self.buffer.capacity() {
unsafe { self.write_all_inner_unchecked(input, input_index, count) }
} else {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(())
}
}
#[cold]
#[inline(never)]
unsafe fn write_cold(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<usize> {
if count > self.spare_capacity() {
self.flush_buffer()?;
}
if count >= self.buffer.capacity() {
unsafe { self.write_inner_unchecked(input, input_index, count) }
} else {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(count)
}
}
pub fn flush_buffer(&mut self) -> Result<()> {
while !self.buffer.is_empty() {
let position = self.buffer.position();
let available = self.buffer.available();
match unsafe {
self.inner.write_unchecked(
self.buffer.data(),
position,
available,
)
} {
Ok(0) => {
self.buffer.compact();
return Err(Error::new(
ErrorKind::WriteZero,
"failed to write buffered data",
));
}
Ok(written) => {
if let Err(error) = validate_write_count(written, available)
{
self.buffer.compact();
return Err(error);
}
unsafe {
self.buffer.consume_unchecked(written);
}
}
Err(error) if error.kind() == ErrorKind::Interrupted => {}
Err(error) => {
self.buffer.compact();
return Err(error);
}
}
}
self.buffer.clear();
Ok(())
}
#[inline(always)]
fn flush_all(&mut self) -> Result<()> {
self.flush_buffer()
.and_then(|()| Output::flush(&mut self.inner))
}
#[inline(always)]
pub fn flush(&mut self) -> Result<()> {
self.flush_all()
}
#[inline]
pub unsafe fn write_from_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<usize> {
debug_assert!(
input_index
.checked_add(count)
.is_some_and(|end| end <= input.len()),
"unchecked write range exceeds input buffer"
);
if count < self.spare_capacity() {
unsafe {
self.write_to_buffer_unchecked(input, input_index, count);
}
Ok(count)
} else {
unsafe { self.write_cold(input, input_index, count) }
}
}
#[inline(always)]
fn flush_then_seek(&mut self, position: SeekFrom) -> Result<u64>
where
O: Seek,
{
self.flush_buffer().and_then(|()| self.inner.seek(position))
}
#[inline(always)]
fn pending_slice(&self) -> &[O::Item] {
&self.buffer.data()[self.buffer.position()..self.buffer.limit()]
}
#[inline(always)]
unsafe fn write_inner_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<usize> {
let written =
unsafe { self.inner.write_unchecked(input, input_index, count) }?;
validate_write_count(written, count)?;
Ok(written)
}
unsafe fn write_all_inner_unchecked(
&mut self,
input: &[O::Item],
input_index: usize,
count: usize,
) -> Result<()> {
let mut written = 0;
while written < count {
let remaining = count - written;
match unsafe {
self.write_inner_unchecked(
input,
input_index + written,
remaining,
)
} {
Ok(0) => {
return Err(Error::new(
ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
Ok(count) => written += count,
Err(error) if error.kind() == ErrorKind::Interrupted => {}
Err(error) => return Err(error),
}
}
Ok(())
}
}
impl<O> Write for BufferedOutput<O>
where
O: Output<Item = u8>,
{
#[inline(always)]
fn write(&mut self, buffer: &[u8]) -> Result<usize> {
unsafe { self.write_from_unchecked(buffer, 0, buffer.len()) }
}
#[inline(always)]
fn write_all(&mut self, buffer: &[u8]) -> Result<()> {
unsafe { self.write_all_unchecked(buffer, 0, buffer.len()) }
}
#[inline(always)]
fn flush(&mut self) -> Result<()> {
self.flush_all()
}
}
impl<O> Seek for BufferedOutput<O>
where
O: Output<Item = u8> + Seek,
{
#[inline(always)]
fn seek(&mut self, position: SeekFrom) -> Result<u64> {
self.flush_then_seek(position)
}
}
#[inline(always)]
fn validate_write_count(written: usize, requested: usize) -> Result<()> {
if written > requested {
return Err(Error::new(
ErrorKind::InvalidData,
format!(
"writer reported {written} bytes for a {requested}-byte buffer"
),
));
}
Ok(())
}