Skip to main content

fast_pull/file/
std.rs

1use crate::{ProgressEntry, Pusher};
2use bytes::Bytes;
3use std::{
4    collections::BTreeMap,
5    fs::File,
6    io::{BufWriter, Seek, Write},
7};
8use tokio::io::SeekFrom;
9
10#[derive(Debug)]
11pub struct FilePusher {
12    buffer: BufWriter<File>,
13    cache: BTreeMap<u64, Bytes>,
14    p: u64,
15    cache_size: usize,
16    buffer_size: usize,
17}
18impl FilePusher {
19    /// # Errors
20    /// 1. 当 `fs::set_len` 失败时返回错误。
21    /// 2. 当 `BufWriter::with_capacity` 失败时返回错误。
22    pub async fn new(
23        file: tokio::fs::File,
24        size: u64,
25        buffer_size: usize,
26    ) -> tokio::io::Result<Self> {
27        file.set_len(size).await?;
28        Ok(Self {
29            buffer: BufWriter::with_capacity(buffer_size, file.into_std().await),
30            cache: BTreeMap::new(),
31            p: 0,
32            cache_size: 0,
33            buffer_size,
34        })
35    }
36}
37impl Pusher for FilePusher {
38    type Error = tokio::io::Error;
39    fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
40        if bytes.is_empty() {
41            return Ok(());
42        }
43        self.cache_size += bytes.len();
44        self.cache.insert(range.start, bytes);
45        if self.cache_size >= self.buffer_size {
46            self.flush().map_err(|e| {
47                let bytes = self.cache.remove(&range.start);
48                (e, bytes.unwrap_or_default())
49            })?;
50        }
51        Ok(())
52    }
53    fn flush(&mut self) -> Result<(), Self::Error> {
54        while let Some(entry) = self.cache.first_entry() {
55            let start = *entry.key();
56            let bytes = entry.get();
57            let len = bytes.len();
58            if start != self.p {
59                self.buffer.seek(SeekFrom::Start(start))?;
60                self.p = start;
61            }
62            self.buffer.write_all(bytes)?;
63            entry.remove_entry();
64            self.cache_size -= len;
65            self.p += len as u64;
66        }
67        self.buffer.flush()?;
68        Ok(())
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    #![allow(clippy::unwrap_used)]
75    use super::*;
76    use std::{io::Read, vec::Vec};
77    use tempfile::NamedTempFile;
78
79    #[tokio::test]
80    async fn test_rand_file_pusher() {
81        // 创建一个临时文件用于测试
82        let temp_file = NamedTempFile::new().unwrap();
83        let file_path = temp_file.path();
84
85        // 初始化 RandFilePusher,假设文件大小为 10 字节
86        let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 10, 8 * 1024 * 1024)
87            .await
88            .unwrap();
89
90        // 写入数据
91        let data = b"234";
92        let range = 2..5;
93        pusher.push(&range, data[..].into()).unwrap();
94        pusher.flush().unwrap();
95
96        // 验证文件内容
97        let mut file_content = Vec::new();
98        File::open(file_path)
99            .unwrap()
100            .read_to_end(&mut file_content)
101            .unwrap();
102        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
103    }
104}