1use crate::{ProgressEntry, RandWriter, SeqWriter};
2use bytes::Bytes;
3use memmap2::MmapMut;
4use tokio::{
5 fs::File,
6 io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
7};
8
9#[derive(Debug)]
10pub struct SeqFileWriter {
11 buffer: BufWriter<File>,
12}
13impl SeqFileWriter {
14 pub fn new(file: File, buffer_size: usize) -> Self {
15 Self {
16 buffer: BufWriter::with_capacity(buffer_size, file),
17 }
18 }
19}
20impl SeqWriter for SeqFileWriter {
21 type Error = io::Error;
22 async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
23 self.buffer.write_all(&content).await
24 }
25 async fn flush(&mut self) -> Result<(), Self::Error> {
26 self.buffer.flush().await
27 }
28}
29
30#[derive(Debug)]
31pub struct RandFileWriterMmap {
32 mmap: MmapMut,
33 downloaded: usize,
34 buffer_size: usize,
35}
36impl RandFileWriterMmap {
37 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
38 file.set_len(size).await?;
39 Ok(Self {
40 mmap: unsafe { MmapMut::map_mut(&file) }?,
41 downloaded: 0,
42 buffer_size,
43 })
44 }
45}
46impl RandWriter for RandFileWriterMmap {
47 type Error = io::Error;
48 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
49 self.mmap[range.start as usize..range.end as usize].copy_from_slice(&bytes);
50 self.downloaded += bytes.len();
51 if self.downloaded >= self.buffer_size {
52 self.mmap.flush()?;
53 self.downloaded = 0;
54 }
55 Ok(())
56 }
57 async fn flush(&mut self) -> Result<(), Self::Error> {
58 self.mmap.flush_async()?;
59 Ok(())
60 }
61}
62
63#[derive(Debug)]
64pub struct RandFileWriterStd {
65 buffer: BufWriter<File>,
66 cache: Vec<(u64, Bytes)>,
67 p: u64,
68 cache_size: usize,
69 buffer_size: usize,
70}
71impl RandFileWriterStd {
72 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
73 file.set_len(size).await?;
74 Ok(Self {
75 buffer: BufWriter::with_capacity(buffer_size, file),
76 cache: Vec::new(),
77 p: 0,
78 cache_size: 0,
79 buffer_size,
80 })
81 }
82}
83impl RandWriter for RandFileWriterStd {
84 type Error = io::Error;
85 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
86 let pos = self.cache.partition_point(|(i, _)| i < &range.start);
87 self.cache_size += bytes.len();
88 self.cache.insert(pos, (range.start, bytes));
89 if self.cache_size >= self.buffer_size {
90 self.flush().await?;
91 }
92 Ok(())
93 }
94 async fn flush(&mut self) -> Result<(), Self::Error> {
95 for (start, bytes) in self.cache.drain(..) {
96 let len = bytes.len();
97 self.cache_size -= len;
98 if start != self.p {
99 self.buffer.seek(io::SeekFrom::Start(start)).await?;
100 self.p = start;
101 }
102 self.buffer.write_all(&bytes).await?;
103 self.p += len as u64;
104 }
105 self.buffer.flush().await?;
106 Ok(())
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use bytes::Bytes;
114 use tempfile::NamedTempFile;
115 use tokio::io::AsyncReadExt;
116
117 #[tokio::test]
118 async fn test_seq_file_writer() -> Result<(), io::Error> {
119 let temp_file = NamedTempFile::new()?;
121 let file_path = temp_file.path().to_path_buf();
122
123 let mut writer = SeqFileWriter::new(temp_file.reopen()?.into(), 1024);
125
126 let data1 = Bytes::from("Hello, ");
128 let data2 = Bytes::from("world!");
129 writer.write(data1).await?;
130 writer.write(data2).await?;
131 writer.flush().await?;
132
133 let mut file_content = Vec::new();
135 File::open(&file_path)
136 .await?
137 .read_to_end(&mut file_content)
138 .await?;
139 assert_eq!(file_content, b"Hello, world!");
140
141 Ok(())
142 }
143
144 #[tokio::test]
145 async fn test_rand_file_writer() -> Result<(), io::Error> {
146 let temp_file = NamedTempFile::new()?;
148 let file_path = temp_file.path().to_path_buf();
149
150 let mut writer =
152 RandFileWriterMmap::new(temp_file.reopen()?.into(), 10, 8 * 1024 * 1024).await?;
153
154 let data = Bytes::from("234");
156 let range = 2..5;
157 writer.write(range, data).await?;
158 writer.flush().await?;
159
160 let mut file_content = Vec::new();
162 File::open(&file_path)
163 .await?
164 .read_to_end(&mut file_content)
165 .await?;
166 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
167
168 Ok(())
169 }
170}