use std::collections::HashMap;
use super::messages::{Dii, DownloadDataBlock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ModuleKey {
pub download_id: u32,
pub module_id: u16,
pub module_version: u8,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct Module {
pub key: ModuleKey,
pub data: Vec<u8>,
}
struct Slot {
block_size: usize,
data: Vec<u8>,
received: Vec<u64>,
n_blocks: usize,
remaining: usize,
}
impl Slot {
fn is_received(&self, n: usize) -> bool {
(self.received[n >> 6] >> (n & 63)) & 1 != 0
}
fn mark_received(&mut self, n: usize) {
self.received[n >> 6] |= 1 << (n & 63);
}
}
pub const DEFAULT_MAX_MODULE_SIZE: u32 = 64 * 1024 * 1024;
pub const DEFAULT_MAX_TOTAL_BYTES: usize = 256 * 1024 * 1024;
pub struct ModuleReassembler {
slots: HashMap<ModuleKey, Slot>,
max_module_size: u32,
max_total_bytes: usize,
total_bytes: usize,
}
impl Default for ModuleReassembler {
fn default() -> Self {
Self::new()
}
}
impl ModuleReassembler {
#[must_use]
pub fn new() -> Self {
Self::with_limits(DEFAULT_MAX_MODULE_SIZE, DEFAULT_MAX_TOTAL_BYTES)
}
#[must_use]
pub fn with_max_module_size(max_module_size: u32) -> Self {
Self::with_limits(max_module_size, DEFAULT_MAX_TOTAL_BYTES)
}
#[must_use]
pub fn with_limits(max_module_size: u32, max_total_bytes: usize) -> Self {
Self {
slots: HashMap::new(),
max_module_size,
max_total_bytes,
total_bytes: 0,
}
}
pub fn note_dii(&mut self, dii: &Dii<'_>) {
for m in &dii.modules {
if m.module_size > self.max_module_size || dii.block_size == 0 {
continue;
}
let key = ModuleKey {
download_id: dii.download_id,
module_id: m.module_id,
module_version: m.module_version,
};
let stale: Vec<ModuleKey> = self
.slots
.keys()
.filter(|k| {
k.download_id == key.download_id
&& k.module_id == key.module_id
&& k.module_version != key.module_version
})
.copied()
.collect();
for k in stale {
if let Some(s) = self.slots.remove(&k) {
self.total_bytes -= s.data.len();
}
}
if self.slots.contains_key(&key) {
continue; }
let size = m.module_size as usize;
if self.total_bytes + size > self.max_total_bytes {
continue; }
let block_size = dii.block_size as usize;
let n_blocks = size.div_ceil(block_size).max(1);
self.total_bytes += size;
self.slots.insert(
key,
Slot {
block_size,
data: vec![0u8; size],
received: vec![0u64; n_blocks.div_ceil(64)],
n_blocks,
remaining: n_blocks,
},
);
}
}
pub fn feed_ddb(&mut self, ddb: &DownloadDataBlock<'_>) -> Option<Module> {
let key = ModuleKey {
download_id: ddb.download_id,
module_id: ddb.module_id,
module_version: ddb.module_version,
};
let slot = self.slots.get_mut(&key)?;
let n = ddb.block_number as usize;
if n >= slot.n_blocks || slot.is_received(n) {
return None;
}
let offset = n * slot.block_size;
let expected = (slot.data.len() - offset).min(slot.block_size);
if ddb.block_data.len() != expected {
return None; }
slot.data[offset..offset + expected].copy_from_slice(ddb.block_data);
slot.mark_received(n);
slot.remaining -= 1;
if slot.remaining > 0 {
return None;
}
let slot = self.slots.remove(&key).expect("slot exists");
self.total_bytes -= slot.data.len();
Some(Module {
key,
data: slot.data,
})
}
#[must_use]
pub fn pending(&self) -> usize {
self.slots.len()
}
#[must_use]
pub fn pending_bytes(&self) -> usize {
self.total_bytes
}
}
#[cfg(test)]
mod tests {
use super::super::messages::DiiModule;
use super::*;
fn dii(download_id: u32, block_size: u16, modules: Vec<DiiModule<'static>>) -> Dii<'static> {
Dii {
transaction_id: 0x8000_0002,
adaptation: &[],
download_id,
block_size,
window_size: 0,
ack_period: 0,
t_c_download_window: 0,
t_c_download_scenario: 0,
compatibility_descriptor: &[],
modules,
private_data: &[],
}
}
fn module(module_id: u16, module_size: u32, module_version: u8) -> DiiModule<'static> {
DiiModule {
module_id,
module_size,
module_version,
module_info: &[],
}
}
fn ddb(
download_id: u32,
module_id: u16,
module_version: u8,
block_number: u16,
block_data: &[u8],
) -> DownloadDataBlock<'_> {
DownloadDataBlock {
download_id,
adaptation: &[],
module_id,
module_version,
block_number,
block_data,
}
}
#[test]
fn two_block_module_completes() {
let mut r = ModuleReassembler::new();
r.note_dii(&dii(1, 4, vec![module(7, 6, 0)]));
assert!(r.feed_ddb(&ddb(1, 7, 0, 0, &[1, 2, 3, 4])).is_none());
let m = r.feed_ddb(&ddb(1, 7, 0, 1, &[5, 6])).expect("complete");
assert_eq!(m.key.module_id, 7);
assert_eq!(m.data, vec![1, 2, 3, 4, 5, 6]);
assert_eq!(r.pending(), 0);
}
#[test]
fn out_of_order_blocks_complete() {
let mut r = ModuleReassembler::new();
r.note_dii(&dii(1, 2, vec![module(1, 4, 0)]));
assert!(r.feed_ddb(&ddb(1, 1, 0, 1, &[3, 4])).is_none());
let m = r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2])).expect("complete");
assert_eq!(m.data, vec![1, 2, 3, 4]);
}
#[test]
fn ddb_before_dii_is_ignored() {
let mut r = ModuleReassembler::new();
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2])).is_none());
r.note_dii(&dii(1, 2, vec![module(1, 2, 0)]));
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2])).is_some());
}
#[test]
fn version_mismatch_ignored_and_new_version_restarts() {
let mut r = ModuleReassembler::new();
r.note_dii(&dii(1, 2, vec![module(1, 4, 0)]));
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2])).is_none());
assert!(r.feed_ddb(&ddb(1, 1, 3, 1, &[9, 9])).is_none());
r.note_dii(&dii(1, 2, vec![module(1, 4, 3)]));
assert_eq!(r.pending(), 1);
assert!(r.feed_ddb(&ddb(1, 1, 3, 0, &[5, 6])).is_none());
let m = r.feed_ddb(&ddb(1, 1, 3, 1, &[7, 8])).expect("complete");
assert_eq!(m.key.module_version, 3);
assert_eq!(m.data, vec![5, 6, 7, 8]);
}
#[test]
fn repeated_dii_keeps_progress() {
let mut r = ModuleReassembler::new();
let d = dii(1, 2, vec![module(1, 4, 0)]);
r.note_dii(&d);
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2])).is_none());
r.note_dii(&d); let m = r.feed_ddb(&ddb(1, 1, 0, 1, &[3, 4])).expect("complete");
assert_eq!(m.data, vec![1, 2, 3, 4]);
}
#[test]
fn duplicate_and_out_of_range_blocks_ignored() {
let mut r = ModuleReassembler::new();
r.note_dii(&dii(1, 2, vec![module(1, 4, 0)]));
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2])).is_none());
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2])).is_none()); assert!(r.feed_ddb(&ddb(1, 1, 0, 9, &[9, 9])).is_none()); assert_eq!(r.pending(), 1);
}
#[test]
fn wrong_block_length_ignored() {
let mut r = ModuleReassembler::new();
r.note_dii(&dii(1, 4, vec![module(1, 6, 0)]));
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2, 3])).is_none());
assert!(r.feed_ddb(&ddb(1, 1, 0, 1, &[5, 6, 7])).is_none());
assert_eq!(r.pending(), 1);
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[1, 2, 3, 4])).is_none());
assert!(r.feed_ddb(&ddb(1, 1, 0, 1, &[5, 6])).is_some());
}
#[test]
fn oversize_module_skipped() {
let mut r = ModuleReassembler::with_max_module_size(8);
r.note_dii(&dii(1, 4, vec![module(1, 9, 0), module(2, 8, 0)]));
assert_eq!(r.pending(), 1); }
#[test]
fn zero_block_size_skipped() {
let mut r = ModuleReassembler::new();
r.note_dii(&dii(1, 0, vec![module(1, 4, 0)]));
assert_eq!(r.pending(), 0);
}
#[test]
fn aggregate_budget_bounds_total_memory() {
let mut r = ModuleReassembler::with_limits(8, 10);
r.note_dii(&dii(1, 4, vec![module(1, 8, 0)]));
assert_eq!(r.pending_bytes(), 8);
r.note_dii(&dii(2, 4, vec![module(1, 8, 0)]));
assert_eq!(r.pending(), 1);
assert_eq!(r.pending_bytes(), 8);
assert!(r.feed_ddb(&ddb(1, 1, 0, 0, &[0; 4])).is_none());
assert!(r.feed_ddb(&ddb(1, 1, 0, 1, &[0; 4])).is_some());
assert_eq!(r.pending_bytes(), 0);
r.note_dii(&dii(2, 4, vec![module(1, 8, 0)]));
assert_eq!(r.pending(), 1);
assert_eq!(r.pending_bytes(), 8);
}
#[test]
fn version_replacement_releases_budget() {
let mut r = ModuleReassembler::with_limits(8, 8);
r.note_dii(&dii(1, 4, vec![module(1, 8, 0)]));
assert_eq!(r.pending_bytes(), 8);
r.note_dii(&dii(1, 4, vec![module(1, 8, 1)]));
assert_eq!(r.pending(), 1);
assert_eq!(r.pending_bytes(), 8); }
#[test]
fn block_size_one_uses_bitset() {
let mut r = ModuleReassembler::new();
r.note_dii(&dii(1, 1, vec![module(1, 130, 0)]));
for i in 0..129u16 {
assert!(r.feed_ddb(&ddb(1, 1, 0, i, &[i as u8])).is_none());
assert!(r.feed_ddb(&ddb(1, 1, 0, i, &[i as u8])).is_none());
}
let m = r.feed_ddb(&ddb(1, 1, 0, 129, &[0x81])).expect("complete");
assert_eq!(m.data.len(), 130);
assert_eq!(m.data[129], 0x81);
}
}