Skip to main content

ant_core/data/client/
merkle.rs

1//! Merkle batch payment support for the Autonomi client.
2//!
3//! When uploading batches of 64+ chunks, merkle payments reduce gas costs
4//! by paying for the entire batch in a single on-chain transaction instead
5//! of one transaction per chunk.
6
7use crate::data::client::adaptive::observe_op;
8use crate::data::client::classify_error;
9use crate::data::client::file::UploadEvent;
10use crate::data::client::Client;
11use crate::data::error::{Error, Result};
12use ant_protocol::evm::{
13    Amount, MerklePaymentCandidateNode, MerklePaymentCandidatePool, MerklePaymentProof, MerkleTree,
14    MidpointProof, PoolCommitment, CANDIDATES_PER_POOL, MAX_LEAVES,
15};
16use ant_protocol::payment::{serialize_merkle_proof, verify_merkle_candidate_signature};
17use ant_protocol::transport::PeerId;
18use ant_protocol::{
19    send_and_await_chunk_response, ChunkMessage, ChunkMessageBody, MerkleCandidateQuoteRequest,
20    MerkleCandidateQuoteResponse,
21};
22use bytes::Bytes;
23use futures::stream::{self, FuturesUnordered, StreamExt};
24use std::collections::HashMap;
25use std::time::Duration;
26use tokio::sync::mpsc;
27use tracing::{debug, info, warn};
28use xor_name::XorName;
29
30/// Default threshold: use merkle payments when chunk count >= this value.
31pub const DEFAULT_MERKLE_THRESHOLD: usize = 64;
32
33/// Payment mode for uploads.
34#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum PaymentMode {
37    /// Automatically choose: merkle for batches >= threshold, single otherwise.
38    #[default]
39    Auto,
40    /// Force merkle batch payment regardless of batch size (min 2 chunks).
41    Merkle,
42    /// Force single-node payment (one tx per chunk).
43    Single,
44}
45
46/// Result of a merkle batch payment.
47#[derive(Debug)]
48pub struct MerkleBatchPaymentResult {
49    /// Map of `XorName` to serialized tagged proof bytes (ready to use in PUT requests).
50    pub proofs: HashMap<[u8; 32], Vec<u8>>,
51    /// Number of chunks in the batch.
52    pub chunk_count: usize,
53    /// Total storage cost in atto (token smallest unit).
54    pub storage_cost_atto: String,
55    /// Total gas cost in wei.
56    pub gas_cost_wei: u128,
57}
58
59/// Prepared merkle batch ready for external payment.
60///
61/// Contains everything needed to submit the on-chain merkle payment
62/// and then finalize proof generation without a wallet.
63pub struct PreparedMerkleBatch {
64    /// Merkle tree depth (needed for the on-chain call).
65    pub depth: u8,
66    /// Pool commitments for the on-chain call.
67    pub pool_commitments: Vec<PoolCommitment>,
68    /// Timestamp used for the merkle payment.
69    pub merkle_payment_timestamp: u64,
70    /// Internal: candidate pools (needed for proof generation after payment).
71    candidate_pools: Vec<MerklePaymentCandidatePool>,
72    /// Internal: the merkle tree (needed for proof generation).
73    tree: MerkleTree,
74    /// Internal: chunk addresses in order.
75    addresses: Vec<[u8; 32]>,
76}
77
78impl std::fmt::Debug for PreparedMerkleBatch {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("PreparedMerkleBatch")
81            .field("depth", &self.depth)
82            .field("pool_commitments", &self.pool_commitments.len())
83            .field("merkle_payment_timestamp", &self.merkle_payment_timestamp)
84            .field("candidate_pools", &self.candidate_pools.len())
85            .field("addresses", &self.addresses.len())
86            .finish()
87    }
88}
89
90/// Determine whether to use merkle payments for a given batch size.
91/// Free function — no Client needed.
92#[must_use]
93pub fn should_use_merkle(chunk_count: usize, mode: PaymentMode) -> bool {
94    match mode {
95        PaymentMode::Auto => chunk_count >= DEFAULT_MERKLE_THRESHOLD,
96        PaymentMode::Merkle => chunk_count >= 2,
97        PaymentMode::Single => false,
98    }
99}
100
101impl Client {
102    /// Determine whether to use merkle payments for a given batch size.
103    #[must_use]
104    pub fn should_use_merkle(&self, chunk_count: usize, mode: PaymentMode) -> bool {
105        should_use_merkle(chunk_count, mode)
106    }
107
108    /// Pay for a batch of chunks using merkle batch payment.
109    ///
110    /// Builds a merkle tree, collects candidate pools, pays on-chain in one tx,
111    /// and returns per-chunk proofs. Splits into sub-batches if > `MAX_LEAVES`.
112    ///
113    /// Does NOT pre-filter already-stored chunks (nodes handle `AlreadyExists`
114    /// gracefully on PUT). This avoids N sequential GET round-trips before payment.
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the batch is too small, candidate collection fails,
119    /// on-chain payment fails, or proof generation fails.
120    pub async fn pay_for_merkle_batch(
121        &self,
122        addresses: &[[u8; 32]],
123        data_type: u32,
124        data_size: u64,
125    ) -> Result<MerkleBatchPaymentResult> {
126        let chunk_count = addresses.len();
127        if chunk_count < 2 {
128            return Err(Error::Payment(
129                "Merkle batch payment requires at least 2 chunks".to_string(),
130            ));
131        }
132
133        if chunk_count > MAX_LEAVES {
134            return self
135                .pay_for_merkle_multi_batch(addresses, data_type, data_size)
136                .await;
137        }
138
139        self.pay_for_merkle_single_batch(addresses, data_type, data_size)
140            .await
141    }
142
143    /// Phase 1 of external-signer merkle payment: prepare batch without paying.
144    ///
145    /// Builds the merkle tree, collects candidate pools from the network,
146    /// and returns the data needed for the on-chain payment call.
147    /// Requires `EvmNetwork` but NOT a wallet.
148    pub async fn prepare_merkle_batch_external(
149        &self,
150        addresses: &[[u8; 32]],
151        data_type: u32,
152        data_size: u64,
153    ) -> Result<PreparedMerkleBatch> {
154        let chunk_count = addresses.len();
155        let xornames: Vec<XorName> = addresses.iter().map(|a| XorName(*a)).collect();
156
157        debug!("Building merkle tree for {chunk_count} chunks");
158
159        // 1. Build merkle tree
160        let tree = MerkleTree::from_xornames(xornames)
161            .map_err(|e| Error::Payment(format!("Failed to build merkle tree: {e}")))?;
162
163        let depth = tree.depth();
164        let merkle_payment_timestamp = std::time::SystemTime::now()
165            .duration_since(std::time::UNIX_EPOCH)
166            .map_err(|e| Error::Payment(format!("System time error: {e}")))?
167            .as_secs();
168
169        debug!("Merkle tree: depth={depth}, leaves={chunk_count}, ts={merkle_payment_timestamp}");
170
171        // 2. Get reward candidates (midpoint proofs)
172        let midpoint_proofs = tree
173            .reward_candidates(merkle_payment_timestamp)
174            .map_err(|e| Error::Payment(format!("Failed to generate reward candidates: {e}")))?;
175
176        debug!(
177            "Collecting candidate pools from {} midpoints (concurrent)",
178            midpoint_proofs.len()
179        );
180
181        // 3. Collect candidate pools from the network (all pools in parallel)
182        let candidate_pools = self
183            .build_candidate_pools(
184                &midpoint_proofs,
185                data_type,
186                data_size,
187                merkle_payment_timestamp,
188            )
189            .await?;
190
191        // 4. Build pool commitments for on-chain payment
192        let pool_commitments: Vec<PoolCommitment> = candidate_pools
193            .iter()
194            .map(MerklePaymentCandidatePool::to_commitment)
195            .collect();
196
197        Ok(PreparedMerkleBatch {
198            depth,
199            pool_commitments,
200            merkle_payment_timestamp,
201            candidate_pools,
202            tree,
203            addresses: addresses.to_vec(),
204        })
205    }
206
207    /// Pay for a single batch (up to `MAX_LEAVES` chunks).
208    async fn pay_for_merkle_single_batch(
209        &self,
210        addresses: &[[u8; 32]],
211        data_type: u32,
212        data_size: u64,
213    ) -> Result<MerkleBatchPaymentResult> {
214        let wallet = self.require_wallet()?;
215        let prepared = self
216            .prepare_merkle_batch_external(addresses, data_type, data_size)
217            .await?;
218
219        info!(
220            "Submitting merkle batch payment on-chain (depth={})",
221            prepared.depth
222        );
223        let (winner_pool_hash, amount, gas_info) = wallet
224            .pay_for_merkle_tree(
225                prepared.depth,
226                prepared.pool_commitments.clone(),
227                prepared.merkle_payment_timestamp,
228            )
229            .await
230            .map_err(|e| Error::Payment(format!("Merkle batch payment failed: {e}")))?;
231
232        info!(
233            "Merkle payment succeeded: winner pool {}",
234            hex::encode(winner_pool_hash)
235        );
236
237        let mut result = finalize_merkle_batch(prepared, winner_pool_hash)?;
238        result.storage_cost_atto = amount.to_string();
239        result.gas_cost_wei = gas_info.gas_cost_wei;
240        Ok(result)
241    }
242
243    /// Handle batches larger than `MAX_LEAVES` by splitting into sub-batches.
244    async fn pay_for_merkle_multi_batch(
245        &self,
246        addresses: &[[u8; 32]],
247        data_type: u32,
248        data_size: u64,
249    ) -> Result<MerkleBatchPaymentResult> {
250        let sub_batches: Vec<&[[u8; 32]]> = addresses.chunks(MAX_LEAVES).collect();
251        let total_sub_batches = sub_batches.len();
252        let mut all_proofs = HashMap::with_capacity(addresses.len());
253        let mut total_storage = Amount::ZERO;
254        let mut total_gas: u128 = 0;
255
256        for (i, chunk) in sub_batches.into_iter().enumerate() {
257            match self
258                .pay_for_merkle_single_batch(chunk, data_type, data_size)
259                .await
260            {
261                Ok(sub_result) => {
262                    if let Ok(cost) = sub_result.storage_cost_atto.parse::<Amount>() {
263                        total_storage += cost;
264                    }
265                    total_gas = total_gas.saturating_add(sub_result.gas_cost_wei);
266                    all_proofs.extend(sub_result.proofs);
267                }
268                Err(e) => {
269                    if all_proofs.is_empty() {
270                        // First sub-batch failed, nothing paid yet -- propagate directly.
271                        return Err(e);
272                    }
273                    // Return partial result so caller can still store already-paid chunks.
274                    warn!(
275                        "Merkle sub-batch {}/{total_sub_batches} failed: {e}. \
276                         Returning {} proofs from prior sub-batches",
277                        i + 1,
278                        all_proofs.len()
279                    );
280                    return Ok(MerkleBatchPaymentResult {
281                        chunk_count: all_proofs.len(),
282                        proofs: all_proofs,
283                        storage_cost_atto: total_storage.to_string(),
284                        gas_cost_wei: total_gas,
285                    });
286                }
287            }
288        }
289
290        Ok(MerkleBatchPaymentResult {
291            chunk_count: addresses.len(),
292            proofs: all_proofs,
293            storage_cost_atto: total_storage.to_string(),
294            gas_cost_wei: total_gas,
295        })
296    }
297
298    /// Build candidate pools by querying the network for each midpoint (concurrently).
299    async fn build_candidate_pools(
300        &self,
301        midpoint_proofs: &[MidpointProof],
302        data_type: u32,
303        data_size: u64,
304        merkle_payment_timestamp: u64,
305    ) -> Result<Vec<MerklePaymentCandidatePool>> {
306        let mut pool_futures = FuturesUnordered::new();
307
308        for midpoint_proof in midpoint_proofs {
309            let pool_address = midpoint_proof.address();
310            let mp = midpoint_proof.clone();
311            pool_futures.push(async move {
312                let candidate_nodes = self
313                    .get_merkle_candidate_pool(
314                        &pool_address.0,
315                        data_type,
316                        data_size,
317                        merkle_payment_timestamp,
318                    )
319                    .await?;
320                Ok::<_, Error>(MerklePaymentCandidatePool {
321                    midpoint_proof: mp,
322                    candidate_nodes,
323                })
324            });
325        }
326
327        let mut pools = Vec::with_capacity(midpoint_proofs.len());
328        while let Some(result) = pool_futures.next().await {
329            pools.push(result?);
330        }
331
332        Ok(pools)
333    }
334
335    /// Collect `CANDIDATES_PER_POOL` (16) merkle candidate quotes from the network.
336    #[allow(clippy::too_many_lines)]
337    async fn get_merkle_candidate_pool(
338        &self,
339        address: &[u8; 32],
340        data_type: u32,
341        data_size: u64,
342        merkle_payment_timestamp: u64,
343    ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
344        let node = self.network().node();
345        let timeout = Duration::from_secs(self.config().quote_timeout_secs);
346
347        // Query extra peers to handle validation failures (bad sigs, wrong type, etc.)
348        let query_count = CANDIDATES_PER_POOL * 2;
349        let mut remote_peers = self
350            .network()
351            .find_closest_peers(address, query_count)
352            .await?;
353
354        // If DHT closest-nodes didn't return enough, supplement with connected peers.
355        // On small networks the DHT iterative lookup may not discover enough peers
356        // close to a random pool address, but we know more peers via direct connections.
357        if remote_peers.len() < CANDIDATES_PER_POOL {
358            let connected = self.network().connected_peers().await;
359            for peer in connected {
360                if !remote_peers.iter().any(|(id, _)| *id == peer) {
361                    remote_peers.push((peer, vec![]));
362                }
363            }
364        }
365
366        if remote_peers.len() < CANDIDATES_PER_POOL {
367            return Err(Error::InsufficientPeers(format!(
368                "Found {} peers, need {CANDIDATES_PER_POOL} for merkle candidate pool. \
369                 Use --no-merkle or a larger network.",
370                remote_peers.len()
371            )));
372        }
373
374        let mut candidate_futures = FuturesUnordered::new();
375
376        for (peer_id, peer_addrs) in &remote_peers {
377            let request_id = self.next_request_id();
378            let request = MerkleCandidateQuoteRequest {
379                address: *address,
380                data_type,
381                data_size,
382                merkle_payment_timestamp,
383            };
384            let message = ChunkMessage {
385                request_id,
386                body: ChunkMessageBody::MerkleCandidateQuoteRequest(request),
387            };
388
389            let message_bytes = match message.encode() {
390                Ok(bytes) => bytes,
391                Err(e) => {
392                    warn!("Failed to encode merkle candidate request for {peer_id}: {e}");
393                    continue;
394                }
395            };
396
397            let peer_id_clone = *peer_id;
398            let addrs_clone = peer_addrs.clone();
399            let node_clone = node.clone();
400
401            let fut = async move {
402                let result = send_and_await_chunk_response(
403                    &node_clone,
404                    &peer_id_clone,
405                    message_bytes,
406                    request_id,
407                    timeout,
408                    &addrs_clone,
409                    |body| match body {
410                        ChunkMessageBody::MerkleCandidateQuoteResponse(
411                            MerkleCandidateQuoteResponse::Success { candidate_node },
412                        ) => {
413                            match rmp_serde::from_slice::<MerklePaymentCandidateNode>(
414                                &candidate_node,
415                            ) {
416                                Ok(node) => Some(Ok(node)),
417                                Err(e) => Some(Err(Error::Serialization(format!(
418                                    "Failed to deserialize candidate node from {peer_id_clone}: {e}"
419                                )))),
420                            }
421                        }
422                        ChunkMessageBody::MerkleCandidateQuoteResponse(
423                            MerkleCandidateQuoteResponse::Error(e),
424                        ) => Some(Err(Error::Protocol(format!(
425                            "Merkle quote error from {peer_id_clone}: {e}"
426                        )))),
427                        _ => None,
428                    },
429                    |e| {
430                        Error::Network(format!(
431                            "Failed to send merkle candidate request to {peer_id_clone}: {e}"
432                        ))
433                    },
434                    || {
435                        Error::Timeout(format!(
436                            "Timeout waiting for merkle candidate from {peer_id_clone}"
437                        ))
438                    },
439                )
440                .await;
441
442                (peer_id_clone, result)
443            };
444
445            candidate_futures.push(fut);
446        }
447
448        self.collect_validated_candidates(&mut candidate_futures, address, merkle_payment_timestamp)
449            .await
450    }
451
452    /// Collect and validate merkle candidate responses, then return the
453    /// `CANDIDATES_PER_POOL` valid responders that are XOR-closest to the
454    /// pool midpoint.
455    ///
456    /// Why distance-sort instead of "first N to respond":
457    /// the storing-node verifier re-runs a network closest-peers lookup of
458    /// the pool midpoint and rejects the pool if fewer than 13 of the 16
459    /// candidate `pub_keys` appear in that authoritative close-set. Pools
460    /// built from the fastest-to-respond quoters fail this check whenever
461    /// truly-close peers are slower (NAT/relay paths) than farther peers.
462    async fn collect_validated_candidates(
463        &self,
464        futures: &mut FuturesUnordered<
465            impl std::future::Future<
466                Output = (
467                    PeerId,
468                    std::result::Result<MerklePaymentCandidateNode, Error>,
469                ),
470            >,
471        >,
472        target_address: &[u8; 32],
473        merkle_payment_timestamp: u64,
474    ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
475        let mut valid: Vec<(PeerId, MerklePaymentCandidateNode)> = Vec::new();
476        let mut failures: Vec<String> = Vec::new();
477
478        while let Some((peer_id, result)) = futures.next().await {
479            match result {
480                Ok(candidate) => {
481                    if !verify_merkle_candidate_signature(&candidate) {
482                        warn!("Invalid ML-DSA-65 signature from merkle candidate {peer_id}");
483                        failures.push(format!("{peer_id}: invalid signature"));
484                        continue;
485                    }
486                    if candidate.merkle_payment_timestamp != merkle_payment_timestamp {
487                        warn!("Timestamp mismatch from merkle candidate {peer_id}");
488                        failures.push(format!("{peer_id}: timestamp mismatch"));
489                        continue;
490                    }
491                    valid.push((peer_id, candidate));
492                }
493                Err(e) => {
494                    debug!("Failed to get merkle candidate from {peer_id}: {e}");
495                    failures.push(format!("{peer_id}: {e}"));
496                }
497            }
498        }
499
500        if valid.len() < CANDIDATES_PER_POOL {
501            return Err(Error::InsufficientPeers(format!(
502                "Got {} merkle candidates, need {CANDIDATES_PER_POOL}. Failures: [{}]",
503                valid.len(),
504                failures.join("; ")
505            )));
506        }
507
508        let target_peer = PeerId::from_bytes(*target_address);
509        valid.sort_by_key(|(peer_id, _)| peer_id.xor_distance(&target_peer));
510
511        let candidates: Vec<MerklePaymentCandidateNode> = valid
512            .into_iter()
513            .take(CANDIDATES_PER_POOL)
514            .map(|(_, candidate)| candidate)
515            .collect();
516
517        candidates
518            .try_into()
519            .map_err(|_| Error::Payment("Failed to convert candidates to fixed array".to_string()))
520    }
521
522    /// Upload chunks using pre-computed merkle proofs from a batch payment.
523    ///
524    /// Each chunk is matched to its proof from `batch_result.proofs`,
525    /// then stored to its close group concurrently. Returns the number
526    /// of chunks successfully stored.
527    ///
528    /// # Errors
529    ///
530    /// Returns an error if any chunk is missing its proof or storage fails.
531    pub(crate) async fn merkle_upload_chunks(
532        &self,
533        chunk_contents: Vec<Bytes>,
534        addresses: Vec<[u8; 32]>,
535        batch_result: &MerkleBatchPaymentResult,
536        progress: Option<&mpsc::Sender<UploadEvent>>,
537    ) -> Result<usize> {
538        let mut stored = 0usize;
539        let store_limiter = self.controller().store.clone();
540        // Clamp fan-out to batch size — partial batches should not
541        // pay for unused slots (see PERF-RESULTS.md).
542        let batch_size = chunk_contents.len();
543        let store_concurrency = store_limiter.current().min(batch_size.max(1));
544        let mut upload_stream = stream::iter(chunk_contents.into_iter().zip(addresses).map(
545            |(content, addr)| {
546                let proof_bytes = batch_result.proofs.get(&addr).cloned();
547                let limiter = store_limiter.clone();
548                async move {
549                    let proof = proof_bytes.ok_or_else(|| {
550                        Error::Payment(format!(
551                            "Missing merkle proof for chunk {}",
552                            hex::encode(addr)
553                        ))
554                    })?;
555                    let peers = self.close_group_peers(&addr).await?;
556                    observe_op(
557                        &limiter,
558                        || async move {
559                            self.chunk_put_to_close_group(content, proof, &peers).await
560                        },
561                        classify_error,
562                    )
563                    .await
564                }
565            },
566        ))
567        .buffer_unordered(store_concurrency);
568
569        while let Some(result) = upload_stream.next().await {
570            result?;
571            stored += 1;
572            if let Some(tx) = progress {
573                let _ = tx.try_send(UploadEvent::ChunkStored {
574                    stored,
575                    total: batch_size,
576                });
577            }
578        }
579
580        Ok(stored)
581    }
582}
583
584/// Phase 2 of external-signer merkle payment: generate proofs from winner.
585///
586/// Takes the prepared batch and the winner pool hash returned by the
587/// on-chain payment transaction. Generates per-chunk merkle proofs.
588pub fn finalize_merkle_batch(
589    prepared: PreparedMerkleBatch,
590    winner_pool_hash: [u8; 32],
591) -> Result<MerkleBatchPaymentResult> {
592    let chunk_count = prepared.addresses.len();
593    let xornames: Vec<XorName> = prepared.addresses.iter().map(|a| XorName(*a)).collect();
594
595    // Find the winner pool
596    let winner_pool = prepared
597        .candidate_pools
598        .iter()
599        .find(|pool| pool.hash() == winner_pool_hash)
600        .ok_or_else(|| {
601            Error::Payment(format!(
602                "Winner pool {} not found in candidate pools",
603                hex::encode(winner_pool_hash)
604            ))
605        })?;
606
607    // Generate proofs for each chunk
608    info!("Generating merkle proofs for {chunk_count} chunks");
609    let mut proofs = HashMap::with_capacity(chunk_count);
610
611    for (i, xorname) in xornames.iter().enumerate() {
612        let address_proof = prepared
613            .tree
614            .generate_address_proof(i, *xorname)
615            .map_err(|e| {
616                Error::Payment(format!(
617                    "Failed to generate address proof for chunk {i}: {e}"
618                ))
619            })?;
620
621        let merkle_proof = MerklePaymentProof::new(*xorname, address_proof, winner_pool.clone());
622
623        let tagged_bytes = serialize_merkle_proof(&merkle_proof)
624            .map_err(|e| Error::Serialization(format!("Failed to serialize merkle proof: {e}")))?;
625
626        proofs.insert(prepared.addresses[i], tagged_bytes);
627    }
628
629    info!("Merkle batch payment complete: {chunk_count} proofs generated");
630
631    Ok(MerkleBatchPaymentResult {
632        proofs,
633        chunk_count,
634        storage_cost_atto: "0".to_string(),
635        gas_cost_wei: 0,
636    })
637}
638
639/// Compile-time assertions that merkle method futures are Send.
640#[cfg(test)]
641mod send_assertions {
642    use super::*;
643    use crate::data::client::Client;
644
645    fn _assert_send<T: Send>(_: &T) {}
646
647    #[allow(
648        dead_code,
649        unreachable_code,
650        unused_variables,
651        clippy::diverging_sub_expression
652    )]
653    async fn _merkle_upload_chunks_is_send(client: &Client) {
654        let batch_result: MerkleBatchPaymentResult = todo!();
655        let fut = client.merkle_upload_chunks(Vec::new(), Vec::new(), &batch_result, None);
656        _assert_send(&fut);
657    }
658}
659
660#[cfg(test)]
661#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
662mod tests {
663    use super::*;
664    use ant_protocol::evm::{Amount, MerkleTree, RewardsAddress, CANDIDATES_PER_POOL};
665
666    // =========================================================================
667    // should_use_merkle (free function, no Client needed)
668    // =========================================================================
669
670    #[test]
671    fn test_auto_below_threshold() {
672        assert!(!should_use_merkle(1, PaymentMode::Auto));
673        assert!(!should_use_merkle(10, PaymentMode::Auto));
674        assert!(!should_use_merkle(63, PaymentMode::Auto));
675    }
676
677    #[test]
678    fn test_auto_at_and_above_threshold() {
679        assert!(should_use_merkle(64, PaymentMode::Auto));
680        assert!(should_use_merkle(65, PaymentMode::Auto));
681        assert!(should_use_merkle(1000, PaymentMode::Auto));
682    }
683
684    #[test]
685    fn test_merkle_mode_forces_at_2() {
686        assert!(!should_use_merkle(1, PaymentMode::Merkle));
687        assert!(should_use_merkle(2, PaymentMode::Merkle));
688        assert!(should_use_merkle(3, PaymentMode::Merkle));
689    }
690
691    #[test]
692    fn test_single_mode_always_false() {
693        assert!(!should_use_merkle(0, PaymentMode::Single));
694        assert!(!should_use_merkle(64, PaymentMode::Single));
695        assert!(!should_use_merkle(1000, PaymentMode::Single));
696    }
697
698    #[test]
699    fn test_default_mode_is_auto() {
700        assert_eq!(PaymentMode::default(), PaymentMode::Auto);
701    }
702
703    #[test]
704    fn test_threshold_value() {
705        assert_eq!(DEFAULT_MERKLE_THRESHOLD, 64);
706    }
707
708    // =========================================================================
709    // MerkleTree construction and proof generation (pure, no network)
710    // =========================================================================
711
712    fn make_test_addresses(count: usize) -> Vec<[u8; 32]> {
713        (0..count)
714            .map(|i| {
715                let xn = XorName::from_content(&i.to_le_bytes());
716                xn.0
717            })
718            .collect()
719    }
720
721    #[test]
722    fn test_tree_depth_for_known_sizes() {
723        let cases = [(2, 1), (4, 2), (16, 4), (100, 7), (256, 8)];
724        for (count, expected_depth) in cases {
725            let addrs = make_test_addresses(count);
726            let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
727            let tree = MerkleTree::from_xornames(xornames).unwrap();
728            assert_eq!(
729                tree.depth(),
730                expected_depth,
731                "depth mismatch for {count} leaves"
732            );
733        }
734    }
735
736    #[test]
737    fn test_proof_generation_and_verification_for_all_leaves() {
738        let addrs = make_test_addresses(16);
739        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
740        let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
741
742        for (i, xn) in xornames.iter().enumerate() {
743            let proof = tree.generate_address_proof(i, *xn).unwrap();
744            assert!(proof.verify(), "proof for leaf {i} should verify");
745            assert_eq!(proof.depth(), tree.depth() as usize);
746        }
747    }
748
749    #[test]
750    fn test_proof_fails_for_wrong_address() {
751        let addrs = make_test_addresses(8);
752        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
753        let tree = MerkleTree::from_xornames(xornames).unwrap();
754
755        let wrong = XorName::from_content(b"wrong");
756        let proof = tree.generate_address_proof(0, wrong).unwrap();
757        assert!(!proof.verify(), "proof with wrong address should fail");
758    }
759
760    #[test]
761    fn test_tree_too_few_leaves() {
762        let xornames = vec![XorName::from_content(b"only_one")];
763        let result = MerkleTree::from_xornames(xornames);
764        assert!(result.is_err());
765    }
766
767    #[test]
768    fn test_tree_at_max_leaves() {
769        let addrs = make_test_addresses(MAX_LEAVES);
770        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
771        let tree = MerkleTree::from_xornames(xornames).unwrap();
772        assert_eq!(tree.leaf_count(), MAX_LEAVES);
773    }
774
775    // =========================================================================
776    // Proof serialization round-trip
777    // =========================================================================
778
779    #[test]
780    fn test_merkle_proof_serialize_deserialize_roundtrip() {
781        use ant_protocol::evm::{Amount, MerklePaymentCandidateNode, RewardsAddress};
782        use ant_protocol::payment::{deserialize_merkle_proof, serialize_merkle_proof};
783
784        let addrs = make_test_addresses(4);
785        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
786        let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
787
788        let timestamp = std::time::SystemTime::now()
789            .duration_since(std::time::UNIX_EPOCH)
790            .unwrap()
791            .as_secs();
792
793        let candidates = tree.reward_candidates(timestamp).unwrap();
794        let midpoint = candidates.first().unwrap().clone();
795
796        // Build candidate nodes (with dummy signatures — not ML-DSA, just for serialization test)
797        #[allow(clippy::cast_possible_truncation)]
798        let candidate_nodes: [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] =
799            std::array::from_fn(|i| MerklePaymentCandidateNode {
800                pub_key: vec![i as u8; 32],
801                price: Amount::from(1024u64),
802                reward_address: RewardsAddress::new([i as u8; 20]),
803                merkle_payment_timestamp: timestamp,
804                signature: vec![i as u8; 64],
805            });
806
807        let pool = MerklePaymentCandidatePool {
808            midpoint_proof: midpoint,
809            candidate_nodes,
810        };
811
812        let address_proof = tree.generate_address_proof(0, xornames[0]).unwrap();
813        let merkle_proof = MerklePaymentProof::new(xornames[0], address_proof, pool);
814
815        let tagged = serialize_merkle_proof(&merkle_proof).unwrap();
816        assert_eq!(
817            tagged.first().copied(),
818            Some(0x02),
819            "tag should be PROOF_TAG_MERKLE"
820        );
821
822        let deserialized = deserialize_merkle_proof(&tagged).unwrap();
823        assert_eq!(deserialized.address, merkle_proof.address);
824        assert_eq!(
825            deserialized.winner_pool.candidate_nodes.len(),
826            CANDIDATES_PER_POOL
827        );
828    }
829
830    // =========================================================================
831    // Candidate validation logic
832    // =========================================================================
833
834    #[test]
835    fn test_candidate_wrong_timestamp_rejected() {
836        // Simulates what collect_validated_candidates checks
837        let candidate = MerklePaymentCandidateNode {
838            pub_key: vec![0u8; 32],
839            price: ant_protocol::evm::Amount::ZERO,
840            reward_address: ant_protocol::evm::RewardsAddress::new([0u8; 20]),
841            merkle_payment_timestamp: 1000,
842            signature: vec![0u8; 64],
843        };
844
845        // Timestamp check: 1000 != 2000
846        assert_ne!(candidate.merkle_payment_timestamp, 2000);
847    }
848
849    // =========================================================================
850    // finalize_merkle_batch (external signer)
851    // =========================================================================
852
853    fn make_dummy_candidate_nodes(
854        timestamp: u64,
855    ) -> [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] {
856        std::array::from_fn(|i| MerklePaymentCandidateNode {
857            pub_key: vec![i as u8; 32],
858            price: Amount::from(1024u64),
859            reward_address: RewardsAddress::new([i as u8; 20]),
860            merkle_payment_timestamp: timestamp,
861            signature: vec![i as u8; 64],
862        })
863    }
864
865    fn make_prepared_merkle_batch(count: usize) -> PreparedMerkleBatch {
866        let addrs = make_test_addresses(count);
867        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
868        let tree = MerkleTree::from_xornames(xornames).unwrap();
869
870        let timestamp = std::time::SystemTime::now()
871            .duration_since(std::time::UNIX_EPOCH)
872            .unwrap()
873            .as_secs();
874
875        let midpoints = tree.reward_candidates(timestamp).unwrap();
876
877        let candidate_pools: Vec<MerklePaymentCandidatePool> = midpoints
878            .into_iter()
879            .map(|mp| MerklePaymentCandidatePool {
880                midpoint_proof: mp,
881                candidate_nodes: make_dummy_candidate_nodes(timestamp),
882            })
883            .collect();
884
885        let pool_commitments = candidate_pools
886            .iter()
887            .map(MerklePaymentCandidatePool::to_commitment)
888            .collect();
889
890        PreparedMerkleBatch {
891            depth: tree.depth(),
892            pool_commitments,
893            merkle_payment_timestamp: timestamp,
894            candidate_pools,
895            tree,
896            addresses: addrs,
897        }
898    }
899
900    #[test]
901    fn test_finalize_merkle_batch_with_valid_winner() {
902        let prepared = make_prepared_merkle_batch(4);
903        let winner_hash = prepared.candidate_pools[0].hash();
904
905        let result = finalize_merkle_batch(prepared, winner_hash);
906        assert!(
907            result.is_ok(),
908            "should succeed with valid winner: {result:?}"
909        );
910
911        let batch = result.unwrap();
912        assert_eq!(batch.chunk_count, 4);
913        assert_eq!(batch.proofs.len(), 4);
914
915        // Every proof should be non-empty
916        for proof_bytes in batch.proofs.values() {
917            assert!(!proof_bytes.is_empty());
918        }
919    }
920
921    #[test]
922    fn test_finalize_merkle_batch_with_invalid_winner() {
923        let prepared = make_prepared_merkle_batch(4);
924        let bad_hash = [0xFF; 32];
925
926        let result = finalize_merkle_batch(prepared, bad_hash);
927        assert!(result.is_err());
928        let err = result.unwrap_err().to_string();
929        assert!(err.contains("not found in candidate pools"), "got: {err}");
930    }
931
932    #[test]
933    fn test_finalize_merkle_batch_proofs_are_deserializable() {
934        use ant_protocol::payment::deserialize_merkle_proof;
935
936        let prepared = make_prepared_merkle_batch(8);
937        let winner_hash = prepared.candidate_pools[0].hash();
938
939        let batch = finalize_merkle_batch(prepared, winner_hash).unwrap();
940
941        for (addr, proof_bytes) in &batch.proofs {
942            let proof = deserialize_merkle_proof(proof_bytes);
943            assert!(
944                proof.is_ok(),
945                "proof for {} should deserialize: {:?}",
946                hex::encode(addr),
947                proof.err()
948            );
949        }
950    }
951
952    // =========================================================================
953    // Batch splitting edge cases
954    // =========================================================================
955
956    #[test]
957    fn test_batch_split_calculation() {
958        // MAX_LEAVES chunks should fit in 1 batch
959        let addrs = make_test_addresses(MAX_LEAVES);
960        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 1);
961
962        // MAX_LEAVES + 1 should split into 2
963        let addrs = make_test_addresses(MAX_LEAVES + 1);
964        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 2);
965
966        // 3 * MAX_LEAVES should split into 3
967        let addrs = make_test_addresses(3 * MAX_LEAVES);
968        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 3);
969    }
970}