asyncio_utils/lib.rs
1use tokio::io::{AsyncRead, AsyncSeek, SeekFrom, AsyncSeekExt, AsyncReadExt, AsyncWriteExt};
2use std::pin::Pin;
3use core::task::Poll;
4use tokio::io::ReadBuf;
5use std::error::Error;
6use std::io::Read;
7use std::io::Write;
8
9#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
10pub enum ObserverDecision {
11 Continue,
12 Abort,
13}
14/// Observe a stream's operation when user code is not involved.
15/// For example: When you call EhRead.stream_to function, and you may want to calculate
16/// the checksum (SHA1, SHA256 or MD5 along the way, you can observe the stream copy using the observer)
17pub trait StreamObserver {
18 /// Your observer's begin will be called before stream copy
19 /// If you plan to reuse the observer, you can reset it here
20 fn begin(&mut self) {
21 // Default initialize does nothing
22 }
23
24 /// Before the upstream is read. If you want to abort, abort it here by
25 /// overriding the before_read()
26 fn before_read(&mut self) -> ObserverDecision {
27 return ObserverDecision::Continue;
28 }
29 /// A chunk of data had been read from upstream, about to be written now
30 /// You can intercept it by aborting it here
31 fn before_write(&mut self, _: &[u8]) -> ObserverDecision {
32 return ObserverDecision::Continue;
33 }
34
35 /// A chunk of data had been written to down stream.
36 /// You can intercept it by aborting it here
37 fn after_write(&mut self, _:&[u8]) -> ObserverDecision {
38 return ObserverDecision::Continue;
39 }
40
41 /// The copy ended and `size` bytes had been copied
42 /// If there is error, err with be Some(cause)
43 /// Note, different from Result<usize, Box<dyn Error>>, the bytes copied is always given
44 /// Even if it is zero
45 fn end(&mut self, _:usize, _:Option<Box<&dyn Error>>) {
46
47 }
48}
49
50struct DumbObserver;
51impl StreamObserver for DumbObserver {
52}
53/// Enhanced Reader for std::io::Read
54/// It provides convenient methods for exact reading without throwing error
55/// It allow you to send it to writer
56pub trait EhRead:Read {
57 /// Try to fully read to fill the buffer, similar to read_exact,
58 /// However, this method never throw errors on error on EOF.
59 /// On EOF, it also returns Ok(size) but size might be smaller than available buffer.
60 /// When size is smaller than buffer size, it must be EOF.
61 ///
62 /// Upon EOF, you may read again, but you will get EOF anyway with the EOF error.
63 fn try_read_exact(&mut self, buffer: &mut [u8]) -> Result<usize, Box<dyn Error>> {
64 let wanted = buffer.len();
65 let mut copied:usize = 0;
66
67 loop {
68 let rr = self.read(&mut buffer[copied..])?;
69 if rr == 0 {
70 // EOF reached, return copied bytes so far.
71 // Caller upon seeing result shorter than expected can either:
72 // a) declare EOF
73 // b) call try_read_exact again but receive 0 bytes as result
74 return Ok(copied);
75 }
76 copied = copied + rr;
77 if copied >= wanted {
78 return Ok(copied);
79 }
80 }
81 }
82
83
84 /// Skip bytes from the reader. Return the actual size skipped or the error.
85 /// If EOF reached before skip is complete, UnexpectedEOF error is returned.
86 /// On success, the size must be equal to the input bytes
87 fn skip(&mut self, bytes: usize) -> Result<usize, Box<dyn Error>> {
88 if bytes == 0 {
89 return Ok(0);
90 }
91 let mut buffer = [0u8; 4096];
92 let mut remaining = bytes;
93 while remaining > 0 {
94 let rr = self.try_read_exact(&mut buffer[..remaining])?;
95 if rr == 0 {
96 // EOF reached
97 return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Insufficient bytes to skip").into());
98 }
99 remaining -= rr;
100 }
101 Ok(bytes)
102 }
103
104 /// Copy all content until EOF to Write.
105 /// Using buffer_size buffer. If not given, 4096 is used.
106 /// If buffer_size is Some(0), 4096 is used
107 ///
108 /// Return the number of bytes copied, or any error encountered.
109 ///
110 /// If error is EOF, then error is not returned and size would be 0.
111 fn stream_to<W>(&mut self, w:&mut W, buffer_size:Option<usize>, observer:Option<Box<dyn StreamObserver>>) -> Result<usize, Box<dyn Error>>
112 where W:Write + Sized
113 {
114 let mut buffer_size = buffer_size.unwrap_or(4096);
115 if buffer_size == 0 {
116 buffer_size = 4096;
117 }
118 let mut buffer = vec![0u8; buffer_size];
119 return self.stream_to_with_buffer(w, &mut buffer, observer);
120 }
121
122 /// Same as stream_to, but use externally provided buffer
123 fn stream_to_with_buffer<W>(&mut self, w:&mut W, buffer:&mut[u8], observer:Option<Box<dyn StreamObserver>>) -> Result<usize, Box<dyn Error>>
124 where W:Write+Sized {
125 let mut observer = observer;
126 let default_ob: Box<dyn StreamObserver> = Box::new(DumbObserver);
127 let mut obs = observer.take().unwrap_or(default_ob);
128 let mut copied:usize = 0;
129 loop {
130 let decision = obs.before_read();
131 if decision == ObserverDecision::Abort {
132 break;
133 }
134 let rr = self.read(buffer);
135 if rr.is_err() {
136 let err = rr.err().unwrap();
137 obs.end(copied, Some(Box::new(&err)));
138 return Err(err.into());
139 }
140 let rr = rr.unwrap();
141 if rr == 0 {
142 // EOF
143 break;
144 }
145 let decision = obs.before_write(&buffer[0..rr]);
146 if decision == ObserverDecision::Abort {
147 break;
148 }
149 let wr = w.write_all(&buffer[0..rr]);
150 if wr.is_err() {
151 let err = wr.err().unwrap();
152 obs.end(copied, Some(Box::new(&err)));
153 return Err(err.into());
154 }
155 let decision = obs.after_write(&buffer[0..rr]);
156 if decision == ObserverDecision::Abort {
157 break;
158 }
159 copied += rr;
160 }
161 return Ok(copied);
162 }
163
164
165}
166
167
168/// Blanked implementation for EhRead for all Read for free
169///
170/// You can use EhRead functions on all Read's implementations as long as
171/// you use this library and import EhRead trait.
172impl <T> EhRead for T where T:Read{}
173/// Undo reader supports unread(&[u8])
174/// Useful when you are doing serialization/deserialization where you
175/// need to put data back (undo the read)
176/// You can use UndoReader as if it is a normal AsyncRead
177/// Additionally, UndoReader supports a limit as well. It would stop reading after limit is reached (EOF)
178
179/// Example:
180/// ```
181/// // You can have rust code between fences inside the comments
182/// // If you pass --test to `rustdoc`, it will even test it for you!
183/// async fn do_test() -> Result<(), Box<dyn std::error::Error>> {
184/// use tokio::io::{AsyncRead,AsyncSeek,AsyncReadExt, AsyncSeekExt};
185/// let f = tokio::fs::File::open("input.txt").await?;
186/// let mut my_undo = crate::asyncio_utils::UndoReader::new(f, Some(10)); // only read 10 bytes
187/// let mut buff = vec![0u8; 10];
188/// let read_count = my_undo.read(&mut buff).await?;
189/// if read_count > 0 {
190/// // inspect the data read check if it is ok
191/// my_undo.unread(&mut buff[0..read_count]); // put all bytes back
192/// }
193/// let data = "XYZ".as_bytes();
194/// my_undo.unread(&data);
195/// // this should be 3 (the "XYZ" should be read here)
196/// let second_read_count = my_undo.read(&mut buff).await?;
197/// // this should be equal to read_count because it would have been reread here
198/// let third_read_count = my_undo.read(&mut buff).await?;
199/// // ...
200/// Ok(())
201/// }
202/// ```
203pub struct UndoReader<T>
204 where T:AsyncRead + Unpin
205{
206 src: T,
207 read_count: usize,
208 limit: usize,
209 buffer: Vec<Vec<u8>>
210}
211
212impl<T> UndoReader<T>
213 where T:AsyncRead + Unpin
214{
215 /// Destruct this UndoReader.
216 ///
217 /// Returns the buffer that has been unread but has not been consumed as well as the raw AsyncRead
218 /// Example:
219 /// ```
220 /// // initialize my_undo
221 /// async fn do_test() -> Result<(), Box<dyn std::error::Error>> {
222 /// let f = tokio::fs::File::open("input.txt").await?;
223 /// let my_undo = crate::asyncio_utils::UndoReader::new(f, None);
224 /// let (remaining, raw) = my_undo.destruct();
225 /// // ...
226 /// Ok(())
227 /// }
228 /// // remaining is the bytes to be consumed.
229 /// // raw is the raw AsyncRead
230 /// ```
231 ///
232 /// The UndoReader can't be used anymore after this call
233 pub fn destruct(self) -> (Vec<u8>, T) {
234 let count = self.count_unread();
235 let mut resultv = vec![0u8; count];
236 self.copy_into(&mut resultv);
237 return (resultv, self.src)
238 }
239
240 // Copy the remaining buffer to given bytes
241 // Internal use only
242 fn copy_into(&self, buf:&mut [u8]) -> usize{
243 let mut copied = 0;
244 for i in 0.. self.buffer.len() {
245 let v = &self.buffer[ self.buffer.len() - i - 1];
246 for i in 0..v.len() {
247 buf[copied + i] = v[i];
248 }
249 copied += v.len();
250 }
251 return copied;
252 }
253
254 /// Get the limit of the UndoReader
255 /// If the limit was None, this would be the usize's max value.
256 pub fn limit(&self)->usize {
257 self.limit
258 }
259
260 /// Count the number of bytes in the unread buffer
261 pub fn count_unread(&self) -> usize {
262 let mut result:usize = 0;
263 for v in &self.buffer {
264 result += v.len();
265 }
266 return result;
267 }
268
269 /// Create new UndoReader with limitation.
270 /// If limit is None, `std::usize::MAX` will be used
271 /// If limit is Some(limit:usize), the limit will be used
272 pub fn new(src:T, limit:Option<usize>) -> UndoReader<T> {
273 UndoReader {
274 src,
275 limit: match limit {
276 None => std::usize::MAX,
277 Some(actual) => actual
278 },
279 read_count: 0,
280 buffer: Vec::new()
281 }
282 }
283
284 /// Put data for unread so it can be read again.
285 ///
286 /// Reading of unread data does not count towards the limit because we assume you
287 /// unconsumed something you consumed in the first place.
288 ///
289 /// However, practically, you can arbitrarily unread any data. So the limit may
290 /// break the promise in such cases
291 pub fn unread(&mut self, data:&[u8]) -> &mut Self {
292 if data.len() > 0 {
293 let mut new = vec![0u8;data.len()];
294 for (index, payload) in data.iter().enumerate() {
295 new[index] = *payload;
296 }
297 self.buffer.push(new);
298 }
299 return self;
300 }
301}
302
303
304/// Implementation of AsyncRead for UndoReader
305impl<T> AsyncRead for UndoReader<T>
306 where T:AsyncRead + Unpin
307{
308 fn poll_read(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>,
309 data: &mut ReadBuf<'_>) -> Poll<Result<(), std::io::Error>> {
310 loop {
311 let next = self.buffer.pop();
312 match next {
313 Some(bufdata) => {
314 if bufdata.len() == 0 {
315 continue;
316 }
317 // give this data out
318 let available = bufdata.len();
319 let remaining = data.remaining();
320 if available <= remaining {
321 data.put_slice(&bufdata);
322 } else {
323 data.put_slice(&bufdata[0..remaining]);
324 let left_over = &bufdata[remaining..];
325 let mut new_vec = vec![0u8;left_over.len()];
326 for (index, payload) in left_over.iter().enumerate() {
327 new_vec[index] = *payload;
328 }
329 self.buffer.push(new_vec);
330 }
331 return Poll::Ready(Ok(()));
332 },
333 None => {
334 break;
335 }
336 }
337 }
338 if self.read_count >= self.limit {
339 // Mark EOF directly
340 return Poll::Ready(Ok(()));
341 }
342 let ms = &mut *self;
343 let p = Pin::new(&mut ms.src);
344 let before_filled = data.filled().len();
345 let result = p.poll_read(ctx, data);
346 let after_filled = data.filled().len();
347 let this_read = after_filled - before_filled;
348 self.read_count += this_read;
349
350 let overread = self.read_count > self.limit;
351 if overread {
352 let overread_count = self.read_count - self.limit;
353 //undo overread portion
354 data.set_filled(after_filled - overread_count);
355 self.read_count = self.limit;
356 }
357
358 return result;
359 }
360}
361
362
363/// An AsyncRead + AsyncSeek wrapper
364/// It supports limit the bytes can be read.
365/// Typically used when you want to read a specific segments from a file
366///
367/// Example
368/// ```
369///
370/// async fn run_test() -> Result<(), Box<dyn std::error::Error>> {
371/// use tokio::io::SeekFrom;
372/// use tokio::io::{AsyncRead,AsyncSeek, AsyncReadExt, AsyncSeekExt};
373/// let f = tokio::fs::File::open("input.txt").await?;
374/// let read_from: u64 = 18; // start read from 18'th byte
375/// let mut lsr = crate::asyncio_utils::LimitSeekerReader::new(f, Some(20)); // read up to 20 bytes
376/// lsr.seek(SeekFrom::Start(read_from)); // do seek
377///
378/// let mut buf = vec![0u8; 1024];
379/// lsr.read(&mut buf); // read it
380/// return Ok(());
381/// }
382/// ```
383pub struct LimitSeekerReader<T>
384 where T:AsyncRead + AsyncSeek + Unpin
385{
386 src: T,
387 read_count: usize,
388 limit: usize,
389}
390
391/// Implement a limit reader for AsyncRead and AsyncSeek together. Typically a file
392/// Note that if your seek does not affect total reads. You can seek with positive/negative
393/// from current/begin of file/end of file, but it does not change the total bytes would be
394/// read from the reader.
395///
396/// This is a little bit weird though. Typically what you want to do is just seek before reading.
397///
398/// This is useful when you want to service Http Get with Range requests.
399///
400/// You open tokio::fs::File
401/// you seek the position
402/// you set limit on number of bytes to read
403/// you start reading and serving.
404///
405impl<T> LimitSeekerReader<T>
406 where T:AsyncRead + AsyncSeek + Unpin
407{
408 /// Destruct the LimitSeekerReader and get the bytes read so far and the original reader
409 /// Returns the size read and the original reader.
410 ///
411 /// You can't use the LimitSeekerReader after this call
412 pub fn destruct(self) -> (usize, T) {
413 (self.read_count, self.src)
414 }
415
416 /// Create new LimitSeekerReader from another AsyncRead + AsyncSeek (typically file)
417 ///
418 /// Argument src is the underlying reader + seeker
419 /// limit is the byte limit. Node that the limit can be
420 /// Some(limit)
421 /// None -> No limit (std::usize::MAX)
422 pub fn new(src:T, limit:Option<usize>) -> LimitSeekerReader<T> {
423 LimitSeekerReader {
424 src,
425 limit: {
426 match limit {
427 None => std::usize::MAX,
428 Some(actual_limit) => actual_limit
429 }
430 },
431 read_count: 0
432 }
433 }
434}
435
436/// Implementation of AsyncRead
437impl<T> AsyncRead for LimitSeekerReader<T>
438 where T:AsyncRead + AsyncSeek + Unpin
439{
440 fn poll_read(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>,
441 data: &mut ReadBuf<'_>) -> Poll<Result<(), std::io::Error>> {
442 if self.read_count >= self.limit {
443 // Mark EOF directly
444 return Poll::Ready(Ok(()));
445 }
446 let ms = &mut *self;
447 let p = Pin::new(&mut ms.src);
448 let before_filled = data.filled().len();
449 let result = p.poll_read(ctx, data);
450 let after_filled = data.filled().len();
451 let this_read = after_filled - before_filled;
452 self.read_count += this_read;
453
454 let overread = self.read_count > self.limit;
455 if overread {
456 let overread_count = self.read_count - self.limit;
457 //undo overread portion
458 data.set_filled(after_filled - overread_count);
459 self.read_count = self.limit;
460 }
461
462 return result;
463 }
464}
465
466
467/// Implementation of AsyncSeek
468impl<T> AsyncSeek for LimitSeekerReader<T>
469 where T:AsyncRead + AsyncSeek + Unpin
470{
471 fn start_seek(mut self: Pin<&mut Self>, from: SeekFrom) -> Result<(), std::io::Error> {
472 let ms = &mut *self;
473 let p = Pin::new(&mut ms.src);
474 return p.start_seek(from);
475 }
476
477 fn poll_complete(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>) -> Poll<Result<u64, std::io::Error>> {
478 let ms = &mut *self;
479 let p = Pin::new(&mut ms.src);
480 return p.poll_complete(ctx);
481 }
482}
483
484
485/// Pure implementation for LimitReader to restrict number of bytes can be read
486/// Useful when you want to read stream but want to end early no matter what
487///
488/// E.g. you can't accept more than 20MiB as HTTP Post body, you can limit it here
489/// Example:
490/// ```
491/// async fn do_test() -> Result<(), Box<dyn std::error::Error>> {
492/// use tokio::io::{AsyncReadExt};
493/// let mut f = tokio::fs::File::open("input.txt").await?;
494/// let mut reader = crate::asyncio_utils::LimitReader::new(f, Some(18)); // only read at most 18 bytes
495///
496/// let mut buf = vec![0u8; 2096];
497/// reader.read(&mut buf).await?;
498/// return Ok(());
499/// }
500/// ```
501pub struct LimitReader<T>
502 where T:AsyncRead + Unpin
503{
504 src: T,
505 read_count: usize,
506 limit: usize,
507}
508
509
510/// Implementation of AysncRead for LimitReader
511impl<T> LimitReader<T>
512 where T:AsyncRead + Unpin
513{
514 /// Create new LimitReader from another AsyncRead (typically File or Stream)
515 ///
516 /// Argument src is the underlying reader
517 /// limit is the byte limit. Node that the limit can be
518 /// Some(limit) -> The limit is set
519 /// None -> No limit (std::usize::MAX)
520 pub fn new(src:T, limit:Option<usize>) -> LimitReader<T> {
521 LimitReader {
522 src,
523 limit: {
524 match limit {
525 None => std::usize::MAX,
526 Some(actual_limit) => actual_limit
527 }
528 },
529 read_count: 0
530 }
531 }
532
533 /// Destruct the LimitReader and get the total read bytes and the original reader
534 ///
535 /// Takes the ownership and You can't use the LimitReader after this call
536 pub fn destruct(self) -> (usize, T) {
537 (self.read_count, self.src)
538 }
539}
540
541/// Implementation of AsyncRead
542impl<T> AsyncRead for LimitReader<T>
543 where T:AsyncRead + Unpin
544{
545 fn poll_read(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>,
546 data: &mut ReadBuf<'_>) -> Poll<Result<(), std::io::Error>> {
547 if self.read_count >= self.limit {
548 // Mark EOF directly
549 return Poll::Ready(Ok(()));
550 }
551 let ms = &mut *self;
552 let p = Pin::new(&mut ms.src);
553 let before_filled = data.filled().len();
554 let result = p.poll_read(ctx, data);
555 let after_filled = data.filled().len();
556 let this_read = after_filled - before_filled;
557 self.read_count += this_read;
558
559 let overread = self.read_count > self.limit;
560 if overread {
561 let overread_count = self.read_count - self.limit;
562 //undo overread portion
563 data.set_filled(after_filled - overread_count);
564 self.read_count = self.limit;
565 }
566
567 return result;
568 }
569}
570
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575
576 #[tokio::test]
577 async fn test_undo() {
578 let file = tokio::fs::File::open("test.data").await.unwrap();
579 let mut undor = UndoReader::new(file, Some(10));
580 let mut buf = [0u8; 1024];
581 let undo = "XXX".as_bytes();
582 let undo1 = "YYY".as_bytes();
583 undor.unread(undo);
584 undor.unread(undo1);
585 let rr = undor.read(&mut buf).await.unwrap();
586 // Note Unread bytes does not count as actual consumption
587 assert_eq!(&buf[0..rr], "YYY".as_bytes()); //undo read is priority
588 let rr = undor.read(&mut buf).await.unwrap();
589 // Note Unread bytes does not count as actual consumption
590 assert_eq!(&buf[0..rr], "XXX".as_bytes()); //undo read is priority
591 let rr = undor.read(&mut buf).await.unwrap();
592 assert_eq!(&buf[0..rr], "123456789\n".as_bytes());
593
594 //assert_eq!(result, 4);
595 }
596 #[tokio::test]
597 async fn test_limit() {
598 let file = tokio::fs::File::open("test.data").await.unwrap();
599 let mut limitr = LimitReader::new(file, Some(10));
600 let mut buf = [0u8; 1024];
601 let rr = limitr.read(&mut buf).await.unwrap();
602 assert_eq!(&buf[0..rr], "123456789\n".as_bytes());
603 let rr = limitr.read(&mut buf).await.unwrap();
604 assert_eq!(&buf[0..rr], "".as_bytes());
605 //assert_eq!(result, 4);
606 }
607
608 #[tokio::test]
609 async fn test_seek() {
610 let file = tokio::fs::File::open("test.data").await.unwrap();
611 let mut limitr = LimitSeekerReader::new(file, Some(10));
612 limitr.seek(SeekFrom::Current(13)).await.unwrap();
613 let mut buf = [0u8; 1024];
614 let rr = limitr.read(&mut buf).await.unwrap();
615 assert_eq!(&buf[0..rr], "456789\n123".as_bytes());
616 let rr = limitr.read(&mut buf).await.unwrap();
617 assert_eq!(&buf[0..rr], "".as_bytes());
618 //assert_eq!(result, 4);
619 }
620}