use std::borrow::Cow;
use std::collections::HashMap;
use std::ops::{Range, RangeBounds};
use std::time::{Duration, Instant};
use bitvec::vec::BitVec;
use bytes::{BufMut, BytesMut};
use num_integer::Integer;
use rand::prelude::{SliceRandom, ThreadRng};
use crate::entropy_bank::EntropyBank;
use crate::packet_vector::{generate_packet_vector, PacketVector};
use crate::prelude::{CryptError, SecurityLevel};
use crate::stacked_ratchet::Ratchet;
#[cfg(not(target_family = "wasm"))]
use rayon::{iter::IndexedParallelIterator, prelude::*};
pub const MAX_BYTES_PER_GROUP: usize = 1024 * 1024 * 10;
pub const MAX_WAVEFORM_PACKET_SIZE: usize = 480;
pub const AES_GCM_GHASH_OVERHEAD: usize = 16;
pub fn get_max_packet_size(
enx: EncryptionAlgorithm,
sig_alg: SigAlgorithm,
security_level: SecurityLevel,
) -> usize {
const BASE: usize = 2;
let security_exponent =
std::cmp::min(security_level.value(), SecurityLevel::Standard.value()) as u32;
let mut starting_max_packet_size = enx.max_ciphertext_len(MAX_WAVEFORM_PACKET_SIZE, sig_alg);
starting_max_packet_size += 8; std::cmp::max(
starting_max_packet_size / (BASE.pow(security_exponent)),
get_aes_gcm_overhead(),
)
}
pub(crate) const fn get_aes_gcm_overhead() -> usize {
AES_GCM_GHASH_OVERHEAD
}
pub fn calculate_aes_gcm_plaintext_length_from_ciphertext_length(
ciphertext: &[u8],
enx: EncryptionAlgorithm,
) -> Option<usize> {
enx.plaintext_length(ciphertext)
}
#[allow(clippy::too_many_arguments)]
pub fn generate_scrambler_metadata<T: AsRef<[u8]>>(
msg_drill: &EntropyBank,
plain_text: T,
header_size_bytes: usize,
security_level: SecurityLevel,
group_id: u64,
object_id: u64,
enx: EncryptionAlgorithm,
sig_alg: SigAlgorithm,
transfer_type: &TransferType,
) -> Result<GroupReceiverConfig, CryptError<String>> {
let plain_text = plain_text.as_ref();
if plain_text.is_empty() {
return Err(CryptError::Encrypt("Empty input".to_string()));
}
let max_packet_payload_size = get_max_packet_size(enx, sig_alg, security_level);
let overhead = max_packet_payload_size - MAX_WAVEFORM_PACKET_SIZE;
let max_packets_per_wave = msg_drill.get_multiport_width();
let max_plaintext_bytes_per_wave = (max_packet_payload_size * max_packets_per_wave) - overhead;
let (number_of_full_waves, number_of_partial_waves, bytes_in_last_wave) =
if plain_text.len() < max_plaintext_bytes_per_wave {
let (_, bytes_in_last_wave) = plain_text.len().div_rem(&max_plaintext_bytes_per_wave);
(0, 1, bytes_in_last_wave)
} else if plain_text.len() % max_plaintext_bytes_per_wave == 0 {
let number_of_full_waves = plain_text.len() / max_plaintext_bytes_per_wave;
(number_of_full_waves, 0, max_plaintext_bytes_per_wave)
} else {
let (number_of_full_waves, bytes_in_last_wave) =
plain_text.len().div_rem(&max_plaintext_bytes_per_wave);
let number_of_partial_waves = 1;
(
number_of_full_waves,
number_of_partial_waves,
bytes_in_last_wave,
)
};
let ciphertext_len_last_wave = if number_of_partial_waves != 0 {
8 + enx.max_ciphertext_len(bytes_in_last_wave, sig_alg)
} else {
0
};
let cfg = GroupReceiverConfig::new_refresh(
group_id,
object_id,
header_size_bytes as u64,
plain_text.len() as u64,
max_packet_payload_size as u32,
number_of_full_waves as u32,
number_of_partial_waves,
max_plaintext_bytes_per_wave as u64,
bytes_in_last_wave as u64,
max_packets_per_wave as u32,
ciphertext_len_last_wave as u32,
transfer_type,
);
Ok(cfg)
}
fn get_scramble_encrypt_config<'a, R: Ratchet>(
hyper_ratchet: &'a R,
plain_text: &'a [u8],
header_size_bytes: usize,
security_level: SecurityLevel,
group_id: u64,
object_id: u64,
transfer_type: &TransferType,
) -> Result<
(
GroupReceiverConfig,
&'a EntropyBank,
&'a PostQuantumContainer,
&'a EntropyBank,
),
CryptError<String>,
> {
let (msg_pqc, msg_drill) = hyper_ratchet.message_pqc_drill(None);
let scramble_drill = hyper_ratchet.get_scramble_drill();
let cfg = generate_scrambler_metadata(
msg_drill,
plain_text,
header_size_bytes,
security_level,
group_id,
object_id,
msg_pqc.params.encryption_algorithm,
msg_pqc.params.sig_algorithm,
transfer_type,
)?;
Ok((cfg, msg_drill, msg_pqc, scramble_drill))
}
#[derive(Clone)]
pub struct PacketCoordinate {
pub packet: BytesMut,
pub vector: PacketVector,
}
#[allow(clippy::too_many_arguments)]
pub fn par_scramble_encrypt_group<T: AsRef<[u8]>, R: Ratchet, F, const N: usize>(
plain_text: T,
security_level: SecurityLevel,
hyper_ratchet: &R,
static_aux_ratchet: &R,
header_size_bytes: usize,
target_cid: u64,
object_id: u64,
group_id: u64,
transfer_type: TransferType,
header_inscriber: F,
) -> Result<GroupSenderDevice<N>, CryptError<String>>
where
F: Fn(&PacketVector, &EntropyBank, u64, u64, &mut BytesMut) + Send + Sync,
{
let mut plain_text = Cow::Borrowed(plain_text.as_ref());
if let TransferType::RemoteEncryptedVirtualFilesystem { security_level, .. } = &transfer_type {
let local_encrypted = static_aux_ratchet
.local_encrypt(plain_text, *security_level)
.unwrap();
plain_text = Cow::Owned(local_encrypted);
}
let (mut cfg, msg_drill, msg_pqc, scramble_drill) = get_scramble_encrypt_config(
hyper_ratchet,
&plain_text,
header_size_bytes,
security_level,
group_id,
object_id,
&transfer_type,
)?;
#[cfg(not(target_family = "wasm"))]
let chunks = plain_text.par_chunks(cfg.max_plaintext_wave_length as usize);
#[cfg(target_family = "wasm")]
let chunks = plain_text.chunks(cfg.max_plaintext_wave_length as usize);
let packets = chunks
.enumerate()
.map(|(wave_idx, bytes_to_encrypt_for_this_wave)| {
scramble_encrypt_wave(
wave_idx,
bytes_to_encrypt_for_this_wave,
&cfg,
msg_drill,
msg_pqc,
scramble_drill,
target_cid,
object_id,
header_size_bytes,
&header_inscriber,
)
})
.flatten()
.collect::<HashMap<usize, PacketCoordinate>>();
debug_assert_ne!(cfg.last_plaintext_wave_length, 0);
if msg_pqc.params.encryption_algorithm != EncryptionAlgorithm::Kyber
&& matches!(&transfer_type, TransferType::FileTransfer)
{
debug_assert_eq!(cfg.packets_needed, packets.len() as _);
} else {
let last_wave_idx = cfg.wave_count - 1;
let ciphertext_len: usize = packets
.values()
.filter_map(|r| {
if r.vector.wave_id == last_wave_idx {
Some(r.packet.len() - N)
} else {
None
}
})
.sum();
cfg = GroupReceiverConfig::new_refresh(
cfg.group_id,
cfg.object_id,
cfg.header_size_bytes,
plain_text.len() as u64,
cfg.max_payload_size as u32,
cfg.number_of_full_waves,
cfg.number_of_partial_waves,
cfg.max_plaintext_wave_length as u64,
cfg.last_plaintext_wave_length as u64,
cfg.max_packets_per_wave,
ciphertext_len as u32,
&transfer_type,
);
}
Ok(GroupSenderDevice::new(cfg, packets))
}
#[allow(clippy::too_many_arguments)]
fn scramble_encrypt_wave(
wave_idx: usize,
bytes_to_encrypt_for_this_wave: &[u8],
cfg: &GroupReceiverConfig,
msg_drill: &EntropyBank,
msg_pqc: &PostQuantumContainer,
scramble_drill: &EntropyBank,
target_cid: u64,
object_id: u64,
header_size_bytes: usize,
header_inscriber: impl Fn(&PacketVector, &EntropyBank, u64, u64, &mut BytesMut) + Send + Sync,
) -> Vec<(usize, PacketCoordinate)> {
let ciphertext = msg_drill
.encrypt(msg_pqc, bytes_to_encrypt_for_this_wave)
.unwrap();
let mut packets = ciphertext
.chunks(cfg.max_payload_size as usize)
.enumerate()
.map(|(relative_packet_idx, ciphertext_packet_bytes)| {
debug_assert_ne!(ciphertext_packet_bytes.len(), 0);
let mut packet =
BytesMut::with_capacity(ciphertext_packet_bytes.len() + header_size_bytes);
let true_packet_sequence =
(wave_idx * cfg.max_packets_per_wave as usize) + relative_packet_idx;
let vector = generate_packet_vector(true_packet_sequence, cfg.group_id, scramble_drill);
header_inscriber(&vector, scramble_drill, object_id, target_cid, &mut packet);
packet.put(ciphertext_packet_bytes);
(true_packet_sequence, PacketCoordinate { packet, vector })
})
.collect::<Vec<(usize, PacketCoordinate)>>();
packets.shuffle(&mut ThreadRng::default());
packets
}
pub fn oneshot_unencrypted_group_unified<const N: usize>(
plain_text: SecureMessagePacket<N>,
header_size_bytes: usize,
group_id: u64,
object_id: u64,
) -> Result<GroupSenderDevice<N>, CryptError<String>> {
let len = plain_text.message_len() as u64;
let group_receiver_config = GroupReceiverConfig {
object_id,
group_id,
packets_needed: 1,
header_size_bytes: header_size_bytes as u64,
plaintext_length: len,
max_payload_size: len,
last_payload_size: len,
number_of_full_waves: 0,
number_of_partial_waves: 1,
wave_count: 1,
max_plaintext_wave_length: len as u32,
last_plaintext_wave_length: len as u32,
max_packets_per_wave: 1,
packets_in_last_wave: 1,
transfer_type: None,
};
Ok(GroupSenderDevice::<N>::new_oneshot(
group_receiver_config,
plain_text,
))
}
#[derive(Debug, Eq, PartialEq)]
#[allow(non_camel_case_types)]
pub enum GroupReceiverStatus {
GROUP_COMPLETE(u32),
INVALID_PACKET,
ALREADY_RECEIVED,
INSERT_SUCCESS,
CORRUPT_WAVE,
WAVE_COMPLETE(u32),
NEEDS_RETRANSMISSION(u32),
}
#[allow(dead_code)]
pub struct GroupReceiver {
unified_plaintext_slab: Vec<u8>,
temp_wave_store: HashMap<u32, TempWaveStore>,
packets_received_order: BitVec,
waves_received: BitVec,
packets_needed: usize,
last_packet_recv_time: Instant,
max_payload_size: usize,
last_payload_size: usize,
max_packets_per_wave: usize,
max_plaintext_wave_length: usize,
last_plaintext_wave_length: usize,
wave_count: usize,
lowest_sequential_wave_completed: isize,
last_complete_wave: isize,
group_timeout: Duration,
wave_timeout: Duration,
}
use crate::misc::TransferType;
use crate::secure_buffer::sec_packet::SecureMessagePacket;
use citadel_pqcrypto::algorithm_dictionary::{EncryptionAlgorithm, SigAlgorithm};
use citadel_pqcrypto::PostQuantumContainer;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct GroupReceiverConfig {
pub packets_needed: u32,
pub max_packets_per_wave: u32,
pub plaintext_length: u64,
pub max_payload_size: u64,
pub last_payload_size: u64,
pub number_of_full_waves: u32,
pub number_of_partial_waves: u32,
pub wave_count: u32,
pub max_plaintext_wave_length: u32,
pub last_plaintext_wave_length: u32,
pub packets_in_last_wave: u32,
pub header_size_bytes: u64,
pub group_id: u64,
pub object_id: u64,
pub transfer_type: Option<TransferType>,
}
pub const GROUP_RECEIVER_INSCRIBE_LEN: usize = 72;
impl GroupReceiverConfig {
#[allow(clippy::too_many_arguments)]
pub fn new_refresh(
group_id: u64,
object_id: u64,
header_size_bytes: u64,
plaintext_length: u64,
max_packet_payload_size: u32,
number_of_full_waves: u32,
number_of_partial_waves: u32,
max_plaintext_bytes_per_wave: u64,
bytes_in_last_wave: u64,
max_packets_per_wave: u32,
ciphertext_len_last_wave: u32,
transfer_type: &TransferType,
) -> Self {
let number_of_waves = number_of_full_waves + number_of_partial_waves;
let packets_in_last_wave =
num_integer::Integer::div_ceil(&ciphertext_len_last_wave, &max_packet_payload_size);
let (_normal_packets_in_last_wave, mut debug_last_payload_size) =
ciphertext_len_last_wave.div_rem(&max_packet_payload_size);
if debug_last_payload_size == 0 {
debug_last_payload_size = max_packet_payload_size;
}
let packets_needed = (number_of_full_waves * max_packets_per_wave) + packets_in_last_wave;
GroupReceiverConfig {
group_id,
object_id,
packets_needed,
header_size_bytes,
plaintext_length,
max_payload_size: max_packet_payload_size as u64,
last_payload_size: debug_last_payload_size as u64,
number_of_full_waves,
number_of_partial_waves,
wave_count: number_of_waves,
max_plaintext_wave_length: max_plaintext_bytes_per_wave as u32,
last_plaintext_wave_length: bytes_in_last_wave as u32,
max_packets_per_wave,
packets_in_last_wave,
transfer_type: Some(transfer_type.clone()),
}
}
pub fn get_packet_count_in_wave(&self, wave_id: u32) -> u32 {
if wave_id == self.wave_count - 1 {
self.packets_in_last_wave
} else {
self.max_packets_per_wave
}
}
}
struct TempWaveStore {
packets_received: usize,
packets_in_wave: usize,
bytes_written: usize,
#[allow(dead_code)]
last_packet_recv_time: Option<Instant>,
ciphertext_buffer: Vec<u8>,
}
impl GroupReceiver {
#[allow(unused_results)]
pub fn new(cfg: GroupReceiverConfig, wave_timeout_ms: usize, group_timeout_ms: usize) -> Self {
use bitvec::prelude::*;
log::trace!(target: "citadel", "Creating new group receiver. Anticipated plaintext slab length: {}", cfg.plaintext_length);
let unified_plaintext_slab = vec![0u8; cfg.plaintext_length as usize];
let packets_needed = cfg.packets_needed;
let wave_count = cfg.wave_count;
let packets_received_order = bitvec::bitvec![usize, Lsb0; 0; packets_needed as usize];
let waves_received = bitvec::bitvec![usize, Lsb0; 0; wave_count as usize];
let mut temp_wave_store = HashMap::with_capacity(cfg.wave_count as usize);
let last_packet_recv_time = Instant::now();
let max_packets_per_wave = cfg.max_packets_per_wave;
let group_timeout = Duration::from_millis(group_timeout_ms as u64);
let wave_timeout = Duration::from_millis(wave_timeout_ms as u64);
let last_complete_wave = -1;
let lowest_sequential_wave_completed = -1;
for wave_id_cur in 0..cfg.wave_count {
let (ciphertext_buffer_alloc_size_for_single_wave, packets_in_wave) =
if wave_id_cur == cfg.wave_count - 1 {
let packets_in_last_wave = cfg.get_packet_count_in_wave(wave_id_cur);
let normal_packet_count = packets_in_last_wave.saturating_sub(1);
(
(normal_packet_count as u64 * cfg.max_payload_size) + cfg.last_payload_size,
packets_in_last_wave,
)
} else {
(
cfg.max_payload_size * max_packets_per_wave as u64,
max_packets_per_wave,
)
};
let last_packet_recv_time = if wave_id_cur == 0 {
Some(Instant::now())
} else {
None
};
let ciphertext_buffer =
vec![0u8; ciphertext_buffer_alloc_size_for_single_wave as usize];
let tmp_wave_store_container = TempWaveStore {
bytes_written: 0,
packets_received: 0,
packets_in_wave: packets_in_wave as usize,
last_packet_recv_time,
ciphertext_buffer,
};
temp_wave_store.insert(wave_id_cur, tmp_wave_store_container);
}
Self {
lowest_sequential_wave_completed,
waves_received,
last_complete_wave,
wave_timeout,
group_timeout,
unified_plaintext_slab,
temp_wave_store,
packets_received_order,
packets_needed: cfg.packets_needed as usize,
last_packet_recv_time,
max_payload_size: cfg.max_payload_size as usize,
last_payload_size: cfg.last_payload_size as usize,
max_packets_per_wave: cfg.max_packets_per_wave as usize,
wave_count: cfg.wave_count as usize,
max_plaintext_wave_length: cfg.max_plaintext_wave_length as usize,
last_plaintext_wave_length: cfg.last_plaintext_wave_length as usize,
}
}
pub fn on_packet_received<T: AsRef<[u8]>, R: Ratchet>(
&mut self,
_group_id: u64,
true_sequence: usize,
wave_id: u32,
hyper_ratchet: &R,
packet: T,
) -> GroupReceiverStatus {
let packet = packet.as_ref();
let is_received =
if let Some(mut is_received) = self.packets_received_order.get_mut(true_sequence) {
let is_recv = *is_received;
if !*is_received {
*is_received = true;
}
is_recv
} else {
return GroupReceiverStatus::INVALID_PACKET;
};
if !is_received {
let wave_store = self.temp_wave_store.get_mut(&wave_id);
if wave_store.is_none() {
log::trace!(target: "citadel", "Packet {} (Parent wave: {}) does not have a wave store", true_sequence, wave_id);
return GroupReceiverStatus::INVALID_PACKET;
}
let wave_store = wave_store.unwrap();
let insert_index = Self::get_ciphertext_insertion_range(
true_sequence,
self.max_packets_per_wave,
self.packets_needed,
self.last_payload_size,
self.max_payload_size,
wave_store,
);
if !check_bounds(&wave_store.ciphertext_buffer, insert_index.clone()) {
log::error!(target: "citadel", "Bad ciphertext buffer insertion index {insert_index:?} for buf of len {}", wave_store.ciphertext_buffer.len());
return GroupReceiverStatus::INVALID_PACKET;
}
let dest_bytes = &mut wave_store.ciphertext_buffer[insert_index];
if !check_bounds(&dest_bytes, ..packet.len()) {
log::error!(target: "citadel", "Bad dest buffer insertion index {:?} for buf of len {}", ..packet.len(), dest_bytes.len());
return GroupReceiverStatus::INVALID_PACKET;
}
let dest_bytes = &mut dest_bytes[..packet.len()];
let packet_bytes = packet;
debug_assert_eq!(packet_bytes.len(), dest_bytes.len());
dest_bytes.copy_from_slice(packet_bytes);
wave_store.packets_received += 1;
wave_store.bytes_written += packet_bytes.len();
wave_store.last_packet_recv_time = Some(Instant::now());
self.packets_received_order.set(true_sequence, true);
if wave_store.packets_received == wave_store.packets_in_wave {
let ciphertext_bytes_for_this_wave =
&wave_store.ciphertext_buffer[..wave_store.bytes_written];
let (msg_pqc, msg_drill) = hyper_ratchet.message_pqc_drill(None);
match msg_drill.decrypt(msg_pqc, ciphertext_bytes_for_this_wave) {
Ok(plaintext) => {
let plaintext = plaintext.as_slice();
let plaintext_insert_index =
Self::get_plaintext_buffer_insertion_range_by_wave_id(
wave_id,
plaintext,
self.max_plaintext_wave_length,
);
let dest_bytes =
&mut self.unified_plaintext_slab[plaintext_insert_index.clone()];
debug_assert_eq!(
plaintext_insert_index.end - plaintext_insert_index.start,
dest_bytes.len()
);
dest_bytes.copy_from_slice(plaintext);
assert!(self.temp_wave_store.remove(&wave_id).is_some());
if self.temp_wave_store.is_empty() {
GroupReceiverStatus::GROUP_COMPLETE(wave_id)
} else {
if let Some(next_wave) = self.temp_wave_store.get_mut(&(wave_id + 1)) {
next_wave.last_packet_recv_time = Some(Instant::now());
}
if wave_id as isize - 1 == self.lowest_sequential_wave_completed {
self.lowest_sequential_wave_completed = wave_id as isize;
}
self.waves_received.set(wave_id as usize, true);
self.last_complete_wave = wave_id as isize;
GroupReceiverStatus::WAVE_COMPLETE(wave_id)
}
}
Err(err) => {
let sample_bytes = std::cmp::min(10, ciphertext_bytes_for_this_wave.len());
log::error!(target: "citadel", "Unable to decrypt wave {}. Reason: {} | len: {} | First bytes: {:?}", wave_id, err.into_string(), ciphertext_bytes_for_this_wave.len(), &ciphertext_bytes_for_this_wave[0..sample_bytes]);
GroupReceiverStatus::CORRUPT_WAVE
}
}
} else {
self.last_packet_recv_time = Instant::now();
GroupReceiverStatus::INSERT_SUCCESS
}
} else {
log::trace!(target: "citadel", "Packet {} (Parent Wave: {}) already received", true_sequence, wave_id);
GroupReceiverStatus::ALREADY_RECEIVED
}
}
pub fn get_missing_waves(&self) -> Option<Vec<u32>> {
if self.lowest_sequential_wave_completed < 0 {
return None;
}
let range =
self.lowest_sequential_wave_completed as usize..self.last_complete_wave as usize;
let offset = range.start;
let subset = &self.waves_received.as_bitslice()[range];
let ret = subset
.iter()
.enumerate()
.filter_map(|(wave_id, finished)| {
if !*finished {
Some((offset + wave_id) as u32)
} else {
None
}
})
.collect::<Vec<u32>>();
if !ret.is_empty() {
Some(ret)
} else {
None
}
}
#[inline]
pub fn get_missing_count_in_wave(&self, wave_id: u32) -> Option<usize> {
debug_assert!(wave_id < self.wave_count as u32);
let wave_store = self.temp_wave_store.get(&wave_id)?;
Some(wave_store.packets_in_wave - wave_store.packets_received)
}
pub fn finalize(self) -> Vec<u8> {
self.unified_plaintext_slab
}
fn get_ciphertext_insertion_range(
true_sequence: usize,
max_packets_per_wave: usize,
packets_needed: usize,
last_payload_size: usize,
max_payload_size: usize,
store: &TempWaveStore,
) -> Range<usize> {
let packet_idx_relative_to_wave = true_sequence % max_packets_per_wave;
if true_sequence == packets_needed - 1 {
let len = store.ciphertext_buffer.capacity();
let start_idx = len - last_payload_size;
let end_idx = len;
start_idx..end_idx
} else {
let start_idx = max_payload_size * packet_idx_relative_to_wave;
let end_idx = max_payload_size + start_idx;
start_idx..end_idx
}
}
fn get_plaintext_buffer_insertion_range_by_wave_id(
wave_id: u32,
plaintext: &[u8],
max_plaintext_wave_length: usize,
) -> Range<usize> {
let plaintext_length = plaintext.len();
let start_idx = wave_id as usize * max_plaintext_wave_length;
let end_idx = start_idx + plaintext_length;
start_idx..end_idx
}
pub fn get_wave_count(&self) -> usize {
self.wave_count
}
pub fn get_last_complete_wave(&self) -> Option<u32> {
let last = self.last_complete_wave;
if last < 0 {
None
} else {
Some(last as u32)
}
}
pub fn has_expired(&self, timeout: Duration) -> bool {
self.last_packet_recv_time.elapsed() > timeout
}
}
pub struct GroupSenderDevice<const N: usize> {
pub packets_in_ram: HashMap<usize, PacketCoordinate>,
oneshot: Option<SecureMessagePacket<N>>,
packets_received: usize,
packets_sent: usize,
receiver_config: GroupReceiverConfig,
last_wave_ack_received: Instant,
}
impl<const N: usize> GroupSenderDevice<N> {
pub fn new(
receiver_config: GroupReceiverConfig,
packets_in_ram: HashMap<usize, PacketCoordinate>,
) -> Self {
Self {
packets_in_ram,
packets_received: 0,
packets_sent: 0,
receiver_config,
oneshot: None,
last_wave_ack_received: Instant::now(),
}
}
pub fn new_oneshot(
receiver_config: GroupReceiverConfig,
oneshot: SecureMessagePacket<N>,
) -> Self {
Self {
packets_in_ram: HashMap::with_capacity(0),
oneshot: Some(oneshot),
packets_received: 0,
packets_sent: 0,
receiver_config,
last_wave_ack_received: Instant::now(),
}
}
pub fn is_atleast_fifty_percent_done(&self) -> bool {
self.packets_received as f32 * 1.5f32 >= self.receiver_config.packets_needed as f32
}
pub fn get_next_packet(&mut self) -> Option<PacketCoordinate> {
if self.packets_sent != self.receiver_config.packets_needed as usize {
let next_packet = self.packets_in_ram.remove(&self.packets_sent).unwrap();
self.packets_sent += 1;
Some(next_packet)
} else {
None
}
}
pub fn get_oneshot(&mut self) -> Option<SecureMessagePacket<N>> {
self.oneshot.take()
}
#[allow(unused_results)]
pub fn on_wave_tail_ack_received(&mut self, wave_id: u32) -> bool {
let offset = self.receiver_config.max_packets_per_wave * wave_id;
let packets_in_this_wave = self.get_packets_in_wave(wave_id);
let end = offset + packets_in_this_wave as u32;
log::trace!(target: "citadel", "Wave tail received for wave {}. Removing entries from {} to {}", wave_id, offset, end);
for idx in offset..end {
self.packets_in_ram.remove(&(idx as usize));
}
self.last_wave_ack_received = Instant::now();
self.packets_received += packets_in_this_wave;
self.packets_received == self.receiver_config.packets_needed as usize
}
pub fn take_all_packets(&mut self) -> Vec<PacketCoordinate> {
self.packets_in_ram.drain().map(|(_, v)| v).collect()
}
pub fn get_receiver_config(&self) -> GroupReceiverConfig {
self.receiver_config.clone()
}
pub fn get_packets_sent(&self) -> usize {
self.packets_sent
}
pub fn get_packets_received(&self) -> usize {
self.packets_received
}
pub fn get_packets_in_wave(&self, wave_id: u32) -> usize {
debug_assert!(wave_id < self.receiver_config.wave_count);
if wave_id == self.receiver_config.wave_count - 1 {
self.receiver_config.packets_in_last_wave as usize
} else {
self.receiver_config.max_packets_per_wave as usize
}
}
pub fn has_expired(&self, timeout: Duration) -> bool {
self.last_wave_ack_received.elapsed() > timeout
}
}
fn check_bounds<T: AsRef<[u8]>, R: RangeBounds<usize>>(buf: T, range: R) -> bool {
let buf = buf.as_ref();
!range.contains(&buf.len())
}