1extern crate std;
2use crate::{ProgressEntry, RandWriter, SeqWriter};
3use bytes::Bytes;
4use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode};
5use std::{path::Path, vec::Vec};
6use thiserror::Error;
7use tokio::{
8 fs::File,
9 io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
10};
11
12#[derive(Error, Debug)]
13pub enum FileWriterError {
14 #[error(transparent)]
15 MmapIo(#[from] MmapIoError),
16 #[error(transparent)]
17 TokioIo(#[from] io::Error),
18}
19
20#[derive(Debug)]
21pub struct SeqFileWriter {
22 buffer: BufWriter<File>,
23}
24impl SeqFileWriter {
25 pub fn new(file: File, buffer_size: usize) -> Self {
26 Self {
27 buffer: BufWriter::with_capacity(buffer_size, file),
28 }
29 }
30}
31impl SeqWriter for SeqFileWriter {
32 type Error = FileWriterError;
33 async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
34 Ok(self.buffer.write_all(&content).await?)
35 }
36 async fn flush(&mut self) -> Result<(), Self::Error> {
37 Ok(self.buffer.flush().await?)
38 }
39}
40
41#[derive(Debug)]
42pub struct RandFileWriterMmap {
43 mmap: MemoryMappedFile,
44 downloaded: usize,
45 buffer_size: usize,
46}
47impl RandFileWriterMmap {
48 pub fn new(
49 path: impl AsRef<Path>,
50 size: u64,
51 buffer_size: usize,
52 ) -> Result<Self, FileWriterError> {
53 let mmap = MemoryMappedFile::builder(path)
54 .huge_pages(true)
55 .mode(MmapMode::ReadWrite)
56 .size(size)
57 .create()?;
58 Ok(Self {
59 mmap,
60 downloaded: 0,
61 buffer_size,
62 })
63 }
64}
65impl RandWriter for RandFileWriterMmap {
66 type Error = FileWriterError;
67 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
68 self.mmap
69 .as_slice_mut(range.start, bytes.len() as u64)?
70 .as_mut()
71 .copy_from_slice(&bytes);
72 self.downloaded += bytes.len();
73 if self.downloaded >= self.buffer_size {
74 self.mmap.flush_async().await?;
75 }
76 Ok(())
77 }
78 async fn flush(&mut self) -> Result<(), Self::Error> {
79 self.mmap.flush_async().await?;
80 Ok(())
81 }
82}
83
84#[derive(Debug)]
85pub struct RandFileWriterStd {
86 buffer: BufWriter<File>,
87 cache: Vec<(u64, Bytes)>,
88 p: u64,
89 cache_size: usize,
90 buffer_size: usize,
91}
92impl RandFileWriterStd {
93 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
94 file.set_len(size).await?;
95 Ok(Self {
96 buffer: BufWriter::with_capacity(buffer_size, file),
97 cache: Vec::new(),
98 p: 0,
99 cache_size: 0,
100 buffer_size,
101 })
102 }
103}
104impl RandWriter for RandFileWriterStd {
105 type Error = io::Error;
106 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
107 let pos = self.cache.partition_point(|(i, _)| i < &range.start);
108 self.cache_size += bytes.len();
109 self.cache.insert(pos, (range.start, bytes));
110 if self.cache_size >= self.buffer_size {
111 self.flush().await?;
112 }
113 Ok(())
114 }
115 async fn flush(&mut self) -> Result<(), Self::Error> {
116 for (start, bytes) in self.cache.drain(..) {
117 let len = bytes.len();
118 self.cache_size -= len;
119 if start != self.p {
120 self.buffer.seek(io::SeekFrom::Start(start)).await?;
121 self.p = start;
122 }
123 self.buffer.write_all(&bytes).await?;
124 self.p += len as u64;
125 }
126 self.buffer.flush().await?;
127 Ok(())
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use bytes::Bytes;
135 use tempfile::NamedTempFile;
136 use tokio::io::AsyncReadExt;
137
138 #[tokio::test]
139 async fn test_seq_file_writer() {
140 let temp_file = NamedTempFile::new().unwrap();
142 let file_path = temp_file.path().to_path_buf();
143
144 let mut writer = SeqFileWriter::new(temp_file.reopen().unwrap().into(), 1024);
146
147 let data1 = Bytes::from("Hello, ");
149 let data2 = Bytes::from("world!");
150 writer.write(data1).await.unwrap();
151 writer.write(data2).await.unwrap();
152 writer.flush().await.unwrap();
153
154 let mut file_content = Vec::new();
156 File::open(&file_path)
157 .await
158 .unwrap()
159 .read_to_end(&mut file_content)
160 .await
161 .unwrap();
162 assert_eq!(file_content, b"Hello, world!");
163 }
164
165 #[tokio::test]
166 async fn test_rand_file_writer() {
167 let temp_file = NamedTempFile::new().unwrap();
169 let file_path = temp_file.path();
170
171 let mut writer = RandFileWriterMmap::new(file_path, 10, 8 * 1024 * 1024).unwrap();
173
174 let data = Bytes::from("234");
176 let range = 2..5;
177 writer.write(range, data).await.unwrap();
178 writer.flush().await.unwrap();
179
180 let mut file_content = Vec::new();
182 File::open(&file_path)
183 .await
184 .unwrap()
185 .read_to_end(&mut file_content)
186 .await
187 .unwrap();
188 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
189 }
190}