1use std::{
2 error::Error,
3 io::{BufRead, Read},
4 time::Duration,
5};
6
7#[cfg(any(feature = "tokio", feature = "smol"))]
8use std::{
9 pin::Pin,
10 task::{ready, Context, Poll},
11};
12
13use crate::{
14 transfer::{Buffer, BulkOrInterrupt, In, TransferError},
15 Endpoint,
16};
17
18pub struct EndpointRead<EpType: BulkOrInterrupt> {
42 endpoint: Endpoint<EpType, In>,
43 reading: Option<ReadBuffer>,
44 num_transfers: usize,
45 transfer_size: usize,
46 read_timeout: Duration,
47}
48
49struct ReadBuffer {
50 pos: usize,
51 buf: Buffer,
52 status: Result<(), TransferError>,
53}
54
55impl ReadBuffer {
56 #[inline]
57 fn error(&self) -> Option<TransferError> {
58 self.status.err().filter(|e| *e != TransferError::Cancelled)
59 }
60
61 #[inline]
62 fn has_remaining(&self) -> bool {
63 self.pos < self.buf.len() || self.error().is_some()
64 }
65
66 #[inline]
67 fn has_remaining_or_short_end(&self) -> bool {
68 self.pos < self.buf.requested_len() || self.error().is_some()
69 }
70
71 #[inline]
72 fn clear_short_packet(&mut self) {
73 self.pos = usize::MAX
74 }
75
76 #[inline]
77 fn remaining(&self) -> Result<&[u8], std::io::Error> {
78 let remaining = &self.buf[self.pos..];
79 match (remaining.len(), self.error()) {
80 (0, Some(e)) => Err(e.into()),
81 _ => Ok(remaining),
82 }
83 }
84
85 #[inline]
86 fn consume(&mut self, len: usize) {
87 let remaining = self.buf.len().saturating_sub(self.pos);
88 assert!(len <= remaining, "consumed more than available");
89 self.pos += len;
90 }
91}
92
93fn copy_min(dest: &mut [u8], src: &[u8]) -> usize {
94 let len = dest.len().min(src.len());
95 dest[..len].copy_from_slice(&src[..len]);
96 len
97}
98
99impl<EpType: BulkOrInterrupt> EndpointRead<EpType> {
100 pub fn new(endpoint: Endpoint<EpType, In>, transfer_size: usize) -> Self {
106 let packet_size = endpoint.max_packet_size();
107 let transfer_size = (transfer_size.div_ceil(packet_size)).max(1) * packet_size;
108
109 Self {
110 endpoint,
111 reading: None,
112 num_transfers: 1,
113 transfer_size,
114 read_timeout: Duration::MAX,
115 }
116 }
117
118 pub fn set_num_transfers(&mut self, num_transfers: usize) {
134 self.num_transfers = num_transfers;
135
136 while self.endpoint.pending() < num_transfers.saturating_sub(1) {
139 let buf = self.endpoint.allocate(self.transfer_size);
140 self.endpoint.submit(buf);
141 }
142 }
143
144 pub fn with_num_transfers(mut self, num_transfers: usize) -> Self {
148 self.set_num_transfers(num_transfers);
149 self
150 }
151
152 pub fn set_read_timeout(&mut self, timeout: Duration) {
160 self.read_timeout = timeout;
161 }
162
163 pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
167 self.set_read_timeout(timeout);
168 self
169 }
170
171 pub fn cancel_all(&mut self) {
180 self.num_transfers = 0;
181 self.endpoint.cancel_all();
182 }
183
184 pub fn into_inner(self) -> Endpoint<EpType, In> {
188 self.endpoint
189 }
190
191 pub fn until_short_packet(&mut self) -> EndpointReadUntilShortPacket<'_, EpType> {
199 EndpointReadUntilShortPacket { reader: self }
200 }
201
202 #[inline]
203 fn has_data(&self) -> bool {
204 self.reading.as_ref().is_some_and(|r| r.has_remaining())
205 }
206
207 #[inline]
208 fn has_data_or_short_end(&self) -> bool {
209 self.reading
210 .as_ref()
211 .is_some_and(|r| r.has_remaining_or_short_end())
212 }
213
214 fn resubmit(&mut self) {
215 if let Some(c) = self.reading.take() {
216 debug_assert!(!c.has_remaining());
217 self.endpoint.submit(c.buf);
218 }
219 }
220
221 fn start_read(&mut self) -> bool {
222 if self.endpoint.pending() < self.num_transfers {
223 self.resubmit();
225 while self.endpoint.pending() < self.num_transfers {
226 let buf = self.endpoint.allocate(self.transfer_size);
228 self.endpoint.submit(buf);
229 }
230 }
231
232 self.endpoint.pending() > 0
234 }
235
236 #[inline]
237 fn remaining(&self) -> Result<&[u8], std::io::Error> {
238 self.reading.as_ref().unwrap().remaining()
239 }
240
241 #[inline]
242 fn consume(&mut self, len: usize) {
243 if let Some(ref mut c) = self.reading {
244 c.consume(len);
245 } else {
246 assert!(len == 0, "consumed more than available");
247 }
248 }
249
250 fn wait(&mut self) -> Result<bool, std::io::Error> {
251 if self.start_read() {
252 let c = self.endpoint.wait_next_complete(self.read_timeout);
253 let c = c.ok_or(std::io::Error::new(
254 std::io::ErrorKind::TimedOut,
255 "timeout waiting for read",
256 ))?;
257 self.reading = Some(ReadBuffer {
258 pos: 0,
259 buf: c.buffer,
260 status: c.status,
261 });
262 Ok(true)
263 } else {
264 Ok(false)
265 }
266 }
267
268 #[cfg(any(feature = "tokio", feature = "smol"))]
269 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
270 if self.start_read() {
271 let c = ready!(self.endpoint.poll_next_complete(cx));
272 self.reading = Some(ReadBuffer {
273 pos: 0,
274 buf: c.buffer,
275 status: c.status,
276 });
277 Poll::Ready(true)
278 } else {
279 Poll::Ready(false)
280 }
281 }
282
283 #[cfg(any(feature = "tokio", feature = "smol"))]
284 #[inline]
285 fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll<Result<&[u8], std::io::Error>> {
286 while !self.has_data() {
287 if !ready!(self.poll(cx)) {
288 return Poll::Ready(Ok(&[]));
289 }
290 }
291 Poll::Ready(self.remaining())
292 }
293
294 #[cfg(any(feature = "tokio", feature = "smol"))]
295 #[inline]
296 fn poll_fill_buf_until_short(
297 &mut self,
298 cx: &mut Context<'_>,
299 ) -> Poll<Result<&[u8], std::io::Error>> {
300 while !self.has_data_or_short_end() {
301 if !ready!(self.poll(cx)) {
302 return Poll::Ready(Err(std::io::Error::new(
303 std::io::ErrorKind::UnexpectedEof,
304 "ended without short packet",
305 )));
306 }
307 }
308 Poll::Ready(self.remaining())
309 }
310}
311
312impl<EpType: BulkOrInterrupt> Read for EndpointRead<EpType> {
313 #[inline]
314 fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
315 let remaining = self.fill_buf()?;
316 let len = copy_min(buf, remaining);
317 self.consume(len);
318 Ok(len)
319 }
320}
321
322impl<EpType: BulkOrInterrupt> BufRead for EndpointRead<EpType> {
323 #[inline]
324 fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
325 while !self.has_data() {
326 if !self.wait()? {
327 return Ok(&[]);
328 }
329 }
330 self.remaining()
331 }
332
333 #[inline]
334 fn consume(&mut self, len: usize) {
335 self.consume(len);
336 }
337}
338
339#[cfg(feature = "tokio")]
340impl<EpType: BulkOrInterrupt> tokio::io::AsyncRead for EndpointRead<EpType> {
341 fn poll_read(
342 self: Pin<&mut Self>,
343 cx: &mut Context<'_>,
344 buf: &mut tokio::io::ReadBuf<'_>,
345 ) -> Poll<Result<(), std::io::Error>> {
346 let this = Pin::into_inner(self);
347 let remaining = ready!(this.poll_fill_buf(cx))?;
348 let len = remaining.len().min(buf.remaining());
349 buf.put_slice(&remaining[..len]);
350 this.consume(len);
351 Poll::Ready(Ok(()))
352 }
353}
354
355#[cfg(feature = "tokio")]
356impl<EpType: BulkOrInterrupt> tokio::io::AsyncBufRead for EndpointRead<EpType> {
357 fn poll_fill_buf(
358 self: Pin<&mut Self>,
359 cx: &mut Context<'_>,
360 ) -> Poll<Result<&[u8], std::io::Error>> {
361 Pin::into_inner(self).poll_fill_buf(cx)
362 }
363
364 fn consume(self: Pin<&mut Self>, amt: usize) {
365 Pin::into_inner(self).consume(amt);
366 }
367}
368
369#[cfg(feature = "smol")]
370impl<EpType: BulkOrInterrupt> futures_io::AsyncRead for EndpointRead<EpType> {
371 fn poll_read(
372 self: Pin<&mut Self>,
373 cx: &mut Context<'_>,
374 buf: &mut [u8],
375 ) -> Poll<Result<usize, std::io::Error>> {
376 let this = Pin::into_inner(self);
377 let remaining = ready!(this.poll_fill_buf(cx))?;
378 let len = copy_min(buf, remaining);
379 this.consume(len);
380 Poll::Ready(Ok(len))
381 }
382}
383
384#[cfg(feature = "smol")]
385impl<EpType: BulkOrInterrupt> futures_io::AsyncBufRead for EndpointRead<EpType> {
386 fn poll_fill_buf(
387 self: Pin<&mut Self>,
388 cx: &mut Context<'_>,
389 ) -> Poll<Result<&[u8], std::io::Error>> {
390 Pin::into_inner(self).poll_fill_buf(cx)
391 }
392
393 fn consume(self: Pin<&mut Self>, amt: usize) {
394 Pin::into_inner(self).consume(amt);
395 }
396}
397
398pub struct EndpointReadUntilShortPacket<'a, EpType: BulkOrInterrupt> {
408 reader: &'a mut EndpointRead<EpType>,
409}
410
411#[derive(Debug)]
414pub struct ExpectedShortPacket;
415
416impl std::fmt::Display for ExpectedShortPacket {
417 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418 write!(f, "expected short packet")
419 }
420}
421
422impl Error for ExpectedShortPacket {}
423
424impl<EpType: BulkOrInterrupt> EndpointReadUntilShortPacket<'_, EpType> {
425 pub fn is_end(&self) -> bool {
431 self.reader
432 .reading
433 .as_ref()
434 .is_some_and(|r| !r.has_remaining() && r.has_remaining_or_short_end())
435 }
436
437 pub fn consume_end(&mut self) -> Result<(), ExpectedShortPacket> {
444 if self.is_end() {
445 self.reader.reading.as_mut().unwrap().clear_short_packet();
446 Ok(())
447 } else {
448 Err(ExpectedShortPacket)
449 }
450 }
451}
452
453impl<EpType: BulkOrInterrupt> Read for EndpointReadUntilShortPacket<'_, EpType> {
454 #[inline]
455 fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
456 let remaining = self.fill_buf()?;
457 let len = copy_min(buf, remaining);
458 self.reader.consume(len);
459 Ok(len)
460 }
461}
462
463impl<EpType: BulkOrInterrupt> BufRead for EndpointReadUntilShortPacket<'_, EpType> {
464 #[inline]
465 fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
466 while !self.reader.has_data_or_short_end() {
467 if !self.reader.wait()? {
468 return Err(std::io::Error::new(
469 std::io::ErrorKind::UnexpectedEof,
470 "ended without short packet",
471 ));
472 }
473 }
474 self.reader.remaining()
475 }
476
477 #[inline]
478 fn consume(&mut self, len: usize) {
479 if self.reader.has_data_or_short_end() {
480 assert!(len == 0, "consumed more than available");
481 } else {
482 self.reader.consume(len);
483 }
484 }
485}
486
487#[cfg(feature = "tokio")]
488impl<EpType: BulkOrInterrupt> tokio::io::AsyncRead for EndpointReadUntilShortPacket<'_, EpType> {
489 fn poll_read(
490 self: Pin<&mut Self>,
491 cx: &mut Context<'_>,
492 buf: &mut tokio::io::ReadBuf<'_>,
493 ) -> Poll<Result<(), std::io::Error>> {
494 let this = Pin::into_inner(self);
495 let remaining = ready!(this.reader.poll_fill_buf_until_short(cx))?;
496 let len = remaining.len().min(buf.remaining());
497 buf.put_slice(&remaining[..len]);
498 this.reader.consume(len);
499 Poll::Ready(Ok(()))
500 }
501}
502
503#[cfg(feature = "tokio")]
504impl<EpType: BulkOrInterrupt> tokio::io::AsyncBufRead for EndpointReadUntilShortPacket<'_, EpType> {
505 fn poll_fill_buf(
506 self: Pin<&mut Self>,
507 cx: &mut Context<'_>,
508 ) -> Poll<Result<&[u8], std::io::Error>> {
509 Pin::into_inner(self).reader.poll_fill_buf(cx)
510 }
511
512 fn consume(self: Pin<&mut Self>, amt: usize) {
513 Pin::into_inner(self).reader.consume(amt);
514 }
515}
516
517#[cfg(feature = "smol")]
518impl<EpType: BulkOrInterrupt> futures_io::AsyncRead for EndpointReadUntilShortPacket<'_, EpType> {
519 fn poll_read(
520 self: Pin<&mut Self>,
521 cx: &mut Context<'_>,
522 buf: &mut [u8],
523 ) -> Poll<Result<usize, std::io::Error>> {
524 let this = Pin::into_inner(self);
525 let remaining = ready!(this.reader.poll_fill_buf_until_short(cx))?;
526 let len = copy_min(buf, remaining);
527 this.reader.consume(len);
528 Poll::Ready(Ok(len))
529 }
530}
531
532#[cfg(feature = "smol")]
533impl<EpType: BulkOrInterrupt> futures_io::AsyncBufRead
534 for EndpointReadUntilShortPacket<'_, EpType>
535{
536 fn poll_fill_buf(
537 self: Pin<&mut Self>,
538 cx: &mut Context<'_>,
539 ) -> Poll<Result<&[u8], std::io::Error>> {
540 Pin::into_inner(self).reader.poll_fill_buf(cx)
541 }
542
543 fn consume(self: Pin<&mut Self>, amt: usize) {
544 Pin::into_inner(self).reader.consume(amt);
545 }
546}