fast_pull/file/
writer.rs

1use crate::{ProgressEntry, RandWriter, SeqWriter};
2use bytes::Bytes;
3use memmap2::MmapMut;
4use tokio::{
5    fs::File,
6    io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
7};
8
9#[derive(Debug)]
10pub struct SeqFileWriter {
11    buffer: BufWriter<File>,
12}
13impl SeqFileWriter {
14    pub fn new(file: File, buffer_size: usize) -> Self {
15        Self {
16            buffer: BufWriter::with_capacity(buffer_size, file),
17        }
18    }
19}
20impl SeqWriter for SeqFileWriter {
21    type Error = io::Error;
22    async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
23        self.buffer.write_all(&content).await
24    }
25    async fn flush(&mut self) -> Result<(), Self::Error> {
26        self.buffer.flush().await
27    }
28}
29
30#[derive(Debug)]
31pub struct RandFileWriterMmap {
32    mmap: MmapMut,
33    downloaded: usize,
34    buffer_size: usize,
35}
36impl RandFileWriterMmap {
37    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
38        file.set_len(size).await?;
39        Ok(Self {
40            mmap: unsafe { MmapMut::map_mut(&file) }?,
41            downloaded: 0,
42            buffer_size,
43        })
44    }
45}
46impl RandWriter for RandFileWriterMmap {
47    type Error = io::Error;
48    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
49        self.mmap[range.start as usize..range.end as usize].copy_from_slice(&bytes);
50        self.downloaded += bytes.len();
51        if self.downloaded >= self.buffer_size {
52            self.mmap.flush()?;
53            self.downloaded = 0;
54        }
55        Ok(())
56    }
57    async fn flush(&mut self) -> Result<(), Self::Error> {
58        self.mmap.flush_async()?;
59        Ok(())
60    }
61}
62
63#[derive(Debug)]
64pub struct RandFileWriterStd {
65    buffer: BufWriter<File>,
66    cache: Vec<(u64, Bytes)>,
67    p: u64,
68    cache_size: usize,
69    buffer_size: usize,
70}
71impl RandFileWriterStd {
72    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
73        file.set_len(size).await?;
74        Ok(Self {
75            buffer: BufWriter::with_capacity(buffer_size, file),
76            cache: Vec::new(),
77            p: 0,
78            cache_size: 0,
79            buffer_size,
80        })
81    }
82}
83impl RandWriter for RandFileWriterStd {
84    type Error = io::Error;
85    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
86        let pos = self.cache.partition_point(|(i, _)| i < &range.start);
87        self.cache_size += bytes.len();
88        self.cache.insert(pos, (range.start, bytes));
89        if self.cache_size >= self.buffer_size {
90            self.flush().await?;
91        }
92        Ok(())
93    }
94    async fn flush(&mut self) -> Result<(), Self::Error> {
95        for (start, bytes) in self.cache.drain(..) {
96            let len = bytes.len();
97            self.cache_size -= len;
98            if start != self.p {
99                self.buffer.seek(io::SeekFrom::Start(start)).await?;
100                self.p = start;
101            }
102            self.buffer.write_all(&bytes).await?;
103            self.p += len as u64;
104        }
105        self.buffer.flush().await?;
106        Ok(())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use bytes::Bytes;
114    use tempfile::NamedTempFile;
115    use tokio::io::AsyncReadExt;
116
117    #[tokio::test]
118    async fn test_seq_file_writer() -> Result<(), io::Error> {
119        // 创建一个临时文件用于测试
120        let temp_file = NamedTempFile::new()?;
121        let file_path = temp_file.path().to_path_buf();
122
123        // 初始化 SeqFileWriter
124        let mut writer = SeqFileWriter::new(temp_file.reopen()?.into(), 1024);
125
126        // 写入数据
127        let data1 = Bytes::from("Hello, ");
128        let data2 = Bytes::from("world!");
129        writer.write(data1).await?;
130        writer.write(data2).await?;
131        writer.flush().await?;
132
133        // 验证文件内容
134        let mut file_content = Vec::new();
135        File::open(&file_path)
136            .await?
137            .read_to_end(&mut file_content)
138            .await?;
139        assert_eq!(file_content, b"Hello, world!");
140
141        Ok(())
142    }
143
144    #[tokio::test]
145    async fn test_rand_file_writer() -> Result<(), io::Error> {
146        // 创建一个临时文件用于测试
147        let temp_file = NamedTempFile::new()?;
148        let file_path = temp_file.path().to_path_buf();
149
150        // 初始化 RandFileWriter,假设文件大小为 10 字节
151        let mut writer =
152            RandFileWriterMmap::new(temp_file.reopen()?.into(), 10, 8 * 1024 * 1024).await?;
153
154        // 写入数据
155        let data = Bytes::from("234");
156        let range = 2..5;
157        writer.write(range, data).await?;
158        writer.flush().await?;
159
160        // 验证文件内容
161        let mut file_content = Vec::new();
162        File::open(&file_path)
163            .await?
164            .read_to_end(&mut file_content)
165            .await?;
166        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
167
168        Ok(())
169    }
170}