1use crate::{
4 mmr::{Location, StandardHasher},
5 qmdb::{
6 self,
7 sync::{
8 error::EngineError,
9 requests::Requests,
10 resolver::{FetchResult, Resolver},
11 target::validate_update,
12 Database, Error as SyncError, Journal, Target,
13 },
14 },
15};
16use commonware_codec::Encode;
17use commonware_cryptography::Digest;
18use commonware_macros::select;
19use commonware_utils::NZU64;
20use futures::{channel::mpsc, future::Either, StreamExt};
21use std::{collections::BTreeMap, fmt::Debug, num::NonZeroU64};
22
23type Error<DB, R> = qmdb::sync::Error<<R as Resolver>::Error, <DB as Database>::Digest>;
25
26#[derive(Debug)]
28pub(crate) enum NextStep<C, D> {
29 Continue(C),
31 Complete(D),
33}
34
35#[derive(Debug)]
37enum Event<Op, D: Digest, E> {
38 TargetUpdate(Target<D>),
40 BatchReceived(IndexedFetchResult<Op, D, E>),
42 UpdateChannelClosed,
44}
45
46#[derive(Debug)]
48pub(super) struct IndexedFetchResult<Op, D: Digest, E> {
49 pub start_loc: Location,
51 pub result: Result<FetchResult<Op, D>, E>,
53}
54
55async fn wait_for_event<Op, D: Digest, E>(
58 update_receiver: &mut Option<mpsc::Receiver<Target<D>>>,
59 outstanding_requests: &mut Requests<Op, D, E>,
60) -> Option<Event<Op, D, E>> {
61 let target_update_fut = update_receiver.as_mut().map_or_else(
62 || Either::Right(futures::future::pending()),
63 |update_rx| Either::Left(update_rx.next()),
64 );
65
66 select! {
67 target = target_update_fut => {
68 target.map_or_else(|| Some(Event::UpdateChannelClosed), |target| Some(Event::TargetUpdate(target)))
69 },
70 result = outstanding_requests.futures_mut().next() => {
71 result.map(|fetch_result| Event::BatchReceived(fetch_result))
72 },
73 }
74}
75
76pub struct Config<DB, R>
78where
79 DB: Database,
80 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
81 DB::Op: Encode,
82{
83 pub context: DB::Context,
85 pub resolver: R,
87 pub target: Target<DB::Digest>,
89 pub max_outstanding_requests: usize,
91 pub fetch_batch_size: NonZeroU64,
93 pub apply_batch_size: usize,
95 pub db_config: DB::Config,
97 pub update_rx: Option<mpsc::Receiver<Target<DB::Digest>>>,
99}
100pub(crate) struct Engine<DB, R>
102where
103 DB: Database,
104 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
105 DB::Op: Encode,
106{
107 outstanding_requests: Requests<DB::Op, DB::Digest, R::Error>,
109
110 fetched_operations: BTreeMap<Location, Vec<DB::Op>>,
112
113 pinned_nodes: Option<Vec<DB::Digest>>,
115
116 target: Target<DB::Digest>,
118
119 max_outstanding_requests: usize,
121
122 fetch_batch_size: NonZeroU64,
124
125 apply_batch_size: usize,
127
128 journal: DB::Journal,
130
131 resolver: R,
133
134 hasher: StandardHasher<DB::Hasher>,
136
137 context: DB::Context,
139
140 config: DB::Config,
142
143 update_receiver: Option<mpsc::Receiver<Target<DB::Digest>>>,
145}
146
147#[cfg(test)]
148impl<DB, R> Engine<DB, R>
149where
150 DB: Database,
151 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
152 DB::Op: Encode,
153{
154 pub(crate) fn journal(&self) -> &DB::Journal {
155 &self.journal
156 }
157}
158
159impl<DB, R> Engine<DB, R>
160where
161 DB: Database,
162 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
163 DB::Op: Encode,
164{
165 pub async fn new(config: Config<DB, R>) -> Result<Self, Error<DB, R>> {
167 if config.target.range.is_empty() {
168 return Err(SyncError::Engine(EngineError::InvalidTarget {
169 lower_bound_pos: config.target.range.start,
170 upper_bound_pos: config.target.range.end,
171 }));
172 }
173
174 let journal = DB::create_journal(
176 config.context.clone(),
177 &config.db_config,
178 config.target.range.clone(),
179 )
180 .await?;
181
182 let mut engine = Self {
183 outstanding_requests: Requests::new(),
184 fetched_operations: BTreeMap::new(),
185 pinned_nodes: None,
186 target: config.target.clone(),
187 max_outstanding_requests: config.max_outstanding_requests,
188 fetch_batch_size: config.fetch_batch_size,
189 apply_batch_size: config.apply_batch_size,
190 journal,
191 resolver: config.resolver.clone(),
192 hasher: StandardHasher::<DB::Hasher>::new(),
193 context: config.context,
194 config: config.db_config,
195 update_receiver: config.update_rx,
196 };
197 engine.schedule_requests().await?;
198 Ok(engine)
199 }
200
201 async fn schedule_requests(&mut self) -> Result<(), Error<DB, R>> {
203 let target_size = self.target.range.end;
204
205 if self.pinned_nodes.is_none() {
208 let start_loc = self.target.range.start;
209 let resolver = self.resolver.clone();
210 self.outstanding_requests.add(
211 start_loc,
212 Box::pin(async move {
213 let result = resolver
214 .get_operations(target_size, start_loc, NZU64!(1))
215 .await;
216 IndexedFetchResult { start_loc, result }
217 }),
218 );
219 }
220
221 let num_requests = self
223 .max_outstanding_requests
224 .saturating_sub(self.outstanding_requests.len());
225
226 let log_size = self.journal.size().await;
227
228 for _ in 0..num_requests {
229 let operation_counts: BTreeMap<Location, u64> = self
231 .fetched_operations
232 .iter()
233 .map(|(&start_loc, operations)| (start_loc, operations.len() as u64))
234 .collect();
235
236 let Some(gap_range) = crate::qmdb::sync::gaps::find_next(
238 Location::new_unchecked(log_size)..self.target.range.end,
239 &operation_counts,
240 self.outstanding_requests.locations(),
241 self.fetch_batch_size,
242 ) else {
243 break; };
245
246 let gap_size = *gap_range.end.checked_sub(*gap_range.start).unwrap();
248 let gap_size: NonZeroU64 = gap_size.try_into().unwrap();
249 let batch_size = self.fetch_batch_size.min(gap_size);
250
251 let resolver = self.resolver.clone();
253 self.outstanding_requests.add(
254 gap_range.start,
255 Box::pin(async move {
256 let result = resolver
257 .get_operations(target_size, gap_range.start, batch_size)
258 .await;
259 IndexedFetchResult {
260 start_loc: gap_range.start,
261 result,
262 }
263 }),
264 );
265 }
266
267 Ok(())
268 }
269
270 pub async fn reset_for_target_update(
272 self,
273 new_target: Target<DB::Digest>,
274 ) -> Result<Self, Error<DB, R>> {
275 let journal = DB::resize_journal(
276 self.journal,
277 self.context.clone(),
278 &self.config,
279 new_target.range.clone(),
280 )
281 .await?;
282
283 Ok(Self {
284 outstanding_requests: Requests::new(),
285 fetched_operations: BTreeMap::new(),
286 pinned_nodes: None,
287 target: new_target,
288 max_outstanding_requests: self.max_outstanding_requests,
289 fetch_batch_size: self.fetch_batch_size,
290 apply_batch_size: self.apply_batch_size,
291 journal,
292 resolver: self.resolver,
293 hasher: self.hasher,
294 context: self.context,
295 config: self.config,
296 update_receiver: self.update_receiver,
297 })
298 }
299
300 pub fn store_operations(&mut self, start_loc: Location, operations: Vec<DB::Op>) {
302 self.fetched_operations.insert(start_loc, operations);
303 }
304
305 pub async fn apply_operations(&mut self) -> Result<(), Error<DB, R>> {
311 let mut next_loc = self.journal.size().await;
312
313 self.fetched_operations.retain(|&start_loc, operations| {
316 let end_loc = start_loc.checked_add(operations.len() as u64 - 1).unwrap();
317 end_loc >= next_loc
318 });
319
320 loop {
321 let range_start_loc =
324 self.fetched_operations
325 .iter()
326 .find_map(|(range_start, range_ops)| {
327 let range_end =
328 range_start.checked_add(range_ops.len() as u64 - 1).unwrap();
329 if *range_start <= next_loc && next_loc <= range_end {
330 Some(*range_start)
331 } else {
332 None
333 }
334 });
335
336 let Some(range_start_loc) = range_start_loc else {
337 break;
339 };
340
341 let operations = self.fetched_operations.remove(&range_start_loc).unwrap();
343 let skip_count = (next_loc - *range_start_loc) as usize;
345 let operations_count = operations.len() - skip_count;
346 let remaining_operations = operations.into_iter().skip(skip_count);
347 next_loc += operations_count as u64;
348 self.apply_operations_batch(remaining_operations).await?;
349 }
350
351 Ok(())
352 }
353
354 async fn apply_operations_batch<I>(&mut self, operations: I) -> Result<(), Error<DB, R>>
356 where
357 I: IntoIterator<Item = DB::Op>,
358 {
359 for op in operations {
360 self.journal.append(op).await?;
361 }
364 Ok(())
365 }
366
367 pub async fn is_complete(&self) -> Result<bool, Error<DB, R>> {
369 let journal_size = self.journal.size().await;
370 let target_journal_size = self.target.range.end;
371
372 if journal_size >= target_journal_size {
374 if journal_size > target_journal_size {
375 return Err(SyncError::Engine(EngineError::InvalidState));
377 }
378 return Ok(true);
379 }
380
381 Ok(false)
382 }
383
384 fn handle_fetch_result(
393 &mut self,
394 fetch_result: IndexedFetchResult<DB::Op, DB::Digest, R::Error>,
395 ) -> Result<(), Error<DB, R>> {
396 self.outstanding_requests.remove(fetch_result.start_loc);
398
399 let start_loc = fetch_result.start_loc;
400 let FetchResult {
401 proof,
402 operations,
403 success_tx,
404 } = fetch_result.result.map_err(SyncError::Resolver)?;
405
406 let operations_len = operations.len() as u64;
408 if operations_len == 0 || operations_len > self.fetch_batch_size.get() {
409 let _ = success_tx.send(false);
412 return Ok(());
413 }
414
415 let proof_valid = qmdb::verify_proof(
417 &mut self.hasher,
418 &proof,
419 start_loc,
420 &operations,
421 &self.target.root,
422 );
423
424 let _ = success_tx.send(proof_valid);
426
427 if proof_valid {
428 if self.pinned_nodes.is_none() && start_loc == self.target.range.start {
430 if let Ok(nodes) =
431 crate::qmdb::extract_pinned_nodes(&proof, start_loc, operations_len)
432 {
433 self.pinned_nodes = Some(nodes);
434 }
435 }
436
437 self.store_operations(start_loc, operations);
439 }
440
441 Ok(())
442 }
443
444 pub(crate) async fn step(mut self) -> Result<NextStep<Self, DB>, Error<DB, R>> {
455 if self.is_complete().await? {
457 let database = DB::from_sync_result(
459 self.context,
460 self.config,
461 self.journal,
462 self.pinned_nodes,
463 self.target.range.clone(),
464 self.apply_batch_size,
465 )
466 .await?;
467
468 let got_root = database.root();
470 let expected_root = self.target.root;
471 if got_root != expected_root {
472 return Err(SyncError::Engine(EngineError::RootMismatch {
473 expected: expected_root,
474 actual: got_root,
475 }));
476 }
477
478 return Ok(NextStep::Complete(database));
479 }
480
481 let event = wait_for_event(&mut self.update_receiver, &mut self.outstanding_requests)
483 .await
484 .ok_or(SyncError::Engine(EngineError::SyncStalled))?;
485
486 match event {
487 Event::TargetUpdate(new_target) => {
488 validate_update(&self.target, &new_target)?;
490
491 let mut updated_self = self.reset_for_target_update(new_target).await?;
492
493 updated_self.schedule_requests().await?;
495
496 return Ok(NextStep::Continue(updated_self));
497 }
498 Event::UpdateChannelClosed => {
499 self.update_receiver = None;
500 }
501 Event::BatchReceived(fetch_result) => {
502 self.handle_fetch_result(fetch_result)?;
504
505 self.schedule_requests().await?;
507
508 self.apply_operations().await?;
510 }
511 }
512
513 Ok(NextStep::Continue(self))
514 }
515
516 pub async fn sync(mut self) -> Result<DB, Error<DB, R>> {
521 loop {
523 match self.step().await? {
524 NextStep::Continue(new_engine) => self = new_engine,
525 NextStep::Complete(database) => return Ok(database),
526 }
527 }
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use crate::mmr::{Position, Proof};
535 use commonware_cryptography::sha256;
536 use futures::channel::oneshot;
537
538 #[test]
539 fn test_outstanding_requests() {
540 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
541 assert_eq!(requests.len(), 0);
542
543 let fut = Box::pin(async {
545 IndexedFetchResult {
546 start_loc: Location::new_unchecked(0),
547 result: Ok(FetchResult {
548 proof: Proof {
549 size: Position::new(0),
550 digests: vec![],
551 },
552 operations: vec![],
553 success_tx: oneshot::channel().0,
554 }),
555 }
556 });
557 requests.add(Location::new_unchecked(10), fut);
558 assert_eq!(requests.len(), 1);
559 assert!(requests.locations().contains(&Location::new_unchecked(10)));
560
561 requests.remove(Location::new_unchecked(10));
563 assert!(!requests.locations().contains(&Location::new_unchecked(10)));
564 }
565}