use crate::types::{Height, Round};
use bytes::{Buf, BufMut, Bytes};
use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
use commonware_cryptography::Digest;
use commonware_resolver::{p2p::Producer, Consumer};
use commonware_utils::{
channel::{mpsc, oneshot},
Span,
};
use std::{
fmt::{Debug, Display},
hash::{Hash, Hasher},
};
use tracing::error;
const BLOCK_REQUEST: u8 = 0;
const FINALIZED_REQUEST: u8 = 1;
const NOTARIZED_REQUEST: u8 = 2;
pub enum Message<D: Digest> {
Deliver {
key: Request<D>,
value: Bytes,
response: oneshot::Sender<bool>,
},
Produce {
key: Request<D>,
response: oneshot::Sender<Bytes>,
},
}
#[derive(Clone)]
pub struct Handler<D: Digest> {
sender: mpsc::Sender<Message<D>>,
}
impl<D: Digest> Handler<D> {
pub const fn new(sender: mpsc::Sender<Message<D>>) -> Self {
Self { sender }
}
}
impl<D: Digest> Consumer for Handler<D> {
type Key = Request<D>;
type Value = Bytes;
type Failure = ();
async fn deliver(&mut self, key: Self::Key, value: Self::Value) -> bool {
let (response, receiver) = oneshot::channel();
if self
.sender
.send(Message::Deliver {
key,
value,
response,
})
.await
.is_err()
{
error!("failed to send deliver message to actor: receiver dropped");
return false;
}
receiver.await.unwrap_or(false)
}
async fn failed(&mut self, _: Self::Key, _: Self::Failure) {
}
}
impl<D: Digest> Producer for Handler<D> {
type Key = Request<D>;
async fn produce(&mut self, key: Self::Key) -> oneshot::Receiver<Bytes> {
let (response, receiver) = oneshot::channel();
if self
.sender
.send(Message::Produce { key, response })
.await
.is_err()
{
error!("failed to send produce message to actor: receiver dropped");
}
receiver
}
}
#[derive(Clone)]
pub enum Request<D: Digest> {
Block(D),
Finalized { height: Height },
Notarized { round: Round },
}
impl<D: Digest> Request<D> {
const fn subject(&self) -> u8 {
match self {
Self::Block(_) => BLOCK_REQUEST,
Self::Finalized { .. } => FINALIZED_REQUEST,
Self::Notarized { .. } => NOTARIZED_REQUEST,
}
}
pub fn predicate(&self) -> impl Fn(&Self) -> bool + Send + 'static {
let cloned = self.clone();
move |s| match (&cloned, &s) {
(Self::Block(_), _) => unreachable!("we should never retain by block"),
(Self::Finalized { height: mine }, Self::Finalized { height: theirs }) => {
*theirs > *mine
}
(Self::Finalized { .. }, _) => true,
(Self::Notarized { round: mine }, Self::Notarized { round: theirs }) => *theirs > *mine,
(Self::Notarized { .. }, _) => true,
}
}
}
impl<D: Digest> Write for Request<D> {
fn write(&self, buf: &mut impl BufMut) {
self.subject().write(buf);
match self {
Self::Block(digest) => digest.write(buf),
Self::Finalized { height } => height.write(buf),
Self::Notarized { round } => round.write(buf),
}
}
}
impl<D: Digest> Read for Request<D> {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
let request = match u8::read(buf)? {
BLOCK_REQUEST => Self::Block(D::read(buf)?),
FINALIZED_REQUEST => Self::Finalized {
height: Height::read(buf)?,
},
NOTARIZED_REQUEST => Self::Notarized {
round: Round::read(buf)?,
},
i => return Err(CodecError::InvalidEnum(i)),
};
Ok(request)
}
}
impl<D: Digest> EncodeSize for Request<D> {
fn encode_size(&self) -> usize {
1 + match self {
Self::Block(block) => block.encode_size(),
Self::Finalized { height } => height.encode_size(),
Self::Notarized { round } => round.encode_size(),
}
}
}
impl<D: Digest> Span for Request<D> {}
impl<D: Digest> PartialEq for Request<D> {
fn eq(&self, other: &Self) -> bool {
match (&self, &other) {
(Self::Block(a), Self::Block(b)) => a == b,
(Self::Finalized { height: a }, Self::Finalized { height: b }) => a == b,
(Self::Notarized { round: a }, Self::Notarized { round: b }) => a == b,
_ => false,
}
}
}
impl<D: Digest> Eq for Request<D> {}
impl<D: Digest> Ord for Request<D> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (&self, &other) {
(Self::Block(a), Self::Block(b)) => a.cmp(b),
(Self::Finalized { height: a }, Self::Finalized { height: b }) => a.cmp(b),
(Self::Notarized { round: a }, Self::Notarized { round: b }) => a.cmp(b),
(a, b) => a.subject().cmp(&b.subject()),
}
}
}
impl<D: Digest> PartialOrd for Request<D> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<D: Digest> Hash for Request<D> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.subject().hash(state);
match self {
Self::Block(digest) => digest.hash(state),
Self::Finalized { height } => height.hash(state),
Self::Notarized { round } => round.hash(state),
}
}
}
impl<D: Digest> Display for Request<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Block(digest) => write!(f, "Block({digest:?})"),
Self::Finalized { height } => write!(f, "Finalized({height:?})"),
Self::Notarized { round } => write!(f, "Notarized({round:?})"),
}
}
}
impl<D: Digest> Debug for Request<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Block(digest) => write!(f, "Block({digest:?})"),
Self::Finalized { height } => write!(f, "Finalized({height:?})"),
Self::Notarized { round } => write!(f, "Notarized({round:?})"),
}
}
}
#[cfg(feature = "arbitrary")]
impl<D: Digest> arbitrary::Arbitrary<'_> for Request<D>
where
D: for<'a> arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let choice = u.int_in_range(0..=2)?;
match choice {
0 => Ok(Self::Block(u.arbitrary()?)),
1 => Ok(Self::Finalized {
height: u.arbitrary()?,
}),
2 => Ok(Self::Notarized {
round: u.arbitrary()?,
}),
_ => unreachable!(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Epoch, View};
use commonware_codec::{Encode, ReadExt};
use commonware_cryptography::{
sha256::{Digest as Sha256Digest, Sha256},
Hasher as _,
};
use std::collections::BTreeSet;
type D = Sha256Digest;
#[test]
fn test_cross_variant_hash_differs() {
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
fn hash_of<T: Hash>(t: &T) -> u64 {
let mut h = DefaultHasher::new();
t.hash(&mut h);
h.finish()
}
let finalized = Request::<D>::Finalized {
height: Height::new(1),
};
let notarized = Request::<D>::Notarized {
round: Round::new(Epoch::new(0), View::new(1)),
};
assert_ne!(hash_of(&finalized), hash_of(¬arized));
}
#[test]
fn test_subject_block_encoding() {
let digest = Sha256::hash(b"test");
let request = Request::<D>::Block(digest);
let encoded = request.encode();
assert_eq!(encoded.len(), 33); assert_eq!(encoded[0], 0);
let mut buf = encoded.as_ref();
let decoded = Request::<D>::read(&mut buf).unwrap();
assert_eq!(request, decoded);
assert_eq!(decoded, Request::Block(digest));
}
#[test]
fn test_subject_finalized_encoding() {
let height = Height::new(12345u64);
let request = Request::<D>::Finalized { height };
let encoded = request.encode();
assert_eq!(encoded[0], 1);
let mut buf = encoded.as_ref();
let decoded = Request::<D>::read(&mut buf).unwrap();
assert_eq!(request, decoded);
assert_eq!(decoded, Request::Finalized { height });
}
#[test]
fn test_subject_notarized_encoding() {
let round = Round::new(Epoch::new(67890), View::new(12345));
let request = Request::<D>::Notarized { round };
let encoded = request.encode();
assert_eq!(encoded[0], 2);
let mut buf = encoded.as_ref();
let decoded = Request::<D>::read(&mut buf).unwrap();
assert_eq!(request, decoded);
assert_eq!(decoded, Request::Notarized { round });
}
#[test]
fn test_subject_hash() {
use std::collections::HashSet;
let r1 = Request::<D>::Finalized {
height: Height::new(100),
};
let r2 = Request::<D>::Finalized {
height: Height::new(100),
};
let r3 = Request::<D>::Finalized {
height: Height::new(200),
};
let mut set = HashSet::new();
set.insert(r1);
assert!(!set.insert(r2)); assert!(set.insert(r3)); }
#[test]
fn test_subject_predicate() {
let r1 = Request::<D>::Finalized {
height: Height::new(100),
};
let r2 = Request::<D>::Finalized {
height: Height::new(200),
};
let r3 = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(150)),
};
let predicate = r1.predicate();
assert!(predicate(&r2)); assert!(predicate(&r3));
let r1_same = Request::<D>::Finalized {
height: Height::new(100),
};
assert!(!predicate(&r1_same)); }
#[test]
fn test_encode_size() {
let digest = Sha256::hash(&[0u8; 32]);
let r1 = Request::<D>::Block(digest);
let r2 = Request::<D>::Finalized {
height: Height::new(u64::MAX),
};
let r3 = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(0)),
};
assert_eq!(r1.encode_size(), r1.encode().len());
assert_eq!(r2.encode_size(), r2.encode().len());
assert_eq!(r3.encode_size(), r3.encode().len());
}
#[test]
fn test_request_ord_same_variant() {
let digest1 = Sha256::hash(b"test1");
let digest2 = Sha256::hash(b"test2");
let block1 = Request::<D>::Block(digest1);
let block2 = Request::<D>::Block(digest2);
if digest1 < digest2 {
assert!(block1 < block2);
assert!(block2 > block1);
} else {
assert!(block1 > block2);
assert!(block2 < block1);
}
let fin1 = Request::<D>::Finalized {
height: Height::new(100),
};
let fin2 = Request::<D>::Finalized {
height: Height::new(200),
};
let fin3 = Request::<D>::Finalized {
height: Height::new(200),
};
assert!(fin1 < fin2);
assert!(fin2 > fin1);
assert_eq!(fin2.cmp(&fin3), std::cmp::Ordering::Equal);
let not1 = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(50)),
};
let not2 = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(150)),
};
let not3 = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(150)),
};
assert!(not1 < not2);
assert!(not2 > not1);
assert_eq!(not2.cmp(¬3), std::cmp::Ordering::Equal);
}
#[test]
fn test_request_ord_cross_variant() {
let digest = Sha256::hash(b"test");
let block = Request::<D>::Block(digest);
let finalized = Request::<D>::Finalized {
height: Height::new(100),
};
let notarized = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(200)),
};
assert!(block < finalized);
assert!(block < notarized);
assert!(finalized < notarized);
assert!(finalized > block);
assert!(notarized > block);
assert!(notarized > finalized);
assert_eq!(block.cmp(&finalized), std::cmp::Ordering::Less);
assert_eq!(block.cmp(¬arized), std::cmp::Ordering::Less);
assert_eq!(finalized.cmp(¬arized), std::cmp::Ordering::Less);
assert_eq!(finalized.cmp(&block), std::cmp::Ordering::Greater);
assert_eq!(notarized.cmp(&block), std::cmp::Ordering::Greater);
assert_eq!(notarized.cmp(&finalized), std::cmp::Ordering::Greater);
}
#[test]
fn test_request_partial_ord() {
let digest1 = Sha256::hash(b"test1");
let digest2 = Sha256::hash(b"test2");
let block1 = Request::<D>::Block(digest1);
let block2 = Request::<D>::Block(digest2);
let finalized = Request::<D>::Finalized {
height: Height::new(100),
};
let notarized = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(200)),
};
assert!(block1.partial_cmp(&block2).is_some());
assert!(block1.partial_cmp(&finalized).is_some());
assert!(finalized.partial_cmp(¬arized).is_some());
assert_eq!(
block1.partial_cmp(&finalized),
Some(std::cmp::Ordering::Less)
);
assert_eq!(
finalized.partial_cmp(¬arized),
Some(std::cmp::Ordering::Less)
);
assert_eq!(
notarized.partial_cmp(&block1),
Some(std::cmp::Ordering::Greater)
);
}
#[test]
fn test_request_ord_sorting() {
let digest1 = Sha256::hash(b"a");
let digest2 = Sha256::hash(b"b");
let digest3 = Sha256::hash(b"c");
let requests = vec![
Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(300)),
},
Request::<D>::Block(digest2),
Request::<D>::Finalized {
height: Height::new(200),
},
Request::<D>::Block(digest1),
Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(250)),
},
Request::<D>::Finalized {
height: Height::new(100),
},
Request::<D>::Block(digest3),
];
let sorted: Vec<_> = requests
.into_iter()
.collect::<BTreeSet<_>>()
.into_iter()
.collect();
assert_eq!(sorted.len(), 7);
assert!(matches!(sorted[0], Request::<D>::Block(_)));
assert!(matches!(sorted[1], Request::<D>::Block(_)));
assert!(matches!(sorted[2], Request::<D>::Block(_)));
assert_eq!(
sorted[3],
Request::<D>::Finalized {
height: Height::new(100)
}
);
assert_eq!(
sorted[4],
Request::<D>::Finalized {
height: Height::new(200)
}
);
assert_eq!(
sorted[5],
Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(250))
}
);
assert_eq!(
sorted[6],
Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(300))
}
);
}
#[test]
fn test_request_ord_edge_cases() {
let min_finalized = Request::<D>::Finalized {
height: Height::new(0),
};
let max_finalized = Request::<D>::Finalized {
height: Height::new(u64::MAX),
};
let min_notarized = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(0)),
};
let max_notarized = Request::<D>::Notarized {
round: Round::new(Epoch::new(333), View::new(u64::MAX)),
};
assert!(min_finalized < max_finalized);
assert!(min_notarized < max_notarized);
assert!(max_finalized < min_notarized);
let digest = Sha256::hash(b"self");
let block = Request::<D>::Block(digest);
assert_eq!(block.cmp(&block), std::cmp::Ordering::Equal);
assert_eq!(min_finalized.cmp(&min_finalized), std::cmp::Ordering::Equal);
assert_eq!(max_notarized.cmp(&max_notarized), std::cmp::Ordering::Equal);
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Request<D>>
}
}
}