use alloc::vec::Vec;
use web_time::{SystemTime, UNIX_EPOCH};
use alloy_primitives::Address;
use alloy_signer::SignerSync;
use bytes::Bytes;
use nectar_postage::{Batch, BatchId, StampIndex};
use nectar_primitives::SwarmAddress;
use thiserror::Error;
use crate::codec::RootInfo;
use crate::seal::{SealError, SealedChunk, seal_plan};
use crate::snapshot::{PublishedSequence, Snapshot};
use crate::{UsageError, usage_chunk_address};
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait SnapshotSource {
type Error: core::error::Error + Send + Sync + 'static;
fn fetch(
&self,
address: &SwarmAddress,
) -> impl core::future::Future<Output = Result<Option<Bytes>, Self::Error>>;
}
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait SnapshotSink {
type Error: core::error::Error + Send + Sync + 'static;
fn push(
&self,
sealed: &SealedChunk,
) -> impl core::future::Future<Output = Result<(), Self::Error>>;
}
#[non_exhaustive]
#[derive(Debug, Error)]
pub enum ClientError<SrcErr, SnkErr>
where
SrcErr: core::error::Error + Send + Sync + 'static,
SnkErr: core::error::Error + Send + Sync + 'static,
{
#[error(transparent)]
Usage(#[from] UsageError),
#[error(transparent)]
Seal(#[from] SealError),
#[error("chunk source read failed")]
Source(#[source] SrcErr),
#[error("chunk sink publish failed")]
Sink(#[source] SnkErr),
#[error("published root commits to leaf {index} but the source reports it absent")]
MissingLeaf {
index: u16,
},
}
#[derive(Debug)]
pub struct BatchStamper<Sg, Src, Snk> {
signer: Sg,
owner: Address,
batch_id: BatchId,
source: Src,
sink: Snk,
snapshot: Snapshot,
persisted_this_session: bool,
}
impl<Sg, Src, Snk> BatchStamper<Sg, Src, Snk>
where
Sg: SignerSync + alloy_signer::Signer,
Src: SnapshotSource,
Snk: SnapshotSink,
{
pub async fn open(
signer: Sg,
batch: &Batch,
source: Src,
sink: Snk,
) -> Result<Self, ClientError<Src::Error, Snk::Error>> {
let owner = signer.address();
let batch_id = batch.id();
let root_addr = usage_chunk_address(&batch_id, &owner, 0);
let snapshot = match source
.fetch(&root_addr)
.await
.map_err(ClientError::Source)?
{
Some(root_bytes) => {
let root = RootInfo::parse(&root_bytes)?;
let mut leaves: Vec<Bytes> = Vec::with_capacity(root.leaf_count() as usize);
for leaf in 0..root.leaf_count() {
let index = leaf + 1;
let leaf_addr = usage_chunk_address(&batch_id, &owner, index);
match source
.fetch(&leaf_addr)
.await
.map_err(ClientError::Source)?
{
Some(bytes) => leaves.push(bytes),
None => return Err(ClientError::MissingLeaf { index }),
}
}
root.assemble(&leaves)?
}
None => Snapshot::from_batch(batch)?,
};
Ok(Self {
signer,
owner,
batch_id,
source,
sink,
snapshot,
persisted_this_session: false,
})
}
pub fn stamp(
&mut self,
content: &SwarmAddress,
) -> Result<StampIndex, ClientError<Src::Error, Snk::Error>> {
Ok(self.snapshot.issuer(self.owner).record_address(content)?)
}
pub async fn flush(&mut self) -> Result<(), ClientError<Src::Error, Snk::Error>> {
if !self.snapshot.is_dirty() && self.persisted_this_session {
return Ok(());
}
let root_addr = usage_chunk_address(&self.batch_id, &self.owner, 0);
let floor = match self
.source
.fetch(&root_addr)
.await
.map_err(ClientError::Source)?
{
Some(root_bytes) => PublishedSequence::from(&RootInfo::parse(&root_bytes)?),
None => PublishedSequence::NONE,
};
let plan = self.snapshot.revalidate(floor)?.plan_persist(&self.owner)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let timestamp = self
.snapshot
.last_seal_timestamp()
.map_or(now, |previous| now.max(previous + 1));
let sealed = seal_plan(&mut self.snapshot, &plan, timestamp, &self.signer)?;
for chunk in &sealed {
self.sink.push(chunk).await.map_err(ClientError::Sink)?;
}
self.persisted_this_session = true;
Ok(())
}
pub const fn snapshot(&self) -> &Snapshot {
&self.snapshot
}
pub const fn owner(&self) -> Address {
self.owner
}
pub const fn batch_id(&self) -> BatchId {
self.batch_id
}
pub const fn is_dirty(&self) -> bool {
self.snapshot.is_dirty()
}
}
#[cfg(test)]
mod tests {
use alloc::collections::BTreeMap;
use std::sync::Mutex;
use alloy_primitives::B256;
use alloy_signer_local::PrivateKeySigner;
use super::*;
use crate::{Mutability, UsageTable};
#[derive(Debug, Default, Clone)]
struct MemNet {
chunks: std::sync::Arc<Mutex<BTreeMap<SwarmAddress, Bytes>>>,
}
#[derive(Debug, Error)]
#[error("mem net error")]
struct MemError;
impl SnapshotSource for MemNet {
type Error = MemError;
async fn fetch(&self, address: &SwarmAddress) -> Result<Option<Bytes>, Self::Error> {
Ok(self.chunks.lock().unwrap().get(address).cloned())
}
}
impl SnapshotSink for MemNet {
type Error = MemError;
async fn push(&self, sealed: &SealedChunk) -> Result<(), Self::Error> {
use nectar_primitives::Chunk;
let address = *sealed.chunk.address();
let payload = Bytes::copy_from_slice(sealed.chunk.data().as_ref());
self.chunks.lock().unwrap().insert(address, payload);
Ok(())
}
}
#[derive(Debug, Default, Clone)]
struct FailingSource;
impl SnapshotSource for FailingSource {
type Error = MemError;
async fn fetch(&self, _address: &SwarmAddress) -> Result<Option<Bytes>, Self::Error> {
Err(MemError)
}
}
impl SnapshotSink for FailingSource {
type Error = MemError;
async fn push(&self, _sealed: &SealedChunk) -> Result<(), Self::Error> {
Ok(())
}
}
struct LocalNet(std::rc::Rc<()>);
impl SnapshotSource for LocalNet {
type Error = MemError;
fn fetch(
&self,
_address: &SwarmAddress,
) -> impl core::future::Future<Output = Result<Option<Bytes>, Self::Error>> {
let hold = self.0.clone();
async move {
let _hold = &hold;
Ok(None)
}
}
}
impl SnapshotSink for LocalNet {
type Error = MemError;
fn push(
&self,
_sealed: &SealedChunk,
) -> impl core::future::Future<Output = Result<(), Self::Error>> {
let hold = self.0.clone();
async move {
let _hold = &hold;
Ok(())
}
}
}
#[test]
fn non_send_transport_satisfies_the_traits() {
fn assert_source<S: SnapshotSource>(_: &S) {}
fn assert_sink<K: SnapshotSink>(_: &K) {}
let local = LocalNet(std::rc::Rc::new(()));
assert_source(&local);
assert_sink(&local);
}
fn test_batch(signer: &PrivateKeySigner, immutable: bool) -> Batch {
Batch::new(
B256::repeat_byte(0x42),
0,
0,
signer.address(),
20,
16,
immutable,
)
}
#[tokio::test]
async fn open_miss_starts_fresh_and_flush_publishes() {
let signer = PrivateKeySigner::random();
let batch = test_batch(&signer, true);
let net = MemNet::default();
let mut stamper = BatchStamper::open(signer, &batch, net.clone(), net.clone())
.await
.unwrap();
assert_eq!(stamper.snapshot().sequence(), 0);
let content = SwarmAddress::from(B256::repeat_byte(0x99));
stamper.stamp(&content).unwrap();
assert!(stamper.is_dirty());
stamper.flush().await.unwrap();
assert_eq!(stamper.snapshot().sequence(), 1);
assert!(!stamper.is_dirty());
stamper.flush().await.unwrap();
assert_eq!(stamper.snapshot().sequence(), 1);
}
#[tokio::test]
async fn open_recovers_published_batch() {
let signer = PrivateKeySigner::random();
let owner = signer.address();
let batch = test_batch(&signer, true);
let net = MemNet::default();
{
let mut a = BatchStamper::open(signer.clone(), &batch, net.clone(), net.clone())
.await
.unwrap();
a.stamp(&SwarmAddress::from(B256::repeat_byte(0x99)))
.unwrap();
a.flush().await.unwrap();
}
let b = BatchStamper::open(signer, &batch, net.clone(), net.clone())
.await
.unwrap();
assert_eq!(
b.snapshot().sequence(),
1,
"recovered the published sequence"
);
assert_eq!(b.owner(), owner);
assert!(!b.snapshot().allocated_slots().is_empty());
}
#[tokio::test]
async fn open_aborts_on_source_error() {
let signer = PrivateKeySigner::random();
let batch = test_batch(&signer, true);
let result = BatchStamper::open(signer, &batch, FailingSource, FailingSource).await;
assert!(
matches!(result, Err(ClientError::Source(_))),
"a failed read must abort open, not start fresh",
);
}
#[tokio::test]
async fn flush_aborts_on_floor_read_error() {
let signer = PrivateKeySigner::random();
let owner = signer.address();
let batch = test_batch(&signer, true);
let net = MemNet::default();
let mut stamper = BatchStamper {
signer: signer.clone(),
owner,
batch_id: batch.id(),
source: FailingSource,
sink: net.clone(),
snapshot: Snapshot::from_batch(&batch).unwrap(),
persisted_this_session: false,
};
stamper
.stamp(&SwarmAddress::from(B256::repeat_byte(0x99)))
.unwrap();
let result = stamper.flush().await;
assert!(
matches!(result, Err(ClientError::Source(_))),
"a failed floor read must abort flush, not persist against NONE",
);
assert!(net.chunks.lock().unwrap().is_empty());
}
#[tokio::test]
async fn stamp_then_flush_advances_sequence_and_reuses_slots() {
let signer = PrivateKeySigner::random();
let net = MemNet::default();
let batch = test_batch(&signer, true);
let mut stamper = BatchStamper::open(signer, &batch, net.clone(), net.clone())
.await
.unwrap();
stamper
.stamp(&SwarmAddress::from(B256::repeat_byte(0x99)))
.unwrap();
stamper.flush().await.unwrap();
let slots_after_first = stamper.snapshot().allocated_slots().to_vec();
assert_eq!(stamper.snapshot().sequence(), 1);
stamper
.stamp(&SwarmAddress::from(B256::repeat_byte(0xab)))
.unwrap();
stamper.flush().await.unwrap();
assert_eq!(stamper.snapshot().sequence(), 2, "sequence advanced");
assert_eq!(
stamper.snapshot().allocated_slots(),
slots_after_first.as_slice(),
"the snapshot's own slots were reused, not re-allocated",
);
}
#[tokio::test]
async fn flush_rejects_stale_sequence() {
let signer = PrivateKeySigner::random();
let owner = signer.address();
let batch = test_batch(&signer, true);
let net = MemNet::default();
{
let mut a = BatchStamper::open(signer.clone(), &batch, net.clone(), net.clone())
.await
.unwrap();
a.stamp(&SwarmAddress::from(B256::repeat_byte(0x01)))
.unwrap();
a.flush().await.unwrap();
a.stamp(&SwarmAddress::from(B256::repeat_byte(0x02)))
.unwrap();
a.flush().await.unwrap();
assert_eq!(a.snapshot().sequence(), 2);
}
let table = UsageTable::new(batch.id(), 20, 16, Mutability::Immutable).unwrap();
let mut stale = Snapshot::new(table);
stale
.revalidate(PublishedSequence::NONE)
.unwrap()
.plan_persist(&owner)
.unwrap();
assert_eq!(stale.sequence(), 1);
let mut b = BatchStamper {
signer,
owner,
batch_id: batch.id(),
source: net.clone(),
sink: net.clone(),
snapshot: stale,
persisted_this_session: false,
};
b.stamp(&SwarmAddress::from(B256::repeat_byte(0x03)))
.unwrap();
let result = b.flush().await;
assert!(
matches!(
result,
Err(ClientError::Usage(UsageError::StaleSequence {
next: 2,
floor: 2
})),
),
"a persist whose next sequence does not exceed the live floor is rejected",
);
}
}