#[cfg(feature = "arbitrary")]
use crate::merkle::mmr::Family;
#[cfg(feature = "arbitrary")]
use crate::merkle::Family as _;
use crate::{
merkle::mmr::Location,
qmdb::sync::{self, error::EngineError},
};
use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt as _, Write};
use commonware_cryptography::Digest;
use commonware_runtime::{Buf, BufMut};
use commonware_utils::range::NonEmptyRange;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Target<D: Digest> {
pub root: D,
pub range: NonEmptyRange<Location>,
}
impl<D: Digest> Write for Target<D> {
fn write(&self, buf: &mut impl BufMut) {
self.root.write(buf);
self.range.write(buf);
}
}
impl<D: Digest> EncodeSize for Target<D> {
fn encode_size(&self) -> usize {
self.root.encode_size() + self.range.encode_size()
}
}
impl<D: Digest> Read for Target<D> {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
let root = D::read(buf)?;
let range = NonEmptyRange::<Location>::read(buf)?;
if !range.start().is_valid() || !range.end().is_valid() {
return Err(CodecError::Invalid(
"storage::qmdb::sync::Target",
"range bounds out of valid range",
));
}
Ok(Self { root, range })
}
}
#[cfg(feature = "arbitrary")]
impl<D: Digest> arbitrary::Arbitrary<'_> for Target<D>
where
D: for<'a> arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let root = u.arbitrary()?;
let max_loc = Family::MAX_LEAVES;
let lower = u.int_in_range(0..=*max_loc - 1)?;
let upper = u.int_in_range(lower + 1..=*max_loc)?;
Ok(Self {
root,
range: commonware_utils::non_empty_range!(Location::new(lower), Location::new(upper)),
})
}
}
pub fn validate_update<U, D>(
old_target: &Target<D>,
new_target: &Target<D>,
) -> Result<(), sync::Error<U, D>>
where
U: std::error::Error + Send + 'static,
D: Digest,
{
if !new_target.range.end().is_valid() {
return Err(sync::Error::Engine(EngineError::InvalidTarget {
lower_bound_pos: new_target.range.start(),
upper_bound_pos: new_target.range.end(),
}));
}
if new_target.range.start() < old_target.range.start()
|| new_target.range.end() <= old_target.range.end()
{
return Err(sync::Error::Engine(EngineError::SyncTargetMovedBackward {
old: old_target.clone(),
new: new_target.clone(),
}));
}
if new_target.root == old_target.root {
return Err(sync::Error::Engine(EngineError::SyncTargetRootUnchanged));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_cryptography::sha256;
use commonware_utils::non_empty_range;
use rstest::rstest;
use std::io::Cursor;
fn target(root: sha256::Digest, start: u64, end: u64) -> Target<sha256::Digest> {
Target {
root,
range: non_empty_range!(Location::new(start), Location::new(end)),
}
}
#[test]
fn test_sync_target_serialization() {
let target = target(sha256::Digest::from([42; 32]), 100, 500);
let mut buffer = Vec::new();
target.write(&mut buffer);
assert_eq!(buffer.len(), target.encode_size());
let mut cursor = Cursor::new(buffer);
let deserialized = Target::read(&mut cursor).unwrap();
assert_eq!(target, deserialized);
assert_eq!(target.root, deserialized.root);
assert_eq!(target.range, deserialized.range);
}
#[test]
fn test_sync_target_read_invalid_bounds() {
let mut buffer = Vec::new();
sha256::Digest::from([42; 32]).write(&mut buffer);
Location::new(100).write(&mut buffer); Location::new(50).write(&mut buffer);
let mut cursor = Cursor::new(buffer);
assert!(matches!(
Target::<sha256::Digest>::read(&mut cursor),
Err(CodecError::Invalid("Range", "start must be <= end"))
));
let root = sha256::Digest::from([42; 32]);
let mut buffer = Vec::new();
root.write(&mut buffer);
(Location::new(100)..Location::new(100)).write(&mut buffer);
let mut cursor = Cursor::new(buffer);
assert!(matches!(
Target::<sha256::Digest>::read(&mut cursor),
Err(CodecError::Invalid("NonEmptyRange", "start must be < end"))
));
}
type TestError = sync::Error<std::io::Error, sha256::Digest>;
#[rstest]
#[case::valid_update(
target(sha256::Digest::from([0; 32]), 0, 100),
target(sha256::Digest::from([1; 32]), 50, 200),
Ok(())
)]
#[case::same_start(
target(sha256::Digest::from([0; 32]), 0, 100),
target(sha256::Digest::from([1; 32]), 0, 200),
Ok(())
)]
#[case::same_end(
target(sha256::Digest::from([0; 32]), 0, 100),
target(sha256::Digest::from([1; 32]), 50, 100),
Err(TestError::Engine(EngineError::SyncTargetMovedBackward {
old: target(sha256::Digest::from([0; 32]), 0, 100),
new: target(sha256::Digest::from([1; 32]), 50, 100),
}))
)]
#[case::moves_backward(
target(sha256::Digest::from([0; 32]), 0, 100),
target(sha256::Digest::from([1; 32]), 0, 50),
Err(TestError::Engine(EngineError::SyncTargetMovedBackward {
old: target(sha256::Digest::from([0; 32]), 0, 100),
new: target(sha256::Digest::from([1; 32]), 0, 50),
}))
)]
#[case::same_root(
target(sha256::Digest::from([0; 32]), 0, 100),
target(sha256::Digest::from([0; 32]), 50, 200),
Err(TestError::Engine(EngineError::SyncTargetRootUnchanged))
)]
fn test_validate_update(
#[case] old_target: Target<sha256::Digest>,
#[case] new_target: Target<sha256::Digest>,
#[case] expected: Result<(), TestError>,
) {
let result = validate_update(&old_target, &new_target);
match (&result, &expected) {
(Ok(()), Ok(())) => {}
(Ok(()), Err(expected_err)) => {
panic!("Expected error {expected_err:?} but got success");
}
(Err(actual_err), Ok(())) => {
panic!("Expected success but got error: {actual_err:?}");
}
(Err(actual_err), Err(expected_err)) => match (actual_err, expected_err) {
(
TestError::Engine(EngineError::InvalidTarget {
lower_bound_pos: a_lower,
upper_bound_pos: a_upper,
}),
TestError::Engine(EngineError::InvalidTarget {
lower_bound_pos: e_lower,
upper_bound_pos: e_upper,
}),
) => {
assert_eq!(a_lower, e_lower);
assert_eq!(a_upper, e_upper);
}
(
TestError::Engine(EngineError::SyncTargetMovedBackward { .. }),
TestError::Engine(EngineError::SyncTargetMovedBackward { .. }),
) => {}
(
TestError::Engine(EngineError::SyncTargetRootUnchanged),
TestError::Engine(EngineError::SyncTargetRootUnchanged),
) => {}
_ => panic!("Error type mismatch: got {actual_err:?}, expected {expected_err:?}"),
},
}
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Target<sha256::Digest>>,
}
}
}