1extern crate std;
2use crate::{ProgressEntry, Pusher};
3use bytes::Bytes;
4use std::{
5 collections::BTreeMap,
6 fs::File,
7 io::{BufWriter, Seek, Write},
8};
9use tokio::io::SeekFrom;
10
11#[derive(Debug)]
12pub struct FilePusher {
13 buffer: BufWriter<File>,
14 cache: BTreeMap<u64, Bytes>,
15 p: u64,
16 cache_size: usize,
17 buffer_size: usize,
18}
19impl FilePusher {
20 pub async fn new(
21 file: tokio::fs::File,
22 size: u64,
23 buffer_size: usize,
24 ) -> tokio::io::Result<Self> {
25 file.set_len(size).await?;
26 Ok(Self {
27 buffer: BufWriter::with_capacity(buffer_size, file.into_std().await),
28 cache: BTreeMap::new(),
29 p: 0,
30 cache_size: 0,
31 buffer_size,
32 })
33 }
34}
35impl Pusher for FilePusher {
36 type Error = tokio::io::Error;
37 fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
38 if bytes.is_empty() {
39 return Ok(());
40 }
41 self.cache_size += bytes.len();
42 self.cache.insert(range.start, bytes);
43 if self.cache_size >= self.buffer_size {
44 self.flush().map_err(|e| {
45 let bytes = self.cache.remove(&range.start);
46 (e, bytes.unwrap_or_default())
47 })?;
48 }
49 Ok(())
50 }
51 fn flush(&mut self) -> Result<(), Self::Error> {
52 while let Some(entry) = self.cache.first_entry() {
53 let start = *entry.key();
54 let bytes = entry.get();
55 let len = bytes.len();
56 if start != self.p {
57 self.buffer.seek(SeekFrom::Start(start))?;
58 self.p = start;
59 }
60 self.buffer.write_all(bytes)?;
61 entry.remove_entry();
62 self.cache_size -= len;
63 self.p += len as u64;
64 }
65 self.buffer.flush()?;
66 Ok(())
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use std::{io::Read, vec::Vec};
74 use tempfile::NamedTempFile;
75
76 #[tokio::test]
77 async fn test_rand_file_pusher() {
78 let temp_file = NamedTempFile::new().unwrap();
80 let file_path = temp_file.path();
81
82 let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 10, 8 * 1024 * 1024)
84 .await
85 .unwrap();
86
87 let data = b"234";
89 let range = 2..5;
90 pusher.push(&range, data[..].into()).unwrap();
91 pusher.flush().unwrap();
92
93 let mut file_content = Vec::new();
95 File::open(file_path)
96 .unwrap()
97 .read_to_end(&mut file_content)
98 .unwrap();
99 assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
100 }
101}