use std::collections::{BTreeMap, HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use bytes::Bytes;
use tokio::sync::{mpsc, oneshot, Mutex, Semaphore};
use super::error::StorageError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub enum WritePriority {
Low = 0,
#[default]
Normal = 1,
High = 2,
}
#[derive(Debug, Clone)]
pub struct WriteOp {
pub torrent_hash: String,
pub file_index: usize,
pub file_offset: u64,
pub data: Bytes,
pub priority: WritePriority,
}
#[derive(Debug, Clone)]
pub struct WriteRegion {
pub file_index: usize,
pub file_offset: u64,
pub data: Bytes,
}
#[derive(Debug)]
pub struct FlushRequest {
pub torrent_hash: String,
pub piece_index: u32,
pub regions: Vec<WriteRegion>,
pub piece_data: Bytes,
pub expected_hash: Vec<u8>,
}
#[derive(Debug)]
pub enum FlushResult {
Success,
HashMismatch,
Error(StorageError),
}
pub struct WriteCoalescer {
blocks: HashMap<String, BTreeMap<(usize, u64), Bytes>>,
max_buffer_size: usize,
current_size: AtomicU64,
}
impl WriteCoalescer {
pub fn new(max_buffer_size: usize) -> Self {
Self {
blocks: HashMap::new(),
max_buffer_size,
current_size: AtomicU64::new(0),
}
}
pub fn add_block(
&mut self,
torrent_hash: &str,
file_index: usize,
file_offset: u64,
data: Bytes,
) {
let len = data.len() as u64;
self.blocks
.entry(torrent_hash.to_string())
.or_default()
.insert((file_index, file_offset), data);
self.current_size.fetch_add(len, Ordering::Relaxed);
}
pub fn should_flush(&self) -> bool {
self.current_size.load(Ordering::Relaxed) as usize >= self.max_buffer_size
}
pub fn buffered_size(&self) -> usize {
self.current_size.load(Ordering::Relaxed) as usize
}
pub fn flush_torrent(&mut self, torrent_hash: &str) -> Vec<WriteRegion> {
let Some(blocks) = self.blocks.remove(torrent_hash) else {
return Vec::new();
};
let regions = coalesce_blocks_from_map(blocks);
let freed: u64 = regions.iter().map(|r| r.data.len() as u64).sum();
self.current_size.fetch_sub(freed, Ordering::Relaxed);
regions
}
pub fn flush_all(&mut self) -> HashMap<String, Vec<WriteRegion>> {
let mut result = HashMap::new();
let keys: Vec<String> = self.blocks.keys().cloned().collect();
for key in keys {
let regions = self.flush_torrent(&key);
if !regions.is_empty() {
result.insert(key, regions);
}
}
self.current_size.store(0, Ordering::Relaxed);
result
}
pub fn clear(&mut self) {
self.blocks.clear();
self.current_size.store(0, Ordering::Relaxed);
}
}
fn coalesce_blocks_from_map(blocks: BTreeMap<(usize, u64), Bytes>) -> Vec<WriteRegion> {
let mut regions = Vec::new();
let mut current_file: Option<usize> = None;
let mut current_offset: u64 = 0;
let mut current_data: Vec<u8> = Vec::new();
for ((file_index, offset), data) in blocks {
let can_coalesce = current_file == Some(file_index)
&& offset == current_offset + current_data.len() as u64;
if can_coalesce {
current_data.extend_from_slice(&data);
} else {
if !current_data.is_empty() {
regions.push(WriteRegion {
file_index: current_file.unwrap(),
file_offset: current_offset,
data: Bytes::from(std::mem::take(&mut current_data)),
});
}
current_file = Some(file_index);
current_offset = offset;
current_data = data.to_vec();
}
}
if !current_data.is_empty() {
regions.push(WriteRegion {
file_index: current_file.unwrap(),
file_offset: current_offset,
data: Bytes::from(current_data),
});
}
regions
}
pub fn coalesce_blocks(blocks: Vec<(usize, u64, Bytes)>) -> Vec<WriteRegion> {
let map: BTreeMap<(usize, u64), Bytes> = blocks
.into_iter()
.map(|(file_index, offset, data)| ((file_index, offset), data))
.collect();
coalesce_blocks_from_map(map)
}
type WriteQueueItem = (WriteOp, oneshot::Sender<Result<(), StorageError>>);
#[allow(dead_code)]
pub struct IoQueue {
writes: Mutex<VecDeque<WriteQueueItem>>,
semaphore: Arc<Semaphore>,
max_queue_size: usize,
}
impl IoQueue {
pub fn new(max_concurrent: usize, max_queue_size: usize) -> Self {
Self {
writes: Mutex::new(VecDeque::with_capacity(max_queue_size)),
semaphore: Arc::new(Semaphore::new(max_concurrent)),
max_queue_size,
}
}
pub async fn submit(
&self,
op: WriteOp,
) -> Result<oneshot::Receiver<Result<(), StorageError>>, StorageError> {
let (tx, rx) = oneshot::channel();
let mut writes = self.writes.lock().await;
if writes.len() >= self.max_queue_size {
return Err(StorageError::Io(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"I/O queue full",
)));
}
writes.push_back((op, tx));
Ok(rx)
}
pub async fn submit_and_wait(&self, op: WriteOp) -> Result<(), StorageError> {
let rx = self.submit(op).await?;
rx.await
.map_err(|_| StorageError::Io(std::io::Error::other("channel closed")))?
}
pub async fn pending_count(&self) -> usize {
self.writes.lock().await.len()
}
}
pub struct IoWorker {
shutdown_rx: mpsc::Receiver<()>,
queue: Arc<IoQueue>,
}
impl IoWorker {
pub fn new(queue: Arc<IoQueue>, shutdown_rx: mpsc::Receiver<()>) -> Self {
Self { shutdown_rx, queue }
}
pub async fn run(mut self) {
loop {
tokio::select! {
biased;
_ = self.shutdown_rx.recv() => {
break;
}
_ = Self::process_batch(&self.queue) => {}
}
}
}
async fn process_batch(_queue: &IoQueue) {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coalesce_adjacent_blocks() {
let blocks = vec![
(0, 0, Bytes::from(vec![1, 2, 3])),
(0, 3, Bytes::from(vec![4, 5, 6])),
(0, 6, Bytes::from(vec![7, 8, 9])),
];
let regions = coalesce_blocks(blocks);
assert_eq!(regions.len(), 1);
assert_eq!(regions[0].file_index, 0);
assert_eq!(regions[0].file_offset, 0);
assert_eq!(regions[0].data.as_ref(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_coalesce_non_adjacent_blocks() {
let blocks = vec![
(0, 0, Bytes::from(vec![1, 2, 3])),
(0, 10, Bytes::from(vec![4, 5, 6])), ];
let regions = coalesce_blocks(blocks);
assert_eq!(regions.len(), 2);
assert_eq!(regions[0].file_offset, 0);
assert_eq!(regions[1].file_offset, 10);
}
#[test]
fn test_coalesce_different_files() {
let blocks = vec![
(0, 0, Bytes::from(vec![1, 2, 3])),
(1, 0, Bytes::from(vec![4, 5, 6])), ];
let regions = coalesce_blocks(blocks);
assert_eq!(regions.len(), 2);
assert_eq!(regions[0].file_index, 0);
assert_eq!(regions[1].file_index, 1);
}
#[test]
fn test_write_coalescer() {
let mut coalescer = WriteCoalescer::new(1024 * 1024);
coalescer.add_block("hash1", 0, 0, Bytes::from(vec![1, 2, 3]));
coalescer.add_block("hash1", 0, 3, Bytes::from(vec![4, 5, 6]));
coalescer.add_block("hash2", 0, 0, Bytes::from(vec![7, 8, 9]));
assert_eq!(coalescer.buffered_size(), 9);
let regions = coalescer.flush_torrent("hash1");
assert_eq!(regions.len(), 1);
assert_eq!(regions[0].data.as_ref(), &[1, 2, 3, 4, 5, 6]);
assert_eq!(coalescer.buffered_size(), 3);
}
}