fast_pull/file/
pusher.rs

1extern crate std;
2use crate::{ProgressEntry, RandPusher, SeqPusher};
3use bytes::Bytes;
4use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode, flush::FlushPolicy};
5use std::{path::Path, vec::Vec};
6use thiserror::Error;
7use tokio::{
8    fs::{File, OpenOptions},
9    io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter, SeekFrom},
10};
11
12#[derive(Error, Debug)]
13pub enum FilePusherError {
14    #[error(transparent)]
15    MmapIo(#[from] MmapIoError),
16    #[error(transparent)]
17    TokioIo(#[from] io::Error),
18}
19
20#[derive(Debug)]
21pub struct SeqFilePusher {
22    buffer: BufWriter<File>,
23}
24impl SeqFilePusher {
25    pub fn new(file: File, buffer_size: usize) -> Self {
26        Self {
27            buffer: BufWriter::with_capacity(buffer_size, file),
28        }
29    }
30}
31impl SeqPusher for SeqFilePusher {
32    type Error = FilePusherError;
33    async fn push(&mut self, content: Bytes) -> Result<(), Self::Error> {
34        Ok(self.buffer.write_all(&content).await?)
35    }
36    async fn flush(&mut self) -> Result<(), Self::Error> {
37        Ok(self.buffer.flush().await?)
38    }
39}
40
41#[derive(Debug)]
42pub struct RandFilePusherMmap {
43    mmap: MemoryMappedFile,
44    downloaded: usize,
45    buffer_size: usize,
46}
47impl RandFilePusherMmap {
48    pub async fn new(
49        path: impl AsRef<Path>,
50        size: u64,
51        buffer_size: usize,
52    ) -> Result<Self, FilePusherError> {
53        let mmap_builder = MemoryMappedFile::builder(&path)
54            .mode(MmapMode::ReadWrite)
55            .flush_policy(FlushPolicy::Manual);
56        Ok(Self {
57            mmap: if path.as_ref().try_exists()? {
58                OpenOptions::new()
59                    .write(true)
60                    .open(path)
61                    .await?
62                    .set_len(size)
63                    .await?;
64                mmap_builder.open()
65            } else {
66                mmap_builder.size(size).create()
67            }?,
68            downloaded: 0,
69            buffer_size,
70        })
71    }
72}
73impl RandPusher for RandFilePusherMmap {
74    type Error = FilePusherError;
75    async fn push(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
76        self.mmap
77            .as_slice_mut(range.start, bytes.len() as u64)?
78            .as_mut()
79            .copy_from_slice(&bytes);
80        self.downloaded += bytes.len();
81        if self.downloaded >= self.buffer_size {
82            self.mmap.flush_async().await?;
83            self.downloaded = 0;
84        }
85        Ok(())
86    }
87    async fn flush(&mut self) -> Result<(), Self::Error> {
88        self.mmap.flush_async().await?;
89        Ok(())
90    }
91}
92
93#[derive(Debug)]
94pub struct RandFilePusherStd {
95    buffer: BufWriter<File>,
96    cache: Vec<(u64, Bytes)>,
97    p: u64,
98    cache_size: usize,
99    buffer_size: usize,
100}
101impl RandFilePusherStd {
102    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, FilePusherError> {
103        file.set_len(size).await?;
104        Ok(Self {
105            buffer: BufWriter::with_capacity(buffer_size, file),
106            cache: Vec::new(),
107            p: 0,
108            cache_size: 0,
109            buffer_size,
110        })
111    }
112}
113impl RandPusher for RandFilePusherStd {
114    type Error = FilePusherError;
115    async fn push(&mut self, range: ProgressEntry, bytes: Bytes) -> Result<(), Self::Error> {
116        let pos = self.cache.partition_point(|(i, _)| i < &range.start);
117        self.cache_size += bytes.len();
118        self.cache.insert(pos, (range.start, bytes));
119        if self.cache_size >= self.buffer_size {
120            self.flush().await?;
121        }
122        Ok(())
123    }
124    async fn flush(&mut self) -> Result<(), Self::Error> {
125        for (start, bytes) in self.cache.drain(..) {
126            let len = bytes.len();
127            self.cache_size -= len;
128            if start != self.p {
129                self.buffer.seek(SeekFrom::Start(start)).await?;
130                self.p = start;
131            }
132            self.buffer.write_all(&bytes).await?;
133            self.p += len as u64;
134        }
135        self.buffer.flush().await?;
136        Ok(())
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use bytes::Bytes;
144    use tempfile::NamedTempFile;
145    use tokio::io::AsyncReadExt;
146
147    #[tokio::test]
148    async fn test_seq_file_pusher() {
149        // 创建一个临时文件用于测试
150        let temp_file = NamedTempFile::new().unwrap();
151        let file_path = temp_file.path().to_path_buf();
152
153        // 初始化 SeqFilePusher
154        let mut pusher = SeqFilePusher::new(temp_file.reopen().unwrap().into(), 1024);
155
156        // 写入数据
157        let data1 = Bytes::from("Hello, ");
158        let data2 = Bytes::from("world!");
159        pusher.push(data1).await.unwrap();
160        pusher.push(data2).await.unwrap();
161        pusher.flush().await.unwrap();
162
163        // 验证文件内容
164        let mut file_content = Vec::new();
165        File::open(&file_path)
166            .await
167            .unwrap()
168            .read_to_end(&mut file_content)
169            .await
170            .unwrap();
171        assert_eq!(file_content, b"Hello, world!");
172    }
173
174    #[tokio::test]
175    async fn test_rand_file_pusher() {
176        // 创建一个临时文件用于测试
177        let temp_file = NamedTempFile::new().unwrap();
178        let file_path = temp_file.path();
179
180        // 初始化 RandFilePusher,假设文件大小为 10 字节
181        let mut pusher = RandFilePusherMmap::new(file_path, 10, 8 * 1024 * 1024)
182            .await
183            .unwrap();
184
185        // 写入数据
186        let data = Bytes::from("234");
187        let range = 2..5;
188        pusher.push(range, data).await.unwrap();
189        pusher.flush().await.unwrap();
190
191        // 验证文件内容
192        let mut file_content = Vec::new();
193        File::open(&file_path)
194            .await
195            .unwrap()
196            .read_to_end(&mut file_content)
197            .await
198            .unwrap();
199        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
200    }
201}