1use crate::{
3 mmr::{Location, StandardHasher},
4 qmdb::{
5 self,
6 sync::{
7 database::Config as _,
8 error::EngineError,
9 requests::Requests,
10 resolver::{FetchResult, Resolver},
11 target::validate_update,
12 Database, Error as SyncError, Journal, Target,
13 },
14 },
15};
16use commonware_codec::Encode;
17use commonware_cryptography::Digest;
18use commonware_macros::select;
19use commonware_runtime::Metrics as _;
20use commonware_utils::{channel::mpsc, NZU64};
21use futures::{future::Either, StreamExt};
22use std::{collections::BTreeMap, fmt::Debug, num::NonZeroU64};
23
24type Error<DB, R> = qmdb::sync::Error<<R as Resolver>::Error, <DB as Database>::Digest>;
26
27#[derive(Debug)]
29pub(crate) enum NextStep<C, D> {
30 Continue(C),
32 Complete(D),
34}
35
36#[derive(Debug)]
38enum Event<Op, D: Digest, E> {
39 TargetUpdate(Target<D>),
41 BatchReceived(IndexedFetchResult<Op, D, E>),
43 UpdateChannelClosed,
45}
46
47#[derive(Debug)]
49pub(super) struct IndexedFetchResult<Op, D: Digest, E> {
50 pub start_loc: Location,
52 pub result: Result<FetchResult<Op, D>, E>,
54}
55
56async fn wait_for_event<Op, D: Digest, E>(
59 update_receiver: &mut Option<mpsc::Receiver<Target<D>>>,
60 outstanding_requests: &mut Requests<Op, D, E>,
61) -> Option<Event<Op, D, E>> {
62 let target_update_fut = update_receiver.as_mut().map_or_else(
63 || Either::Right(futures::future::pending()),
64 |update_rx| Either::Left(update_rx.recv()),
65 );
66
67 select! {
68 target = target_update_fut => target.map_or_else(
69 || Some(Event::UpdateChannelClosed),
70 |target| Some(Event::TargetUpdate(target))
71 ),
72 result = outstanding_requests.futures_mut().next() => {
73 result.map(|fetch_result| Event::BatchReceived(fetch_result))
74 },
75 }
76}
77
78pub struct Config<DB, R>
80where
81 DB: Database,
82 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
83 DB::Op: Encode,
84{
85 pub context: DB::Context,
87 pub resolver: R,
89 pub target: Target<DB::Digest>,
91 pub max_outstanding_requests: usize,
93 pub fetch_batch_size: NonZeroU64,
95 pub apply_batch_size: usize,
97 pub db_config: DB::Config,
99 pub update_rx: Option<mpsc::Receiver<Target<DB::Digest>>>,
101}
102pub(crate) struct Engine<DB, R>
104where
105 DB: Database,
106 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
107 DB::Op: Encode,
108{
109 outstanding_requests: Requests<DB::Op, DB::Digest, R::Error>,
111
112 fetched_operations: BTreeMap<Location, Vec<DB::Op>>,
118
119 pinned_nodes: Option<Vec<DB::Digest>>,
121
122 target: Target<DB::Digest>,
124
125 max_outstanding_requests: usize,
127
128 fetch_batch_size: NonZeroU64,
130
131 apply_batch_size: usize,
133
134 journal: DB::Journal,
136
137 resolver: R,
139
140 hasher: StandardHasher<DB::Hasher>,
142
143 context: DB::Context,
145
146 config: DB::Config,
148
149 update_receiver: Option<mpsc::Receiver<Target<DB::Digest>>>,
151}
152
153#[cfg(test)]
154impl<DB, R> Engine<DB, R>
155where
156 DB: Database,
157 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
158 DB::Op: Encode,
159{
160 pub(crate) fn journal(&self) -> &DB::Journal {
161 &self.journal
162 }
163}
164
165impl<DB, R> Engine<DB, R>
166where
167 DB: Database,
168 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
169 DB::Op: Encode,
170{
171 pub async fn new(config: Config<DB, R>) -> Result<Self, Error<DB, R>> {
173 if config.target.range.is_empty() || !config.target.range.end.is_valid() {
174 return Err(SyncError::Engine(EngineError::InvalidTarget {
175 lower_bound_pos: config.target.range.start,
176 upper_bound_pos: config.target.range.end,
177 }));
178 }
179
180 let journal = <DB::Journal as Journal>::new(
182 config.context.with_label("journal"),
183 config.db_config.journal_config(),
184 config.target.range.clone(),
185 )
186 .await?;
187
188 let mut engine = Self {
189 outstanding_requests: Requests::new(),
190 fetched_operations: BTreeMap::new(),
191 pinned_nodes: None,
192 target: config.target.clone(),
193 max_outstanding_requests: config.max_outstanding_requests,
194 fetch_batch_size: config.fetch_batch_size,
195 apply_batch_size: config.apply_batch_size,
196 journal,
197 resolver: config.resolver.clone(),
198 hasher: StandardHasher::<DB::Hasher>::new(),
199 context: config.context,
200 config: config.db_config,
201 update_receiver: config.update_rx,
202 };
203 engine.schedule_requests().await?;
204 Ok(engine)
205 }
206
207 async fn schedule_requests(&mut self) -> Result<(), Error<DB, R>> {
209 let target_size = self.target.range.end;
210
211 if self.pinned_nodes.is_none() {
214 let start_loc = self.target.range.start;
215 let resolver = self.resolver.clone();
216 self.outstanding_requests.add(
217 start_loc,
218 Box::pin(async move {
219 let result = resolver
220 .get_operations(target_size, start_loc, NZU64!(1))
221 .await;
222 IndexedFetchResult { start_loc, result }
223 }),
224 );
225 }
226
227 let num_requests = self
229 .max_outstanding_requests
230 .saturating_sub(self.outstanding_requests.len());
231
232 let log_size = self.journal.size().await;
233
234 for _ in 0..num_requests {
235 let operation_counts: BTreeMap<Location, u64> = self
237 .fetched_operations
238 .iter()
239 .map(|(&start_loc, operations)| (start_loc, operations.len() as u64))
240 .collect();
241
242 let Some(gap_range) = crate::qmdb::sync::gaps::find_next(
244 Location::new(log_size)..self.target.range.end,
245 &operation_counts,
246 self.outstanding_requests.locations(),
247 self.fetch_batch_size,
248 ) else {
249 break; };
251
252 let gap_size = *gap_range.end.checked_sub(*gap_range.start).unwrap();
254 let gap_size: NonZeroU64 = gap_size.try_into().unwrap();
255 let batch_size = self.fetch_batch_size.min(gap_size);
256
257 let resolver = self.resolver.clone();
259 self.outstanding_requests.add(
260 gap_range.start,
261 Box::pin(async move {
262 let result = resolver
263 .get_operations(target_size, gap_range.start, batch_size)
264 .await;
265 IndexedFetchResult {
266 start_loc: gap_range.start,
267 result,
268 }
269 }),
270 );
271 }
272
273 Ok(())
274 }
275
276 pub async fn reset_for_target_update(
278 mut self,
279 new_target: Target<DB::Digest>,
280 ) -> Result<Self, Error<DB, R>> {
281 self.journal.resize(new_target.range.start).await?;
282
283 Ok(Self {
284 outstanding_requests: Requests::new(),
285 fetched_operations: BTreeMap::new(),
286 pinned_nodes: None,
287 target: new_target,
288 max_outstanding_requests: self.max_outstanding_requests,
289 fetch_batch_size: self.fetch_batch_size,
290 apply_batch_size: self.apply_batch_size,
291 journal: self.journal,
292 resolver: self.resolver,
293 hasher: self.hasher,
294 context: self.context,
295 config: self.config,
296 update_receiver: self.update_receiver,
297 })
298 }
299
300 pub(crate) fn store_operations(&mut self, start_loc: Location, operations: Vec<DB::Op>) {
302 if operations.is_empty() {
303 return;
304 }
305 self.fetched_operations.insert(start_loc, operations);
306 }
307
308 pub(crate) async fn apply_operations(&mut self) -> Result<(), Error<DB, R>> {
314 let mut next_loc = self.journal.size().await;
315
316 self.fetched_operations.retain(|&start_loc, operations| {
319 assert!(!operations.is_empty());
320 let end_loc = start_loc.checked_add(operations.len() as u64 - 1).unwrap();
321 end_loc >= next_loc
322 });
323
324 loop {
325 let range_start_loc =
328 self.fetched_operations
329 .iter()
330 .find_map(|(range_start, range_ops)| {
331 assert!(!range_ops.is_empty());
332 let range_end =
333 range_start.checked_add(range_ops.len() as u64 - 1).unwrap();
334 if *range_start <= next_loc && next_loc <= range_end {
335 Some(*range_start)
336 } else {
337 None
338 }
339 });
340
341 let Some(range_start_loc) = range_start_loc else {
342 break;
344 };
345
346 let operations = self.fetched_operations.remove(&range_start_loc).unwrap();
348 assert!(!operations.is_empty());
349 let skip_count = (next_loc - *range_start_loc) as usize;
351 let operations_count = operations.len() - skip_count;
352 let remaining_operations = operations.into_iter().skip(skip_count);
353 next_loc += operations_count as u64;
354 self.apply_operations_batch(remaining_operations).await?;
355 }
356
357 Ok(())
358 }
359
360 async fn apply_operations_batch<I>(&mut self, operations: I) -> Result<(), Error<DB, R>>
362 where
363 I: IntoIterator<Item = DB::Op>,
364 {
365 for op in operations {
366 self.journal.append(op).await?;
367 }
370 Ok(())
371 }
372
373 pub async fn is_complete(&self) -> Result<bool, Error<DB, R>> {
375 let journal_size = self.journal.size().await;
376 let target_journal_size = self.target.range.end;
377
378 if journal_size >= target_journal_size {
380 if journal_size > target_journal_size {
381 return Err(SyncError::Engine(EngineError::InvalidState));
383 }
384 return Ok(true);
385 }
386
387 Ok(false)
388 }
389
390 fn handle_fetch_result(
399 &mut self,
400 fetch_result: IndexedFetchResult<DB::Op, DB::Digest, R::Error>,
401 ) -> Result<(), Error<DB, R>> {
402 self.outstanding_requests.remove(fetch_result.start_loc);
404
405 let start_loc = fetch_result.start_loc;
406 let FetchResult {
407 proof,
408 operations,
409 success_tx,
410 } = fetch_result.result.map_err(SyncError::Resolver)?;
411
412 let operations_len = operations.len() as u64;
414 if operations_len == 0 || operations_len > self.fetch_batch_size.get() {
415 let _ = success_tx.send(false);
418 return Ok(());
419 }
420
421 let proof_valid = qmdb::verify_proof(
423 &mut self.hasher,
424 &proof,
425 start_loc,
426 &operations,
427 &self.target.root,
428 );
429
430 let _ = success_tx.send(proof_valid);
432
433 if proof_valid {
434 if self.pinned_nodes.is_none() && start_loc == self.target.range.start {
436 if let Ok(nodes) =
437 crate::qmdb::extract_pinned_nodes(&proof, start_loc, operations_len)
438 {
439 self.pinned_nodes = Some(nodes);
440 }
441 }
442
443 self.store_operations(start_loc, operations);
445 }
446
447 Ok(())
448 }
449
450 pub(crate) async fn step(mut self) -> Result<NextStep<Self, DB>, Error<DB, R>> {
461 if self.is_complete().await? {
463 self.journal.sync().await?;
464
465 let database = DB::from_sync_result(
467 self.context,
468 self.config,
469 self.journal,
470 self.pinned_nodes,
471 self.target.range.clone(),
472 self.apply_batch_size,
473 )
474 .await?;
475
476 let got_root = database.root();
478 let expected_root = self.target.root;
479 if got_root != expected_root {
480 return Err(SyncError::Engine(EngineError::RootMismatch {
481 expected: expected_root,
482 actual: got_root,
483 }));
484 }
485
486 return Ok(NextStep::Complete(database));
487 }
488
489 let event = wait_for_event(&mut self.update_receiver, &mut self.outstanding_requests)
491 .await
492 .ok_or(SyncError::Engine(EngineError::SyncStalled))?;
493
494 match event {
495 Event::TargetUpdate(new_target) => {
496 validate_update(&self.target, &new_target)?;
498
499 let mut updated_self = self.reset_for_target_update(new_target).await?;
500
501 updated_self.schedule_requests().await?;
503
504 return Ok(NextStep::Continue(updated_self));
505 }
506 Event::UpdateChannelClosed => {
507 self.update_receiver = None;
508 }
509 Event::BatchReceived(fetch_result) => {
510 self.handle_fetch_result(fetch_result)?;
512
513 self.schedule_requests().await?;
515
516 self.apply_operations().await?;
518 }
519 }
520
521 Ok(NextStep::Continue(self))
522 }
523
524 pub async fn sync(mut self) -> Result<DB, Error<DB, R>> {
529 loop {
531 match self.step().await? {
532 NextStep::Continue(new_engine) => self = new_engine,
533 NextStep::Complete(database) => return Ok(database),
534 }
535 }
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use crate::mmr::Proof;
543 use commonware_cryptography::sha256;
544 use commonware_utils::channel::oneshot;
545
546 #[test]
547 fn test_outstanding_requests() {
548 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
549 assert_eq!(requests.len(), 0);
550
551 let fut = Box::pin(async {
553 IndexedFetchResult {
554 start_loc: Location::new(0),
555 result: Ok(FetchResult {
556 proof: Proof {
557 leaves: Location::new(0),
558 digests: vec![],
559 },
560 operations: vec![],
561 success_tx: oneshot::channel().0,
562 }),
563 }
564 });
565 requests.add(Location::new(10), fut);
566 assert_eq!(requests.len(), 1);
567 assert!(requests.locations().contains(&Location::new(10)));
568
569 requests.remove(Location::new(10));
571 assert!(!requests.locations().contains(&Location::new(10)));
572 }
573}