fast_pull/file/
writer.rs

1extern crate alloc;
2use crate::{ProgressEntry, RandWriter, SeqWriter};
3use alloc::vec::Vec;
4use bytes::Bytes;
5use memmap2::MmapMut;
6use tokio::{
7    fs::File,
8    io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter},
9};
10
11#[derive(Debug)]
12pub struct SeqFileWriter {
13    buffer: BufWriter<File>,
14}
15impl SeqFileWriter {
16    pub fn new(file: File, buffer_size: usize) -> Self {
17        Self {
18            buffer: BufWriter::with_capacity(buffer_size, file),
19        }
20    }
21}
22impl SeqWriter for SeqFileWriter {
23    type Error = io::Error;
24    async fn write(&mut self, content: Bytes) -> Result<(), Self::Error> {
25        self.buffer.write_all(&content).await
26    }
27    async fn flush(&mut self) -> Result<(), Self::Error> {
28        self.buffer.flush().await
29    }
30}
31
32#[derive(Debug)]
33pub struct RandFileWriterMmap {
34    mmap: MmapMut,
35    downloaded: usize,
36    buffer_size: usize,
37}
38impl RandFileWriterMmap {
39    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
40        file.set_len(size).await?;
41        Ok(Self {
42            mmap: unsafe { MmapMut::map_mut(&file) }?,
43            downloaded: 0,
44            buffer_size,
45        })
46    }
47}
48impl RandWriter for RandFileWriterMmap {
49    type Error = io::Error;
50    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
51        self.mmap[range.start as usize..range.end as usize].copy_from_slice(&bytes);
52        self.downloaded += bytes.len();
53        if self.downloaded >= self.buffer_size {
54            self.mmap.flush()?;
55            self.downloaded = 0;
56        }
57        Ok(())
58    }
59    async fn flush(&mut self) -> Result<(), Self::Error> {
60        self.mmap.flush_async()?;
61        Ok(())
62    }
63}
64
65#[derive(Debug)]
66pub struct RandFileWriterStd {
67    buffer: BufWriter<File>,
68    cache: Vec<(u64, Bytes)>,
69    p: u64,
70    cache_size: usize,
71    buffer_size: usize,
72}
73impl RandFileWriterStd {
74    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, io::Error> {
75        file.set_len(size).await?;
76        Ok(Self {
77            buffer: BufWriter::with_capacity(buffer_size, file),
78            cache: Vec::new(),
79            p: 0,
80            cache_size: 0,
81            buffer_size,
82        })
83    }
84}
85impl RandWriter for RandFileWriterStd {
86    type Error = io::Error;
87    async fn write(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
88        let pos = self.cache.partition_point(|(i, _)| i < &range.start);
89        self.cache_size += bytes.len();
90        self.cache.insert(pos, (range.start, bytes));
91        if self.cache_size >= self.buffer_size {
92            self.flush().await?;
93        }
94        Ok(())
95    }
96    async fn flush(&mut self) -> Result<(), Self::Error> {
97        for (start, bytes) in self.cache.drain(..) {
98            let len = bytes.len();
99            self.cache_size -= len;
100            if start != self.p {
101                self.buffer.seek(io::SeekFrom::Start(start)).await?;
102                self.p = start;
103            }
104            self.buffer.write_all(&bytes).await?;
105            self.p += len as u64;
106        }
107        self.buffer.flush().await?;
108        Ok(())
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use bytes::Bytes;
116    use tempfile::NamedTempFile;
117    use tokio::io::AsyncReadExt;
118
119    #[tokio::test]
120    async fn test_seq_file_writer() -> Result<(), io::Error> {
121        // 创建一个临时文件用于测试
122        let temp_file = NamedTempFile::new()?;
123        let file_path = temp_file.path().to_path_buf();
124
125        // 初始化 SeqFileWriter
126        let mut writer = SeqFileWriter::new(temp_file.reopen()?.into(), 1024);
127
128        // 写入数据
129        let data1 = Bytes::from("Hello, ");
130        let data2 = Bytes::from("world!");
131        writer.write(data1).await?;
132        writer.write(data2).await?;
133        writer.flush().await?;
134
135        // 验证文件内容
136        let mut file_content = Vec::new();
137        File::open(&file_path)
138            .await?
139            .read_to_end(&mut file_content)
140            .await?;
141        assert_eq!(file_content, b"Hello, world!");
142
143        Ok(())
144    }
145
146    #[tokio::test]
147    async fn test_rand_file_writer() -> Result<(), io::Error> {
148        // 创建一个临时文件用于测试
149        let temp_file = NamedTempFile::new()?;
150        let file_path = temp_file.path().to_path_buf();
151
152        // 初始化 RandFileWriter,假设文件大小为 10 字节
153        let mut writer =
154            RandFileWriterMmap::new(temp_file.reopen()?.into(), 10, 8 * 1024 * 1024).await?;
155
156        // 写入数据
157        let data = Bytes::from("234");
158        let range = 2..5;
159        writer.write(range, data).await?;
160        writer.flush().await?;
161
162        // 验证文件内容
163        let mut file_content = Vec::new();
164        File::open(&file_path)
165            .await?
166            .read_to_end(&mut file_content)
167            .await?;
168        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
169
170        Ok(())
171    }
172}