Skip to main content

commonware_consensus/marshal/resolver/
handler.rs

1use crate::types::{Height, Round};
2use bytes::{Buf, BufMut, Bytes};
3use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
4use commonware_cryptography::Digest;
5use commonware_resolver::{p2p::Producer, Consumer};
6use commonware_utils::{
7    channel::{mpsc, oneshot},
8    Span,
9};
10use std::{
11    fmt::{Debug, Display},
12    hash::{Hash, Hasher},
13};
14use tracing::error;
15
16/// The subject of a backfill request.
17const BLOCK_REQUEST: u8 = 0;
18const FINALIZED_REQUEST: u8 = 1;
19const NOTARIZED_REQUEST: u8 = 2;
20
21/// Messages sent from the resolver's [Consumer]/[Producer] implementation
22/// to the marshal actor.
23pub enum Message<D: Digest> {
24    /// A request to deliver a value for a given key.
25    Deliver {
26        /// The key of the value being delivered.
27        key: Request<D>,
28        /// The value being delivered.
29        value: Bytes,
30        /// A channel to send the result of the delivery (true for success).
31        response: oneshot::Sender<bool>,
32    },
33    /// A request to produce a value for a given key.
34    Produce {
35        /// The key of the value to produce.
36        key: Request<D>,
37        /// A channel to send the produced value.
38        response: oneshot::Sender<Bytes>,
39    },
40}
41
42/// A handler that forwards requests from the resolver to the marshal actor.
43///
44/// This struct implements the [Consumer] and [Producer] traits from the
45/// resolver, and acts as a bridge to the main actor loop.
46#[derive(Clone)]
47pub struct Handler<D: Digest> {
48    sender: mpsc::Sender<Message<D>>,
49}
50
51impl<D: Digest> Handler<D> {
52    /// Creates a new handler.
53    pub const fn new(sender: mpsc::Sender<Message<D>>) -> Self {
54        Self { sender }
55    }
56}
57
58impl<D: Digest> Consumer for Handler<D> {
59    type Key = Request<D>;
60    type Value = Bytes;
61    type Failure = ();
62
63    async fn deliver(&mut self, key: Self::Key, value: Self::Value) -> bool {
64        let (response, receiver) = oneshot::channel();
65        if self
66            .sender
67            .send(Message::Deliver {
68                key,
69                value,
70                response,
71            })
72            .await
73            .is_err()
74        {
75            error!("failed to send deliver message to actor: receiver dropped");
76            return false;
77        }
78        receiver.await.unwrap_or(false)
79    }
80
81    async fn failed(&mut self, _: Self::Key, _: Self::Failure) {
82        // We don't need to do anything on failure, the resolver will retry.
83    }
84}
85
86impl<D: Digest> Producer for Handler<D> {
87    type Key = Request<D>;
88
89    async fn produce(&mut self, key: Self::Key) -> oneshot::Receiver<Bytes> {
90        let (response, receiver) = oneshot::channel();
91        if self
92            .sender
93            .send(Message::Produce { key, response })
94            .await
95            .is_err()
96        {
97            error!("failed to send produce message to actor: receiver dropped");
98        }
99        receiver
100    }
101}
102
103/// A request for backfilling data.
104#[derive(Clone)]
105pub enum Request<D: Digest> {
106    Block(D),
107    Finalized { height: Height },
108    Notarized { round: Round },
109}
110
111impl<D: Digest> Request<D> {
112    /// The subject of the request.
113    const fn subject(&self) -> u8 {
114        match self {
115            Self::Block(_) => BLOCK_REQUEST,
116            Self::Finalized { .. } => FINALIZED_REQUEST,
117            Self::Notarized { .. } => NOTARIZED_REQUEST,
118        }
119    }
120
121    /// The predicate to use when pruning subjects related to this subject.
122    ///
123    /// Specifically, any subjects unrelated will be left unmodified. Any related
124    /// subjects will be pruned if they are "less than or equal to" this subject.
125    pub fn predicate(&self) -> impl Fn(&Self) -> bool + Send + 'static {
126        let cloned = self.clone();
127        move |s| match (&cloned, &s) {
128            (Self::Block(_), _) => unreachable!("we should never retain by block"),
129            (Self::Finalized { height: mine }, Self::Finalized { height: theirs }) => {
130                *theirs > *mine
131            }
132            (Self::Finalized { .. }, _) => true,
133            (Self::Notarized { round: mine }, Self::Notarized { round: theirs }) => *theirs > *mine,
134            (Self::Notarized { .. }, _) => true,
135        }
136    }
137}
138
139impl<D: Digest> Write for Request<D> {
140    fn write(&self, buf: &mut impl BufMut) {
141        self.subject().write(buf);
142        match self {
143            Self::Block(digest) => digest.write(buf),
144            Self::Finalized { height } => height.write(buf),
145            Self::Notarized { round } => round.write(buf),
146        }
147    }
148}
149
150impl<D: Digest> Read for Request<D> {
151    type Cfg = ();
152
153    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
154        let request = match u8::read(buf)? {
155            BLOCK_REQUEST => Self::Block(D::read(buf)?),
156            FINALIZED_REQUEST => Self::Finalized {
157                height: Height::read(buf)?,
158            },
159            NOTARIZED_REQUEST => Self::Notarized {
160                round: Round::read(buf)?,
161            },
162            i => return Err(CodecError::InvalidEnum(i)),
163        };
164        Ok(request)
165    }
166}
167
168impl<D: Digest> EncodeSize for Request<D> {
169    fn encode_size(&self) -> usize {
170        1 + match self {
171            Self::Block(block) => block.encode_size(),
172            Self::Finalized { height } => height.encode_size(),
173            Self::Notarized { round } => round.encode_size(),
174        }
175    }
176}
177
178impl<D: Digest> Span for Request<D> {}
179
180impl<D: Digest> PartialEq for Request<D> {
181    fn eq(&self, other: &Self) -> bool {
182        match (&self, &other) {
183            (Self::Block(a), Self::Block(b)) => a == b,
184            (Self::Finalized { height: a }, Self::Finalized { height: b }) => a == b,
185            (Self::Notarized { round: a }, Self::Notarized { round: b }) => a == b,
186            _ => false,
187        }
188    }
189}
190
191impl<D: Digest> Eq for Request<D> {}
192
193impl<D: Digest> Ord for Request<D> {
194    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
195        match (&self, &other) {
196            (Self::Block(a), Self::Block(b)) => a.cmp(b),
197            (Self::Finalized { height: a }, Self::Finalized { height: b }) => a.cmp(b),
198            (Self::Notarized { round: a }, Self::Notarized { round: b }) => a.cmp(b),
199            (a, b) => a.subject().cmp(&b.subject()),
200        }
201    }
202}
203
204impl<D: Digest> PartialOrd for Request<D> {
205    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
206        Some(self.cmp(other))
207    }
208}
209
210impl<D: Digest> Hash for Request<D> {
211    fn hash<H: Hasher>(&self, state: &mut H) {
212        self.subject().hash(state);
213        match self {
214            Self::Block(digest) => digest.hash(state),
215            Self::Finalized { height } => height.hash(state),
216            Self::Notarized { round } => round.hash(state),
217        }
218    }
219}
220
221impl<D: Digest> Display for Request<D> {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        match self {
224            Self::Block(digest) => write!(f, "Block({digest:?})"),
225            Self::Finalized { height } => write!(f, "Finalized({height:?})"),
226            Self::Notarized { round } => write!(f, "Notarized({round:?})"),
227        }
228    }
229}
230
231impl<D: Digest> Debug for Request<D> {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        match self {
234            Self::Block(digest) => write!(f, "Block({digest:?})"),
235            Self::Finalized { height } => write!(f, "Finalized({height:?})"),
236            Self::Notarized { round } => write!(f, "Notarized({round:?})"),
237        }
238    }
239}
240
241#[cfg(feature = "arbitrary")]
242impl<D: Digest> arbitrary::Arbitrary<'_> for Request<D>
243where
244    D: for<'a> arbitrary::Arbitrary<'a>,
245{
246    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
247        let choice = u.int_in_range(0..=2)?;
248        match choice {
249            0 => Ok(Self::Block(u.arbitrary()?)),
250            1 => Ok(Self::Finalized {
251                height: u.arbitrary()?,
252            }),
253            2 => Ok(Self::Notarized {
254                round: u.arbitrary()?,
255            }),
256            _ => unreachable!(),
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::types::{Epoch, View};
265    use commonware_codec::{Encode, ReadExt};
266    use commonware_cryptography::{
267        sha256::{Digest as Sha256Digest, Sha256},
268        Hasher as _,
269    };
270    use std::collections::BTreeSet;
271
272    type D = Sha256Digest;
273
274    #[test]
275    fn test_cross_variant_hash_differs() {
276        use std::{
277            collections::hash_map::DefaultHasher,
278            hash::{Hash, Hasher},
279        };
280
281        fn hash_of<T: Hash>(t: &T) -> u64 {
282            let mut h = DefaultHasher::new();
283            t.hash(&mut h);
284            h.finish()
285        }
286
287        let finalized = Request::<D>::Finalized {
288            height: Height::new(1),
289        };
290        let notarized = Request::<D>::Notarized {
291            round: Round::new(Epoch::new(0), View::new(1)),
292        };
293        assert_ne!(hash_of(&finalized), hash_of(&notarized));
294    }
295
296    #[test]
297    fn test_subject_block_encoding() {
298        let digest = Sha256::hash(b"test");
299        let request = Request::<D>::Block(digest);
300
301        // Test encoding
302        let encoded = request.encode();
303        assert_eq!(encoded.len(), 33); // 1 byte for enum variant + 32 bytes for digest
304        assert_eq!(encoded[0], 0); // Block variant
305
306        // Test decoding
307        let mut buf = encoded.as_ref();
308        let decoded = Request::<D>::read(&mut buf).unwrap();
309        assert_eq!(request, decoded);
310        assert_eq!(decoded, Request::Block(digest));
311    }
312
313    #[test]
314    fn test_subject_finalized_encoding() {
315        let height = Height::new(12345u64);
316        let request = Request::<D>::Finalized { height };
317
318        // Test encoding
319        let encoded = request.encode();
320        assert_eq!(encoded[0], 1); // Finalized variant
321
322        // Test decoding
323        let mut buf = encoded.as_ref();
324        let decoded = Request::<D>::read(&mut buf).unwrap();
325        assert_eq!(request, decoded);
326        assert_eq!(decoded, Request::Finalized { height });
327    }
328
329    #[test]
330    fn test_subject_notarized_encoding() {
331        let round = Round::new(Epoch::new(67890), View::new(12345));
332        let request = Request::<D>::Notarized { round };
333
334        // Test encoding
335        let encoded = request.encode();
336        assert_eq!(encoded[0], 2); // Notarized variant
337
338        // Test decoding
339        let mut buf = encoded.as_ref();
340        let decoded = Request::<D>::read(&mut buf).unwrap();
341        assert_eq!(request, decoded);
342        assert_eq!(decoded, Request::Notarized { round });
343    }
344
345    #[test]
346    fn test_subject_hash() {
347        use std::collections::HashSet;
348
349        let r1 = Request::<D>::Finalized {
350            height: Height::new(100),
351        };
352        let r2 = Request::<D>::Finalized {
353            height: Height::new(100),
354        };
355        let r3 = Request::<D>::Finalized {
356            height: Height::new(200),
357        };
358
359        let mut set = HashSet::new();
360        set.insert(r1);
361        assert!(!set.insert(r2)); // Should not insert duplicate
362        assert!(set.insert(r3)); // Should insert different value
363    }
364
365    #[test]
366    fn test_subject_predicate() {
367        let r1 = Request::<D>::Finalized {
368            height: Height::new(100),
369        };
370        let r2 = Request::<D>::Finalized {
371            height: Height::new(200),
372        };
373        let r3 = Request::<D>::Notarized {
374            round: Round::new(Epoch::new(333), View::new(150)),
375        };
376
377        let predicate = r1.predicate();
378        assert!(predicate(&r2)); // r2.height > r1.height
379        assert!(predicate(&r3)); // Different variant (notarized)
380
381        let r1_same = Request::<D>::Finalized {
382            height: Height::new(100),
383        };
384        assert!(!predicate(&r1_same)); // Same height, should not pass
385    }
386
387    #[test]
388    fn test_encode_size() {
389        let digest = Sha256::hash(&[0u8; 32]);
390        let r1 = Request::<D>::Block(digest);
391        let r2 = Request::<D>::Finalized {
392            height: Height::new(u64::MAX),
393        };
394        let r3 = Request::<D>::Notarized {
395            round: Round::new(Epoch::new(333), View::new(0)),
396        };
397
398        // Verify encode_size matches actual encoded length
399        assert_eq!(r1.encode_size(), r1.encode().len());
400        assert_eq!(r2.encode_size(), r2.encode().len());
401        assert_eq!(r3.encode_size(), r3.encode().len());
402    }
403
404    #[test]
405    fn test_request_ord_same_variant() {
406        // Test ordering within the same variant
407        let digest1 = Sha256::hash(b"test1");
408        let digest2 = Sha256::hash(b"test2");
409        let block1 = Request::<D>::Block(digest1);
410        let block2 = Request::<D>::Block(digest2);
411
412        // Block ordering depends on digest ordering
413        if digest1 < digest2 {
414            assert!(block1 < block2);
415            assert!(block2 > block1);
416        } else {
417            assert!(block1 > block2);
418            assert!(block2 < block1);
419        }
420
421        // Finalized ordering by height
422        let fin1 = Request::<D>::Finalized {
423            height: Height::new(100),
424        };
425        let fin2 = Request::<D>::Finalized {
426            height: Height::new(200),
427        };
428        let fin3 = Request::<D>::Finalized {
429            height: Height::new(200),
430        };
431
432        assert!(fin1 < fin2);
433        assert!(fin2 > fin1);
434        assert_eq!(fin2.cmp(&fin3), std::cmp::Ordering::Equal);
435
436        // Notarized ordering by view
437        let not1 = Request::<D>::Notarized {
438            round: Round::new(Epoch::new(333), View::new(50)),
439        };
440        let not2 = Request::<D>::Notarized {
441            round: Round::new(Epoch::new(333), View::new(150)),
442        };
443        let not3 = Request::<D>::Notarized {
444            round: Round::new(Epoch::new(333), View::new(150)),
445        };
446
447        assert!(not1 < not2);
448        assert!(not2 > not1);
449        assert_eq!(not2.cmp(&not3), std::cmp::Ordering::Equal);
450    }
451
452    #[test]
453    fn test_request_ord_cross_variant() {
454        let digest = Sha256::hash(b"test");
455        let block = Request::<D>::Block(digest);
456        let finalized = Request::<D>::Finalized {
457            height: Height::new(100),
458        };
459        let notarized = Request::<D>::Notarized {
460            round: Round::new(Epoch::new(333), View::new(200)),
461        };
462
463        // Block < Finalized < Notarized
464        assert!(block < finalized);
465        assert!(block < notarized);
466        assert!(finalized < notarized);
467
468        assert!(finalized > block);
469        assert!(notarized > block);
470        assert!(notarized > finalized);
471
472        // Test all combinations
473        assert_eq!(block.cmp(&finalized), std::cmp::Ordering::Less);
474        assert_eq!(block.cmp(&notarized), std::cmp::Ordering::Less);
475        assert_eq!(finalized.cmp(&notarized), std::cmp::Ordering::Less);
476        assert_eq!(finalized.cmp(&block), std::cmp::Ordering::Greater);
477        assert_eq!(notarized.cmp(&block), std::cmp::Ordering::Greater);
478        assert_eq!(notarized.cmp(&finalized), std::cmp::Ordering::Greater);
479    }
480
481    #[test]
482    fn test_request_partial_ord() {
483        let digest1 = Sha256::hash(b"test1");
484        let digest2 = Sha256::hash(b"test2");
485        let block1 = Request::<D>::Block(digest1);
486        let block2 = Request::<D>::Block(digest2);
487        let finalized = Request::<D>::Finalized {
488            height: Height::new(100),
489        };
490        let notarized = Request::<D>::Notarized {
491            round: Round::new(Epoch::new(333), View::new(200)),
492        };
493
494        // PartialOrd should always return Some
495        assert!(block1.partial_cmp(&block2).is_some());
496        assert!(block1.partial_cmp(&finalized).is_some());
497        assert!(finalized.partial_cmp(&notarized).is_some());
498
499        // Verify consistency with Ord
500        assert_eq!(
501            block1.partial_cmp(&finalized),
502            Some(std::cmp::Ordering::Less)
503        );
504        assert_eq!(
505            finalized.partial_cmp(&notarized),
506            Some(std::cmp::Ordering::Less)
507        );
508        assert_eq!(
509            notarized.partial_cmp(&block1),
510            Some(std::cmp::Ordering::Greater)
511        );
512    }
513
514    #[test]
515    fn test_request_ord_sorting() {
516        let digest1 = Sha256::hash(b"a");
517        let digest2 = Sha256::hash(b"b");
518        let digest3 = Sha256::hash(b"c");
519
520        let requests = vec![
521            Request::<D>::Notarized {
522                round: Round::new(Epoch::new(333), View::new(300)),
523            },
524            Request::<D>::Block(digest2),
525            Request::<D>::Finalized {
526                height: Height::new(200),
527            },
528            Request::<D>::Block(digest1),
529            Request::<D>::Notarized {
530                round: Round::new(Epoch::new(333), View::new(250)),
531            },
532            Request::<D>::Finalized {
533                height: Height::new(100),
534            },
535            Request::<D>::Block(digest3),
536        ];
537
538        // Sort using BTreeSet (uses Ord)
539        let sorted: Vec<_> = requests
540            .into_iter()
541            .collect::<BTreeSet<_>>()
542            .into_iter()
543            .collect();
544
545        // Verify order: all Blocks first (sorted by digest), then Finalized (by height), then Notarized (by view)
546        assert_eq!(sorted.len(), 7);
547
548        // Check that all blocks come first
549        assert!(matches!(sorted[0], Request::<D>::Block(_)));
550        assert!(matches!(sorted[1], Request::<D>::Block(_)));
551        assert!(matches!(sorted[2], Request::<D>::Block(_)));
552
553        // Check that finalized come next
554        assert_eq!(
555            sorted[3],
556            Request::<D>::Finalized {
557                height: Height::new(100)
558            }
559        );
560        assert_eq!(
561            sorted[4],
562            Request::<D>::Finalized {
563                height: Height::new(200)
564            }
565        );
566
567        // Check that notarized come last
568        assert_eq!(
569            sorted[5],
570            Request::<D>::Notarized {
571                round: Round::new(Epoch::new(333), View::new(250))
572            }
573        );
574        assert_eq!(
575            sorted[6],
576            Request::<D>::Notarized {
577                round: Round::new(Epoch::new(333), View::new(300))
578            }
579        );
580    }
581
582    #[test]
583    fn test_request_ord_edge_cases() {
584        // Test with extreme values
585        let min_finalized = Request::<D>::Finalized {
586            height: Height::new(0),
587        };
588        let max_finalized = Request::<D>::Finalized {
589            height: Height::new(u64::MAX),
590        };
591        let min_notarized = Request::<D>::Notarized {
592            round: Round::new(Epoch::new(333), View::new(0)),
593        };
594        let max_notarized = Request::<D>::Notarized {
595            round: Round::new(Epoch::new(333), View::new(u64::MAX)),
596        };
597
598        assert!(min_finalized < max_finalized);
599        assert!(min_notarized < max_notarized);
600        assert!(max_finalized < min_notarized);
601
602        // Test self-comparison
603        let digest = Sha256::hash(b"self");
604        let block = Request::<D>::Block(digest);
605        assert_eq!(block.cmp(&block), std::cmp::Ordering::Equal);
606        assert_eq!(min_finalized.cmp(&min_finalized), std::cmp::Ordering::Equal);
607        assert_eq!(max_notarized.cmp(&max_notarized), std::cmp::Ordering::Equal);
608    }
609
610    #[cfg(feature = "arbitrary")]
611    mod conformance {
612        use super::*;
613        use commonware_codec::conformance::CodecConformance;
614
615        commonware_conformance::conformance_tests! {
616            CodecConformance<Request<D>>
617        }
618    }
619}