1use crate::data::client::Client;
8use crate::data::error::{Error, Result};
9use ant_protocol::evm::{
10 Amount, MerklePaymentCandidateNode, MerklePaymentCandidatePool, MerklePaymentProof, MerkleTree,
11 MidpointProof, PoolCommitment, CANDIDATES_PER_POOL, MAX_LEAVES,
12};
13use ant_protocol::payment::{serialize_merkle_proof, verify_merkle_candidate_signature};
14use ant_protocol::transport::PeerId;
15use ant_protocol::{
16 send_and_await_chunk_response, ChunkMessage, ChunkMessageBody, MerkleCandidateQuoteRequest,
17 MerkleCandidateQuoteResponse,
18};
19use bytes::Bytes;
20use futures::stream::{self, FuturesUnordered, StreamExt};
21use std::collections::HashMap;
22use std::time::Duration;
23use tracing::{debug, info, warn};
24use xor_name::XorName;
25
26pub const DEFAULT_MERKLE_THRESHOLD: usize = 64;
28
29#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum PaymentMode {
33 #[default]
35 Auto,
36 Merkle,
38 Single,
40}
41
42#[derive(Debug)]
44pub struct MerkleBatchPaymentResult {
45 pub proofs: HashMap<[u8; 32], Vec<u8>>,
47 pub chunk_count: usize,
49 pub storage_cost_atto: String,
51 pub gas_cost_wei: u128,
53}
54
55pub struct PreparedMerkleBatch {
60 pub depth: u8,
62 pub pool_commitments: Vec<PoolCommitment>,
64 pub merkle_payment_timestamp: u64,
66 candidate_pools: Vec<MerklePaymentCandidatePool>,
68 tree: MerkleTree,
70 addresses: Vec<[u8; 32]>,
72}
73
74impl std::fmt::Debug for PreparedMerkleBatch {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("PreparedMerkleBatch")
77 .field("depth", &self.depth)
78 .field("pool_commitments", &self.pool_commitments.len())
79 .field("merkle_payment_timestamp", &self.merkle_payment_timestamp)
80 .field("candidate_pools", &self.candidate_pools.len())
81 .field("addresses", &self.addresses.len())
82 .finish()
83 }
84}
85
86#[must_use]
89pub fn should_use_merkle(chunk_count: usize, mode: PaymentMode) -> bool {
90 match mode {
91 PaymentMode::Auto => chunk_count >= DEFAULT_MERKLE_THRESHOLD,
92 PaymentMode::Merkle => chunk_count >= 2,
93 PaymentMode::Single => false,
94 }
95}
96
97impl Client {
98 #[must_use]
100 pub fn should_use_merkle(&self, chunk_count: usize, mode: PaymentMode) -> bool {
101 should_use_merkle(chunk_count, mode)
102 }
103
104 pub async fn pay_for_merkle_batch(
117 &self,
118 addresses: &[[u8; 32]],
119 data_type: u32,
120 data_size: u64,
121 ) -> Result<MerkleBatchPaymentResult> {
122 let chunk_count = addresses.len();
123 if chunk_count < 2 {
124 return Err(Error::Payment(
125 "Merkle batch payment requires at least 2 chunks".to_string(),
126 ));
127 }
128
129 if chunk_count > MAX_LEAVES {
130 return self
131 .pay_for_merkle_multi_batch(addresses, data_type, data_size)
132 .await;
133 }
134
135 self.pay_for_merkle_single_batch(addresses, data_type, data_size)
136 .await
137 }
138
139 pub async fn prepare_merkle_batch_external(
145 &self,
146 addresses: &[[u8; 32]],
147 data_type: u32,
148 data_size: u64,
149 ) -> Result<PreparedMerkleBatch> {
150 let chunk_count = addresses.len();
151 let xornames: Vec<XorName> = addresses.iter().map(|a| XorName(*a)).collect();
152
153 debug!("Building merkle tree for {chunk_count} chunks");
154
155 let tree = MerkleTree::from_xornames(xornames)
157 .map_err(|e| Error::Payment(format!("Failed to build merkle tree: {e}")))?;
158
159 let depth = tree.depth();
160 let merkle_payment_timestamp = std::time::SystemTime::now()
161 .duration_since(std::time::UNIX_EPOCH)
162 .map_err(|e| Error::Payment(format!("System time error: {e}")))?
163 .as_secs();
164
165 debug!("Merkle tree: depth={depth}, leaves={chunk_count}, ts={merkle_payment_timestamp}");
166
167 let midpoint_proofs = tree
169 .reward_candidates(merkle_payment_timestamp)
170 .map_err(|e| Error::Payment(format!("Failed to generate reward candidates: {e}")))?;
171
172 debug!(
173 "Collecting candidate pools from {} midpoints (concurrent)",
174 midpoint_proofs.len()
175 );
176
177 let candidate_pools = self
179 .build_candidate_pools(
180 &midpoint_proofs,
181 data_type,
182 data_size,
183 merkle_payment_timestamp,
184 )
185 .await?;
186
187 let pool_commitments: Vec<PoolCommitment> = candidate_pools
189 .iter()
190 .map(MerklePaymentCandidatePool::to_commitment)
191 .collect();
192
193 Ok(PreparedMerkleBatch {
194 depth,
195 pool_commitments,
196 merkle_payment_timestamp,
197 candidate_pools,
198 tree,
199 addresses: addresses.to_vec(),
200 })
201 }
202
203 async fn pay_for_merkle_single_batch(
205 &self,
206 addresses: &[[u8; 32]],
207 data_type: u32,
208 data_size: u64,
209 ) -> Result<MerkleBatchPaymentResult> {
210 let wallet = self.require_wallet()?;
211 let prepared = self
212 .prepare_merkle_batch_external(addresses, data_type, data_size)
213 .await?;
214
215 info!(
216 "Submitting merkle batch payment on-chain (depth={})",
217 prepared.depth
218 );
219 let (winner_pool_hash, amount, gas_info) = wallet
220 .pay_for_merkle_tree(
221 prepared.depth,
222 prepared.pool_commitments.clone(),
223 prepared.merkle_payment_timestamp,
224 )
225 .await
226 .map_err(|e| Error::Payment(format!("Merkle batch payment failed: {e}")))?;
227
228 info!(
229 "Merkle payment succeeded: winner pool {}",
230 hex::encode(winner_pool_hash)
231 );
232
233 let mut result = finalize_merkle_batch(prepared, winner_pool_hash)?;
234 result.storage_cost_atto = amount.to_string();
235 result.gas_cost_wei = gas_info.gas_cost_wei;
236 Ok(result)
237 }
238
239 async fn pay_for_merkle_multi_batch(
241 &self,
242 addresses: &[[u8; 32]],
243 data_type: u32,
244 data_size: u64,
245 ) -> Result<MerkleBatchPaymentResult> {
246 let sub_batches: Vec<&[[u8; 32]]> = addresses.chunks(MAX_LEAVES).collect();
247 let total_sub_batches = sub_batches.len();
248 let mut all_proofs = HashMap::with_capacity(addresses.len());
249 let mut total_storage = Amount::ZERO;
250 let mut total_gas: u128 = 0;
251
252 for (i, chunk) in sub_batches.into_iter().enumerate() {
253 match self
254 .pay_for_merkle_single_batch(chunk, data_type, data_size)
255 .await
256 {
257 Ok(sub_result) => {
258 if let Ok(cost) = sub_result.storage_cost_atto.parse::<Amount>() {
259 total_storage += cost;
260 }
261 total_gas = total_gas.saturating_add(sub_result.gas_cost_wei);
262 all_proofs.extend(sub_result.proofs);
263 }
264 Err(e) => {
265 if all_proofs.is_empty() {
266 return Err(e);
268 }
269 warn!(
271 "Merkle sub-batch {}/{total_sub_batches} failed: {e}. \
272 Returning {} proofs from prior sub-batches",
273 i + 1,
274 all_proofs.len()
275 );
276 return Ok(MerkleBatchPaymentResult {
277 chunk_count: all_proofs.len(),
278 proofs: all_proofs,
279 storage_cost_atto: total_storage.to_string(),
280 gas_cost_wei: total_gas,
281 });
282 }
283 }
284 }
285
286 Ok(MerkleBatchPaymentResult {
287 chunk_count: addresses.len(),
288 proofs: all_proofs,
289 storage_cost_atto: total_storage.to_string(),
290 gas_cost_wei: total_gas,
291 })
292 }
293
294 async fn build_candidate_pools(
296 &self,
297 midpoint_proofs: &[MidpointProof],
298 data_type: u32,
299 data_size: u64,
300 merkle_payment_timestamp: u64,
301 ) -> Result<Vec<MerklePaymentCandidatePool>> {
302 let mut pool_futures = FuturesUnordered::new();
303
304 for midpoint_proof in midpoint_proofs {
305 let pool_address = midpoint_proof.address();
306 let mp = midpoint_proof.clone();
307 pool_futures.push(async move {
308 let candidate_nodes = self
309 .get_merkle_candidate_pool(
310 &pool_address.0,
311 data_type,
312 data_size,
313 merkle_payment_timestamp,
314 )
315 .await?;
316 Ok::<_, Error>(MerklePaymentCandidatePool {
317 midpoint_proof: mp,
318 candidate_nodes,
319 })
320 });
321 }
322
323 let mut pools = Vec::with_capacity(midpoint_proofs.len());
324 while let Some(result) = pool_futures.next().await {
325 pools.push(result?);
326 }
327
328 Ok(pools)
329 }
330
331 #[allow(clippy::too_many_lines)]
333 async fn get_merkle_candidate_pool(
334 &self,
335 address: &[u8; 32],
336 data_type: u32,
337 data_size: u64,
338 merkle_payment_timestamp: u64,
339 ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
340 let node = self.network().node();
341 let timeout = Duration::from_secs(self.config().quote_timeout_secs);
342
343 let query_count = CANDIDATES_PER_POOL * 2;
345 let mut remote_peers = self
346 .network()
347 .find_closest_peers(address, query_count)
348 .await?;
349
350 if remote_peers.len() < CANDIDATES_PER_POOL {
354 let connected = self.network().connected_peers().await;
355 for peer in connected {
356 if !remote_peers.iter().any(|(id, _)| *id == peer) {
357 remote_peers.push((peer, vec![]));
358 }
359 }
360 }
361
362 if remote_peers.len() < CANDIDATES_PER_POOL {
363 return Err(Error::InsufficientPeers(format!(
364 "Found {} peers, need {CANDIDATES_PER_POOL} for merkle candidate pool. \
365 Use --no-merkle or a larger network.",
366 remote_peers.len()
367 )));
368 }
369
370 let mut candidate_futures = FuturesUnordered::new();
371
372 for (peer_id, peer_addrs) in &remote_peers {
373 let request_id = self.next_request_id();
374 let request = MerkleCandidateQuoteRequest {
375 address: *address,
376 data_type,
377 data_size,
378 merkle_payment_timestamp,
379 };
380 let message = ChunkMessage {
381 request_id,
382 body: ChunkMessageBody::MerkleCandidateQuoteRequest(request),
383 };
384
385 let message_bytes = match message.encode() {
386 Ok(bytes) => bytes,
387 Err(e) => {
388 warn!("Failed to encode merkle candidate request for {peer_id}: {e}");
389 continue;
390 }
391 };
392
393 let peer_id_clone = *peer_id;
394 let addrs_clone = peer_addrs.clone();
395 let node_clone = node.clone();
396
397 let fut = async move {
398 let result = send_and_await_chunk_response(
399 &node_clone,
400 &peer_id_clone,
401 message_bytes,
402 request_id,
403 timeout,
404 &addrs_clone,
405 |body| match body {
406 ChunkMessageBody::MerkleCandidateQuoteResponse(
407 MerkleCandidateQuoteResponse::Success { candidate_node },
408 ) => {
409 match rmp_serde::from_slice::<MerklePaymentCandidateNode>(
410 &candidate_node,
411 ) {
412 Ok(node) => Some(Ok(node)),
413 Err(e) => Some(Err(Error::Serialization(format!(
414 "Failed to deserialize candidate node from {peer_id_clone}: {e}"
415 )))),
416 }
417 }
418 ChunkMessageBody::MerkleCandidateQuoteResponse(
419 MerkleCandidateQuoteResponse::Error(e),
420 ) => Some(Err(Error::Protocol(format!(
421 "Merkle quote error from {peer_id_clone}: {e}"
422 )))),
423 _ => None,
424 },
425 |e| {
426 Error::Network(format!(
427 "Failed to send merkle candidate request to {peer_id_clone}: {e}"
428 ))
429 },
430 || {
431 Error::Timeout(format!(
432 "Timeout waiting for merkle candidate from {peer_id_clone}"
433 ))
434 },
435 )
436 .await;
437
438 (peer_id_clone, result)
439 };
440
441 candidate_futures.push(fut);
442 }
443
444 self.collect_validated_candidates(&mut candidate_futures, merkle_payment_timestamp)
445 .await
446 }
447
448 async fn collect_validated_candidates(
450 &self,
451 futures: &mut FuturesUnordered<
452 impl std::future::Future<
453 Output = (
454 PeerId,
455 std::result::Result<MerklePaymentCandidateNode, Error>,
456 ),
457 >,
458 >,
459 merkle_payment_timestamp: u64,
460 ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
461 let mut candidates = Vec::with_capacity(CANDIDATES_PER_POOL);
462 let mut failures: Vec<String> = Vec::new();
463
464 while let Some((peer_id, result)) = futures.next().await {
465 match result {
466 Ok(candidate) => {
467 if !verify_merkle_candidate_signature(&candidate) {
468 warn!("Invalid ML-DSA-65 signature from merkle candidate {peer_id}");
469 failures.push(format!("{peer_id}: invalid signature"));
470 continue;
471 }
472 if candidate.merkle_payment_timestamp != merkle_payment_timestamp {
473 warn!("Timestamp mismatch from merkle candidate {peer_id}");
474 failures.push(format!("{peer_id}: timestamp mismatch"));
475 continue;
476 }
477 candidates.push(candidate);
478 if candidates.len() >= CANDIDATES_PER_POOL {
479 break;
480 }
481 }
482 Err(e) => {
483 debug!("Failed to get merkle candidate from {peer_id}: {e}");
484 failures.push(format!("{peer_id}: {e}"));
485 }
486 }
487 }
488
489 if candidates.len() < CANDIDATES_PER_POOL {
490 return Err(Error::InsufficientPeers(format!(
491 "Got {} merkle candidates, need {CANDIDATES_PER_POOL}. Failures: [{}]",
492 candidates.len(),
493 failures.join("; ")
494 )));
495 }
496
497 candidates.truncate(CANDIDATES_PER_POOL);
498 candidates
499 .try_into()
500 .map_err(|_| Error::Payment("Failed to convert candidates to fixed array".to_string()))
501 }
502
503 pub(crate) async fn merkle_upload_chunks(
513 &self,
514 chunk_contents: Vec<Bytes>,
515 addresses: Vec<[u8; 32]>,
516 batch_result: &MerkleBatchPaymentResult,
517 ) -> Result<usize> {
518 let mut stored = 0usize;
519 let mut upload_stream = stream::iter(chunk_contents.into_iter().zip(addresses).map(
520 |(content, addr)| {
521 let proof_bytes = batch_result.proofs.get(&addr).cloned();
522 async move {
523 let proof = proof_bytes.ok_or_else(|| {
524 Error::Payment(format!(
525 "Missing merkle proof for chunk {}",
526 hex::encode(addr)
527 ))
528 })?;
529 let peers = self.close_group_peers(&addr).await?;
530 self.chunk_put_to_close_group(content, proof, &peers).await
531 }
532 },
533 ))
534 .buffer_unordered(self.config().store_concurrency);
535
536 while let Some(result) = upload_stream.next().await {
537 result?;
538 stored += 1;
539 }
540
541 Ok(stored)
542 }
543}
544
545pub fn finalize_merkle_batch(
550 prepared: PreparedMerkleBatch,
551 winner_pool_hash: [u8; 32],
552) -> Result<MerkleBatchPaymentResult> {
553 let chunk_count = prepared.addresses.len();
554 let xornames: Vec<XorName> = prepared.addresses.iter().map(|a| XorName(*a)).collect();
555
556 let winner_pool = prepared
558 .candidate_pools
559 .iter()
560 .find(|pool| pool.hash() == winner_pool_hash)
561 .ok_or_else(|| {
562 Error::Payment(format!(
563 "Winner pool {} not found in candidate pools",
564 hex::encode(winner_pool_hash)
565 ))
566 })?;
567
568 info!("Generating merkle proofs for {chunk_count} chunks");
570 let mut proofs = HashMap::with_capacity(chunk_count);
571
572 for (i, xorname) in xornames.iter().enumerate() {
573 let address_proof = prepared
574 .tree
575 .generate_address_proof(i, *xorname)
576 .map_err(|e| {
577 Error::Payment(format!(
578 "Failed to generate address proof for chunk {i}: {e}"
579 ))
580 })?;
581
582 let merkle_proof = MerklePaymentProof::new(*xorname, address_proof, winner_pool.clone());
583
584 let tagged_bytes = serialize_merkle_proof(&merkle_proof)
585 .map_err(|e| Error::Serialization(format!("Failed to serialize merkle proof: {e}")))?;
586
587 proofs.insert(prepared.addresses[i], tagged_bytes);
588 }
589
590 info!("Merkle batch payment complete: {chunk_count} proofs generated");
591
592 Ok(MerkleBatchPaymentResult {
593 proofs,
594 chunk_count,
595 storage_cost_atto: "0".to_string(),
596 gas_cost_wei: 0,
597 })
598}
599
600#[cfg(test)]
602mod send_assertions {
603 use super::*;
604 use crate::data::client::Client;
605
606 fn _assert_send<T: Send>(_: &T) {}
607
608 #[allow(
609 dead_code,
610 unreachable_code,
611 unused_variables,
612 clippy::diverging_sub_expression
613 )]
614 async fn _merkle_upload_chunks_is_send(client: &Client) {
615 let batch_result: MerkleBatchPaymentResult = todo!();
616 let fut = client.merkle_upload_chunks(Vec::new(), Vec::new(), &batch_result);
617 _assert_send(&fut);
618 }
619}
620
621#[cfg(test)]
622#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
623mod tests {
624 use super::*;
625 use ant_protocol::evm::{Amount, MerkleTree, RewardsAddress, CANDIDATES_PER_POOL};
626
627 #[test]
632 fn test_auto_below_threshold() {
633 assert!(!should_use_merkle(1, PaymentMode::Auto));
634 assert!(!should_use_merkle(10, PaymentMode::Auto));
635 assert!(!should_use_merkle(63, PaymentMode::Auto));
636 }
637
638 #[test]
639 fn test_auto_at_and_above_threshold() {
640 assert!(should_use_merkle(64, PaymentMode::Auto));
641 assert!(should_use_merkle(65, PaymentMode::Auto));
642 assert!(should_use_merkle(1000, PaymentMode::Auto));
643 }
644
645 #[test]
646 fn test_merkle_mode_forces_at_2() {
647 assert!(!should_use_merkle(1, PaymentMode::Merkle));
648 assert!(should_use_merkle(2, PaymentMode::Merkle));
649 assert!(should_use_merkle(3, PaymentMode::Merkle));
650 }
651
652 #[test]
653 fn test_single_mode_always_false() {
654 assert!(!should_use_merkle(0, PaymentMode::Single));
655 assert!(!should_use_merkle(64, PaymentMode::Single));
656 assert!(!should_use_merkle(1000, PaymentMode::Single));
657 }
658
659 #[test]
660 fn test_default_mode_is_auto() {
661 assert_eq!(PaymentMode::default(), PaymentMode::Auto);
662 }
663
664 #[test]
665 fn test_threshold_value() {
666 assert_eq!(DEFAULT_MERKLE_THRESHOLD, 64);
667 }
668
669 fn make_test_addresses(count: usize) -> Vec<[u8; 32]> {
674 (0..count)
675 .map(|i| {
676 let xn = XorName::from_content(&i.to_le_bytes());
677 xn.0
678 })
679 .collect()
680 }
681
682 #[test]
683 fn test_tree_depth_for_known_sizes() {
684 let cases = [(2, 1), (4, 2), (16, 4), (100, 7), (256, 8)];
685 for (count, expected_depth) in cases {
686 let addrs = make_test_addresses(count);
687 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
688 let tree = MerkleTree::from_xornames(xornames).unwrap();
689 assert_eq!(
690 tree.depth(),
691 expected_depth,
692 "depth mismatch for {count} leaves"
693 );
694 }
695 }
696
697 #[test]
698 fn test_proof_generation_and_verification_for_all_leaves() {
699 let addrs = make_test_addresses(16);
700 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
701 let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
702
703 for (i, xn) in xornames.iter().enumerate() {
704 let proof = tree.generate_address_proof(i, *xn).unwrap();
705 assert!(proof.verify(), "proof for leaf {i} should verify");
706 assert_eq!(proof.depth(), tree.depth() as usize);
707 }
708 }
709
710 #[test]
711 fn test_proof_fails_for_wrong_address() {
712 let addrs = make_test_addresses(8);
713 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
714 let tree = MerkleTree::from_xornames(xornames).unwrap();
715
716 let wrong = XorName::from_content(b"wrong");
717 let proof = tree.generate_address_proof(0, wrong).unwrap();
718 assert!(!proof.verify(), "proof with wrong address should fail");
719 }
720
721 #[test]
722 fn test_tree_too_few_leaves() {
723 let xornames = vec![XorName::from_content(b"only_one")];
724 let result = MerkleTree::from_xornames(xornames);
725 assert!(result.is_err());
726 }
727
728 #[test]
729 fn test_tree_at_max_leaves() {
730 let addrs = make_test_addresses(MAX_LEAVES);
731 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
732 let tree = MerkleTree::from_xornames(xornames).unwrap();
733 assert_eq!(tree.leaf_count(), MAX_LEAVES);
734 }
735
736 #[test]
741 fn test_merkle_proof_serialize_deserialize_roundtrip() {
742 use ant_protocol::evm::{Amount, MerklePaymentCandidateNode, RewardsAddress};
743 use ant_protocol::payment::{deserialize_merkle_proof, serialize_merkle_proof};
744
745 let addrs = make_test_addresses(4);
746 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
747 let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
748
749 let timestamp = std::time::SystemTime::now()
750 .duration_since(std::time::UNIX_EPOCH)
751 .unwrap()
752 .as_secs();
753
754 let candidates = tree.reward_candidates(timestamp).unwrap();
755 let midpoint = candidates.first().unwrap().clone();
756
757 #[allow(clippy::cast_possible_truncation)]
759 let candidate_nodes: [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] =
760 std::array::from_fn(|i| MerklePaymentCandidateNode {
761 pub_key: vec![i as u8; 32],
762 price: Amount::from(1024u64),
763 reward_address: RewardsAddress::new([i as u8; 20]),
764 merkle_payment_timestamp: timestamp,
765 signature: vec![i as u8; 64],
766 });
767
768 let pool = MerklePaymentCandidatePool {
769 midpoint_proof: midpoint,
770 candidate_nodes,
771 };
772
773 let address_proof = tree.generate_address_proof(0, xornames[0]).unwrap();
774 let merkle_proof = MerklePaymentProof::new(xornames[0], address_proof, pool);
775
776 let tagged = serialize_merkle_proof(&merkle_proof).unwrap();
777 assert_eq!(
778 tagged.first().copied(),
779 Some(0x02),
780 "tag should be PROOF_TAG_MERKLE"
781 );
782
783 let deserialized = deserialize_merkle_proof(&tagged).unwrap();
784 assert_eq!(deserialized.address, merkle_proof.address);
785 assert_eq!(
786 deserialized.winner_pool.candidate_nodes.len(),
787 CANDIDATES_PER_POOL
788 );
789 }
790
791 #[test]
796 fn test_candidate_wrong_timestamp_rejected() {
797 let candidate = MerklePaymentCandidateNode {
799 pub_key: vec![0u8; 32],
800 price: ant_protocol::evm::Amount::ZERO,
801 reward_address: ant_protocol::evm::RewardsAddress::new([0u8; 20]),
802 merkle_payment_timestamp: 1000,
803 signature: vec![0u8; 64],
804 };
805
806 assert_ne!(candidate.merkle_payment_timestamp, 2000);
808 }
809
810 fn make_dummy_candidate_nodes(
815 timestamp: u64,
816 ) -> [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] {
817 std::array::from_fn(|i| MerklePaymentCandidateNode {
818 pub_key: vec![i as u8; 32],
819 price: Amount::from(1024u64),
820 reward_address: RewardsAddress::new([i as u8; 20]),
821 merkle_payment_timestamp: timestamp,
822 signature: vec![i as u8; 64],
823 })
824 }
825
826 fn make_prepared_merkle_batch(count: usize) -> PreparedMerkleBatch {
827 let addrs = make_test_addresses(count);
828 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
829 let tree = MerkleTree::from_xornames(xornames).unwrap();
830
831 let timestamp = std::time::SystemTime::now()
832 .duration_since(std::time::UNIX_EPOCH)
833 .unwrap()
834 .as_secs();
835
836 let midpoints = tree.reward_candidates(timestamp).unwrap();
837
838 let candidate_pools: Vec<MerklePaymentCandidatePool> = midpoints
839 .into_iter()
840 .map(|mp| MerklePaymentCandidatePool {
841 midpoint_proof: mp,
842 candidate_nodes: make_dummy_candidate_nodes(timestamp),
843 })
844 .collect();
845
846 let pool_commitments = candidate_pools
847 .iter()
848 .map(MerklePaymentCandidatePool::to_commitment)
849 .collect();
850
851 PreparedMerkleBatch {
852 depth: tree.depth(),
853 pool_commitments,
854 merkle_payment_timestamp: timestamp,
855 candidate_pools,
856 tree,
857 addresses: addrs,
858 }
859 }
860
861 #[test]
862 fn test_finalize_merkle_batch_with_valid_winner() {
863 let prepared = make_prepared_merkle_batch(4);
864 let winner_hash = prepared.candidate_pools[0].hash();
865
866 let result = finalize_merkle_batch(prepared, winner_hash);
867 assert!(
868 result.is_ok(),
869 "should succeed with valid winner: {result:?}"
870 );
871
872 let batch = result.unwrap();
873 assert_eq!(batch.chunk_count, 4);
874 assert_eq!(batch.proofs.len(), 4);
875
876 for proof_bytes in batch.proofs.values() {
878 assert!(!proof_bytes.is_empty());
879 }
880 }
881
882 #[test]
883 fn test_finalize_merkle_batch_with_invalid_winner() {
884 let prepared = make_prepared_merkle_batch(4);
885 let bad_hash = [0xFF; 32];
886
887 let result = finalize_merkle_batch(prepared, bad_hash);
888 assert!(result.is_err());
889 let err = result.unwrap_err().to_string();
890 assert!(err.contains("not found in candidate pools"), "got: {err}");
891 }
892
893 #[test]
894 fn test_finalize_merkle_batch_proofs_are_deserializable() {
895 use ant_protocol::payment::deserialize_merkle_proof;
896
897 let prepared = make_prepared_merkle_batch(8);
898 let winner_hash = prepared.candidate_pools[0].hash();
899
900 let batch = finalize_merkle_batch(prepared, winner_hash).unwrap();
901
902 for (addr, proof_bytes) in &batch.proofs {
903 let proof = deserialize_merkle_proof(proof_bytes);
904 assert!(
905 proof.is_ok(),
906 "proof for {} should deserialize: {:?}",
907 hex::encode(addr),
908 proof.err()
909 );
910 }
911 }
912
913 #[test]
918 fn test_batch_split_calculation() {
919 let addrs = make_test_addresses(MAX_LEAVES);
921 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 1);
922
923 let addrs = make_test_addresses(MAX_LEAVES + 1);
925 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 2);
926
927 let addrs = make_test_addresses(3 * MAX_LEAVES);
929 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 3);
930 }
931}