1use crate::{
3 mmr::{Location, StandardHasher},
4 qmdb::{
5 self,
6 sync::{
7 database::Config as _,
8 error::EngineError,
9 requests::Requests,
10 resolver::{FetchResult, Resolver},
11 target::validate_update,
12 Database, Error as SyncError, Journal, Target,
13 },
14 },
15};
16use commonware_codec::Encode;
17use commonware_cryptography::Digest;
18use commonware_macros::select;
19use commonware_runtime::Metrics as _;
20use commonware_utils::{channel::mpsc, NZU64};
21use futures::{future::Either, StreamExt};
22use std::{collections::BTreeMap, fmt::Debug, num::NonZeroU64};
23
24type Error<DB, R> = qmdb::sync::Error<<R as Resolver>::Error, <DB as Database>::Digest>;
26
27#[derive(Debug)]
29pub(crate) enum NextStep<C, D> {
30 Continue(C),
32 Complete(D),
34}
35
36#[derive(Debug)]
38enum Event<Op, D: Digest, E> {
39 TargetUpdate(Target<D>),
41 BatchReceived(IndexedFetchResult<Op, D, E>),
43 UpdateChannelClosed,
45}
46
47#[derive(Debug)]
49pub(super) struct IndexedFetchResult<Op, D: Digest, E> {
50 pub start_loc: Location,
52 pub result: Result<FetchResult<Op, D>, E>,
54}
55
56async fn wait_for_event<Op, D: Digest, E>(
59 update_receiver: &mut Option<mpsc::Receiver<Target<D>>>,
60 outstanding_requests: &mut Requests<Op, D, E>,
61) -> Option<Event<Op, D, E>> {
62 let target_update_fut = update_receiver.as_mut().map_or_else(
63 || Either::Right(futures::future::pending()),
64 |update_rx| Either::Left(update_rx.recv()),
65 );
66
67 select! {
68 target = target_update_fut => target.map_or_else(
69 || Some(Event::UpdateChannelClosed),
70 |target| Some(Event::TargetUpdate(target))
71 ),
72 result = outstanding_requests.futures_mut().next() => {
73 result.map(|fetch_result| Event::BatchReceived(fetch_result))
74 },
75 }
76}
77
78pub struct Config<DB, R>
80where
81 DB: Database,
82 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
83 DB::Op: Encode,
84{
85 pub context: DB::Context,
87 pub resolver: R,
89 pub target: Target<DB::Digest>,
91 pub max_outstanding_requests: usize,
93 pub fetch_batch_size: NonZeroU64,
95 pub apply_batch_size: usize,
97 pub db_config: DB::Config,
99 pub update_rx: Option<mpsc::Receiver<Target<DB::Digest>>>,
101}
102pub(crate) struct Engine<DB, R>
104where
105 DB: Database,
106 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
107 DB::Op: Encode,
108{
109 outstanding_requests: Requests<DB::Op, DB::Digest, R::Error>,
111
112 fetched_operations: BTreeMap<Location, Vec<DB::Op>>,
114
115 pinned_nodes: Option<Vec<DB::Digest>>,
117
118 target: Target<DB::Digest>,
120
121 max_outstanding_requests: usize,
123
124 fetch_batch_size: NonZeroU64,
126
127 apply_batch_size: usize,
129
130 journal: DB::Journal,
132
133 resolver: R,
135
136 hasher: StandardHasher<DB::Hasher>,
138
139 context: DB::Context,
141
142 config: DB::Config,
144
145 update_receiver: Option<mpsc::Receiver<Target<DB::Digest>>>,
147}
148
149#[cfg(test)]
150impl<DB, R> Engine<DB, R>
151where
152 DB: Database,
153 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
154 DB::Op: Encode,
155{
156 pub(crate) fn journal(&self) -> &DB::Journal {
157 &self.journal
158 }
159}
160
161impl<DB, R> Engine<DB, R>
162where
163 DB: Database,
164 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
165 DB::Op: Encode,
166{
167 pub async fn new(config: Config<DB, R>) -> Result<Self, Error<DB, R>> {
169 if config.target.range.is_empty() || !config.target.range.end.is_valid() {
170 return Err(SyncError::Engine(EngineError::InvalidTarget {
171 lower_bound_pos: config.target.range.start,
172 upper_bound_pos: config.target.range.end,
173 }));
174 }
175
176 let journal = <DB::Journal as Journal>::new(
178 config.context.with_label("journal"),
179 config.db_config.journal_config(),
180 config.target.range.clone(),
181 )
182 .await?;
183
184 let mut engine = Self {
185 outstanding_requests: Requests::new(),
186 fetched_operations: BTreeMap::new(),
187 pinned_nodes: None,
188 target: config.target.clone(),
189 max_outstanding_requests: config.max_outstanding_requests,
190 fetch_batch_size: config.fetch_batch_size,
191 apply_batch_size: config.apply_batch_size,
192 journal,
193 resolver: config.resolver.clone(),
194 hasher: StandardHasher::<DB::Hasher>::new(),
195 context: config.context,
196 config: config.db_config,
197 update_receiver: config.update_rx,
198 };
199 engine.schedule_requests().await?;
200 Ok(engine)
201 }
202
203 async fn schedule_requests(&mut self) -> Result<(), Error<DB, R>> {
205 let target_size = self.target.range.end;
206
207 if self.pinned_nodes.is_none() {
210 let start_loc = self.target.range.start;
211 let resolver = self.resolver.clone();
212 self.outstanding_requests.add(
213 start_loc,
214 Box::pin(async move {
215 let result = resolver
216 .get_operations(target_size, start_loc, NZU64!(1))
217 .await;
218 IndexedFetchResult { start_loc, result }
219 }),
220 );
221 }
222
223 let num_requests = self
225 .max_outstanding_requests
226 .saturating_sub(self.outstanding_requests.len());
227
228 let log_size = self.journal.size().await;
229
230 for _ in 0..num_requests {
231 let operation_counts: BTreeMap<Location, u64> = self
233 .fetched_operations
234 .iter()
235 .map(|(&start_loc, operations)| (start_loc, operations.len() as u64))
236 .collect();
237
238 let Some(gap_range) = crate::qmdb::sync::gaps::find_next(
240 Location::new_unchecked(log_size)..self.target.range.end,
241 &operation_counts,
242 self.outstanding_requests.locations(),
243 self.fetch_batch_size,
244 ) else {
245 break; };
247
248 let gap_size = *gap_range.end.checked_sub(*gap_range.start).unwrap();
250 let gap_size: NonZeroU64 = gap_size.try_into().unwrap();
251 let batch_size = self.fetch_batch_size.min(gap_size);
252
253 let resolver = self.resolver.clone();
255 self.outstanding_requests.add(
256 gap_range.start,
257 Box::pin(async move {
258 let result = resolver
259 .get_operations(target_size, gap_range.start, batch_size)
260 .await;
261 IndexedFetchResult {
262 start_loc: gap_range.start,
263 result,
264 }
265 }),
266 );
267 }
268
269 Ok(())
270 }
271
272 pub async fn reset_for_target_update(
274 mut self,
275 new_target: Target<DB::Digest>,
276 ) -> Result<Self, Error<DB, R>> {
277 self.journal.resize(new_target.range.start).await?;
278
279 Ok(Self {
280 outstanding_requests: Requests::new(),
281 fetched_operations: BTreeMap::new(),
282 pinned_nodes: None,
283 target: new_target,
284 max_outstanding_requests: self.max_outstanding_requests,
285 fetch_batch_size: self.fetch_batch_size,
286 apply_batch_size: self.apply_batch_size,
287 journal: self.journal,
288 resolver: self.resolver,
289 hasher: self.hasher,
290 context: self.context,
291 config: self.config,
292 update_receiver: self.update_receiver,
293 })
294 }
295
296 pub fn store_operations(&mut self, start_loc: Location, operations: Vec<DB::Op>) {
298 self.fetched_operations.insert(start_loc, operations);
299 }
300
301 pub async fn apply_operations(&mut self) -> Result<(), Error<DB, R>> {
307 let mut next_loc = self.journal.size().await;
308
309 self.fetched_operations.retain(|&start_loc, operations| {
312 let end_loc = start_loc.checked_add(operations.len() as u64 - 1).unwrap();
313 end_loc >= next_loc
314 });
315
316 loop {
317 let range_start_loc =
320 self.fetched_operations
321 .iter()
322 .find_map(|(range_start, range_ops)| {
323 let range_end =
324 range_start.checked_add(range_ops.len() as u64 - 1).unwrap();
325 if *range_start <= next_loc && next_loc <= range_end {
326 Some(*range_start)
327 } else {
328 None
329 }
330 });
331
332 let Some(range_start_loc) = range_start_loc else {
333 break;
335 };
336
337 let operations = self.fetched_operations.remove(&range_start_loc).unwrap();
339 let skip_count = (next_loc - *range_start_loc) as usize;
341 let operations_count = operations.len() - skip_count;
342 let remaining_operations = operations.into_iter().skip(skip_count);
343 next_loc += operations_count as u64;
344 self.apply_operations_batch(remaining_operations).await?;
345 }
346
347 Ok(())
348 }
349
350 async fn apply_operations_batch<I>(&mut self, operations: I) -> Result<(), Error<DB, R>>
352 where
353 I: IntoIterator<Item = DB::Op>,
354 {
355 for op in operations {
356 self.journal.append(op).await?;
357 }
360 Ok(())
361 }
362
363 pub async fn is_complete(&self) -> Result<bool, Error<DB, R>> {
365 let journal_size = self.journal.size().await;
366 let target_journal_size = self.target.range.end;
367
368 if journal_size >= target_journal_size {
370 if journal_size > target_journal_size {
371 return Err(SyncError::Engine(EngineError::InvalidState));
373 }
374 return Ok(true);
375 }
376
377 Ok(false)
378 }
379
380 fn handle_fetch_result(
389 &mut self,
390 fetch_result: IndexedFetchResult<DB::Op, DB::Digest, R::Error>,
391 ) -> Result<(), Error<DB, R>> {
392 self.outstanding_requests.remove(fetch_result.start_loc);
394
395 let start_loc = fetch_result.start_loc;
396 let FetchResult {
397 proof,
398 operations,
399 success_tx,
400 } = fetch_result.result.map_err(SyncError::Resolver)?;
401
402 let operations_len = operations.len() as u64;
404 if operations_len == 0 || operations_len > self.fetch_batch_size.get() {
405 let _ = success_tx.send(false);
408 return Ok(());
409 }
410
411 let proof_valid = qmdb::verify_proof(
413 &mut self.hasher,
414 &proof,
415 start_loc,
416 &operations,
417 &self.target.root,
418 );
419
420 let _ = success_tx.send(proof_valid);
422
423 if proof_valid {
424 if self.pinned_nodes.is_none() && start_loc == self.target.range.start {
426 if let Ok(nodes) =
427 crate::qmdb::extract_pinned_nodes(&proof, start_loc, operations_len)
428 {
429 self.pinned_nodes = Some(nodes);
430 }
431 }
432
433 self.store_operations(start_loc, operations);
435 }
436
437 Ok(())
438 }
439
440 pub(crate) async fn step(mut self) -> Result<NextStep<Self, DB>, Error<DB, R>> {
451 if self.is_complete().await? {
453 self.journal.sync().await?;
454
455 let database = DB::from_sync_result(
457 self.context,
458 self.config,
459 self.journal,
460 self.pinned_nodes,
461 self.target.range.clone(),
462 self.apply_batch_size,
463 )
464 .await?;
465
466 let got_root = database.root();
468 let expected_root = self.target.root;
469 if got_root != expected_root {
470 return Err(SyncError::Engine(EngineError::RootMismatch {
471 expected: expected_root,
472 actual: got_root,
473 }));
474 }
475
476 return Ok(NextStep::Complete(database));
477 }
478
479 let event = wait_for_event(&mut self.update_receiver, &mut self.outstanding_requests)
481 .await
482 .ok_or(SyncError::Engine(EngineError::SyncStalled))?;
483
484 match event {
485 Event::TargetUpdate(new_target) => {
486 validate_update(&self.target, &new_target)?;
488
489 let mut updated_self = self.reset_for_target_update(new_target).await?;
490
491 updated_self.schedule_requests().await?;
493
494 return Ok(NextStep::Continue(updated_self));
495 }
496 Event::UpdateChannelClosed => {
497 self.update_receiver = None;
498 }
499 Event::BatchReceived(fetch_result) => {
500 self.handle_fetch_result(fetch_result)?;
502
503 self.schedule_requests().await?;
505
506 self.apply_operations().await?;
508 }
509 }
510
511 Ok(NextStep::Continue(self))
512 }
513
514 pub async fn sync(mut self) -> Result<DB, Error<DB, R>> {
519 loop {
521 match self.step().await? {
522 NextStep::Continue(new_engine) => self = new_engine,
523 NextStep::Complete(database) => return Ok(database),
524 }
525 }
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use crate::mmr::Proof;
533 use commonware_cryptography::sha256;
534 use commonware_utils::channel::oneshot;
535
536 #[test]
537 fn test_outstanding_requests() {
538 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
539 assert_eq!(requests.len(), 0);
540
541 let fut = Box::pin(async {
543 IndexedFetchResult {
544 start_loc: Location::new_unchecked(0),
545 result: Ok(FetchResult {
546 proof: Proof {
547 leaves: Location::new_unchecked(0),
548 digests: vec![],
549 },
550 operations: vec![],
551 success_tx: oneshot::channel().0,
552 }),
553 }
554 });
555 requests.add(Location::new_unchecked(10), fut);
556 assert_eq!(requests.len(), 1);
557 assert!(requests.locations().contains(&Location::new_unchecked(10)));
558
559 requests.remove(Location::new_unchecked(10));
561 assert!(!requests.locations().contains(&Location::new_unchecked(10)));
562 }
563}