asyncio_utils/
lib.rs

1use tokio::io::{AsyncRead, AsyncSeek, SeekFrom, AsyncSeekExt, AsyncReadExt, AsyncWriteExt};
2use std::pin::Pin;
3use core::task::Poll;
4use tokio::io::ReadBuf;
5use std::error::Error;
6use std::io::Read;
7use std::io::Write;
8
9#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
10pub enum ObserverDecision {
11    Continue,
12    Abort,
13}
14/// Observe a stream's operation when user code is not involved.
15/// For example: When you call EhRead.stream_to function, and you may want to calculate
16/// the checksum (SHA1, SHA256 or MD5 along the way, you can observe the stream copy using the observer)
17pub trait StreamObserver {
18    /// Your observer's begin will be called before stream copy
19    /// If you plan to reuse the observer, you can reset it here
20    fn begin(&mut self) {
21        // Default initialize does nothing
22    }
23
24    /// Before the upstream is read. If you want to abort, abort it here by
25    /// overriding the before_read()
26    fn before_read(&mut self) -> ObserverDecision {
27        return ObserverDecision::Continue;
28    }
29    /// A chunk of data had been read from upstream, about to be written now
30    /// You can intercept it by aborting it here
31    fn before_write(&mut self, _: &[u8]) -> ObserverDecision {
32        return ObserverDecision::Continue;
33    }
34
35    /// A chunk of data had been written to down stream.
36    /// You can intercept it by aborting it here
37    fn after_write(&mut self, _:&[u8]) -> ObserverDecision {
38        return ObserverDecision::Continue;
39    }
40
41    /// The copy ended and `size` bytes had been copied
42    /// If there is error, err with be Some(cause)
43    /// Note, different from Result<usize, Box<dyn Error>>, the bytes copied is always given
44    /// Even if it is zero
45    fn end(&mut self, _:usize, _:Option<Box<&dyn Error>>) {
46
47    }
48}
49
50struct DumbObserver;
51impl StreamObserver for DumbObserver {
52}
53/// Enhanced Reader for std::io::Read
54/// It provides convenient methods for exact reading without throwing error
55/// It allow you to send it to writer
56pub trait EhRead:Read {
57    /// Try to fully read to fill the buffer, similar to read_exact, 
58    /// However, this method never throw errors on error on EOF.
59    /// On EOF, it also returns Ok(size) but size might be smaller than available buffer.
60    /// When size is smaller than buffer size, it must be EOF.
61    /// 
62    /// Upon EOF, you may read again, but you will get EOF anyway with the EOF error.
63    fn try_read_exact(&mut self, buffer: &mut [u8]) -> Result<usize, Box<dyn Error>> {
64        let wanted = buffer.len();
65        let mut copied:usize = 0;
66
67        loop {
68            let rr = self.read(&mut buffer[copied..])?;
69            if rr == 0 {
70                // EOF reached, return copied bytes so far. 
71                // Caller upon seeing result shorter than expected can either:
72                // a) declare EOF
73                // b) call try_read_exact again but receive 0 bytes as result
74                return Ok(copied);
75            }
76            copied = copied + rr;
77            if copied >= wanted {
78                return Ok(copied);
79            }
80        }
81    }
82
83
84    /// Skip bytes from the reader. Return the actual size skipped or the error.
85    /// If EOF reached before skip is complete, UnexpectedEOF error is returned.
86    /// On success, the size must be equal to the input bytes
87    fn skip(&mut self, bytes: usize) -> Result<usize, Box<dyn Error>> {
88        if bytes == 0 {
89            return Ok(0);
90        }
91        let mut buffer = [0u8; 4096];
92        let mut remaining = bytes;
93        while remaining > 0 {
94            let rr = self.try_read_exact(&mut buffer[..remaining])?;
95            if rr == 0 {
96                // EOF reached
97                return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Insufficient bytes to skip").into());
98            }
99            remaining -= rr;
100        }
101        Ok(bytes)
102    }
103
104    /// Copy all content until EOF to Write. 
105    /// Using buffer_size buffer. If not given, 4096 is used.
106    /// If buffer_size is Some(0), 4096 is used
107    /// 
108    /// Return the number of bytes copied, or any error encountered. 
109    /// 
110    /// If error is EOF, then error is not returned and size would be 0.
111    fn stream_to<W>(&mut self, w:&mut W, buffer_size:Option<usize>, observer:Option<Box<dyn StreamObserver>>) -> Result<usize, Box<dyn Error>> 
112        where W:Write + Sized
113    {
114        let mut buffer_size = buffer_size.unwrap_or(4096);
115        if buffer_size == 0 {
116            buffer_size = 4096;
117        }
118        let mut buffer = vec![0u8; buffer_size];
119        return self.stream_to_with_buffer(w, &mut buffer, observer);
120    }
121
122    /// Same as stream_to, but use externally provided buffer
123    fn stream_to_with_buffer<W>(&mut self, w:&mut W, buffer:&mut[u8], observer:Option<Box<dyn StreamObserver>>) -> Result<usize, Box<dyn Error>>
124        where W:Write+Sized {
125        let mut observer = observer;
126        let default_ob: Box<dyn StreamObserver> = Box::new(DumbObserver);
127        let mut obs = observer.take().unwrap_or(default_ob);
128        let mut copied:usize = 0;
129        loop {
130            let decision = obs.before_read();
131            if decision == ObserverDecision::Abort {
132                break;
133            }
134            let rr = self.read(buffer);
135            if rr.is_err() {
136                let err = rr.err().unwrap();
137                obs.end(copied, Some(Box::new(&err)));
138                return Err(err.into());
139            }
140            let rr = rr.unwrap();
141            if rr == 0 {
142                // EOF
143                break;
144            }
145            let decision = obs.before_write(&buffer[0..rr]);
146            if decision == ObserverDecision::Abort {
147                break;
148            }
149            let wr = w.write_all(&buffer[0..rr]);
150            if wr.is_err() {
151                let err = wr.err().unwrap();
152                obs.end(copied, Some(Box::new(&err)));
153                return Err(err.into());
154            }
155            let decision = obs.after_write(&buffer[0..rr]);
156            if decision == ObserverDecision::Abort {
157                break;
158            }
159            copied += rr;
160        }
161        return Ok(copied);
162    }
163
164
165}
166
167
168/// Blanked implementation for EhRead for all Read for free
169/// 
170/// You can use EhRead functions on all Read's implementations as long as 
171/// you use this library and import EhRead trait.
172impl <T> EhRead for T where T:Read{}
173/// Undo reader supports unread(&[u8])
174/// Useful when you are doing serialization/deserialization where you 
175/// need to put data back (undo the read)
176/// You can use UndoReader as if it is a normal AsyncRead
177/// Additionally, UndoReader supports a limit as well. It would stop reading after limit is reached (EOF)
178
179/// Example:
180/// ```
181/// // You can have rust code between fences inside the comments
182/// // If you pass --test to `rustdoc`, it will even test it for you!
183/// async fn do_test() -> Result<(), Box<dyn std::error::Error>> {
184///     use tokio::io::{AsyncRead,AsyncSeek,AsyncReadExt, AsyncSeekExt};
185///     let f = tokio::fs::File::open("input.txt").await?;
186///     let mut my_undo = crate::asyncio_utils::UndoReader::new(f, Some(10)); // only read 10 bytes
187///     let mut buff = vec![0u8; 10];
188///     let read_count = my_undo.read(&mut buff).await?;
189///     if read_count > 0 {
190///         // inspect the data read check if it is ok
191///         my_undo.unread(&mut buff[0..read_count]); // put all bytes back
192///     }
193///     let data = "XYZ".as_bytes();
194///     my_undo.unread(&data);
195///     // this should be 3 (the "XYZ" should be read here)
196///     let second_read_count = my_undo.read(&mut buff).await?;
197///     // this should be equal to read_count because it would have been reread here
198///     let third_read_count = my_undo.read(&mut buff).await?;
199///     // ...
200///     Ok(())
201/// }
202/// ```
203pub struct UndoReader<T>
204    where T:AsyncRead + Unpin
205{
206    src: T,
207    read_count: usize,
208    limit: usize,
209    buffer: Vec<Vec<u8>>
210}
211
212impl<T> UndoReader<T>
213    where T:AsyncRead + Unpin
214{
215    /// Destruct this UndoReader. 
216    /// 
217    /// Returns the buffer that has been unread but has not been consumed as well as the raw AsyncRead
218    /// Example:
219    /// ```
220    /// // initialize my_undo
221    /// async fn do_test() -> Result<(), Box<dyn std::error::Error>> {
222    ///     let f = tokio::fs::File::open("input.txt").await?;
223    ///     let my_undo = crate::asyncio_utils::UndoReader::new(f, None);
224    ///     let (remaining, raw) = my_undo.destruct();
225    ///     // ...
226    ///     Ok(())
227    /// }
228    /// // remaining is the bytes to be consumed.
229    /// // raw is the raw AsyncRead
230    /// ```
231    /// 
232    /// The UndoReader can't be used anymore after this call
233    pub fn destruct(self) -> (Vec<u8>, T) {
234        let count = self.count_unread();
235        let mut resultv = vec![0u8; count];
236        self.copy_into(&mut resultv); 
237        return (resultv, self.src)
238    }
239
240    // Copy the remaining buffer to given bytes
241    // Internal use only
242    fn copy_into(&self, buf:&mut [u8]) -> usize{
243        let mut copied = 0;
244        for i in 0.. self.buffer.len() {
245            let v = &self.buffer[ self.buffer.len() - i - 1];
246            for i in 0..v.len() {
247                buf[copied + i] = v[i];
248            }
249            copied += v.len();
250        }
251        return copied;
252    }
253
254    /// Get the limit of the UndoReader
255    /// If the limit was None, this would be the usize's max value.
256    pub fn limit(&self)->usize {
257        self.limit
258    }
259
260    /// Count the number of bytes in the unread buffer
261    pub fn count_unread(&self) -> usize {
262        let mut result:usize = 0;
263        for v in &self.buffer {
264            result += v.len();
265        }
266        return result;
267    }
268
269    /// Create new UndoReader with limitation.
270    /// If limit is None, `std::usize::MAX` will be used
271    /// If limit is Some(limit:usize), the limit will be used
272    pub fn new(src:T, limit:Option<usize>) -> UndoReader<T> {
273        UndoReader {
274            src, 
275            limit: match limit {
276                None => std::usize::MAX,
277                Some(actual) => actual
278            },
279            read_count: 0,
280            buffer: Vec::new()
281        }
282    }
283
284    /// Put data for unread so it can be read again.
285    /// 
286    /// Reading of unread data does not count towards the limit because we assume you 
287    /// unconsumed something you consumed in the first place. 
288    /// 
289    /// However, practically, you can arbitrarily unread any data. So the limit may 
290    /// break the promise in such cases
291    pub fn unread(&mut self, data:&[u8]) -> &mut Self {
292        if data.len() > 0 {
293            let mut new = vec![0u8;data.len()];
294            for (index, payload) in data.iter().enumerate() {
295                new[index] = *payload;
296            }
297            self.buffer.push(new);
298        }
299        return self;
300    }
301}
302
303
304/// Implementation of AsyncRead for UndoReader
305impl<T> AsyncRead for UndoReader<T>
306    where T:AsyncRead + Unpin
307{
308    fn poll_read(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>, 
309        data: &mut ReadBuf<'_>) -> Poll<Result<(), std::io::Error>> { 
310        loop {
311            let next = self.buffer.pop();
312            match next {
313                Some(bufdata) => {
314                    if bufdata.len() == 0 {
315                        continue;
316                    }
317                    // give this data out
318                    let available = bufdata.len();
319                    let remaining = data.remaining();
320                    if available <= remaining {
321                        data.put_slice(&bufdata);
322                    } else {
323                        data.put_slice(&bufdata[0..remaining]);
324                        let left_over = &bufdata[remaining..];
325                        let mut new_vec = vec![0u8;left_over.len()];
326                        for (index, payload) in left_over.iter().enumerate() {
327                            new_vec[index] = *payload;
328                        }
329                        self.buffer.push(new_vec);
330                    }
331                    return Poll::Ready(Ok(()));
332                },
333                None => {
334                    break;
335                }
336            }
337        }
338        if self.read_count >= self.limit {
339            // Mark EOF directly
340            return Poll::Ready(Ok(()));
341        }
342        let ms = &mut *self;
343        let p = Pin::new(&mut ms.src);
344        let before_filled = data.filled().len();
345        let result = p.poll_read(ctx, data);
346        let after_filled = data.filled().len();
347        let this_read = after_filled - before_filled;
348        self.read_count += this_read;
349
350        let overread = self.read_count > self.limit;
351        if overread {
352            let overread_count = self.read_count - self.limit;
353            //undo overread portion
354            data.set_filled(after_filled - overread_count);
355            self.read_count = self.limit;
356        }
357        
358        return result;
359    }
360}
361
362
363/// An AsyncRead + AsyncSeek wrapper
364/// It supports limit the bytes can be read.
365/// Typically used when you want to read a specific segments from a file
366/// 
367/// Example
368/// ```
369/// 
370/// async fn run_test() -> Result<(), Box<dyn std::error::Error>> {
371///     use tokio::io::SeekFrom;
372///     use tokio::io::{AsyncRead,AsyncSeek, AsyncReadExt, AsyncSeekExt};
373///     let f = tokio::fs::File::open("input.txt").await?;
374///     let read_from: u64 = 18; // start read from 18'th byte
375///     let mut lsr = crate::asyncio_utils::LimitSeekerReader::new(f, Some(20)); // read up to 20 bytes
376///     lsr.seek(SeekFrom::Start(read_from)); // do seek
377/// 
378///     let mut buf = vec![0u8; 1024];
379///     lsr.read(&mut buf); // read it
380///     return Ok(());
381/// }
382/// ```
383pub struct LimitSeekerReader<T>
384    where T:AsyncRead + AsyncSeek + Unpin
385{
386    src: T,
387    read_count: usize,
388    limit: usize,
389}
390
391/// Implement a limit reader for AsyncRead and AsyncSeek together. Typically a file
392/// Note that if your seek does not affect total reads. You can seek with positive/negative
393/// from current/begin of file/end of file, but it does not change the total bytes would be 
394/// read from the reader.
395/// 
396/// This is a little bit weird though. Typically what you want to do is just seek before reading.
397/// 
398/// This is useful when you want to service Http Get with Range requests.
399/// 
400/// You open tokio::fs::File 
401/// you seek the position
402/// you set limit on number of bytes to read
403/// you start reading and serving.
404/// 
405impl<T> LimitSeekerReader<T>
406    where T:AsyncRead + AsyncSeek + Unpin
407{
408    /// Destruct the LimitSeekerReader and get the bytes read so far and the original reader
409    /// Returns the size read and the original reader. 
410    /// 
411    /// You can't use the LimitSeekerReader after this call
412    pub fn destruct(self) -> (usize, T) {
413        (self.read_count, self.src)
414    }
415
416    /// Create new LimitSeekerReader from another AsyncRead + AsyncSeek (typically file)
417    /// 
418    /// Argument src is the underlying reader + seeker
419    /// limit is the byte limit. Node that the limit can be
420    ///     Some(limit)
421    ///     None -> No limit (std::usize::MAX)
422    pub fn new(src:T, limit:Option<usize>) -> LimitSeekerReader<T> {
423        LimitSeekerReader {
424            src, 
425            limit: {
426                match limit {
427                    None => std::usize::MAX,
428                    Some(actual_limit) => actual_limit
429                }
430            },
431            read_count: 0
432        }
433    }
434}
435
436/// Implementation of AsyncRead
437impl<T> AsyncRead for LimitSeekerReader<T>
438    where T:AsyncRead + AsyncSeek + Unpin
439{
440    fn poll_read(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>, 
441        data: &mut ReadBuf<'_>) -> Poll<Result<(), std::io::Error>> { 
442        if self.read_count >= self.limit {
443            // Mark EOF directly
444            return Poll::Ready(Ok(()));
445        }
446        let ms = &mut *self;
447        let p = Pin::new(&mut ms.src);
448        let before_filled = data.filled().len();
449        let result = p.poll_read(ctx, data);
450        let after_filled = data.filled().len();
451        let this_read = after_filled - before_filled;
452        self.read_count += this_read;
453
454        let overread = self.read_count > self.limit;
455        if overread {
456            let overread_count = self.read_count - self.limit;
457            //undo overread portion
458            data.set_filled(after_filled - overread_count);
459            self.read_count = self.limit;
460        }
461        
462        return result;
463    }
464}
465
466
467/// Implementation of AsyncSeek
468impl<T> AsyncSeek for LimitSeekerReader<T>
469    where T:AsyncRead + AsyncSeek + Unpin
470{
471    fn start_seek(mut self: Pin<&mut Self>, from: SeekFrom) -> Result<(), std::io::Error> {
472        let ms = &mut *self;
473        let p = Pin::new(&mut ms.src);
474        return p.start_seek(from);
475    }
476
477    fn poll_complete(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>) -> Poll<Result<u64, std::io::Error>> { 
478        let ms = &mut *self;
479        let p = Pin::new(&mut ms.src);
480        return p.poll_complete(ctx);
481    }
482}
483
484
485/// Pure implementation for LimitReader to restrict number of bytes can be read
486/// Useful when you want to read stream but want to end early no matter what
487/// 
488/// E.g. you can't accept more than 20MiB as HTTP Post body, you can limit it here
489/// Example:
490/// ```
491/// async fn do_test() -> Result<(), Box<dyn std::error::Error>> {
492///     use tokio::io::{AsyncReadExt};
493///     let mut f = tokio::fs::File::open("input.txt").await?;
494///     let mut reader = crate::asyncio_utils::LimitReader::new(f, Some(18)); // only read at most 18 bytes
495/// 
496///     let mut buf = vec![0u8; 2096];
497///     reader.read(&mut buf).await?;
498///     return Ok(());
499/// }
500/// ```
501pub struct LimitReader<T>
502    where T:AsyncRead + Unpin
503{
504    src: T,
505    read_count: usize,
506    limit: usize,
507}
508
509
510/// Implementation of AysncRead for LimitReader
511impl<T> LimitReader<T>
512    where T:AsyncRead + Unpin
513{
514    /// Create new LimitReader from another AsyncRead (typically File or Stream)
515    /// 
516    /// Argument src is the underlying reader
517    /// limit is the byte limit. Node that the limit can be
518    ///     Some(limit) -> The limit is set
519    ///     None -> No limit (std::usize::MAX)
520    pub fn new(src:T, limit:Option<usize>) -> LimitReader<T> {
521        LimitReader {
522            src, 
523            limit: {
524                match limit {
525                    None => std::usize::MAX,
526                    Some(actual_limit) => actual_limit
527                }
528            },
529            read_count: 0
530        }
531    }
532
533    /// Destruct the LimitReader and get the total read bytes and the original reader
534    /// 
535    /// Takes the ownership and You can't use the LimitReader after this call
536    pub fn destruct(self) -> (usize, T) {
537        (self.read_count, self.src)
538    }
539}
540
541/// Implementation of AsyncRead
542impl<T> AsyncRead for LimitReader<T>
543    where T:AsyncRead + Unpin
544{
545    fn poll_read(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>, 
546        data: &mut ReadBuf<'_>) -> Poll<Result<(), std::io::Error>> { 
547        if self.read_count >= self.limit {
548            // Mark EOF directly
549            return Poll::Ready(Ok(()));
550        }
551        let ms = &mut *self;
552        let p = Pin::new(&mut ms.src);
553        let before_filled = data.filled().len();
554        let result = p.poll_read(ctx, data);
555        let after_filled = data.filled().len();
556        let this_read = after_filled - before_filled;
557        self.read_count += this_read;
558
559        let overread = self.read_count > self.limit;
560        if overread {
561            let overread_count = self.read_count - self.limit;
562            //undo overread portion
563            data.set_filled(after_filled - overread_count);
564            self.read_count = self.limit;
565        }
566        
567        return result;
568    }
569}
570
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    #[tokio::test]
577    async fn test_undo() {
578        let file = tokio::fs::File::open("test.data").await.unwrap();
579        let mut undor = UndoReader::new(file, Some(10));
580        let mut buf = [0u8; 1024];
581        let undo = "XXX".as_bytes();
582        let undo1 = "YYY".as_bytes();
583        undor.unread(undo);
584        undor.unread(undo1);
585        let rr = undor.read(&mut buf).await.unwrap();
586        // Note Unread bytes does not count as actual consumption
587        assert_eq!(&buf[0..rr], "YYY".as_bytes()); //undo read is priority
588        let rr = undor.read(&mut buf).await.unwrap();
589        // Note Unread bytes does not count as actual consumption
590        assert_eq!(&buf[0..rr], "XXX".as_bytes()); //undo read is priority
591        let rr = undor.read(&mut buf).await.unwrap();
592        assert_eq!(&buf[0..rr], "123456789\n".as_bytes());
593
594        //assert_eq!(result, 4);
595    }
596    #[tokio::test]
597    async fn test_limit() {
598        let file = tokio::fs::File::open("test.data").await.unwrap();
599        let mut limitr = LimitReader::new(file, Some(10));
600        let mut buf = [0u8; 1024];
601        let rr = limitr.read(&mut buf).await.unwrap();
602        assert_eq!(&buf[0..rr], "123456789\n".as_bytes());
603        let rr = limitr.read(&mut buf).await.unwrap();
604        assert_eq!(&buf[0..rr], "".as_bytes());
605        //assert_eq!(result, 4);
606    }
607
608    #[tokio::test]
609    async fn test_seek() {
610        let file = tokio::fs::File::open("test.data").await.unwrap();
611        let mut limitr = LimitSeekerReader::new(file, Some(10));
612        limitr.seek(SeekFrom::Current(13)).await.unwrap();
613        let mut buf = [0u8; 1024];
614        let rr = limitr.read(&mut buf).await.unwrap();
615        assert_eq!(&buf[0..rr], "456789\n123".as_bytes());
616        let rr = limitr.read(&mut buf).await.unwrap();
617        assert_eq!(&buf[0..rr], "".as_bytes());
618        //assert_eq!(result, 4);
619    }
620}