nusb/io/
read.rs

1use std::{
2    error::Error,
3    io::{BufRead, Read},
4    time::Duration,
5};
6
7#[cfg(any(feature = "tokio", feature = "smol"))]
8use std::{
9    pin::Pin,
10    task::{ready, Context, Poll},
11};
12
13use crate::{
14    transfer::{Buffer, BulkOrInterrupt, In, TransferError},
15    Endpoint,
16};
17
18/// Wrapper for a Bulk or Interrupt IN [`Endpoint`](crate::Endpoint) that
19/// manages transfers to provide a higher-level buffered API.
20///
21/// Most of the functionality of this type is provided through standard IO
22/// traits; you'll want to use one of the following:
23///
24/// * [`std::io::Read`](std::io::Read) and [`BufRead`](std::io::BufRead) for
25///   blocking IO.
26/// * With the `tokio` cargo feature,
27///   [`tokio::io::AsyncRead`](tokio::io::AsyncRead) and
28///   [`AsyncBufRead`](tokio::io::AsyncBufRead) for async IO. Tokio also
29///   provides `AsyncReadExt` and `AsyncBufReadExt` with additional methods.
30/// * With the `smol` cargo feature,
31///   [`futures_io::AsyncRead`](futures_io::AsyncRead) and
32///   [`AsyncBufRead`](futures_io::AsyncBufRead) for async IO.
33///   `futures_lite` provides `AsyncReadExt` and `AsyncBufReadExt` with
34///   additional methods.
35///
36/// By default, this type ignores USB packet lengths and boundaries. For protocols
37/// that use short or zero-length packets as delimiters, you can use the
38/// [`until_short_packet()`](Self::until_short_packet) method to get an
39/// [`EndpointReadUntilShortPacket`](EndpointReadUntilShortPacket) adapter
40/// that observes these delimiters.
41pub struct EndpointRead<EpType: BulkOrInterrupt> {
42    endpoint: Endpoint<EpType, In>,
43    reading: Option<ReadBuffer>,
44    num_transfers: usize,
45    transfer_size: usize,
46    read_timeout: Duration,
47}
48
49struct ReadBuffer {
50    pos: usize,
51    buf: Buffer,
52    status: Result<(), TransferError>,
53}
54
55impl ReadBuffer {
56    #[inline]
57    fn error(&self) -> Option<TransferError> {
58        self.status.err().filter(|e| *e != TransferError::Cancelled)
59    }
60
61    #[inline]
62    fn has_remaining(&self) -> bool {
63        self.pos < self.buf.len() || self.error().is_some()
64    }
65
66    #[inline]
67    fn has_remaining_or_short_end(&self) -> bool {
68        self.pos < self.buf.requested_len() || self.error().is_some()
69    }
70
71    #[inline]
72    fn clear_short_packet(&mut self) {
73        self.pos = usize::MAX
74    }
75
76    #[inline]
77    fn remaining(&self) -> Result<&[u8], std::io::Error> {
78        let remaining = &self.buf[self.pos..];
79        match (remaining.len(), self.error()) {
80            (0, Some(e)) => Err(e.into()),
81            _ => Ok(remaining),
82        }
83    }
84
85    #[inline]
86    fn consume(&mut self, len: usize) {
87        let remaining = self.buf.len().saturating_sub(self.pos);
88        assert!(len <= remaining, "consumed more than available");
89        self.pos += len;
90    }
91}
92
93fn copy_min(dest: &mut [u8], src: &[u8]) -> usize {
94    let len = dest.len().min(src.len());
95    dest[..len].copy_from_slice(&src[..len]);
96    len
97}
98
99impl<EpType: BulkOrInterrupt> EndpointRead<EpType> {
100    /// Create a new `EndpointRead` wrapping the given endpoint.
101    ///
102    /// The `transfer_size` parameter is the size of the buffer passed to the OS
103    /// for each transfer. It will be rounded up to the next multiple of the
104    /// endpoint's max packet size.
105    pub fn new(endpoint: Endpoint<EpType, In>, transfer_size: usize) -> Self {
106        let packet_size = endpoint.max_packet_size();
107        let transfer_size = (transfer_size.div_ceil(packet_size)).max(1) * packet_size;
108
109        Self {
110            endpoint,
111            reading: None,
112            num_transfers: 1,
113            transfer_size,
114            read_timeout: Duration::MAX,
115        }
116    }
117
118    /// Set the number of concurrent transfers.
119    ///
120    /// A value of 1 (default) means that transfers will only be submitted when
121    /// calling `read()` or `fill_buf()` and the buffer is empty. To maximize
122    /// throughput, a value of 2 or more is recommended for applications that
123    /// stream data continuously so that the host controller can continue to
124    /// receive data while the application processes the data from a completed
125    /// transfer.
126    ///
127    /// A value of 0 means no further transfers will be submitted. Existing
128    /// transfers will complete normally, and subsequent calls to `read()` and
129    /// `fill_buf()` will return zero bytes (EOF).
130    ///
131    /// This submits more transfers when increasing the number, but does not
132    /// [cancel transfers](Self::cancel_all) when decreasing it.
133    pub fn set_num_transfers(&mut self, num_transfers: usize) {
134        self.num_transfers = num_transfers;
135
136        // Leave the last transfer to be submitted by `read` such that
137        // a value of `1` only has transfers pending within `read` calls.
138        while self.endpoint.pending() < num_transfers.saturating_sub(1) {
139            let buf = self.endpoint.allocate(self.transfer_size);
140            self.endpoint.submit(buf);
141        }
142    }
143
144    /// Set the number of concurrent transfers.
145    ///
146    /// See [Self::set_num_transfers] (this version is for method chaining).
147    pub fn with_num_transfers(mut self, num_transfers: usize) -> Self {
148        self.set_num_transfers(num_transfers);
149        self
150    }
151
152    /// Set the timeout for waiting for a transfer in the blocking `read` APIs.
153    ///
154    /// This affects the `std::io::Read` and `std::io::BufRead` implementations
155    /// only, and not the async trait implementations.
156    ///
157    /// When a timeout occurs, the call fails but the transfer is not cancelled
158    /// and may complete later if the read is retried.
159    pub fn set_read_timeout(&mut self, timeout: Duration) {
160        self.read_timeout = timeout;
161    }
162
163    /// Set the timeout for an individual transfer for the blocking `read` APIs.
164    ///
165    /// See [Self::set_read_timeout] -- this is for method chaining with `EndpointWrite::new()`.
166    pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
167        self.set_read_timeout(timeout);
168        self
169    }
170
171    /// Cancel all pending transfers.
172    ///
173    /// This sets [`num_transfers`](Self::set_num_transfers) to 0, so no further
174    /// transfers will be submitted. Any data buffered before the transfers are cancelled
175    /// can be read, and then the read methods will return 0 bytes (EOF).
176    ///
177    /// Call [`num_transfers`](Self::set_num_transfers) with a non-zero value
178    /// to resume receiving data.
179    pub fn cancel_all(&mut self) {
180        self.num_transfers = 0;
181        self.endpoint.cancel_all();
182    }
183
184    /// Destroy this `EndpointRead` and return the underlying [`Endpoint`].
185    ///
186    /// Any pending transfers are not cancelled.
187    pub fn into_inner(self) -> Endpoint<EpType, In> {
188        self.endpoint
189    }
190
191    /// Get an [`EndpointReadUntilShortPacket`] adapter that will read only until
192    /// the end of a short or zero-length packet.
193    ///
194    /// Some USB protocols use packets shorter than the endpoint's max packet size
195    /// as a delimiter marking the end of a message. By default, [`EndpointRead`]
196    /// ignores packet boundaries, but this adapter allows you to observe these
197    /// delimiters.
198    pub fn until_short_packet(&mut self) -> EndpointReadUntilShortPacket<'_, EpType> {
199        EndpointReadUntilShortPacket { reader: self }
200    }
201
202    #[inline]
203    fn has_data(&self) -> bool {
204        self.reading.as_ref().is_some_and(|r| r.has_remaining())
205    }
206
207    #[inline]
208    fn has_data_or_short_end(&self) -> bool {
209        self.reading
210            .as_ref()
211            .is_some_and(|r| r.has_remaining_or_short_end())
212    }
213
214    fn resubmit(&mut self) {
215        if let Some(c) = self.reading.take() {
216            debug_assert!(!c.has_remaining());
217            self.endpoint.submit(c.buf);
218        }
219    }
220
221    fn start_read(&mut self) -> bool {
222        if self.endpoint.pending() < self.num_transfers {
223            // Re-use the last completed buffer if available
224            self.resubmit();
225            while self.endpoint.pending() < self.num_transfers {
226                // Allocate more buffers for any remaining transfers
227                let buf = self.endpoint.allocate(self.transfer_size);
228                self.endpoint.submit(buf);
229            }
230        }
231
232        // If num_transfers is 0 and all transfers are complete
233        self.endpoint.pending() > 0
234    }
235
236    #[inline]
237    fn remaining(&self) -> Result<&[u8], std::io::Error> {
238        self.reading.as_ref().unwrap().remaining()
239    }
240
241    #[inline]
242    fn consume(&mut self, len: usize) {
243        if let Some(ref mut c) = self.reading {
244            c.consume(len);
245        } else {
246            assert!(len == 0, "consumed more than available");
247        }
248    }
249
250    fn wait(&mut self) -> Result<bool, std::io::Error> {
251        if self.start_read() {
252            let c = self.endpoint.wait_next_complete(self.read_timeout);
253            let c = c.ok_or(std::io::Error::new(
254                std::io::ErrorKind::TimedOut,
255                "timeout waiting for read",
256            ))?;
257            self.reading = Some(ReadBuffer {
258                pos: 0,
259                buf: c.buffer,
260                status: c.status,
261            });
262            Ok(true)
263        } else {
264            Ok(false)
265        }
266    }
267
268    #[cfg(any(feature = "tokio", feature = "smol"))]
269    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
270        if self.start_read() {
271            let c = ready!(self.endpoint.poll_next_complete(cx));
272            self.reading = Some(ReadBuffer {
273                pos: 0,
274                buf: c.buffer,
275                status: c.status,
276            });
277            Poll::Ready(true)
278        } else {
279            Poll::Ready(false)
280        }
281    }
282
283    #[cfg(any(feature = "tokio", feature = "smol"))]
284    #[inline]
285    fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll<Result<&[u8], std::io::Error>> {
286        while !self.has_data() {
287            if !ready!(self.poll(cx)) {
288                return Poll::Ready(Ok(&[]));
289            }
290        }
291        Poll::Ready(self.remaining())
292    }
293
294    #[cfg(any(feature = "tokio", feature = "smol"))]
295    #[inline]
296    fn poll_fill_buf_until_short(
297        &mut self,
298        cx: &mut Context<'_>,
299    ) -> Poll<Result<&[u8], std::io::Error>> {
300        while !self.has_data_or_short_end() {
301            if !ready!(self.poll(cx)) {
302                return Poll::Ready(Err(std::io::Error::new(
303                    std::io::ErrorKind::UnexpectedEof,
304                    "ended without short packet",
305                )));
306            }
307        }
308        Poll::Ready(self.remaining())
309    }
310}
311
312impl<EpType: BulkOrInterrupt> Read for EndpointRead<EpType> {
313    #[inline]
314    fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
315        let remaining = self.fill_buf()?;
316        let len = copy_min(buf, remaining);
317        self.consume(len);
318        Ok(len)
319    }
320}
321
322impl<EpType: BulkOrInterrupt> BufRead for EndpointRead<EpType> {
323    #[inline]
324    fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
325        while !self.has_data() {
326            if !self.wait()? {
327                return Ok(&[]);
328            }
329        }
330        self.remaining()
331    }
332
333    #[inline]
334    fn consume(&mut self, len: usize) {
335        self.consume(len);
336    }
337}
338
339#[cfg(feature = "tokio")]
340impl<EpType: BulkOrInterrupt> tokio::io::AsyncRead for EndpointRead<EpType> {
341    fn poll_read(
342        self: Pin<&mut Self>,
343        cx: &mut Context<'_>,
344        buf: &mut tokio::io::ReadBuf<'_>,
345    ) -> Poll<Result<(), std::io::Error>> {
346        let this = Pin::into_inner(self);
347        let remaining = ready!(this.poll_fill_buf(cx))?;
348        let len = remaining.len().min(buf.remaining());
349        buf.put_slice(&remaining[..len]);
350        this.consume(len);
351        Poll::Ready(Ok(()))
352    }
353}
354
355#[cfg(feature = "tokio")]
356impl<EpType: BulkOrInterrupt> tokio::io::AsyncBufRead for EndpointRead<EpType> {
357    fn poll_fill_buf(
358        self: Pin<&mut Self>,
359        cx: &mut Context<'_>,
360    ) -> Poll<Result<&[u8], std::io::Error>> {
361        Pin::into_inner(self).poll_fill_buf(cx)
362    }
363
364    fn consume(self: Pin<&mut Self>, amt: usize) {
365        Pin::into_inner(self).consume(amt);
366    }
367}
368
369#[cfg(feature = "smol")]
370impl<EpType: BulkOrInterrupt> futures_io::AsyncRead for EndpointRead<EpType> {
371    fn poll_read(
372        self: Pin<&mut Self>,
373        cx: &mut Context<'_>,
374        buf: &mut [u8],
375    ) -> Poll<Result<usize, std::io::Error>> {
376        let this = Pin::into_inner(self);
377        let remaining = ready!(this.poll_fill_buf(cx))?;
378        let len = copy_min(buf, remaining);
379        this.consume(len);
380        Poll::Ready(Ok(len))
381    }
382}
383
384#[cfg(feature = "smol")]
385impl<EpType: BulkOrInterrupt> futures_io::AsyncBufRead for EndpointRead<EpType> {
386    fn poll_fill_buf(
387        self: Pin<&mut Self>,
388        cx: &mut Context<'_>,
389    ) -> Poll<Result<&[u8], std::io::Error>> {
390        Pin::into_inner(self).poll_fill_buf(cx)
391    }
392
393    fn consume(self: Pin<&mut Self>, amt: usize) {
394        Pin::into_inner(self).consume(amt);
395    }
396}
397
398/// Adapter for [`EndpointRead`] that ends after a short or zero-length packet.
399///
400/// This can be obtained from [`EndpointRead::until_short_packet()`]. It does
401/// have any state other than that of the underlying [`EndpointRead`], so
402/// dropping and re-creating with another call to
403/// [`EndpointRead::until_short_packet()`] has no effect.
404///
405/// This implements the same traits as `EndpointRead` but observes packet
406/// boundaries instead of ignoring them.
407pub struct EndpointReadUntilShortPacket<'a, EpType: BulkOrInterrupt> {
408    reader: &'a mut EndpointRead<EpType>,
409}
410
411/// Error returned by [`EndpointReadUntilShortPacket::consume_end()`]
412/// when the reader is not at the end of a short packet.
413#[derive(Debug)]
414pub struct ExpectedShortPacket;
415
416impl std::fmt::Display for ExpectedShortPacket {
417    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418        write!(f, "expected short packet")
419    }
420}
421
422impl Error for ExpectedShortPacket {}
423
424impl<EpType: BulkOrInterrupt> EndpointReadUntilShortPacket<'_, EpType> {
425    /// Check if the endpoint has reached the end of a short packet.
426    ///
427    /// Upon reading the end of a short packet, the next `read()` or
428    /// `fill_buf()` will return 0 bytes (EOF) and this method will return
429    /// `true`. To begin reading the next message, call `consume_end()`.
430    pub fn is_end(&self) -> bool {
431        self.reader
432            .reading
433            .as_ref()
434            .is_some_and(|r| !r.has_remaining() && r.has_remaining_or_short_end())
435    }
436
437    /// Consume the end of a short packet.
438    ///
439    /// Use this after `read()` or `fill_buf()` have returned EOF to reset the reader
440    /// to read the next message.
441    ///
442    /// Returns an error and does nothing if the reader [is not at the end of a short packet](Self::is_end).
443    pub fn consume_end(&mut self) -> Result<(), ExpectedShortPacket> {
444        if self.is_end() {
445            self.reader.reading.as_mut().unwrap().clear_short_packet();
446            Ok(())
447        } else {
448            Err(ExpectedShortPacket)
449        }
450    }
451}
452
453impl<EpType: BulkOrInterrupt> Read for EndpointReadUntilShortPacket<'_, EpType> {
454    #[inline]
455    fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
456        let remaining = self.fill_buf()?;
457        let len = copy_min(buf, remaining);
458        self.reader.consume(len);
459        Ok(len)
460    }
461}
462
463impl<EpType: BulkOrInterrupt> BufRead for EndpointReadUntilShortPacket<'_, EpType> {
464    #[inline]
465    fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
466        while !self.reader.has_data_or_short_end() {
467            if !self.reader.wait()? {
468                return Err(std::io::Error::new(
469                    std::io::ErrorKind::UnexpectedEof,
470                    "ended without short packet",
471                ));
472            }
473        }
474        self.reader.remaining()
475    }
476
477    #[inline]
478    fn consume(&mut self, len: usize) {
479        if self.reader.has_data_or_short_end() {
480            assert!(len == 0, "consumed more than available");
481        } else {
482            self.reader.consume(len);
483        }
484    }
485}
486
487#[cfg(feature = "tokio")]
488impl<EpType: BulkOrInterrupt> tokio::io::AsyncRead for EndpointReadUntilShortPacket<'_, EpType> {
489    fn poll_read(
490        self: Pin<&mut Self>,
491        cx: &mut Context<'_>,
492        buf: &mut tokio::io::ReadBuf<'_>,
493    ) -> Poll<Result<(), std::io::Error>> {
494        let this = Pin::into_inner(self);
495        let remaining = ready!(this.reader.poll_fill_buf_until_short(cx))?;
496        let len = remaining.len().min(buf.remaining());
497        buf.put_slice(&remaining[..len]);
498        this.reader.consume(len);
499        Poll::Ready(Ok(()))
500    }
501}
502
503#[cfg(feature = "tokio")]
504impl<EpType: BulkOrInterrupt> tokio::io::AsyncBufRead for EndpointReadUntilShortPacket<'_, EpType> {
505    fn poll_fill_buf(
506        self: Pin<&mut Self>,
507        cx: &mut Context<'_>,
508    ) -> Poll<Result<&[u8], std::io::Error>> {
509        Pin::into_inner(self).reader.poll_fill_buf(cx)
510    }
511
512    fn consume(self: Pin<&mut Self>, amt: usize) {
513        Pin::into_inner(self).reader.consume(amt);
514    }
515}
516
517#[cfg(feature = "smol")]
518impl<EpType: BulkOrInterrupt> futures_io::AsyncRead for EndpointReadUntilShortPacket<'_, EpType> {
519    fn poll_read(
520        self: Pin<&mut Self>,
521        cx: &mut Context<'_>,
522        buf: &mut [u8],
523    ) -> Poll<Result<usize, std::io::Error>> {
524        let this = Pin::into_inner(self);
525        let remaining = ready!(this.reader.poll_fill_buf_until_short(cx))?;
526        let len = copy_min(buf, remaining);
527        this.reader.consume(len);
528        Poll::Ready(Ok(len))
529    }
530}
531
532#[cfg(feature = "smol")]
533impl<EpType: BulkOrInterrupt> futures_io::AsyncBufRead
534    for EndpointReadUntilShortPacket<'_, EpType>
535{
536    fn poll_fill_buf(
537        self: Pin<&mut Self>,
538        cx: &mut Context<'_>,
539    ) -> Poll<Result<&[u8], std::io::Error>> {
540        Pin::into_inner(self).reader.poll_fill_buf(cx)
541    }
542
543    fn consume(self: Pin<&mut Self>, amt: usize) {
544        Pin::into_inner(self).reader.consume(amt);
545    }
546}