Skip to main content

qubit_io/util/
streams.rs

1/*******************************************************************************
2 *
3 *    Copyright (c) 2026 Haixing Hu.
4 *
5 *    SPDX-License-Identifier: Apache-2.0
6 *
7 *    Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::cmp::Ordering;
11use std::io::{
12    Error,
13    ErrorKind,
14    Read,
15    Result,
16    Write,
17    copy,
18};
19use std::string::FromUtf8Error;
20
21use super::allocation::try_reserve_vec;
22use crate::{
23    Leb128DecodeError,
24    ReadExt,
25};
26
27/// Default buffer size used by stream copy operations.
28const COPY_BUFFER_SIZE: usize = 16 * 1024;
29
30/// Buffer size used by stream comparison operations.
31const COMPARE_BUFFER_SIZE: usize = 16 * 1024;
32
33/// Stream utility namespace.
34///
35/// This type is an uninstantiable namespace for operations involving one or
36/// more [`Read`] or [`Write`] values. The methods do not close or flush the
37/// supplied streams unless the underlying standard-library operation documents
38/// otherwise.
39///
40/// # Examples
41/// ```
42/// use qubit_io::Streams;
43/// use std::io::Cursor;
44///
45/// let mut input = Cursor::new(b"abcdef".to_vec());
46/// let mut output = Vec::new();
47///
48/// let copied = Streams::copy_at_most(&mut input, &mut output, 4)?;
49///
50/// assert_eq!(4, copied);
51/// assert_eq!(b"abcd", output.as_slice());
52/// # Ok::<(), std::io::Error>(())
53/// ```
54pub enum Streams {}
55
56impl Streams {
57    /// Copies all remaining bytes from `reader` to `writer`.
58    ///
59    /// This is a namespace-style wrapper around [`std::io::copy`]. It preserves
60    /// the standard-library behavior, including platform-specific optimized
61    /// copy paths when available.
62    ///
63    /// # Parameters
64    /// - `reader`: Source reader.
65    /// - `writer`: Destination writer.
66    ///
67    /// # Returns
68    /// The number of bytes copied.
69    ///
70    /// # Errors
71    /// Returns the first read or write error reported by the underlying
72    /// streams, using the same error behavior as [`std::io::copy`].
73    #[inline]
74    pub fn copy<R, W>(reader: &mut R, writer: &mut W) -> Result<u64>
75    where
76        R: Read + ?Sized,
77        W: Write + ?Sized,
78    {
79        copy(reader, writer)
80    }
81
82    /// Copies at most `max_bytes` bytes from `reader` to `writer`.
83    ///
84    /// This method stops successfully when either EOF is reached or
85    /// `max_bytes` bytes have been copied. It does not close or flush either
86    /// stream.
87    ///
88    /// # Parameters
89    /// - `reader`: Source reader.
90    /// - `writer`: Destination writer.
91    /// - `max_bytes`: Maximum number of bytes to copy.
92    ///
93    /// # Returns
94    /// The number of bytes copied.
95    ///
96    /// # Errors
97    /// Returns the first non-interrupted read error or write error reported by
98    /// the underlying streams. Interrupted reads are retried.
99    #[inline]
100    pub fn copy_at_most<R, W>(reader: &mut R, writer: &mut W, max_bytes: u64) -> Result<u64>
101    where
102        R: Read + ?Sized,
103        W: Write + ?Sized,
104    {
105        let mut reader = reader;
106        let mut writer = writer;
107        copy_at_most_impl(&mut reader, &mut writer, max_bytes)
108    }
109
110    /// Copies the remaining input if its total length is at most `max_bytes`.
111    ///
112    /// This method copies from the current reader position until EOF. If EOF is
113    /// not reached within `max_bytes` bytes, it returns
114    /// [`std::io::ErrorKind::InvalidData`]. Detecting oversized input consumes
115    /// one excess byte from `reader`; that excess byte is not written to
116    /// `writer`.
117    ///
118    /// # Parameters
119    /// - `reader`: Source reader.
120    /// - `writer`: Destination writer.
121    /// - `max_bytes`: Maximum accepted number of bytes in the remaining input.
122    ///
123    /// # Returns
124    /// The number of bytes copied when EOF is reached within the limit.
125    ///
126    /// # Errors
127    /// Returns [`std::io::ErrorKind::InvalidData`] when the remaining input is
128    /// longer than `max_bytes`. Returns the first non-interrupted read error or
129    /// write error reported by the underlying streams. Interrupted reads are
130    /// retried.
131    #[inline]
132    pub fn copy_to_end_limited<R, W>(reader: &mut R, writer: &mut W, max_bytes: u64) -> Result<u64>
133    where
134        R: Read + ?Sized,
135        W: Write + ?Sized,
136    {
137        let mut reader = reader;
138        let mut writer = writer;
139        copy_to_end_limited_impl(&mut reader, &mut writer, max_bytes)
140    }
141
142    /// Tests whether two readable streams have equal remaining contents.
143    ///
144    /// The comparison starts at each reader's current position and consumes
145    /// both streams until a difference or EOF is found.
146    ///
147    /// # Parameters
148    /// - `left`: First stream.
149    /// - `right`: Second stream.
150    ///
151    /// # Returns
152    /// `true` when both streams produce the same bytes until EOF.
153    ///
154    /// # Errors
155    /// Returns the first read error reported by either stream.
156    #[inline]
157    pub fn content_eq(left: &mut dyn Read, right: &mut dyn Read) -> Result<bool> {
158        Ok(Self::compare_content(left, right)? == Ordering::Equal)
159    }
160
161    /// Lexicographically compares the remaining contents of two readable
162    /// streams.
163    ///
164    /// The comparison starts at each reader's current position and consumes
165    /// both streams until a difference or EOF is found.
166    ///
167    /// # Parameters
168    /// - `left`: First stream.
169    /// - `right`: Second stream.
170    ///
171    /// # Returns
172    /// The lexicographic ordering of the remaining bytes.
173    ///
174    /// # Errors
175    /// Returns the first read error reported by either stream.
176    pub fn compare_content(left: &mut dyn Read, right: &mut dyn Read) -> Result<Ordering> {
177        let mut left_buffer = [0; COMPARE_BUFFER_SIZE];
178        let mut right_buffer = [0; COMPARE_BUFFER_SIZE];
179        loop {
180            let left_count = left.read_exact_or_eof(&mut left_buffer)?;
181            let right_count = right.read_exact_or_eof(&mut right_buffer)?;
182            let n = left_count.min(right_count);
183            for index in 0..n {
184                match left_buffer[index].cmp(&right_buffer[index]) {
185                    Ordering::Equal => {}
186                    ordering => return Ok(ordering),
187                }
188            }
189            match left_count.cmp(&right_count) {
190                Ordering::Equal if left_count == 0 => return Ok(Ordering::Equal),
191                Ordering::Equal => {}
192                ordering => return Ok(ordering),
193            }
194        }
195    }
196}
197
198/// Copies at most `max_bytes` bytes using trait-object I/O endpoints.
199///
200/// # Parameters
201/// - `reader`: Source reader.
202/// - `writer`: Destination writer.
203/// - `max_bytes`: Maximum number of bytes to copy.
204///
205/// # Returns
206/// The number of bytes copied.
207///
208/// # Errors
209/// Returns the first non-interrupted read error or write error reported by the
210/// underlying streams. Interrupted reads are retried.
211fn copy_at_most_impl(reader: &mut dyn Read, writer: &mut dyn Write, max_bytes: u64) -> Result<u64> {
212    let mut buffer = [0; COPY_BUFFER_SIZE];
213    let mut remaining = max_bytes;
214    let mut copied = 0;
215    while remaining > 0 {
216        let requested = remaining.min(COPY_BUFFER_SIZE as u64) as usize;
217        match reader.read(&mut buffer[..requested]) {
218            Ok(0) => break,
219            Ok(count) => {
220                writer.write_all(&buffer[..count])?;
221                let count = count as u64;
222                remaining -= count;
223                copied += count;
224            }
225            Err(error) => {
226                if error.kind() == ErrorKind::Interrupted {
227                    continue;
228                }
229                return Err(error);
230            }
231        }
232    }
233    Ok(copied)
234}
235
236/// Copies the remaining input through trait-object endpoints when it fits.
237///
238/// # Parameters
239/// - `reader`: Source reader.
240/// - `writer`: Destination writer.
241/// - `max_bytes`: Maximum accepted number of bytes in the remaining input.
242///
243/// # Returns
244/// The number of bytes copied when EOF is reached within the limit.
245///
246/// # Errors
247/// Returns [`ErrorKind::InvalidData`] when the remaining input is longer than
248/// `max_bytes`. Returns the first non-interrupted read error or write error
249/// reported by the underlying streams. Interrupted reads are retried.
250fn copy_to_end_limited_impl(reader: &mut dyn Read, writer: &mut dyn Write, max_bytes: u64) -> Result<u64> {
251    let copied = copy_at_most_impl(reader, writer, max_bytes)?;
252    if copied < max_bytes {
253        return Ok(copied);
254    }
255    if has_more_input(reader)? {
256        return Err(Error::new(
257            ErrorKind::InvalidData,
258            format!("input exceeds maximum length of {max_bytes} bytes"),
259        ));
260    }
261    Ok(copied)
262}
263
264/// Returns whether `reader` has at least one more byte.
265///
266/// # Parameters
267/// - `reader`: Source reader to probe.
268///
269/// # Returns
270/// `true` when one extra byte was read, or `false` when EOF was reached.
271///
272/// # Errors
273/// Returns the first non-interrupted read error reported by `reader`.
274fn has_more_input(reader: &mut dyn Read) -> Result<bool> {
275    let mut byte = [0];
276    loop {
277        match reader.read(&mut byte) {
278            Ok(0) => return Ok(false),
279            Ok(_) => return Ok(true),
280            Err(error) => {
281                if error.kind() == ErrorKind::Interrupted {
282                    continue;
283                }
284                return Err(error);
285            }
286        }
287    }
288}
289
290/// Reads one terminated LEB128 payload from a byte stream.
291///
292/// The function fills a fixed-size stack buffer one byte at a time until a
293/// terminating byte is found or the buffer is full, then delegates decoding to
294/// `decode`.
295///
296/// # Parameters
297///
298/// - `reader`: Source reader.
299/// - `decode`: Decoder for the populated stack buffer.
300///
301/// # Returns
302///
303/// Returns the decoded value.
304///
305/// # Errors
306///
307/// Returns an I/O error reported by `reader`, or [`ErrorKind::InvalidData`] when
308/// `decode` rejects the payload.
309#[inline]
310pub(crate) fn read_leb128_payload<const N: usize, T, R, F>(reader: &mut R, decode: F) -> Result<T>
311where
312    R: Read + ?Sized,
313    F: FnOnce(&[u8]) -> std::result::Result<(T, usize), Leb128DecodeError>,
314{
315    let mut bytes = [0u8; N];
316    for index in 0..N {
317        let target = one_byte_slice(&mut bytes, index);
318        reader.read_exact(target)?;
319        if bytes[index] & 0x80 == 0 {
320            return decode(&bytes)
321                .map(|(value, _)| value)
322                .map_err(|error| Error::new(ErrorKind::InvalidData, error));
323        }
324    }
325    decode(&bytes)
326        .map(|(value, _)| value)
327        .map_err(|error| Error::new(ErrorKind::InvalidData, error))
328}
329
330/// Creates a mutable one-byte slice at `index`.
331///
332/// # Parameters
333///
334/// - `bytes`: Fixed-size temporary buffer.
335/// - `index`: Byte index inside `bytes`.
336///
337/// # Returns
338///
339/// Returns a mutable slice containing exactly `bytes[index]`.
340#[inline]
341fn one_byte_slice(bytes: &mut [u8], index: usize) -> &mut [u8] {
342    // SAFETY: Callers pass an index inside the fixed-size local buffer.
343    unsafe { core::slice::from_raw_parts_mut(bytes.as_mut_ptr().add(index), 1) }
344}
345
346/// Reads a UTF-8 payload after its length has already been decoded.
347///
348/// # Parameters
349///
350/// - `reader`: Reader that provides the UTF-8 payload bytes.
351/// - `len`: Payload length in bytes.
352/// - `max_len`: Maximum accepted payload length in bytes.
353///
354/// # Returns
355///
356/// Returns the decoded UTF-8 string.
357///
358/// # Errors
359///
360/// Returns [`ErrorKind::InvalidData`] when `len` exceeds `max_len`, an
361/// allocation error when reserving the output buffer fails, an I/O error from
362/// `reader`, or [`ErrorKind::InvalidData`] when the payload is not valid UTF-8.
363pub(crate) fn read_utf8_payload<R>(reader: &mut R, len: usize, max_len: usize) -> Result<String>
364where
365    R: Read + ?Sized,
366{
367    if len > max_len {
368        return Err(length_exceeded_error(len, max_len));
369    }
370    let mut bytes = Vec::new();
371    try_reserve_vec(&mut bytes, len)?;
372    bytes.resize(len, 0);
373    reader.read_exact(&mut bytes)?;
374    String::from_utf8(bytes).map_err(invalid_utf8_error)
375}
376
377/// Writes a UTF-8 payload without a length prefix.
378///
379/// # Parameters
380///
381/// - `writer`: Destination writer.
382/// - `value`: String slice to write.
383///
384/// # Errors
385///
386/// Returns the I/O error reported by `writer`.
387pub(crate) fn write_utf8_payload<W>(writer: &mut W, value: &str) -> Result<()>
388where
389    W: Write + ?Sized,
390{
391    writer.write_all(value.as_bytes())
392}
393
394/// Writes a UTF-8 string after a `u16` byte-length prefix.
395///
396/// # Parameters
397///
398/// - `writer`: Destination writer.
399/// - `value`: String slice to write.
400/// - `write_len`: Callback that writes the encoded `u16` length.
401///
402/// # Errors
403///
404/// Returns [`ErrorKind::InvalidInput`] when the UTF-8 byte length does not fit
405/// into `u16`, or an I/O error from the underlying writer.
406pub(crate) fn write_utf8_string_with_u16_len<W, F>(writer: &mut W, value: &str, write_len: F) -> Result<()>
407where
408    W: Write + ?Sized,
409    F: FnOnce(&mut W, u16) -> Result<()>,
410{
411    let bytes = value.as_bytes();
412    write_len(writer, checked_u16_len(bytes.len())?)?;
413    writer.write_all(bytes)
414}
415
416/// Writes a UTF-8 string after a `u32` byte-length prefix.
417///
418/// # Parameters
419///
420/// - `writer`: Destination writer.
421/// - `value`: String slice to write.
422/// - `write_len`: Callback that writes the encoded `u32` length.
423///
424/// # Errors
425///
426/// Returns [`ErrorKind::InvalidInput`] when the UTF-8 byte length does not fit
427/// into `u32`, or an I/O error from the underlying writer.
428pub(crate) fn write_utf8_string_with_u32_len<W, F>(writer: &mut W, value: &str, write_len: F) -> Result<()>
429where
430    W: Write + ?Sized,
431    F: FnOnce(&mut W, u32) -> Result<()>,
432{
433    let bytes = value.as_bytes();
434    write_len(writer, checked_u32_len(bytes.len())?)?;
435    writer.write_all(bytes)
436}
437
438/// Converts a UTF-8 payload length to a `u16` length prefix value.
439///
440/// # Parameters
441///
442/// - `len`: Payload length in bytes.
443///
444/// # Returns
445///
446/// Returns the payload length represented as `u16`.
447///
448/// # Errors
449///
450/// Returns [`ErrorKind::InvalidInput`] when `len` is larger than `u16::MAX`.
451pub(crate) fn checked_u16_len(len: usize) -> Result<u16> {
452    u16::try_from(len).map_err(|_| {
453        Error::new(
454            ErrorKind::InvalidInput,
455            format!("string length {len} exceeds maximum encodable u16 length"),
456        )
457    })
458}
459
460/// Converts a UTF-8 payload length to a `u32` length prefix value.
461///
462/// # Parameters
463///
464/// - `len`: Payload length in bytes.
465///
466/// # Returns
467///
468/// Returns the payload length represented as `u32`.
469///
470/// # Errors
471///
472/// Returns [`ErrorKind::InvalidInput`] when `len` is larger than `u32::MAX`.
473pub(crate) fn checked_u32_len(len: usize) -> Result<u32> {
474    if len > u32::MAX as usize {
475        Err(Error::new(
476            ErrorKind::InvalidInput,
477            format!("string length {len} exceeds maximum encodable u32 length"),
478        ))
479    } else {
480        Ok(len as u32)
481    }
482}
483
484/// Builds an invalid-data error for UTF-8 payloads that exceed their limit.
485///
486/// # Parameters
487///
488/// - `len`: Decoded payload length.
489/// - `max_len`: Maximum accepted payload length.
490///
491/// # Returns
492///
493/// Returns an [`ErrorKind::InvalidData`] error.
494fn length_exceeded_error(len: usize, max_len: usize) -> Error {
495    Error::new(
496        ErrorKind::InvalidData,
497        format!("string length {len} exceeds maximum length of {max_len} bytes"),
498    )
499}
500
501/// Converts an invalid UTF-8 payload error into an I/O error.
502///
503/// # Parameters
504///
505/// - `error`: UTF-8 conversion error.
506///
507/// # Returns
508///
509/// Returns an [`ErrorKind::InvalidData`] error containing the UTF-8 error
510/// context.
511fn invalid_utf8_error(error: FromUtf8Error) -> Error {
512    Error::new(
513        ErrorKind::InvalidData,
514        format!("length-prefixed string is not valid UTF-8: {error}"),
515    )
516}