use crate::types::Round;
use commonware_utils::{channel::oneshot, sync::Mutex};
use std::{collections::HashMap, hash::Hash, sync::Arc};
type VerificationTaskMap<D> = HashMap<(Round, D), oneshot::Receiver<bool>>;
#[derive(Clone)]
pub(crate) struct VerificationTasks<D>
where
D: Eq + Hash,
{
inner: Arc<Mutex<VerificationTaskMap<D>>>,
}
impl<D> Default for VerificationTasks<D>
where
D: Eq + Hash,
{
fn default() -> Self {
Self::new()
}
}
impl<D> VerificationTasks<D>
where
D: Eq + Hash,
{
pub(crate) fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub(crate) fn insert(&self, round: Round, digest: D, task: oneshot::Receiver<bool>) {
self.inner.lock().insert((round, digest), task);
}
pub(crate) fn take(&self, round: Round, digest: D) -> Option<oneshot::Receiver<bool>> {
self.inner.lock().remove(&(round, digest))
}
pub(crate) fn retain_after(&self, finalized_round: &Round) {
self.inner
.lock()
.retain(|(task_round, _), _| task_round > finalized_round);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Epoch, View};
use commonware_cryptography::{sha256::Digest as Sha256Digest, Hasher, Sha256};
type D = Sha256Digest;
fn round(view: u64) -> Round {
Round::new(Epoch::zero(), View::new(view))
}
fn pending_task() -> oneshot::Receiver<bool> {
let (_tx, rx) = oneshot::channel();
rx
}
#[test]
fn test_insert_and_take_returns_task() {
let tasks = VerificationTasks::<D>::new();
let digest = Sha256::hash(b"block");
tasks.insert(round(1), digest, pending_task());
assert!(tasks.take(round(1), digest).is_some());
assert!(
tasks.take(round(1), digest).is_none(),
"taking twice should yield None"
);
}
#[test]
fn test_take_absent_key_is_none() {
let tasks = VerificationTasks::<D>::new();
assert!(tasks.take(round(1), Sha256::hash(b"missing")).is_none());
}
#[test]
fn test_take_distinguishes_rounds_and_digests() {
let tasks = VerificationTasks::<D>::new();
let digest_a = Sha256::hash(b"a");
let digest_b = Sha256::hash(b"b");
tasks.insert(round(1), digest_a, pending_task());
tasks.insert(round(2), digest_a, pending_task());
tasks.insert(round(1), digest_b, pending_task());
assert!(tasks.take(round(1), digest_a).is_some());
assert!(tasks.take(round(2), digest_a).is_some());
assert!(tasks.take(round(1), digest_b).is_some());
}
#[test]
fn test_retain_after_drops_at_and_below_boundary() {
let tasks = VerificationTasks::<D>::new();
let digest = Sha256::hash(b"block");
tasks.insert(round(1), digest, pending_task());
tasks.insert(round(2), digest, pending_task());
tasks.insert(round(3), digest, pending_task());
tasks.retain_after(&round(2));
assert!(
tasks.take(round(1), digest).is_none(),
"tasks strictly below boundary should be dropped"
);
assert!(
tasks.take(round(2), digest).is_none(),
"tasks at boundary should be dropped"
);
assert!(
tasks.take(round(3), digest).is_some(),
"tasks strictly above boundary should be retained"
);
}
#[test]
fn test_retain_after_spans_epochs() {
let tasks = VerificationTasks::<D>::new();
let digest = Sha256::hash(b"block");
let early = Round::new(Epoch::zero(), View::new(100));
let late = Round::new(Epoch::new(1), View::zero());
tasks.insert(early, digest, pending_task());
tasks.insert(late, digest, pending_task());
tasks.retain_after(&early);
assert!(
tasks.take(early, digest).is_none(),
"task at boundary must be dropped"
);
assert!(
tasks.take(late, digest).is_some(),
"task in later epoch must outlive an earlier boundary"
);
}
#[test]
fn test_retain_after_empty_map_is_noop() {
let tasks = VerificationTasks::<D>::new();
tasks.retain_after(&round(5));
assert!(tasks.take(round(5), Sha256::hash(b"x")).is_none());
}
#[test]
fn test_default_matches_new() {
let default = <VerificationTasks<D> as Default>::default();
let digest = Sha256::hash(b"block");
default.insert(round(1), digest, pending_task());
assert!(default.take(round(1), digest).is_some());
}
}