use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use bytes::Bytes;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot};
use crate::error::DownloadError;
use crate::storage::cache::WriteBackCache;
pub(crate) enum WriterCommand {
Data {
offset: u64,
data: Bytes,
piece_id: Option<usize>,
},
FlushPiece {
piece_id: usize,
ack: oneshot::Sender<()>,
},
FlushAll {
sync_data: bool,
ack: oneshot::Sender<u64>,
},
}
pub(crate) struct WriterTask {
rx: mpsc::Receiver<WriterCommand>,
file: tokio::fs::File,
written_bytes: Arc<AtomicU64>,
cache: WriteBackCache,
budget_return: Arc<tokio::sync::Semaphore>,
cache_high_watermark: usize,
}
impl WriterTask {
pub fn new(
rx: mpsc::Receiver<WriterCommand>,
file: tokio::fs::File,
written_bytes: Arc<AtomicU64>,
budget_return: Arc<tokio::sync::Semaphore>,
cache_high_watermark: usize,
) -> Self {
Self {
rx,
file,
written_bytes,
cache: WriteBackCache::new(),
budget_return,
cache_high_watermark,
}
}
pub async fn run(mut self) -> Result<(), DownloadError> {
while let Some(cmd) = self.rx.recv().await {
match cmd {
WriterCommand::Data {
offset,
data,
piece_id,
} => {
let data_len = data.len();
match piece_id {
Some(piece_id) => {
self.cache.insert(piece_id, offset, data);
}
None => {
self.write_block(offset, &data).await?;
self.budget_return.add_permits(data_len);
}
}
if self.cache.total_bytes() >= self.cache_high_watermark {
self.flush_all().await?;
}
}
WriterCommand::FlushPiece { piece_id, ack } => {
self.flush_piece(piece_id).await?;
let _ = ack.send(());
}
WriterCommand::FlushAll { sync_data, ack } => {
self.flush_all().await?;
if sync_data {
self.sync_file().await?;
}
let _ = ack.send(self.written_bytes.load(Ordering::Acquire));
}
}
}
self.flush_all().await?;
self.file.flush().await?;
self.file.sync_all().await?;
Ok(())
}
async fn flush_piece(&mut self, piece_id: usize) -> Result<(), DownloadError> {
let blocks = self.cache.drain_piece(piece_id);
let mut flushed = 0usize;
for block in blocks {
self.write_block(block.offset, &block.data).await?;
flushed += block.data.len();
}
if flushed > 0 {
self.budget_return.add_permits(flushed);
}
Ok(())
}
async fn flush_all(&mut self) -> Result<(), DownloadError> {
let blocks = self.cache.drain_all();
let mut flushed = 0usize;
for block in blocks {
self.write_block(block.offset, &block.data).await?;
flushed += block.data.len();
}
if flushed > 0 {
self.budget_return.add_permits(flushed);
}
Ok(())
}
async fn sync_file(&mut self) -> Result<(), DownloadError> {
self.file.flush().await?;
self.file.sync_data().await?;
Ok(())
}
async fn write_block(&mut self, offset: u64, data: &[u8]) -> Result<(), DownloadError> {
self.file.seek(std::io::SeekFrom::Start(offset)).await?;
self.file.write_all(data).await?;
let end = offset + data.len() as u64;
self.written_bytes.fetch_max(end, Ordering::Release);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_writer_direct_write() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.bin");
let file = tokio::fs::File::create(&path).await.unwrap();
let budget = Arc::new(tokio::sync::Semaphore::new(1024 * 1024));
let written = Arc::new(AtomicU64::new(0));
let (tx, rx) = mpsc::channel(16);
let writer = WriterTask::new(rx, file, written.clone(), budget.clone(), 1024 * 1024);
let handle = tokio::spawn(writer.run());
tx.send(WriterCommand::Data {
offset: 0,
data: Bytes::from(vec![0xAA; 100]),
piece_id: None,
})
.await
.unwrap();
tx.send(WriterCommand::Data {
offset: 100,
data: Bytes::from(vec![0xBB; 100]),
piece_id: None,
})
.await
.unwrap();
drop(tx);
handle.await.unwrap().unwrap();
let content = std::fs::read(&path).unwrap();
assert_eq!(content.len(), 200);
assert!(content[..100].iter().all(|&b| b == 0xAA));
assert!(content[100..].iter().all(|&b| b == 0xBB));
assert_eq!(written.load(Ordering::Acquire), 200);
}
#[tokio::test]
async fn test_writer_with_piece_cache_and_flush() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test_piece.bin");
std::fs::write(&path, vec![0u8; 200]).unwrap();
let file = tokio::fs::OpenOptions::new()
.write(true)
.open(&path)
.await
.unwrap();
let budget = Arc::new(tokio::sync::Semaphore::new(1024 * 1024));
let written = Arc::new(AtomicU64::new(0));
let (tx, rx) = mpsc::channel(16);
let writer = WriterTask::new(rx, file, written.clone(), budget.clone(), 1024 * 1024);
let handle = tokio::spawn(writer.run());
tx.send(WriterCommand::Data {
offset: 0,
data: Bytes::from(vec![0xCC; 100]),
piece_id: Some(0),
})
.await
.unwrap();
let (ack_tx, ack_rx) = oneshot::channel();
tx.send(WriterCommand::FlushPiece {
piece_id: 0,
ack: ack_tx,
})
.await
.unwrap();
ack_rx.await.unwrap();
drop(tx);
handle.await.unwrap().unwrap();
let content = std::fs::read(&path).unwrap();
assert!(content[..100].iter().all(|&b| b == 0xCC));
}
#[tokio::test]
async fn test_writer_flush_all_with_sync() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test_flush_all.bin");
std::fs::write(&path, vec![0u8; 200]).unwrap();
let file = tokio::fs::OpenOptions::new()
.write(true)
.open(&path)
.await
.unwrap();
let budget = Arc::new(tokio::sync::Semaphore::new(1024 * 1024));
let written = Arc::new(AtomicU64::new(0));
let (tx, rx) = mpsc::channel(16);
let writer = WriterTask::new(rx, file, written.clone(), budget.clone(), 1024 * 1024);
let handle = tokio::spawn(writer.run());
tx.send(WriterCommand::Data {
offset: 0,
data: Bytes::from(vec![0xDD; 50]),
piece_id: Some(0),
})
.await
.unwrap();
let (ack_tx, ack_rx) = oneshot::channel();
tx.send(WriterCommand::FlushAll {
sync_data: true,
ack: ack_tx,
})
.await
.unwrap();
let written_val = ack_rx.await.unwrap();
assert!(written_val >= 50);
drop(tx);
handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_writer_high_watermark_flush() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test_hwm.bin");
std::fs::write(&path, vec![0u8; 1024]).unwrap();
let file = tokio::fs::OpenOptions::new()
.write(true)
.open(&path)
.await
.unwrap();
let budget = Arc::new(tokio::sync::Semaphore::new(1024 * 1024));
let written = Arc::new(AtomicU64::new(0));
let (tx, rx) = mpsc::channel(16);
let writer = WriterTask::new(rx, file, written.clone(), budget.clone(), 50);
let handle = tokio::spawn(writer.run());
tx.send(WriterCommand::Data {
offset: 0,
data: Bytes::from(vec![0xEE; 100]),
piece_id: Some(0),
})
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
drop(tx);
handle.await.unwrap().unwrap();
let content = std::fs::read(&path).unwrap();
assert!(content[..100].iter().all(|&b| b == 0xEE));
}
}