use alloc::collections::BTreeMap;
use std::sync::OnceLock;
use crypto::keys::bip44::Bip44;
use primitive_types::U256;
use super::{TransactionBuilder, TransactionBuilderError};
use crate::{
client::api::{
transaction_builder::{requirement::native_tokens::get_native_tokens_diff, Remainders},
RemainderData,
},
types::block::{
address::{Address, Ed25519Address},
output::{unlock_condition::AddressUnlockCondition, BasicOutputBuilder, ChainId, NativeToken, Output, TokenId},
},
};
impl TransactionBuilder {
pub(crate) fn update_remainders(&mut self) -> Result<bool, TransactionBuilderError> {
let mut old_remainders = Remainders {
address: self.remainders.address.clone(),
..Default::default()
};
core::mem::swap(&mut self.remainders, &mut old_remainders);
let (input_amount, output_amount, inputs_sdr, outputs_sdr) = self.amount_sums();
for (address, amount) in inputs_sdr {
let output_sdr_amount = *outputs_sdr.get(&address).unwrap_or(&0);
if amount > output_sdr_amount {
let diff = amount - output_sdr_amount;
let srd_output = BasicOutputBuilder::new_with_amount(diff)
.with_unlock_conditions([AddressUnlockCondition::new(address.clone())])
.finish_output()?;
log::debug!("Created storage deposit return output of {diff} for {address:?}");
self.remainders.storage_deposit_returns.push(srd_output);
}
}
let (input_nts, output_nts) = self.get_input_output_native_tokens();
log::debug!("input_nts: {input_nts:#?}");
log::debug!("output_nts: {output_nts:#?}");
let native_tokens_diff = get_native_tokens_diff(input_nts, output_nts);
let (input_mana, output_mana) = self.mana_sums(false)?;
let mut amount_diff = input_amount.checked_sub(output_amount).expect("amount underflow");
let mut mana_diff = input_mana.checked_sub(output_mana).expect("mana underflow");
if self.burn.as_ref().map_or(false, |b| b.mana()) {
mana_diff = mana_diff.saturating_sub(self.initial_mana_excess()?);
}
let (remainder_address, chain) = self
.get_remainder_address()?
.ok_or(TransactionBuilderError::MissingInputWithEd25519Address)?;
let nt_min_amount = self.native_token_remainder().amount() * native_tokens_diff.len() as u64;
if amount_diff > nt_min_amount {
for (chain_id, (input_amount, output_amount)) in self.amount_chains()? {
if input_amount > output_amount
&& (self.output_for_remainder_exists(Some(chain_id), &remainder_address)
|| self.output_for_remainder_exists(None, &remainder_address))
{
let amount_to_add = (amount_diff - nt_min_amount).min(input_amount - output_amount);
log::debug!(
"Allocating {amount_to_add} excess input amount for output with address {remainder_address} and chain id {chain_id}"
);
amount_diff -= amount_to_add;
self.remainders.added_amount.insert(Some(chain_id), amount_to_add);
}
}
if amount_diff > nt_min_amount && self.output_for_remainder_exists(None, &remainder_address) {
let amount_to_add = amount_diff - nt_min_amount;
log::debug!(
"Allocating {amount_to_add} excess input amount for output with address {remainder_address}"
);
amount_diff = nt_min_amount;
self.remainders.added_amount.insert(None, amount_to_add);
}
}
if mana_diff > 0 {
for (chain_id, (input_mana, output_mana)) in self.mana_chains()? {
if input_mana > output_mana
&& (self.output_for_remainder_exists(Some(chain_id), &remainder_address)
|| self.output_for_remainder_exists(None, &remainder_address))
{
let mana_to_add = mana_diff.min(input_mana - output_mana);
log::debug!(
"Allocating {mana_to_add} excess input mana for output with address {remainder_address} and chain id {chain_id}"
);
mana_diff -= mana_to_add;
self.remainders.added_mana.insert(Some(chain_id), mana_to_add);
}
}
if mana_diff > 0 && self.output_for_remainder_exists(None, &remainder_address) {
log::debug!("Allocating {mana_diff} excess input mana for output with address {remainder_address}");
self.remainders.added_mana.insert(None, std::mem::take(&mut mana_diff));
}
}
if amount_diff == 0 && mana_diff == 0 && native_tokens_diff.is_empty() {
log::debug!("No remainder required");
} else {
self.create_remainder_outputs(amount_diff, mana_diff, native_tokens_diff, remainder_address, chain)?;
}
Ok(self.remainders != old_remainders)
}
pub(crate) fn get_remainder_address(&self) -> Result<Option<(Address, Option<Bip44>)>, TransactionBuilderError> {
if let Some(remainder_address) = &self.remainders.address {
for input in self.available_inputs.iter().chain(self.selected_inputs.iter()) {
let required_address = input
.output
.required_address(
self.latest_slot_commitment_id.slot_index(),
self.protocol_parameters.committable_age_range(),
)?
.expect("expiration unlockable outputs already filtered out");
if &required_address == remainder_address {
return Ok(Some((remainder_address.clone(), input.chain)));
}
}
return Ok(Some((remainder_address.clone(), None)));
}
for input in self.selected_inputs.iter() {
let required_address = input
.output
.required_address(
self.latest_slot_commitment_id.slot_index(),
self.protocol_parameters.committable_age_range(),
)?
.expect("expiration unlockable outputs already filtered out");
if let Some(&required_address) = required_address.backing_ed25519() {
return Ok(Some((required_address.into(), input.chain)));
}
}
Ok(None)
}
fn output_for_remainder_exists(&self, chain_id: Option<ChainId>, remainder_address: &Address) -> bool {
self.added_outputs.iter().any(|o| {
(o.chain_id() == chain_id
|| (chain_id.is_none()
&& (o.is_basic() || o.is_account() || o.is_nft())
&& matches!(o.required_address(
self.latest_slot_commitment_id.slot_index(),
self.protocol_parameters.committable_age_range(),
), Ok(Some(address)) if &address == remainder_address)))
&& o.unlock_conditions().expiration().is_none()
&& o.unlock_conditions().timelock().is_none()
})
}
pub(crate) fn get_output_for_remainder(
&mut self,
chain_id: Option<ChainId>,
remainder_address: &Address,
) -> Option<&mut Output> {
self.added_outputs.iter_mut().find(|o| {
(o.chain_id() == chain_id
|| (chain_id.is_none()
&& (o.is_basic() || o.is_account() || o.is_nft())
&& matches!(o.required_address(
self.latest_slot_commitment_id.slot_index(),
self.protocol_parameters.committable_age_range(),
), Ok(Some(address)) if &address == remainder_address)))
&& o.unlock_conditions().expiration().is_none()
&& o.unlock_conditions().timelock().is_none()
})
}
pub(crate) fn required_remainder_amount(&mut self) -> Result<(u64, bool, bool), TransactionBuilderError> {
let (input_nts, output_nts) = self.get_input_output_native_tokens();
let remainder_native_tokens = get_native_tokens_diff(input_nts, output_nts);
let remainder_amount = if !remainder_native_tokens.is_empty() {
self.native_token_remainder().amount() * remainder_native_tokens.len() as u64
} else {
self.basic_remainder().amount()
};
let (selected_mana, required_mana) = self.mana_sums(false)?;
let remainder_address = self.get_remainder_address()?.map(|v| v.0);
let mana_chains = self.mana_chains()?;
let mut mana_remainder = selected_mana > required_mana
&& remainder_address.map_or(true, |remainder_address| {
let mut mana_diff = selected_mana - required_mana;
for (chain_id, (mana_in, mana_out)) in mana_chains {
if mana_in > mana_out && self.output_for_remainder_exists(Some(chain_id), &remainder_address) {
mana_diff -= mana_diff.min(mana_in - mana_out);
if mana_diff == 0 {
return false;
}
}
}
mana_diff > 0 && !self.output_for_remainder_exists(None, &remainder_address)
});
if self.burn.as_ref().map_or(false, |b| b.mana()) {
let initial_excess = self.initial_mana_excess()?;
mana_remainder &= selected_mana > required_mana + initial_excess;
}
Ok((remainder_amount, !remainder_native_tokens.is_empty(), mana_remainder))
}
fn create_remainder_outputs(
&mut self,
amount_diff: u64,
mana_diff: u64,
mut native_tokens: BTreeMap<TokenId, U256>,
remainder_address: Address,
remainder_address_chain: Option<Bip44>,
) -> Result<(), TransactionBuilderError> {
let mut remaining_amount = amount_diff;
let mut catchall_native_token = None;
if let Some((token_id, amount)) = native_tokens.pop_last() {
catchall_native_token.replace(NativeToken::new(token_id, amount)?);
for (token_id, amount) in native_tokens {
let output =
BasicOutputBuilder::new_with_minimum_amount(self.protocol_parameters.storage_score_parameters())
.add_unlock_condition(AddressUnlockCondition::new(remainder_address.clone()))
.with_native_token(NativeToken::new(token_id, amount)?)
.finish_output()?;
log::debug!(
"Created remainder output of amount {}, mana {} and native token ({token_id}: {amount}) for {remainder_address:?}",
output.amount(),
output.mana()
);
remaining_amount = remaining_amount.saturating_sub(output.amount());
self.remainders.data.push(RemainderData {
output,
chain: remainder_address_chain,
address: remainder_address.clone(),
});
}
}
let mut catchall = BasicOutputBuilder::new_with_amount(remaining_amount)
.with_mana(mana_diff)
.add_unlock_condition(AddressUnlockCondition::new(remainder_address.clone()));
if let Some(native_token) = catchall_native_token {
catchall = catchall.with_native_token(native_token);
}
let catchall = catchall.finish_output()?;
catchall.verify_storage_deposit(self.protocol_parameters.storage_score_parameters())?;
log::debug!(
"Created remainder output of amount {}, mana {} and native token {:?} for {remainder_address:?}",
catchall.amount(),
catchall.mana(),
catchall.native_token(),
);
self.remainders.data.push(RemainderData {
output: catchall,
chain: remainder_address_chain,
address: remainder_address,
});
Ok(())
}
pub(crate) fn basic_remainder(&self) -> &'static Output {
static OUTPUT_LOCK: OnceLock<Output> = OnceLock::new();
OUTPUT_LOCK.get_or_init(|| {
BasicOutputBuilder::new_with_minimum_amount(self.protocol_parameters.storage_score_parameters())
.add_unlock_condition(AddressUnlockCondition::new(Ed25519Address::null()))
.finish_output()
.unwrap()
})
}
pub(crate) fn native_token_remainder(&self) -> &'static Output {
static OUTPUT_LOCK: OnceLock<Output> = OnceLock::new();
OUTPUT_LOCK.get_or_init(|| {
BasicOutputBuilder::new_with_minimum_amount(self.protocol_parameters.storage_score_parameters())
.add_unlock_condition(AddressUnlockCondition::new(Ed25519Address::null()))
.with_native_token(NativeToken::new(TokenId::null(), 1).unwrap())
.finish_output()
.unwrap()
})
}
}