1extern crate std;
2use crate::{ProgressEntry, RandPusher, SeqPusher};
3use bytes::Bytes;
4use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode, flush::FlushPolicy};
5use std::{path::Path, vec::Vec};
6use thiserror::Error;
7use tokio::{
8 fs::{File, OpenOptions},
9 io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter, SeekFrom},
10};
11
12#[derive(Error, Debug)]
13pub enum FilePusherError {
14 #[error(transparent)]
15 MmapIo(#[from] MmapIoError),
16 #[error(transparent)]
17 TokioIo(#[from] io::Error),
18}
19
20#[derive(Debug)]
21pub struct SeqFilePusher {
22 buffer: BufWriter<File>,
23}
24impl SeqFilePusher {
25 pub fn new(file: File, buffer_size: usize) -> Self {
26 Self {
27 buffer: BufWriter::with_capacity(buffer_size, file),
28 }
29 }
30}
31impl SeqPusher for SeqFilePusher {
32 type Error = FilePusherError;
33 async fn push(&mut self, content: Bytes) -> Result<(), Self::Error> {
34 Ok(self.buffer.write_all(&content).await?)
35 }
36 async fn flush(&mut self) -> Result<(), Self::Error> {
37 Ok(self.buffer.flush().await?)
38 }
39}
40
41#[derive(Debug)]
42pub struct RandFilePusherMmap {
43 mmap: MemoryMappedFile,
44 downloaded: usize,
45 buffer_size: usize,
46}
47impl RandFilePusherMmap {
48 pub async fn new(
49 path: impl AsRef<Path>,
50 size: u64,
51 buffer_size: usize,
52 ) -> Result<Self, FilePusherError> {
53 let mmap_builder = MemoryMappedFile::builder(&path)
54 .mode(MmapMode::ReadWrite)
55 .flush_policy(FlushPolicy::Manual);
56 Ok(Self {
57 mmap: if path.as_ref().try_exists()? {
58 OpenOptions::new()
59 .write(true)
60 .open(path)
61 .await?
62 .set_len(size)
63 .await?;
64 mmap_builder.open()
65 } else {
66 mmap_builder.size(size).create()
67 }?,
68 downloaded: 0,
69 buffer_size,
70 })
71 }
72}
73impl RandPusher for RandFilePusherMmap {
74 type Error = FilePusherError;
75 async fn push(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
76 self.mmap
77 .as_slice_mut(range.start, bytes.len() as u64)?
78 .as_mut()
79 .copy_from_slice(&bytes);
80 self.downloaded += bytes.len();
81 if self.downloaded >= self.buffer_size {
82 self.mmap.flush_async().await?;
83 self.downloaded = 0;
84 }
85 Ok(())
86 }
87 async fn flush(&mut self) -> Result<(), Self::Error> {
88 self.mmap.flush_async().await?;
89 Ok(())
90 }
91}
92
93#[derive(Debug)]
94pub struct RandFilePusherStd {
95 buffer: BufWriter<File>,
96 cache: Vec<(u64, Bytes)>,
97 p: u64,
98 cache_size: usize,
99 buffer_size: usize,
100}
101impl RandFilePusherStd {
102 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, FilePusherError> {
103 file.set_len(size).await?;
104 Ok(Self {
105 buffer: BufWriter::with_capacity(buffer_size, file),
106 cache: Vec::new(),
107 p: 0,
108 cache_size: 0,
109 buffer_size,
110 })
111 }
112}
113impl RandPusher for RandFilePusherStd {
114 type Error = FilePusherError;
115 async fn push(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
116 let pos = self.cache.partition_point(|(i, _)| i < &range.start);
117 self.cache_size += bytes.len();
118 self.cache.insert(pos, (range.start, bytes));
119 if self.cache_size >= self.buffer_size {
120 self.flush().await?;
121 }
122 Ok(())
123 }
124 async fn flush(&mut self) -> Result<(), Self::Error> {
125 for (start, bytes) in self.cache.drain(..) {
126 let len = bytes.len();
127 self.cache_size -= len;
128 if start != self.p {
129 self.buffer.seek(SeekFrom::Start(start)).await?;
130 self.p = start;
131 }
132 self.buffer.write_all(&bytes).await?;
133 self.p += len as u64;
134 }
135 self.buffer.flush().await?;
136 Ok(())
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use bytes::Bytes;
144 use tempfile::NamedTempFile;
145 use tokio::io::AsyncReadExt;
146
147 #[tokio::test]
148 async fn test_seq_file_pusher() {
149 let temp_file = NamedTempFile::new().unwrap();
151 let file_path = temp_file.path().to_path_buf();
152
153 let mut pusher = SeqFilePusher::new(temp_file.reopen().unwrap().into(), 1024);
155
156 let data1 = Bytes::from("Hello, ");
158 let data2 = Bytes::from("world!");
159 pusher.push(data1).await.unwrap();
160 pusher.push(data2).await.unwrap();
161 pusher.flush().await.unwrap();
162
163 let mut file_content = Vec::new();
165 File::open(&file_path)
166 .await
167 .unwrap()
168 .read_to_end(&mut file_content)
169 .await
170 .unwrap();
171 assert_eq!(file_content, b"Hello, world!");
172 }
173
174 #[tokio::test]
175 async fn test_rand_file_pusher() {
176 let temp_file = NamedTempFile::new().unwrap();
178 let file_path = temp_file.path();
179
180 let mut pusher = RandFilePusherMmap::new(file_path, 10, 8 * 1024 * 1024)
182 .await
183 .unwrap();
184
185 let data = Bytes::from("234");
187 let range = 2..5;
188 pusher.push(range, data).await.unwrap();
189 pusher.flush().await.unwrap();
190
191 let mut file_content = Vec::new();
193 File::open(&file_path)
194 .await
195 .unwrap()
196 .read_to_end(&mut file_content)
197 .await
198 .unwrap();
199 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
200 }
201}