use std::collections::HashMap;
use super::{Requirement, TransactionBuilder, TransactionBuilderError};
use crate::{
client::secret::types::InputSigningData,
types::block::{
address::Address,
input::{UtxoInput, INPUT_COUNT_MAX},
output::{
unlock_condition::StorageDepositReturnUnlockCondition, AccountOutput, BasicOutput, ChainId, FoundryOutput,
NftOutput, Output,
},
slot::{SlotCommitmentId, SlotIndex},
},
};
pub(crate) fn sdruc_not_expired(
output: &Output,
slot_index: SlotIndex,
) -> Option<&StorageDepositReturnUnlockCondition> {
output.unlock_conditions().storage_deposit_return().filter(|_| {
output
.unlock_conditions()
.expiration()
.map_or(true, |expiration| slot_index < expiration.slot_index())
})
}
impl TransactionBuilder {
pub(crate) fn fulfill_amount_requirement(&mut self) -> Result<(), TransactionBuilderError> {
let (input_amount, output_amount) = self.amount_balance()?;
if input_amount >= output_amount {
log::debug!("Amount requirement already fulfilled");
return Ok(());
}
log::debug!("Fulfilling amount requirement with input amount {input_amount}, output amount {output_amount}");
if !self.allow_additional_input_selection {
return Err(TransactionBuilderError::AdditionalInputsRequired(Requirement::Amount));
}
if let Some(input) = self.next_input_for_amount(output_amount - input_amount, self.latest_slot_commitment_id) {
self.requirements.push(Requirement::Amount);
self.select_input(input)?;
} else {
return Err(TransactionBuilderError::InsufficientAmount {
found: input_amount,
required: output_amount,
});
}
Ok(())
}
pub(crate) fn amount_sums(&self) -> (u64, u64, HashMap<Address, u64>, HashMap<Address, u64>) {
let mut inputs_sum = 0;
let mut outputs_sum = 0;
let mut inputs_sdr = HashMap::new();
let mut outputs_sdr = HashMap::new();
for selected_input in self.selected_inputs.iter() {
inputs_sum += selected_input.output.amount();
if let Some(sdruc) = sdruc_not_expired(&selected_input.output, self.latest_slot_commitment_id.slot_index())
{
*inputs_sdr.entry(sdruc.return_address().clone()).or_default() += sdruc.amount();
}
}
for output in self.non_remainder_outputs() {
outputs_sum += output.amount();
if let Output::Basic(output) = output {
if let Some(address) = output.simple_deposit_address() {
*outputs_sdr.entry(address.clone()).or_default() += output.amount();
}
}
}
for (sdr_address, input_sdr_amount) in &inputs_sdr {
let output_sdr_amount = outputs_sdr.get(sdr_address).unwrap_or(&0);
if input_sdr_amount > output_sdr_amount {
outputs_sum += input_sdr_amount - output_sdr_amount;
}
}
(inputs_sum, outputs_sum, inputs_sdr, outputs_sdr)
}
pub(crate) fn amount_balance(&mut self) -> Result<(u64, u64), TransactionBuilderError> {
let (inputs_sum, mut outputs_sum, _, _) = self.amount_sums();
let (remainder_amount, native_tokens_remainder, mana_remainder) = self.required_remainder_amount()?;
if inputs_sum > outputs_sum {
let diff = inputs_sum - outputs_sum;
if remainder_amount > diff {
outputs_sum += remainder_amount - diff
}
} else if native_tokens_remainder || mana_remainder {
outputs_sum += remainder_amount
}
Ok((inputs_sum, outputs_sum))
}
pub(crate) fn amount_chains(&self) -> Result<HashMap<ChainId, (u64, u64)>, TransactionBuilderError> {
let mut res = self
.non_remainder_outputs()
.filter_map(|o| o.chain_id().map(|id| (id, (0, o.amount()))))
.collect::<HashMap<_, _>>();
for input in self.selected_inputs.iter() {
if let Some(chain_id) = input
.output
.chain_id()
.map(|id| id.or_from_output_id(input.output_id()))
{
res.entry(chain_id).or_default().0 += input.output.amount();
}
}
Ok(res)
}
fn next_input_for_amount(
&mut self,
missing_amount: u64,
slot_commitment_id: SlotCommitmentId,
) -> Option<InputSigningData> {
self.available_inputs
.iter()
.enumerate()
.filter_map(|(idx, input)| {
self.score_for_amount(input, missing_amount, slot_commitment_id.slot_index())
.map(|score| (score, idx))
})
.max_by_key(|(score, _)| *score)
.map(|(_, idx)| self.available_inputs.swap_remove(idx))
}
fn score_for_amount(&self, input: &InputSigningData, missing_amount: u64, slot_index: SlotIndex) -> Option<usize> {
([
BasicOutput::KIND,
NftOutput::KIND,
AccountOutput::KIND,
FoundryOutput::KIND,
]
.contains(&input.output.kind()))
.then(|| {
let mut work_score = self
.protocol_parameters
.work_score(&UtxoInput::from(*input.output_id()));
let mut amount_gained = input.output.amount();
let mut remainder_work_score = 0;
if let Some(sdruc) = sdruc_not_expired(&input.output, slot_index) {
amount_gained = amount_gained.saturating_sub(sdruc.amount());
remainder_work_score = self.protocol_parameters.work_score(self.basic_remainder())
}
if let Ok(Some(output)) = self.transition_input(input) {
amount_gained = amount_gained.saturating_sub(output.amount());
work_score += self.protocol_parameters.work_score(&output);
} else if input.output.native_token().is_some() {
amount_gained = amount_gained.saturating_sub(self.native_token_remainder().amount());
remainder_work_score += self.protocol_parameters.work_score(self.native_token_remainder());
} else if amount_gained > missing_amount {
amount_gained = amount_gained.saturating_sub(self.basic_remainder().amount());
remainder_work_score = self.protocol_parameters.work_score(self.basic_remainder());
}
work_score += remainder_work_score;
if amount_gained == 0 {
return None;
}
let amount_diff = amount_gained.abs_diff(missing_amount) as f64;
let amount_score = if amount_gained >= missing_amount {
(-amount_diff / u64::MAX as f64).exp()
} else {
(-amount_diff / missing_amount as f64).exp()
* ((INPUT_COUNT_MAX as f64 - self.selected_inputs.len() as f64) / INPUT_COUNT_MAX as f64)
};
let work_score = (-(work_score as f64) / u32::MAX as f64).exp();
Some((amount_score * work_score * usize::MAX as f64).round() as _)
})
.flatten()
}
}