chia_sdk_utils/
coin_selection.rs1use std::cmp::Reverse;
2
3use chia_protocol::Coin;
4use indexmap::IndexSet;
5use rand::{Rng, SeedableRng};
6use rand_chacha::ChaCha8Rng;
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, Error, PartialEq, Eq)]
11pub enum CoinSelectionError {
12 #[error("no spendable coins")]
14 NoSpendableCoins,
15
16 #[error("insufficient balance {0}")]
18 InsufficientBalance(u64),
19
20 #[error("exceeded max coins")]
22 ExceededMaxCoins,
23}
24
25pub fn select_coins(
27 mut spendable_coins: Vec<Coin>,
28 amount: u64,
29) -> Result<Vec<Coin>, CoinSelectionError> {
30 let amount = u128::from(amount);
31 let max_coins = 500;
32
33 if spendable_coins.is_empty() {
35 return Err(CoinSelectionError::NoSpendableCoins);
36 }
37
38 let spendable_amount = spendable_coins
40 .iter()
41 .fold(0u128, |acc, coin| acc + u128::from(coin.amount));
42
43 if spendable_amount < amount {
44 return Err(CoinSelectionError::InsufficientBalance(
45 spendable_amount.try_into().expect("should fit"),
46 ));
47 }
48
49 spendable_coins.sort_unstable_by_key(|coin| Reverse(coin.amount));
51
52 for coin in &spendable_coins {
54 if u128::from(coin.amount) == amount {
55 return Ok(vec![*coin]);
56 }
57 }
58
59 let mut smaller_coins = IndexSet::new();
60 let mut smaller_sum = 0;
61
62 for coin in &spendable_coins {
63 let coin_amount = u128::from(coin.amount);
64
65 if coin_amount < amount {
66 smaller_coins.insert(*coin);
67 smaller_sum += coin_amount;
68 }
69 }
70
71 if smaller_sum == amount && smaller_coins.len() < max_coins && amount != 0 {
73 return Ok(smaller_coins.into_iter().collect());
74 }
75
76 if smaller_sum < amount {
78 return Ok(vec![smallest_coin_above(&spendable_coins, amount).unwrap()]);
79 }
80
81 if smaller_sum > amount {
83 if let Some(result) = knapsack_coin_algorithm(
84 &mut ChaCha8Rng::seed_from_u64(0),
85 &spendable_coins,
86 amount,
87 u128::MAX,
88 max_coins,
89 ) {
90 return Ok(result.into_iter().collect());
91 }
92
93 let summed_coins = sum_largest_coins(&spendable_coins, amount);
95
96 if summed_coins.len() <= max_coins {
97 return Ok(summed_coins.into_iter().collect());
98 }
99
100 return Err(CoinSelectionError::ExceededMaxCoins);
101 }
102
103 if let Some(coin) = smallest_coin_above(&spendable_coins, amount) {
105 return Ok(vec![coin]);
106 }
107
108 Err(CoinSelectionError::ExceededMaxCoins)
110}
111
112fn sum_largest_coins(coins: &[Coin], amount: u128) -> IndexSet<Coin> {
113 let mut selected_coins = IndexSet::new();
114 let mut selected_sum = 0;
115 for coin in coins {
116 selected_sum += u128::from(coin.amount);
117 selected_coins.insert(*coin);
118
119 if selected_sum >= amount {
120 return selected_coins;
121 }
122 }
123 unreachable!()
124}
125
126fn smallest_coin_above(coins: &[Coin], amount: u128) -> Option<Coin> {
127 if u128::from(coins[0].amount) < amount {
128 return None;
129 }
130 for coin in coins.iter().rev() {
131 if u128::from(coin.amount) >= amount {
132 return Some(*coin);
133 }
134 }
135 unreachable!();
136}
137
138pub fn knapsack_coin_algorithm(
140 rng: &mut impl Rng,
141 spendable_coins: &[Coin],
142 amount: u128,
143 max_amount: u128,
144 max_coins: usize,
145) -> Option<IndexSet<Coin>> {
146 let mut best_sum = max_amount;
147 let mut best_coins = None;
148
149 for _ in 0..1000 {
150 let mut selected_coins = IndexSet::new();
151 let mut selected_sum = 0;
152 let mut target_reached = false;
153
154 for pass in 0..2 {
155 if target_reached {
156 break;
157 }
158
159 for coin in spendable_coins {
160 let filter_first = pass != 0 || !rng.random::<bool>();
161 let filter_second = pass != 1 || selected_coins.contains(coin);
162
163 if filter_first && filter_second {
164 continue;
165 }
166
167 if selected_coins.len() > max_coins {
168 break;
169 }
170
171 selected_sum += u128::from(coin.amount);
172 selected_coins.insert(*coin);
173
174 if selected_sum == amount {
175 return Some(selected_coins);
176 }
177
178 if selected_sum > amount {
179 target_reached = true;
180
181 if selected_sum < best_sum {
182 best_sum = selected_sum;
183 best_coins = Some(selected_coins.clone());
184
185 selected_sum -= u128::from(coin.amount);
186 selected_coins.shift_remove(coin);
187 }
188 }
189 }
190 }
191 }
192
193 best_coins
194}
195
196#[cfg(test)]
197mod tests {
198 use chia_protocol::Bytes32;
199
200 use super::*;
201
202 macro_rules! coin_list {
203 ( $( $coin:expr ),* $(,)? ) => {
204 vec![$( coin($coin) ),*]
205 };
206 }
207
208 fn coin(amount: u64) -> Coin {
209 Coin::new(Bytes32::from([0; 32]), Bytes32::from([0; 32]), amount)
210 }
211
212 #[test]
213 fn test_select_coins() {
214 let coins = coin_list![100, 200, 300, 400, 500];
215
216 let selected = select_coins(coins, 700).unwrap();
218 let expected = coin_list![400, 300];
219 assert_eq!(selected, expected);
220 }
221
222 #[test]
223 fn test_insufficient_balance() {
224 let coins = coin_list![50, 250, 100_000];
225
226 let selected = select_coins(coins, 9_999_999);
228 assert_eq!(
229 selected,
230 Err(CoinSelectionError::InsufficientBalance(100_300))
231 );
232 }
233
234 #[test]
235 fn test_no_coins() {
236 let selected = select_coins(Vec::new(), 100);
238 assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
239
240 let selected = select_coins(Vec::new(), 0);
242 assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
243 }
244}