1use crate::{ProgressEntry, Pusher};
2use bytes::Bytes;
3use std::{
4 collections::BTreeMap,
5 fs::File,
6 io::{BufWriter, Seek, Write},
7};
8use tokio::io::SeekFrom;
9
10#[derive(Debug)]
11pub struct FilePusher {
12 buffer: BufWriter<File>,
13 cache: BTreeMap<u64, Bytes>,
14 p: u64,
15 cache_size: usize,
16 buffer_size: usize,
17}
18impl FilePusher {
19 pub async fn new(
23 file: tokio::fs::File,
24 size: u64,
25 buffer_size: usize,
26 ) -> tokio::io::Result<Self> {
27 file.set_len(size).await?;
28 Ok(Self {
29 buffer: BufWriter::with_capacity(buffer_size, file.into_std().await),
30 cache: BTreeMap::new(),
31 p: 0,
32 cache_size: 0,
33 buffer_size,
34 })
35 }
36}
37impl Pusher for FilePusher {
38 type Error = tokio::io::Error;
39 fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
40 if bytes.is_empty() {
41 return Ok(());
42 }
43 self.cache_size += bytes.len();
44 self.cache.insert(range.start, bytes);
45 if self.cache_size >= self.buffer_size {
46 self.flush().map_err(|e| {
47 let bytes = self.cache.remove(&range.start);
48 (e, bytes.unwrap_or_default())
49 })?;
50 }
51 Ok(())
52 }
53 fn flush(&mut self) -> Result<(), Self::Error> {
54 while let Some(entry) = self.cache.first_entry() {
55 let start = *entry.key();
56 let bytes = entry.get();
57 let len = bytes.len();
58 if start != self.p {
59 self.buffer.seek(SeekFrom::Start(start))?;
60 self.p = start;
61 }
62 self.buffer.write_all(bytes)?;
63 entry.remove_entry();
64 self.cache_size -= len;
65 self.p += len as u64;
66 }
67 self.buffer.flush()?;
68 Ok(())
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 #![allow(clippy::unwrap_used)]
75 use super::*;
76 use std::{io::Read, vec::Vec};
77 use tempfile::NamedTempFile;
78
79 #[tokio::test]
80 async fn test_rand_file_pusher() {
81 let temp_file = NamedTempFile::new().unwrap();
83 let file_path = temp_file.path();
84
85 let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 10, 8 * 1024 * 1024)
87 .await
88 .unwrap();
89
90 let data = b"234";
92 let range = 2..5;
93 pusher.push(&range, data[..].into()).unwrap();
94 pusher.flush().unwrap();
95
96 let mut file_content = Vec::new();
98 File::open(file_path)
99 .unwrap()
100 .read_to_end(&mut file_content)
101 .unwrap();
102 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
103 }
104}