1extern crate alloc;
2use crate::{ProgressEntry, RandWriter, SeqWriter};
3use alloc::vec::Vec;
4use bytes::Bytes;
5use memmap2::MmapMut;
6use tokio::{
7 fs::File,
8 io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
9};
10
11#[derive(Debug)]
12pub struct SeqFileWriter {
13 buffer: BufWriter<File>,
14}
15impl SeqFileWriter {
16 pub fn new(file: File, buffer_size: usize) -> Self {
17 Self {
18 buffer: BufWriter::with_capacity(buffer_size, file),
19 }
20 }
21}
22impl SeqWriter for SeqFileWriter {
23 type Error = io::Error;
24 async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
25 self.buffer.write_all(&content).await
26 }
27 async fn flush(&mut self) -> Result<(), Self::Error> {
28 self.buffer.flush().await
29 }
30}
31
32#[derive(Debug)]
33pub struct RandFileWriterMmap {
34 mmap: MmapMut,
35 downloaded: usize,
36 buffer_size: usize,
37}
38impl RandFileWriterMmap {
39 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
40 file.set_len(size).await?;
41 Ok(Self {
42 mmap: unsafe { MmapMut::map_mut(&file) }?,
43 downloaded: 0,
44 buffer_size,
45 })
46 }
47}
48impl RandWriter for RandFileWriterMmap {
49 type Error = io::Error;
50 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
51 self.mmap[range.start as usize..range.end as usize].copy_from_slice(&bytes);
52 self.downloaded += bytes.len();
53 if self.downloaded >= self.buffer_size {
54 self.mmap.flush()?;
55 self.downloaded = 0;
56 }
57 Ok(())
58 }
59 async fn flush(&mut self) -> Result<(), Self::Error> {
60 self.mmap.flush_async()?;
61 Ok(())
62 }
63}
64
65#[derive(Debug)]
66pub struct RandFileWriterStd {
67 buffer: BufWriter<File>,
68 cache: Vec<(u64, Bytes)>,
69 p: u64,
70 cache_size: usize,
71 buffer_size: usize,
72}
73impl RandFileWriterStd {
74 pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
75 file.set_len(size).await?;
76 Ok(Self {
77 buffer: BufWriter::with_capacity(buffer_size, file),
78 cache: Vec::new(),
79 p: 0,
80 cache_size: 0,
81 buffer_size,
82 })
83 }
84}
85impl RandWriter for RandFileWriterStd {
86 type Error = io::Error;
87 async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
88 let pos = self.cache.partition_point(|(i, _)| i < &range.start);
89 self.cache_size += bytes.len();
90 self.cache.insert(pos, (range.start, bytes));
91 if self.cache_size >= self.buffer_size {
92 self.flush().await?;
93 }
94 Ok(())
95 }
96 async fn flush(&mut self) -> Result<(), Self::Error> {
97 for (start, bytes) in self.cache.drain(..) {
98 let len = bytes.len();
99 self.cache_size -= len;
100 if start != self.p {
101 self.buffer.seek(io::SeekFrom::Start(start)).await?;
102 self.p = start;
103 }
104 self.buffer.write_all(&bytes).await?;
105 self.p += len as u64;
106 }
107 self.buffer.flush().await?;
108 Ok(())
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use bytes::Bytes;
116 use tempfile::NamedTempFile;
117 use tokio::io::AsyncReadExt;
118
119 #[tokio::test]
120 async fn test_seq_file_writer() -> Result<(), io::Error> {
121 let temp_file = NamedTempFile::new()?;
123 let file_path = temp_file.path().to_path_buf();
124
125 let mut writer = SeqFileWriter::new(temp_file.reopen()?.into(), 1024);
127
128 let data1 = Bytes::from("Hello, ");
130 let data2 = Bytes::from("world!");
131 writer.write(data1).await?;
132 writer.write(data2).await?;
133 writer.flush().await?;
134
135 let mut file_content = Vec::new();
137 File::open(&file_path)
138 .await?
139 .read_to_end(&mut file_content)
140 .await?;
141 assert_eq!(file_content, b"Hello, world!");
142
143 Ok(())
144 }
145
146 #[tokio::test]
147 async fn test_rand_file_writer() -> Result<(), io::Error> {
148 let temp_file = NamedTempFile::new()?;
150 let file_path = temp_file.path().to_path_buf();
151
152 let mut writer =
154 RandFileWriterMmap::new(temp_file.reopen()?.into(), 10, 8 * 1024 * 1024).await?;
155
156 let data = Bytes::from("234");
158 let range = 2..5;
159 writer.write(range, data).await?;
160 writer.flush().await?;
161
162 let mut file_content = Vec::new();
164 File::open(&file_path)
165 .await?
166 .read_to_end(&mut file_content)
167 .await?;
168 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
169
170 Ok(())
171 }
172}