1use crate::data::client::adaptive::observe_op;
8use crate::data::client::classify_error;
9use crate::data::client::file::UploadEvent;
10use crate::data::client::Client;
11use crate::data::error::{Error, Result};
12use ant_protocol::evm::{
13 Amount, MerklePaymentCandidateNode, MerklePaymentCandidatePool, MerklePaymentProof, MerkleTree,
14 MidpointProof, PoolCommitment, CANDIDATES_PER_POOL, MAX_LEAVES,
15};
16use ant_protocol::payment::{serialize_merkle_proof, verify_merkle_candidate_signature};
17use ant_protocol::transport::PeerId;
18use ant_protocol::{
19 send_and_await_chunk_response, ChunkMessage, ChunkMessageBody, MerkleCandidateQuoteRequest,
20 MerkleCandidateQuoteResponse,
21};
22use bytes::Bytes;
23use futures::stream::{self, FuturesUnordered, StreamExt};
24use std::collections::HashMap;
25use std::time::Duration;
26use tokio::sync::mpsc;
27use tracing::{debug, info, warn};
28use xor_name::XorName;
29
30pub const DEFAULT_MERKLE_THRESHOLD: usize = 64;
32
33#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum PaymentMode {
37 #[default]
39 Auto,
40 Merkle,
42 Single,
44}
45
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct MerkleBatchPaymentResult {
52 pub proofs: HashMap<[u8; 32], Vec<u8>>,
54 pub chunk_count: usize,
56 pub storage_cost_atto: String,
58 pub gas_cost_wei: u128,
60 #[serde(default)]
65 pub merkle_payment_timestamp: u64,
66}
67
68pub struct PreparedMerkleBatch {
73 pub depth: u8,
75 pub pool_commitments: Vec<PoolCommitment>,
77 pub merkle_payment_timestamp: u64,
79 candidate_pools: Vec<MerklePaymentCandidatePool>,
81 tree: MerkleTree,
83 addresses: Vec<[u8; 32]>,
85}
86
87impl std::fmt::Debug for PreparedMerkleBatch {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("PreparedMerkleBatch")
90 .field("depth", &self.depth)
91 .field("pool_commitments", &self.pool_commitments.len())
92 .field("merkle_payment_timestamp", &self.merkle_payment_timestamp)
93 .field("candidate_pools", &self.candidate_pools.len())
94 .field("addresses", &self.addresses.len())
95 .finish()
96 }
97}
98
99#[must_use]
102pub fn should_use_merkle(chunk_count: usize, mode: PaymentMode) -> bool {
103 match mode {
104 PaymentMode::Auto => chunk_count >= DEFAULT_MERKLE_THRESHOLD,
105 PaymentMode::Merkle => chunk_count >= 2,
106 PaymentMode::Single => false,
107 }
108}
109
110impl Client {
111 #[must_use]
113 pub fn should_use_merkle(&self, chunk_count: usize, mode: PaymentMode) -> bool {
114 should_use_merkle(chunk_count, mode)
115 }
116
117 pub async fn pay_for_merkle_batch(
130 &self,
131 addresses: &[[u8; 32]],
132 data_type: u32,
133 data_size: u64,
134 ) -> Result<MerkleBatchPaymentResult> {
135 let chunk_count = addresses.len();
136 if chunk_count < 2 {
137 return Err(Error::Payment(
138 "Merkle batch payment requires at least 2 chunks".to_string(),
139 ));
140 }
141
142 if chunk_count > MAX_LEAVES {
143 return self
144 .pay_for_merkle_multi_batch(addresses, data_type, data_size)
145 .await;
146 }
147
148 self.pay_for_merkle_single_batch(addresses, data_type, data_size)
149 .await
150 }
151
152 pub async fn prepare_merkle_batch_external(
158 &self,
159 addresses: &[[u8; 32]],
160 data_type: u32,
161 data_size: u64,
162 ) -> Result<PreparedMerkleBatch> {
163 let chunk_count = addresses.len();
164 let xornames: Vec<XorName> = addresses.iter().map(|a| XorName(*a)).collect();
165
166 debug!("Building merkle tree for {chunk_count} chunks");
167
168 let tree = MerkleTree::from_xornames(xornames)
170 .map_err(|e| Error::Payment(format!("Failed to build merkle tree: {e}")))?;
171
172 let depth = tree.depth();
173 let merkle_payment_timestamp = std::time::SystemTime::now()
174 .duration_since(std::time::UNIX_EPOCH)
175 .map_err(|e| Error::Payment(format!("System time error: {e}")))?
176 .as_secs();
177
178 debug!("Merkle tree: depth={depth}, leaves={chunk_count}, ts={merkle_payment_timestamp}");
179
180 let midpoint_proofs = tree
182 .reward_candidates(merkle_payment_timestamp)
183 .map_err(|e| Error::Payment(format!("Failed to generate reward candidates: {e}")))?;
184
185 debug!(
186 "Collecting candidate pools from {} midpoints (concurrent)",
187 midpoint_proofs.len()
188 );
189
190 let candidate_pools = self
192 .build_candidate_pools(
193 &midpoint_proofs,
194 data_type,
195 data_size,
196 merkle_payment_timestamp,
197 )
198 .await?;
199
200 let pool_commitments: Vec<PoolCommitment> = candidate_pools
202 .iter()
203 .map(MerklePaymentCandidatePool::to_commitment)
204 .collect();
205
206 Ok(PreparedMerkleBatch {
207 depth,
208 pool_commitments,
209 merkle_payment_timestamp,
210 candidate_pools,
211 tree,
212 addresses: addresses.to_vec(),
213 })
214 }
215
216 async fn pay_for_merkle_single_batch(
218 &self,
219 addresses: &[[u8; 32]],
220 data_type: u32,
221 data_size: u64,
222 ) -> Result<MerkleBatchPaymentResult> {
223 let wallet = self.require_wallet()?;
224 let prepared = self
225 .prepare_merkle_batch_external(addresses, data_type, data_size)
226 .await?;
227
228 info!(
229 "Submitting merkle batch payment on-chain (depth={})",
230 prepared.depth
231 );
232 let (winner_pool_hash, amount, gas_info) = wallet
233 .pay_for_merkle_tree(
234 prepared.depth,
235 prepared.pool_commitments.clone(),
236 prepared.merkle_payment_timestamp,
237 )
238 .await
239 .map_err(|e| Error::Payment(format!("Merkle batch payment failed: {e}")))?;
240
241 info!(
242 "Merkle payment succeeded: winner pool {}",
243 hex::encode(winner_pool_hash)
244 );
245
246 let mut result = finalize_merkle_batch(prepared, winner_pool_hash)?;
247 result.storage_cost_atto = amount.to_string();
248 result.gas_cost_wei = gas_info.gas_cost_wei;
249 Ok(result)
250 }
251
252 async fn pay_for_merkle_multi_batch(
254 &self,
255 addresses: &[[u8; 32]],
256 data_type: u32,
257 data_size: u64,
258 ) -> Result<MerkleBatchPaymentResult> {
259 let sub_batches: Vec<&[[u8; 32]]> = addresses.chunks(MAX_LEAVES).collect();
260 let total_sub_batches = sub_batches.len();
261 let mut all_proofs = HashMap::with_capacity(addresses.len());
262 let mut total_storage = Amount::ZERO;
263 let mut total_gas: u128 = 0;
264 let mut oldest_ts: u64 = 0;
268
269 for (i, chunk) in sub_batches.into_iter().enumerate() {
270 match self
271 .pay_for_merkle_single_batch(chunk, data_type, data_size)
272 .await
273 {
274 Ok(sub_result) => {
275 if let Ok(cost) = sub_result.storage_cost_atto.parse::<Amount>() {
276 total_storage += cost;
277 }
278 total_gas = total_gas.saturating_add(sub_result.gas_cost_wei);
279 if oldest_ts == 0
280 || (sub_result.merkle_payment_timestamp > 0
281 && sub_result.merkle_payment_timestamp < oldest_ts)
282 {
283 oldest_ts = sub_result.merkle_payment_timestamp;
284 }
285 all_proofs.extend(sub_result.proofs);
286 }
287 Err(e) => {
288 if all_proofs.is_empty() {
289 return Err(e);
291 }
292 warn!(
294 "Merkle sub-batch {}/{total_sub_batches} failed: {e}. \
295 Returning {} proofs from prior sub-batches",
296 i + 1,
297 all_proofs.len()
298 );
299 return Ok(MerkleBatchPaymentResult {
300 chunk_count: all_proofs.len(),
301 proofs: all_proofs,
302 storage_cost_atto: total_storage.to_string(),
303 gas_cost_wei: total_gas,
304 merkle_payment_timestamp: oldest_ts,
305 });
306 }
307 }
308 }
309
310 Ok(MerkleBatchPaymentResult {
311 chunk_count: addresses.len(),
312 proofs: all_proofs,
313 storage_cost_atto: total_storage.to_string(),
314 gas_cost_wei: total_gas,
315 merkle_payment_timestamp: oldest_ts,
316 })
317 }
318
319 async fn build_candidate_pools(
321 &self,
322 midpoint_proofs: &[MidpointProof],
323 data_type: u32,
324 data_size: u64,
325 merkle_payment_timestamp: u64,
326 ) -> Result<Vec<MerklePaymentCandidatePool>> {
327 let mut pool_futures = FuturesUnordered::new();
328
329 for midpoint_proof in midpoint_proofs {
330 let pool_address = midpoint_proof.address();
331 let mp = midpoint_proof.clone();
332 pool_futures.push(async move {
333 let candidate_nodes = self
334 .get_merkle_candidate_pool(
335 &pool_address.0,
336 data_type,
337 data_size,
338 merkle_payment_timestamp,
339 )
340 .await?;
341 Ok::<_, Error>(MerklePaymentCandidatePool {
342 midpoint_proof: mp,
343 candidate_nodes,
344 })
345 });
346 }
347
348 let mut pools = Vec::with_capacity(midpoint_proofs.len());
349 while let Some(result) = pool_futures.next().await {
350 pools.push(result?);
351 }
352
353 Ok(pools)
354 }
355
356 #[allow(clippy::too_many_lines)]
358 async fn get_merkle_candidate_pool(
359 &self,
360 address: &[u8; 32],
361 data_type: u32,
362 data_size: u64,
363 merkle_payment_timestamp: u64,
364 ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
365 let node = self.network().node();
366 let timeout = Duration::from_secs(self.config().quote_timeout_secs);
367
368 let query_count = CANDIDATES_PER_POOL * 2;
370 let mut remote_peers = self
371 .network()
372 .find_closest_peers(address, query_count)
373 .await?;
374
375 if remote_peers.len() < CANDIDATES_PER_POOL {
379 let connected = self.network().connected_peers().await;
380 for peer in connected {
381 if !remote_peers.iter().any(|(id, _)| *id == peer) {
382 remote_peers.push((peer, vec![]));
383 }
384 }
385 }
386
387 if remote_peers.len() < CANDIDATES_PER_POOL {
388 return Err(Error::InsufficientPeers(format!(
389 "Found {} peers, need {CANDIDATES_PER_POOL} for merkle candidate pool. \
390 Use --no-merkle or a larger network.",
391 remote_peers.len()
392 )));
393 }
394
395 let mut candidate_futures = FuturesUnordered::new();
396
397 for (peer_id, peer_addrs) in &remote_peers {
398 let request_id = self.next_request_id();
399 let request = MerkleCandidateQuoteRequest {
400 address: *address,
401 data_type,
402 data_size,
403 merkle_payment_timestamp,
404 };
405 let message = ChunkMessage {
406 request_id,
407 body: ChunkMessageBody::MerkleCandidateQuoteRequest(request),
408 };
409
410 let message_bytes = match message.encode() {
411 Ok(bytes) => bytes,
412 Err(e) => {
413 warn!("Failed to encode merkle candidate request for {peer_id}: {e}");
414 continue;
415 }
416 };
417
418 let peer_id_clone = *peer_id;
419 let addrs_clone = peer_addrs.clone();
420 let node_clone = node.clone();
421
422 let fut = async move {
423 let result = send_and_await_chunk_response(
424 &node_clone,
425 &peer_id_clone,
426 message_bytes,
427 request_id,
428 timeout,
429 &addrs_clone,
430 |body| match body {
431 ChunkMessageBody::MerkleCandidateQuoteResponse(
432 MerkleCandidateQuoteResponse::Success { candidate_node },
433 ) => {
434 match rmp_serde::from_slice::<MerklePaymentCandidateNode>(
435 &candidate_node,
436 ) {
437 Ok(node) => Some(Ok(node)),
438 Err(e) => Some(Err(Error::Serialization(format!(
439 "Failed to deserialize candidate node from {peer_id_clone}: {e}"
440 )))),
441 }
442 }
443 ChunkMessageBody::MerkleCandidateQuoteResponse(
444 MerkleCandidateQuoteResponse::Error(e),
445 ) => Some(Err(Error::Protocol(format!(
446 "Merkle quote error from {peer_id_clone}: {e}"
447 )))),
448 _ => None,
449 },
450 |e| {
451 Error::Network(format!(
452 "Failed to send merkle candidate request to {peer_id_clone}: {e}"
453 ))
454 },
455 || {
456 Error::Timeout(format!(
457 "Timeout waiting for merkle candidate from {peer_id_clone}"
458 ))
459 },
460 )
461 .await;
462
463 (peer_id_clone, result)
464 };
465
466 candidate_futures.push(fut);
467 }
468
469 self.collect_validated_candidates(&mut candidate_futures, address, merkle_payment_timestamp)
470 .await
471 }
472
473 async fn collect_validated_candidates(
484 &self,
485 futures: &mut FuturesUnordered<
486 impl std::future::Future<
487 Output = (
488 PeerId,
489 std::result::Result<MerklePaymentCandidateNode, Error>,
490 ),
491 >,
492 >,
493 target_address: &[u8; 32],
494 merkle_payment_timestamp: u64,
495 ) -> Result<[MerklePaymentCandidateNode; CANDIDATES_PER_POOL]> {
496 let mut valid: Vec<(PeerId, MerklePaymentCandidateNode)> = Vec::new();
497 let mut failures: Vec<String> = Vec::new();
498
499 while let Some((peer_id, result)) = futures.next().await {
500 match result {
501 Ok(candidate) => {
502 if !verify_merkle_candidate_signature(&candidate) {
503 warn!("Invalid ML-DSA-65 signature from merkle candidate {peer_id}");
504 failures.push(format!("{peer_id}: invalid signature"));
505 continue;
506 }
507 if candidate.merkle_payment_timestamp != merkle_payment_timestamp {
508 warn!("Timestamp mismatch from merkle candidate {peer_id}");
509 failures.push(format!("{peer_id}: timestamp mismatch"));
510 continue;
511 }
512 valid.push((peer_id, candidate));
513 }
514 Err(e) => {
515 debug!("Failed to get merkle candidate from {peer_id}: {e}");
516 failures.push(format!("{peer_id}: {e}"));
517 }
518 }
519 }
520
521 if valid.len() < CANDIDATES_PER_POOL {
522 return Err(Error::InsufficientPeers(format!(
523 "Got {} merkle candidates, need {CANDIDATES_PER_POOL}. Failures: [{}]",
524 valid.len(),
525 failures.join("; ")
526 )));
527 }
528
529 let target_peer = PeerId::from_bytes(*target_address);
530 valid.sort_by_key(|(peer_id, _)| peer_id.xor_distance(&target_peer));
531
532 let candidates: Vec<MerklePaymentCandidateNode> = valid
533 .into_iter()
534 .take(CANDIDATES_PER_POOL)
535 .map(|(_, candidate)| candidate)
536 .collect();
537
538 candidates
539 .try_into()
540 .map_err(|_| Error::Payment("Failed to convert candidates to fixed array".to_string()))
541 }
542
543 pub(crate) async fn merkle_upload_chunks(
553 &self,
554 chunk_contents: Vec<Bytes>,
555 addresses: Vec<[u8; 32]>,
556 batch_result: &MerkleBatchPaymentResult,
557 progress: Option<&mpsc::Sender<UploadEvent>>,
558 ) -> Result<(usize, crate::data::client::batch::WaveAggregateStats)> {
559 let mut stored = 0usize;
560 let mut stats = crate::data::client::batch::WaveAggregateStats::default();
561 let store_limiter = self.controller().store.clone();
562 let batch_size = chunk_contents.len();
565 let store_concurrency = store_limiter.current().min(batch_size.max(1));
566 let mut upload_stream = stream::iter(chunk_contents.into_iter().zip(addresses).map(
567 |(content, addr)| {
568 let proof_bytes = batch_result.proofs.get(&addr).cloned();
569 let limiter = store_limiter.clone();
570 async move {
571 let started = std::time::Instant::now();
572 let proof = proof_bytes.ok_or_else(|| {
573 Error::Payment(format!(
574 "Missing merkle proof for chunk {}",
575 hex::encode(addr)
576 ))
577 })?;
578 let peers = self.close_group_peers(&addr).await?;
579 observe_op(
580 &limiter,
581 || async move {
582 self.chunk_put_to_close_group(content, proof, &peers).await
583 },
584 classify_error,
585 )
586 .await
587 .map(|_| started)
588 }
589 },
590 ))
591 .buffer_unordered(store_concurrency);
592
593 while let Some(result) = upload_stream.next().await {
594 let started = result?;
595 let duration_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
596 stats.store_durations_ms.push(duration_ms);
597 stats.chunk_attempts_total = stats.chunk_attempts_total.saturating_add(1);
598 stats.retries_histogram[0] = stats.retries_histogram[0].saturating_add(1);
599 stored += 1;
600 if let Some(tx) = progress {
601 let _ = tx.try_send(UploadEvent::ChunkStored {
602 stored,
603 total: batch_size,
604 });
605 }
606 }
607
608 Ok((stored, stats))
609 }
610}
611
612pub fn finalize_merkle_batch(
617 prepared: PreparedMerkleBatch,
618 winner_pool_hash: [u8; 32],
619) -> Result<MerkleBatchPaymentResult> {
620 let chunk_count = prepared.addresses.len();
621 let xornames: Vec<XorName> = prepared.addresses.iter().map(|a| XorName(*a)).collect();
622
623 let winner_pool = prepared
625 .candidate_pools
626 .iter()
627 .find(|pool| pool.hash() == winner_pool_hash)
628 .ok_or_else(|| {
629 Error::Payment(format!(
630 "Winner pool {} not found in candidate pools",
631 hex::encode(winner_pool_hash)
632 ))
633 })?;
634
635 info!("Generating merkle proofs for {chunk_count} chunks");
637 let mut proofs = HashMap::with_capacity(chunk_count);
638
639 for (i, xorname) in xornames.iter().enumerate() {
640 let address_proof = prepared
641 .tree
642 .generate_address_proof(i, *xorname)
643 .map_err(|e| {
644 Error::Payment(format!(
645 "Failed to generate address proof for chunk {i}: {e}"
646 ))
647 })?;
648
649 let merkle_proof = MerklePaymentProof::new(*xorname, address_proof, winner_pool.clone());
650
651 let tagged_bytes = serialize_merkle_proof(&merkle_proof)
652 .map_err(|e| Error::Serialization(format!("Failed to serialize merkle proof: {e}")))?;
653
654 proofs.insert(prepared.addresses[i], tagged_bytes);
655 }
656
657 info!("Merkle batch payment complete: {chunk_count} proofs generated");
658
659 Ok(MerkleBatchPaymentResult {
660 proofs,
661 chunk_count,
662 storage_cost_atto: "0".to_string(),
663 gas_cost_wei: 0,
664 merkle_payment_timestamp: prepared.merkle_payment_timestamp,
665 })
666}
667
668#[cfg(test)]
670mod send_assertions {
671 use super::*;
672 use crate::data::client::Client;
673
674 fn _assert_send<T: Send>(_: &T) {}
675
676 #[allow(
677 dead_code,
678 unreachable_code,
679 unused_variables,
680 clippy::diverging_sub_expression
681 )]
682 async fn _merkle_upload_chunks_is_send(client: &Client) {
683 let batch_result: MerkleBatchPaymentResult = todo!();
684 let fut = client.merkle_upload_chunks(Vec::new(), Vec::new(), &batch_result, None);
685 _assert_send(&fut);
686 }
687}
688
689#[cfg(test)]
690#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
691mod tests {
692 use super::*;
693 use ant_protocol::evm::{Amount, MerkleTree, RewardsAddress, CANDIDATES_PER_POOL};
694
695 #[test]
700 fn test_auto_below_threshold() {
701 assert!(!should_use_merkle(1, PaymentMode::Auto));
702 assert!(!should_use_merkle(10, PaymentMode::Auto));
703 assert!(!should_use_merkle(63, PaymentMode::Auto));
704 }
705
706 #[test]
707 fn test_auto_at_and_above_threshold() {
708 assert!(should_use_merkle(64, PaymentMode::Auto));
709 assert!(should_use_merkle(65, PaymentMode::Auto));
710 assert!(should_use_merkle(1000, PaymentMode::Auto));
711 }
712
713 #[test]
714 fn test_merkle_mode_forces_at_2() {
715 assert!(!should_use_merkle(1, PaymentMode::Merkle));
716 assert!(should_use_merkle(2, PaymentMode::Merkle));
717 assert!(should_use_merkle(3, PaymentMode::Merkle));
718 }
719
720 #[test]
721 fn test_single_mode_always_false() {
722 assert!(!should_use_merkle(0, PaymentMode::Single));
723 assert!(!should_use_merkle(64, PaymentMode::Single));
724 assert!(!should_use_merkle(1000, PaymentMode::Single));
725 }
726
727 #[test]
728 fn test_default_mode_is_auto() {
729 assert_eq!(PaymentMode::default(), PaymentMode::Auto);
730 }
731
732 #[test]
733 fn test_threshold_value() {
734 assert_eq!(DEFAULT_MERKLE_THRESHOLD, 64);
735 }
736
737 fn make_test_addresses(count: usize) -> Vec<[u8; 32]> {
742 (0..count)
743 .map(|i| {
744 let xn = XorName::from_content(&i.to_le_bytes());
745 xn.0
746 })
747 .collect()
748 }
749
750 #[test]
751 fn test_tree_depth_for_known_sizes() {
752 let cases = [(2, 1), (4, 2), (16, 4), (100, 7), (256, 8)];
753 for (count, expected_depth) in cases {
754 let addrs = make_test_addresses(count);
755 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
756 let tree = MerkleTree::from_xornames(xornames).unwrap();
757 assert_eq!(
758 tree.depth(),
759 expected_depth,
760 "depth mismatch for {count} leaves"
761 );
762 }
763 }
764
765 #[test]
766 fn test_proof_generation_and_verification_for_all_leaves() {
767 let addrs = make_test_addresses(16);
768 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
769 let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
770
771 for (i, xn) in xornames.iter().enumerate() {
772 let proof = tree.generate_address_proof(i, *xn).unwrap();
773 assert!(proof.verify(), "proof for leaf {i} should verify");
774 assert_eq!(proof.depth(), tree.depth() as usize);
775 }
776 }
777
778 #[test]
779 fn test_proof_fails_for_wrong_address() {
780 let addrs = make_test_addresses(8);
781 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
782 let tree = MerkleTree::from_xornames(xornames).unwrap();
783
784 let wrong = XorName::from_content(b"wrong");
785 let proof = tree.generate_address_proof(0, wrong).unwrap();
786 assert!(!proof.verify(), "proof with wrong address should fail");
787 }
788
789 #[test]
790 fn test_tree_too_few_leaves() {
791 let xornames = vec![XorName::from_content(b"only_one")];
792 let result = MerkleTree::from_xornames(xornames);
793 assert!(result.is_err());
794 }
795
796 #[test]
797 fn test_tree_at_max_leaves() {
798 let addrs = make_test_addresses(MAX_LEAVES);
799 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
800 let tree = MerkleTree::from_xornames(xornames).unwrap();
801 assert_eq!(tree.leaf_count(), MAX_LEAVES);
802 }
803
804 #[test]
809 fn test_merkle_proof_serialize_deserialize_roundtrip() {
810 use ant_protocol::evm::{Amount, MerklePaymentCandidateNode, RewardsAddress};
811 use ant_protocol::payment::{deserialize_merkle_proof, serialize_merkle_proof};
812
813 let addrs = make_test_addresses(4);
814 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
815 let tree = MerkleTree::from_xornames(xornames.clone()).unwrap();
816
817 let timestamp = std::time::SystemTime::now()
818 .duration_since(std::time::UNIX_EPOCH)
819 .unwrap()
820 .as_secs();
821
822 let candidates = tree.reward_candidates(timestamp).unwrap();
823 let midpoint = candidates.first().unwrap().clone();
824
825 #[allow(clippy::cast_possible_truncation)]
827 let candidate_nodes: [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] =
828 std::array::from_fn(|i| MerklePaymentCandidateNode {
829 pub_key: vec![i as u8; 32],
830 price: Amount::from(1024u64),
831 reward_address: RewardsAddress::new([i as u8; 20]),
832 merkle_payment_timestamp: timestamp,
833 signature: vec![i as u8; 64],
834 });
835
836 let pool = MerklePaymentCandidatePool {
837 midpoint_proof: midpoint,
838 candidate_nodes,
839 };
840
841 let address_proof = tree.generate_address_proof(0, xornames[0]).unwrap();
842 let merkle_proof = MerklePaymentProof::new(xornames[0], address_proof, pool);
843
844 let tagged = serialize_merkle_proof(&merkle_proof).unwrap();
845 assert_eq!(
846 tagged.first().copied(),
847 Some(0x02),
848 "tag should be PROOF_TAG_MERKLE"
849 );
850
851 let deserialized = deserialize_merkle_proof(&tagged).unwrap();
852 assert_eq!(deserialized.address, merkle_proof.address);
853 assert_eq!(
854 deserialized.winner_pool.candidate_nodes.len(),
855 CANDIDATES_PER_POOL
856 );
857 }
858
859 #[test]
864 fn test_candidate_wrong_timestamp_rejected() {
865 let candidate = MerklePaymentCandidateNode {
867 pub_key: vec![0u8; 32],
868 price: ant_protocol::evm::Amount::ZERO,
869 reward_address: ant_protocol::evm::RewardsAddress::new([0u8; 20]),
870 merkle_payment_timestamp: 1000,
871 signature: vec![0u8; 64],
872 };
873
874 assert_ne!(candidate.merkle_payment_timestamp, 2000);
876 }
877
878 fn make_dummy_candidate_nodes(
883 timestamp: u64,
884 ) -> [MerklePaymentCandidateNode; CANDIDATES_PER_POOL] {
885 std::array::from_fn(|i| MerklePaymentCandidateNode {
886 pub_key: vec![i as u8; 32],
887 price: Amount::from(1024u64),
888 reward_address: RewardsAddress::new([i as u8; 20]),
889 merkle_payment_timestamp: timestamp,
890 signature: vec![i as u8; 64],
891 })
892 }
893
894 fn make_prepared_merkle_batch(count: usize) -> PreparedMerkleBatch {
895 let addrs = make_test_addresses(count);
896 let xornames: Vec<XorName> = addrs.iter().map(|a| XorName(*a)).collect();
897 let tree = MerkleTree::from_xornames(xornames).unwrap();
898
899 let timestamp = std::time::SystemTime::now()
900 .duration_since(std::time::UNIX_EPOCH)
901 .unwrap()
902 .as_secs();
903
904 let midpoints = tree.reward_candidates(timestamp).unwrap();
905
906 let candidate_pools: Vec<MerklePaymentCandidatePool> = midpoints
907 .into_iter()
908 .map(|mp| MerklePaymentCandidatePool {
909 midpoint_proof: mp,
910 candidate_nodes: make_dummy_candidate_nodes(timestamp),
911 })
912 .collect();
913
914 let pool_commitments = candidate_pools
915 .iter()
916 .map(MerklePaymentCandidatePool::to_commitment)
917 .collect();
918
919 PreparedMerkleBatch {
920 depth: tree.depth(),
921 pool_commitments,
922 merkle_payment_timestamp: timestamp,
923 candidate_pools,
924 tree,
925 addresses: addrs,
926 }
927 }
928
929 #[test]
930 fn test_finalize_merkle_batch_with_valid_winner() {
931 let prepared = make_prepared_merkle_batch(4);
932 let winner_hash = prepared.candidate_pools[0].hash();
933
934 let result = finalize_merkle_batch(prepared, winner_hash);
935 assert!(
936 result.is_ok(),
937 "should succeed with valid winner: {result:?}"
938 );
939
940 let batch = result.unwrap();
941 assert_eq!(batch.chunk_count, 4);
942 assert_eq!(batch.proofs.len(), 4);
943
944 for proof_bytes in batch.proofs.values() {
946 assert!(!proof_bytes.is_empty());
947 }
948 }
949
950 #[test]
951 fn test_finalize_merkle_batch_with_invalid_winner() {
952 let prepared = make_prepared_merkle_batch(4);
953 let bad_hash = [0xFF; 32];
954
955 let result = finalize_merkle_batch(prepared, bad_hash);
956 assert!(result.is_err());
957 let err = result.unwrap_err().to_string();
958 assert!(err.contains("not found in candidate pools"), "got: {err}");
959 }
960
961 #[test]
962 fn test_finalize_merkle_batch_proofs_are_deserializable() {
963 use ant_protocol::payment::deserialize_merkle_proof;
964
965 let prepared = make_prepared_merkle_batch(8);
966 let winner_hash = prepared.candidate_pools[0].hash();
967
968 let batch = finalize_merkle_batch(prepared, winner_hash).unwrap();
969
970 for (addr, proof_bytes) in &batch.proofs {
971 let proof = deserialize_merkle_proof(proof_bytes);
972 assert!(
973 proof.is_ok(),
974 "proof for {} should deserialize: {:?}",
975 hex::encode(addr),
976 proof.err()
977 );
978 }
979 }
980
981 #[test]
986 fn test_batch_split_calculation() {
987 let addrs = make_test_addresses(MAX_LEAVES);
989 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 1);
990
991 let addrs = make_test_addresses(MAX_LEAVES + 1);
993 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 2);
994
995 let addrs = make_test_addresses(3 * MAX_LEAVES);
997 assert_eq!(addrs.chunks(MAX_LEAVES).count(), 3);
998 }
999}