1use crate::{
3 merkle::mmr::{Location, StandardHasher},
4 qmdb::{
5 self,
6 sync::{
7 database::Config as _,
8 error::EngineError,
9 requests::{Id as RequestId, Requests},
10 resolver::{FetchResult, Resolver},
11 target::validate_update,
12 Database, Error as SyncError, Journal, Target,
13 },
14 },
15};
16use commonware_codec::Encode;
17use commonware_cryptography::Digest;
18use commonware_macros::select;
19use commonware_runtime::Metrics as _;
20use commonware_utils::{
21 channel::{
22 fallible::{AsyncFallibleExt, OneshotExt as _},
23 mpsc, oneshot,
24 },
25 NZU64,
26};
27use futures::{
28 future::{pending, Either},
29 StreamExt,
30};
31use mpsc::error::TryRecvError;
32use std::{
33 collections::{BTreeMap, HashMap, VecDeque},
34 fmt::Debug,
35 num::NonZeroU64,
36};
37
38type Error<DB, R> = qmdb::sync::Error<<R as Resolver>::Error, <DB as Database>::Digest>;
40
41#[derive(Debug)]
43pub(crate) enum NextStep<C, D> {
44 Continue(C),
46 Complete(D),
48}
49
50#[derive(Debug)]
52enum Event<Op, D: Digest, E> {
53 TargetUpdate(Target<D>),
55 BatchReceived(IndexedFetchResult<Op, D, E>),
57 UpdateChannelClosed,
59 FinishRequested,
61 FinishChannelClosed,
63}
64
65#[derive(Debug)]
67pub(super) struct IndexedFetchResult<Op, D: Digest, E> {
68 pub id: RequestId,
70 pub start_loc: Location,
72 pub result: Result<FetchResult<Op, D>, E>,
74}
75
76async fn wait_for_event<Op, D: Digest, E>(
79 update_rx: &mut Option<mpsc::Receiver<Target<D>>>,
80 finish_rx: &mut Option<mpsc::Receiver<()>>,
81 outstanding_requests: &mut Requests<Op, D, E>,
82) -> Option<Event<Op, D, E>> {
83 if outstanding_requests.len() == 0 && update_rx.is_none() && finish_rx.is_none() {
84 return None;
85 }
86
87 let target_update_fut = update_rx.as_mut().map_or_else(
88 || Either::Right(pending()),
89 |update_rx| Either::Left(update_rx.recv()),
90 );
91 let finish_fut = finish_rx.as_mut().map_or_else(
92 || Either::Right(pending()),
93 |finish_rx| Either::Left(finish_rx.recv()),
94 );
95 let batch_result_fut = if outstanding_requests.len() == 0 {
96 Either::Right(pending())
97 } else {
98 Either::Left(outstanding_requests.futures_mut().next())
99 };
100
101 select! {
102 finish = finish_fut => finish.map_or_else(
103 || Some(Event::FinishChannelClosed),
104 |_| Some(Event::FinishRequested)
105 ),
106 target = target_update_fut => target.map_or_else(
107 || Some(Event::UpdateChannelClosed),
108 |target| Some(Event::TargetUpdate(target))
109 ),
110 result = batch_result_fut => result.map(|fetch_result| Event::BatchReceived(fetch_result)),
111 }
112}
113
114pub struct Config<DB, R>
116where
117 DB: Database,
118 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
119 DB::Op: Encode,
120{
121 pub context: DB::Context,
123 pub resolver: R,
125 pub target: Target<DB::Digest>,
127 pub max_outstanding_requests: usize,
129 pub fetch_batch_size: NonZeroU64,
131 pub apply_batch_size: usize,
133 pub db_config: DB::Config,
135 pub update_rx: Option<mpsc::Receiver<Target<DB::Digest>>>,
137 pub finish_rx: Option<mpsc::Receiver<()>>,
141 pub reached_target_tx: Option<mpsc::Sender<Target<DB::Digest>>>,
148 pub max_retained_roots: usize,
152}
153pub(crate) struct Engine<DB, R>
155where
156 DB: Database,
157 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
158 DB::Op: Encode,
159{
160 outstanding_requests: Requests<DB::Op, DB::Digest, R::Error>,
162
163 fetched_operations: BTreeMap<Location, Vec<DB::Op>>,
169
170 pinned_nodes: Option<Vec<DB::Digest>>,
172
173 retained_roots: HashMap<Location, DB::Digest>,
179
180 retained_roots_order: VecDeque<Location>,
183
184 max_retained_roots: usize,
186
187 target: Target<DB::Digest>,
189
190 max_outstanding_requests: usize,
192
193 fetch_batch_size: NonZeroU64,
195
196 apply_batch_size: usize,
198
199 journal: DB::Journal,
201
202 resolver: R,
204
205 hasher: StandardHasher<DB::Hasher>,
207
208 context: DB::Context,
210
211 config: DB::Config,
213
214 update_rx: Option<mpsc::Receiver<Target<DB::Digest>>>,
216
217 finish_rx: Option<mpsc::Receiver<()>>,
221
222 reached_target_tx: Option<mpsc::Sender<Target<DB::Digest>>>,
229
230 finish_requested: bool,
232
233 reached_current_target_reported: bool,
235}
236
237#[cfg(test)]
238impl<DB, R> Engine<DB, R>
239where
240 DB: Database,
241 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
242 DB::Op: Encode,
243{
244 pub(crate) fn journal(&self) -> &DB::Journal {
245 &self.journal
246 }
247}
248
249impl<DB, R> Engine<DB, R>
250where
251 DB: Database,
252 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
253 DB::Op: Encode,
254{
255 pub async fn new(config: Config<DB, R>) -> Result<Self, Error<DB, R>> {
257 if !config.target.range.end().is_valid() {
258 return Err(SyncError::Engine(EngineError::InvalidTarget {
259 lower_bound_pos: config.target.range.start(),
260 upper_bound_pos: config.target.range.end(),
261 }));
262 }
263
264 let journal = <DB::Journal as Journal>::new(
266 config.context.with_label("journal"),
267 config.db_config.journal_config(),
268 config.target.range.clone().into(),
269 )
270 .await?;
271
272 let mut engine = Self {
273 outstanding_requests: Requests::new(),
274 fetched_operations: BTreeMap::new(),
275 pinned_nodes: None,
276 retained_roots: HashMap::new(),
277 retained_roots_order: VecDeque::new(),
278 max_retained_roots: config.max_retained_roots,
279 target: config.target.clone(),
280 max_outstanding_requests: config.max_outstanding_requests,
281 fetch_batch_size: config.fetch_batch_size,
282 apply_batch_size: config.apply_batch_size,
283 journal,
284 resolver: config.resolver.clone(),
285 hasher: StandardHasher::<DB::Hasher>::new(),
286 context: config.context,
287 config: config.db_config,
288 update_rx: config.update_rx,
289 finish_rx: config.finish_rx,
290 reached_target_tx: config.reached_target_tx,
291 finish_requested: false,
292 reached_current_target_reported: false,
293 };
294 engine.schedule_requests().await?;
295 Ok(engine)
296 }
297
298 async fn schedule_requests(&mut self) -> Result<(), Error<DB, R>> {
300 let target_size = self.target.range.end();
301
302 if self.pinned_nodes.is_none()
305 && !self
306 .outstanding_requests
307 .contains(&self.target.range.start())
308 {
309 let start_loc = self.target.range.start();
310 let resolver = self.resolver.clone();
311 let (cancel_tx, cancel_rx) = oneshot::channel();
312 let id = self.outstanding_requests.next_id();
313 self.outstanding_requests.insert(
314 id,
315 start_loc,
316 cancel_tx,
317 Box::pin(async move {
318 let result = resolver
319 .get_operations(target_size, start_loc, NZU64!(1), true, cancel_rx)
320 .await;
321 IndexedFetchResult {
322 id,
323 start_loc,
324 result,
325 }
326 }),
327 );
328 }
329
330 let num_requests = self
332 .max_outstanding_requests
333 .saturating_sub(self.outstanding_requests.len());
334
335 let log_size = self.journal.size().await;
336
337 for _ in 0..num_requests {
338 let operation_counts: BTreeMap<Location, u64> = self
340 .fetched_operations
341 .iter()
342 .map(|(&start_loc, operations)| (start_loc, operations.len() as u64))
343 .collect();
344
345 let Some(gap_range) = crate::qmdb::sync::gaps::find_next(
347 Location::new(log_size)..self.target.range.end(),
348 &operation_counts,
349 self.outstanding_requests.locations(),
350 self.fetch_batch_size,
351 ) else {
352 break; };
354
355 let gap_size = *gap_range.end.checked_sub(*gap_range.start).unwrap();
357 let gap_size: NonZeroU64 = gap_size.try_into().unwrap();
358 let batch_size = self.fetch_batch_size.min(gap_size);
359
360 let resolver = self.resolver.clone();
362 let (cancel_tx, cancel_rx) = oneshot::channel();
363 let id = self.outstanding_requests.next_id();
364 self.outstanding_requests.insert(
365 id,
366 gap_range.start,
367 cancel_tx,
368 Box::pin(async move {
369 let result = resolver
370 .get_operations(target_size, gap_range.start, batch_size, false, cancel_rx)
371 .await;
372 IndexedFetchResult {
373 id,
374 start_loc: gap_range.start,
375 result,
376 }
377 }),
378 );
379 }
380
381 Ok(())
382 }
383
384 pub async fn reset_for_target_update(
391 mut self,
392 new_target: Target<DB::Digest>,
393 ) -> Result<Self, Error<DB, R>> {
394 self.journal.resize(new_target.range.start()).await?;
395 self.outstanding_requests
398 .remove_before(new_target.range.start().checked_add(1).unwrap());
399 self.fetched_operations.clear();
400 self.pinned_nodes = None;
401
402 if self.max_retained_roots > 0 {
405 let old_target_size = self.target.range.end();
406 assert!(
407 self.retained_roots
408 .insert(old_target_size, self.target.root)
409 .is_none(),
410 "duplicate retained root for tree size {old_target_size:?}"
411 );
412 self.retained_roots_order.push_back(old_target_size);
413 while self.retained_roots.len() > self.max_retained_roots {
414 if let Some(oldest) = self.retained_roots_order.pop_front() {
415 self.retained_roots.remove(&oldest);
416 }
417 }
418 }
419
420 self.target = new_target;
421 self.reached_current_target_reported = false;
422 Ok(self)
423 }
424
425 fn drain_finish_requests(&mut self) -> Result<(), Error<DB, R>> {
431 let Some(finish_rx) = self.finish_rx.as_mut() else {
432 return Ok(());
433 };
434 match finish_rx.try_recv() {
435 Ok(()) => {
436 self.accept_finish();
437 Ok(())
438 }
439 Err(TryRecvError::Empty) => Ok(()),
440 Err(TryRecvError::Disconnected) => {
441 Err(SyncError::Engine(EngineError::FinishChannelClosed))
442 }
443 }
444 }
445
446 fn accept_finish(&mut self) {
451 self.finish_requested = true;
452 self.finish_rx = None;
453 }
454
455 async fn report_reached_target(&mut self) {
463 if self.reached_current_target_reported {
464 return;
465 }
466 if let Some(sender) = self.reached_target_tx.as_ref() {
467 if !sender.send_lossy(self.target.clone()).await {
468 self.reached_target_tx = None;
469 }
470 }
471 self.reached_current_target_reported = true;
472 }
473
474 pub(crate) fn store_operations(&mut self, start_loc: Location, operations: Vec<DB::Op>) {
476 if operations.is_empty() {
477 return;
478 }
479 self.fetched_operations.insert(start_loc, operations);
480 }
481
482 pub(crate) async fn apply_operations(&mut self) -> Result<(), Error<DB, R>> {
488 let mut next_loc = self.journal.size().await;
489
490 self.fetched_operations.retain(|&start_loc, operations| {
493 assert!(!operations.is_empty());
494 let end_loc = start_loc.checked_add(operations.len() as u64 - 1).unwrap();
495 end_loc >= next_loc
496 });
497
498 loop {
499 let range_start_loc =
502 self.fetched_operations
503 .iter()
504 .find_map(|(range_start, range_ops)| {
505 assert!(!range_ops.is_empty());
506 let range_end =
507 range_start.checked_add(range_ops.len() as u64 - 1).unwrap();
508 if *range_start <= next_loc && next_loc <= range_end {
509 Some(*range_start)
510 } else {
511 None
512 }
513 });
514
515 let Some(range_start_loc) = range_start_loc else {
516 break;
518 };
519
520 let operations = self.fetched_operations.remove(&range_start_loc).unwrap();
522 assert!(!operations.is_empty());
523 let skip_count = (next_loc - *range_start_loc) as usize;
525 let operations_count = operations.len() - skip_count;
526 let remaining_operations = operations.into_iter().skip(skip_count);
527 next_loc += operations_count as u64;
528 self.apply_operations_batch(remaining_operations).await?;
529 }
530
531 Ok(())
532 }
533
534 async fn apply_operations_batch<I>(&mut self, operations: I) -> Result<(), Error<DB, R>>
536 where
537 I: IntoIterator<Item = DB::Op>,
538 {
539 for op in operations {
540 self.journal.append(op).await?;
541 }
544 Ok(())
545 }
546
547 pub async fn is_at_target(&self) -> Result<bool, Error<DB, R>> {
549 let journal_size = self.journal.size().await;
550 let target_journal_size = self.target.range.end();
551
552 if journal_size >= target_journal_size {
554 if journal_size > target_journal_size {
555 return Err(SyncError::Engine(EngineError::InvalidState));
557 }
558 return Ok(true);
559 }
560
561 Ok(false)
562 }
563
564 fn handle_fetch_result(
571 &mut self,
572 fetch_result: IndexedFetchResult<DB::Op, DB::Digest, R::Error>,
573 ) -> Result<(), Error<DB, R>> {
574 if !self.outstanding_requests.remove(fetch_result.id) {
578 return Ok(());
579 }
580
581 let start_loc = fetch_result.start_loc;
582 let FetchResult {
583 proof,
584 operations,
585 success_tx,
586 pinned_nodes,
587 } = fetch_result.result.map_err(SyncError::Resolver)?;
588
589 let operations_len = operations.len() as u64;
591 if operations_len == 0 || operations_len > self.fetch_batch_size.get() {
592 success_tx.send_lossy(false);
595 return Ok(());
596 }
597
598 let is_current = proof.leaves == self.target.range.end();
602 let target_root = if is_current {
603 &self.target.root
604 } else {
605 let Some(root) = self.retained_roots.get(&proof.leaves) else {
606 return Ok(());
610 };
611 root
612 };
613
614 let need_pinned =
618 is_current && self.pinned_nodes.is_none() && start_loc == self.target.range.start();
619 let valid = if need_pinned {
620 let nodes = pinned_nodes.as_deref().unwrap_or(&[]);
621 qmdb::verify_proof_and_pinned_nodes(
622 &self.hasher,
623 &proof,
624 start_loc,
625 &operations,
626 nodes,
627 target_root,
628 )
629 } else {
630 qmdb::verify_proof(&self.hasher, &proof, start_loc, &operations, target_root)
631 };
632
633 success_tx.send_lossy(valid);
635
636 if !valid {
637 if need_pinned {
638 tracing::warn!("boundary proof or pinned nodes failed verification, will retry");
639 }
640 return Ok(());
641 }
642
643 if need_pinned {
645 if let Some(nodes) = pinned_nodes {
646 self.pinned_nodes = Some(nodes);
647 }
648 }
649
650 self.store_operations(start_loc, operations);
652
653 Ok(())
654 }
655
656 async fn handle_event(
658 mut self,
659 event: Event<DB::Op, DB::Digest, R::Error>,
660 ) -> Result<NextStep<Self, DB>, Error<DB, R>> {
661 match event {
662 Event::TargetUpdate(new_target) => {
663 validate_update(&self.target, &new_target)?;
664
665 let mut updated_self = self.reset_for_target_update(new_target).await?;
666 updated_self.schedule_requests().await?;
667 Ok(NextStep::Continue(updated_self))
668 }
669 Event::UpdateChannelClosed => {
670 self.update_rx = None;
671 Ok(NextStep::Continue(self))
672 }
673 Event::FinishRequested => {
674 self.accept_finish();
675 Ok(NextStep::Continue(self))
676 }
677 Event::FinishChannelClosed => Err(SyncError::Engine(EngineError::FinishChannelClosed)),
678 Event::BatchReceived(fetch_result) => {
679 self.handle_fetch_result(fetch_result)?;
680 self.schedule_requests().await?;
681 self.apply_operations().await?;
682 Ok(NextStep::Continue(self))
683 }
684 }
685 }
686
687 pub(crate) async fn step(mut self) -> Result<NextStep<Self, DB>, Error<DB, R>> {
698 self.drain_finish_requests()?;
699
700 if self.is_at_target().await? {
702 self.report_reached_target().await;
703
704 if self.finish_rx.is_some() && !self.finish_requested {
705 let event = wait_for_event(
706 &mut self.update_rx,
707 &mut self.finish_rx,
708 &mut self.outstanding_requests,
709 )
710 .await
711 .ok_or(SyncError::Engine(EngineError::SyncStalled))?;
712 return self.handle_event(event).await;
713 }
714
715 self.journal.sync().await?;
716
717 let database = DB::from_sync_result(
719 self.context,
720 self.config,
721 self.journal,
722 self.pinned_nodes,
723 self.target.range.clone().into(),
724 self.apply_batch_size,
725 )
726 .await?;
727
728 let got_root = database.root();
730 let expected_root = self.target.root;
731 if got_root != expected_root {
732 return Err(SyncError::Engine(EngineError::RootMismatch {
733 expected: expected_root,
734 actual: got_root,
735 }));
736 }
737
738 return Ok(NextStep::Complete(database));
739 }
740
741 let event = wait_for_event(
743 &mut self.update_rx,
744 &mut self.finish_rx,
745 &mut self.outstanding_requests,
746 )
747 .await
748 .ok_or(SyncError::Engine(EngineError::SyncStalled))?;
749 self.handle_event(event).await
750 }
751
752 pub async fn sync(mut self) -> Result<DB, Error<DB, R>> {
757 loop {
759 match self.step().await? {
760 NextStep::Continue(new_engine) => self = new_engine,
761 NextStep::Complete(database) => return Ok(database),
762 }
763 }
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770 use crate::merkle::mmr::Proof;
771 use commonware_cryptography::sha256;
772 use commonware_utils::channel::oneshot;
773 use std::{future::Future, pin::Pin};
774
775 fn dummy_future(
777 id: RequestId,
778 loc: u64,
779 ) -> Pin<Box<dyn Future<Output = IndexedFetchResult<i32, sha256::Digest, ()>> + Send>> {
780 Box::pin(async move {
781 IndexedFetchResult {
782 id,
783 start_loc: Location::new(loc),
784 result: Ok(FetchResult {
785 proof: Proof {
786 leaves: Location::new(0),
787 digests: vec![],
788 },
789 operations: vec![],
790 success_tx: oneshot::channel().0,
791 pinned_nodes: None,
792 }),
793 }
794 })
795 }
796
797 fn add(requests: &mut Requests<i32, sha256::Digest, ()>, loc: u64) -> RequestId {
799 let id = requests.next_id();
800 requests.insert(
801 id,
802 Location::new(loc),
803 oneshot::channel().0,
804 dummy_future(id, loc),
805 );
806 id
807 }
808
809 #[test]
810 fn test_add_and_remove() {
811 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
812 assert_eq!(requests.len(), 0);
813
814 let id = add(&mut requests, 10);
815 assert_eq!(requests.len(), 1);
816 assert!(requests.contains(&Location::new(10)));
817
818 assert!(requests.remove(id));
819 assert!(!requests.contains(&Location::new(10)));
820 assert!(!requests.remove(id));
821 }
822
823 #[test]
824 fn test_remove_before() {
825 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
826
827 add(&mut requests, 5);
828 add(&mut requests, 10);
829 add(&mut requests, 15);
830 add(&mut requests, 20);
831 assert_eq!(requests.len(), 4);
832
833 requests.remove_before(Location::new(10));
834 assert_eq!(requests.len(), 3);
835 assert!(!requests.contains(&Location::new(5)));
836 assert!(requests.contains(&Location::new(10)));
837 assert!(requests.contains(&Location::new(15)));
838 assert!(requests.contains(&Location::new(20)));
839 }
840
841 #[test]
842 fn test_remove_before_all() {
843 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
844
845 add(&mut requests, 5);
846 add(&mut requests, 10);
847 assert_eq!(requests.len(), 2);
848
849 requests.remove_before(Location::new(100));
850 assert_eq!(requests.len(), 0);
851 }
852
853 #[test]
854 fn test_remove_before_empty() {
855 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
856 requests.remove_before(Location::new(10));
857 assert_eq!(requests.len(), 0);
858 }
859
860 #[test]
861 fn test_remove_before_none() {
862 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
863
864 add(&mut requests, 10);
865 add(&mut requests, 20);
866 assert_eq!(requests.len(), 2);
867
868 requests.remove_before(Location::new(5));
869 assert_eq!(requests.len(), 2);
870 assert!(requests.contains(&Location::new(10)));
871 assert!(requests.contains(&Location::new(20)));
872 }
873
874 #[test]
875 fn test_superseded_request() {
876 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
877
878 let old_id = add(&mut requests, 10);
880 assert_eq!(requests.len(), 1);
881
882 let new_id = add(&mut requests, 10);
884 assert_eq!(requests.len(), 1);
885
886 assert!(!requests.remove(old_id));
888
889 assert!(requests.contains(&Location::new(10)));
891 assert!(requests.remove(new_id));
892 assert!(!requests.contains(&Location::new(10)));
893 }
894
895 #[test]
896 fn test_stale_id_after_remove_before() {
897 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
898
899 let old_id = add(&mut requests, 5);
900 add(&mut requests, 15);
901 requests.remove_before(Location::new(10));
902
903 assert!(!requests.remove(old_id));
905
906 let new_id = add(&mut requests, 5);
908 assert_ne!(old_id, new_id);
909 assert!(requests.remove(new_id));
910 }
911}