chia_sdk_utils/
coin_selection.rs1use std::cmp::Reverse;
2
3use chia_protocol::Coin;
4use indexmap::IndexSet;
5use rand::{Rng, SeedableRng};
6use rand_chacha::ChaCha8Rng;
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, Error, PartialEq, Eq)]
11pub enum CoinSelectionError {
12 #[error("no spendable coins")]
14 NoSpendableCoins,
15
16 #[error("insufficient balance {0}")]
18 InsufficientBalance(u128),
19
20 #[error("exceeded max coins")]
22 ExceededMaxCoins,
23}
24
25pub fn select_coins(
27 mut spendable_coins: Vec<Coin>,
28 amount: u128,
29) -> Result<Vec<Coin>, CoinSelectionError> {
30 let max_coins = 500;
31
32 if spendable_coins.is_empty() {
34 return Err(CoinSelectionError::NoSpendableCoins);
35 }
36
37 let spendable_amount = spendable_coins
39 .iter()
40 .fold(0u128, |acc, coin| acc + u128::from(coin.amount));
41
42 if spendable_amount < amount {
43 return Err(CoinSelectionError::InsufficientBalance(spendable_amount));
44 }
45
46 spendable_coins.sort_unstable_by_key(|coin| Reverse(coin.amount));
48
49 for coin in &spendable_coins {
51 if u128::from(coin.amount) == amount {
52 return Ok(vec![*coin]);
53 }
54 }
55
56 let mut smaller_coins = IndexSet::new();
57 let mut smaller_sum = 0;
58
59 for coin in &spendable_coins {
60 let coin_amount = u128::from(coin.amount);
61
62 if coin_amount < amount {
63 smaller_coins.insert(*coin);
64 smaller_sum += coin_amount;
65 }
66 }
67
68 if smaller_sum == amount && smaller_coins.len() < max_coins && amount != 0 {
70 return Ok(smaller_coins.into_iter().collect());
71 }
72
73 if smaller_sum < amount {
75 return Ok(vec![smallest_coin_above(&spendable_coins, amount).unwrap()]);
76 }
77
78 if smaller_sum > amount {
80 if let Some(result) = knapsack_coin_algorithm(
81 &mut ChaCha8Rng::seed_from_u64(0),
82 &spendable_coins,
83 amount,
84 u128::MAX,
85 max_coins,
86 ) {
87 return Ok(result.into_iter().collect());
88 }
89
90 let summed_coins = sum_largest_coins(&spendable_coins, amount);
92
93 if summed_coins.len() <= max_coins {
94 return Ok(summed_coins.into_iter().collect());
95 }
96
97 return Err(CoinSelectionError::ExceededMaxCoins);
98 }
99
100 if let Some(coin) = smallest_coin_above(&spendable_coins, amount) {
102 return Ok(vec![coin]);
103 }
104
105 Err(CoinSelectionError::ExceededMaxCoins)
107}
108
109fn sum_largest_coins(coins: &[Coin], amount: u128) -> IndexSet<Coin> {
110 let mut selected_coins = IndexSet::new();
111 let mut selected_sum = 0;
112 for coin in coins {
113 selected_sum += u128::from(coin.amount);
114 selected_coins.insert(*coin);
115
116 if selected_sum >= amount {
117 return selected_coins;
118 }
119 }
120 unreachable!()
121}
122
123fn smallest_coin_above(coins: &[Coin], amount: u128) -> Option<Coin> {
124 if u128::from(coins[0].amount) < amount {
125 return None;
126 }
127 for coin in coins.iter().rev() {
128 if u128::from(coin.amount) >= amount {
129 return Some(*coin);
130 }
131 }
132 unreachable!();
133}
134
135pub fn knapsack_coin_algorithm(
137 rng: &mut impl Rng,
138 spendable_coins: &[Coin],
139 amount: u128,
140 max_amount: u128,
141 max_coins: usize,
142) -> Option<IndexSet<Coin>> {
143 let mut best_sum = max_amount;
144 let mut best_coins = None;
145
146 for _ in 0..1000 {
147 let mut selected_coins = IndexSet::new();
148 let mut selected_sum = 0;
149 let mut target_reached = false;
150
151 for pass in 0..2 {
152 if target_reached {
153 break;
154 }
155
156 for coin in spendable_coins {
157 let filter_first = pass != 0 || !rng.gen::<bool>();
158 let filter_second = pass != 1 || selected_coins.contains(coin);
159
160 if filter_first && filter_second {
161 continue;
162 }
163
164 if selected_coins.len() > max_coins {
165 break;
166 }
167
168 selected_sum += u128::from(coin.amount);
169 selected_coins.insert(*coin);
170
171 if selected_sum == amount {
172 return Some(selected_coins);
173 }
174
175 if selected_sum > amount {
176 target_reached = true;
177
178 if selected_sum < best_sum {
179 best_sum = selected_sum;
180 best_coins = Some(selected_coins.clone());
181
182 selected_sum -= u128::from(coin.amount);
183 selected_coins.shift_remove(coin);
184 }
185 }
186 }
187 }
188 }
189
190 best_coins
191}
192
193#[cfg(test)]
194mod tests {
195 use chia_protocol::Bytes32;
196
197 use super::*;
198
199 macro_rules! coin_list {
200 ( $( $coin:expr ),* $(,)? ) => {
201 vec![$( coin($coin) ),*]
202 };
203 }
204
205 fn coin(amount: u64) -> Coin {
206 Coin::new(Bytes32::from([0; 32]), Bytes32::from([0; 32]), amount)
207 }
208
209 #[test]
210 fn test_select_coins() {
211 let coins = coin_list![100, 200, 300, 400, 500];
212
213 let selected = select_coins(coins, 700).unwrap();
215 let expected = coin_list![400, 300];
216 assert_eq!(selected, expected);
217 }
218
219 #[test]
220 fn test_insufficient_balance() {
221 let coins = coin_list![50, 250, 100_000];
222
223 let selected = select_coins(coins, 9_999_999);
225 assert_eq!(
226 selected,
227 Err(CoinSelectionError::InsufficientBalance(100_300))
228 );
229 }
230
231 #[test]
232 fn test_no_coins() {
233 let selected = select_coins(Vec::new(), 100);
235 assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
236
237 let selected = select_coins(Vec::new(), 0);
239 assert_eq!(selected, Err(CoinSelectionError::NoSpendableCoins));
240 }
241}