1use crate::{
4 adb::{
5 self,
6 sync::{
7 requests::Requests,
8 resolver::{FetchResult, Resolver},
9 target::validate_update,
10 Database, Error, Journal, Target,
11 },
12 },
13 mmr::hasher,
14};
15use commonware_codec::Encode;
16use commonware_cryptography::Digest;
17use commonware_macros::select;
18use commonware_utils::NZU64;
19use futures::{channel::mpsc, future::Either, StreamExt};
20use std::{collections::BTreeMap, fmt::Debug, num::NonZeroU64};
21
22type EngineError<DB, R> =
24 Error<<DB as Database>::Error, <R as Resolver>::Error, <DB as Database>::Digest>;
25
26#[derive(Debug)]
28pub(crate) enum NextStep<C, D> {
29 Continue(C),
31 Complete(D),
33}
34
35#[derive(Debug)]
37enum Event<Op, D: Digest, E> {
38 TargetUpdate(Target<D>),
40 BatchReceived(IndexedFetchResult<Op, D, E>),
42 UpdateChannelClosed,
44}
45
46#[derive(Debug)]
48pub(super) struct IndexedFetchResult<Op, D: Digest, E> {
49 pub start_loc: u64,
51 pub result: Result<FetchResult<Op, D>, E>,
53}
54
55async fn wait_for_event<Op, D: Digest, E>(
58 update_receiver: &mut Option<mpsc::Receiver<Target<D>>>,
59 outstanding_requests: &mut Requests<Op, D, E>,
60) -> Option<Event<Op, D, E>> {
61 let target_update_fut = match update_receiver {
62 Some(update_rx) => Either::Left(update_rx.next()),
63 None => Either::Right(futures::future::pending()),
64 };
65
66 select! {
67 target = target_update_fut => {
68 match target {
69 Some(target) => Some(Event::TargetUpdate(target)),
70 None => Some(Event::UpdateChannelClosed),
71 }
72 },
73 result = outstanding_requests.futures_mut().next() => {
74 result.map(|fetch_result| Event::BatchReceived(fetch_result))
75 },
76 }
77}
78
79pub struct Config<DB, R>
81where
82 DB: Database,
83 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
84 DB::Op: Encode,
85{
86 pub context: DB::Context,
88 pub resolver: R,
90 pub target: Target<DB::Digest>,
92 pub max_outstanding_requests: usize,
94 pub fetch_batch_size: NonZeroU64,
96 pub apply_batch_size: usize,
98 pub db_config: DB::Config,
100 pub update_rx: Option<mpsc::Receiver<Target<DB::Digest>>>,
102}
103pub(crate) struct Engine<DB, R>
105where
106 DB: Database,
107 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
108 DB::Op: Encode,
109{
110 outstanding_requests: Requests<DB::Op, DB::Digest, R::Error>,
112
113 fetched_operations: BTreeMap<u64, Vec<DB::Op>>,
115
116 pinned_nodes: Option<Vec<DB::Digest>>,
118
119 target: Target<DB::Digest>,
121
122 max_outstanding_requests: usize,
124
125 fetch_batch_size: NonZeroU64,
127
128 apply_batch_size: usize,
130
131 journal: DB::Journal,
133
134 resolver: R,
136
137 hasher: crate::mmr::hasher::Standard<DB::Hasher>,
139
140 context: DB::Context,
142
143 config: DB::Config,
145
146 update_receiver: Option<mpsc::Receiver<Target<DB::Digest>>>,
148}
149
150#[cfg(test)]
151impl<DB, R> Engine<DB, R>
152where
153 DB: Database,
154 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
155 DB::Op: Encode,
156{
157 pub(crate) fn journal(&self) -> &DB::Journal {
158 &self.journal
159 }
160}
161
162impl<DB, R> Engine<DB, R>
163where
164 DB: Database,
165 R: Resolver<Op = DB::Op, Digest = DB::Digest>,
166 DB::Op: Encode,
167{
168 pub async fn new(config: Config<DB, R>) -> Result<Self, EngineError<DB, R>> {
170 if config.target.lower_bound_ops > config.target.upper_bound_ops {
171 return Err(Error::InvalidTarget {
172 lower_bound_pos: config.target.lower_bound_ops,
173 upper_bound_pos: config.target.upper_bound_ops,
174 });
175 }
176
177 let journal = DB::create_journal(
179 config.context.clone(),
180 &config.db_config,
181 config.target.lower_bound_ops,
182 config.target.upper_bound_ops,
183 )
184 .await
185 .map_err(Error::database)?;
186
187 let mut engine = Self {
188 outstanding_requests: Requests::new(),
189 fetched_operations: BTreeMap::new(),
190 pinned_nodes: None,
191 target: config.target.clone(),
192 max_outstanding_requests: config.max_outstanding_requests,
193 fetch_batch_size: config.fetch_batch_size,
194 apply_batch_size: config.apply_batch_size,
195 journal,
196 resolver: config.resolver.clone(),
197 hasher: hasher::Standard::<DB::Hasher>::new(),
198 context: config.context,
199 config: config.db_config,
200 update_receiver: config.update_rx,
201 };
202 engine.schedule_requests().await?;
203 Ok(engine)
204 }
205
206 async fn schedule_requests(&mut self) -> Result<(), EngineError<DB, R>> {
208 let target_size = self.target.upper_bound_ops + 1;
209
210 if self.pinned_nodes.is_none() {
213 let start_loc = self.target.lower_bound_ops;
214 let resolver = self.resolver.clone();
215 self.outstanding_requests.add(
216 start_loc,
217 Box::pin(async move {
218 let result = resolver
219 .get_operations(target_size, start_loc, NZU64!(1))
220 .await;
221 IndexedFetchResult { start_loc, result }
222 }),
223 );
224 }
225
226 let num_requests = self
228 .max_outstanding_requests
229 .saturating_sub(self.outstanding_requests.len());
230
231 let log_size = self.journal.size().await.map_err(Error::database)?;
232
233 for _ in 0..num_requests {
234 let operation_counts: BTreeMap<u64, u64> = self
236 .fetched_operations
237 .iter()
238 .map(|(&start_loc, operations)| (start_loc, operations.len() as u64))
239 .collect();
240
241 let Some((start_loc, end_loc)) = crate::adb::sync::gaps::find_next(
243 log_size,
244 self.target.upper_bound_ops,
245 &operation_counts,
246 self.outstanding_requests.locations(),
247 self.fetch_batch_size.get(),
248 ) else {
249 break; };
251
252 let gap_size = NZU64!(end_loc - start_loc + 1);
254 let batch_size = self.fetch_batch_size.min(gap_size);
255
256 let resolver = self.resolver.clone();
258 self.outstanding_requests.add(
259 start_loc,
260 Box::pin(async move {
261 let result = resolver
262 .get_operations(target_size, start_loc, batch_size)
263 .await;
264 IndexedFetchResult { start_loc, result }
265 }),
266 );
267 }
268
269 Ok(())
270 }
271
272 pub async fn reset_for_target_update(
274 self,
275 new_target: Target<DB::Digest>,
276 ) -> Result<Self, EngineError<DB, R>> {
277 let journal = DB::resize_journal(
278 self.journal,
279 self.context.clone(),
280 &self.config,
281 new_target.lower_bound_ops,
282 new_target.upper_bound_ops,
283 )
284 .await
285 .map_err(Error::database)?;
286
287 Ok(Self {
288 outstanding_requests: Requests::new(),
289 fetched_operations: BTreeMap::new(),
290 pinned_nodes: None,
291 target: new_target,
292 max_outstanding_requests: self.max_outstanding_requests,
293 fetch_batch_size: self.fetch_batch_size,
294 apply_batch_size: self.apply_batch_size,
295 journal,
296 resolver: self.resolver,
297 hasher: self.hasher,
298 context: self.context,
299 config: self.config,
300 update_receiver: self.update_receiver,
301 })
302 }
303
304 pub fn store_operations(&mut self, start_loc: u64, operations: Vec<DB::Op>) {
306 self.fetched_operations.insert(start_loc, operations);
307 }
308
309 pub async fn apply_operations(&mut self) -> Result<(), EngineError<DB, R>> {
315 let mut next_loc = self.journal.size().await.map_err(Error::database)?;
316
317 self.fetched_operations.retain(|&start_loc, operations| {
320 let end_loc = start_loc + operations.len() as u64 - 1;
321 end_loc >= next_loc
322 });
323
324 loop {
325 let range_start_loc =
328 self.fetched_operations
329 .iter()
330 .find_map(|(range_start, range_ops)| {
331 let range_end = range_start + range_ops.len() as u64 - 1;
332 if *range_start <= next_loc && next_loc <= range_end {
333 Some(*range_start)
334 } else {
335 None
336 }
337 });
338
339 let Some(range_start_loc) = range_start_loc else {
340 break;
342 };
343
344 let operations = self.fetched_operations.remove(&range_start_loc).unwrap();
346 let skip_count = (next_loc - range_start_loc) as usize;
348 let operations_count = operations.len() - skip_count;
349 let remaining_operations = operations.into_iter().skip(skip_count);
350 next_loc += operations_count as u64;
351 self.apply_operations_batch(remaining_operations).await?;
352 }
353
354 Ok(())
355 }
356
357 async fn apply_operations_batch<I>(&mut self, operations: I) -> Result<(), EngineError<DB, R>>
359 where
360 I: IntoIterator<Item = DB::Op>,
361 {
362 for op in operations {
363 self.journal.append(op).await.map_err(Error::database)?;
364 }
367 Ok(())
368 }
369
370 pub async fn is_complete(&self) -> Result<bool, EngineError<DB, R>> {
372 let journal_size = self.journal.size().await.map_err(Error::database)?;
373
374 let target_journal_size = self.target.upper_bound_ops + 1;
376
377 if journal_size >= target_journal_size {
379 if journal_size > target_journal_size {
380 return Err(Error::InvalidState);
382 }
383 return Ok(true);
384 }
385
386 Ok(false)
387 }
388
389 fn handle_fetch_result(
398 &mut self,
399 fetch_result: IndexedFetchResult<DB::Op, DB::Digest, R::Error>,
400 ) -> Result<(), EngineError<DB, R>> {
401 self.outstanding_requests.remove(fetch_result.start_loc);
403
404 let start_loc = fetch_result.start_loc;
405 let FetchResult {
406 proof,
407 operations,
408 success_tx,
409 } = fetch_result.result.map_err(Error::Resolver)?;
410
411 let operations_len = operations.len() as u64;
413 if operations_len == 0 || operations_len > self.fetch_batch_size.get() {
414 let _ = success_tx.send(false);
417 return Ok(());
418 }
419
420 let proof_valid = adb::verify_proof(
422 &mut self.hasher,
423 &proof,
424 start_loc,
425 &operations,
426 &self.target.root,
427 );
428
429 let _ = success_tx.send(proof_valid);
431
432 if proof_valid {
433 if self.pinned_nodes.is_none() && start_loc == self.target.lower_bound_ops {
435 if let Ok(nodes) =
436 crate::adb::extract_pinned_nodes(&proof, start_loc, operations_len)
437 {
438 self.pinned_nodes = Some(nodes);
439 }
440 }
441
442 self.store_operations(start_loc, operations);
444 }
445
446 Ok(())
447 }
448
449 pub(crate) async fn step(mut self) -> Result<NextStep<Self, DB>, EngineError<DB, R>> {
460 if self.is_complete().await? {
462 let database = DB::from_sync_result(
464 self.context,
465 self.config,
466 self.journal,
467 self.pinned_nodes,
468 self.target.lower_bound_ops,
469 self.target.upper_bound_ops,
470 self.apply_batch_size,
471 )
472 .await
473 .map_err(Error::database)?;
474
475 let got_root = database.root();
477 let expected_root = self.target.root;
478 if got_root != expected_root {
479 return Err(Error::RootMismatch {
480 expected: expected_root,
481 actual: got_root,
482 });
483 }
484
485 return Ok(NextStep::Complete(database));
486 }
487
488 let event = wait_for_event(&mut self.update_receiver, &mut self.outstanding_requests)
490 .await
491 .ok_or(Error::SyncStalled)?;
492
493 match event {
494 Event::TargetUpdate(new_target) => {
495 validate_update(&self.target, &new_target)?;
497
498 let mut updated_self = self.reset_for_target_update(new_target).await?;
499
500 updated_self.schedule_requests().await?;
502
503 return Ok(NextStep::Continue(updated_self));
504 }
505 Event::UpdateChannelClosed => {
506 self.update_receiver = None;
507 }
508 Event::BatchReceived(fetch_result) => {
509 self.handle_fetch_result(fetch_result)?;
511
512 self.schedule_requests().await?;
514
515 self.apply_operations().await?;
517 }
518 }
519
520 Ok(NextStep::Continue(self))
521 }
522
523 pub async fn sync(mut self) -> Result<DB, EngineError<DB, R>> {
528 loop {
530 match self.step().await? {
531 NextStep::Continue(new_engine) => self = new_engine,
532 NextStep::Complete(database) => return Ok(database),
533 }
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::mmr::verification::Proof;
542 use commonware_cryptography::sha256;
543 use futures::channel::oneshot;
544
545 #[test]
546 fn test_outstanding_requests() {
547 let mut requests: Requests<i32, sha256::Digest, ()> = Requests::new();
548 assert_eq!(requests.len(), 0);
549
550 let fut = Box::pin(async {
552 IndexedFetchResult {
553 start_loc: 0,
554 result: Ok(FetchResult {
555 proof: Proof {
556 size: 0,
557 digests: vec![],
558 },
559 operations: vec![],
560 success_tx: oneshot::channel().0,
561 }),
562 }
563 });
564 requests.add(10, fut);
565 assert_eq!(requests.len(), 1);
566 assert!(requests.locations().contains(&10));
567
568 requests.remove(10);
570 assert!(!requests.locations().contains(&10));
571 }
572}