fast_pull/file/
pusher.rs

1extern crate std;
2use crate::{ProgressEntry, RandPusher, SeqPusher, Total};
3use mmap_io::{MemoryMappedFile, MmapIoError, MmapMode, flush::FlushPolicy};
4use std::{boxed::Box, collections::VecDeque, path::Path};
5use tokio::{
6    fs::{File, OpenOptions},
7    io::{self, AsyncSeekExt, AsyncWriteExt, BufWriter, SeekFrom},
8};
9
10#[derive(thiserror::Error, Debug)]
11pub enum FilePusherError {
12    #[error(transparent)]
13    MmapIo(#[from] MmapIoError),
14    #[error(transparent)]
15    TokioIo(#[from] io::Error),
16}
17
18#[derive(Debug)]
19pub struct SeqFilePusher {
20    buffer: BufWriter<File>,
21}
22impl SeqFilePusher {
23    pub fn new(file: File, buffer_size: usize) -> Self {
24        Self {
25            buffer: BufWriter::with_capacity(buffer_size, file),
26        }
27    }
28}
29impl SeqPusher for SeqFilePusher {
30    type Error = FilePusherError;
31    async fn push(&mut self, content: &[u8]) -> Result<(), Self::Error> {
32        Ok(self.buffer.write_all(content).await?)
33    }
34    async fn flush(&mut self) -> Result<(), Self::Error> {
35        Ok(self.buffer.flush().await?)
36    }
37}
38
39#[derive(Debug)]
40pub struct RandFilePusherMmap {
41    mmap: MemoryMappedFile,
42    downloaded: usize,
43    buffer_size: usize,
44}
45impl RandFilePusherMmap {
46    pub async fn new(
47        path: impl AsRef<Path>,
48        size: u64,
49        buffer_size: usize,
50    ) -> Result<Self, FilePusherError> {
51        let mmap_builder = MemoryMappedFile::builder(&path)
52            .mode(MmapMode::ReadWrite)
53            .huge_pages(true)
54            .flush_policy(FlushPolicy::Manual);
55        Ok(Self {
56            mmap: if path.as_ref().try_exists()? {
57                OpenOptions::new()
58                    .write(true)
59                    .open(path)
60                    .await?
61                    .set_len(size)
62                    .await?;
63                mmap_builder.open()
64            } else {
65                mmap_builder.size(size).create()
66            }?,
67            downloaded: 0,
68            buffer_size,
69        })
70    }
71}
72impl RandPusher for RandFilePusherMmap {
73    type Error = FilePusherError;
74    async fn push(&mut self, range: ProgressEntry, bytes: &[u8]) -> Result<(), Self::Error> {
75        self.mmap
76            .as_slice_mut(range.start, range.total())?
77            .as_mut()
78            .copy_from_slice(bytes);
79        self.downloaded += bytes.len();
80        if self.downloaded >= self.buffer_size {
81            self.mmap.flush_async().await?;
82            self.downloaded = 0;
83        }
84        Ok(())
85    }
86    async fn flush(&mut self) -> Result<(), Self::Error> {
87        self.mmap.flush_async().await?;
88        Ok(())
89    }
90}
91
92#[derive(Debug)]
93pub struct RandFilePusherStd {
94    buffer: BufWriter<File>,
95    cache: VecDeque<(u64, Box<[u8]>)>,
96    p: u64,
97    cache_size: usize,
98    buffer_size: usize,
99}
100impl RandFilePusherStd {
101    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, FilePusherError> {
102        file.set_len(size).await?;
103        Ok(Self {
104            buffer: BufWriter::with_capacity(buffer_size, file),
105            cache: VecDeque::new(),
106            p: 0,
107            cache_size: 0,
108            buffer_size,
109        })
110    }
111}
112impl RandPusher for RandFilePusherStd {
113    type Error = FilePusherError;
114    async fn push(&mut self, range: ProgressEntry, bytes: &[u8]) -> Result<(), Self::Error> {
115        let pos = self.cache.partition_point(|(i, _)| i < &range.start);
116        self.cache_size += bytes.len();
117        self.cache.insert(pos, (range.start, bytes.into()));
118        if self.cache_size >= self.buffer_size {
119            self.flush().await?;
120        }
121        Ok(())
122    }
123    async fn flush(&mut self) -> Result<(), Self::Error> {
124        while let Some((start, bytes)) = self.cache.front() {
125            let len = bytes.len();
126            self.cache_size -= len;
127            if *start != self.p {
128                self.buffer.seek(SeekFrom::Start(*start)).await?;
129                self.p = *start;
130            }
131            self.buffer.write_all(bytes).await?;
132            self.p += len as u64;
133            self.cache.pop_front();
134        }
135        self.buffer.flush().await?;
136        Ok(())
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use bytes::Bytes;
144    use std::vec::Vec;
145    use tempfile::NamedTempFile;
146    use tokio::io::AsyncReadExt;
147
148    #[tokio::test]
149    async fn test_seq_file_pusher() {
150        // 创建一个临时文件用于测试
151        let temp_file = NamedTempFile::new().unwrap();
152        let file_path = temp_file.path().to_path_buf();
153
154        // 初始化 SeqFilePusher
155        let mut pusher = SeqFilePusher::new(temp_file.reopen().unwrap().into(), 1024);
156
157        // 写入数据
158        let data1 = Bytes::from("Hello, ");
159        let data2 = Bytes::from("world!");
160        pusher.push(&data1).await.unwrap();
161        pusher.push(&data2).await.unwrap();
162        pusher.flush().await.unwrap();
163
164        // 验证文件内容
165        let mut file_content = Vec::new();
166        File::open(&file_path)
167            .await
168            .unwrap()
169            .read_to_end(&mut file_content)
170            .await
171            .unwrap();
172        assert_eq!(file_content, b"Hello, world!");
173    }
174
175    #[tokio::test]
176    async fn test_rand_file_pusher() {
177        // 创建一个临时文件用于测试
178        let temp_file = NamedTempFile::new().unwrap();
179        let file_path = temp_file.path();
180
181        // 初始化 RandFilePusher,假设文件大小为 10 字节
182        let mut pusher = RandFilePusherMmap::new(file_path, 10, 8 * 1024 * 1024)
183            .await
184            .unwrap();
185
186        // 写入数据
187        let data = Bytes::from("234");
188        let range = 2..5;
189        pusher.push(range, &data).await.unwrap();
190        pusher.flush().await.unwrap();
191
192        // 验证文件内容
193        let mut file_content = Vec::new();
194        File::open(&file_path)
195            .await
196            .unwrap()
197            .read_to_end(&mut file_content)
198            .await
199            .unwrap();
200        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
201    }
202}