chia_sdk_utils/
coin_selection.rs

1use std::cmp::Reverse;
2
3use chia_protocol::Coin;
4use indexmap::IndexSet;
5use rand::{Rng, SeedableRng};
6use rand_chacha::ChaCha8Rng;
7use thiserror::Error;
8
9/// An error that occurs when selecting coins.
10#[derive(Debug, Clone, Copy, Error, PartialEq, Eq)]
11pub enum CoinSelectionError {
12    /// There were no spendable coins to select from.
13    #[error("no spendable coins")]
14    NoSpendableCoins,
15
16    /// There weren't enough coins to reach the amount.
17    #[error("insufficient balance {0}")]
18    InsufficientBalance(u128),
19
20    /// The selected coins exceeded the maximum.
21    #[error("exceeded max coins")]
22    ExceededMaxCoins,
23}
24
25/// Uses the knapsack algorithm to select coins.
26pub fn select_coins(
27    mut spendable_coins: Vec<Coin>,
28    amount: u128,
29) -> Result<Vec<Coin>, CoinSelectionError> {
30    let max_coins = 500;
31
32    // You cannot spend no coins.
33    if spendable_coins.is_empty() {
34        return Err(CoinSelectionError::NoSpendableCoins);
35    }
36
37    // Checks to ensure the balance is sufficient before continuing.
38    let spendable_amount = spendable_coins
39        .iter()
40        .fold(0u128, |acc, coin| acc + u128::from(coin.amount));
41
42    if spendable_amount < amount {
43        return Err(CoinSelectionError::InsufficientBalance(spendable_amount));
44    }
45
46    // Sorts by amount, descending.
47    spendable_coins.sort_unstable_by_key(|coin| Reverse(coin.amount));
48
49    // Exact coin match.
50    for coin in &spendable_coins {
51        if u128::from(coin.amount) == amount {
52            return Ok(vec![*coin]);
53        }
54    }
55
56    let mut smaller_coins = IndexSet::new();
57    let mut smaller_sum = 0;
58
59    for coin in &spendable_coins {
60        let coin_amount = u128::from(coin.amount);
61
62        if coin_amount < amount {
63            smaller_coins.insert(*coin);
64            smaller_sum += coin_amount;
65        }
66    }
67
68    // Check for an exact match.
69    if smaller_sum == amount && smaller_coins.len() < max_coins && amount != 0 {
70        return Ok(smaller_coins.into_iter().collect());
71    }
72
73    // There must be a single coin larger than the amount.
74    if smaller_sum < amount {
75        return Ok(vec![smallest_coin_above(&spendable_coins, amount).unwrap()]);
76    }
77
78    // Apply the knapsack algorithm otherwise.
79    if smaller_sum > amount {
80        if let Some(result) = knapsack_coin_algorithm(
81            &mut ChaCha8Rng::seed_from_u64(0),
82            &spendable_coins,
83            amount,
84            u128::MAX,
85            max_coins,
86        ) {
87            return Ok(result.into_iter().collect());
88        }
89
90        // Knapsack failed to select coins, so try summing the largest coins.
91        let summed_coins = sum_largest_coins(&spendable_coins, amount);
92
93        if summed_coins.len() <= max_coins {
94            return Ok(summed_coins.into_iter().collect());
95        }
96
97        return Err(CoinSelectionError::ExceededMaxCoins);
98    }
99
100    // Try to find a large coin to select.
101    if let Some(coin) = smallest_coin_above(&spendable_coins, amount) {
102        return Ok(vec![coin]);
103    }
104
105    // It would require too many coins to match the amount.
106    Err(CoinSelectionError::ExceededMaxCoins)
107}
108
109fn sum_largest_coins(coins: &[Coin], amount: u128) -> IndexSet<Coin> {
110    let mut selected_coins = IndexSet::new();
111    let mut selected_sum = 0;
112    for coin in coins {
113        selected_sum += u128::from(coin.amount);
114        selected_coins.insert(*coin);
115
116        if selected_sum >= amount {
117            return selected_coins;
118        }
119    }
120    unreachable!()
121}
122
123fn smallest_coin_above(coins: &[Coin], amount: u128) -> Option<Coin> {
124    if u128::from(coins[0].amount) < amount {
125        return None;
126    }
127    for coin in coins.iter().rev() {
128        if u128::from(coin.amount) >= amount {
129            return Some(*coin);
130        }
131    }
132    unreachable!();
133}
134
135/// Runs the knapsack algorithm on a set of coins, attempting to find an optimal set.
136pub fn knapsack_coin_algorithm(
137    rng: &mut impl Rng,
138    spendable_coins: &[Coin],
139    amount: u128,
140    max_amount: u128,
141    max_coins: usize,
142) -> Option<IndexSet<Coin>> {
143    let mut best_sum = max_amount;
144    let mut best_coins = None;
145
146    for _ in 0..1000 {
147        let mut selected_coins = IndexSet::new();
148        let mut selected_sum = 0;
149        let mut target_reached = false;
150
151        for pass in 0..2 {
152            if target_reached {
153                break;
154            }
155
156            for coin in spendable_coins {
157                let filter_first = pass != 0 || !rng.gen::<bool>();
158                let filter_second = pass != 1 || selected_coins.contains(coin);
159
160                if filter_first && filter_second {
161                    continue;
162                }
163
164                if selected_coins.len() > max_coins {
165                    break;
166                }
167
168                selected_sum += u128::from(coin.amount);
169                selected_coins.insert(*coin);
170
171                if selected_sum == amount {
172                    return Some(selected_coins);
173                }
174
175                if selected_sum > amount {
176                    target_reached = true;
177
178                    if selected_sum < best_sum {
179                        best_sum = selected_sum;
180                        best_coins = Some(selected_coins.clone());
181
182                        selected_sum -= u128::from(coin.amount);
183                        selected_coins.shift_remove(coin);
184                    }
185                }
186            }
187        }
188    }
189
190    best_coins
191}
192
193#[cfg(test)]
194mod tests {
195    use chia_protocol::Bytes32;
196
197    use super::*;
198
199    macro_rules! coin_list {
200        ( $( $coin:expr ),* $(,)? ) => {
201            vec![$( coin($coin) ),*]
202        };
203    }
204
205    fn coin(amount: u64) -> Coin {
206        Coin::new(Bytes32::from([0; 32]), Bytes32::from([0; 32]), amount)
207    }
208
209    #[test]
210    fn test_select_coins() {
211        let coins = coin_list![100, 200, 300, 400, 500];
212
213        // Sorted by amount, ascending.
214        let selected = select_coins(coins, 700).unwrap();
215        let expected = coin_list![400, 300];
216        assert_eq!(selected, expected);
217    }
218
219    #[test]
220    fn test_insufficient_balance() {
221        let coins = coin_list![50, 250, 100_000];
222
223        // Select an amount that is too high.
224        let selected = select_coins(coins, 9_999_999);
225        assert_eq!(
226            selected,
227            Err(CoinSelectionError::InsufficientBalance(100_300))
228        );
229    }
230
231    #[test]
232    fn test_no_coins() {
233        // There is no amount to select from.
234        let selected = select_coins(Vec::new(), 100);
235        assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
236
237        // Even if the amount is zero, this should fail.
238        let selected = select_coins(Vec::new(), 0);
239        assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
240    }
241}