Skip to main content

fast_pull/file/
std.rs

1extern crate std;
2use crate::{ProgressEntry, Pusher};
3use bytes::Bytes;
4use std::{
5    collections::BTreeMap,
6    fs::File,
7    io::{BufWriter, Seek, Write},
8};
9use tokio::io::SeekFrom;
10
11#[derive(Debug)]
12pub struct FilePusher {
13    buffer: BufWriter<File>,
14    cache: BTreeMap<u64, Bytes>,
15    p: u64,
16    cache_size: usize,
17    buffer_size: usize,
18}
19impl FilePusher {
20    pub async fn new(
21        file: tokio::fs::File,
22        size: u64,
23        buffer_size: usize,
24    ) -> tokio::io::Result<Self> {
25        file.set_len(size).await?;
26        Ok(Self {
27            buffer: BufWriter::with_capacity(buffer_size, file.into_std().await),
28            cache: BTreeMap::new(),
29            p: 0,
30            cache_size: 0,
31            buffer_size,
32        })
33    }
34}
35impl Pusher for FilePusher {
36    type Error = tokio::io::Error;
37    fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
38        if bytes.is_empty() {
39            return Ok(());
40        }
41        self.cache_size += bytes.len();
42        self.cache.insert(range.start, bytes);
43        if self.cache_size >= self.buffer_size {
44            self.flush().map_err(|e| {
45                let bytes = self.cache.remove(&range.start);
46                (e, bytes.unwrap_or_default())
47            })?;
48        }
49        Ok(())
50    }
51    fn flush(&mut self) -> Result<(), Self::Error> {
52        while let Some(entry) = self.cache.first_entry() {
53            let start = *entry.key();
54            let bytes = entry.get();
55            let len = bytes.len();
56            if start != self.p {
57                self.buffer.seek(SeekFrom::Start(start))?;
58                self.p = start;
59            }
60            self.buffer.write_all(bytes)?;
61            entry.remove_entry();
62            self.cache_size -= len;
63            self.p += len as u64;
64        }
65        self.buffer.flush()?;
66        Ok(())
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use std::{io::Read, vec::Vec};
74    use tempfile::NamedTempFile;
75
76    #[tokio::test]
77    async fn test_rand_file_pusher() {
78        // 创建一个临时文件用于测试
79        let temp_file = NamedTempFile::new().unwrap();
80        let file_path = temp_file.path();
81
82        // 初始化 RandFilePusher,假设文件大小为 10 字节
83        let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 10, 8 * 1024 * 1024)
84            .await
85            .unwrap();
86
87        // 写入数据
88        let data = b"234";
89        let range = 2..5;
90        pusher.push(&range, data[..].into()).unwrap();
91        pusher.flush().unwrap();
92
93        // 验证文件内容
94        let mut file_content = Vec::new();
95        File::open(file_path)
96            .unwrap()
97            .read_to_end(&mut file_content)
98            .unwrap();
99        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
100    }
101}