use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use nectar_postage::{Batch, BatchId, StampDigest, StampError, StampIndex, calculate_bucket};
use nectar_primitives::SwarmAddress;
#[cfg(feature = "parallel")]
use {
crate::error::SigningError,
alloy_primitives::B256,
alloy_signer::Signature,
nectar_postage::{Stamp, current_timestamp},
};
const DEFAULT_SHARD_COUNT: usize = 16;
#[derive(Debug)]
struct BucketShard {
base_bucket: u32,
indices: Vec<AtomicU32>,
}
impl BucketShard {
fn new(base_bucket: u32, bucket_count: u32) -> Self {
let indices = (0..bucket_count).map(|_| AtomicU32::new(0)).collect();
Self {
base_bucket,
indices,
}
}
#[inline]
const fn local_index(&self, bucket: u32) -> usize {
(bucket - self.base_bucket) as usize
}
#[inline]
fn allocate(&self, bucket: u32, bucket_capacity: u32) -> Option<u32> {
let local_idx = self.local_index(bucket);
let current = self.indices[local_idx].fetch_add(1, Ordering::Relaxed);
if current >= bucket_capacity {
self.indices[local_idx].fetch_sub(1, Ordering::Relaxed);
None
} else {
Some(current)
}
}
#[inline]
fn utilization(&self, bucket: u32) -> u32 {
let local_idx = self.local_index(bucket);
self.indices[local_idx].load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub struct ShardedIssuer {
batch_id: BatchId,
depth: u8,
bucket_depth: u8,
bucket_capacity: u32,
shards: Vec<BucketShard>,
shard_mask: u32,
shard_shift: u32,
max_utilization: AtomicU32,
stamps_issued: AtomicU64,
}
impl ShardedIssuer {
pub fn new(batch_id: BatchId, depth: u8, bucket_depth: u8) -> Self {
Self::with_shard_count(batch_id, depth, bucket_depth, DEFAULT_SHARD_COUNT)
}
pub fn with_shard_count(
batch_id: BatchId,
depth: u8,
bucket_depth: u8,
shard_count: usize,
) -> Self {
assert!(
shard_count.is_power_of_two(),
"shard_count must be a power of 2"
);
let total_buckets = 1u32 << bucket_depth;
let shard_count = shard_count.min(total_buckets as usize);
let buckets_per_shard = total_buckets / shard_count as u32;
let bucket_capacity = 1u32 << (depth - bucket_depth);
let shard_bits = (shard_count as u32).trailing_zeros();
let shard_shift = bucket_depth as u32 - shard_bits;
let shard_mask = (shard_count - 1) as u32;
let shards: Vec<_> = (0..shard_count)
.map(|i| BucketShard::new(i as u32 * buckets_per_shard, buckets_per_shard))
.collect();
Self {
batch_id,
depth,
bucket_depth,
bucket_capacity,
shards,
shard_mask,
shard_shift,
max_utilization: AtomicU32::new(0),
stamps_issued: AtomicU64::new(0),
}
}
pub fn from_batch(batch: &Batch) -> Self {
Self::new(batch.id(), batch.depth(), batch.bucket_depth())
}
#[inline]
const fn shard_index(&self, bucket: u32) -> usize {
((bucket >> self.shard_shift) & self.shard_mask) as usize
}
pub fn prepare_stamp(
&self,
address: &SwarmAddress,
timestamp: u64,
) -> Result<StampDigest, StampError> {
let bucket = calculate_bucket(address, self.bucket_depth);
let shard_idx = self.shard_index(bucket);
let shard = &self.shards[shard_idx];
let position =
shard
.allocate(bucket, self.bucket_capacity)
.ok_or(StampError::BucketFull {
bucket,
capacity: self.bucket_capacity,
})?;
self.stamps_issued.fetch_add(1, Ordering::Relaxed);
let new_util = position + 1;
let mut current_max = self.max_utilization.load(Ordering::Relaxed);
while new_util > current_max {
match self.max_utilization.compare_exchange_weak(
current_max,
new_util,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current_max = actual,
}
}
let index = StampIndex::new(bucket, position);
Ok(StampDigest::new(*address, self.batch_id, index, timestamp))
}
pub const fn batch_id(&self) -> BatchId {
self.batch_id
}
pub const fn batch_depth(&self) -> u8 {
self.depth
}
pub const fn bucket_depth(&self) -> u8 {
self.bucket_depth
}
pub fn max_bucket_utilization(&self) -> u32 {
self.max_utilization.load(Ordering::Relaxed)
}
pub fn bucket_utilization(&self, bucket: u32) -> u32 {
let shard_idx = self.shard_index(bucket);
self.shards[shard_idx].utilization(bucket)
}
pub fn stamps_issued(&self) -> u64 {
self.stamps_issued.load(Ordering::Relaxed)
}
pub const fn bucket_capacity(&self) -> u32 {
self.bucket_capacity
}
pub const fn shard_count(&self) -> usize {
self.shards.len()
}
}
#[cfg(feature = "parallel")]
#[derive(Debug)]
pub struct StampResult {
pub address: SwarmAddress,
pub result: Result<Stamp, SigningError>,
}
#[cfg(feature = "parallel")]
pub fn sign_stamps_parallel<S, E>(
issuer: &ShardedIssuer,
signer: &S,
addresses: &[SwarmAddress],
) -> Vec<StampResult>
where
S: Fn(&B256) -> Result<Signature, E> + Sync,
E: Into<SigningError>,
{
use rayon::prelude::*;
addresses
.par_iter()
.map(|address| {
let result = sign_stamp_internal(issuer, signer, address);
StampResult {
address: *address,
result,
}
})
.collect()
}
#[cfg(feature = "parallel")]
fn sign_stamp_internal<S, E>(
issuer: &ShardedIssuer,
signer: &S,
address: &SwarmAddress,
) -> Result<Stamp, SigningError>
where
S: Fn(&B256) -> Result<Signature, E>,
E: Into<SigningError>,
{
let timestamp = current_timestamp();
let digest = issuer.prepare_stamp(address, timestamp)?;
let prehash = digest.to_prehash();
let sig = signer(&prehash).map_err(|e| e.into())?;
Ok(stamp_from_signature(&digest, sig))
}
#[cfg(feature = "parallel")]
#[inline]
const fn stamp_from_signature(digest: &StampDigest, sig: Signature) -> Stamp {
Stamp::with_index(digest.batch_id, digest.index, digest.timestamp, sig)
}
#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::B256;
#[test]
fn test_sharded_issuer_basic() {
let issuer = ShardedIssuer::new(B256::ZERO, 20, 16);
assert_eq!(issuer.batch_id(), B256::ZERO);
assert_eq!(issuer.batch_depth(), 20);
assert_eq!(issuer.bucket_depth(), 16);
assert_eq!(issuer.bucket_capacity(), 16); assert_eq!(issuer.shard_count(), DEFAULT_SHARD_COUNT);
}
#[test]
fn test_sharded_issuer_prepare_stamp() {
let issuer = ShardedIssuer::new(B256::ZERO, 20, 16);
let address = SwarmAddress::from(B256::random());
let digest = issuer.prepare_stamp(&address, 12345).unwrap();
assert_eq!(digest.batch_id, B256::ZERO);
assert_eq!(digest.timestamp, 12345);
assert_eq!(issuer.stamps_issued(), 1);
}
#[test]
fn test_sharded_issuer_concurrent_access() {
use std::sync::Arc;
use std::thread;
let issuer = Arc::new(ShardedIssuer::new(B256::ZERO, 24, 16));
let num_threads = 8;
let stamps_per_thread = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let issuer = Arc::clone(&issuer);
thread::spawn(move || {
for _ in 0..stamps_per_thread {
let addr = SwarmAddress::from(B256::random());
issuer.prepare_stamp(&addr, 0).unwrap();
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(
issuer.stamps_issued(),
(num_threads * stamps_per_thread) as u64
);
}
#[cfg(feature = "parallel")]
#[test]
fn test_parallel_signing() {
use crate::error::SigningError;
use alloy_signer::SignerSync;
use alloy_signer_local::PrivateKeySigner;
let issuer = ShardedIssuer::new(B256::ZERO, 24, 16);
let signer = PrivateKeySigner::random();
let addresses: Vec<_> = (0..100)
.map(|_| SwarmAddress::from(B256::random()))
.collect();
let sign_fn = |prehash: &B256| -> Result<Signature, SigningError> {
Ok(signer
.sign_message_sync(prehash.as_slice())
.map_err(alloy_signer::Error::other)?)
};
let results = sign_stamps_parallel(&issuer, &sign_fn, &addresses);
assert_eq!(results.len(), 100);
for result in &results {
assert!(result.result.is_ok());
}
assert_eq!(issuer.stamps_issued(), 100);
}
}