use crate::common::{Address as RewardsAddress, U256};
use crate::contract::data_type_conversion;
use crate::quoting_metrics::QuotingMetrics;
use serde::{Deserialize, Serialize};
#[cfg(any(test, feature = "test-utils"))]
use crate::common::Amount;
#[cfg(any(test, feature = "test-utils"))]
use std::path::PathBuf;
#[cfg(any(test, feature = "test-utils"))]
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CostUnitOverflow;
impl std::fmt::Display for CostUnitOverflow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"total_cost_unit exceeds {TOTAL_COST_UNIT_BITS}-bit limit (top 8 bits reserved for packing)"
)
}
}
impl std::error::Error for CostUnitOverflow {}
pub type PoolHash = [u8; 32];
pub const CANDIDATES_PER_POOL: usize = 16;
pub const MAX_MERKLE_DEPTH: u8 = 8;
const TOTAL_COST_UNIT_BITS: usize = 248;
const COST_UNIT_GRAPH_ENTRY: u64 = 1;
const COST_UNIT_SCRATCHPAD: u64 = 100;
const COST_UNIT_CHUNK: u64 = 10;
const COST_UNIT_POINTER: u64 = 20;
fn cost_unit_for_data_type(solidity_data_type: u8) -> U256 {
match solidity_data_type {
0 => U256::from(COST_UNIT_GRAPH_ENTRY),
1 => U256::from(COST_UNIT_SCRATCHPAD),
2 => U256::from(COST_UNIT_CHUNK),
3 => U256::from(COST_UNIT_POINTER),
_ => U256::ZERO,
}
}
pub fn expected_reward_pools(depth: u8) -> usize {
let half_depth = depth.div_ceil(2);
1 << half_depth
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct PoolCommitment {
pub pool_hash: PoolHash,
pub candidates: [CandidateNode; CANDIDATES_PER_POOL],
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct CandidateNode {
pub rewards_address: RewardsAddress,
pub metrics: QuotingMetrics,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct CandidateNodePacked {
pub rewards_address: RewardsAddress,
pub data_type_and_total_cost_unit: U256,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct PoolCommitmentPacked {
pub pool_hash: PoolHash,
pub candidates: [CandidateNodePacked; CANDIDATES_PER_POOL],
}
pub fn encode_data_type_and_cost(
data_type: u8,
total_cost_unit: U256,
) -> Result<U256, CostUnitOverflow> {
if total_cost_unit >= (U256::from(1) << TOTAL_COST_UNIT_BITS) {
return Err(CostUnitOverflow);
}
Ok((total_cost_unit << 8) | U256::from(data_type))
}
#[cfg(test)]
pub fn decode_data_type_and_cost(packed: U256) -> (u8, U256) {
let data_type = (packed & U256::from(0xFF)).to::<u8>();
let total_cost_unit = packed >> 8;
(data_type, total_cost_unit)
}
pub fn calculate_total_cost_unit(metrics: &QuotingMetrics) -> U256 {
let total_from_types: U256 =
metrics
.records_per_type
.iter()
.fold(U256::ZERO, |acc, (data_type, count)| {
let solidity_type = data_type_conversion(*data_type);
acc + cost_unit_for_data_type(solidity_type) * U256::from(*count)
});
if total_from_types > U256::ZERO {
total_from_types
} else {
let fallback = std::cmp::max(metrics.close_records_stored as u64, 1);
let solidity_type = data_type_conversion(metrics.data_type);
cost_unit_for_data_type(solidity_type) * U256::from(fallback)
}
}
impl CandidateNode {
pub fn to_packed(&self) -> Result<CandidateNodePacked, CostUnitOverflow> {
let data_type = data_type_conversion(self.metrics.data_type);
let total_cost_unit = calculate_total_cost_unit(&self.metrics);
Ok(CandidateNodePacked {
rewards_address: self.rewards_address,
data_type_and_total_cost_unit: encode_data_type_and_cost(data_type, total_cost_unit)?,
})
}
}
impl PoolCommitment {
pub fn to_packed(&self) -> Result<PoolCommitmentPacked, CostUnitOverflow> {
let mut packed_candidates = Vec::with_capacity(CANDIDATES_PER_POOL);
for c in &self.candidates {
packed_candidates.push(c.to_packed()?);
}
let candidates: [CandidateNodePacked; CANDIDATES_PER_POOL] = packed_candidates
.try_into()
.expect("Vec length matches CANDIDATES_PER_POOL");
Ok(PoolCommitmentPacked {
pool_hash: self.pool_hash,
candidates,
})
}
}
#[cfg(any(test, feature = "test-utils"))]
#[derive(Debug, Error)]
pub enum SmartContractError {
#[error("Wrong number of candidate nodes: expected {expected}, got {got}")]
WrongCandidateCount { expected: usize, got: usize },
#[error("Wrong number of candidate pools: expected {expected}, got {got}")]
WrongPoolCount { expected: usize, got: usize },
#[error("Depth {depth} exceeds maximum supported depth {max}")]
DepthTooLarge { depth: u8, max: u8 },
#[error("Payment not found for winner pool hash: {0}")]
PaymentNotFound(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OnChainPaymentInfo {
pub depth: u8,
pub merkle_payment_timestamp: u64,
pub paid_node_addresses: Vec<(RewardsAddress, usize)>,
}
#[cfg(any(test, feature = "test-utils"))]
pub struct DiskMerklePaymentContract {
storage_path: PathBuf, }
#[cfg(any(test, feature = "test-utils"))]
impl DiskMerklePaymentContract {
pub fn new_with_path(storage_path: PathBuf) -> Result<Self, SmartContractError> {
std::fs::create_dir_all(&storage_path)?;
Ok(Self { storage_path })
}
pub fn new() -> Result<Self, SmartContractError> {
let storage_path = if let Some(data_dir) = dirs_next::data_dir() {
data_dir.join("autonomi").join("merkle_payments")
} else {
PathBuf::from(".autonomi").join("merkle_payments")
};
Self::new_with_path(storage_path)
}
pub fn pay_for_merkle_tree(
&self,
depth: u8,
pool_commitments: Vec<PoolCommitment>,
merkle_payment_timestamp: u64,
) -> Result<(PoolHash, Amount), SmartContractError> {
if depth > MAX_MERKLE_DEPTH {
return Err(SmartContractError::DepthTooLarge {
depth,
max: MAX_MERKLE_DEPTH,
});
}
let expected_pools = expected_reward_pools(depth);
if pool_commitments.len() != expected_pools {
return Err(SmartContractError::WrongPoolCount {
expected: expected_pools,
got: pool_commitments.len(),
});
}
for pool in &pool_commitments {
if pool.candidates.len() != CANDIDATES_PER_POOL {
return Err(SmartContractError::WrongCandidateCount {
expected: CANDIDATES_PER_POOL,
got: pool.candidates.len(),
});
}
}
let winner_pool_idx = rand::random::<usize>() % pool_commitments.len();
let winner_pool = &pool_commitments[winner_pool_idx];
let winner_pool_hash = winner_pool.pool_hash;
println!("\n=== MERKLE BATCH PAYMENT ===");
println!("Depth: {depth}");
println!("Total pools: {}", pool_commitments.len());
println!("Nodes per pool: {CANDIDATES_PER_POOL}");
println!("Winner pool index: {winner_pool_idx}");
println!("Winner pool hash: {}", hex::encode(winner_pool_hash));
use std::collections::HashSet;
let mut winner_node_indices = HashSet::new();
while winner_node_indices.len() < depth as usize {
let idx = rand::random::<usize>() % winner_pool.candidates.len();
winner_node_indices.insert(idx);
}
let winner_node_indices: Vec<usize> = winner_node_indices.into_iter().collect();
println!(
"\nSelected {} winner nodes from pool:",
winner_node_indices.len()
);
let mut paid_node_addresses = Vec::new();
for (i, &node_idx) in winner_node_indices.iter().enumerate() {
let addr = winner_pool.candidates[node_idx].rewards_address;
paid_node_addresses.push((addr, node_idx));
println!(" Node {}: {addr}", i + 1);
}
println!(
"\nSimulating payment to {} nodes...",
paid_node_addresses.len()
);
println!("=========================\n");
let info = OnChainPaymentInfo {
depth,
merkle_payment_timestamp,
paid_node_addresses,
};
let file_path = self
.storage_path
.join(format!("{}.json", hex::encode(winner_pool_hash)));
let json = serde_json::to_string_pretty(&info)?;
std::fs::write(&file_path, json)?;
println!("✓ Stored payment info to: {}", file_path.display());
let placeholder_amount = Amount::from(2_u64.pow(depth as u32));
Ok((winner_pool_hash, placeholder_amount))
}
pub fn get_payment_info(
&self,
winner_pool_hash: PoolHash,
) -> Result<OnChainPaymentInfo, SmartContractError> {
let file_path = self
.storage_path
.join(format!("{}.json", hex::encode(winner_pool_hash)));
let json = std::fs::read_to_string(&file_path)
.map_err(|_| SmartContractError::PaymentNotFound(hex::encode(winner_pool_hash)))?;
let info = serde_json::from_str(&json)?;
Ok(info)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_data_type_and_cost() {
let data_type: u8 = 2; let total_cost_unit = U256::from(1000u64);
let packed = encode_data_type_and_cost(data_type, total_cost_unit).unwrap();
let (decoded_type, decoded_cost) = decode_data_type_and_cost(packed);
assert_eq!(decoded_type, data_type);
assert_eq!(decoded_cost, total_cost_unit);
}
#[test]
fn test_encode_decode_boundary_values() {
let data_type: u8 = 255;
let total_cost_unit = U256::from(100u64);
let packed = encode_data_type_and_cost(data_type, total_cost_unit).unwrap();
let (decoded_type, decoded_cost) = decode_data_type_and_cost(packed);
assert_eq!(decoded_type, data_type);
assert_eq!(decoded_cost, total_cost_unit);
}
#[test]
fn test_encode_decode_zero_values() {
let packed = encode_data_type_and_cost(0, U256::ZERO).unwrap();
let (decoded_type, decoded_cost) = decode_data_type_and_cost(packed);
assert_eq!(decoded_type, 0);
assert_eq!(decoded_cost, U256::ZERO);
}
#[test]
fn test_encode_returns_error_on_overflow() {
let overflow_value = U256::MAX;
let result = encode_data_type_and_cost(0, overflow_value);
assert!(result.is_err());
}
#[test]
fn test_encode_decode_large_cost() {
let data_type: u8 = 1;
let large_cost = U256::from(u128::MAX);
let packed = encode_data_type_and_cost(data_type, large_cost).unwrap();
let (decoded_type, decoded_cost) = decode_data_type_and_cost(packed);
assert_eq!(decoded_type, data_type);
assert_eq!(decoded_cost, large_cost);
}
#[test]
fn test_calculate_total_cost_unit() {
let metrics = QuotingMetrics {
data_type: 0,
data_size: 1024 * 1024,
close_records_stored: 100,
records_per_type: vec![(0, 10), (1, 20), (2, 5)],
max_records: 1000,
received_payment_count: 50,
live_time: 3600,
network_density: None,
network_size: Some(1000),
};
let total_cost = calculate_total_cost_unit(&metrics);
assert_eq!(total_cost, U256::from(220u64));
}
#[test]
fn test_calculate_total_cost_unit_empty_records() {
let metrics = QuotingMetrics {
data_type: 0,
data_size: 1024,
close_records_stored: 0,
records_per_type: vec![],
max_records: 1000,
received_payment_count: 0,
live_time: 0,
network_density: None,
network_size: None,
};
let total_cost = calculate_total_cost_unit(&metrics);
assert_eq!(total_cost, U256::from(10u64));
}
#[test]
fn test_candidate_node_to_packed() {
let metrics = QuotingMetrics {
data_type: 0, data_size: 1024 * 1024,
close_records_stored: 100,
records_per_type: vec![(0, 10)],
max_records: 1000,
received_payment_count: 50,
live_time: 3600,
network_density: None,
network_size: Some(1000),
};
let candidate = CandidateNode {
rewards_address: RewardsAddress::from([0x42; 20]),
metrics,
};
let packed = candidate.to_packed().unwrap();
assert_eq!(packed.rewards_address, candidate.rewards_address);
let (data_type, total_cost) =
decode_data_type_and_cost(packed.data_type_and_total_cost_unit);
assert_eq!(data_type, 2); assert_eq!(total_cost, U256::from(100u64)); }
#[test]
fn test_pool_commitment_to_packed() {
let candidates: [CandidateNode; CANDIDATES_PER_POOL] =
std::array::from_fn(|i| CandidateNode {
rewards_address: RewardsAddress::from([i as u8; 20]),
metrics: QuotingMetrics {
data_type: 0,
data_size: 1024,
close_records_stored: i * 10,
records_per_type: vec![(0, i as u32)],
max_records: 1000,
received_payment_count: i,
live_time: 3600,
network_density: None,
network_size: None,
},
});
let pool = PoolCommitment {
pool_hash: [0x42; 32],
candidates,
};
let packed = pool.to_packed().unwrap();
assert_eq!(packed.pool_hash, pool.pool_hash);
assert_eq!(packed.candidates.len(), CANDIDATES_PER_POOL);
assert_eq!(
packed.candidates[0].rewards_address,
pool.candidates[0].rewards_address
);
}
}