Skip to main content

ant_core/data/client/
merkle.rs

1//! Merkle batch payment support for the Autonomi client.
2//!
3//! When uploading batches of 64+ chunks, merkle payments reduce gas costs
4//! by paying for the entire batch in a single on-chain transaction instead
5//! of one transaction per chunk.
6
7use crate::data::client::Client;
8use crate::data::error::{Error, Result};
9use ant_protocol::evm::{
10    Amount, MerklePaymentCandidateNode, MerklePaymentCandidatePool, MerklePaymentProof, MerkleTree,
11    MidpointProof, PoolCommitment, CANDIDATES_PER_POOL, MAX_LEAVES,
12};
13use ant_protocol::payment::{serialize_merkle_proof, verify_merkle_candidate_signature};
14use ant_protocol::transport::PeerId;
15use ant_protocol::{
16    send_and_await_chunk_response, ChunkMessage, ChunkMessageBody, MerkleCandidateQuoteRequest,
17    MerkleCandidateQuoteResponse,
18};
19use bytes::Bytes;
20use futures::stream::{self, FuturesUnordered, StreamExt};
21use std::collections::HashMap;
22use std::time::Duration;
23use tracing::{debug, info, warn};
24use xor_name::XorName;
25
26/// Default threshold: use merkle payments when chunk count >= this value.
27pub const DEFAULT_MERKLE_THRESHOLD: usize = 64;
28
29/// Payment mode for uploads.
30#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum PaymentMode {
33    /// Automatically choose: merkle for batches >= threshold, single otherwise.
34    #[default]
35    Auto,
36    /// Force merkle batch payment regardless of batch size (min 2 chunks).
37    Merkle,
38    /// Force single-node payment (one tx per chunk).
39    Single,
40}
41
42/// Result of a merkle batch payment.
43#[derive(Debug)]
44pub struct MerkleBatchPaymentResult {
45    /// Map of `XorName` to serialized tagged proof bytes (ready to use in PUT requests).
46    pub proofs: HashMap<[u8; 32], Vec<u8>>,
47    /// Number of chunks in the batch.
48    pub chunk_count: usize,
49    /// Total storage cost in atto (token smallest unit).
50    pub storage_cost_atto: String,
51    /// Total gas cost in wei.
52    pub gas_cost_wei: u128,
53}
54
55/// Prepared merkle batch ready for external payment.
56///
57/// Contains everything needed to submit the on-chain merkle payment
58/// and then finalize proof generation without a wallet.
59pub struct PreparedMerkleBatch {
60    /// Merkle tree depth (needed for the on-chain call).
61    pub depth: u8,
62    /// Pool commitments for the on-chain call.
63    pub pool_commitments: Vec<PoolCommitment>,
64    /// Timestamp used for the merkle payment.
65    pub merkle_payment_timestamp: u64,
66    /// Internal: candidate pools (needed for proof generation after payment).
67    candidate_pools: Vec<MerklePaymentCandidatePool>,
68    /// Internal: the merkle tree (needed for proof generation).
69    tree: MerkleTree,
70    /// Internal: chunk addresses in order.
71    addresses: Vec<[u8; 32]>,
72}
73
74impl std::fmt::Debug for PreparedMerkleBatch {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("PreparedMerkleBatch")
77            .field("depth", &self.depth)
78            .field("pool_commitments", &self.pool_commitments.len())
79            .field("merkle_payment_timestamp", &self.merkle_payment_timestamp)
80            .field("candidate_pools", &self.candidate_pools.len())
81            .field("addresses", &self.addresses.len())
82            .finish()
83    }
84}
85
86/// Determine whether to use merkle payments for a given batch size.
87/// Free function — no Client needed.
88#[must_use]
89pub fn should_use_merkle(chunk_count: usize, mode: PaymentMode) -> bool {
90    match mode {
91        PaymentMode::Auto => chunk_count >= DEFAULT_MERKLE_THRESHOLD,
92        PaymentMode::Merkle => chunk_count >= 2,
93        PaymentMode::Single => false,
94    }
95}
96
97impl Client {
98    /// Determine whether to use merkle payments for a given batch size.
99    #[must_use]
100    pub fn should_use_merkle(&self, chunk_count: usize, mode: PaymentMode) -> bool {
101        should_use_merkle(chunk_count, mode)
102    }
103
104    /// Pay for a batch of chunks using merkle batch payment.
105    ///
106    /// Builds a merkle tree, collects candidate pools, pays on-chain in one tx,
107    /// and returns per-chunk proofs. Splits into sub-batches if > `MAX_LEAVES`.
108    ///
109    /// Does NOT pre-filter already-stored chunks (nodes handle `AlreadyExists`
110    /// gracefully on PUT). This avoids N sequential GET round-trips before payment.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the batch is too small, candidate collection fails,
115    /// on-chain payment fails, or proof generation fails.
116    pub async fn pay_for_merkle_batch(
117        &self,
118        addresses: &[[u8; 32]],
119        data_type: u32,
120        data_size: u64,
121    ) -> Result<MerkleBatchPaymentResult> {
122        let chunk_count = addresses.len();
123        if chunk_count < 2 {
124            return Err(Error::Payment(
125                "Merkle batch payment requires at least 2 chunks".to_string(),
126            ));
127        }
128
129        if chunk_count > MAX_LEAVES {
130            return self
131                .pay_for_merkle_multi_batch(addresses, data_type, data_size)
132                .await;
133        }
134
135        self.pay_for_merkle_single_batch(addresses, data_type, data_size)
136            .await
137    }
138
139    /// Phase 1 of external-signer merkle payment: prepare batch without paying.
140    ///
141    /// Builds the merkle tree, collects candidate pools from the network,
142    /// and returns the data needed for the on-chain payment call.
143    /// Requires `EvmNetwork` but NOT a wallet.
144    pub async fn prepare_merkle_batch_external(
145        &self,
146        addresses: &[[u8; 32]],
147        data_type: u32,
148        data_size: u64,
149    ) -> Result<PreparedMerkleBatch> {
150        let chunk_count = addresses.len();
151        let xornames: Vec<XorName> = addresses.iter().map(|a| XorName(*a)).collect();
152
153        debug!("Building merkle tree for {chunk_count} chunks");
154
155        // 1. Build merkle tree
156        let tree = MerkleTree::from_xornames(xornames)
157            .map_err(|e| Error::Payment(format!("Failed to build merkle tree: {e}")))?;
158
159        let depth = tree.depth();
160        let merkle_payment_timestamp = std::time::SystemTime::now()
161            .duration_since(std::time::UNIX_EPOCH)
162            .map_err(|e| Error::Payment(format!("System time error: {e}")))?
163            .as_secs();
164
165        debug!("Merkle tree: depth={depth}, leaves={chunk_count}, ts={merkle_payment_timestamp}");
166
167        // 2. Get reward candidates (midpoint proofs)
168        let midpoint_proofs = tree
169            .reward_candidates(merkle_payment_timestamp)
170            .map_err(|e| Error::Payment(format!("Failed to generate reward candidates: {e}")))?;
171
172        debug!(
173            "Collecting candidate pools from {} midpoints (concurrent)",
174            midpoint_proofs.len()
175        );
176
177        // 3. Collect candidate pools from the network (all pools in parallel)
178        let candidate_pools = self
179            .build_candidate_pools(
180                &midpoint_proofs,
181                data_type,
182                data_size,
183                merkle_payment_timestamp,
184            )
185            .await?;
186
187        // 4. Build pool commitments for on-chain payment
188        let pool_commitments: Vec<PoolCommitment> = candidate_pools
189            .iter()
190            .map(MerklePaymentCandidatePool::to_commitment)
191            .collect();
192
193        Ok(PreparedMerkleBatch {
194            depth,
195            pool_commitments,
196            merkle_payment_timestamp,
197            candidate_pools,
198            tree,
199            addresses: addresses.to_vec(),
200        })
201    }
202
203    /// Pay for a single batch (up to `MAX_LEAVES` chunks).
204    async fn pay_for_merkle_single_batch(
205        &self,
206        addresses: &[[u8; 32]],
207        data_type: u32,
208        data_size: u64,
209    ) -> Result<MerkleBatchPaymentResult> {
210        let wallet = self.require_wallet()?;
211        let prepared = self
212            .prepare_merkle_batch_external(addresses, data_type, data_size)
213            .await?;
214
215        info!(
216            "Submitting merkle batch payment on-chain (depth={})",
217            prepared.depth
218        );
219        let (winner_pool_hash, amount, gas_info) = wallet
220            .pay_for_merkle_tree(
221                prepared.depth,
222                prepared.pool_commitments.clone(),
223                prepared.merkle_payment_timestamp,
224            )
225            .await
226            .map_err(|e| Error::Payment(format!("Merkle batch payment failed: {e}")))?;
227
228        info!(
229            "Merkle payment succeeded: winner pool {}",
230            hex::encode(winner_pool_hash)
231        );
232
233        let mut result = finalize_merkle_batch(prepared, winner_pool_hash)?;
234        result.storage_cost_atto = amount.to_string();
235        result.gas_cost_wei = gas_info.gas_cost_wei;
236        Ok(result)
237    }
238
239    /// Handle batches larger than `MAX_LEAVES` by splitting into sub-batches.
240    async fn pay_for_merkle_multi_batch(
241        &self,
242        addresses: &[[u8; 32]],
243        data_type: u32,
244        data_size: u64,
245    ) -> Result<MerkleBatchPaymentResult> {
246        let sub_batches: Vec<&[[u8; 32]]> = addresses.chunks(MAX_LEAVES).collect();
247        let total_sub_batches = sub_batches.len();
248        let mut all_proofs = HashMap::with_capacity(addresses.len());
249        let mut total_storage = Amount::ZERO;
250        let mut total_gas: u128 = 0;
251
252        for (i, chunk) in sub_batches.into_iter().enumerate() {
253            match self
254                .pay_for_merkle_single_batch(chunk, data_type, data_size)
255                .await
256            {
257                Ok(sub_result) => {
258                    if let Ok(cost) = sub_result.storage_cost_atto.parse::<Amount>() {
259                        total_storage += cost;
260                    }
261                    total_gas = total_gas.saturating_add(sub_result.gas_cost_wei);
262                    all_proofs.extend(sub_result.proofs);
263                }
264                Err(e) => {
265                    if all_proofs.is_empty() {
266                        // First sub-batch failed, nothing paid yet -- propagate directly.
267                        return Err(e);
268                    }
269                    // Return partial result so caller can still store already-paid chunks.
270                    warn!(
271                        "Merkle sub-batch {}/{total_sub_batches} failed: {e}. \
272                         Returning {} proofs from prior sub-batches",
273                        i + 1,
274                        all_proofs.len()
275                    );
276                    return Ok(MerkleBatchPaymentResult {
277                        chunk_count: all_proofs.len(),
278                        proofs: all_proofs,
279                        storage_cost_atto: total_storage.to_string(),
280                        gas_cost_wei: total_gas,
281                    });
282                }
283            }
284        }
285
286        Ok(MerkleBatchPaymentResult {
287            chunk_count: addresses.len(),
288            proofs: all_proofs,
289            storage_cost_atto: total_storage.to_string(),
290            gas_cost_wei: total_gas,
291        })
292    }
293
294    /// Build candidate pools by querying the network for each midpoint (concurrently).
295    async fn build_candidate_pools(
296        &self,
297        midpoint_proofs: &[MidpointProof],
298        data_type: u32,
299        data_size: u64,
300        merkle_payment_timestamp: u64,
301    ) -> Result<Vec<MerklePaymentCandidatePool>> {
302        let mut pool_futures = FuturesUnordered::new();
303
304        for midpoint_proof in midpoint_proofs {
305            let pool_address = midpoint_proof.address();
306            let mp = midpoint_proof.clone();
307            pool_futures.push(async move {
308                let candidate_nodes = self
309                    .get_merkle_candidate_pool(
310                        &pool_address.0,
311                        data_type,
312                        data_size,
313                        merkle_payment_timestamp,
314                    )
315                    .await?;
316                Ok::<_, Error>(MerklePaymentCandidatePool {
317                    midpoint_proof: mp,
318                    candidate_nodes,
319                })
320            });
321        }
322
323        let mut pools = Vec::with_capacity(midpoint_proofs.len());
324        while let Some(result) = pool_futures.next().await {
325            pools.push(result?);
326        }
327
328        Ok(pools)
329    }
330
331    /// Collect `CANDIDATES_PER_POOL` (16) merkle candidate quotes from the network.
332    #[allow(clippy::too_many_lines)]
333    async fn get_merkle_candidate_pool(
334        &self,
335        address: &[u8; 32],
336        data_type: u32,
337        data_size: u64,
338        merkle_payment_timestamp: u64,
339    ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
340        let node = self.network().node();
341        let timeout = Duration::from_secs(self.config().quote_timeout_secs);
342
343        // Query extra peers to handle validation failures (bad sigs, wrong type, etc.)
344        let query_count = CANDIDATES_PER_POOL * 2;
345        let mut remote_peers = self
346            .network()
347            .find_closest_peers(address, query_count)
348            .await?;
349
350        // If DHT closest-nodes didn't return enough, supplement with connected peers.
351        // On small networks the DHT iterative lookup may not discover enough peers
352        // close to a random pool address, but we know more peers via direct connections.
353        if remote_peers.len() < CANDIDATES_PER_POOL {
354            let connected = self.network().connected_peers().await;
355            for peer in connected {
356                if !remote_peers.iter().any(|(id, _)| *id == peer) {
357                    remote_peers.push((peer, vec![]));
358                }
359            }
360        }
361
362        if remote_peers.len() < CANDIDATES_PER_POOL {
363            return Err(Error::InsufficientPeers(format!(
364                "Found {} peers, need {CANDIDATES_PER_POOL} for merkle candidate pool. \
365                 Use --no-merkle or a larger network.",
366                remote_peers.len()
367            )));
368        }
369
370        let mut candidate_futures = FuturesUnordered::new();
371
372        for (peer_id, peer_addrs) in &remote_peers {
373            let request_id = self.next_request_id();
374            let request = MerkleCandidateQuoteRequest {
375                address: *address,
376                data_type,
377                data_size,
378                merkle_payment_timestamp,
379            };
380            let message = ChunkMessage {
381                request_id,
382                body: ChunkMessageBody::MerkleCandidateQuoteRequest(request),
383            };
384
385            let message_bytes = match message.encode() {
386                Ok(bytes) => bytes,
387                Err(e) => {
388                    warn!("Failed to encode merkle candidate request for {peer_id}: {e}");
389                    continue;
390                }
391            };
392
393            let peer_id_clone = *peer_id;
394            let addrs_clone = peer_addrs.clone();
395            let node_clone = node.clone();
396
397            let fut = async move {
398                let result = send_and_await_chunk_response(
399                    &node_clone,
400                    &peer_id_clone,
401                    message_bytes,
402                    request_id,
403                    timeout,
404                    &addrs_clone,
405                    |body| match body {
406                        ChunkMessageBody::MerkleCandidateQuoteResponse(
407                            MerkleCandidateQuoteResponse::Success { candidate_node },
408                        ) => {
409                            match rmp_serde::from_slice::<MerklePaymentCandidateNode>(
410                                &candidate_node,
411                            ) {
412                                Ok(node) => Some(Ok(node)),
413                                Err(e) => Some(Err(Error::Serialization(format!(
414                                    "Failed to deserialize candidate node from {peer_id_clone}: {e}"
415                                )))),
416                            }
417                        }
418                        ChunkMessageBody::MerkleCandidateQuoteResponse(
419                            MerkleCandidateQuoteResponse::Error(e),
420                        ) => Some(Err(Error::Protocol(format!(
421                            "Merkle quote error from {peer_id_clone}: {e}"
422                        )))),
423                        _ => None,
424                    },
425                    |e| {
426                        Error::Network(format!(
427                            "Failed to send merkle candidate request to {peer_id_clone}: {e}"
428                        ))
429                    },
430                    || {
431                        Error::Timeout(format!(
432                            "Timeout waiting for merkle candidate from {peer_id_clone}"
433                        ))
434                    },
435                )
436                .await;
437
438                (peer_id_clone, result)
439            };
440
441            candidate_futures.push(fut);
442        }
443
444        self.collect_validated_candidates(&mut candidate_futures, merkle_payment_timestamp)
445            .await
446    }
447
448    /// Collect and validate merkle candidate responses until we have enough.
449    async fn collect_validated_candidates(
450        &self,
451        futures: &mut FuturesUnordered<
452            impl std::future::Future<
453                Output = (
454                    PeerId,
455                    std::result::Result<MerklePaymentCandidateNode, Error>,
456                ),
457            >,
458        >,
459        merkle_payment_timestamp: u64,
460    ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
461        let mut candidates = Vec::with_capacity(CANDIDATES_PER_POOL);
462        let mut failures: Vec<String> = Vec::new();
463
464        while let Some((peer_id, result)) = futures.next().await {
465            match result {
466                Ok(candidate) => {
467                    if !verify_merkle_candidate_signature(&candidate) {
468                        warn!("Invalid ML-DSA-65 signature from merkle candidate {peer_id}");
469                        failures.push(format!("{peer_id}: invalid signature"));
470                        continue;
471                    }
472                    if candidate.merkle_payment_timestamp != merkle_payment_timestamp {
473                        warn!("Timestamp mismatch from merkle candidate {peer_id}");
474                        failures.push(format!("{peer_id}: timestamp mismatch"));
475                        continue;
476                    }
477                    candidates.push(candidate);
478                    if candidates.len() >= CANDIDATES_PER_POOL {
479                        break;
480                    }
481                }
482                Err(e) => {
483                    debug!("Failed to get merkle candidate from {peer_id}: {e}");
484                    failures.push(format!("{peer_id}: {e}"));
485                }
486            }
487        }
488
489        if candidates.len() < CANDIDATES_PER_POOL {
490            return Err(Error::InsufficientPeers(format!(
491                "Got {} merkle candidates, need {CANDIDATES_PER_POOL}. Failures: [{}]",
492                candidates.len(),
493                failures.join("; ")
494            )));
495        }
496
497        candidates.truncate(CANDIDATES_PER_POOL);
498        candidates
499            .try_into()
500            .map_err(|_| Error::Payment("Failed to convert candidates to fixed array".to_string()))
501    }
502
503    /// Upload chunks using pre-computed merkle proofs from a batch payment.
504    ///
505    /// Each chunk is matched to its proof from `batch_result.proofs`,
506    /// then stored to its close group concurrently. Returns the number
507    /// of chunks successfully stored.
508    ///
509    /// # Errors
510    ///
511    /// Returns an error if any chunk is missing its proof or storage fails.
512    pub(crate) async fn merkle_upload_chunks(
513        &self,
514        chunk_contents: Vec<Bytes>,
515        addresses: Vec<[u8; 32]>,
516        batch_result: &MerkleBatchPaymentResult,
517    ) -> Result<usize> {
518        let mut stored = 0usize;
519        let mut upload_stream = stream::iter(chunk_contents.into_iter().zip(addresses).map(
520            |(content, addr)| {
521                let proof_bytes = batch_result.proofs.get(&addr).cloned();
522                async move {
523                    let proof = proof_bytes.ok_or_else(|| {
524                        Error::Payment(format!(
525                            "Missing merkle proof for chunk {}",
526                            hex::encode(addr)
527                        ))
528                    })?;
529                    let peers = self.close_group_peers(&addr).await?;
530                    self.chunk_put_to_close_group(content, proof, &peers).await
531                }
532            },
533        ))
534        .buffer_unordered(self.config().store_concurrency);
535
536        while let Some(result) = upload_stream.next().await {
537            result?;
538            stored += 1;
539        }
540
541        Ok(stored)
542    }
543}
544
545/// Phase 2 of external-signer merkle payment: generate proofs from winner.
546///
547/// Takes the prepared batch and the winner pool hash returned by the
548/// on-chain payment transaction. Generates per-chunk merkle proofs.
549pub fn finalize_merkle_batch(
550    prepared: PreparedMerkleBatch,
551    winner_pool_hash: [u8; 32],
552) -> Result<MerkleBatchPaymentResult> {
553    let chunk_count = prepared.addresses.len();
554    let xornames: Vec<XorName> = prepared.addresses.iter().map(|a| XorName(*a)).collect();
555
556    // Find the winner pool
557    let winner_pool = prepared
558        .candidate_pools
559        .iter()
560        .find(|pool| pool.hash() == winner_pool_hash)
561        .ok_or_else(|| {
562            Error::Payment(format!(
563                "Winner pool {} not found in candidate pools",
564                hex::encode(winner_pool_hash)
565            ))
566        })?;
567
568    // Generate proofs for each chunk
569    info!("Generating merkle proofs for {chunk_count} chunks");
570    let mut proofs = HashMap::with_capacity(chunk_count);
571
572    for (i, xorname) in xornames.iter().enumerate() {
573        let address_proof = prepared
574            .tree
575            .generate_address_proof(i, *xorname)
576            .map_err(|e| {
577                Error::Payment(format!(
578                    "Failed to generate address proof for chunk {i}: {e}"
579                ))
580            })?;
581
582        let merkle_proof = MerklePaymentProof::new(*xorname, address_proof, winner_pool.clone());
583
584        let tagged_bytes = serialize_merkle_proof(&merkle_proof)
585            .map_err(|e| Error::Serialization(format!("Failed to serialize merkle proof: {e}")))?;
586
587        proofs.insert(prepared.addresses[i], tagged_bytes);
588    }
589
590    info!("Merkle batch payment complete: {chunk_count} proofs generated");
591
592    Ok(MerkleBatchPaymentResult {
593        proofs,
594        chunk_count,
595        storage_cost_atto: "0".to_string(),
596        gas_cost_wei: 0,
597    })
598}
599
600/// Compile-time assertions that merkle method futures are Send.
601#[cfg(test)]
602mod send_assertions {
603    use super::*;
604    use crate::data::client::Client;
605
606    fn _assert_send<T: Send>(_: &T) {}
607
608    #[allow(
609        dead_code,
610        unreachable_code,
611        unused_variables,
612        clippy::diverging_sub_expression
613    )]
614    async fn _merkle_upload_chunks_is_send(client: &Client) {
615        let batch_result: MerkleBatchPaymentResult = todo!();
616        let fut = client.merkle_upload_chunks(Vec::new(), Vec::new(), &batch_result);
617        _assert_send(&fut);
618    }
619}
620
621#[cfg(test)]
622#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
623mod tests {
624    use super::*;
625    use ant_protocol::evm::{Amount, MerkleTree, RewardsAddress, CANDIDATES_PER_POOL};
626
627    // =========================================================================
628    // should_use_merkle (free function, no Client needed)
629    // =========================================================================
630
631    #[test]
632    fn test_auto_below_threshold() {
633        assert!(!should_use_merkle(1, PaymentMode::Auto));
634        assert!(!should_use_merkle(10, PaymentMode::Auto));
635        assert!(!should_use_merkle(63, PaymentMode::Auto));
636    }
637
638    #[test]
639    fn test_auto_at_and_above_threshold() {
640        assert!(should_use_merkle(64, PaymentMode::Auto));
641        assert!(should_use_merkle(65, PaymentMode::Auto));
642        assert!(should_use_merkle(1000, PaymentMode::Auto));
643    }
644
645    #[test]
646    fn test_merkle_mode_forces_at_2() {
647        assert!(!should_use_merkle(1, PaymentMode::Merkle));
648        assert!(should_use_merkle(2, PaymentMode::Merkle));
649        assert!(should_use_merkle(3, PaymentMode::Merkle));
650    }
651
652    #[test]
653    fn test_single_mode_always_false() {
654        assert!(!should_use_merkle(0, PaymentMode::Single));
655        assert!(!should_use_merkle(64, PaymentMode::Single));
656        assert!(!should_use_merkle(1000, PaymentMode::Single));
657    }
658
659    #[test]
660    fn test_default_mode_is_auto() {
661        assert_eq!(PaymentMode::default(), PaymentMode::Auto);
662    }
663
664    #[test]
665    fn test_threshold_value() {
666        assert_eq!(DEFAULT_MERKLE_THRESHOLD, 64);
667    }
668
669    // =========================================================================
670    // MerkleTree construction and proof generation (pure, no network)
671    // =========================================================================
672
673    fn make_test_addresses(count: usize) -> Vec<[u8; 32]> {
674        (0..count)
675            .map(|i| {
676                let xn = XorName::from_content(&i.to_le_bytes());
677                xn.0
678            })
679            .collect()
680    }
681
682    #[test]
683    fn test_tree_depth_for_known_sizes() {
684        let cases = [(2, 1), (4, 2), (16, 4), (100, 7), (256, 8)];
685        for (count, expected_depth) in cases {
686            let addrs = make_test_addresses(count);
687            let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
688            let tree = MerkleTree::from_xornames(xornames).unwrap();
689            assert_eq!(
690                tree.depth(),
691                expected_depth,
692                "depth mismatch for {count} leaves"
693            );
694        }
695    }
696
697    #[test]
698    fn test_proof_generation_and_verification_for_all_leaves() {
699        let addrs = make_test_addresses(16);
700        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
701        let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
702
703        for (i, xn) in xornames.iter().enumerate() {
704            let proof = tree.generate_address_proof(i, *xn).unwrap();
705            assert!(proof.verify(), "proof for leaf {i} should verify");
706            assert_eq!(proof.depth(), tree.depth() as usize);
707        }
708    }
709
710    #[test]
711    fn test_proof_fails_for_wrong_address() {
712        let addrs = make_test_addresses(8);
713        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
714        let tree = MerkleTree::from_xornames(xornames).unwrap();
715
716        let wrong = XorName::from_content(b"wrong");
717        let proof = tree.generate_address_proof(0, wrong).unwrap();
718        assert!(!proof.verify(), "proof with wrong address should fail");
719    }
720
721    #[test]
722    fn test_tree_too_few_leaves() {
723        let xornames = vec![XorName::from_content(b"only_one")];
724        let result = MerkleTree::from_xornames(xornames);
725        assert!(result.is_err());
726    }
727
728    #[test]
729    fn test_tree_at_max_leaves() {
730        let addrs = make_test_addresses(MAX_LEAVES);
731        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
732        let tree = MerkleTree::from_xornames(xornames).unwrap();
733        assert_eq!(tree.leaf_count(), MAX_LEAVES);
734    }
735
736    // =========================================================================
737    // Proof serialization round-trip
738    // =========================================================================
739
740    #[test]
741    fn test_merkle_proof_serialize_deserialize_roundtrip() {
742        use ant_protocol::evm::{Amount, MerklePaymentCandidateNode, RewardsAddress};
743        use ant_protocol::payment::{deserialize_merkle_proof, serialize_merkle_proof};
744
745        let addrs = make_test_addresses(4);
746        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
747        let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
748
749        let timestamp = std::time::SystemTime::now()
750            .duration_since(std::time::UNIX_EPOCH)
751            .unwrap()
752            .as_secs();
753
754        let candidates = tree.reward_candidates(timestamp).unwrap();
755        let midpoint = candidates.first().unwrap().clone();
756
757        // Build candidate nodes (with dummy signatures — not ML-DSA, just for serialization test)
758        #[allow(clippy::cast_possible_truncation)]
759        let candidate_nodes: [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] =
760            std::array::from_fn(|i| MerklePaymentCandidateNode {
761                pub_key: vec![i as u8; 32],
762                price: Amount::from(1024u64),
763                reward_address: RewardsAddress::new([i as u8; 20]),
764                merkle_payment_timestamp: timestamp,
765                signature: vec![i as u8; 64],
766            });
767
768        let pool = MerklePaymentCandidatePool {
769            midpoint_proof: midpoint,
770            candidate_nodes,
771        };
772
773        let address_proof = tree.generate_address_proof(0, xornames[0]).unwrap();
774        let merkle_proof = MerklePaymentProof::new(xornames[0], address_proof, pool);
775
776        let tagged = serialize_merkle_proof(&merkle_proof).unwrap();
777        assert_eq!(
778            tagged.first().copied(),
779            Some(0x02),
780            "tag should be PROOF_TAG_MERKLE"
781        );
782
783        let deserialized = deserialize_merkle_proof(&tagged).unwrap();
784        assert_eq!(deserialized.address, merkle_proof.address);
785        assert_eq!(
786            deserialized.winner_pool.candidate_nodes.len(),
787            CANDIDATES_PER_POOL
788        );
789    }
790
791    // =========================================================================
792    // Candidate validation logic
793    // =========================================================================
794
795    #[test]
796    fn test_candidate_wrong_timestamp_rejected() {
797        // Simulates what collect_validated_candidates checks
798        let candidate = MerklePaymentCandidateNode {
799            pub_key: vec![0u8; 32],
800            price: ant_protocol::evm::Amount::ZERO,
801            reward_address: ant_protocol::evm::RewardsAddress::new([0u8; 20]),
802            merkle_payment_timestamp: 1000,
803            signature: vec![0u8; 64],
804        };
805
806        // Timestamp check: 1000 != 2000
807        assert_ne!(candidate.merkle_payment_timestamp, 2000);
808    }
809
810    // =========================================================================
811    // finalize_merkle_batch (external signer)
812    // =========================================================================
813
814    fn make_dummy_candidate_nodes(
815        timestamp: u64,
816    ) -> [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] {
817        std::array::from_fn(|i| MerklePaymentCandidateNode {
818            pub_key: vec![i as u8; 32],
819            price: Amount::from(1024u64),
820            reward_address: RewardsAddress::new([i as u8; 20]),
821            merkle_payment_timestamp: timestamp,
822            signature: vec![i as u8; 64],
823        })
824    }
825
826    fn make_prepared_merkle_batch(count: usize) -> PreparedMerkleBatch {
827        let addrs = make_test_addresses(count);
828        let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
829        let tree = MerkleTree::from_xornames(xornames).unwrap();
830
831        let timestamp = std::time::SystemTime::now()
832            .duration_since(std::time::UNIX_EPOCH)
833            .unwrap()
834            .as_secs();
835
836        let midpoints = tree.reward_candidates(timestamp).unwrap();
837
838        let candidate_pools: Vec<MerklePaymentCandidatePool> = midpoints
839            .into_iter()
840            .map(|mp| MerklePaymentCandidatePool {
841                midpoint_proof: mp,
842                candidate_nodes: make_dummy_candidate_nodes(timestamp),
843            })
844            .collect();
845
846        let pool_commitments = candidate_pools
847            .iter()
848            .map(MerklePaymentCandidatePool::to_commitment)
849            .collect();
850
851        PreparedMerkleBatch {
852            depth: tree.depth(),
853            pool_commitments,
854            merkle_payment_timestamp: timestamp,
855            candidate_pools,
856            tree,
857            addresses: addrs,
858        }
859    }
860
861    #[test]
862    fn test_finalize_merkle_batch_with_valid_winner() {
863        let prepared = make_prepared_merkle_batch(4);
864        let winner_hash = prepared.candidate_pools[0].hash();
865
866        let result = finalize_merkle_batch(prepared, winner_hash);
867        assert!(
868            result.is_ok(),
869            "should succeed with valid winner: {result:?}"
870        );
871
872        let batch = result.unwrap();
873        assert_eq!(batch.chunk_count, 4);
874        assert_eq!(batch.proofs.len(), 4);
875
876        // Every proof should be non-empty
877        for proof_bytes in batch.proofs.values() {
878            assert!(!proof_bytes.is_empty());
879        }
880    }
881
882    #[test]
883    fn test_finalize_merkle_batch_with_invalid_winner() {
884        let prepared = make_prepared_merkle_batch(4);
885        let bad_hash = [0xFF; 32];
886
887        let result = finalize_merkle_batch(prepared, bad_hash);
888        assert!(result.is_err());
889        let err = result.unwrap_err().to_string();
890        assert!(err.contains("not found in candidate pools"), "got: {err}");
891    }
892
893    #[test]
894    fn test_finalize_merkle_batch_proofs_are_deserializable() {
895        use ant_protocol::payment::deserialize_merkle_proof;
896
897        let prepared = make_prepared_merkle_batch(8);
898        let winner_hash = prepared.candidate_pools[0].hash();
899
900        let batch = finalize_merkle_batch(prepared, winner_hash).unwrap();
901
902        for (addr, proof_bytes) in &batch.proofs {
903            let proof = deserialize_merkle_proof(proof_bytes);
904            assert!(
905                proof.is_ok(),
906                "proof for {} should deserialize: {:?}",
907                hex::encode(addr),
908                proof.err()
909            );
910        }
911    }
912
913    // =========================================================================
914    // Batch splitting edge cases
915    // =========================================================================
916
917    #[test]
918    fn test_batch_split_calculation() {
919        // MAX_LEAVES chunks should fit in 1 batch
920        let addrs = make_test_addresses(MAX_LEAVES);
921        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 1);
922
923        // MAX_LEAVES + 1 should split into 2
924        let addrs = make_test_addresses(MAX_LEAVES + 1);
925        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 2);
926
927        // 3 * MAX_LEAVES should split into 3
928        let addrs = make_test_addresses(3 * MAX_LEAVES);
929        assert_eq!(addrs.chunks(MAX_LEAVES).count(), 3);
930    }
931}