use crate::data::client::adaptive::observe_op;
use crate::data::client::classify_error;
use crate::data::client::file::UploadEvent;
use crate::data::client::Client;
use crate::data::error::{Error, Result};
use ant_protocol::evm::{
Amount, MerklePaymentCandidateNode, MerklePaymentCandidatePool, MerklePaymentProof, MerkleTree,
MidpointProof, PoolCommitment, CANDIDATES_PER_POOL, MAX_LEAVES,
};
use ant_protocol::payment::{serialize_merkle_proof, verify_merkle_candidate_signature};
use ant_protocol::transport::PeerId;
use ant_protocol::{
compute_address, send_and_await_chunk_response, ChunkMessage, ChunkMessageBody,
MerkleCandidateQuoteRequest, MerkleCandidateQuoteResponse,
};
use bytes::Bytes;
use futures::stream::{self, FuturesUnordered, StreamExt};
use rand::Rng;
use std::collections::{HashMap, VecDeque};
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
use xor_name::XorName;
pub const DEFAULT_MERKLE_THRESHOLD: usize = 64;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PaymentMode {
#[default]
Auto,
Merkle,
Single,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MerkleBatchPaymentResult {
pub proofs: HashMap<[u8; 32], Vec<u8>>,
pub chunk_count: usize,
pub storage_cost_atto: String,
pub gas_cost_wei: u128,
#[serde(default)]
pub merkle_payment_timestamp: u64,
}
pub struct PreparedMerkleBatch {
pub depth: u8,
pub pool_commitments: Vec<PoolCommitment>,
pub merkle_payment_timestamp: u64,
candidate_pools: Vec<MerklePaymentCandidatePool>,
tree: MerkleTree,
addresses: Vec<[u8; 32]>,
}
#[derive(Debug, Clone, Default)]
pub(crate) struct MerkleUploadPlan {
pub already_stored: Vec<[u8; 32]>,
pub to_upload: Vec<[u8; 32]>,
to_upload_total_bytes: u64,
}
impl MerkleUploadPlan {
#[must_use]
pub fn to_upload_avg_size(&self) -> u64 {
if self.to_upload.is_empty() {
return 0;
}
self.to_upload_total_bytes / self.to_upload.len() as u64
}
}
impl std::fmt::Debug for PreparedMerkleBatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PreparedMerkleBatch")
.field("depth", &self.depth)
.field("pool_commitments", &self.pool_commitments.len())
.field("merkle_payment_timestamp", &self.merkle_payment_timestamp)
.field("candidate_pools", &self.candidate_pools.len())
.field("addresses", &self.addresses.len())
.finish()
}
}
pub(crate) fn chunk_contents_for_upload_addresses(
chunk_contents: Vec<Bytes>,
addresses: &[[u8; 32]],
) -> Result<Vec<Bytes>> {
if addresses.is_empty() {
return Ok(Vec::new());
}
let mut needed_by_address: HashMap<[u8; 32], usize> = HashMap::new();
for address in addresses {
*needed_by_address.entry(*address).or_default() += 1;
}
let mut chunks_by_address: HashMap<[u8; 32], VecDeque<Bytes>> =
HashMap::with_capacity(needed_by_address.len());
let mut remaining = addresses.len();
for chunk in chunk_contents {
let address = compute_address(&chunk);
if let Some(needed) = needed_by_address.get_mut(&address) {
if *needed > 0 {
chunks_by_address
.entry(address)
.or_default()
.push_back(chunk);
*needed -= 1;
remaining -= 1;
if remaining == 0 {
break;
}
}
}
}
for (address, needed) in &needed_by_address {
if *needed == 0 {
continue;
}
if chunks_by_address.contains_key(address) {
return Err(Error::InvalidData(format!(
"missing duplicate chunk content for merkle address {}",
hex::encode(address)
)));
}
return Err(Error::InvalidData(format!(
"missing chunk content for merkle address {}",
hex::encode(address)
)));
}
let mut selected = Vec::with_capacity(addresses.len());
for address in addresses {
let chunks = chunks_by_address.get_mut(address).ok_or_else(|| {
Error::InvalidData(format!(
"missing chunk content for merkle address {}",
hex::encode(address)
))
})?;
let chunk = chunks.pop_front().ok_or_else(|| {
Error::InvalidData(format!(
"missing duplicate chunk content for merkle address {}",
hex::encode(address)
))
})?;
selected.push(chunk);
}
Ok(selected)
}
#[must_use]
pub fn should_use_merkle(chunk_count: usize, mode: PaymentMode) -> bool {
match mode {
PaymentMode::Auto => chunk_count >= DEFAULT_MERKLE_THRESHOLD,
PaymentMode::Merkle => chunk_count >= 2,
PaymentMode::Single => false,
}
}
impl Client {
#[must_use]
pub fn should_use_merkle(&self, chunk_count: usize, mode: PaymentMode) -> bool {
should_use_merkle(chunk_count, mode)
}
pub async fn pay_for_merkle_batch(
&self,
addresses: &[[u8; 32]],
data_type: u32,
data_size: u64,
) -> Result<MerkleBatchPaymentResult> {
let chunk_count = addresses.len();
if chunk_count < 2 {
return Err(Error::Payment(
"Merkle batch payment requires at least 2 chunks".to_string(),
));
}
if chunk_count > MAX_LEAVES {
return self
.pay_for_merkle_multi_batch(addresses, data_type, data_size)
.await;
}
self.pay_for_merkle_single_batch(addresses, data_type, data_size)
.await
}
pub(crate) async fn plan_merkle_upload(
&self,
chunks: Vec<([u8; 32], u64)>,
data_type: u32,
progress: Option<&mpsc::Sender<UploadEvent>>,
) -> Result<MerkleUploadPlan> {
let total_chunks = chunks.len();
if total_chunks == 0 {
return Ok(MerkleUploadPlan::default());
}
info!("Checking {total_chunks} merkle chunks for existing storage before payment");
let quote_limiter = self.controller().quote.clone();
let quote_concurrency = quote_limiter.current().min(total_chunks.max(1));
let mut check_stream = stream::iter(chunks.into_iter().enumerate())
.map(|(index, (address, data_size))| {
let limiter = quote_limiter.clone();
async move {
let result = observe_op(
&limiter,
|| async move {
self.chunk_already_stored_for_merkle(&address, data_type, data_size)
.await
},
classify_error,
)
.await;
(index, address, data_size, result)
}
})
.buffer_unordered(quote_concurrency);
let mut already_stored: Vec<(usize, [u8; 32])> = Vec::new();
let mut to_upload: Vec<(usize, [u8; 32], u64)> = Vec::new();
let mut checked = 0usize;
while let Some((index, address, data_size, result)) = check_stream.next().await {
let is_already_stored = result?;
checked += 1;
if let Some(tx) = progress {
let _ = tx.try_send(UploadEvent::ChunkQuoted {
quoted: checked,
total: total_chunks,
});
}
if is_already_stored {
debug!(
"Merkle preflight {checked}/{total_chunks}: chunk {} already stored",
hex::encode(address)
);
already_stored.push((index, address));
if let Some(tx) = progress {
let _ = tx.try_send(UploadEvent::ChunkStored {
stored: already_stored.len(),
total: total_chunks,
});
}
} else {
debug!(
"Merkle preflight {checked}/{total_chunks}: chunk {} needs upload",
hex::encode(address)
);
to_upload.push((index, address, data_size));
}
}
already_stored.sort_by_key(|(index, _)| *index);
to_upload.sort_by_key(|(index, _, _)| *index);
let to_upload_total_bytes = to_upload.iter().fold(0u64, |acc, (_, _, data_size)| {
acc.saturating_add(*data_size)
});
let already_stored = already_stored
.into_iter()
.map(|(_, address)| address)
.collect::<Vec<_>>();
let to_upload = to_upload
.into_iter()
.map(|(_, address, _)| address)
.collect::<Vec<_>>();
info!(
"Merkle preflight complete: {} already stored, {} need upload",
already_stored.len(),
to_upload.len()
);
Ok(MerkleUploadPlan {
already_stored,
to_upload,
to_upload_total_bytes,
})
}
async fn chunk_already_stored_for_merkle(
&self,
address: &[u8; 32],
data_type: u32,
data_size: u64,
) -> Result<bool> {
match self.get_store_quotes(address, data_size, data_type).await {
Ok(_) => Ok(false),
Err(Error::AlreadyStored) => Ok(true),
Err(e) => Err(e),
}
}
pub async fn prepare_merkle_batch_external(
&self,
addresses: &[[u8; 32]],
data_type: u32,
data_size: u64,
) -> Result<PreparedMerkleBatch> {
let chunk_count = addresses.len();
let xornames: Vec<XorName> = addresses.iter().map(|a| XorName(*a)).collect();
debug!("Building merkle tree for {chunk_count} chunks");
let tree = MerkleTree::from_xornames(xornames)
.map_err(|e| Error::Payment(format!("Failed to build merkle tree: {e}")))?;
let depth = tree.depth();
let merkle_payment_timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| Error::Payment(format!("System time error: {e}")))?
.as_secs();
debug!("Merkle tree: depth={depth}, leaves={chunk_count}, ts={merkle_payment_timestamp}");
let midpoint_proofs = tree
.reward_candidates(merkle_payment_timestamp)
.map_err(|e| Error::Payment(format!("Failed to generate reward candidates: {e}")))?;
debug!(
"Collecting candidate pools from {} midpoints (concurrent)",
midpoint_proofs.len()
);
let candidate_pools = self
.build_candidate_pools(
&midpoint_proofs,
data_type,
data_size,
merkle_payment_timestamp,
)
.await?;
let pool_commitments: Vec<PoolCommitment> = candidate_pools
.iter()
.map(MerklePaymentCandidatePool::to_commitment)
.collect();
Ok(PreparedMerkleBatch {
depth,
pool_commitments,
merkle_payment_timestamp,
candidate_pools,
tree,
addresses: addresses.to_vec(),
})
}
async fn pay_for_merkle_single_batch(
&self,
addresses: &[[u8; 32]],
data_type: u32,
data_size: u64,
) -> Result<MerkleBatchPaymentResult> {
let wallet = self.require_wallet()?;
let prepared = self
.prepare_merkle_batch_external(addresses, data_type, data_size)
.await?;
info!(
"Submitting merkle batch payment on-chain (depth={})",
prepared.depth
);
let (winner_pool_hash, amount, gas_info) = wallet
.pay_for_merkle_tree(
prepared.depth,
prepared.pool_commitments.clone(),
prepared.merkle_payment_timestamp,
)
.await
.map_err(|e| Error::Payment(format!("Merkle batch payment failed: {e}")))?;
info!(
"Merkle payment succeeded: winner pool {}",
hex::encode(winner_pool_hash)
);
let mut result = finalize_merkle_batch(prepared, winner_pool_hash)?;
result.storage_cost_atto = amount.to_string();
result.gas_cost_wei = gas_info.gas_cost_wei;
Ok(result)
}
async fn pay_for_merkle_multi_batch(
&self,
addresses: &[[u8; 32]],
data_type: u32,
data_size: u64,
) -> Result<MerkleBatchPaymentResult> {
let sub_batches: Vec<&[[u8; 32]]> = addresses.chunks(MAX_LEAVES).collect();
let total_sub_batches = sub_batches.len();
let mut all_proofs = HashMap::with_capacity(addresses.len());
let mut total_storage = Amount::ZERO;
let mut total_gas: u128 = 0;
let mut oldest_ts: u64 = 0;
for (i, chunk) in sub_batches.into_iter().enumerate() {
match self
.pay_for_merkle_single_batch(chunk, data_type, data_size)
.await
{
Ok(sub_result) => {
if let Ok(cost) = sub_result.storage_cost_atto.parse::<Amount>() {
total_storage += cost;
}
total_gas = total_gas.saturating_add(sub_result.gas_cost_wei);
if oldest_ts == 0
|| (sub_result.merkle_payment_timestamp > 0
&& sub_result.merkle_payment_timestamp < oldest_ts)
{
oldest_ts = sub_result.merkle_payment_timestamp;
}
all_proofs.extend(sub_result.proofs);
}
Err(e) => {
if all_proofs.is_empty() {
return Err(e);
}
warn!(
"Merkle sub-batch {}/{total_sub_batches} failed: {e}. \
Returning {} proofs from prior sub-batches",
i + 1,
all_proofs.len()
);
return Ok(MerkleBatchPaymentResult {
chunk_count: all_proofs.len(),
proofs: all_proofs,
storage_cost_atto: total_storage.to_string(),
gas_cost_wei: total_gas,
merkle_payment_timestamp: oldest_ts,
});
}
}
}
Ok(MerkleBatchPaymentResult {
chunk_count: addresses.len(),
proofs: all_proofs,
storage_cost_atto: total_storage.to_string(),
gas_cost_wei: total_gas,
merkle_payment_timestamp: oldest_ts,
})
}
async fn build_candidate_pools(
&self,
midpoint_proofs: &[MidpointProof],
data_type: u32,
data_size: u64,
merkle_payment_timestamp: u64,
) -> Result<Vec<MerklePaymentCandidatePool>> {
let mut pool_futures = FuturesUnordered::new();
for midpoint_proof in midpoint_proofs {
let pool_address = midpoint_proof.address();
let mp = midpoint_proof.clone();
pool_futures.push(async move {
let candidate_nodes = self
.get_merkle_candidate_pool(
&pool_address.0,
data_type,
data_size,
merkle_payment_timestamp,
)
.await?;
Ok::<_, Error>(MerklePaymentCandidatePool {
midpoint_proof: mp,
candidate_nodes,
})
});
}
let mut pools = Vec::with_capacity(midpoint_proofs.len());
while let Some(result) = pool_futures.next().await {
pools.push(result?);
}
Ok(pools)
}
#[allow(clippy::too_many_lines)]
async fn get_merkle_candidate_pool(
&self,
address: &[u8; 32],
data_type: u32,
data_size: u64,
merkle_payment_timestamp: u64,
) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
let node = self.network().node();
let timeout = Duration::from_secs(self.config().quote_timeout_secs);
let query_count = CANDIDATES_PER_POOL * 2;
let mut remote_peers = self
.network()
.find_closest_peers(address, query_count)
.await?;
if remote_peers.len() < CANDIDATES_PER_POOL {
let connected = self.network().connected_peers().await;
for peer in connected {
if !remote_peers.iter().any(|(id, _)| *id == peer) {
remote_peers.push((peer, vec![]));
}
}
}
if remote_peers.len() < CANDIDATES_PER_POOL {
return Err(Error::InsufficientPeers(format!(
"Found {} peers, need {CANDIDATES_PER_POOL} for merkle candidate pool. \
Use --no-merkle or a larger network.",
remote_peers.len()
)));
}
let mut candidate_futures = FuturesUnordered::new();
for (peer_id, peer_addrs) in &remote_peers {
let request_id = self.next_request_id();
let request = MerkleCandidateQuoteRequest {
address: *address,
data_type,
data_size,
merkle_payment_timestamp,
};
let message = ChunkMessage {
request_id,
body: ChunkMessageBody::MerkleCandidateQuoteRequest(request),
};
let message_bytes = match message.encode() {
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to encode merkle candidate request for {peer_id}: {e}");
continue;
}
};
let peer_id_clone = *peer_id;
let addrs_clone = peer_addrs.clone();
let node_clone = node.clone();
let fut = async move {
let result = send_and_await_chunk_response(
&node_clone,
&peer_id_clone,
message_bytes,
request_id,
timeout,
&addrs_clone,
|body| match body {
ChunkMessageBody::MerkleCandidateQuoteResponse(
MerkleCandidateQuoteResponse::Success { candidate_node },
) => {
match rmp_serde::from_slice::<MerklePaymentCandidateNode>(
&candidate_node,
) {
Ok(node) => Some(Ok(node)),
Err(e) => Some(Err(Error::Serialization(format!(
"Failed to deserialize candidate node from {peer_id_clone}: {e}"
)))),
}
}
ChunkMessageBody::MerkleCandidateQuoteResponse(
MerkleCandidateQuoteResponse::Error(e),
) => Some(Err(Error::Protocol(format!(
"Merkle quote error from {peer_id_clone}: {e}"
)))),
_ => None,
},
|e| {
Error::Network(format!(
"Failed to send merkle candidate request to {peer_id_clone}: {e}"
))
},
|| {
Error::Timeout(format!(
"Timeout waiting for merkle candidate from {peer_id_clone}"
))
},
)
.await;
(peer_id_clone, result)
};
candidate_futures.push(fut);
}
self.collect_validated_candidates(&mut candidate_futures, address, merkle_payment_timestamp)
.await
}
async fn collect_validated_candidates(
&self,
futures: &mut FuturesUnordered<
impl std::future::Future<
Output = (
PeerId,
std::result::Result<MerklePaymentCandidateNode, Error>,
),
>,
>,
target_address: &[u8; 32],
merkle_payment_timestamp: u64,
) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
let mut valid: Vec<(PeerId, MerklePaymentCandidateNode)> = Vec::new();
let mut failures: Vec<String> = Vec::new();
while let Some((peer_id, result)) = futures.next().await {
match result {
Ok(candidate) => {
if !verify_merkle_candidate_signature(&candidate) {
warn!("Invalid ML-DSA-65 signature from merkle candidate {peer_id}");
failures.push(format!("{peer_id}: invalid signature"));
continue;
}
if candidate.merkle_payment_timestamp != merkle_payment_timestamp {
warn!("Timestamp mismatch from merkle candidate {peer_id}");
failures.push(format!("{peer_id}: timestamp mismatch"));
continue;
}
valid.push((peer_id, candidate));
}
Err(e) => {
debug!("Failed to get merkle candidate from {peer_id}: {e}");
failures.push(format!("{peer_id}: {e}"));
}
}
}
if valid.len() < CANDIDATES_PER_POOL {
return Err(Error::InsufficientPeers(format!(
"Got {} merkle candidates, need {CANDIDATES_PER_POOL}. Failures: [{}]",
valid.len(),
failures.join("; ")
)));
}
let target_peer = PeerId::from_bytes(*target_address);
valid.sort_by_key(|(peer_id, _)| peer_id.xor_distance(&target_peer));
let candidates: Vec<MerklePaymentCandidateNode> = valid
.into_iter()
.take(CANDIDATES_PER_POOL)
.map(|(_, candidate)| candidate)
.collect();
candidates
.try_into()
.map_err(|_| Error::Payment("Failed to convert candidates to fixed array".to_string()))
}
pub(crate) async fn merkle_upload_chunks(
&self,
chunk_contents: Vec<Bytes>,
addresses: Vec<[u8; 32]>,
batch_result: &MerkleBatchPaymentResult,
progress: Option<&mpsc::Sender<UploadEvent>>,
stored_offset: usize,
total_chunks: usize,
) -> Result<MerkleStoreOutcome> {
let store_limiter = self.controller().store.clone();
let batch_size = chunk_contents.len();
if batch_size != addresses.len() {
return Err(Error::InvalidData(format!(
"merkle upload has {batch_size} chunk contents but {} addresses",
addresses.len()
)));
}
let store_concurrency = store_limiter.current().min(batch_size.max(1));
let chunks: Vec<([u8; 32], Bytes)> = addresses.into_iter().zip(chunk_contents).collect();
let store_one = |addr: [u8; 32], content: Bytes| {
let limiter = store_limiter.clone();
let proof_bytes = batch_result.proofs.get(&addr).cloned();
async move {
let started = std::time::Instant::now();
let proof = proof_bytes.ok_or_else(|| {
Error::Payment(format!(
"Missing merkle proof for chunk {}",
hex::encode(addr)
))
})?;
let peers = self.close_group_peers(&addr).await?;
observe_op(
&limiter,
|| async move { self.chunk_put_to_close_group(content, proof, &peers).await },
classify_error,
)
.await
.map(|_| started)
}
};
merkle_store_with_retry(
chunks,
store_concurrency,
MERKLE_STORE_MAX_ATTEMPTS,
MERKLE_RETRY_BACKOFF,
progress,
stored_offset,
total_chunks,
store_one,
)
.await
}
}
const MERKLE_STORE_MAX_ATTEMPTS: usize = 4;
const MERKLE_RETRY_BACKOFF: Duration = Duration::from_secs(30);
const MERKLE_RETRY_JITTER: f64 = 0.1;
#[derive(Debug, Default)]
pub(crate) struct MerkleStoreOutcome {
pub stored: usize,
pub failed: usize,
pub stats: crate::data::client::batch::WaveAggregateStats,
}
#[allow(clippy::too_many_arguments)]
async fn merkle_store_with_retry<F, Fut>(
chunks: Vec<([u8; 32], Bytes)>,
store_concurrency: usize,
max_attempts: usize,
backoff: Duration,
progress: Option<&mpsc::Sender<UploadEvent>>,
stored_offset: usize,
total: usize,
store_one: F,
) -> Result<MerkleStoreOutcome>
where
F: Fn([u8; 32], Bytes) -> Fut,
Fut: std::future::Future<Output = Result<std::time::Instant>>,
{
let attempts = max_attempts.max(1);
let mut outcome = MerkleStoreOutcome {
stored: stored_offset,
..MerkleStoreOutcome::default()
};
let mut pending = chunks;
for attempt in 0..attempts {
let concurrency = store_concurrency.min(pending.len().max(1)).max(1);
let mut next_failed: Vec<([u8; 32], Bytes)> = Vec::new();
let mut upload_stream = stream::iter(pending.into_iter().map(|(addr, content)| {
let fut = store_one(addr, content.clone());
async move { (addr, content, fut.await) }
}))
.buffer_unordered(concurrency);
while let Some((addr, content, result)) = upload_stream.next().await {
outcome.stats.chunk_attempts_total =
outcome.stats.chunk_attempts_total.saturating_add(1);
match result {
Ok(started) => {
let duration_ms =
u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
outcome.stats.store_durations_ms.push(duration_ms);
let idx = attempt.min(outcome.stats.retries_histogram.len().saturating_sub(1));
outcome.stats.retries_histogram[idx] =
outcome.stats.retries_histogram[idx].saturating_add(1);
outcome.stored += 1;
if let Some(tx) = progress {
let _ = tx.try_send(UploadEvent::ChunkStored {
stored: outcome.stored,
total,
});
}
}
Err(Error::InsufficientPeers(_)) => next_failed.push((addr, content)),
Err(e) => return Err(e),
}
}
if next_failed.is_empty() {
break;
}
if attempt + 1 < attempts {
warn!(
failed = next_failed.len(),
attempt = attempt + 1,
"merkle chunks short of quorum, retrying after backoff"
);
pending = next_failed;
if backoff > Duration::ZERO {
let wait = {
let mut rng = rand::thread_rng();
let factor = 1.0 + rng.gen_range(-MERKLE_RETRY_JITTER..=MERKLE_RETRY_JITTER);
backoff.mul_f64(factor)
};
tokio::time::sleep(wait).await;
}
} else {
outcome.failed = next_failed.len();
break;
}
}
Ok(outcome)
}
pub fn finalize_merkle_batch(
prepared: PreparedMerkleBatch,
winner_pool_hash: [u8; 32],
) -> Result<MerkleBatchPaymentResult> {
let chunk_count = prepared.addresses.len();
let xornames: Vec<XorName> = prepared.addresses.iter().map(|a| XorName(*a)).collect();
let winner_pool = prepared
.candidate_pools
.iter()
.find(|pool| pool.hash() == winner_pool_hash)
.ok_or_else(|| {
Error::Payment(format!(
"Winner pool {} not found in candidate pools",
hex::encode(winner_pool_hash)
))
})?;
info!("Generating merkle proofs for {chunk_count} chunks");
let mut proofs = HashMap::with_capacity(chunk_count);
for (i, xorname) in xornames.iter().enumerate() {
let address_proof = prepared
.tree
.generate_address_proof(i, *xorname)
.map_err(|e| {
Error::Payment(format!(
"Failed to generate address proof for chunk {i}: {e}"
))
})?;
let merkle_proof = MerklePaymentProof::new(*xorname, address_proof, winner_pool.clone());
let tagged_bytes = serialize_merkle_proof(&merkle_proof)
.map_err(|e| Error::Serialization(format!("Failed to serialize merkle proof: {e}")))?;
proofs.insert(prepared.addresses[i], tagged_bytes);
}
info!("Merkle batch payment complete: {chunk_count} proofs generated");
Ok(MerkleBatchPaymentResult {
proofs,
chunk_count,
storage_cost_atto: "0".to_string(),
gas_cost_wei: 0,
merkle_payment_timestamp: prepared.merkle_payment_timestamp,
})
}
#[cfg(test)]
mod send_assertions {
use super::*;
use crate::data::client::Client;
fn _assert_send<T: Send>(_: &T) {}
#[allow(
dead_code,
unreachable_code,
unused_variables,
clippy::diverging_sub_expression
)]
async fn _merkle_upload_chunks_is_send(client: &Client) {
let batch_result: MerkleBatchPaymentResult = todo!();
let fut = client.merkle_upload_chunks(Vec::new(), Vec::new(), &batch_result, None, 0, 0);
_assert_send(&fut);
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use ant_protocol::evm::{Amount, MerkleTree, RewardsAddress, CANDIDATES_PER_POOL};
#[test]
fn test_auto_below_threshold() {
assert!(!should_use_merkle(1, PaymentMode::Auto));
assert!(!should_use_merkle(10, PaymentMode::Auto));
assert!(!should_use_merkle(63, PaymentMode::Auto));
}
#[test]
fn test_auto_at_and_above_threshold() {
assert!(should_use_merkle(64, PaymentMode::Auto));
assert!(should_use_merkle(65, PaymentMode::Auto));
assert!(should_use_merkle(1000, PaymentMode::Auto));
}
#[test]
fn test_merkle_mode_forces_at_2() {
assert!(!should_use_merkle(1, PaymentMode::Merkle));
assert!(should_use_merkle(2, PaymentMode::Merkle));
assert!(should_use_merkle(3, PaymentMode::Merkle));
}
#[test]
fn test_single_mode_always_false() {
assert!(!should_use_merkle(0, PaymentMode::Single));
assert!(!should_use_merkle(64, PaymentMode::Single));
assert!(!should_use_merkle(1000, PaymentMode::Single));
}
#[test]
fn test_default_mode_is_auto() {
assert_eq!(PaymentMode::default(), PaymentMode::Auto);
}
#[test]
fn test_threshold_value() {
assert_eq!(DEFAULT_MERKLE_THRESHOLD, 64);
}
#[test]
fn chunk_contents_for_upload_addresses_preserves_requested_order() {
let first = Bytes::from_static(b"first");
let second = Bytes::from_static(b"second");
let first_addr = compute_address(&first);
let second_addr = compute_address(&second);
let selected = chunk_contents_for_upload_addresses(
vec![first.clone(), second.clone()],
&[second_addr, first_addr],
)
.unwrap();
assert_eq!(selected, vec![second, first]);
}
#[test]
fn chunk_contents_for_upload_addresses_preserves_duplicate_requests() {
let repeated = Bytes::from_static(b"same-content");
let other = Bytes::from_static(b"other-content");
let repeated_addr = compute_address(&repeated);
let selected = chunk_contents_for_upload_addresses(
vec![repeated.clone(), other, repeated.clone()],
&[repeated_addr, repeated_addr],
)
.unwrap();
assert_eq!(selected, vec![repeated.clone(), repeated]);
}
#[test]
fn chunk_contents_for_upload_addresses_ignores_unrequested_duplicates() {
let requested = Bytes::from_static(b"requested-content");
let unrequested = Bytes::from_static(b"unrequested-content");
let requested_addr = compute_address(&requested);
let selected = chunk_contents_for_upload_addresses(
vec![
unrequested.clone(),
requested.clone(),
unrequested.clone(),
unrequested,
],
&[requested_addr],
)
.unwrap();
assert_eq!(selected, vec![requested]);
}
#[test]
fn chunk_contents_for_upload_addresses_errors_for_missing_content() {
let present = Bytes::from_static(b"present-content");
let missing = Bytes::from_static(b"missing-content");
let missing_addr = compute_address(&missing);
let result = chunk_contents_for_upload_addresses(vec![present], &[missing_addr]);
assert!(matches!(result, Err(Error::InvalidData(_))));
}
fn make_test_addresses(count: usize) -> Vec<[u8; 32]> {
(0..count)
.map(|i| {
let xn = XorName::from_content(&i.to_le_bytes());
xn.0
})
.collect()
}
#[test]
fn test_tree_depth_for_known_sizes() {
let cases = [(2, 1), (4, 2), (16, 4), (100, 7), (256, 8)];
for (count, expected_depth) in cases {
let addrs = make_test_addresses(count);
let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
let tree = MerkleTree::from_xornames(xornames).unwrap();
assert_eq!(
tree.depth(),
expected_depth,
"depth mismatch for {count} leaves"
);
}
}
#[test]
fn test_proof_generation_and_verification_for_all_leaves() {
let addrs = make_test_addresses(16);
let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
for (i, xn) in xornames.iter().enumerate() {
let proof = tree.generate_address_proof(i, *xn).unwrap();
assert!(proof.verify(), "proof for leaf {i} should verify");
assert_eq!(proof.depth(), tree.depth() as usize);
}
}
#[test]
fn test_proof_fails_for_wrong_address() {
let addrs = make_test_addresses(8);
let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
let tree = MerkleTree::from_xornames(xornames).unwrap();
let wrong = XorName::from_content(b"wrong");
let proof = tree.generate_address_proof(0, wrong).unwrap();
assert!(!proof.verify(), "proof with wrong address should fail");
}
#[test]
fn test_tree_too_few_leaves() {
let xornames = vec![XorName::from_content(b"only_one")];
let result = MerkleTree::from_xornames(xornames);
assert!(result.is_err());
}
#[test]
fn test_tree_at_max_leaves() {
let addrs = make_test_addresses(MAX_LEAVES);
let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
let tree = MerkleTree::from_xornames(xornames).unwrap();
assert_eq!(tree.leaf_count(), MAX_LEAVES);
}
#[test]
fn test_merkle_proof_serialize_deserialize_roundtrip() {
use ant_protocol::evm::{Amount, MerklePaymentCandidateNode, RewardsAddress};
use ant_protocol::payment::{deserialize_merkle_proof, serialize_merkle_proof};
let addrs = make_test_addresses(4);
let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let candidates = tree.reward_candidates(timestamp).unwrap();
let midpoint = candidates.first().unwrap().clone();
#[allow(clippy::cast_possible_truncation)]
let candidate_nodes: [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] =
std::array::from_fn(|i| MerklePaymentCandidateNode {
pub_key: vec![i as u8; 32],
price: Amount::from(1024u64),
reward_address: RewardsAddress::new([i as u8; 20]),
merkle_payment_timestamp: timestamp,
signature: vec![i as u8; 64],
});
let pool = MerklePaymentCandidatePool {
midpoint_proof: midpoint,
candidate_nodes,
};
let address_proof = tree.generate_address_proof(0, xornames[0]).unwrap();
let merkle_proof = MerklePaymentProof::new(xornames[0], address_proof, pool);
let tagged = serialize_merkle_proof(&merkle_proof).unwrap();
assert_eq!(
tagged.first().copied(),
Some(0x02),
"tag should be PROOF_TAG_MERKLE"
);
let deserialized = deserialize_merkle_proof(&tagged).unwrap();
assert_eq!(deserialized.address, merkle_proof.address);
assert_eq!(
deserialized.winner_pool.candidate_nodes.len(),
CANDIDATES_PER_POOL
);
}
#[test]
fn test_candidate_wrong_timestamp_rejected() {
let candidate = MerklePaymentCandidateNode {
pub_key: vec![0u8; 32],
price: ant_protocol::evm::Amount::ZERO,
reward_address: ant_protocol::evm::RewardsAddress::new([0u8; 20]),
merkle_payment_timestamp: 1000,
signature: vec![0u8; 64],
};
assert_ne!(candidate.merkle_payment_timestamp, 2000);
}
fn make_dummy_candidate_nodes(
timestamp: u64,
) -> [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] {
std::array::from_fn(|i| MerklePaymentCandidateNode {
pub_key: vec![i as u8; 32],
price: Amount::from(1024u64),
reward_address: RewardsAddress::new([i as u8; 20]),
merkle_payment_timestamp: timestamp,
signature: vec![i as u8; 64],
})
}
fn make_prepared_merkle_batch(count: usize) -> PreparedMerkleBatch {
let addrs = make_test_addresses(count);
let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
let tree = MerkleTree::from_xornames(xornames).unwrap();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let midpoints = tree.reward_candidates(timestamp).unwrap();
let candidate_pools: Vec<MerklePaymentCandidatePool> = midpoints
.into_iter()
.map(|mp| MerklePaymentCandidatePool {
midpoint_proof: mp,
candidate_nodes: make_dummy_candidate_nodes(timestamp),
})
.collect();
let pool_commitments = candidate_pools
.iter()
.map(MerklePaymentCandidatePool::to_commitment)
.collect();
PreparedMerkleBatch {
depth: tree.depth(),
pool_commitments,
merkle_payment_timestamp: timestamp,
candidate_pools,
tree,
addresses: addrs,
}
}
#[test]
fn test_finalize_merkle_batch_with_valid_winner() {
let prepared = make_prepared_merkle_batch(4);
let winner_hash = prepared.candidate_pools[0].hash();
let result = finalize_merkle_batch(prepared, winner_hash);
assert!(
result.is_ok(),
"should succeed with valid winner: {result:?}"
);
let batch = result.unwrap();
assert_eq!(batch.chunk_count, 4);
assert_eq!(batch.proofs.len(), 4);
for proof_bytes in batch.proofs.values() {
assert!(!proof_bytes.is_empty());
}
}
#[test]
fn test_finalize_merkle_batch_with_invalid_winner() {
let prepared = make_prepared_merkle_batch(4);
let bad_hash = [0xFF; 32];
let result = finalize_merkle_batch(prepared, bad_hash);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not found in candidate pools"), "got: {err}");
}
#[test]
fn test_finalize_merkle_batch_proofs_are_deserializable() {
use ant_protocol::payment::deserialize_merkle_proof;
let prepared = make_prepared_merkle_batch(8);
let winner_hash = prepared.candidate_pools[0].hash();
let batch = finalize_merkle_batch(prepared, winner_hash).unwrap();
for (addr, proof_bytes) in &batch.proofs {
let proof = deserialize_merkle_proof(proof_bytes);
assert!(
proof.is_ok(),
"proof for {} should deserialize: {:?}",
hex::encode(addr),
proof.err()
);
}
}
#[test]
fn test_batch_split_calculation() {
let addrs = make_test_addresses(MAX_LEAVES);
assert_eq!(addrs.chunks(MAX_LEAVES).count(), 1);
let addrs = make_test_addresses(MAX_LEAVES + 1);
assert_eq!(addrs.chunks(MAX_LEAVES).count(), 2);
let addrs = make_test_addresses(3 * MAX_LEAVES);
assert_eq!(addrs.chunks(MAX_LEAVES).count(), 3);
}
use std::sync::{Arc, Mutex};
fn make_chunks(count: usize) -> Vec<([u8; 32], Bytes)> {
make_test_addresses(count)
.into_iter()
.map(|addr| (addr, Bytes::from_static(b"chunk")))
.collect()
}
#[tokio::test]
async fn store_with_retry_collects_failures_instead_of_aborting() {
let chunks = make_chunks(6);
let failing: std::collections::HashSet<[u8; 32]> =
chunks.iter().take(2).map(|(a, _)| *a).collect();
let failing_for_closure = failing.clone();
let store_one = move |addr: [u8; 32], _content: Bytes| {
let fail = failing_for_closure.contains(&addr);
async move {
if fail {
Err(Error::InsufficientPeers("test shortfall".into()))
} else {
Ok(std::time::Instant::now())
}
}
};
let outcome = merkle_store_with_retry(chunks, 8, 1, Duration::ZERO, None, 0, 6, store_one)
.await
.expect("quorum shortfalls must not abort the batch");
assert_eq!(outcome.stored, 4);
assert_eq!(outcome.failed, 2);
assert_eq!(outcome.stats.retries_histogram[0], 4);
assert_eq!(outcome.stats.chunk_attempts_total, 6);
}
#[tokio::test]
async fn store_with_retry_propagates_non_quorum_errors() {
let chunks = make_chunks(3);
let store_one = |_addr: [u8; 32], _content: Bytes| async move {
Err::<std::time::Instant, _>(Error::Payment("missing proof".into()))
};
let result =
merkle_store_with_retry(chunks, 8, 3, Duration::ZERO, None, 0, 3, store_one).await;
assert!(matches!(result, Err(Error::Payment(_))));
}
#[tokio::test]
async fn store_with_retry_retries_only_the_failed_set() {
let chunks = make_chunks(5);
let total = chunks.len();
let failing: std::collections::HashSet<[u8; 32]> =
chunks.iter().take(2).map(|(a, _)| *a).collect();
let failing_for_closure = failing.clone();
let calls = Arc::new(Mutex::new(Vec::<[u8; 32]>::new()));
let calls_for_closure = calls.clone();
let store_one = move |addr: [u8; 32], _content: Bytes| {
let calls = calls_for_closure.clone();
let already_seen = calls.lock().unwrap().iter().filter(|&&a| a == addr).count();
let fail = failing_for_closure.contains(&addr) && already_seen == 0;
calls.lock().unwrap().push(addr);
async move {
if fail {
Err(Error::InsufficientPeers("round-1 shortfall".into()))
} else {
Ok(std::time::Instant::now())
}
}
};
let outcome =
merkle_store_with_retry(chunks, 8, 3, Duration::ZERO, None, 0, total, store_one)
.await
.expect("should converge after retry");
assert_eq!(outcome.stored, total);
assert_eq!(outcome.failed, 0);
let calls = calls.lock().unwrap();
assert_eq!(calls.len(), total + failing.len());
let round_two: std::collections::HashSet<[u8; 32]> =
calls[total..].iter().copied().collect();
assert_eq!(round_two, failing);
}
#[tokio::test]
async fn store_with_retry_counts_retry_success_once_in_histogram() {
let chunks = make_chunks(4);
let total = chunks.len();
let flaky_addr = chunks[0].0;
let attempts = Arc::new(Mutex::new(HashMap::<[u8; 32], usize>::new()));
let attempts_for_closure = attempts.clone();
let store_one = move |addr: [u8; 32], _content: Bytes| {
let attempts = attempts_for_closure.clone();
let n = {
let mut m = attempts.lock().unwrap();
let entry = m.entry(addr).or_insert(0);
*entry += 1;
*entry
};
let fail = addr == flaky_addr && n == 1;
async move {
if fail {
Err(Error::InsufficientPeers("transient".into()))
} else {
Ok(std::time::Instant::now())
}
}
};
let outcome =
merkle_store_with_retry(chunks, 8, 3, Duration::ZERO, None, 0, total, store_one)
.await
.expect("flaky chunk should recover on retry");
assert_eq!(outcome.stored, total);
assert_eq!(outcome.failed, 0);
assert_eq!(outcome.stats.retries_histogram[0], total - 1);
assert_eq!(outcome.stats.retries_histogram[1], 1);
assert_eq!(outcome.stats.chunk_attempts_total, total + 1);
}
#[tokio::test]
async fn store_with_retry_reports_all_failed_when_retries_exhausted() {
let chunks = make_chunks(3);
let total = chunks.len();
let store_one = |_addr: [u8; 32], _content: Bytes| async move {
Err::<std::time::Instant, _>(Error::InsufficientPeers("never converges".into()))
};
let outcome = merkle_store_with_retry(
chunks,
8,
MERKLE_STORE_MAX_ATTEMPTS,
Duration::ZERO,
None,
0,
total,
store_one,
)
.await
.expect("an exhausted retry budget is reported, not propagated as Err");
assert_eq!(outcome.stored, 0);
assert_eq!(outcome.failed, total);
assert_eq!(
outcome.stats.chunk_attempts_total,
total * MERKLE_STORE_MAX_ATTEMPTS
);
assert_eq!(outcome.stats.retries_histogram, [0; 4]);
}
}