1extern crate std;
2use crate::{ProgressEntry, RandWriter, SeqWriter};
3use bytes::Bytes;
4use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode};
5use std::{path::Path, vec::Vec};
6use thiserror::Error;
7use tokio::{
8 fs::File,
9 io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
10};
11
12#[derive(Error, Debug)]
13pub enum FileWriterError {
14 #[error(transparent)]
15 MmapIo(#[from] MmapIoError),
16 #[error(transparent)]
17 TokioIo(#[from] io::Error),
18}
19
20#[derive(Debug)]
21pub struct SeqFileWriter {
22 buffer: BufWriter<File>,
23}
24impl SeqFileWriter {
25 pub fn new(file: File, buffer_size: usize) -> Self {
26 Self {
27 buffer: BufWriter::with_capacity(buffer_size, file),
28 }
29 }
30}
31impl SeqWriter for SeqFileWriter {
32 type Error = FileWriterError;
33 async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
34 Ok(self.buffer.write_all(&content).await?)
35 }
36 async fn flush(&mut self) -> Result<(), Self::Error> {
37 Ok(self.buffer.flush().await?)
38 }
39}
40
41#[derive(Debug)]
42pub struct RandFileWriterMmap {
43 mmap: MemoryMappedFile,
44 downloaded: usize,
45 buffer_size: usize,
46}
47impl RandFileWriterMmap {
48 pub fn new(
49 path: impl AsRef<Path>,
50 size: u64,
51 buffer_size: usize,
52 ) -> Result<Self, FileWriterError> {
53 let mmap_builder = MemoryMappedFile::builder(&path)
54 .huge_pages(true)
55 .mode(MmapMode::ReadWrite)
56 .size(size);
57 Ok(Self {
58 mmap: if path.as_ref().try_exists()? {
59 mmap_builder.open()
60 } else {
61 mmap_builder.create()
62 }?,
63 downloaded: 0,
64 buffer_size,
65 })
66 }
67}
68impl RandWriter for RandFileWriterMmap {
69 type Error = FileWriterError;
70 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
71 self.mmap
72 .as_slice_mut(range.start, bytes.len() as u64)?
73 .as_mut()
74 .copy_from_slice(&bytes);
75 self.downloaded += bytes.len();
76 if self.downloaded >= self.buffer_size {
77 self.mmap.flush_async().await?;
78 }
79 Ok(())
80 }
81 async fn flush(&mut self) -> Result<(), Self::Error> {
82 self.mmap.flush_async().await?;
83 Ok(())
84 }
85}
86
87#[derive(Debug)]
88pub struct RandFileWriterStd {
89 buffer: BufWriter<File>,
90 cache: Vec<(u64, Bytes)>,
91 p: u64,
92 cache_size: usize,
93 buffer_size: usize,
94}
95impl RandFileWriterStd {
96 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, FileWriterError> {
97 file.set_len(size).await?;
98 Ok(Self {
99 buffer: BufWriter::with_capacity(buffer_size, file),
100 cache: Vec::new(),
101 p: 0,
102 cache_size: 0,
103 buffer_size,
104 })
105 }
106}
107impl RandWriter for RandFileWriterStd {
108 type Error = FileWriterError;
109 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
110 let pos = self.cache.partition_point(|(i, _)| i < &range.start);
111 self.cache_size += bytes.len();
112 self.cache.insert(pos, (range.start, bytes));
113 if self.cache_size >= self.buffer_size {
114 self.flush().await?;
115 }
116 Ok(())
117 }
118 async fn flush(&mut self) -> Result<(), Self::Error> {
119 for (start, bytes) in self.cache.drain(..) {
120 let len = bytes.len();
121 self.cache_size -= len;
122 if start != self.p {
123 self.buffer.seek(io::SeekFrom::Start(start)).await?;
124 self.p = start;
125 }
126 self.buffer.write_all(&bytes).await?;
127 self.p += len as u64;
128 }
129 self.buffer.flush().await?;
130 Ok(())
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use bytes::Bytes;
138 use tempfile::NamedTempFile;
139 use tokio::io::AsyncReadExt;
140
141 #[tokio::test]
142 async fn test_seq_file_writer() {
143 let temp_file = NamedTempFile::new().unwrap();
145 let file_path = temp_file.path().to_path_buf();
146
147 let mut writer = SeqFileWriter::new(temp_file.reopen().unwrap().into(), 1024);
149
150 let data1 = Bytes::from("Hello, ");
152 let data2 = Bytes::from("world!");
153 writer.write(data1).await.unwrap();
154 writer.write(data2).await.unwrap();
155 writer.flush().await.unwrap();
156
157 let mut file_content = Vec::new();
159 File::open(&file_path)
160 .await
161 .unwrap()
162 .read_to_end(&mut file_content)
163 .await
164 .unwrap();
165 assert_eq!(file_content, b"Hello, world!");
166 }
167
168 #[tokio::test]
169 async fn test_rand_file_writer() {
170 let temp_file = NamedTempFile::new().unwrap();
172 let file_path = temp_file.path();
173
174 let mut writer = RandFileWriterMmap::new(file_path, 10, 8 * 1024 * 1024).unwrap();
176
177 let data = Bytes::from("234");
179 let range = 2..5;
180 writer.write(range, data).await.unwrap();
181 writer.flush().await.unwrap();
182
183 let mut file_content = Vec::new();
185 File::open(&file_path)
186 .await
187 .unwrap()
188 .read_to_end(&mut file_content)
189 .await
190 .unwrap();
191 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
192 }
193}