fast_down/writer/
file.rs

1use super::{RandWriter, SeqWriter};
2use crate::ProgressEntry;
3use bytes::Bytes;
4use tokio::{
5    fs::File,
6    io::{AsyncWriteExt, BufWriter},
7};
8
9#[derive(Debug)]
10pub struct SeqFileWriter {
11    buffer: BufWriter<File>,
12}
13
14impl SeqFileWriter {
15    pub fn new(file: File, write_buffer_size: usize) -> Self {
16        Self {
17            buffer: BufWriter::with_capacity(write_buffer_size, file),
18        }
19    }
20}
21
22impl SeqWriter for SeqFileWriter {
23    async fn write_sequentially(&mut self, bytes: &Bytes) -> Result<(), std::io::Error> {
24        self.buffer.write_all(bytes).await?;
25        Ok(())
26    }
27
28    async fn flush(&mut self) -> Result<(), std::io::Error> {
29        self.buffer.flush().await?;
30        Ok(())
31    }
32}
33
34pub mod rand_file_writer_mmap {
35    use super::*;
36    use memmap2::MmapMut;
37
38    #[derive(Debug)]
39    pub struct RandFileWriter {
40        mmap: MmapMut,
41        downloaded: usize,
42        write_buffer_size: usize,
43    }
44
45    impl RandFileWriter {
46        pub async fn new(
47            file: File,
48            size: u64,
49            write_buffer_size: usize,
50        ) -> Result<Self, std::io::Error> {
51            file.set_len(size).await?;
52            Ok(Self {
53                mmap: unsafe { MmapMut::map_mut(&file) }?,
54                downloaded: 0,
55                write_buffer_size,
56            })
57        }
58    }
59
60    impl RandWriter for RandFileWriter {
61        async fn write_randomly(
62            &mut self,
63            range: ProgressEntry,
64            bytes: &Bytes,
65        ) -> Result<(), std::io::Error> {
66            self.mmap[range.start as usize..range.end as usize].copy_from_slice(bytes);
67            self.downloaded += bytes.len();
68            if self.downloaded >= self.write_buffer_size {
69                self.mmap.flush()?;
70                self.downloaded = 0;
71            }
72            Ok(())
73        }
74
75        async fn flush(&mut self) -> Result<(), std::io::Error> {
76            self.mmap.flush_async()?;
77            Ok(())
78        }
79    }
80}
81
82pub mod rand_file_writer_std {
83    use super::*;
84    use tokio::io::AsyncSeekExt;
85
86    #[derive(Debug)]
87    pub struct RandFileWriter {
88        buffer: BufWriter<File>,
89        cache: Vec<(u64, Bytes)>,
90        p: u64,
91        cache_size: usize,
92        write_buffer_size: usize,
93    }
94
95    impl RandFileWriter {
96        pub async fn new(
97            file: File,
98            size: u64,
99            write_buffer_size: usize,
100        ) -> Result<Self, std::io::Error> {
101            file.set_len(size).await?;
102            Ok(Self {
103                buffer: BufWriter::with_capacity(write_buffer_size, file),
104                cache: Vec::new(),
105                p: 0,
106                cache_size: 0,
107                write_buffer_size,
108            })
109        }
110    }
111
112    impl RandWriter for RandFileWriter {
113        async fn write_randomly(
114            &mut self,
115            range: ProgressEntry,
116            bytes: &Bytes,
117        ) -> Result<(), std::io::Error> {
118            let pos = self.cache.partition_point(|(i, _)| i < &range.start);
119            self.cache_size += bytes.len();
120            self.cache.insert(pos, (range.start, bytes.clone()));
121            if self.cache_size >= self.write_buffer_size {
122                self.flush().await?;
123            }
124            Ok(())
125        }
126
127        async fn flush(&mut self) -> Result<(), std::io::Error> {
128            for (start, bytes) in self.cache.drain(..) {
129                let len = bytes.len();
130                self.cache_size -= len;
131                if start != self.p {
132                    self.buffer.seek(std::io::SeekFrom::Start(start)).await?;
133                    self.p = start;
134                }
135                self.buffer.write_all(&bytes).await?;
136                self.p += len as u64;
137            }
138            self.buffer.flush().await?;
139            Ok(())
140        }
141    }
142}
143
144#[cfg(test)]
145#[cfg(feature = "file")]
146mod tests {
147    use super::*;
148    use crate::{RandWriter, SeqWriter};
149    use bytes::Bytes;
150    use tempfile::NamedTempFile;
151    use tokio::io::AsyncReadExt;
152
153    #[tokio::test]
154    async fn test_seq_file_writer() -> Result<(), std::io::Error> {
155        // 创建一个临时文件用于测试
156        let temp_file = NamedTempFile::new()?;
157        let file_path = temp_file.path().to_path_buf();
158
159        // 初始化 SeqFileWriter
160        let mut writer = SeqFileWriter::new(temp_file.reopen()?.into(), 1024);
161
162        // 写入数据
163        let data1 = Bytes::from("Hello, ");
164        let data2 = Bytes::from("world!");
165        writer.write_sequentially(&data1).await?;
166        writer.write_sequentially(&data2).await?;
167        writer.flush().await?;
168
169        // 验证文件内容
170        let mut file_content = Vec::new();
171        File::open(&file_path)
172            .await?
173            .read_to_end(&mut file_content)
174            .await?;
175        assert_eq!(file_content, b"Hello, world!");
176
177        Ok(())
178    }
179
180    #[tokio::test]
181    async fn test_rand_file_writer() -> Result<(), std::io::Error> {
182        // 创建一个临时文件用于测试
183        let temp_file = NamedTempFile::new()?;
184        let file_path = temp_file.path().to_path_buf();
185
186        // 初始化 RandFileWriter,假设文件大小为 10 字节
187        let mut writer = rand_file_writer_mmap::RandFileWriter::new(
188            temp_file.reopen()?.into(),
189            10,
190            8 * 1024 * 1024,
191        )
192        .await?;
193
194        // 写入数据
195        let data = Bytes::from("234");
196        let range = 2..5;
197        writer.write_randomly(range, &data).await?;
198        writer.flush().await?;
199
200        // 验证文件内容
201        let mut file_content = Vec::new();
202        File::open(&file_path)
203            .await?
204            .read_to_end(&mut file_content)
205            .await?;
206        assert_eq!(file_content, b"\0\0234\0\0\0\0\0");
207
208        Ok(())
209    }
210}