1use crate::data::client::adaptive::observe_op;
8use crate::data::client::classify_error;
9use crate::data::client::file::UploadEvent;
10use crate::data::client::Client;
11use crate::data::error::{Error, Result};
12use ant_protocol::evm::{
13 Amount, MerklePaymentCandidateNode, MerklePaymentCandidatePool, MerklePaymentProof, MerkleTree,
14 MidpointProof, PoolCommitment, CANDIDATES_PER_POOL, MAX_LEAVES,
15};
16use ant_protocol::payment::{serialize_merkle_proof, verify_merkle_candidate_signature};
17use ant_protocol::transport::PeerId;
18use ant_protocol::{
19 send_and_await_chunk_response, ChunkMessage, ChunkMessageBody, MerkleCandidateQuoteRequest,
20 MerkleCandidateQuoteResponse,
21};
22use bytes::Bytes;
23use futures::stream::{self, FuturesUnordered, StreamExt};
24use std::collections::HashMap;
25use std::time::Duration;
26use tokio::sync::mpsc;
27use tracing::{debug, info, warn};
28use xor_name::XorName;
29
30pub const DEFAULT_MERKLE_THRESHOLD: usize = 64;
32
33#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum PaymentMode {
37 #[default]
39 Auto,
40 Merkle,
42 Single,
44}
45
46#[derive(Debug)]
48pub struct MerkleBatchPaymentResult {
49 pub proofs: HashMap<[u8; 32], Vec<u8>>,
51 pub chunk_count: usize,
53 pub storage_cost_atto: String,
55 pub gas_cost_wei: u128,
57}
58
59pub struct PreparedMerkleBatch {
64 pub depth: u8,
66 pub pool_commitments: Vec<PoolCommitment>,
68 pub merkle_payment_timestamp: u64,
70 candidate_pools: Vec<MerklePaymentCandidatePool>,
72 tree: MerkleTree,
74 addresses: Vec<[u8; 32]>,
76}
77
78impl std::fmt::Debug for PreparedMerkleBatch {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("PreparedMerkleBatch")
81 .field("depth", &self.depth)
82 .field("pool_commitments", &self.pool_commitments.len())
83 .field("merkle_payment_timestamp", &self.merkle_payment_timestamp)
84 .field("candidate_pools", &self.candidate_pools.len())
85 .field("addresses", &self.addresses.len())
86 .finish()
87 }
88}
89
90#[must_use]
93pub fn should_use_merkle(chunk_count: usize, mode: PaymentMode) -> bool {
94 match mode {
95 PaymentMode::Auto => chunk_count >= DEFAULT_MERKLE_THRESHOLD,
96 PaymentMode::Merkle => chunk_count >= 2,
97 PaymentMode::Single => false,
98 }
99}
100
101impl Client {
102 #[must_use]
104 pub fn should_use_merkle(&self, chunk_count: usize, mode: PaymentMode) -> bool {
105 should_use_merkle(chunk_count, mode)
106 }
107
108 pub async fn pay_for_merkle_batch(
121 &self,
122 addresses: &[[u8; 32]],
123 data_type: u32,
124 data_size: u64,
125 ) -> Result<MerkleBatchPaymentResult> {
126 let chunk_count = addresses.len();
127 if chunk_count < 2 {
128 return Err(Error::Payment(
129 "Merkle batch payment requires at least 2 chunks".to_string(),
130 ));
131 }
132
133 if chunk_count > MAX_LEAVES {
134 return self
135 .pay_for_merkle_multi_batch(addresses, data_type, data_size)
136 .await;
137 }
138
139 self.pay_for_merkle_single_batch(addresses, data_type, data_size)
140 .await
141 }
142
143 pub async fn prepare_merkle_batch_external(
149 &self,
150 addresses: &[[u8; 32]],
151 data_type: u32,
152 data_size: u64,
153 ) -> Result<PreparedMerkleBatch> {
154 let chunk_count = addresses.len();
155 let xornames: Vec<XorName> = addresses.iter().map(|a| XorName(*a)).collect();
156
157 debug!("Building merkle tree for {chunk_count} chunks");
158
159 let tree = MerkleTree::from_xornames(xornames)
161 .map_err(|e| Error::Payment(format!("Failed to build merkle tree: {e}")))?;
162
163 let depth = tree.depth();
164 let merkle_payment_timestamp = std::time::SystemTime::now()
165 .duration_since(std::time::UNIX_EPOCH)
166 .map_err(|e| Error::Payment(format!("System time error: {e}")))?
167 .as_secs();
168
169 debug!("Merkle tree: depth={depth}, leaves={chunk_count}, ts={merkle_payment_timestamp}");
170
171 let midpoint_proofs = tree
173 .reward_candidates(merkle_payment_timestamp)
174 .map_err(|e| Error::Payment(format!("Failed to generate reward candidates: {e}")))?;
175
176 debug!(
177 "Collecting candidate pools from {} midpoints (concurrent)",
178 midpoint_proofs.len()
179 );
180
181 let candidate_pools = self
183 .build_candidate_pools(
184 &midpoint_proofs,
185 data_type,
186 data_size,
187 merkle_payment_timestamp,
188 )
189 .await?;
190
191 let pool_commitments: Vec<PoolCommitment> = candidate_pools
193 .iter()
194 .map(MerklePaymentCandidatePool::to_commitment)
195 .collect();
196
197 Ok(PreparedMerkleBatch {
198 depth,
199 pool_commitments,
200 merkle_payment_timestamp,
201 candidate_pools,
202 tree,
203 addresses: addresses.to_vec(),
204 })
205 }
206
207 async fn pay_for_merkle_single_batch(
209 &self,
210 addresses: &[[u8; 32]],
211 data_type: u32,
212 data_size: u64,
213 ) -> Result<MerkleBatchPaymentResult> {
214 let wallet = self.require_wallet()?;
215 let prepared = self
216 .prepare_merkle_batch_external(addresses, data_type, data_size)
217 .await?;
218
219 info!(
220 "Submitting merkle batch payment on-chain (depth={})",
221 prepared.depth
222 );
223 let (winner_pool_hash, amount, gas_info) = wallet
224 .pay_for_merkle_tree(
225 prepared.depth,
226 prepared.pool_commitments.clone(),
227 prepared.merkle_payment_timestamp,
228 )
229 .await
230 .map_err(|e| Error::Payment(format!("Merkle batch payment failed: {e}")))?;
231
232 info!(
233 "Merkle payment succeeded: winner pool {}",
234 hex::encode(winner_pool_hash)
235 );
236
237 let mut result = finalize_merkle_batch(prepared, winner_pool_hash)?;
238 result.storage_cost_atto = amount.to_string();
239 result.gas_cost_wei = gas_info.gas_cost_wei;
240 Ok(result)
241 }
242
243 async fn pay_for_merkle_multi_batch(
245 &self,
246 addresses: &[[u8; 32]],
247 data_type: u32,
248 data_size: u64,
249 ) -> Result<MerkleBatchPaymentResult> {
250 let sub_batches: Vec<&[[u8; 32]]> = addresses.chunks(MAX_LEAVES).collect();
251 let total_sub_batches = sub_batches.len();
252 let mut all_proofs = HashMap::with_capacity(addresses.len());
253 let mut total_storage = Amount::ZERO;
254 let mut total_gas: u128 = 0;
255
256 for (i, chunk) in sub_batches.into_iter().enumerate() {
257 match self
258 .pay_for_merkle_single_batch(chunk, data_type, data_size)
259 .await
260 {
261 Ok(sub_result) => {
262 if let Ok(cost) = sub_result.storage_cost_atto.parse::<Amount>() {
263 total_storage += cost;
264 }
265 total_gas = total_gas.saturating_add(sub_result.gas_cost_wei);
266 all_proofs.extend(sub_result.proofs);
267 }
268 Err(e) => {
269 if all_proofs.is_empty() {
270 return Err(e);
272 }
273 warn!(
275 "Merkle sub-batch {}/{total_sub_batches} failed: {e}. \
276 Returning {} proofs from prior sub-batches",
277 i + 1,
278 all_proofs.len()
279 );
280 return Ok(MerkleBatchPaymentResult {
281 chunk_count: all_proofs.len(),
282 proofs: all_proofs,
283 storage_cost_atto: total_storage.to_string(),
284 gas_cost_wei: total_gas,
285 });
286 }
287 }
288 }
289
290 Ok(MerkleBatchPaymentResult {
291 chunk_count: addresses.len(),
292 proofs: all_proofs,
293 storage_cost_atto: total_storage.to_string(),
294 gas_cost_wei: total_gas,
295 })
296 }
297
298 async fn build_candidate_pools(
300 &self,
301 midpoint_proofs: &[MidpointProof],
302 data_type: u32,
303 data_size: u64,
304 merkle_payment_timestamp: u64,
305 ) -> Result<Vec<MerklePaymentCandidatePool>> {
306 let mut pool_futures = FuturesUnordered::new();
307
308 for midpoint_proof in midpoint_proofs {
309 let pool_address = midpoint_proof.address();
310 let mp = midpoint_proof.clone();
311 pool_futures.push(async move {
312 let candidate_nodes = self
313 .get_merkle_candidate_pool(
314 &pool_address.0,
315 data_type,
316 data_size,
317 merkle_payment_timestamp,
318 )
319 .await?;
320 Ok::<_, Error>(MerklePaymentCandidatePool {
321 midpoint_proof: mp,
322 candidate_nodes,
323 })
324 });
325 }
326
327 let mut pools = Vec::with_capacity(midpoint_proofs.len());
328 while let Some(result) = pool_futures.next().await {
329 pools.push(result?);
330 }
331
332 Ok(pools)
333 }
334
335 #[allow(clippy::too_many_lines)]
337 async fn get_merkle_candidate_pool(
338 &self,
339 address: &[u8; 32],
340 data_type: u32,
341 data_size: u64,
342 merkle_payment_timestamp: u64,
343 ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
344 let node = self.network().node();
345 let timeout = Duration::from_secs(self.config().quote_timeout_secs);
346
347 let query_count = CANDIDATES_PER_POOL * 2;
349 let mut remote_peers = self
350 .network()
351 .find_closest_peers(address, query_count)
352 .await?;
353
354 if remote_peers.len() < CANDIDATES_PER_POOL {
358 let connected = self.network().connected_peers().await;
359 for peer in connected {
360 if !remote_peers.iter().any(|(id, _)| *id == peer) {
361 remote_peers.push((peer, vec![]));
362 }
363 }
364 }
365
366 if remote_peers.len() < CANDIDATES_PER_POOL {
367 return Err(Error::InsufficientPeers(format!(
368 "Found {} peers, need {CANDIDATES_PER_POOL} for merkle candidate pool. \
369 Use --no-merkle or a larger network.",
370 remote_peers.len()
371 )));
372 }
373
374 let mut candidate_futures = FuturesUnordered::new();
375
376 for (peer_id, peer_addrs) in &remote_peers {
377 let request_id = self.next_request_id();
378 let request = MerkleCandidateQuoteRequest {
379 address: *address,
380 data_type,
381 data_size,
382 merkle_payment_timestamp,
383 };
384 let message = ChunkMessage {
385 request_id,
386 body: ChunkMessageBody::MerkleCandidateQuoteRequest(request),
387 };
388
389 let message_bytes = match message.encode() {
390 Ok(bytes) => bytes,
391 Err(e) => {
392 warn!("Failed to encode merkle candidate request for {peer_id}: {e}");
393 continue;
394 }
395 };
396
397 let peer_id_clone = *peer_id;
398 let addrs_clone = peer_addrs.clone();
399 let node_clone = node.clone();
400
401 let fut = async move {
402 let result = send_and_await_chunk_response(
403 &node_clone,
404 &peer_id_clone,
405 message_bytes,
406 request_id,
407 timeout,
408 &addrs_clone,
409 |body| match body {
410 ChunkMessageBody::MerkleCandidateQuoteResponse(
411 MerkleCandidateQuoteResponse::Success { candidate_node },
412 ) => {
413 match rmp_serde::from_slice::<MerklePaymentCandidateNode>(
414 &candidate_node,
415 ) {
416 Ok(node) => Some(Ok(node)),
417 Err(e) => Some(Err(Error::Serialization(format!(
418 "Failed to deserialize candidate node from {peer_id_clone}: {e}"
419 )))),
420 }
421 }
422 ChunkMessageBody::MerkleCandidateQuoteResponse(
423 MerkleCandidateQuoteResponse::Error(e),
424 ) => Some(Err(Error::Protocol(format!(
425 "Merkle quote error from {peer_id_clone}: {e}"
426 )))),
427 _ => None,
428 },
429 |e| {
430 Error::Network(format!(
431 "Failed to send merkle candidate request to {peer_id_clone}: {e}"
432 ))
433 },
434 || {
435 Error::Timeout(format!(
436 "Timeout waiting for merkle candidate from {peer_id_clone}"
437 ))
438 },
439 )
440 .await;
441
442 (peer_id_clone, result)
443 };
444
445 candidate_futures.push(fut);
446 }
447
448 self.collect_validated_candidates(&mut candidate_futures, address, merkle_payment_timestamp)
449 .await
450 }
451
452 async fn collect_validated_candidates(
463 &self,
464 futures: &mut FuturesUnordered<
465 impl std::future::Future<
466 Output = (
467 PeerId,
468 std::result::Result<MerklePaymentCandidateNode, Error>,
469 ),
470 >,
471 >,
472 target_address: &[u8; 32],
473 merkle_payment_timestamp: u64,
474 ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
475 let mut valid: Vec<(PeerId, MerklePaymentCandidateNode)> = Vec::new();
476 let mut failures: Vec<String> = Vec::new();
477
478 while let Some((peer_id, result)) = futures.next().await {
479 match result {
480 Ok(candidate) => {
481 if !verify_merkle_candidate_signature(&candidate) {
482 warn!("Invalid ML-DSA-65 signature from merkle candidate {peer_id}");
483 failures.push(format!("{peer_id}: invalid signature"));
484 continue;
485 }
486 if candidate.merkle_payment_timestamp != merkle_payment_timestamp {
487 warn!("Timestamp mismatch from merkle candidate {peer_id}");
488 failures.push(format!("{peer_id}: timestamp mismatch"));
489 continue;
490 }
491 valid.push((peer_id, candidate));
492 }
493 Err(e) => {
494 debug!("Failed to get merkle candidate from {peer_id}: {e}");
495 failures.push(format!("{peer_id}: {e}"));
496 }
497 }
498 }
499
500 if valid.len() < CANDIDATES_PER_POOL {
501 return Err(Error::InsufficientPeers(format!(
502 "Got {} merkle candidates, need {CANDIDATES_PER_POOL}. Failures: [{}]",
503 valid.len(),
504 failures.join("; ")
505 )));
506 }
507
508 let target_peer = PeerId::from_bytes(*target_address);
509 valid.sort_by_key(|(peer_id, _)| peer_id.xor_distance(&target_peer));
510
511 let candidates: Vec<MerklePaymentCandidateNode> = valid
512 .into_iter()
513 .take(CANDIDATES_PER_POOL)
514 .map(|(_, candidate)| candidate)
515 .collect();
516
517 candidates
518 .try_into()
519 .map_err(|_| Error::Payment("Failed to convert candidates to fixed array".to_string()))
520 }
521
522 pub(crate) async fn merkle_upload_chunks(
532 &self,
533 chunk_contents: Vec<Bytes>,
534 addresses: Vec<[u8; 32]>,
535 batch_result: &MerkleBatchPaymentResult,
536 progress: Option<&mpsc::Sender<UploadEvent>>,
537 ) -> Result<usize> {
538 let mut stored = 0usize;
539 let store_limiter = self.controller().store.clone();
540 let batch_size = chunk_contents.len();
543 let store_concurrency = store_limiter.current().min(batch_size.max(1));
544 let mut upload_stream = stream::iter(chunk_contents.into_iter().zip(addresses).map(
545 |(content, addr)| {
546 let proof_bytes = batch_result.proofs.get(&addr).cloned();
547 let limiter = store_limiter.clone();
548 async move {
549 let proof = proof_bytes.ok_or_else(|| {
550 Error::Payment(format!(
551 "Missing merkle proof for chunk {}",
552 hex::encode(addr)
553 ))
554 })?;
555 let peers = self.close_group_peers(&addr).await?;
556 observe_op(
557 &limiter,
558 || async move {
559 self.chunk_put_to_close_group(content, proof, &peers).await
560 },
561 classify_error,
562 )
563 .await
564 }
565 },
566 ))
567 .buffer_unordered(store_concurrency);
568
569 while let Some(result) = upload_stream.next().await {
570 result?;
571 stored += 1;
572 if let Some(tx) = progress {
573 let _ = tx.try_send(UploadEvent::ChunkStored {
574 stored,
575 total: batch_size,
576 });
577 }
578 }
579
580 Ok(stored)
581 }
582}
583
584pub fn finalize_merkle_batch(
589 prepared: PreparedMerkleBatch,
590 winner_pool_hash: [u8; 32],
591) -> Result<MerkleBatchPaymentResult> {
592 let chunk_count = prepared.addresses.len();
593 let xornames: Vec<XorName> = prepared.addresses.iter().map(|a| XorName(*a)).collect();
594
595 let winner_pool = prepared
597 .candidate_pools
598 .iter()
599 .find(|pool| pool.hash() == winner_pool_hash)
600 .ok_or_else(|| {
601 Error::Payment(format!(
602 "Winner pool {} not found in candidate pools",
603 hex::encode(winner_pool_hash)
604 ))
605 })?;
606
607 info!("Generating merkle proofs for {chunk_count} chunks");
609 let mut proofs = HashMap::with_capacity(chunk_count);
610
611 for (i, xorname) in xornames.iter().enumerate() {
612 let address_proof = prepared
613 .tree
614 .generate_address_proof(i, *xorname)
615 .map_err(|e| {
616 Error::Payment(format!(
617 "Failed to generate address proof for chunk {i}: {e}"
618 ))
619 })?;
620
621 let merkle_proof = MerklePaymentProof::new(*xorname, address_proof, winner_pool.clone());
622
623 let tagged_bytes = serialize_merkle_proof(&merkle_proof)
624 .map_err(|e| Error::Serialization(format!("Failed to serialize merkle proof: {e}")))?;
625
626 proofs.insert(prepared.addresses[i], tagged_bytes);
627 }
628
629 info!("Merkle batch payment complete: {chunk_count} proofs generated");
630
631 Ok(MerkleBatchPaymentResult {
632 proofs,
633 chunk_count,
634 storage_cost_atto: "0".to_string(),
635 gas_cost_wei: 0,
636 })
637}
638
639#[cfg(test)]
641mod send_assertions {
642 use super::*;
643 use crate::data::client::Client;
644
645 fn _assert_send<T: Send>(_: &T) {}
646
647 #[allow(
648 dead_code,
649 unreachable_code,
650 unused_variables,
651 clippy::diverging_sub_expression
652 )]
653 async fn _merkle_upload_chunks_is_send(client: &Client) {
654 let batch_result: MerkleBatchPaymentResult = todo!();
655 let fut = client.merkle_upload_chunks(Vec::new(), Vec::new(), &batch_result, None);
656 _assert_send(&fut);
657 }
658}
659
660#[cfg(test)]
661#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
662mod tests {
663 use super::*;
664 use ant_protocol::evm::{Amount, MerkleTree, RewardsAddress, CANDIDATES_PER_POOL};
665
666 #[test]
671 fn test_auto_below_threshold() {
672 assert!(!should_use_merkle(1, PaymentMode::Auto));
673 assert!(!should_use_merkle(10, PaymentMode::Auto));
674 assert!(!should_use_merkle(63, PaymentMode::Auto));
675 }
676
677 #[test]
678 fn test_auto_at_and_above_threshold() {
679 assert!(should_use_merkle(64, PaymentMode::Auto));
680 assert!(should_use_merkle(65, PaymentMode::Auto));
681 assert!(should_use_merkle(1000, PaymentMode::Auto));
682 }
683
684 #[test]
685 fn test_merkle_mode_forces_at_2() {
686 assert!(!should_use_merkle(1, PaymentMode::Merkle));
687 assert!(should_use_merkle(2, PaymentMode::Merkle));
688 assert!(should_use_merkle(3, PaymentMode::Merkle));
689 }
690
691 #[test]
692 fn test_single_mode_always_false() {
693 assert!(!should_use_merkle(0, PaymentMode::Single));
694 assert!(!should_use_merkle(64, PaymentMode::Single));
695 assert!(!should_use_merkle(1000, PaymentMode::Single));
696 }
697
698 #[test]
699 fn test_default_mode_is_auto() {
700 assert_eq!(PaymentMode::default(), PaymentMode::Auto);
701 }
702
703 #[test]
704 fn test_threshold_value() {
705 assert_eq!(DEFAULT_MERKLE_THRESHOLD, 64);
706 }
707
708 fn make_test_addresses(count: usize) -> Vec<[u8; 32]> {
713 (0..count)
714 .map(|i| {
715 let xn = XorName::from_content(&i.to_le_bytes());
716 xn.0
717 })
718 .collect()
719 }
720
721 #[test]
722 fn test_tree_depth_for_known_sizes() {
723 let cases = [(2, 1), (4, 2), (16, 4), (100, 7), (256, 8)];
724 for (count, expected_depth) in cases {
725 let addrs = make_test_addresses(count);
726 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
727 let tree = MerkleTree::from_xornames(xornames).unwrap();
728 assert_eq!(
729 tree.depth(),
730 expected_depth,
731 "depth mismatch for {count} leaves"
732 );
733 }
734 }
735
736 #[test]
737 fn test_proof_generation_and_verification_for_all_leaves() {
738 let addrs = make_test_addresses(16);
739 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
740 let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
741
742 for (i, xn) in xornames.iter().enumerate() {
743 let proof = tree.generate_address_proof(i, *xn).unwrap();
744 assert!(proof.verify(), "proof for leaf {i} should verify");
745 assert_eq!(proof.depth(), tree.depth() as usize);
746 }
747 }
748
749 #[test]
750 fn test_proof_fails_for_wrong_address() {
751 let addrs = make_test_addresses(8);
752 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
753 let tree = MerkleTree::from_xornames(xornames).unwrap();
754
755 let wrong = XorName::from_content(b"wrong");
756 let proof = tree.generate_address_proof(0, wrong).unwrap();
757 assert!(!proof.verify(), "proof with wrong address should fail");
758 }
759
760 #[test]
761 fn test_tree_too_few_leaves() {
762 let xornames = vec![XorName::from_content(b"only_one")];
763 let result = MerkleTree::from_xornames(xornames);
764 assert!(result.is_err());
765 }
766
767 #[test]
768 fn test_tree_at_max_leaves() {
769 let addrs = make_test_addresses(MAX_LEAVES);
770 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
771 let tree = MerkleTree::from_xornames(xornames).unwrap();
772 assert_eq!(tree.leaf_count(), MAX_LEAVES);
773 }
774
775 #[test]
780 fn test_merkle_proof_serialize_deserialize_roundtrip() {
781 use ant_protocol::evm::{Amount, MerklePaymentCandidateNode, RewardsAddress};
782 use ant_protocol::payment::{deserialize_merkle_proof, serialize_merkle_proof};
783
784 let addrs = make_test_addresses(4);
785 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
786 let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
787
788 let timestamp = std::time::SystemTime::now()
789 .duration_since(std::time::UNIX_EPOCH)
790 .unwrap()
791 .as_secs();
792
793 let candidates = tree.reward_candidates(timestamp).unwrap();
794 let midpoint = candidates.first().unwrap().clone();
795
796 #[allow(clippy::cast_possible_truncation)]
798 let candidate_nodes: [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] =
799 std::array::from_fn(|i| MerklePaymentCandidateNode {
800 pub_key: vec![i as u8; 32],
801 price: Amount::from(1024u64),
802 reward_address: RewardsAddress::new([i as u8; 20]),
803 merkle_payment_timestamp: timestamp,
804 signature: vec![i as u8; 64],
805 });
806
807 let pool = MerklePaymentCandidatePool {
808 midpoint_proof: midpoint,
809 candidate_nodes,
810 };
811
812 let address_proof = tree.generate_address_proof(0, xornames[0]).unwrap();
813 let merkle_proof = MerklePaymentProof::new(xornames[0], address_proof, pool);
814
815 let tagged = serialize_merkle_proof(&merkle_proof).unwrap();
816 assert_eq!(
817 tagged.first().copied(),
818 Some(0x02),
819 "tag should be PROOF_TAG_MERKLE"
820 );
821
822 let deserialized = deserialize_merkle_proof(&tagged).unwrap();
823 assert_eq!(deserialized.address, merkle_proof.address);
824 assert_eq!(
825 deserialized.winner_pool.candidate_nodes.len(),
826 CANDIDATES_PER_POOL
827 );
828 }
829
830 #[test]
835 fn test_candidate_wrong_timestamp_rejected() {
836 let candidate = MerklePaymentCandidateNode {
838 pub_key: vec![0u8; 32],
839 price: ant_protocol::evm::Amount::ZERO,
840 reward_address: ant_protocol::evm::RewardsAddress::new([0u8; 20]),
841 merkle_payment_timestamp: 1000,
842 signature: vec![0u8; 64],
843 };
844
845 assert_ne!(candidate.merkle_payment_timestamp, 2000);
847 }
848
849 fn make_dummy_candidate_nodes(
854 timestamp: u64,
855 ) -> [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] {
856 std::array::from_fn(|i| MerklePaymentCandidateNode {
857 pub_key: vec![i as u8; 32],
858 price: Amount::from(1024u64),
859 reward_address: RewardsAddress::new([i as u8; 20]),
860 merkle_payment_timestamp: timestamp,
861 signature: vec![i as u8; 64],
862 })
863 }
864
865 fn make_prepared_merkle_batch(count: usize) -> PreparedMerkleBatch {
866 let addrs = make_test_addresses(count);
867 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
868 let tree = MerkleTree::from_xornames(xornames).unwrap();
869
870 let timestamp = std::time::SystemTime::now()
871 .duration_since(std::time::UNIX_EPOCH)
872 .unwrap()
873 .as_secs();
874
875 let midpoints = tree.reward_candidates(timestamp).unwrap();
876
877 let candidate_pools: Vec<MerklePaymentCandidatePool> = midpoints
878 .into_iter()
879 .map(|mp| MerklePaymentCandidatePool {
880 midpoint_proof: mp,
881 candidate_nodes: make_dummy_candidate_nodes(timestamp),
882 })
883 .collect();
884
885 let pool_commitments = candidate_pools
886 .iter()
887 .map(MerklePaymentCandidatePool::to_commitment)
888 .collect();
889
890 PreparedMerkleBatch {
891 depth: tree.depth(),
892 pool_commitments,
893 merkle_payment_timestamp: timestamp,
894 candidate_pools,
895 tree,
896 addresses: addrs,
897 }
898 }
899
900 #[test]
901 fn test_finalize_merkle_batch_with_valid_winner() {
902 let prepared = make_prepared_merkle_batch(4);
903 let winner_hash = prepared.candidate_pools[0].hash();
904
905 let result = finalize_merkle_batch(prepared, winner_hash);
906 assert!(
907 result.is_ok(),
908 "should succeed with valid winner: {result:?}"
909 );
910
911 let batch = result.unwrap();
912 assert_eq!(batch.chunk_count, 4);
913 assert_eq!(batch.proofs.len(), 4);
914
915 for proof_bytes in batch.proofs.values() {
917 assert!(!proof_bytes.is_empty());
918 }
919 }
920
921 #[test]
922 fn test_finalize_merkle_batch_with_invalid_winner() {
923 let prepared = make_prepared_merkle_batch(4);
924 let bad_hash = [0xFF; 32];
925
926 let result = finalize_merkle_batch(prepared, bad_hash);
927 assert!(result.is_err());
928 let err = result.unwrap_err().to_string();
929 assert!(err.contains("not found in candidate pools"), "got: {err}");
930 }
931
932 #[test]
933 fn test_finalize_merkle_batch_proofs_are_deserializable() {
934 use ant_protocol::payment::deserialize_merkle_proof;
935
936 let prepared = make_prepared_merkle_batch(8);
937 let winner_hash = prepared.candidate_pools[0].hash();
938
939 let batch = finalize_merkle_batch(prepared, winner_hash).unwrap();
940
941 for (addr, proof_bytes) in &batch.proofs {
942 let proof = deserialize_merkle_proof(proof_bytes);
943 assert!(
944 proof.is_ok(),
945 "proof for {} should deserialize: {:?}",
946 hex::encode(addr),
947 proof.err()
948 );
949 }
950 }
951
952 #[test]
957 fn test_batch_split_calculation() {
958 let addrs = make_test_addresses(MAX_LEAVES);
960 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 1);
961
962 let addrs = make_test_addresses(MAX_LEAVES + 1);
964 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 2);
965
966 let addrs = make_test_addresses(3 * MAX_LEAVES);
968 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 3);
969 }
970}