use std::collections::BTreeMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use dashmap::DashMap;
use sha1::{Digest, Sha1};
use sha2::Sha256;
use crate::metainfo::MerkleTree;
pub const BLOCK_SIZE: u32 = 16384;
pub const MERKLE_BLOCK_SIZE: u32 = 16384;
#[derive(Clone)]
pub enum HashState {
V1(Sha1),
V2(Sha256),
}
impl HashState {
pub fn new_v1() -> Self {
HashState::V1(Sha1::new())
}
pub fn new_v2() -> Self {
HashState::V2(Sha256::new())
}
pub fn update(&mut self, data: &[u8]) {
match self {
HashState::V1(h) => h.update(data),
HashState::V2(h) => h.update(data),
}
}
pub fn finalize(self) -> Vec<u8> {
match self {
HashState::V1(h) => h.finalize().to_vec(),
HashState::V2(h) => h.finalize().to_vec(),
}
}
pub fn is_v1(&self) -> bool {
matches!(self, HashState::V1(_))
}
pub fn is_v2(&self) -> bool {
matches!(self, HashState::V2(_))
}
}
struct PieceBlocks {
blocks: BTreeMap<u32, Bytes>,
piece_length: u32,
bytes_hashed: u32,
#[allow(dead_code)]
started_at: Instant,
block_hashes: Option<Vec<[u8; 32]>>,
}
impl PieceBlocks {
fn new(piece_length: u32) -> Self {
Self {
blocks: BTreeMap::new(),
piece_length,
bytes_hashed: 0,
started_at: Instant::now(),
block_hashes: None,
}
}
fn new_v2(piece_length: u32) -> Self {
let block_count = piece_length.div_ceil(MERKLE_BLOCK_SIZE) as usize;
Self {
blocks: BTreeMap::new(),
piece_length,
bytes_hashed: 0,
started_at: Instant::now(),
block_hashes: Some(vec![[0u8; 32]; block_count]),
}
}
fn is_complete(&self) -> bool {
let block_count = self.piece_length.div_ceil(BLOCK_SIZE);
self.blocks.len() as u32 == block_count
}
fn total_bytes(&self) -> usize {
self.blocks.values().map(|b| b.len()).sum()
}
fn assemble(&self) -> Bytes {
let mut data = Vec::with_capacity(self.piece_length as usize);
for block in self.blocks.values() {
data.extend_from_slice(block);
}
Bytes::from(data)
}
}
type CacheKey = (String, u32);
pub struct BlockCache {
pieces: DashMap<CacheKey, PieceBlocks>,
hash_states: DashMap<CacheKey, HashState>,
memory_used: AtomicUsize,
memory_limit: usize,
}
impl BlockCache {
pub fn new(memory_limit: usize) -> Arc<Self> {
Arc::new(Self {
pieces: DashMap::new(),
hash_states: DashMap::new(),
memory_used: AtomicUsize::new(0),
memory_limit,
})
}
pub fn add_block(
&self,
info_hash: &str,
piece_index: u32,
offset: u32,
data: Bytes,
piece_length: u32,
hash_version: u8,
) -> bool {
let key = (info_hash.to_string(), piece_index);
let data_len = data.len();
let is_v2 = hash_version == 2;
{
let mut piece = self.pieces.entry(key.clone()).or_insert_with(|| {
let state = if is_v2 {
HashState::new_v2()
} else {
HashState::new_v1()
};
self.hash_states.insert(key.clone(), state);
if is_v2 {
PieceBlocks::new_v2(piece_length)
} else {
PieceBlocks::new(piece_length)
}
});
if is_v2 {
if let Some(ref mut block_hashes) = piece.block_hashes {
let block_index = (offset / MERKLE_BLOCK_SIZE) as usize;
if block_index < block_hashes.len() {
let mut hasher = Sha256::new();
hasher.update(&data);
block_hashes[block_index] = hasher.finalize().into();
}
}
}
if piece.blocks.insert(offset, data).is_none() {
self.memory_used.fetch_add(data_len, Ordering::Relaxed);
}
self.try_advance_hash(&key, &mut piece);
piece.is_complete()
}
}
fn try_advance_hash(&self, key: &CacheKey, piece: &mut PieceBlocks) {
if let Some(mut state) = self.hash_states.get_mut(key) {
let mut next_offset = piece.bytes_hashed;
while let Some(block) = piece.blocks.get(&next_offset) {
state.update(block);
next_offset += block.len() as u32;
}
piece.bytes_hashed = next_offset;
}
}
pub fn finalize_and_verify(&self, info_hash: &str, piece_index: u32, expected: &[u8]) -> bool {
let key = (info_hash.to_string(), piece_index);
if let Some(mut piece) = self.pieces.get_mut(&key) {
self.try_advance_hash(&key, &mut piece);
if piece.bytes_hashed != piece.piece_length {
return false;
}
}
if let Some((_, state)) = self.hash_states.remove(&key) {
let computed = state.finalize();
computed == expected
} else {
false
}
}
pub fn finalize_and_verify_v2(
&self,
info_hash: &str,
piece_index: u32,
expected: &[u8; 32],
full_piece_length: u32,
) -> bool {
let key = (info_hash.to_string(), piece_index);
if let Some(piece) = self.pieces.get(&key) {
if !piece.is_complete() {
return false;
}
if let Some(ref block_hashes) = piece.block_hashes {
let expected_blocks =
(full_piece_length as usize).div_ceil(MERKLE_BLOCK_SIZE as usize);
let mut padded_hashes = block_hashes.clone();
while padded_hashes.len() < expected_blocks {
padded_hashes.push([0u8; 32]);
}
let computed_root = Self::compute_merkle_root(&padded_hashes);
return &computed_root == expected;
}
}
false
}
pub fn finalize_and_verify_auto(
&self,
info_hash: &str,
piece_index: u32,
expected: &[u8],
piece_length: u32,
) -> bool {
if expected.len() == 32 {
let mut expected_arr = [0u8; 32];
expected_arr.copy_from_slice(expected);
self.finalize_and_verify_v2(info_hash, piece_index, &expected_arr, piece_length)
} else {
self.finalize_and_verify(info_hash, piece_index, expected)
}
}
fn compute_merkle_root(block_hashes: &[[u8; 32]]) -> [u8; 32] {
if block_hashes.is_empty() {
return [0u8; 32];
}
let tree = MerkleTree::from_piece_hashes(block_hashes.to_vec());
tree.root().unwrap_or([0u8; 32])
}
pub fn get_block_hashes(&self, info_hash: &str, piece_index: u32) -> Option<Vec<[u8; 32]>> {
let key = (info_hash.to_string(), piece_index);
self.pieces.get(&key).and_then(|p| p.block_hashes.clone())
}
pub fn get_assembled_piece(&self, info_hash: &str, piece_index: u32) -> Option<Bytes> {
let key = (info_hash.to_string(), piece_index);
self.pieces.get(&key).map(|p| p.assemble())
}
pub fn remove_piece(&self, info_hash: &str, piece_index: u32) -> Option<Bytes> {
let key = (info_hash.to_string(), piece_index);
self.hash_states.remove(&key);
if let Some((_, piece)) = self.pieces.remove(&key) {
let bytes_freed = piece.total_bytes();
self.memory_used.fetch_sub(bytes_freed, Ordering::Relaxed);
Some(piece.assemble())
} else {
None
}
}
pub fn has_piece(&self, info_hash: &str, piece_index: u32) -> bool {
let key = (info_hash.to_string(), piece_index);
self.pieces.contains_key(&key)
}
pub fn is_piece_complete(&self, info_hash: &str, piece_index: u32) -> bool {
let key = (info_hash.to_string(), piece_index);
self.pieces
.get(&key)
.map(|p| p.is_complete())
.unwrap_or(false)
}
pub fn memory_used(&self) -> usize {
self.memory_used.load(Ordering::Relaxed)
}
pub fn memory_limit(&self) -> usize {
self.memory_limit
}
pub fn is_under_pressure(&self) -> bool {
self.memory_used() > (self.memory_limit as f32 * 0.9) as usize
}
pub fn pieces_count(&self) -> usize {
self.pieces.len()
}
pub fn clear(&self) {
self.pieces.clear();
self.hash_states.clear();
self.memory_used.store(0, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metainfo::hash_block;
#[test]
fn test_block_cache_basic_operations() {
let cache = BlockCache::new(64 * 1024 * 1024);
let data = Bytes::from(vec![0u8; BLOCK_SIZE as usize]);
let is_complete = cache.add_block("test_hash", 0, 0, data, BLOCK_SIZE, 1);
assert!(is_complete);
assert!(cache.is_piece_complete("test_hash", 0));
assert!(cache.has_piece("test_hash", 0));
let piece = cache.get_assembled_piece("test_hash", 0);
assert!(piece.is_some());
assert_eq!(piece.unwrap().len(), BLOCK_SIZE as usize);
let removed = cache.remove_piece("test_hash", 0);
assert!(removed.is_some());
assert!(!cache.has_piece("test_hash", 0));
}
#[test]
fn test_block_cache_multi_block_piece() {
let cache = BlockCache::new(64 * 1024 * 1024);
let piece_length = BLOCK_SIZE * 4;
let block2 = Bytes::from(vec![2u8; BLOCK_SIZE as usize]);
let block0 = Bytes::from(vec![0u8; BLOCK_SIZE as usize]);
let block3 = Bytes::from(vec![3u8; BLOCK_SIZE as usize]);
let block1 = Bytes::from(vec![1u8; BLOCK_SIZE as usize]);
assert!(!cache.add_block("test", 0, BLOCK_SIZE * 2, block2, piece_length, 1));
assert!(!cache.add_block("test", 0, 0, block0, piece_length, 1));
assert!(!cache.add_block("test", 0, BLOCK_SIZE * 3, block3, piece_length, 1));
assert!(cache.add_block("test", 0, BLOCK_SIZE, block1, piece_length, 1));
assert!(cache.is_piece_complete("test", 0));
let piece = cache.get_assembled_piece("test", 0).unwrap();
assert_eq!(piece.len(), piece_length as usize);
assert!(piece[0..BLOCK_SIZE as usize].iter().all(|&b| b == 0));
assert!(piece[BLOCK_SIZE as usize..(BLOCK_SIZE * 2) as usize]
.iter()
.all(|&b| b == 1));
}
#[test]
fn test_block_cache_v1_verification() {
let cache = BlockCache::new(64 * 1024 * 1024);
let data: Vec<u8> = (0..BLOCK_SIZE as usize).map(|i| (i % 256) as u8).collect();
let data = Bytes::from(data);
let expected_hash = {
use sha1::{Digest, Sha1};
let mut hasher = Sha1::new();
hasher.update(&data);
hasher.finalize().to_vec()
};
cache.add_block("v1test", 0, 0, data, BLOCK_SIZE, 1);
assert!(cache.finalize_and_verify("v1test", 0, &expected_hash));
}
#[test]
fn test_block_cache_v1_verification_fails_wrong_hash() {
let cache = BlockCache::new(64 * 1024 * 1024);
let data = Bytes::from(vec![42u8; BLOCK_SIZE as usize]);
cache.add_block("v1test", 0, 0, data, BLOCK_SIZE, 1);
let wrong_hash = vec![0u8; 20];
assert!(!cache.finalize_and_verify("v1test", 0, &wrong_hash));
}
#[test]
fn test_block_cache_v2_single_block() {
let cache = BlockCache::new(64 * 1024 * 1024);
let data: Vec<u8> = (0..BLOCK_SIZE as usize).map(|i| (i % 256) as u8).collect();
let data = Bytes::from(data);
let expected_root = hash_block(&data);
cache.add_block("v2test", 0, 0, data, BLOCK_SIZE, 2);
assert!(cache.finalize_and_verify_v2("v2test", 0, &expected_root, BLOCK_SIZE));
assert!(cache.finalize_and_verify_auto("v2test", 0, &expected_root, BLOCK_SIZE));
}
#[test]
fn test_block_cache_v2_multi_block() {
let cache = BlockCache::new(64 * 1024 * 1024);
let piece_length = BLOCK_SIZE * 4;
let block0: Vec<u8> = (0..BLOCK_SIZE as usize).map(|_| 0u8).collect();
let block1: Vec<u8> = (0..BLOCK_SIZE as usize).map(|_| 1u8).collect();
let block2: Vec<u8> = (0..BLOCK_SIZE as usize).map(|_| 2u8).collect();
let block3: Vec<u8> = (0..BLOCK_SIZE as usize).map(|_| 3u8).collect();
let h0 = hash_block(&block0);
let h1 = hash_block(&block1);
let h2 = hash_block(&block2);
let h3 = hash_block(&block3);
let tree = MerkleTree::from_piece_hashes(vec![h0, h1, h2, h3]);
let expected_root = tree.root().unwrap();
cache.add_block("v2multi", 0, 0, Bytes::from(block0), piece_length, 2);
cache.add_block(
"v2multi",
0,
BLOCK_SIZE,
Bytes::from(block1),
piece_length,
2,
);
cache.add_block(
"v2multi",
0,
BLOCK_SIZE * 2,
Bytes::from(block2),
piece_length,
2,
);
cache.add_block(
"v2multi",
0,
BLOCK_SIZE * 3,
Bytes::from(block3),
piece_length,
2,
);
assert!(cache.is_piece_complete("v2multi", 0));
assert!(cache.finalize_and_verify_v2("v2multi", 0, &expected_root, piece_length));
}
#[test]
fn test_block_cache_v2_partial_piece() {
let cache = BlockCache::new(64 * 1024 * 1024);
let full_piece_length = BLOCK_SIZE * 4;
let actual_data_len = BLOCK_SIZE + 1000;
let block0: Vec<u8> = (0..BLOCK_SIZE as usize).map(|_| 0xAA).collect();
let block1: Vec<u8> = (0..1000usize).map(|_| 0xBB).collect();
let h0 = hash_block(&block0);
let h1 = hash_block(&block1);
let tree = MerkleTree::from_piece_hashes(vec![h0, h1, [0u8; 32], [0u8; 32]]);
let expected_root = tree.root().unwrap();
cache.add_block("partial", 0, 0, Bytes::from(block0), actual_data_len, 2);
cache.add_block(
"partial",
0,
BLOCK_SIZE,
Bytes::from(block1),
actual_data_len,
2,
);
assert!(cache.is_piece_complete("partial", 0));
assert!(cache.finalize_and_verify_v2("partial", 0, &expected_root, full_piece_length));
}
#[test]
fn test_block_cache_v2_verification_fails_wrong_hash() {
let cache = BlockCache::new(64 * 1024 * 1024);
let data = Bytes::from(vec![42u8; BLOCK_SIZE as usize]);
cache.add_block("v2wrong", 0, 0, data, BLOCK_SIZE, 2);
let wrong_root = [0xFFu8; 32];
assert!(!cache.finalize_and_verify_v2("v2wrong", 0, &wrong_root, BLOCK_SIZE));
}
#[test]
fn test_block_cache_get_block_hashes() {
let cache = BlockCache::new(64 * 1024 * 1024);
let block0 = Bytes::from(vec![0u8; BLOCK_SIZE as usize]);
let block1 = Bytes::from(vec![1u8; BLOCK_SIZE as usize]);
let piece_length = BLOCK_SIZE * 2;
cache.add_block("hashes", 0, 0, block0.clone(), piece_length, 2);
cache.add_block("hashes", 0, BLOCK_SIZE, block1.clone(), piece_length, 2);
let block_hashes = cache.get_block_hashes("hashes", 0);
assert!(block_hashes.is_some());
let block_hashes = block_hashes.unwrap();
assert_eq!(block_hashes.len(), 2);
assert_eq!(block_hashes[0], hash_block(&block0));
assert_eq!(block_hashes[1], hash_block(&block1));
}
#[test]
fn test_block_cache_memory_tracking() {
let cache = BlockCache::new(64 * 1024 * 1024);
assert_eq!(cache.memory_used(), 0);
let data = Bytes::from(vec![0u8; BLOCK_SIZE as usize]);
cache.add_block("mem", 0, 0, data, BLOCK_SIZE, 1);
assert_eq!(cache.memory_used(), BLOCK_SIZE as usize);
cache.remove_piece("mem", 0);
assert_eq!(cache.memory_used(), 0);
}
#[test]
fn test_block_cache_clear() {
let cache = BlockCache::new(64 * 1024 * 1024);
let data = Bytes::from(vec![0u8; BLOCK_SIZE as usize]);
cache.add_block("p1", 0, 0, data.clone(), BLOCK_SIZE, 1);
cache.add_block("p2", 0, 0, data.clone(), BLOCK_SIZE, 1);
cache.add_block("p3", 0, 0, data, BLOCK_SIZE, 1);
assert_eq!(cache.pieces_count(), 3);
cache.clear();
assert_eq!(cache.pieces_count(), 0);
assert_eq!(cache.memory_used(), 0);
}
#[test]
fn test_block_cache_finalize_and_verify_auto() {
let cache_v1 = BlockCache::new(64 * 1024 * 1024);
let cache_v2 = BlockCache::new(64 * 1024 * 1024);
let data = Bytes::from(vec![42u8; BLOCK_SIZE as usize]);
let v1_hash = {
use sha1::{Digest, Sha1};
let mut hasher = Sha1::new();
hasher.update(&data);
hasher.finalize().to_vec()
};
let v2_hash = hash_block(&data);
cache_v1.add_block("auto_v1", 0, 0, data.clone(), BLOCK_SIZE, 1);
cache_v2.add_block("auto_v2", 0, 0, data, BLOCK_SIZE, 2);
assert!(cache_v1.finalize_and_verify_auto("auto_v1", 0, &v1_hash, BLOCK_SIZE));
assert!(cache_v2.finalize_and_verify_auto("auto_v2", 0, &v2_hash, BLOCK_SIZE));
}
#[test]
fn test_compute_merkle_root_consistency() {
let hashes: Vec<[u8; 32]> = (0..4u8)
.map(|i| {
let mut h = [0u8; 32];
h[0] = i;
h
})
.collect();
let tree = MerkleTree::from_piece_hashes(hashes.clone());
let tree_root = tree.root().unwrap();
let cache_root = BlockCache::compute_merkle_root(&hashes);
assert_eq!(tree_root, cache_root);
}
}