fast_pull/file/
std.rs

1extern crate std;
2use crate::{ProgressEntry, RandPusher, SeqPusher, file::FilePusherError};
3use std::{boxed::Box, collections::VecDeque};
4use tokio::{
5    fs::File,
6    io::{AsyncSeekExt, AsyncWriteExt, BufWriter, SeekFrom},
7};
8
9#[derive(Debug)]
10pub struct FilePusher {
11    buffer: BufWriter<File>,
12    cache: VecDeque<(u64, Box<[u8]>)>,
13    p: u64,
14    cache_size: usize,
15    buffer_size: usize,
16}
17impl FilePusher {
18    pub async fn new(file: File, size: u64, buffer_size: usize) -> Result<Self, FilePusherError> {
19        file.set_len(size).await?;
20        Ok(Self {
21            buffer: BufWriter::with_capacity(buffer_size, file),
22            cache: VecDeque::new(),
23            p: 0,
24            cache_size: 0,
25            buffer_size,
26        })
27    }
28}
29impl SeqPusher for FilePusher {
30    type Error = FilePusherError;
31    async fn push(&mut self, content: &[u8]) -> Result<(), Self::Error> {
32        Ok(self.buffer.write_all(content).await?)
33    }
34    async fn flush(&mut self) -> Result<(), Self::Error> {
35        Ok(self.buffer.flush().await?)
36    }
37}
38impl RandPusher for FilePusher {
39    type Error = FilePusherError;
40    async fn push(&mut self, range: ProgressEntry, bytes: &[u8]) -> Result<(), Self::Error> {
41        let pos = self.cache.partition_point(|(i, _)| i < &range.start);
42        self.cache_size += bytes.len();
43        self.cache.insert(pos, (range.start, bytes.into()));
44        if self.cache_size >= self.buffer_size {
45            RandPusher::flush(self).await?;
46        }
47        Ok(())
48    }
49    async fn flush(&mut self) -> Result<(), Self::Error> {
50        while let Some((start, bytes)) = self.cache.front() {
51            let len = bytes.len();
52            if *start != self.p {
53                self.buffer.seek(SeekFrom::Start(*start)).await?;
54                self.p = *start;
55            }
56            self.buffer.write_all(bytes).await?;
57            self.cache.pop_front();
58            self.cache_size -= len;
59            self.p += len as u64;
60        }
61        self.buffer.flush().await?;
62        Ok(())
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use std::vec::Vec;
70    use tempfile::NamedTempFile;
71    use tokio::io::AsyncReadExt;
72
73    #[tokio::test]
74    async fn test_seq_file_pusher() {
75        // 创建一个临时文件用于测试
76        let temp_file = NamedTempFile::new().unwrap();
77        let file_path = temp_file.path().to_path_buf();
78
79        // 初始化 SeqFilePusher
80        let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 0, 1024)
81            .await
82            .unwrap();
83
84        // 写入数据
85        let data1 = b"Hello, ";
86        let data2 = b"world!";
87        SeqPusher::push(&mut pusher, &data1[..]).await.unwrap();
88        SeqPusher::push(&mut pusher, &data2[..]).await.unwrap();
89        SeqPusher::flush(&mut pusher).await.unwrap();
90
91        // 验证文件内容
92        let mut file_content = Vec::new();
93        File::open(&file_path)
94            .await
95            .unwrap()
96            .read_to_end(&mut file_content)
97            .await
98            .unwrap();
99        assert_eq!(file_content, b"Hello, world!");
100    }
101
102    #[tokio::test]
103    async fn test_rand_file_pusher() {
104        // 创建一个临时文件用于测试
105        let temp_file = NamedTempFile::new().unwrap();
106        let file_path = temp_file.path();
107
108        // 初始化 RandFilePusher,假设文件大小为 10 字节
109        let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 10, 8 * 1024 * 1024)
110            .await
111            .unwrap();
112
113        // 写入数据
114        let data = b"234";
115        let range = 2..5;
116        RandPusher::push(&mut pusher, range, &data[..])
117            .await
118            .unwrap();
119        RandPusher::flush(&mut pusher).await.unwrap();
120
121        // 验证文件内容
122        let mut file_content = Vec::new();
123        File::open(&file_path)
124            .await
125            .unwrap()
126            .read_to_end(&mut file_content)
127            .await
128            .unwrap();
129        assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
130    }
131}