1use std::io;
4use std::path::{Path, PathBuf};
5
6use bytes::{Bytes, BytesMut};
7use futures::stream::StreamExt;
8use tokio::fs::File;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10
11use crate::sink::Sink;
12use crate::source::Source;
13
14pub struct FileIO;
15
16impl FileIO {
17 pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> Source<io::Result<Bytes>> {
19 let path: PathBuf = path.into();
20 let cap = chunk_size.max(512);
21 let s = futures::stream::unfold(
22 FileState { path, file: None, cap, done: false },
23 |mut state| async move {
24 if state.done {
25 return None;
26 }
27 if state.file.is_none() {
28 match File::open(&state.path).await {
29 Ok(f) => state.file = Some(f),
30 Err(e) => {
31 state.done = true;
32 return Some((Err(e), state));
33 }
34 }
35 }
36 let mut buf = BytesMut::with_capacity(state.cap);
37 buf.resize(state.cap, 0);
38 let read = state.file.as_mut().unwrap().read(&mut buf).await;
39 match read {
40 Ok(0) => None,
41 Ok(n) => {
42 buf.truncate(n);
43 Some((Ok(buf.freeze()), state))
44 }
45 Err(e) => {
46 state.done = true;
47 Some((Err(e), state))
48 }
49 }
50 },
51 )
52 .boxed();
53 Source { inner: s }
54 }
55
56 pub async fn to_path(source: Source<Bytes>, path: impl AsRef<Path>) -> io::Result<u64> {
58 let mut file = File::create(path.as_ref()).await?;
59 let mut stream = source.into_boxed();
60 let mut written: u64 = 0;
61 while let Some(chunk) = stream.next().await {
62 file.write_all(&chunk).await?;
63 written += chunk.len() as u64;
64 }
65 file.flush().await?;
66 Ok(written)
67 }
68
69 pub async fn pipe_to_path(source: Source<io::Result<Bytes>>, path: impl AsRef<Path>) -> io::Result<u64> {
71 let mut file = File::create(path.as_ref()).await?;
72 let mut stream = source.into_boxed();
73 let mut written: u64 = 0;
74 while let Some(chunk) = stream.next().await {
75 let chunk = chunk?;
76 file.write_all(&chunk).await?;
77 written += chunk.len() as u64;
78 }
79 file.flush().await?;
80 Ok(written)
81 }
82}
83
84struct FileState {
85 path: PathBuf,
86 file: Option<File>,
87 cap: usize,
88 done: bool,
89}
90
91#[allow(dead_code)]
92async fn _drain<T: Send + 'static>(s: Source<T>) {
93 Sink::ignore(s).await
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use std::io::Write;
100 use tempfile::NamedTempFile;
101
102 #[tokio::test]
103 async fn round_trip_file_read_write() {
104 let mut src = NamedTempFile::new().unwrap();
105 src.write_all(b"hello world, this is streams").unwrap();
106 let path = src.path().to_path_buf();
107
108 let dst = NamedTempFile::new().unwrap();
109 let dst_path = dst.path().to_path_buf();
110 drop(dst);
111
112 let read = FileIO::from_path(&path, 8);
113 let wrote = FileIO::pipe_to_path(read, &dst_path).await.unwrap();
114 assert!(wrote > 0);
115
116 let mut contents = Vec::new();
117 std::io::Read::read_to_end(&mut std::fs::File::open(&dst_path).unwrap(), &mut contents).unwrap();
118 assert_eq!(contents, b"hello world, this is streams");
119 }
120}