Skip to main content

atomr_streams/
file_io.rs

1//! FileIO — read/write files as streams of `Bytes`. akka.net: `Dsl/FileIO.cs`.
2
3use std::io;
4use std::path::{Path, PathBuf};
5
6use bytes::{Bytes, BytesMut};
7use futures::stream::StreamExt;
8use tokio::fs::File;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10
11use crate::sink::Sink;
12use crate::source::Source;
13
14pub struct FileIO;
15
16impl FileIO {
17    /// Read a file in chunks of `chunk_size` bytes. akka.net: `FileIO.FromFile`.
18    pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> Source<io::Result<Bytes>> {
19        let path: PathBuf = path.into();
20        let cap = chunk_size.max(512);
21        let s = futures::stream::unfold(
22            FileState { path, file: None, cap, done: false },
23            |mut state| async move {
24                if state.done {
25                    return None;
26                }
27                if state.file.is_none() {
28                    match File::open(&state.path).await {
29                        Ok(f) => state.file = Some(f),
30                        Err(e) => {
31                            state.done = true;
32                            return Some((Err(e), state));
33                        }
34                    }
35                }
36                let mut buf = BytesMut::with_capacity(state.cap);
37                buf.resize(state.cap, 0);
38                let read = state.file.as_mut().unwrap().read(&mut buf).await;
39                match read {
40                    Ok(0) => None,
41                    Ok(n) => {
42                        buf.truncate(n);
43                        Some((Ok(buf.freeze()), state))
44                    }
45                    Err(e) => {
46                        state.done = true;
47                        Some((Err(e), state))
48                    }
49                }
50            },
51        )
52        .boxed();
53        Source { inner: s }
54    }
55
56    /// Write every `Bytes` chunk to `path`, truncating any existing file.
57    /// akka.net: `FileIO.ToFile`.
58    pub async fn to_path(source: Source<Bytes>, path: impl AsRef<Path>) -> io::Result<u64> {
59        let mut file = File::create(path.as_ref()).await?;
60        let mut stream = source.into_boxed();
61        let mut written: u64 = 0;
62        while let Some(chunk) = stream.next().await {
63            file.write_all(&chunk).await?;
64            written += chunk.len() as u64;
65        }
66        file.flush().await?;
67        Ok(written)
68    }
69
70    /// Same as `to_path`, but consumes a source of `io::Result<Bytes>`.
71    pub async fn pipe_to_path(source: Source<io::Result<Bytes>>, path: impl AsRef<Path>) -> io::Result<u64> {
72        let mut file = File::create(path.as_ref()).await?;
73        let mut stream = source.into_boxed();
74        let mut written: u64 = 0;
75        while let Some(chunk) = stream.next().await {
76            let chunk = chunk?;
77            file.write_all(&chunk).await?;
78            written += chunk.len() as u64;
79        }
80        file.flush().await?;
81        Ok(written)
82    }
83}
84
85struct FileState {
86    path: PathBuf,
87    file: Option<File>,
88    cap: usize,
89    done: bool,
90}
91
92#[allow(dead_code)]
93async fn _drain<T: Send + 'static>(s: Source<T>) {
94    Sink::ignore(s).await
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use std::io::Write;
101    use tempfile::NamedTempFile;
102
103    #[tokio::test]
104    async fn round_trip_file_read_write() {
105        let mut src = NamedTempFile::new().unwrap();
106        src.write_all(b"hello world, this is streams").unwrap();
107        let path = src.path().to_path_buf();
108
109        let dst = NamedTempFile::new().unwrap();
110        let dst_path = dst.path().to_path_buf();
111        drop(dst);
112
113        let read = FileIO::from_path(&path, 8);
114        let wrote = FileIO::pipe_to_path(read, &dst_path).await.unwrap();
115        assert!(wrote > 0);
116
117        let mut contents = Vec::new();
118        std::io::Read::read_to_end(&mut std::fs::File::open(&dst_path).unwrap(), &mut contents).unwrap();
119        assert_eq!(contents, b"hello world, this is streams");
120    }
121}