chia_sdk_utils/
coin_selection.rs

1use std::cmp::Reverse;
2
3use chia_protocol::Coin;
4use indexmap::IndexSet;
5use rand::{Rng, SeedableRng};
6use rand_chacha::ChaCha8Rng;
7use thiserror::Error;
8
9/// An error that occurs when selecting coins.
10#[derive(Debug, Clone, Copy, Error, PartialEq, Eq)]
11pub enum CoinSelectionError {
12    /// There were no spendable coins to select from.
13    #[error("no spendable coins")]
14    NoSpendableCoins,
15
16    /// There weren't enough coins to reach the amount.
17    #[error("insufficient balance {0}")]
18    InsufficientBalance(u64),
19
20    /// The selected coins exceeded the maximum.
21    #[error("exceeded max coins")]
22    ExceededMaxCoins,
23}
24
25/// Uses the knapsack algorithm to select coins.
26pub fn select_coins(
27    mut spendable_coins: Vec<Coin>,
28    amount: u64,
29) -> Result<Vec<Coin>, CoinSelectionError> {
30    let amount = u128::from(amount);
31    let max_coins = 500;
32
33    // You cannot spend no coins.
34    if spendable_coins.is_empty() {
35        return Err(CoinSelectionError::NoSpendableCoins);
36    }
37
38    // Checks to ensure the balance is sufficient before continuing.
39    let spendable_amount = spendable_coins
40        .iter()
41        .fold(0u128, |acc, coin| acc + u128::from(coin.amount));
42
43    if spendable_amount < amount {
44        return Err(CoinSelectionError::InsufficientBalance(
45            spendable_amount.try_into().expect("should fit"),
46        ));
47    }
48
49    // Sorts by amount, descending.
50    spendable_coins.sort_unstable_by_key(|coin| Reverse(coin.amount));
51
52    // Exact coin match.
53    for coin in &spendable_coins {
54        if u128::from(coin.amount) == amount {
55            return Ok(vec![*coin]);
56        }
57    }
58
59    let mut smaller_coins = IndexSet::new();
60    let mut smaller_sum = 0;
61
62    for coin in &spendable_coins {
63        let coin_amount = u128::from(coin.amount);
64
65        if coin_amount < amount {
66            smaller_coins.insert(*coin);
67            smaller_sum += coin_amount;
68        }
69    }
70
71    // Check for an exact match.
72    if smaller_sum == amount && smaller_coins.len() < max_coins && amount != 0 {
73        return Ok(smaller_coins.into_iter().collect());
74    }
75
76    // There must be a single coin larger than the amount.
77    if smaller_sum < amount {
78        return Ok(vec![smallest_coin_above(&spendable_coins, amount).unwrap()]);
79    }
80
81    // Apply the knapsack algorithm otherwise.
82    if smaller_sum > amount {
83        if let Some(result) = knapsack_coin_algorithm(
84            &mut ChaCha8Rng::seed_from_u64(0),
85            &spendable_coins,
86            amount,
87            u128::MAX,
88            max_coins,
89        ) {
90            return Ok(result.into_iter().collect());
91        }
92
93        // Knapsack failed to select coins, so try summing the largest coins.
94        let summed_coins = sum_largest_coins(&spendable_coins, amount);
95
96        if summed_coins.len() <= max_coins {
97            return Ok(summed_coins.into_iter().collect());
98        }
99
100        return Err(CoinSelectionError::ExceededMaxCoins);
101    }
102
103    // Try to find a large coin to select.
104    if let Some(coin) = smallest_coin_above(&spendable_coins, amount) {
105        return Ok(vec![coin]);
106    }
107
108    // It would require too many coins to match the amount.
109    Err(CoinSelectionError::ExceededMaxCoins)
110}
111
112fn sum_largest_coins(coins: &[Coin], amount: u128) -> IndexSet<Coin> {
113    let mut selected_coins = IndexSet::new();
114    let mut selected_sum = 0;
115    for coin in coins {
116        selected_sum += u128::from(coin.amount);
117        selected_coins.insert(*coin);
118
119        if selected_sum >= amount {
120            return selected_coins;
121        }
122    }
123    unreachable!()
124}
125
126fn smallest_coin_above(coins: &[Coin], amount: u128) -> Option<Coin> {
127    if u128::from(coins[0].amount) < amount {
128        return None;
129    }
130    for coin in coins.iter().rev() {
131        if u128::from(coin.amount) >= amount {
132            return Some(*coin);
133        }
134    }
135    unreachable!();
136}
137
138/// Runs the knapsack algorithm on a set of coins, attempting to find an optimal set.
139pub fn knapsack_coin_algorithm(
140    rng: &mut impl Rng,
141    spendable_coins: &[Coin],
142    amount: u128,
143    max_amount: u128,
144    max_coins: usize,
145) -> Option<IndexSet<Coin>> {
146    let mut best_sum = max_amount;
147    let mut best_coins = None;
148
149    for _ in 0..1000 {
150        let mut selected_coins = IndexSet::new();
151        let mut selected_sum = 0;
152        let mut target_reached = false;
153
154        for pass in 0..2 {
155            if target_reached {
156                break;
157            }
158
159            for coin in spendable_coins {
160                let filter_first = pass != 0 || !rng.random::<bool>();
161                let filter_second = pass != 1 || selected_coins.contains(coin);
162
163                if filter_first && filter_second {
164                    continue;
165                }
166
167                if selected_coins.len() > max_coins {
168                    break;
169                }
170
171                selected_sum += u128::from(coin.amount);
172                selected_coins.insert(*coin);
173
174                if selected_sum == amount {
175                    return Some(selected_coins);
176                }
177
178                if selected_sum > amount {
179                    target_reached = true;
180
181                    if selected_sum < best_sum {
182                        best_sum = selected_sum;
183                        best_coins = Some(selected_coins.clone());
184
185                        selected_sum -= u128::from(coin.amount);
186                        selected_coins.shift_remove(coin);
187                    }
188                }
189            }
190        }
191    }
192
193    best_coins
194}
195
196#[cfg(test)]
197mod tests {
198    use chia_protocol::Bytes32;
199
200    use super::*;
201
202    macro_rules! coin_list {
203        ( $( $coin:expr ),* $(,)? ) => {
204            vec![$( coin($coin) ),*]
205        };
206    }
207
208    fn coin(amount: u64) -> Coin {
209        Coin::new(Bytes32::from([0; 32]), Bytes32::from([0; 32]), amount)
210    }
211
212    #[test]
213    fn test_select_coins() {
214        let coins = coin_list![100, 200, 300, 400, 500];
215
216        // Sorted by amount, ascending.
217        let selected = select_coins(coins, 700).unwrap();
218        let expected = coin_list![400, 300];
219        assert_eq!(selected, expected);
220    }
221
222    #[test]
223    fn test_insufficient_balance() {
224        let coins = coin_list![50, 250, 100_000];
225
226        // Select an amount that is too high.
227        let selected = select_coins(coins, 9_999_999);
228        assert_eq!(
229            selected,
230            Err(CoinSelectionError::InsufficientBalance(100_300))
231        );
232    }
233
234    #[test]
235    fn test_no_coins() {
236        // There is no amount to select from.
237        let selected = select_coins(Vec::new(), 100);
238        assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
239
240        // Even if the amount is zero, this should fail.
241        let selected = select_coins(Vec::new(), 0);
242        assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
243    }
244}