use crate::{ProgressEntry, Pusher};
use bytes::Bytes;
use std::{
collections::BTreeMap,
fs::File,
io::{BufWriter, Seek, Write},
sync::Arc,
};
use tokio::io::SeekFrom;
#[derive(Debug)]
pub struct FilePusher {
buffer: BufWriter<Arc<File>>,
file: Arc<File>,
cache: BTreeMap<u64, Bytes>,
p: u64,
cache_size: usize,
buffer_size: usize,
}
impl FilePusher {
pub async fn new(
file: tokio::fs::File,
size: u64,
buffer_size: usize,
) -> std::io::Result<Self> {
file.set_len(size).await?;
let file = Arc::new(file.into_std().await);
Ok(Self {
buffer: BufWriter::with_capacity(buffer_size, file.clone()),
cache: BTreeMap::new(),
p: 0,
cache_size: 0,
file,
buffer_size,
})
}
pub fn write(&mut self) -> Result<(), std::io::Error> {
while let Some(entry) = self.cache.first_entry() {
let start = *entry.key();
let bytes = entry.get();
let len = bytes.len();
if start != self.p {
self.buffer.seek(SeekFrom::Start(start))?;
self.p = start;
}
self.buffer.write_all(bytes)?;
entry.remove_entry();
self.cache_size -= len;
self.p += len as u64;
}
self.buffer.flush()?;
Ok(())
}
}
impl Pusher for FilePusher {
type Error = std::io::Error;
fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
if bytes.is_empty() {
return Ok(());
}
self.cache_size += bytes.len();
self.cache.insert(range.start, bytes);
if self.cache_size >= self.buffer_size {
self.write().map_err(|e| {
let bytes = self.cache.remove(&range.start);
(e, bytes.unwrap_or_default())
})?;
}
Ok(())
}
fn flush(&mut self) -> Result<(), Self::Error> {
self.write()?;
self.file.sync_all()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use std::{io::Read, vec::Vec};
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_rand_file_pusher() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path();
let mut pusher = FilePusher::new(temp_file.reopen().unwrap().into(), 10, 8 * 1024 * 1024)
.await
.unwrap();
let data = b"234";
let range = 2..5;
pusher.push(&range, data[..].into()).unwrap();
pusher.flush().unwrap();
let mut file_content = Vec::new();
File::open(file_path)
.unwrap()
.read_to_end(&mut file_content)
.unwrap();
assert_eq!(file_content, b"\0\x00234\0\0\0\0\0");
}
}