use std::cmp::Ordering;
pub fn compute_sat_recipients(splits: &[u64], total_sats: u64) -> Vec<u64> {
let num_recipients = splits.len();
if splits.is_empty() || total_sats == 0 {
return vec![0; num_recipients];
}
let total_split: u128 = splits.iter().map(|&x| x as u128).sum();
if total_split == 0 {
let base_amount = total_sats / num_recipients as u64;
let remainder = total_sats % num_recipients as u64;
let mut result = vec![base_amount; num_recipients];
for result_item in result.iter_mut().take(remainder as usize) {
*result_item += 1;
}
return result;
}
let mut sat_amounts: Vec<u64> = splits
.iter()
.map(|&split| ((split as u128 * total_sats as u128) / total_split) as u64)
.collect();
for amount in sat_amounts.iter_mut() {
if *amount == 0 {
*amount = 1;
}
}
let mut balance: i64 = total_sats as i64 - sat_amounts.iter().sum::<u64>() as i64;
let mut indexed_splits: Vec<(usize, u128)> = splits
.iter()
.cloned()
.enumerate()
.map(|(i, s)| (i, s as u128))
.collect();
match balance.cmp(&0) {
Ordering::Less => {
indexed_splits.sort_by(|&(i1, s1), &(i2, s2)| s1.cmp(&s2).then(i2.cmp(&i1)));
let initial_balance = balance;
for &(index, split) in &indexed_splits {
let amount_to_remove =
(split * initial_balance.unsigned_abs() as u128 / total_split).max(1);
if balance < 0 && sat_amounts[index] > amount_to_remove as u64 {
sat_amounts[index] -= amount_to_remove as u64;
balance += amount_to_remove as i64;
}
if balance == 0 {
break;
}
}
if balance < 0 {
let mut indexed_splits_with_amounts: Vec<(usize, u128, u64)> = indexed_splits
.iter()
.map(|&(index, split)| (index, split, sat_amounts[index]))
.collect();
indexed_splits_with_amounts.sort_by(|&(i1, s1, a1), &(i2, s2, a2)| {
match (a1 >= 2, a2 >= 2) {
(true, true) => s1.cmp(&s2),
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
(false, false) => match (a1 >= 1, a2 >= 1) {
(true, true) => s1.cmp(&s2),
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
(false, false) => i2.cmp(&i1),
},
}
});
indexed_splits = indexed_splits_with_amounts
.into_iter()
.map(|(i, s, _)| (i, s))
.collect();
for &(index, _) in &indexed_splits {
if balance < 0 && sat_amounts[index] > 0 {
sat_amounts[index] -= 1;
balance += 1;
}
if balance == 0 {
break;
}
}
}
}
Ordering::Equal => {}
Ordering::Greater => {
indexed_splits.sort_by_key(|&(_, split)| std::cmp::Reverse(split));
for &(index, _) in &indexed_splits {
if balance > 0 {
sat_amounts[index] += 1;
balance -= 1;
} else {
break;
}
}
}
}
sat_amounts
}
pub fn compute_sat_recipients_generic<T: HasSplit + Clone>(
values: &[T],
total_sats: u64,
) -> Vec<u64> {
let splits: Vec<u64> = values.iter().map(|v| v.get_split()).collect();
compute_sat_recipients(&splits, total_sats)
}
pub enum GenericRecipient {
ShareBased {
num_shares: u64,
},
PercentageBased {
percentage: u64,
},
}
fn gcd(a: u128, b: u128) -> u128 {
if b == 0 {
a
} else {
gcd(b, a % b)
}
}
#[derive(PartialEq)]
pub enum RecipientsToSplitsError {
TotalFeeExceeds100,
FeeIs100ButNonFeeRecipientsExist,
}
impl std::fmt::Display for RecipientsToSplitsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecipientsToSplitsError::TotalFeeExceeds100 => {
write!(f, "Total fees exceeds 100%")
}
RecipientsToSplitsError::FeeIs100ButNonFeeRecipientsExist => {
write!(f, "Total fees equal 100%, but non-fee recipients exist")
}
}
}
}
impl std::fmt::Debug for RecipientsToSplitsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self)
}
}
pub fn fee_recipients_to_splits(
recipients: &[GenericRecipient],
) -> Result<Vec<u64>, RecipientsToSplitsError> {
let total_percentage: u128 = recipients
.iter()
.filter_map(|r| match r {
GenericRecipient::PercentageBased { percentage } => Some(*percentage as u128),
_ => None,
})
.sum();
if total_percentage > 100 {
return Err(RecipientsToSplitsError::TotalFeeExceeds100);
}
let share_recipients: Vec<&GenericRecipient> = recipients
.iter()
.filter(|r| matches!(r, GenericRecipient::ShareBased { .. }))
.collect();
if total_percentage == 100 && !share_recipients.is_empty() {
return Err(RecipientsToSplitsError::FeeIs100ButNonFeeRecipientsExist);
}
let remaining_percentage: u128 = 100 - total_percentage;
let total_shares: u128 = share_recipients
.iter()
.filter_map(|r| match r {
GenericRecipient::ShareBased { num_shares } => Some(*num_shares as u128),
_ => None,
})
.sum();
let mut result = Vec::with_capacity(recipients.len());
for recipient in recipients {
match recipient {
GenericRecipient::ShareBased { num_shares } => {
let value = (*num_shares as u128) * remaining_percentage;
result.push(value);
}
GenericRecipient::PercentageBased { percentage } => {
let value = if share_recipients.is_empty() {
*percentage as u128
} else {
(*percentage as u128) * total_shares
};
result.push(value);
}
}
}
let gcd_value = result
.iter()
.filter(|&&x| x != 0)
.fold(0u128, |acc, &x| gcd(acc, x));
if gcd_value > 1 {
result = result
.into_iter()
.map(|x| if x == 0 { 0 } else { x / gcd_value })
.collect();
}
let max_value = *result.iter().max().unwrap_or(&0);
if max_value > u64::MAX as u128 {
let scale_factor = (u64::MAX as f64) / (max_value as f64);
Ok(result
.into_iter()
.map(|x| (x as f64 * scale_factor).round() as u64)
.collect())
} else {
Ok(result.into_iter().map(|x| x as u64).collect())
}
}
pub fn fee_recipients_to_splits_generic<T: Into<GenericRecipient> + HasSplit + Clone>(
recipients: &[T],
) -> Result<Vec<T>, RecipientsToSplitsError> {
let generic_recipients: Vec<GenericRecipient> =
recipients.iter().map(|r| (*r).clone().into()).collect();
let splits = fee_recipients_to_splits(&generic_recipients)?;
let mut result = Vec::with_capacity(recipients.len());
for (recipient, &split) in recipients.iter().zip(splits.iter()) {
let mut recipient = (*recipient).clone();
recipient.set_split(split);
result.push(recipient);
}
Ok(result)
}
pub fn use_remote_splits(
local_splits: &[u64],
remote_splits: &[u64],
remote_percentage: u64,
) -> (Vec<u64>, Vec<u64>) {
let remote_percentage = remote_percentage.min(100) as u128;
let local_percentage = 100u128 - remote_percentage;
let total_local: u128 = local_splits.iter().map(|&x| x as u128).sum();
let total_remote: u128 = remote_splits.iter().map(|&x| x as u128).sum();
if total_local == 0 || total_remote == 0 {
let all_values = local_splits
.iter()
.chain(remote_splits.iter())
.copied()
.collect::<Vec<_>>();
let gcd_value = all_values
.iter()
.filter(|&&x| x != 0)
.fold(0u128, |acc, &x| gcd(acc, x as u128));
let new_local_splits = local_splits
.iter()
.map(|&x| {
if x == 0 {
0
} else {
(x as u128 / gcd_value) as u64
}
})
.collect();
let new_remote_splits = remote_splits
.iter()
.map(|&x| {
if x == 0 {
0
} else {
(x as u128 / gcd_value) as u64
}
})
.collect();
return (new_local_splits, new_remote_splits);
}
let scaled_local: Vec<u128> = local_splits
.iter()
.map(|&split| (split as u128) * local_percentage * total_remote)
.collect();
let scaled_remote: Vec<u128> = remote_splits
.iter()
.map(|&split| (split as u128) * remote_percentage * total_local)
.collect();
let mut all_values = scaled_local.clone();
all_values.extend(scaled_remote.iter().cloned());
let gcd_value = all_values
.iter()
.filter(|&&x| x != 0)
.fold(0u128, |acc, &x| gcd(acc, x));
let final_local: Vec<u64> = scaled_local
.into_iter()
.map(|x| if x == 0 { 0 } else { (x / gcd_value) as u64 })
.collect();
let final_remote: Vec<u64> = scaled_remote
.into_iter()
.map(|x| if x == 0 { 0 } else { (x / gcd_value) as u64 })
.collect();
(final_local, final_remote)
}
pub trait HasSplit {
fn set_split(&mut self, split: u64);
fn get_split(&self) -> u64;
}
pub fn use_remote_splits_generic<T: HasSplit + Clone>(
local_values: &[T],
remote_values: &[T],
remote_percentage: u64,
) -> Vec<T> {
let local_splits: Vec<u64> = local_values.iter().map(|v| v.get_split()).collect();
let remote_splits: Vec<u64> = remote_values.iter().map(|v| v.get_split()).collect();
let (new_local_splits, new_remote_splits) =
use_remote_splits(&local_splits, &remote_splits, remote_percentage);
let mut result = Vec::with_capacity(local_values.len() + remote_values.len());
for (value, &split) in local_values.iter().zip(new_local_splits.iter()) {
let mut value = value.clone();
value.set_split(split);
result.push(value);
}
for (value, &split) in remote_values.iter().zip(new_remote_splits.iter()) {
let mut value = value.clone();
value.set_split(split);
result.push(value);
}
result
}