1use crate::{ProgressEntry, Pusher};
2use bytes::Bytes;
3use std::{
4 collections::BTreeMap,
5 fs::File,
6 io::{BufWriter, Seek, Write},
7};
8use tokio::io::SeekFrom;
9
10#[derive(Debug)]
11pub struct FilePusher {
12 buffer: BufWriter<File>,
13 cache: BTreeMap<u64, Bytes>,
14 p: u64,
15 cache_size: usize,
16 buffer_size: usize,
17}
18impl FilePusher {
19 pub async fn new(
20 file: tokio::fs::File,
21 size: u64,
22 buffer_size: usize,
23 ) -> tokio::io::Result<Self> {
24 file.set_len(size).await?;
25 Ok(Self {
26 buffer: BufWriter::with_capacity(buffer_size, file.into_std().await),
27 cache: BTreeMap::new(),
28 p: 0,
29 cache_size: 0,
30 buffer_size,
31 })
32 }
33}
34impl Pusher for FilePusher {
35 type Error = tokio::io::Error;
36 fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
37 if bytes.is_empty() {
38 return Ok(());
39 }
40 self.cache_size += bytes.len();
41 self.cache.insert(range.start, bytes);
42 if self.cache_size >= self.buffer_size {
43 self.flush().map_err(|e| {
44 let bytes = self.cache.remove(&range.start);
45 (e, bytes.unwrap_or_default())
46 })?;
47 }
48 Ok(())
49 }
50 fn flush(&mut self) -> Result<(), Self::Error> {
51 while let Some(entry) = self.cache.first_entry() {
52 let start = *entry.key();
53 let bytes = entry.get();
54 let len = bytes.len();
55 if start != self.p {
56 self.buffer.seek(SeekFrom::Start(start))?;
57 self.p = start;
58 }
59 self.buffer.write_all(bytes)?;
60 entry.remove_entry();
61 self.cache_size -= len;
62 self.p += len as u64;
63 }
64 self.buffer.flush()?;
65 Ok(())
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::{io::Read, vec::Vec};
73 use tempfile::NamedTempFile;
74
75 #[tokio::test]
76 async fn test_rand_file_pusher() {
77 let temp_file = NamedTempFile::new().unwrap();
79 let file_path = temp_file.path();
80
81 let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 10, 8 * 1024 * 1024)
83 .await
84 .unwrap();
85
86 let data = b"234";
88 let range = 2..5;
89 pusher.push(&range, data[..].into()).unwrap();
90 pusher.flush().unwrap();
91
92 let mut file_content = Vec::new();
94 File::open(file_path)
95 .unwrap()
96 .read_to_end(&mut file_content)
97 .unwrap();
98 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
99 }
100}