use alloy_primitives::{Address, B256};
use nectar_primitives::SwarmAddress;
use crate::{StampError, StampIndex, calculate_bucket};
pub type BatchId = B256;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BatchParams {
pub owner: Address,
pub depth: u8,
pub bucket_depth: u8,
pub immutable: bool,
pub amount: u128,
}
impl BatchParams {
pub const fn new(owner: Address, depth: u8, bucket_depth: u8, amount: u128) -> Self {
Self {
owner,
depth,
bucket_depth,
immutable: false,
amount,
}
}
#[must_use]
pub const fn immutable(mut self, immutable: bool) -> Self {
self.immutable = immutable;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Batch {
id: BatchId,
value: u128,
start: u64,
owner: Address,
depth: u8,
bucket_depth: u8,
immutable: bool,
}
impl Batch {
#[inline]
pub const fn new(
id: BatchId,
value: u128,
start: u64,
owner: Address,
depth: u8,
bucket_depth: u8,
immutable: bool,
) -> Self {
Self {
id,
value,
start,
owner,
depth,
bucket_depth,
immutable,
}
}
#[inline]
pub const fn id(&self) -> BatchId {
self.id
}
#[inline]
pub const fn value(&self) -> u128 {
self.value
}
#[inline]
pub const fn start(&self) -> u64 {
self.start
}
#[inline]
pub const fn owner(&self) -> Address {
self.owner
}
#[inline]
pub const fn depth(&self) -> u8 {
self.depth
}
#[inline]
pub const fn bucket_depth(&self) -> u8 {
self.bucket_depth
}
#[inline]
pub const fn immutable(&self) -> bool {
self.immutable
}
#[inline]
pub const fn bucket_upper_bound(&self) -> u32 {
1u32 << (self.depth - self.bucket_depth)
}
#[inline]
pub const fn bucket_count(&self) -> u32 {
1u32 << self.bucket_depth
}
#[inline]
pub const fn set_value(&mut self, value: u128) {
self.value = value;
}
#[inline]
pub const fn set_depth(&mut self, depth: u8) {
self.depth = depth;
}
#[inline]
pub const fn is_expired(&self, total_amount: u128) -> bool {
self.value <= total_amount
}
#[inline]
pub const fn is_usable(&self, current_block: u64, threshold: u64) -> bool {
current_block >= self.start.saturating_add(threshold)
}
pub const fn validate_index(&self, index: &StampIndex) -> Result<(), StampError> {
if index.bucket() >= self.bucket_count() {
return Err(StampError::InvalidIndex);
}
if index.index() >= self.bucket_upper_bound() {
return Err(StampError::InvalidIndex);
}
Ok(())
}
#[inline]
pub fn bucket_for_address(&self, address: &SwarmAddress) -> u32 {
calculate_bucket(address, self.bucket_depth)
}
pub fn validate_bucket(
&self,
index: &StampIndex,
address: &SwarmAddress,
) -> Result<(), StampError> {
let expected_bucket = self.bucket_for_address(address);
if index.bucket() != expected_bucket {
return Err(StampError::BucketMismatch);
}
Ok(())
}
}
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for BatchParams {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let depth: u8 = u.int_in_range(1..=32)?;
let bucket_depth: u8 = u.int_in_range(1..=depth)?;
Ok(Self {
owner: Address::arbitrary(u)?,
depth,
bucket_depth,
immutable: u.arbitrary()?,
amount: u.arbitrary()?,
})
}
}
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for Batch {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let depth: u8 = u.int_in_range(1..=32)?;
let bucket_depth: u8 = u.int_in_range(1..=depth)?;
Ok(Self::new(
B256::arbitrary(u)?,
u.arbitrary()?,
u.arbitrary()?,
Address::arbitrary(u)?,
depth,
bucket_depth,
u.arbitrary()?,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_creation() {
let id = B256::ZERO;
let batch = Batch::new(id, 1000, 100, Address::ZERO, 18, 16, false);
assert_eq!(batch.id(), id);
assert_eq!(batch.value(), 1000);
assert_eq!(batch.start(), 100);
assert_eq!(batch.owner(), Address::ZERO);
assert_eq!(batch.depth(), 18);
assert_eq!(batch.bucket_depth(), 16);
assert!(!batch.immutable());
}
#[test]
fn test_bucket_calculations() {
let batch = Batch::new(B256::ZERO, 0, 0, Address::ZERO, 18, 16, false);
assert_eq!(batch.bucket_upper_bound(), 4);
assert_eq!(batch.bucket_count(), 65536);
}
#[test]
fn test_batch_expiry() {
let batch = Batch::new(B256::ZERO, 1000, 0, Address::ZERO, 18, 16, false);
assert!(!batch.is_expired(999));
assert!(batch.is_expired(1000));
assert!(batch.is_expired(1001));
}
#[test]
fn test_batch_usability() {
let batch = Batch::new(B256::ZERO, 1000, 100, Address::ZERO, 18, 16, false);
assert!(!batch.is_usable(100, 10)); assert!(!batch.is_usable(109, 10)); assert!(batch.is_usable(110, 10)); assert!(batch.is_usable(111, 10)); }
#[test]
fn test_batch_params_builder() {
let params = BatchParams::new(Address::ZERO, 20, 16, 1000).immutable(true);
assert_eq!(params.owner, Address::ZERO);
assert_eq!(params.depth, 20);
assert_eq!(params.bucket_depth, 16);
assert_eq!(params.amount, 1000);
assert!(params.immutable);
}
}