use std::cmp::Ordering;
use std::io::{
Error,
ErrorKind,
Read,
Result,
Write,
copy,
};
use std::string::FromUtf8Error;
use super::allocation::try_reserve_vec;
use crate::{
Leb128DecodeError,
ReadExt,
};
const COPY_BUFFER_SIZE: usize = 16 * 1024;
const COMPARE_BUFFER_SIZE: usize = 16 * 1024;
pub enum Streams {}
impl Streams {
#[inline]
pub fn copy<R, W>(reader: &mut R, writer: &mut W) -> Result<u64>
where
R: Read + ?Sized,
W: Write + ?Sized,
{
copy(reader, writer)
}
#[inline]
pub fn copy_at_most<R, W>(reader: &mut R, writer: &mut W, max_bytes: u64) -> Result<u64>
where
R: Read + ?Sized,
W: Write + ?Sized,
{
let mut reader = reader;
let mut writer = writer;
copy_at_most_impl(&mut reader, &mut writer, max_bytes)
}
#[inline]
pub fn copy_to_end_limited<R, W>(reader: &mut R, writer: &mut W, max_bytes: u64) -> Result<u64>
where
R: Read + ?Sized,
W: Write + ?Sized,
{
let mut reader = reader;
let mut writer = writer;
copy_to_end_limited_impl(&mut reader, &mut writer, max_bytes)
}
#[inline]
pub fn content_eq(left: &mut dyn Read, right: &mut dyn Read) -> Result<bool> {
Ok(Self::compare_content(left, right)? == Ordering::Equal)
}
pub fn compare_content(left: &mut dyn Read, right: &mut dyn Read) -> Result<Ordering> {
let mut left_buffer = [0; COMPARE_BUFFER_SIZE];
let mut right_buffer = [0; COMPARE_BUFFER_SIZE];
loop {
let left_count = left.read_exact_or_eof(&mut left_buffer)?;
let right_count = right.read_exact_or_eof(&mut right_buffer)?;
let n = left_count.min(right_count);
for index in 0..n {
match left_buffer[index].cmp(&right_buffer[index]) {
Ordering::Equal => {}
ordering => return Ok(ordering),
}
}
match left_count.cmp(&right_count) {
Ordering::Equal if left_count == 0 => return Ok(Ordering::Equal),
Ordering::Equal => {}
ordering => return Ok(ordering),
}
}
}
}
fn copy_at_most_impl(reader: &mut dyn Read, writer: &mut dyn Write, max_bytes: u64) -> Result<u64> {
let mut buffer = [0; COPY_BUFFER_SIZE];
let mut remaining = max_bytes;
let mut copied = 0;
while remaining > 0 {
let requested = remaining.min(COPY_BUFFER_SIZE as u64) as usize;
match reader.read(&mut buffer[..requested]) {
Ok(0) => break,
Ok(count) => {
writer.write_all(&buffer[..count])?;
let count = count as u64;
remaining -= count;
copied += count;
}
Err(error) => {
if error.kind() == ErrorKind::Interrupted {
continue;
}
return Err(error);
}
}
}
Ok(copied)
}
fn copy_to_end_limited_impl(reader: &mut dyn Read, writer: &mut dyn Write, max_bytes: u64) -> Result<u64> {
let copied = copy_at_most_impl(reader, writer, max_bytes)?;
if copied < max_bytes {
return Ok(copied);
}
if has_more_input(reader)? {
return Err(Error::new(
ErrorKind::InvalidData,
format!("input exceeds maximum length of {max_bytes} bytes"),
));
}
Ok(copied)
}
fn has_more_input(reader: &mut dyn Read) -> Result<bool> {
let mut byte = [0];
loop {
match reader.read(&mut byte) {
Ok(0) => return Ok(false),
Ok(_) => return Ok(true),
Err(error) => {
if error.kind() == ErrorKind::Interrupted {
continue;
}
return Err(error);
}
}
}
}
#[inline]
pub(crate) fn read_leb128_payload<const N: usize, T, R, F>(reader: &mut R, decode: F) -> Result<T>
where
R: Read + ?Sized,
F: FnOnce(&[u8]) -> std::result::Result<(T, usize), Leb128DecodeError>,
{
let mut bytes = [0u8; N];
for index in 0..N {
let target = one_byte_slice(&mut bytes, index);
reader.read_exact(target)?;
if bytes[index] & 0x80 == 0 {
return decode(&bytes)
.map(|(value, _)| value)
.map_err(|error| Error::new(ErrorKind::InvalidData, error));
}
}
decode(&bytes)
.map(|(value, _)| value)
.map_err(|error| Error::new(ErrorKind::InvalidData, error))
}
#[inline]
fn one_byte_slice(bytes: &mut [u8], index: usize) -> &mut [u8] {
unsafe { core::slice::from_raw_parts_mut(bytes.as_mut_ptr().add(index), 1) }
}
pub(crate) fn read_utf8_payload<R>(reader: &mut R, len: usize, max_len: usize) -> Result<String>
where
R: Read + ?Sized,
{
if len > max_len {
return Err(length_exceeded_error(len, max_len));
}
let mut bytes = Vec::new();
try_reserve_vec(&mut bytes, len)?;
bytes.resize(len, 0);
reader.read_exact(&mut bytes)?;
String::from_utf8(bytes).map_err(invalid_utf8_error)
}
pub(crate) fn write_utf8_payload<W>(writer: &mut W, value: &str) -> Result<()>
where
W: Write + ?Sized,
{
writer.write_all(value.as_bytes())
}
pub(crate) fn write_utf8_string_with_u16_len<W, F>(writer: &mut W, value: &str, write_len: F) -> Result<()>
where
W: Write + ?Sized,
F: FnOnce(&mut W, u16) -> Result<()>,
{
let bytes = value.as_bytes();
write_len(writer, checked_u16_len(bytes.len())?)?;
writer.write_all(bytes)
}
pub(crate) fn write_utf8_string_with_u32_len<W, F>(writer: &mut W, value: &str, write_len: F) -> Result<()>
where
W: Write + ?Sized,
F: FnOnce(&mut W, u32) -> Result<()>,
{
let bytes = value.as_bytes();
write_len(writer, checked_u32_len(bytes.len())?)?;
writer.write_all(bytes)
}
pub(crate) fn checked_u16_len(len: usize) -> Result<u16> {
u16::try_from(len).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
format!("string length {len} exceeds maximum encodable u16 length"),
)
})
}
pub(crate) fn checked_u32_len(len: usize) -> Result<u32> {
if len > u32::MAX as usize {
Err(Error::new(
ErrorKind::InvalidInput,
format!("string length {len} exceeds maximum encodable u32 length"),
))
} else {
Ok(len as u32)
}
}
fn length_exceeded_error(len: usize, max_len: usize) -> Error {
Error::new(
ErrorKind::InvalidData,
format!("string length {len} exceeds maximum length of {max_len} bytes"),
)
}
fn invalid_utf8_error(error: FromUtf8Error) -> Error {
Error::new(
ErrorKind::InvalidData,
format!("length-prefixed string is not valid UTF-8: {error}"),
)
}