fast_pull/file/
writer.rs

1extern crate std;
2use crate::{ProgressEntry, RandWriter, SeqWriter};
3use bytes::Bytes;
4use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode};
5use std::{path::Path, vec::Vec};
6use thiserror::Error;
7use tokio::{
8    fs::File,
9    io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
10};
11
12#[derive(Error, Debug)]
13pub enum FileWriterError {
14    #[error(transparent)]
15    MmapIo(#[from] MmapIoError),
16    #[error(transparent)]
17    TokioIo(#[from] io::Error),
18}
19
20#[derive(Debug)]
21pub struct SeqFileWriter {
22    buffer: BufWriter<File>,
23}
24impl SeqFileWriter {
25    pub fn new(file: File, buffer_size: usize) -> Self {
26        Self {
27            buffer: BufWriter::with_capacity(buffer_size, file),
28        }
29    }
30}
31impl SeqWriter for SeqFileWriter {
32    type Error = FileWriterError;
33    async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
34        Ok(self.buffer.write_all(&content).await?)
35    }
36    async fn flush(&mut self) -> Result<(), Self::Error> {
37        Ok(self.buffer.flush().await?)
38    }
39}
40
41#[derive(Debug)]
42pub struct RandFileWriterMmap {
43    mmap: MemoryMappedFile,
44    downloaded: usize,
45    buffer_size: usize,
46}
47impl RandFileWriterMmap {
48    pub fn new(
49        path: impl AsRef<Path>,
50        size: u64,
51        buffer_size: usize,
52    ) -> Result<Self, FileWriterError> {
53        let mmap = MemoryMappedFile::builder(path)
54            .huge_pages(true)
55            .mode(MmapMode::ReadWrite)
56            .size(size)
57            .create()?;
58        Ok(Self {
59            mmap,
60            downloaded: 0,
61            buffer_size,
62        })
63    }
64}
65impl RandWriter for RandFileWriterMmap {
66    type Error = FileWriterError;
67    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
68        self.mmap
69            .as_slice_mut(range.start, bytes.len() as u64)?
70            .as_mut()
71            .copy_from_slice(&bytes);
72        self.downloaded += bytes.len();
73        if self.downloaded >= self.buffer_size {
74            self.mmap.flush_async().await?;
75        }
76        Ok(())
77    }
78    async fn flush(&mut self) -> Result<(), Self::Error> {
79        self.mmap.flush_async().await?;
80        Ok(())
81    }
82}
83
84#[derive(Debug)]
85pub struct RandFileWriterStd {
86    buffer: BufWriter<File>,
87    cache: Vec<(u64, Bytes)>,
88    p: u64,
89    cache_size: usize,
90    buffer_size: usize,
91}
92impl RandFileWriterStd {
93    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
94        file.set_len(size).await?;
95        Ok(Self {
96            buffer: BufWriter::with_capacity(buffer_size, file),
97            cache: Vec::new(),
98            p: 0,
99            cache_size: 0,
100            buffer_size,
101        })
102    }
103}
104impl RandWriter for RandFileWriterStd {
105    type Error = io::Error;
106    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
107        let pos = self.cache.partition_point(|(i, _)| i < &range.start);
108        self.cache_size += bytes.len();
109        self.cache.insert(pos, (range.start, bytes));
110        if self.cache_size >= self.buffer_size {
111            self.flush().await?;
112        }
113        Ok(())
114    }
115    async fn flush(&mut self) -> Result<(), Self::Error> {
116        for (start, bytes) in self.cache.drain(..) {
117            let len = bytes.len();
118            self.cache_size -= len;
119            if start != self.p {
120                self.buffer.seek(io::SeekFrom::Start(start)).await?;
121                self.p = start;
122            }
123            self.buffer.write_all(&bytes).await?;
124            self.p += len as u64;
125        }
126        self.buffer.flush().await?;
127        Ok(())
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use bytes::Bytes;
135    use tempfile::NamedTempFile;
136    use tokio::io::AsyncReadExt;
137
138    #[tokio::test]
139    async fn test_seq_file_writer() {
140        // 创建一个临时文件用于测试
141        let temp_file = NamedTempFile::new().unwrap();
142        let file_path = temp_file.path().to_path_buf();
143
144        // 初始化 SeqFileWriter
145        let mut writer = SeqFileWriter::new(temp_file.reopen().unwrap().into(), 1024);
146
147        // 写入数据
148        let data1 = Bytes::from("Hello, ");
149        let data2 = Bytes::from("world!");
150        writer.write(data1).await.unwrap();
151        writer.write(data2).await.unwrap();
152        writer.flush().await.unwrap();
153
154        // 验证文件内容
155        let mut file_content = Vec::new();
156        File::open(&file_path)
157            .await
158            .unwrap()
159            .read_to_end(&mut file_content)
160            .await
161            .unwrap();
162        assert_eq!(file_content, b"Hello, world!");
163    }
164
165    #[tokio::test]
166    async fn test_rand_file_writer() {
167        // 创建一个临时文件用于测试
168        let temp_file = NamedTempFile::new().unwrap();
169        let file_path = temp_file.path();
170
171        // 初始化 RandFileWriter,假设文件大小为 10 字节
172        let mut writer = RandFileWriterMmap::new(file_path, 10, 8 * 1024 * 1024).unwrap();
173
174        // 写入数据
175        let data = Bytes::from("234");
176        let range = 2..5;
177        writer.write(range, data).await.unwrap();
178        writer.flush().await.unwrap();
179
180        // 验证文件内容
181        let mut file_content = Vec::new();
182        File::open(&file_path)
183            .await
184            .unwrap()
185            .read_to_end(&mut file_content)
186            .await
187            .unwrap();
188        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
189    }
190}