fast_pull/file/
writer.rs

1extern crate std;
2use crate::{ProgressEntry, RandWriter, SeqWriter};
3use bytes::Bytes;
4use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode};
5use std::{path::Path, vec::Vec};
6use thiserror::Error;
7use tokio::{
8    fs::File,
9    io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
10};
11
12#[derive(Error, Debug)]
13pub enum FileWriterError {
14    #[error(transparent)]
15    MmapIo(#[from] MmapIoError),
16    #[error(transparent)]
17    TokioIo(#[from] io::Error),
18}
19
20#[derive(Debug)]
21pub struct SeqFileWriter {
22    buffer: BufWriter<File>,
23}
24impl SeqFileWriter {
25    pub fn new(file: File, buffer_size: usize) -> Self {
26        Self {
27            buffer: BufWriter::with_capacity(buffer_size, file),
28        }
29    }
30}
31impl SeqWriter for SeqFileWriter {
32    type Error = FileWriterError;
33    async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
34        Ok(self.buffer.write_all(&content).await?)
35    }
36    async fn flush(&mut self) -> Result<(), Self::Error> {
37        Ok(self.buffer.flush().await?)
38    }
39}
40
41#[derive(Debug)]
42pub struct RandFileWriterMmap {
43    mmap: MemoryMappedFile,
44    downloaded: usize,
45    buffer_size: usize,
46}
47impl RandFileWriterMmap {
48    pub fn new(
49        path: impl AsRef<Path>,
50        size: u64,
51        buffer_size: usize,
52    ) -> Result<Self, FileWriterError> {
53        let mmap_builder = MemoryMappedFile::builder(&path)
54            .huge_pages(true)
55            .mode(MmapMode::ReadWrite)
56            .size(size);
57        Ok(Self {
58            mmap: if path.as_ref().try_exists()? {
59                mmap_builder.open()
60            } else {
61                mmap_builder.create()
62            }?,
63            downloaded: 0,
64            buffer_size,
65        })
66    }
67}
68impl RandWriter for RandFileWriterMmap {
69    type Error = FileWriterError;
70    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
71        self.mmap
72            .as_slice_mut(range.start, bytes.len() as u64)?
73            .as_mut()
74            .copy_from_slice(&bytes);
75        self.downloaded += bytes.len();
76        if self.downloaded >= self.buffer_size {
77            self.mmap.flush_async().await?;
78        }
79        Ok(())
80    }
81    async fn flush(&mut self) -> Result<(), Self::Error> {
82        self.mmap.flush_async().await?;
83        Ok(())
84    }
85}
86
87#[derive(Debug)]
88pub struct RandFileWriterStd {
89    buffer: BufWriter<File>,
90    cache: Vec<(u64, Bytes)>,
91    p: u64,
92    cache_size: usize,
93    buffer_size: usize,
94}
95impl RandFileWriterStd {
96    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, FileWriterError> {
97        file.set_len(size).await?;
98        Ok(Self {
99            buffer: BufWriter::with_capacity(buffer_size, file),
100            cache: Vec::new(),
101            p: 0,
102            cache_size: 0,
103            buffer_size,
104        })
105    }
106}
107impl RandWriter for RandFileWriterStd {
108    type Error = FileWriterError;
109    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
110        let pos = self.cache.partition_point(|(i, _)| i < &range.start);
111        self.cache_size += bytes.len();
112        self.cache.insert(pos, (range.start, bytes));
113        if self.cache_size >= self.buffer_size {
114            self.flush().await?;
115        }
116        Ok(())
117    }
118    async fn flush(&mut self) -> Result<(), Self::Error> {
119        for (start, bytes) in self.cache.drain(..) {
120            let len = bytes.len();
121            self.cache_size -= len;
122            if start != self.p {
123                self.buffer.seek(io::SeekFrom::Start(start)).await?;
124                self.p = start;
125            }
126            self.buffer.write_all(&bytes).await?;
127            self.p += len as u64;
128        }
129        self.buffer.flush().await?;
130        Ok(())
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use bytes::Bytes;
138    use tempfile::NamedTempFile;
139    use tokio::io::AsyncReadExt;
140
141    #[tokio::test]
142    async fn test_seq_file_writer() {
143        // 创建一个临时文件用于测试
144        let temp_file = NamedTempFile::new().unwrap();
145        let file_path = temp_file.path().to_path_buf();
146
147        // 初始化 SeqFileWriter
148        let mut writer = SeqFileWriter::new(temp_file.reopen().unwrap().into(), 1024);
149
150        // 写入数据
151        let data1 = Bytes::from("Hello, ");
152        let data2 = Bytes::from("world!");
153        writer.write(data1).await.unwrap();
154        writer.write(data2).await.unwrap();
155        writer.flush().await.unwrap();
156
157        // 验证文件内容
158        let mut file_content = Vec::new();
159        File::open(&file_path)
160            .await
161            .unwrap()
162            .read_to_end(&mut file_content)
163            .await
164            .unwrap();
165        assert_eq!(file_content, b"Hello, world!");
166    }
167
168    #[tokio::test]
169    async fn test_rand_file_writer() {
170        // 创建一个临时文件用于测试
171        let temp_file = NamedTempFile::new().unwrap();
172        let file_path = temp_file.path();
173
174        // 初始化 RandFileWriter,假设文件大小为 10 字节
175        let mut writer = RandFileWriterMmap::new(file_path, 10, 8 * 1024 * 1024).unwrap();
176
177        // 写入数据
178        let data = Bytes::from("234");
179        let range = 2..5;
180        writer.write(range, data).await.unwrap();
181        writer.flush().await.unwrap();
182
183        // 验证文件内容
184        let mut file_content = Vec::new();
185        File::open(&file_path)
186            .await
187            .unwrap()
188            .read_to_end(&mut file_content)
189            .await
190            .unwrap();
191        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
192    }
193}