Skip to main content

ant_core/data/client/
merkle.rs

1//! Merkle batch payment support for the Autonomi client.
2//!
3//! When uploading batches of 64+ chunks, merkle payments reduce gas costs
4//! by paying for the entire batch in a single on-chain transaction instead
5//! of one transaction per chunk.
6
7use crate::data::client::adaptive::observe_op;
8use crate::data::client::classify_error;
9use crate::data::client::file::UploadEvent;
10use crate::data::client::Client;
11use crate::data::error::{Error, Result};
12use ant_protocol::evm::{
13    Amount, MerklePaymentCandidateNode, MerklePaymentCandidatePool, MerklePaymentProof, MerkleTree,
14    MidpointProof, PoolCommitment, CANDIDATES_PER_POOL, MAX_LEAVES,
15};
16use ant_protocol::payment::{serialize_merkle_proof, verify_merkle_candidate_signature};
17use ant_protocol::transport::PeerId;
18use ant_protocol::{
19    send_and_await_chunk_response, ChunkMessage, ChunkMessageBody, MerkleCandidateQuoteRequest,
20    MerkleCandidateQuoteResponse,
21};
22use bytes::Bytes;
23use futures::stream::{self, FuturesUnordered, StreamExt};
24use std::collections::HashMap;
25use std::time::Duration;
26use tokio::sync::mpsc;
27use tracing::{debug, info, warn};
28use xor_name::XorName;
29
30/// Default threshold: use merkle payments when chunk count >= this value.
31pub const DEFAULT_MERKLE_THRESHOLD: usize = 64;
32
33/// Payment mode for uploads.
34#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum PaymentMode {
37    /// Automatically choose: merkle for batches >= threshold, single otherwise.
38    #[default]
39    Auto,
40    /// Force merkle batch payment regardless of batch size (min 2 chunks).
41    Merkle,
42    /// Force single-node payment (one tx per chunk).
43    Single,
44}
45
46/// Result of a merkle batch payment.
47///
48/// Serializable so it can be persisted across runs for resume after a
49/// partial-upload failure. See `crate::data::client::cached_merkle`.
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct MerkleBatchPaymentResult {
52    /// Map of `XorName` to serialized tagged proof bytes (ready to use in PUT requests).
53    pub proofs: HashMap<[u8; 32], Vec<u8>>,
54    /// Number of chunks in the batch.
55    pub chunk_count: usize,
56    /// Total storage cost in atto (token smallest unit).
57    pub storage_cost_atto: String,
58    /// Total gas cost in wei.
59    pub gas_cost_wei: u128,
60    /// Unix timestamp (seconds) used for the on-chain merkle payment.
61    /// Persisted so resume can check whether the on-chain payment has
62    /// aged out beyond the merkle expiration window and the cached
63    /// receipt must be discarded.
64    #[serde(default)]
65    pub merkle_payment_timestamp: u64,
66}
67
68/// Prepared merkle batch ready for external payment.
69///
70/// Contains everything needed to submit the on-chain merkle payment
71/// and then finalize proof generation without a wallet.
72pub struct PreparedMerkleBatch {
73    /// Merkle tree depth (needed for the on-chain call).
74    pub depth: u8,
75    /// Pool commitments for the on-chain call.
76    pub pool_commitments: Vec<PoolCommitment>,
77    /// Timestamp used for the merkle payment.
78    pub merkle_payment_timestamp: u64,
79    /// Internal: candidate pools (needed for proof generation after payment).
80    candidate_pools: Vec<MerklePaymentCandidatePool>,
81    /// Internal: the merkle tree (needed for proof generation).
82    tree: MerkleTree,
83    /// Internal: chunk addresses in order.
84    addresses: Vec<[u8; 32]>,
85}
86
87impl std::fmt::Debug for PreparedMerkleBatch {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("PreparedMerkleBatch")
90            .field("depth", &self.depth)
91            .field("pool_commitments", &self.pool_commitments.len())
92            .field("merkle_payment_timestamp", &self.merkle_payment_timestamp)
93            .field("candidate_pools", &self.candidate_pools.len())
94            .field("addresses", &self.addresses.len())
95            .finish()
96    }
97}
98
99/// Determine whether to use merkle payments for a given batch size.
100/// Free function — no Client needed.
101#[must_use]
102pub fn should_use_merkle(chunk_count: usize, mode: PaymentMode) -> bool {
103    match mode {
104        PaymentMode::Auto => chunk_count >= DEFAULT_MERKLE_THRESHOLD,
105        PaymentMode::Merkle => chunk_count >= 2,
106        PaymentMode::Single => false,
107    }
108}
109
110impl Client {
111    /// Determine whether to use merkle payments for a given batch size.
112    #[must_use]
113    pub fn should_use_merkle(&self, chunk_count: usize, mode: PaymentMode) -> bool {
114        should_use_merkle(chunk_count, mode)
115    }
116
117    /// Pay for a batch of chunks using merkle batch payment.
118    ///
119    /// Builds a merkle tree, collects candidate pools, pays on-chain in one tx,
120    /// and returns per-chunk proofs. Splits into sub-batches if > `MAX_LEAVES`.
121    ///
122    /// Does NOT pre-filter already-stored chunks (nodes handle `AlreadyExists`
123    /// gracefully on PUT). This avoids N sequential GET round-trips before payment.
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if the batch is too small, candidate collection fails,
128    /// on-chain payment fails, or proof generation fails.
129    pub async fn pay_for_merkle_batch(
130        &self,
131        addresses: &[[u8; 32]],
132        data_type: u32,
133        data_size: u64,
134    ) -> Result<MerkleBatchPaymentResult> {
135        let chunk_count = addresses.len();
136        if chunk_count < 2 {
137            return Err(Error::Payment(
138                "Merkle batch payment requires at least 2 chunks".to_string(),
139            ));
140        }
141
142        if chunk_count > MAX_LEAVES {
143            return self
144                .pay_for_merkle_multi_batch(addresses, data_type, data_size)
145                .await;
146        }
147
148        self.pay_for_merkle_single_batch(addresses, data_type, data_size)
149            .await
150    }
151
152    /// Phase 1 of external-signer merkle payment: prepare batch without paying.
153    ///
154    /// Builds the merkle tree, collects candidate pools from the network,
155    /// and returns the data needed for the on-chain payment call.
156    /// Requires `EvmNetwork` but NOT a wallet.
157    pub async fn prepare_merkle_batch_external(
158        &self,
159        addresses: &[[u8; 32]],
160        data_type: u32,
161        data_size: u64,
162    ) -> Result<PreparedMerkleBatch> {
163        let chunk_count = addresses.len();
164        let xornames: Vec<XorName> = addresses.iter().map(|a| XorName(*a)).collect();
165
166        debug!("Building merkle tree for {chunk_count} chunks");
167
168        // 1. Build merkle tree
169        let tree = MerkleTree::from_xornames(xornames)
170            .map_err(|e| Error::Payment(format!("Failed to build merkle tree: {e}")))?;
171
172        let depth = tree.depth();
173        let merkle_payment_timestamp = std::time::SystemTime::now()
174            .duration_since(std::time::UNIX_EPOCH)
175            .map_err(|e| Error::Payment(format!("System time error: {e}")))?
176            .as_secs();
177
178        debug!("Merkle tree: depth={depth}, leaves={chunk_count}, ts={merkle_payment_timestamp}");
179
180        // 2. Get reward candidates (midpoint proofs)
181        let midpoint_proofs = tree
182            .reward_candidates(merkle_payment_timestamp)
183            .map_err(|e| Error::Payment(format!("Failed to generate reward candidates: {e}")))?;
184
185        debug!(
186            "Collecting candidate pools from {} midpoints (concurrent)",
187            midpoint_proofs.len()
188        );
189
190        // 3. Collect candidate pools from the network (all pools in parallel)
191        let candidate_pools = self
192            .build_candidate_pools(
193                &midpoint_proofs,
194                data_type,
195                data_size,
196                merkle_payment_timestamp,
197            )
198            .await?;
199
200        // 4. Build pool commitments for on-chain payment
201        let pool_commitments: Vec<PoolCommitment> = candidate_pools
202            .iter()
203            .map(MerklePaymentCandidatePool::to_commitment)
204            .collect();
205
206        Ok(PreparedMerkleBatch {
207            depth,
208            pool_commitments,
209            merkle_payment_timestamp,
210            candidate_pools,
211            tree,
212            addresses: addresses.to_vec(),
213        })
214    }
215
216    /// Pay for a single batch (up to `MAX_LEAVES` chunks).
217    async fn pay_for_merkle_single_batch(
218        &self,
219        addresses: &[[u8; 32]],
220        data_type: u32,
221        data_size: u64,
222    ) -> Result<MerkleBatchPaymentResult> {
223        let wallet = self.require_wallet()?;
224        let prepared = self
225            .prepare_merkle_batch_external(addresses, data_type, data_size)
226            .await?;
227
228        info!(
229            "Submitting merkle batch payment on-chain (depth={})",
230            prepared.depth
231        );
232        let (winner_pool_hash, amount, gas_info) = wallet
233            .pay_for_merkle_tree(
234                prepared.depth,
235                prepared.pool_commitments.clone(),
236                prepared.merkle_payment_timestamp,
237            )
238            .await
239            .map_err(|e| Error::Payment(format!("Merkle batch payment failed: {e}")))?;
240
241        info!(
242            "Merkle payment succeeded: winner pool {}",
243            hex::encode(winner_pool_hash)
244        );
245
246        let mut result = finalize_merkle_batch(prepared, winner_pool_hash)?;
247        result.storage_cost_atto = amount.to_string();
248        result.gas_cost_wei = gas_info.gas_cost_wei;
249        Ok(result)
250    }
251
252    /// Handle batches larger than `MAX_LEAVES` by splitting into sub-batches.
253    async fn pay_for_merkle_multi_batch(
254        &self,
255        addresses: &[[u8; 32]],
256        data_type: u32,
257        data_size: u64,
258    ) -> Result<MerkleBatchPaymentResult> {
259        let sub_batches: Vec<&[[u8; 32]]> = addresses.chunks(MAX_LEAVES).collect();
260        let total_sub_batches = sub_batches.len();
261        let mut all_proofs = HashMap::with_capacity(addresses.len());
262        let mut total_storage = Amount::ZERO;
263        let mut total_gas: u128 = 0;
264        // Track the oldest sub-batch timestamp so the overall receipt
265        // expires when the *first* sub-batch's on-chain payment ages
266        // out (worst case for resume).
267        let mut oldest_ts: u64 = 0;
268
269        for (i, chunk) in sub_batches.into_iter().enumerate() {
270            match self
271                .pay_for_merkle_single_batch(chunk, data_type, data_size)
272                .await
273            {
274                Ok(sub_result) => {
275                    if let Ok(cost) = sub_result.storage_cost_atto.parse::<Amount>() {
276                        total_storage += cost;
277                    }
278                    total_gas = total_gas.saturating_add(sub_result.gas_cost_wei);
279                    if oldest_ts == 0
280                        || (sub_result.merkle_payment_timestamp > 0
281                            && sub_result.merkle_payment_timestamp < oldest_ts)
282                    {
283                        oldest_ts = sub_result.merkle_payment_timestamp;
284                    }
285                    all_proofs.extend(sub_result.proofs);
286                }
287                Err(e) => {
288                    if all_proofs.is_empty() {
289                        // First sub-batch failed, nothing paid yet -- propagate directly.
290                        return Err(e);
291                    }
292                    // Return partial result so caller can still store already-paid chunks.
293                    warn!(
294                        "Merkle sub-batch {}/{total_sub_batches} failed: {e}. \
295                         Returning {} proofs from prior sub-batches",
296                        i + 1,
297                        all_proofs.len()
298                    );
299                    return Ok(MerkleBatchPaymentResult {
300                        chunk_count: all_proofs.len(),
301                        proofs: all_proofs,
302                        storage_cost_atto: total_storage.to_string(),
303                        gas_cost_wei: total_gas,
304                        merkle_payment_timestamp: oldest_ts,
305                    });
306                }
307            }
308        }
309
310        Ok(MerkleBatchPaymentResult {
311            chunk_count: addresses.len(),
312            proofs: all_proofs,
313            storage_cost_atto: total_storage.to_string(),
314            gas_cost_wei: total_gas,
315            merkle_payment_timestamp: oldest_ts,
316        })
317    }
318
319    /// Build candidate pools by querying the network for each midpoint (concurrently).
320    async fn build_candidate_pools(
321        &self,
322        midpoint_proofs: &[MidpointProof],
323        data_type: u32,
324        data_size: u64,
325        merkle_payment_timestamp: u64,
326    ) -> Result<Vec<MerklePaymentCandidatePool>> {
327        let mut pool_futures = FuturesUnordered::new();
328
329        for midpoint_proof in midpoint_proofs {
330            let pool_address = midpoint_proof.address();
331            let mp = midpoint_proof.clone();
332            pool_futures.push(async move {
333                let candidate_nodes = self
334                    .get_merkle_candidate_pool(
335                        &pool_address.0,
336                        data_type,
337                        data_size,
338                        merkle_payment_timestamp,
339                    )
340                    .await?;
341                Ok::<_, Error>(MerklePaymentCandidatePool {
342                    midpoint_proof: mp,
343                    candidate_nodes,
344                })
345            });
346        }
347
348        let mut pools = Vec::with_capacity(midpoint_proofs.len());
349        while let Some(result) = pool_futures.next().await {
350            pools.push(result?);
351        }
352
353        Ok(pools)
354    }
355
356    /// Collect `CANDIDATES_PER_POOL` (16) merkle candidate quotes from the network.
357    #[allow(clippy::too_many_lines)]
358    async fn get_merkle_candidate_pool(
359        &self,
360        address: &[u8; 32],
361        data_type: u32,
362        data_size: u64,
363        merkle_payment_timestamp: u64,
364    ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
365        let node = self.network().node();
366        let timeout = Duration::from_secs(self.config().quote_timeout_secs);
367
368        // Query extra peers to handle validation failures (bad sigs, wrong type, etc.)
369        let query_count = CANDIDATES_PER_POOL * 2;
370        let mut remote_peers = self
371            .network()
372            .find_closest_peers(address, query_count)
373            .await?;
374
375        // If DHT closest-nodes didn't return enough, supplement with connected peers.
376        // On small networks the DHT iterative lookup may not discover enough peers
377        // close to a random pool address, but we know more peers via direct connections.
378        if remote_peers.len() < CANDIDATES_PER_POOL {
379            let connected = self.network().connected_peers().await;
380            for peer in connected {
381                if !remote_peers.iter().any(|(id, _)| *id == peer) {
382                    remote_peers.push((peer, vec![]));
383                }
384            }
385        }
386
387        if remote_peers.len() < CANDIDATES_PER_POOL {
388            return Err(Error::InsufficientPeers(format!(
389                "Found {} peers, need {CANDIDATES_PER_POOL} for merkle candidate pool. \
390                 Use --no-merkle or a larger network.",
391                remote_peers.len()
392            )));
393        }
394
395        let mut candidate_futures = FuturesUnordered::new();
396
397        for (peer_id, peer_addrs) in &remote_peers {
398            let request_id = self.next_request_id();
399            let request = MerkleCandidateQuoteRequest {
400                address: *address,
401                data_type,
402                data_size,
403                merkle_payment_timestamp,
404            };
405            let message = ChunkMessage {
406                request_id,
407                body: ChunkMessageBody::MerkleCandidateQuoteRequest(request),
408            };
409
410            let message_bytes = match message.encode() {
411                Ok(bytes) => bytes,
412                Err(e) => {
413                    warn!("Failed to encode merkle candidate request for {peer_id}: {e}");
414                    continue;
415                }
416            };
417
418            let peer_id_clone = *peer_id;
419            let addrs_clone = peer_addrs.clone();
420            let node_clone = node.clone();
421
422            let fut = async move {
423                let result = send_and_await_chunk_response(
424                    &node_clone,
425                    &peer_id_clone,
426                    message_bytes,
427                    request_id,
428                    timeout,
429                    &addrs_clone,
430                    |body| match body {
431                        ChunkMessageBody::MerkleCandidateQuoteResponse(
432                            MerkleCandidateQuoteResponse::Success { candidate_node },
433                        ) => {
434                            match rmp_serde::from_slice::<MerklePaymentCandidateNode>(
435                                &candidate_node,
436                            ) {
437                                Ok(node) => Some(Ok(node)),
438                                Err(e) => Some(Err(Error::Serialization(format!(
439                                    "Failed to deserialize candidate node from {peer_id_clone}: {e}"
440                                )))),
441                            }
442                        }
443                        ChunkMessageBody::MerkleCandidateQuoteResponse(
444                            MerkleCandidateQuoteResponse::Error(e),
445                        ) => Some(Err(Error::Protocol(format!(
446                            "Merkle quote error from {peer_id_clone}: {e}"
447                        )))),
448                        _ => None,
449                    },
450                    |e| {
451                        Error::Network(format!(
452                            "Failed to send merkle candidate request to {peer_id_clone}: {e}"
453                        ))
454                    },
455                    || {
456                        Error::Timeout(format!(
457                            "Timeout waiting for merkle candidate from {peer_id_clone}"
458                        ))
459                    },
460                )
461                .await;
462
463                (peer_id_clone, result)
464            };
465
466            candidate_futures.push(fut);
467        }
468
469        self.collect_validated_candidates(&mut candidate_futures, address, merkle_payment_timestamp)
470            .await
471    }
472
473    /// Collect and validate merkle candidate responses, then return the
474    /// `CANDIDATES_PER_POOL` valid responders that are XOR-closest to the
475    /// pool midpoint.
476    ///
477    /// Why distance-sort instead of "first N to respond":
478    /// the storing-node verifier re-runs a network closest-peers lookup of
479    /// the pool midpoint and rejects the pool if fewer than 13 of the 16
480    /// candidate `pub_keys` appear in that authoritative close-set. Pools
481    /// built from the fastest-to-respond quoters fail this check whenever
482    /// truly-close peers are slower (NAT/relay paths) than farther peers.
483    async fn collect_validated_candidates(
484        &self,
485        futures: &mut FuturesUnordered<
486            impl std::future::Future<
487                Output = (
488                    PeerId,
489                    std::result::Result<MerklePaymentCandidateNode, Error>,
490                ),
491            >,
492        >,
493        target_address: &[u8; 32],
494        merkle_payment_timestamp: u64,
495    ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
496        let mut valid: Vec<(PeerId, MerklePaymentCandidateNode)> = Vec::new();
497        let mut failures: Vec<String> = Vec::new();
498
499        while let Some((peer_id, result)) = futures.next().await {
500            match result {
501                Ok(candidate) => {
502                    if !verify_merkle_candidate_signature(&candidate) {
503                        warn!("Invalid ML-DSA-65 signature from merkle candidate {peer_id}");
504                        failures.push(format!("{peer_id}: invalid signature"));
505                        continue;
506                    }
507                    if candidate.merkle_payment_timestamp != merkle_payment_timestamp {
508                        warn!("Timestamp mismatch from merkle candidate {peer_id}");
509                        failures.push(format!("{peer_id}: timestamp mismatch"));
510                        continue;
511                    }
512                    valid.push((peer_id, candidate));
513                }
514                Err(e) => {
515                    debug!("Failed to get merkle candidate from {peer_id}: {e}");
516                    failures.push(format!("{peer_id}: {e}"));
517                }
518            }
519        }
520
521        if valid.len() < CANDIDATES_PER_POOL {
522            return Err(Error::InsufficientPeers(format!(
523                "Got {} merkle candidates, need {CANDIDATES_PER_POOL}. Failures: [{}]",
524                valid.len(),
525                failures.join("; ")
526            )));
527        }
528
529        let target_peer = PeerId::from_bytes(*target_address);
530        valid.sort_by_key(|(peer_id, _)| peer_id.xor_distance(&target_peer));
531
532        let candidates: Vec<MerklePaymentCandidateNode> = valid
533            .into_iter()
534            .take(CANDIDATES_PER_POOL)
535            .map(|(_, candidate)| candidate)
536            .collect();
537
538        candidates
539            .try_into()
540            .map_err(|_| Error::Payment("Failed to convert candidates to fixed array".to_string()))
541    }
542
543    /// Upload chunks using pre-computed merkle proofs from a batch payment.
544    ///
545    /// Each chunk is matched to its proof from `batch_result.proofs`,
546    /// then stored to its close group concurrently. Returns the number
547    /// of chunks successfully stored.
548    ///
549    /// # Errors
550    ///
551    /// Returns an error if any chunk is missing its proof or storage fails.
552    pub(crate) async fn merkle_upload_chunks(
553        &self,
554        chunk_contents: Vec<Bytes>,
555        addresses: Vec<[u8; 32]>,
556        batch_result: &MerkleBatchPaymentResult,
557        progress: Option<&mpsc::Sender<UploadEvent>>,
558    ) -> Result<(usize, crate::data::client::batch::WaveAggregateStats)> {
559        let mut stored = 0usize;
560        let mut stats = crate::data::client::batch::WaveAggregateStats::default();
561        let store_limiter = self.controller().store.clone();
562        // Clamp fan-out to batch size — partial batches should not
563        // pay for unused slots (see PERF-RESULTS.md).
564        let batch_size = chunk_contents.len();
565        let store_concurrency = store_limiter.current().min(batch_size.max(1));
566        let mut upload_stream = stream::iter(chunk_contents.into_iter().zip(addresses).map(
567            |(content, addr)| {
568                let proof_bytes = batch_result.proofs.get(&addr).cloned();
569                let limiter = store_limiter.clone();
570                async move {
571                    let started = std::time::Instant::now();
572                    let proof = proof_bytes.ok_or_else(|| {
573                        Error::Payment(format!(
574                            "Missing merkle proof for chunk {}",
575                            hex::encode(addr)
576                        ))
577                    })?;
578                    let peers = self.close_group_peers(&addr).await?;
579                    observe_op(
580                        &limiter,
581                        || async move {
582                            self.chunk_put_to_close_group(content, proof, &peers).await
583                        },
584                        classify_error,
585                    )
586                    .await
587                    .map(|_| started)
588                }
589            },
590        ))
591        .buffer_unordered(store_concurrency);
592
593        while let Some(result) = upload_stream.next().await {
594            let started = result?;
595            let duration_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
596            stats.store_durations_ms.push(duration_ms);
597            stats.chunk_attempts_total = stats.chunk_attempts_total.saturating_add(1);
598            stats.retries_histogram[0] = stats.retries_histogram[0].saturating_add(1);
599            stored += 1;
600            if let Some(tx) = progress {
601                let _ = tx.try_send(UploadEvent::ChunkStored {
602                    stored,
603                    total: batch_size,
604                });
605            }
606        }
607
608        Ok((stored, stats))
609    }
610}
611
612/// Phase 2 of external-signer merkle payment: generate proofs from winner.
613///
614/// Takes the prepared batch and the winner pool hash returned by the
615/// on-chain payment transaction. Generates per-chunk merkle proofs.
616pub fn finalize_merkle_batch(
617    prepared: PreparedMerkleBatch,
618    winner_pool_hash: [u8; 32],
619) -> Result<MerkleBatchPaymentResult> {
620    let chunk_count = prepared.addresses.len();
621    let xornames: Vec<XorName> = prepared.addresses.iter().map(|a| XorName(*a)).collect();
622
623    // Find the winner pool
624    let winner_pool = prepared
625        .candidate_pools
626        .iter()
627        .find(|pool| pool.hash() == winner_pool_hash)
628        .ok_or_else(|| {
629            Error::Payment(format!(
630                "Winner pool {} not found in candidate pools",
631                hex::encode(winner_pool_hash)
632            ))
633        })?;
634
635    // Generate proofs for each chunk
636    info!("Generating merkle proofs for {chunk_count} chunks");
637    let mut proofs = HashMap::with_capacity(chunk_count);
638
639    for (i, xorname) in xornames.iter().enumerate() {
640        let address_proof = prepared
641            .tree
642            .generate_address_proof(i, *xorname)
643            .map_err(|e| {
644                Error::Payment(format!(
645                    "Failed to generate address proof for chunk {i}: {e}"
646                ))
647            })?;
648
649        let merkle_proof = MerklePaymentProof::new(*xorname, address_proof, winner_pool.clone());
650
651        let tagged_bytes = serialize_merkle_proof(&merkle_proof)
652            .map_err(|e| Error::Serialization(format!("Failed to serialize merkle proof: {e}")))?;
653
654        proofs.insert(prepared.addresses[i], tagged_bytes);
655    }
656
657    info!("Merkle batch payment complete: {chunk_count} proofs generated");
658
659    Ok(MerkleBatchPaymentResult {
660        proofs,
661        chunk_count,
662        storage_cost_atto: "0".to_string(),
663        gas_cost_wei: 0,
664        merkle_payment_timestamp: prepared.merkle_payment_timestamp,
665    })
666}
667
668/// Compile-time assertions that merkle method futures are Send.
669#[cfg(test)]
670mod send_assertions {
671    use super::*;
672    use crate::data::client::Client;
673
674    fn _assert_send<T: Send>(_: &T) {}
675
676    #[allow(
677        dead_code,
678        unreachable_code,
679        unused_variables,
680        clippy::diverging_sub_expression
681    )]
682    async fn _merkle_upload_chunks_is_send(client: &Client) {
683        let batch_result: MerkleBatchPaymentResult = todo!();
684        let fut = client.merkle_upload_chunks(Vec::new(), Vec::new(), &batch_result, None);
685        _assert_send(&fut);
686    }
687}
688
689#[cfg(test)]
690#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
691mod tests {
692    use super::*;
693    use ant_protocol::evm::{Amount, MerkleTree, RewardsAddress, CANDIDATES_PER_POOL};
694
695    // =========================================================================
696    // should_use_merkle (free function, no Client needed)
697    // =========================================================================
698
699    #[test]
700    fn test_auto_below_threshold() {
701        assert!(!should_use_merkle(1, PaymentMode::Auto));
702        assert!(!should_use_merkle(10, PaymentMode::Auto));
703        assert!(!should_use_merkle(63, PaymentMode::Auto));
704    }
705
706    #[test]
707    fn test_auto_at_and_above_threshold() {
708        assert!(should_use_merkle(64, PaymentMode::Auto));
709        assert!(should_use_merkle(65, PaymentMode::Auto));
710        assert!(should_use_merkle(1000, PaymentMode::Auto));
711    }
712
713    #[test]
714    fn test_merkle_mode_forces_at_2() {
715        assert!(!should_use_merkle(1, PaymentMode::Merkle));
716        assert!(should_use_merkle(2, PaymentMode::Merkle));
717        assert!(should_use_merkle(3, PaymentMode::Merkle));
718    }
719
720    #[test]
721    fn test_single_mode_always_false() {
722        assert!(!should_use_merkle(0, PaymentMode::Single));
723        assert!(!should_use_merkle(64, PaymentMode::Single));
724        assert!(!should_use_merkle(1000, PaymentMode::Single));
725    }
726
727    #[test]
728    fn test_default_mode_is_auto() {
729        assert_eq!(PaymentMode::default(), PaymentMode::Auto);
730    }
731
732    #[test]
733    fn test_threshold_value() {
734        assert_eq!(DEFAULT_MERKLE_THRESHOLD, 64);
735    }
736
737    // =========================================================================
738    // MerkleTree construction and proof generation (pure, no network)
739    // =========================================================================
740
741    fn make_test_addresses(count: usize) -> Vec<[u8; 32]> {
742        (0..count)
743            .map(|i| {
744                let xn = XorName::from_content(&i.to_le_bytes());
745                xn.0
746            })
747            .collect()
748    }
749
750    #[test]
751    fn test_tree_depth_for_known_sizes() {
752        let cases = [(2, 1), (4, 2), (16, 4), (100, 7), (256, 8)];
753        for (count, expected_depth) in cases {
754            let addrs = make_test_addresses(count);
755            let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
756            let tree = MerkleTree::from_xornames(xornames).unwrap();
757            assert_eq!(
758                tree.depth(),
759                expected_depth,
760                "depth mismatch for {count} leaves"
761            );
762        }
763    }
764
765    #[test]
766    fn test_proof_generation_and_verification_for_all_leaves() {
767        let addrs = make_test_addresses(16);
768        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
769        let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
770
771        for (i, xn) in xornames.iter().enumerate() {
772            let proof = tree.generate_address_proof(i, *xn).unwrap();
773            assert!(proof.verify(), "proof for leaf {i} should verify");
774            assert_eq!(proof.depth(), tree.depth() as usize);
775        }
776    }
777
778    #[test]
779    fn test_proof_fails_for_wrong_address() {
780        let addrs = make_test_addresses(8);
781        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
782        let tree = MerkleTree::from_xornames(xornames).unwrap();
783
784        let wrong = XorName::from_content(b"wrong");
785        let proof = tree.generate_address_proof(0, wrong).unwrap();
786        assert!(!proof.verify(), "proof with wrong address should fail");
787    }
788
789    #[test]
790    fn test_tree_too_few_leaves() {
791        let xornames = vec![XorName::from_content(b"only_one")];
792        let result = MerkleTree::from_xornames(xornames);
793        assert!(result.is_err());
794    }
795
796    #[test]
797    fn test_tree_at_max_leaves() {
798        let addrs = make_test_addresses(MAX_LEAVES);
799        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
800        let tree = MerkleTree::from_xornames(xornames).unwrap();
801        assert_eq!(tree.leaf_count(), MAX_LEAVES);
802    }
803
804    // =========================================================================
805    // Proof serialization round-trip
806    // =========================================================================
807
808    #[test]
809    fn test_merkle_proof_serialize_deserialize_roundtrip() {
810        use ant_protocol::evm::{Amount, MerklePaymentCandidateNode, RewardsAddress};
811        use ant_protocol::payment::{deserialize_merkle_proof, serialize_merkle_proof};
812
813        let addrs = make_test_addresses(4);
814        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
815        let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
816
817        let timestamp = std::time::SystemTime::now()
818            .duration_since(std::time::UNIX_EPOCH)
819            .unwrap()
820            .as_secs();
821
822        let candidates = tree.reward_candidates(timestamp).unwrap();
823        let midpoint = candidates.first().unwrap().clone();
824
825        // Build candidate nodes (with dummy signatures — not ML-DSA, just for serialization test)
826        #[allow(clippy::cast_possible_truncation)]
827        let candidate_nodes: [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] =
828            std::array::from_fn(|i| MerklePaymentCandidateNode {
829                pub_key: vec![i as u8; 32],
830                price: Amount::from(1024u64),
831                reward_address: RewardsAddress::new([i as u8; 20]),
832                merkle_payment_timestamp: timestamp,
833                signature: vec![i as u8; 64],
834            });
835
836        let pool = MerklePaymentCandidatePool {
837            midpoint_proof: midpoint,
838            candidate_nodes,
839        };
840
841        let address_proof = tree.generate_address_proof(0, xornames[0]).unwrap();
842        let merkle_proof = MerklePaymentProof::new(xornames[0], address_proof, pool);
843
844        let tagged = serialize_merkle_proof(&merkle_proof).unwrap();
845        assert_eq!(
846            tagged.first().copied(),
847            Some(0x02),
848            "tag should be PROOF_TAG_MERKLE"
849        );
850
851        let deserialized = deserialize_merkle_proof(&tagged).unwrap();
852        assert_eq!(deserialized.address, merkle_proof.address);
853        assert_eq!(
854            deserialized.winner_pool.candidate_nodes.len(),
855            CANDIDATES_PER_POOL
856        );
857    }
858
859    // =========================================================================
860    // Candidate validation logic
861    // =========================================================================
862
863    #[test]
864    fn test_candidate_wrong_timestamp_rejected() {
865        // Simulates what collect_validated_candidates checks
866        let candidate = MerklePaymentCandidateNode {
867            pub_key: vec![0u8; 32],
868            price: ant_protocol::evm::Amount::ZERO,
869            reward_address: ant_protocol::evm::RewardsAddress::new([0u8; 20]),
870            merkle_payment_timestamp: 1000,
871            signature: vec![0u8; 64],
872        };
873
874        // Timestamp check: 1000 != 2000
875        assert_ne!(candidate.merkle_payment_timestamp, 2000);
876    }
877
878    // =========================================================================
879    // finalize_merkle_batch (external signer)
880    // =========================================================================
881
882    fn make_dummy_candidate_nodes(
883        timestamp: u64,
884    ) -> [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] {
885        std::array::from_fn(|i| MerklePaymentCandidateNode {
886            pub_key: vec![i as u8; 32],
887            price: Amount::from(1024u64),
888            reward_address: RewardsAddress::new([i as u8; 20]),
889            merkle_payment_timestamp: timestamp,
890            signature: vec![i as u8; 64],
891        })
892    }
893
894    fn make_prepared_merkle_batch(count: usize) -> PreparedMerkleBatch {
895        let addrs = make_test_addresses(count);
896        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
897        let tree = MerkleTree::from_xornames(xornames).unwrap();
898
899        let timestamp = std::time::SystemTime::now()
900            .duration_since(std::time::UNIX_EPOCH)
901            .unwrap()
902            .as_secs();
903
904        let midpoints = tree.reward_candidates(timestamp).unwrap();
905
906        let candidate_pools: Vec<MerklePaymentCandidatePool> = midpoints
907            .into_iter()
908            .map(|mp| MerklePaymentCandidatePool {
909                midpoint_proof: mp,
910                candidate_nodes: make_dummy_candidate_nodes(timestamp),
911            })
912            .collect();
913
914        let pool_commitments = candidate_pools
915            .iter()
916            .map(MerklePaymentCandidatePool::to_commitment)
917            .collect();
918
919        PreparedMerkleBatch {
920            depth: tree.depth(),
921            pool_commitments,
922            merkle_payment_timestamp: timestamp,
923            candidate_pools,
924            tree,
925            addresses: addrs,
926        }
927    }
928
929    #[test]
930    fn test_finalize_merkle_batch_with_valid_winner() {
931        let prepared = make_prepared_merkle_batch(4);
932        let winner_hash = prepared.candidate_pools[0].hash();
933
934        let result = finalize_merkle_batch(prepared, winner_hash);
935        assert!(
936            result.is_ok(),
937            "should succeed with valid winner: {result:?}"
938        );
939
940        let batch = result.unwrap();
941        assert_eq!(batch.chunk_count, 4);
942        assert_eq!(batch.proofs.len(), 4);
943
944        // Every proof should be non-empty
945        for proof_bytes in batch.proofs.values() {
946            assert!(!proof_bytes.is_empty());
947        }
948    }
949
950    #[test]
951    fn test_finalize_merkle_batch_with_invalid_winner() {
952        let prepared = make_prepared_merkle_batch(4);
953        let bad_hash = [0xFF; 32];
954
955        let result = finalize_merkle_batch(prepared, bad_hash);
956        assert!(result.is_err());
957        let err = result.unwrap_err().to_string();
958        assert!(err.contains("not found in candidate pools"), "got: {err}");
959    }
960
961    #[test]
962    fn test_finalize_merkle_batch_proofs_are_deserializable() {
963        use ant_protocol::payment::deserialize_merkle_proof;
964
965        let prepared = make_prepared_merkle_batch(8);
966        let winner_hash = prepared.candidate_pools[0].hash();
967
968        let batch = finalize_merkle_batch(prepared, winner_hash).unwrap();
969
970        for (addr, proof_bytes) in &batch.proofs {
971            let proof = deserialize_merkle_proof(proof_bytes);
972            assert!(
973                proof.is_ok(),
974                "proof for {} should deserialize: {:?}",
975                hex::encode(addr),
976                proof.err()
977            );
978        }
979    }
980
981    // =========================================================================
982    // Batch splitting edge cases
983    // =========================================================================
984
985    #[test]
986    fn test_batch_split_calculation() {
987        // MAX_LEAVES chunks should fit in 1 batch
988        let addrs = make_test_addresses(MAX_LEAVES);
989        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 1);
990
991        // MAX_LEAVES + 1 should split into 2
992        let addrs = make_test_addresses(MAX_LEAVES + 1);
993        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 2);
994
995        // 3 * MAX_LEAVES should split into 3
996        let addrs = make_test_addresses(3 * MAX_LEAVES);
997        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 3);
998    }
999}