use bytes::Bytes;
use dragonfly_api::common::v2::Range;
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{Error, Result};
use lru_cache::LruCache;
use std::cmp::{max, min};
use std::collections::HashMap;
use std::io::Cursor;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncRead, BufReader};
use tokio::sync::RwLock;
use tracing::{error, info};
pub mod lru_cache;
#[derive(Clone, Debug)]
struct Task {
content_length: u64,
pieces: Arc<RwLock<HashMap<String, Bytes>>>,
}
impl Task {
fn new(content_length: u64) -> Self {
Self {
content_length,
pieces: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn write_piece(&self, id: &str, piece: Bytes) {
let mut pieces = self.pieces.write().await;
pieces.insert(id.to_string(), piece);
}
async fn read_piece(&self, id: &str) -> Option<Bytes> {
let pieces = self.pieces.read().await;
pieces.get(id).cloned()
}
async fn contains(&self, id: &str) -> bool {
let pieces = self.pieces.read().await;
pieces.contains_key(id)
}
fn content_length(&self) -> u64 {
self.content_length
}
}
#[derive(Clone)]
pub struct Cache {
config: Arc<Config>,
size: Arc<AtomicU64>,
capacity: u64,
tasks: Arc<RwLock<LruCache<String, Task>>>,
}
impl Cache {
pub fn new(config: Arc<Config>) -> Self {
Cache {
config: config.clone(),
size: Arc::new(AtomicU64::new(0)),
capacity: config.storage.cache_capacity.as_u64(),
tasks: Arc::new(RwLock::new(LruCache::new(usize::MAX))),
}
}
pub async fn read_piece(
&self,
task_id: &str,
piece_id: &str,
piece: super::metadata::Piece,
range: Option<Range>,
) -> Result<impl AsyncRead> {
let mut tasks = self.tasks.write().await;
let Some(task) = tasks.get(task_id) else {
return Err(Error::TaskNotFound(task_id.to_string()));
};
let Some(piece_content) = task.read_piece(piece_id).await else {
return Err(Error::PieceNotFound(piece_id.to_string()));
};
drop(tasks);
let (target_offset, target_length) = if let Some(range) = range {
let target_offset = max(piece.offset, range.start) - piece.offset;
let target_length = min(
piece.offset + piece.length - 1,
range.start + range.length - 1,
) - target_offset
- piece.offset
+ 1;
(target_offset as usize, target_length as usize)
} else {
(0, piece.length as usize)
};
let begin = target_offset;
let end = target_offset + target_length;
if begin >= piece_content.len() || end > piece_content.len() {
error!(
"invalid range for piece {} in task {}: begin {}, end {}, piece length {}",
piece_id,
task_id,
begin,
end,
piece_content.len()
);
return Err(Error::InvalidParameter);
}
let content = piece_content.slice(begin..end);
let reader =
BufReader::with_capacity(self.config.storage.read_buffer_size, Cursor::new(content));
Ok(reader)
}
pub async fn write_piece(&self, task_id: &str, piece_id: &str, content: Bytes) -> Result<()> {
let mut tasks = self.tasks.write().await;
let Some(task) = tasks.get(task_id) else {
return Err(Error::TaskNotFound(task_id.to_string()));
};
if task.contains(piece_id).await {
return Ok(());
}
task.write_piece(piece_id, content).await;
Ok(())
}
pub async fn put_task(&mut self, task_id: &str, content_length: u64) {
if content_length == 0 {
return;
}
if content_length > self.capacity {
info!(
"task {} is too large and cannot be cached: {}",
task_id, content_length
);
return;
}
let mut tasks = self.tasks.write().await;
while self.size.load(Ordering::Relaxed) + content_length > self.capacity {
match tasks.pop_lru() {
Some((_, task)) => {
self.size
.fetch_sub(task.content_length(), Ordering::Relaxed);
}
None => {
break;
}
}
}
let task = Task::new(content_length);
tasks.put(task_id.to_string(), task);
self.size.fetch_add(content_length, Ordering::Relaxed);
}
pub async fn delete_task(&mut self, task_id: &str) -> Result<()> {
let mut tasks = self.tasks.write().await;
let Some((_, task)) = tasks.pop(task_id) else {
return Err(Error::TaskNotFound(task_id.to_string()));
};
self.size
.fetch_sub(task.content_length(), Ordering::Relaxed);
Ok(())
}
pub async fn contains_task(&self, id: &str) -> bool {
let tasks = self.tasks.read().await;
tasks.contains(id)
}
pub async fn contains_piece(&self, task_id: &str, piece_id: &str) -> bool {
let tasks = self.tasks.read().await;
if let Some(task) = tasks.peek(task_id) {
task.contains(piece_id).await
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::super::metadata::Piece;
use super::*;
use bytesize::ByteSize;
use dragonfly_api::common::v2::Range;
use dragonfly_client_config::dfdaemon::Storage;
use tokio::io::AsyncReadExt;
#[tokio::test]
async fn test_new() {
let test_cases = vec![
(Config::default(), 0, ByteSize::mib(64).as_u64()),
(
Config {
storage: Storage {
cache_capacity: ByteSize::mib(100),
..Default::default()
},
..Default::default()
},
0,
ByteSize::mib(100).as_u64(),
),
(
Config {
storage: Storage {
cache_capacity: ByteSize::b(0),
..Default::default()
},
..Default::default()
},
0,
0,
),
];
for (config, expected_size, expected_capacity) in test_cases {
let cache = Cache::new(Arc::new(config));
assert_eq!(cache.size.load(Ordering::Relaxed), expected_size);
assert_eq!(cache.capacity, expected_capacity);
}
}
#[tokio::test]
async fn test_contains_task() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let cache = Cache::new(Arc::new(config));
let test_cases = vec![
("check", "non_existent", 0, false),
("add", "task1", ByteSize::mib(1).as_u64(), true),
("check", "task1", 0, true),
("remove", "task1", 0, false),
("check", "task1", 0, false),
("add", "task1", ByteSize::mib(1).as_u64(), true),
("add", "task2", ByteSize::mib(2).as_u64(), true),
("check", "task1", 0, true),
("check", "task2", 0, true),
("check", "task3", 0, false),
];
for (operation, task_id, content_length, expected_result) in test_cases {
match operation {
"check" => {
assert_eq!(cache.contains_task(task_id).await, expected_result);
}
"add" => {
let task = Task::new(content_length);
cache.tasks.write().await.put(task_id.to_string(), task);
assert_eq!(cache.contains_task(task_id).await, expected_result);
}
"remove" => {
cache.tasks.write().await.pop_lru();
assert_eq!(cache.contains_task(task_id).await, expected_result);
}
_ => panic!("Unknown operation."),
}
}
}
#[tokio::test]
async fn test_put_task() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
let test_cases = vec![
("empty_task", 0, false),
("equal_capacity", ByteSize::mib(10).as_u64(), true),
("exceed_capacity", ByteSize::mib(10).as_u64() + 1, false),
("normal_task", ByteSize::mib(1).as_u64(), true),
];
for (task_id, size, should_exist) in test_cases {
if size > 0 {
cache.put_task(task_id, size).await;
}
assert_eq!(cache.contains_task(task_id).await, should_exist);
}
}
#[tokio::test]
async fn test_put_task_lru() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(5),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
let test_cases = vec![
("lru_task_1", ByteSize::mib(2).as_u64(), true),
("lru_task_2", ByteSize::mib(2).as_u64(), true),
("lru_task_3", ByteSize::mib(2).as_u64(), true),
("lru_task_1", 0, false),
("lru_task_2", 0, true),
("lru_task_3", 0, true),
];
for (task_id, size, should_exist) in test_cases {
if size > 0 {
cache.put_task(task_id, size).await;
}
assert_eq!(cache.contains_task(task_id).await, should_exist);
}
}
#[tokio::test]
async fn test_delete_task() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
cache.put_task("task1", ByteSize::mib(1).as_u64()).await;
cache.put_task("task2", ByteSize::mib(1).as_u64()).await;
cache.put_task("task3", ByteSize::mib(1).as_u64()).await;
let test_cases = vec![
("task1", true),
("task2", true),
("task3", true),
("nonexistent", false),
("", false),
("large_task", false),
];
for (task_id, exists) in test_cases {
assert_eq!(cache.contains_task(task_id).await, exists);
let result = cache.delete_task(task_id).await;
if exists {
assert!(result.is_ok());
} else {
assert!(result.is_err());
}
assert!(!cache.contains_task(task_id).await);
}
assert!(!cache.contains_task("task1").await);
assert!(!cache.contains_task("task2").await);
assert!(!cache.contains_task("task3").await);
assert!(!cache.contains_task("nonexistent").await);
assert!(!cache.contains_task("").await);
assert!(!cache.contains_task("large_task").await);
}
#[tokio::test]
async fn test_contains_piece() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
let test_cases = vec![
("check", "non_existent", "piece1", "", false),
("check", "non_existent", "", "", false),
("add_task", "task1", "", "", true),
("check", "task1", "piece1", "", false),
("add_piece", "task1", "piece1", "test data", true),
("check", "task1", "piece1", "", true),
("check", "task1", "", "", false),
("check", "task1", "non_existent_piece", "", false),
("add_piece", "task1", "piece#$%^&*", "test data", true),
("check", "task1", "piece#$%^&*", "", true),
];
for (operation, task_id, piece_id, content, expected_result) in test_cases {
match operation {
"check" => {
assert_eq!(
cache.contains_piece(task_id, piece_id).await,
expected_result
);
}
"add_task" => {
cache.put_task(task_id, 1000).await;
assert!(cache.contains_task(task_id).await);
}
"add_piece" => {
cache
.write_piece(task_id, piece_id, Bytes::from(content))
.await
.unwrap();
assert_eq!(
cache.contains_piece(task_id, piece_id).await,
expected_result
);
}
_ => panic!("Unknown operation."),
}
}
}
#[tokio::test]
async fn test_write_piece() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
let test_data = b"test data".to_vec();
let result = cache
.write_piece("non_existent", "piece1", Bytes::from(test_data))
.await;
assert!(matches!(result, Err(Error::TaskNotFound(_))));
cache.put_task("task1", ByteSize::mib(1).as_u64()).await;
assert!(cache.contains_task("task1").await);
let test_cases = vec![
("piece1", b"hello world".to_vec()),
("piece2", b"rust programming".to_vec()),
("piece3", b"dragonfly cache".to_vec()),
("piece4", b"unit testing".to_vec()),
("piece5", b"async await".to_vec()),
("piece6", b"error handling".to_vec()),
("piece7", vec![0u8; 1024]),
("piece8", vec![1u8; 2048]),
];
for (piece_id, content) in &test_cases {
let result = cache
.write_piece("task1", piece_id, Bytes::copy_from_slice(content))
.await;
assert!(result.is_ok());
assert!(cache.contains_piece("task1", piece_id).await);
let piece = Piece {
number: 0,
offset: 0,
length: content.len() as u64,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
};
let mut reader = cache
.read_piece("task1", piece_id, piece, None)
.await
.unwrap();
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await.unwrap();
assert_eq!(buffer, *content);
}
for (piece_id, original_content) in &test_cases {
let new_content = format!("updated content for {}", piece_id);
let result = cache
.write_piece("task1", piece_id, Bytes::from(new_content))
.await;
assert!(result.is_ok());
let piece = Piece {
number: 0,
offset: 0,
length: original_content.len() as u64,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
};
let mut reader = cache
.read_piece("task1", piece_id, piece, None)
.await
.unwrap();
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await.unwrap();
assert_eq!(buffer, *original_content);
}
}
#[tokio::test]
async fn test_read_piece() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(100),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
let piece = Piece {
number: 0,
offset: 0,
length: 11,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
};
let result = cache
.read_piece("non_existent", "piece1", piece.clone(), None)
.await;
assert!(matches!(result, Err(Error::TaskNotFound(_))));
cache.put_task("task1", ByteSize::mib(50).as_u64()).await;
let result = cache
.read_piece("task1", "non_existent", piece.clone(), None)
.await;
assert!(matches!(result, Err(Error::PieceNotFound(_))));
let test_pieces = vec![
(
"piece1",
b"hello world".to_vec(),
Piece {
number: 0,
offset: 0,
length: 11,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
},
vec![
(None, b"hello world".to_vec()),
(
Some(Range {
start: 0,
length: 5,
}),
b"hello".to_vec(),
),
],
),
(
"piece2",
b"rust lang".to_vec(),
Piece {
number: 1,
offset: 11,
length: 9,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
},
vec![
(None, b"rust lang".to_vec()),
(
Some(Range {
start: 11,
length: 4,
}),
b"rust".to_vec(),
),
],
),
(
"piece3",
b"unit test".to_vec(),
Piece {
number: 2,
offset: 20,
length: 9,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
},
vec![
(None, b"unit test".to_vec()),
(
Some(Range {
start: 20,
length: 4,
}),
b"unit".to_vec(),
),
],
),
(
"large_piece",
{
let size = ByteSize::mib(50).as_u64();
(0..size).map(|i| (i % 256) as u8).collect()
},
Piece {
number: 2,
offset: 0,
length: ByteSize::mib(50).as_u64(),
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
},
vec![
(
None,
(0..ByteSize::mib(50).as_u64())
.map(|i| (i % 256) as u8)
.collect(),
),
(
Some(Range {
start: 0,
length: ByteSize::mib(1).as_u64(),
}),
(0..ByteSize::mib(1).as_u64())
.map(|i| (i % 256) as u8)
.collect(),
),
(
Some(Range {
start: ByteSize::mib(49).as_u64(),
length: ByteSize::mib(1).as_u64(),
}),
(ByteSize::mib(49).as_u64()..ByteSize::mib(50).as_u64())
.map(|i| (i % 256) as u8)
.collect(),
),
],
),
];
for (id, content, _, _) in &test_pieces {
cache
.write_piece("task1", id, Bytes::copy_from_slice(content))
.await
.unwrap();
}
for (id, _, piece, ranges) in &test_pieces {
for (range, expected_content) in ranges {
let mut reader = cache
.read_piece("task1", id, piece.clone(), *range)
.await
.unwrap();
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await.unwrap();
assert_eq!(&buffer, expected_content);
}
}
}
#[tokio::test]
async fn test_concurrent_read_same_piece() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
cache.put_task("task1", ByteSize::mib(1).as_u64()).await;
let content = b"test data for concurrent read".to_vec();
cache
.write_piece("task1", "piece1", Bytes::from(content.clone()))
.await
.unwrap();
let cache_arc = Arc::new(cache);
let mut join_set = tokio::task::JoinSet::new();
for i in 0..50 {
let cache_clone = cache_arc.clone();
let expected_content = content.clone();
join_set.spawn(async move {
let piece = Piece {
number: 0,
offset: 0,
length: expected_content.len() as u64,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
};
let range = if i % 2 == 0 {
None
} else {
Some(Range {
start: 0,
length: 5,
})
};
let mut reader = cache_clone
.read_piece("task1", "piece1", piece, range)
.await
.unwrap_or_else(|e| panic!("Reader {} failed: {:?}.", i, e));
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await.unwrap();
if let Some(range) = range {
assert_eq!(buffer, &expected_content[..range.length as usize]);
} else {
assert_eq!(buffer, expected_content);
}
});
}
while let Some(result) = join_set.join_next().await {
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_concurrent_write_different_pieces() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
cache.put_task("task1", ByteSize::mib(1).as_u64()).await;
let cache_arc = Arc::new(cache);
let mut join_set = tokio::task::JoinSet::new();
for i in 0..50 {
let cache_clone = cache_arc.clone();
let content = format!("content for piece {}", i).into_bytes();
join_set.spawn(async move {
let piece_id = format!("piece{}", i);
let result = cache_clone
.write_piece("task1", &piece_id, Bytes::from(content.clone()))
.await;
assert!(result.is_ok());
let piece = Piece {
number: 0,
offset: 0,
length: content.len() as u64,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
};
let mut reader = cache_clone
.read_piece("task1", &piece_id, piece, None)
.await
.unwrap();
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await.unwrap();
assert_eq!(buffer, content);
});
}
while let Some(result) = join_set.join_next().await {
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_concurrent_write_same_piece() {
let config = Config {
storage: Storage {
cache_capacity: ByteSize::mib(10),
..Default::default()
},
..Default::default()
};
let mut cache = Cache::new(Arc::new(config));
cache.put_task("task1", ByteSize::mib(1).as_u64()).await;
let original_content = b"original content".to_vec();
cache
.write_piece("task1", "piece1", Bytes::from(original_content.clone()))
.await
.unwrap();
let cache_arc = Arc::new(cache);
let mut join_set = tokio::task::JoinSet::new();
for i in 0..50 {
let cache_clone = cache_arc.clone();
let new_content = format!("new content from writer {}", i).into_bytes();
join_set.spawn(async move {
let result = cache_clone
.write_piece("task1", "piece1", Bytes::from(new_content))
.await;
assert!(result.is_ok());
});
}
while let Some(result) = join_set.join_next().await {
assert!(result.is_ok());
}
let piece = Piece {
number: 0,
offset: 0,
length: original_content.len() as u64,
digest: "".to_string(),
parent_id: None,
uploading_count: 0,
uploaded_count: 0,
updated_at: chrono::Utc::now().naive_utc(),
created_at: chrono::Utc::now().naive_utc(),
finished_at: None,
};
let mut reader = cache_arc
.read_piece("task1", "piece1", piece, None)
.await
.unwrap();
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await.unwrap();
assert_eq!(buffer, original_content);
}
}