1extern crate std;
2use crate::{ProgressEntry, RandPusher, SeqPusher, Total};
3use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode, flush::FlushPolicy};
4use std::{boxed::Box, collections::VecDeque, path::Path};
5use tokio::{
6 fs::{File, OpenOptions},
7 io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter, SeekFrom},
8};
9
10#[derive(thiserror::Error, Debug)]
11pub enum FilePusherError {
12 #[error(transparent)]
13 MmapIo(#[from] MmapIoError),
14 #[error(transparent)]
15 TokioIo(#[from] io::Error),
16}
17
18#[derive(Debug)]
19pub struct SeqFilePusher {
20 buffer: BufWriter<File>,
21}
22impl SeqFilePusher {
23 pub fn new(file: File, buffer_size: usize) -> Self {
24 Self {
25 buffer: BufWriter::with_capacity(buffer_size, file),
26 }
27 }
28}
29impl SeqPusher for SeqFilePusher {
30 type Error = FilePusherError;
31 async fn push(&mut self, content: &[u8]) -> Result<(), Self::Error> {
32 Ok(self.buffer.write_all(content).await?)
33 }
34 async fn flush(&mut self) -> Result<(), Self::Error> {
35 Ok(self.buffer.flush().await?)
36 }
37}
38
39#[derive(Debug)]
40pub struct RandFilePusherMmap {
41 mmap: MemoryMappedFile,
42 downloaded: usize,
43 buffer_size: usize,
44}
45impl RandFilePusherMmap {
46 pub async fn new(
47 path: impl AsRef<Path>,
48 size: u64,
49 buffer_size: usize,
50 ) -> Result<Self, FilePusherError> {
51 let mmap_builder = MemoryMappedFile::builder(&path)
52 .mode(MmapMode::ReadWrite)
53 .huge_pages(true)
54 .flush_policy(FlushPolicy::Manual);
55 Ok(Self {
56 mmap: if path.as_ref().try_exists()? {
57 OpenOptions::new()
58 .write(true)
59 .open(path)
60 .await?
61 .set_len(size)
62 .await?;
63 mmap_builder.open()
64 } else {
65 mmap_builder.size(size).create()
66 }?,
67 downloaded: 0,
68 buffer_size,
69 })
70 }
71}
72impl RandPusher for RandFilePusherMmap {
73 type Error = FilePusherError;
74 async fn push(&mut self, range: ProgressEntry, bytes: &[u8]) -> Result<(), Self::Error> {
75 self.mmap
76 .as_slice_mut(range.start, range.total())?
77 .as_mut()
78 .copy_from_slice(bytes);
79 self.downloaded += bytes.len();
80 if self.downloaded >= self.buffer_size {
81 self.mmap.flush_async().await?;
82 self.downloaded = 0;
83 }
84 Ok(())
85 }
86 async fn flush(&mut self) -> Result<(), Self::Error> {
87 self.mmap.flush_async().await?;
88 Ok(())
89 }
90}
91
92#[derive(Debug)]
93pub struct RandFilePusherStd {
94 buffer: BufWriter<File>,
95 cache: VecDeque<(u64, Box<[u8]>)>,
96 p: u64,
97 cache_size: usize,
98 buffer_size: usize,
99}
100impl RandFilePusherStd {
101 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, FilePusherError> {
102 file.set_len(size).await?;
103 Ok(Self {
104 buffer: BufWriter::with_capacity(buffer_size, file),
105 cache: VecDeque::new(),
106 p: 0,
107 cache_size: 0,
108 buffer_size,
109 })
110 }
111}
112impl RandPusher for RandFilePusherStd {
113 type Error = FilePusherError;
114 async fn push(&mut self, range: ProgressEntry, bytes: &[u8]) -> Result<(), Self::Error> {
115 let pos = self.cache.partition_point(|(i, _)| i < &range.start);
116 self.cache_size += bytes.len();
117 self.cache.insert(pos, (range.start, bytes.into()));
118 if self.cache_size >= self.buffer_size {
119 self.flush().await?;
120 }
121 Ok(())
122 }
123 async fn flush(&mut self) -> Result<(), Self::Error> {
124 while let Some((start, bytes)) = self.cache.front() {
125 let len = bytes.len();
126 self.cache_size -= len;
127 if *start != self.p {
128 self.buffer.seek(SeekFrom::Start(*start)).await?;
129 self.p = *start;
130 }
131 self.buffer.write_all(bytes).await?;
132 self.p += len as u64;
133 self.cache.pop_front();
134 }
135 self.buffer.flush().await?;
136 Ok(())
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use bytes::Bytes;
144 use std::vec::Vec;
145 use tempfile::NamedTempFile;
146 use tokio::io::AsyncReadExt;
147
148 #[tokio::test]
149 async fn test_seq_file_pusher() {
150 let temp_file = NamedTempFile::new().unwrap();
152 let file_path = temp_file.path().to_path_buf();
153
154 let mut pusher = SeqFilePusher::new(temp_file.reopen().unwrap().into(), 1024);
156
157 let data1 = Bytes::from("Hello, ");
159 let data2 = Bytes::from("world!");
160 pusher.push(&data1).await.unwrap();
161 pusher.push(&data2).await.unwrap();
162 pusher.flush().await.unwrap();
163
164 let mut file_content = Vec::new();
166 File::open(&file_path)
167 .await
168 .unwrap()
169 .read_to_end(&mut file_content)
170 .await
171 .unwrap();
172 assert_eq!(file_content, b"Hello, world!");
173 }
174
175 #[tokio::test]
176 async fn test_rand_file_pusher() {
177 let temp_file = NamedTempFile::new().unwrap();
179 let file_path = temp_file.path();
180
181 let mut pusher = RandFilePusherMmap::new(file_path, 10, 8 * 1024 * 1024)
183 .await
184 .unwrap();
185
186 let data = Bytes::from("234");
188 let range = 2..5;
189 pusher.push(range, &data).await.unwrap();
190 pusher.flush().await.unwrap();
191
192 let mut file_content = Vec::new();
194 File::open(&file_path)
195 .await
196 .unwrap()
197 .read_to_end(&mut file_content)
198 .await
199 .unwrap();
200 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
201 }
202}