use crate::common::{Address, Amount, Calldata, TxHash};
use crate::contract::payment_vault::error::Error;
use crate::contract::payment_vault::interface::IPaymentVault;
use crate::contract::payment_vault::interface::IPaymentVault::IPaymentVaultInstance;
use crate::merkle_batch_payment::PoolHash;
use crate::retry::{GasInfo, TransactionError, send_transaction_with_retries};
use crate::transaction_config::TransactionConfig;
use alloy::network::{Network, TransactionResponse};
use alloy::providers::Provider;
use exponential_backoff::Backoff;
use std::time::Duration;
pub struct PaymentVaultHandler<P: Provider<N>, N: Network> {
pub contract: IPaymentVaultInstance<P, N>,
}
impl<P, N> PaymentVaultHandler<P, N>
where
P: Provider<N>,
N: Network,
{
pub fn new(contract_address: Address, provider: P) -> Self {
let contract = IPaymentVault::new(contract_address, provider);
Self { contract }
}
pub fn set_provider(&mut self, provider: P) {
let address = *self.contract.address();
self.contract = IPaymentVault::new(address, provider);
}
pub async fn pay_for_quotes<I: IntoIterator<Item: Into<IPaymentVault::DataPayment>>>(
&self,
data_payments: I,
transaction_config: &TransactionConfig,
) -> Result<(TxHash, GasInfo), Error> {
debug!("Paying for quotes.");
let (calldata, to) = self.pay_for_quotes_calldata(data_payments)?;
send_transaction_with_retries(
self.contract.provider(),
calldata,
to,
"pay for quotes",
transaction_config,
)
.await
.map_err(Error::from)
}
pub fn pay_for_quotes_calldata<I: IntoIterator<Item: Into<IPaymentVault::DataPayment>>>(
&self,
data_payments: I,
) -> Result<(Calldata, Address), Error> {
let data_payments: Vec<IPaymentVault::DataPayment> =
data_payments.into_iter().map(|item| item.into()).collect();
let calldata = self
.contract
.payForQuotes(data_payments)
.calldata()
.to_owned();
Ok((calldata, *self.contract.address()))
}
pub async fn pay_for_merkle_tree<I, T>(
&self,
depth: u8,
pool_commitments: I,
merkle_payment_timestamp: u64,
transaction_config: &TransactionConfig,
) -> Result<(PoolHash, Amount, GasInfo), Error>
where
I: IntoIterator<Item = T>,
T: Into<IPaymentVault::PoolCommitment>,
{
debug!("Paying for Merkle tree: depth={depth}, timestamp={merkle_payment_timestamp}");
let (calldata, to) =
self.pay_for_merkle_tree_calldata(depth, pool_commitments, merkle_payment_timestamp)?;
let (tx_hash, gas_info) = self
.send_transaction_and_handle_errors(calldata, to, transaction_config)
.await?;
let event = self.get_merkle_payment_event(tx_hash).await?;
let winner_pool_hash = event.winnerPoolHash.0;
let total_amount = event.totalAmount;
debug!(
"MerklePaymentMade event: winnerPoolHash={}, depth={}, totalAmount={}, timestamp={}",
hex::encode(winner_pool_hash),
event.depth,
total_amount,
event.merklePaymentTimestamp
);
Ok((winner_pool_hash, total_amount, gas_info))
}
pub fn pay_for_merkle_tree_calldata<I, T>(
&self,
depth: u8,
pool_commitments: I,
merkle_payment_timestamp: u64,
) -> Result<(Calldata, Address), Error>
where
I: IntoIterator<Item = T>,
T: Into<IPaymentVault::PoolCommitment>,
{
let pool_commitments: Vec<IPaymentVault::PoolCommitment> = pool_commitments
.into_iter()
.map(|item| item.into())
.collect();
let calldata = self
.contract
.payForMerkleTree(depth, pool_commitments, merkle_payment_timestamp)
.calldata()
.to_owned();
Ok((calldata, *self.contract.address()))
}
pub async fn get_completed_merkle_payment(
&self,
winner_pool_hash: PoolHash,
) -> Result<IPaymentVault::CompletedMerklePayment, Error> {
debug!(
"Getting completed merkle payment for pool hash: {}",
hex::encode(winner_pool_hash)
);
let info = self
.contract
.getCompletedMerklePayment(winner_pool_hash.into())
.call()
.await
.map_err(Error::Contract)?;
if info.depth == 0 {
return Err(Error::PaymentNotFound(hex::encode(winner_pool_hash)));
}
debug!(
"getCompletedMerklePayment returned: depth={}, timestamp={}, paid_nodes={}",
info.depth,
info.merklePaymentTimestamp,
info.paidNodeAddresses.len()
);
Ok(info)
}
async fn get_merkle_payment_event(
&self,
tx_hash: TxHash,
) -> Result<IPaymentVault::MerklePaymentMade, Error> {
const MAX_ATTEMPTS: u32 = 3;
const INITIAL_DELAY_MS: u64 = 500;
const MAX_DELAY_MS: u64 = 8000;
let backoff = Backoff::new(
MAX_ATTEMPTS,
Duration::from_millis(INITIAL_DELAY_MS),
Some(Duration::from_millis(MAX_DELAY_MS)),
);
let mut last_error = None;
let mut attempt = 1;
for duration_opt in backoff {
match self.try_get_merkle_payment_event(tx_hash).await {
Ok(event) => return Ok(event),
Err(e) => {
last_error = Some(e);
if let Some(duration) = duration_opt {
debug!(
"Failed to get MerklePaymentMade event (attempt {}/{}), retrying in {}ms",
attempt,
MAX_ATTEMPTS,
duration.as_millis()
);
tokio::time::sleep(duration).await;
}
attempt += 1;
}
}
}
Err(last_error.unwrap_or_else(|| {
Error::Rpc("Failed to get MerklePaymentMade event after retries".to_string())
}))
}
async fn try_get_merkle_payment_event(
&self,
tx_hash: TxHash,
) -> Result<IPaymentVault::MerklePaymentMade, Error> {
let tx = self
.contract
.provider()
.get_transaction_by_hash(tx_hash)
.await
.map_err(|e| Error::Rpc(format!("Failed to get transaction: {e}")))?
.ok_or_else(|| Error::Rpc("Transaction not found".to_string()))?;
let block_number = tx
.block_number()
.ok_or_else(|| Error::Rpc("Transaction has no block number".to_string()))?;
let events = self
.contract
.MerklePaymentMade_filter()
.from_block(block_number)
.to_block(block_number)
.query()
.await
.map_err(|e| Error::Rpc(format!("Failed to query MerklePaymentMade events: {e}")))?;
events
.into_iter()
.find(|(_, log)| log.transaction_hash == Some(tx_hash))
.map(|(event, _)| event)
.ok_or_else(|| {
Error::Rpc("MerklePaymentMade event not found in transaction".to_string())
})
}
async fn send_transaction_and_handle_errors(
&self,
calldata: Calldata,
to: Address,
transaction_config: &TransactionConfig,
) -> Result<(TxHash, GasInfo), Error> {
let tx_result = crate::retry::send_transaction_with_retries(
self.contract.provider(),
calldata,
to,
"pay for merkle tree",
transaction_config,
)
.await;
match tx_result {
Ok((hash, gas_info)) => Ok((hash, gas_info)),
Err(TransactionError::TransactionReverted {
message,
revert_data,
nonce,
}) => {
let error = self.decode_revert_error(message, revert_data, nonce);
Err(error)
}
Err(other_err) => Err(Error::from(other_err)),
}
}
fn decode_revert_error(
&self,
message: String,
revert_data: Option<alloy::primitives::Bytes>,
nonce: Option<u64>,
) -> Error {
if let Some(revert_data_bytes) = &revert_data
&& let Some(decoded_err) = Error::try_decode_revert(revert_data_bytes)
{
return decoded_err;
}
Error::Transaction(TransactionError::TransactionReverted {
message,
revert_data,
nonce,
})
}
}