use rustc_hash::FxHashMap;
use irontide_core::Lengths;
use crate::Bitfield;
pub struct ChunkTracker {
have: Bitfield,
in_progress: FxHashMap<u32, Bitfield>,
lengths: Lengths,
block_verified: Option<FxHashMap<u32, Bitfield>>,
}
impl ChunkTracker {
pub fn new(lengths: Lengths) -> Self {
let have = Bitfield::new(lengths.num_pieces());
ChunkTracker {
have,
in_progress: FxHashMap::default(),
lengths,
block_verified: None,
}
}
pub fn from_bitfield(have: Bitfield, lengths: Lengths) -> Self {
ChunkTracker {
have,
in_progress: FxHashMap::default(),
lengths,
block_verified: None,
}
}
pub fn chunk_received(&mut self, piece: u32, begin: u32) -> bool {
if piece >= self.lengths.num_pieces() {
return false;
}
let num_chunks = self.lengths.chunks_in_piece(piece);
let chunk_index = begin / self.lengths.chunk_size();
let chunk_bf = self
.in_progress
.entry(piece)
.or_insert_with(|| Bitfield::new(num_chunks));
chunk_bf.set(chunk_index);
chunk_bf.all_set()
}
pub fn mark_verified(&mut self, piece: u32) {
self.have.set(piece);
self.in_progress.remove(&piece);
}
pub fn mark_failed(&mut self, piece: u32) {
self.in_progress.remove(&piece);
if let Some(ref mut bv) = self.block_verified {
bv.remove(&piece);
}
}
pub fn clear_piece(&mut self, piece: u32) {
self.have.clear(piece);
}
pub fn has_chunk(&self, piece: u32, begin: u32) -> bool {
if self.have.get(piece) {
return true;
}
let chunk_index = begin / self.lengths.chunk_size();
self.in_progress
.get(&piece)
.is_some_and(|bf| bf.get(chunk_index))
}
pub fn has_piece(&self, piece: u32) -> bool {
self.have.get(piece)
}
pub fn bitfield(&self) -> &Bitfield {
&self.have
}
pub fn missing_chunks_into(&self, piece: u32, out: &mut Vec<(u32, u32)>) {
out.clear();
if self.have.get(piece) {
return;
}
let num_chunks = self.lengths.chunks_in_piece(piece);
match self.in_progress.get(&piece) {
Some(bf) => {
out.extend(
bf.zeros()
.filter_map(|ci| self.lengths.chunk_info(piece, ci)),
);
}
None => {
out.extend((0..num_chunks).filter_map(|ci| self.lengths.chunk_info(piece, ci)));
}
}
}
pub fn missing_chunks(&self, piece: u32) -> Vec<(u32, u32)> {
let mut out = Vec::new();
self.missing_chunks_into(piece, &mut out);
out
}
pub fn clear(&mut self) {
self.have = Bitfield::new(self.have.len());
self.in_progress.clear();
if let Some(ref mut bv) = self.block_verified {
bv.clear();
}
}
pub fn enable_v2_tracking(&mut self) {
self.block_verified = Some(FxHashMap::default());
}
pub fn has_v2_tracking(&self) -> bool {
self.block_verified.is_some()
}
pub fn mark_block_verified(&mut self, piece: u32, block_index: u32) {
if let Some(ref mut bv) = self.block_verified {
let num_chunks = self.lengths.chunks_in_piece(piece);
let bf = bv.entry(piece).or_insert_with(|| Bitfield::new(num_chunks));
bf.set(block_index);
}
}
pub fn is_block_verified(&self, piece: u32, block_index: u32) -> bool {
self.block_verified
.as_ref()
.and_then(|bv| bv.get(&piece))
.is_some_and(|bf| bf.get(block_index))
}
pub fn all_blocks_verified(&self, piece: u32) -> bool {
let Some(ref bv) = self.block_verified else {
return false;
};
let num_chunks = self.lengths.chunks_in_piece(piece);
bv.get(&piece)
.is_some_and(|bf| bf.count_ones() == num_chunks)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tracker() -> ChunkTracker {
ChunkTracker::new(Lengths::new(100000, 50000, 16384))
}
#[test]
fn new_all_missing() {
let ct = make_tracker();
assert!(!ct.has_piece(0));
assert!(!ct.has_piece(1));
assert_eq!(ct.bitfield().count_ones(), 0);
}
#[test]
fn chunk_received() {
let mut ct = make_tracker();
assert!(!ct.chunk_received(0, 0));
assert!(ct.has_chunk(0, 0));
assert!(!ct.has_chunk(0, 16384));
}
#[test]
fn piece_complete() {
let mut ct = make_tracker();
assert!(!ct.chunk_received(0, 0));
assert!(!ct.chunk_received(0, 16384));
assert!(!ct.chunk_received(0, 32768));
assert!(ct.chunk_received(0, 49152));
}
#[test]
fn mark_verified() {
let mut ct = make_tracker();
ct.chunk_received(0, 0);
ct.chunk_received(0, 16384);
ct.chunk_received(0, 32768);
ct.chunk_received(0, 49152);
ct.mark_verified(0);
assert!(ct.has_piece(0));
assert!(ct.has_chunk(0, 0)); assert_eq!(ct.bitfield().count_ones(), 1);
}
#[test]
fn mark_failed_resets() {
let mut ct = make_tracker();
ct.chunk_received(0, 0);
ct.chunk_received(0, 16384);
ct.mark_failed(0);
assert!(!ct.has_piece(0));
assert!(!ct.has_chunk(0, 0));
assert_eq!(ct.missing_chunks(0).len(), 4);
}
#[test]
fn has_chunk() {
let mut ct = make_tracker();
assert!(!ct.has_chunk(0, 0));
ct.chunk_received(0, 0);
assert!(ct.has_chunk(0, 0));
assert!(!ct.has_chunk(0, 16384));
}
#[test]
fn missing_chunks() {
let mut ct = make_tracker();
let missing = ct.missing_chunks(0);
assert_eq!(missing.len(), 4);
assert_eq!(missing[0], (0, 16384));
assert_eq!(missing[1], (16384, 16384));
ct.chunk_received(0, 0);
let missing = ct.missing_chunks(0);
assert_eq!(missing.len(), 3);
assert_eq!(missing[0], (16384, 16384));
}
#[test]
fn from_bitfield() {
let lengths = Lengths::new(100000, 50000, 16384);
let mut have = Bitfield::new(2);
have.set(0);
let ct = ChunkTracker::from_bitfield(have, lengths);
assert!(ct.has_piece(0));
assert!(!ct.has_piece(1));
assert!(ct.missing_chunks(0).is_empty());
}
#[test]
fn clear_piece_removes_from_have() {
let mut ct = make_tracker();
ct.mark_verified(0);
assert!(ct.has_piece(0));
ct.clear_piece(0);
assert!(!ct.has_piece(0));
assert_eq!(ct.bitfield().count_ones(), 0);
}
#[test]
fn v2_tracking_disabled_by_default() {
let ct = make_tracker();
assert!(!ct.is_block_verified(0, 0));
assert!(!ct.all_blocks_verified(0));
}
#[test]
fn enable_v2_and_mark_blocks() {
let mut ct = make_tracker();
ct.enable_v2_tracking();
assert!(!ct.is_block_verified(0, 0));
ct.mark_block_verified(0, 0);
assert!(ct.is_block_verified(0, 0));
assert!(!ct.is_block_verified(0, 1));
}
#[test]
fn all_blocks_verified_complete() {
let mut ct = make_tracker();
ct.enable_v2_tracking();
for i in 0..4 {
ct.mark_block_verified(0, i);
}
assert!(ct.all_blocks_verified(0));
}
#[test]
fn all_blocks_verified_incomplete() {
let mut ct = make_tracker();
ct.enable_v2_tracking();
ct.mark_block_verified(0, 0);
ct.mark_block_verified(0, 2);
assert!(!ct.all_blocks_verified(0));
}
#[test]
fn mark_failed_clears_v2_state() {
let mut ct = make_tracker();
ct.enable_v2_tracking();
ct.mark_block_verified(0, 0);
ct.mark_block_verified(0, 1);
ct.mark_failed(0);
assert!(!ct.is_block_verified(0, 0));
}
#[test]
fn clear_resets_all_state() {
let mut ct = make_tracker();
ct.enable_v2_tracking();
ct.chunk_received(0, 0);
ct.chunk_received(0, 16384);
ct.chunk_received(0, 32768);
ct.chunk_received(0, 49152);
ct.mark_verified(0);
assert!(ct.has_piece(0));
ct.chunk_received(1, 0);
ct.mark_block_verified(1, 0);
ct.clear();
assert!(!ct.has_piece(0), "have bitfield should be cleared");
assert_eq!(ct.bitfield().count_ones(), 0, "no pieces should be marked");
assert!(!ct.has_chunk(1, 0), "in_progress should be cleared");
assert!(
!ct.is_block_verified(1, 0),
"block_verified should be cleared"
);
assert_eq!(ct.missing_chunks(0).len(), 4);
assert_eq!(ct.missing_chunks(1).len(), 4);
}
}