1use crate::error::{CoinJoinError, Result};
4use crate::types::{FeeStrategy, InputRef, OutputDef, Participant};
5use sha2::{Digest, Sha256};
6
7pub struct CoinJoinBuilder {
12 participants: Vec<Participant>,
14 output_amount: Option<u64>,
16 fee_rate: f64,
18 fee_strategy: FeeStrategy,
20 min_participants: usize,
22}
23
24impl CoinJoinBuilder {
25 pub fn new() -> Self {
27 Self {
28 participants: Vec::new(),
29 output_amount: None,
30 fee_rate: 1.0,
31 fee_strategy: FeeStrategy::Equal,
32 min_participants: 2,
33 }
34 }
35
36 pub fn add_participant(&mut self, participant: Participant) -> &mut Self {
38 self.participants.push(participant);
39 self
40 }
41
42 pub fn add_participant_simple(
44 &mut self,
45 id: impl Into<String>,
46 inputs: Vec<InputRef>,
47 output_script: Vec<u8>,
48 ) -> &mut Self {
49 self.participants
50 .push(Participant::new(id, inputs, output_script));
51 self
52 }
53
54 pub fn set_output_amount(&mut self, amount: u64) -> &mut Self {
56 self.output_amount = Some(amount);
57 self
58 }
59
60 pub fn set_fee_rate(&mut self, rate: f64) -> &mut Self {
62 self.fee_rate = rate;
63 self
64 }
65
66 pub fn set_fee_strategy(&mut self, strategy: FeeStrategy) -> &mut Self {
68 self.fee_strategy = strategy;
69 self
70 }
71
72 pub fn set_min_participants(&mut self, min: usize) -> &mut Self {
74 self.min_participants = min;
75 self
76 }
77
78 pub fn build(&self) -> Result<CoinJoinTransaction> {
80 if self.participants.len() < self.min_participants {
82 return Err(CoinJoinError::NoParticipants);
83 }
84
85 let output_amount = self
86 .output_amount
87 .ok_or_else(|| CoinJoinError::InvalidAmount("Output amount not set".into()))?;
88
89 let _total_inputs: u64 = self.participants.iter().map(|p| p.total_input()).sum();
91 let estimated_size = self.estimate_tx_size();
92 let total_fee = (estimated_size as f64 * self.fee_rate) as u64;
93
94 let fees = self.calculate_fees(total_fee)?;
96
97 for (i, participant) in self.participants.iter().enumerate() {
99 let needed = output_amount + fees[i];
100 let available = participant.total_input();
101 if available < needed {
102 return Err(CoinJoinError::InsufficientFunds { needed, available });
103 }
104 }
105
106 let mut inputs: Vec<(InputRef, usize)> = Vec::new();
108 for (idx, participant) in self.participants.iter().enumerate() {
109 for input in &participant.inputs {
110 inputs.push((input.clone(), idx));
111 }
112 }
113
114 let mut outputs: Vec<(OutputDef, usize)> = Vec::new();
116 for (idx, participant) in self.participants.iter().enumerate() {
117 outputs.push((
118 OutputDef::new(output_amount, participant.output_script.clone()),
119 idx,
120 ));
121 }
122
123 let mut change_outputs: Vec<(OutputDef, usize)> = Vec::new();
125 for (idx, participant) in self.participants.iter().enumerate() {
126 let change = participant.total_input() - output_amount - fees[idx];
127 if change > 546 {
128 if let Some(change_script) = &participant.change_script {
130 change_outputs.push((OutputDef::new(change, change_script.clone()), idx));
131 }
132 }
133 }
134
135 let shuffled_inputs = shuffle_with_seed(&inputs, &self.generate_shuffle_seed());
137 let shuffled_outputs = shuffle_with_seed(&outputs, &self.generate_shuffle_seed());
138
139 Ok(CoinJoinTransaction {
140 inputs: shuffled_inputs.into_iter().map(|(i, _)| i).collect(),
141 outputs: shuffled_outputs.into_iter().map(|(o, _)| o).collect(),
142 change_outputs: change_outputs.into_iter().map(|(o, _)| o).collect(),
143 participant_count: self.participants.len(),
144 output_amount,
145 total_fee,
146 })
147 }
148
149 fn calculate_fees(&self, total_fee: u64) -> Result<Vec<u64>> {
151 let n = self.participants.len();
152 if n == 0 {
153 return Err(CoinJoinError::NoParticipants);
154 }
155
156 match self.fee_strategy {
157 FeeStrategy::Equal => {
158 let per_participant = total_fee / n as u64;
159 let remainder = total_fee % n as u64;
160 let mut fees: Vec<u64> = vec![per_participant; n];
161 fees[0] += remainder;
163 Ok(fees)
164 }
165 FeeStrategy::Proportional => {
166 let total_input: u64 = self.participants.iter().map(|p| p.total_input()).sum();
167 if total_input == 0 {
168 return Err(CoinJoinError::FeeError("No inputs".into()));
169 }
170 let fees: Vec<u64> = self
171 .participants
172 .iter()
173 .map(|p| (p.total_input() as f64 / total_input as f64 * total_fee as f64) as u64)
174 .collect();
175 Ok(fees)
176 }
177 FeeStrategy::SinglePayer(idx) => {
178 if idx >= n {
179 return Err(CoinJoinError::FeeError("Invalid payer index".into()));
180 }
181 let mut fees = vec![0u64; n];
182 fees[idx] = total_fee;
183 Ok(fees)
184 }
185 }
186 }
187
188 fn estimate_tx_size(&self) -> usize {
190 let input_count: usize = self.participants.iter().map(|p| p.inputs.len()).sum();
191 let output_count = self.participants.len() * 2; 10 + 68 * input_count + 34 * output_count
195 }
196
197 fn generate_shuffle_seed(&self) -> [u8; 32] {
199 let mut hasher = Sha256::new();
200 for participant in &self.participants {
201 hasher.update(participant.id.as_bytes());
202 for input in &participant.inputs {
203 hasher.update(input.txid);
204 hasher.update(input.vout.to_le_bytes());
205 }
206 }
207 let result = hasher.finalize();
208 let mut seed = [0u8; 32];
209 seed.copy_from_slice(&result);
210 seed
211 }
212}
213
214impl Default for CoinJoinBuilder {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct CoinJoinTransaction {
223 pub inputs: Vec<InputRef>,
225 pub outputs: Vec<OutputDef>,
227 pub change_outputs: Vec<OutputDef>,
229 pub participant_count: usize,
231 pub output_amount: u64,
233 pub total_fee: u64,
235}
236
237impl CoinJoinTransaction {
238 pub fn total_input(&self) -> u64 {
240 self.inputs.iter().map(|i| i.amount).sum()
241 }
242
243 pub fn total_output(&self) -> u64 {
245 self.outputs.iter().map(|o| o.amount).sum()
246 }
247
248 pub fn total_change(&self) -> u64 {
250 self.change_outputs.iter().map(|o| o.amount).sum()
251 }
252
253 pub fn verify_equal_outputs(&self) -> bool {
255 self.outputs.iter().all(|o| o.amount == self.output_amount)
256 }
257
258 pub fn all_outputs(&self) -> Vec<&OutputDef> {
260 self.outputs
261 .iter()
262 .chain(self.change_outputs.iter())
263 .collect()
264 }
265}
266
267fn shuffle_with_seed<T: Clone>(items: &[T], seed: &[u8; 32]) -> Vec<T> {
269 if items.is_empty() {
270 return Vec::new();
271 }
272
273 let mut result: Vec<(T, u64)> = items
274 .iter()
275 .enumerate()
276 .map(|(i, item)| {
277 let mut hasher = Sha256::new();
278 hasher.update(seed);
279 hasher.update(i.to_le_bytes());
280 let hash = hasher.finalize();
281 let sort_key = u64::from_le_bytes(hash[0..8].try_into().unwrap());
282 (item.clone(), sort_key)
283 })
284 .collect();
285
286 result.sort_by_key(|(_, key)| *key);
287 result.into_iter().map(|(item, _)| item).collect()
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_builder_creation() {
296 let builder = CoinJoinBuilder::new();
297 assert_eq!(builder.participants.len(), 0);
298 assert_eq!(builder.min_participants, 2);
299 }
300
301 #[test]
302 fn test_add_participant() {
303 let mut builder = CoinJoinBuilder::new();
304 builder.add_participant_simple(
305 "alice",
306 vec![InputRef::from_outpoint([1u8; 32], 0, 100_000)],
307 vec![0x00, 0x14],
308 );
309 assert_eq!(builder.participants.len(), 1);
310 }
311
312 #[test]
313 fn test_build_coinjoin() {
314 let mut builder = CoinJoinBuilder::new();
315
316 builder.add_participant_simple(
317 "alice",
318 vec![InputRef::from_outpoint([1u8; 32], 0, 100_000)],
319 vec![0x00, 0x14, 0x01],
320 );
321 builder.add_participant_simple(
322 "bob",
323 vec![InputRef::from_outpoint([2u8; 32], 0, 100_000)],
324 vec![0x00, 0x14, 0x02],
325 );
326 builder.set_output_amount(50_000);
327 builder.set_fee_rate(1.0);
328
329 let tx = builder.build().unwrap();
330
331 assert_eq!(tx.participant_count, 2);
332 assert_eq!(tx.inputs.len(), 2);
333 assert_eq!(tx.outputs.len(), 2);
334 assert!(tx.verify_equal_outputs());
335 }
336
337 #[test]
338 fn test_insufficient_funds() {
339 let mut builder = CoinJoinBuilder::new();
340
341 builder.add_participant_simple(
342 "alice",
343 vec![InputRef::from_outpoint([1u8; 32], 0, 10_000)],
344 vec![0x00, 0x14],
345 );
346 builder.add_participant_simple(
347 "bob",
348 vec![InputRef::from_outpoint([2u8; 32], 0, 10_000)],
349 vec![0x00, 0x14],
350 );
351 builder.set_output_amount(50_000);
352
353 let result = builder.build();
354 assert!(matches!(result, Err(CoinJoinError::InsufficientFunds { .. })));
355 }
356
357 #[test]
358 fn test_fee_strategies() {
359 let mut builder = CoinJoinBuilder::new();
360 builder.add_participant_simple(
361 "alice",
362 vec![InputRef::from_outpoint([1u8; 32], 0, 100_000)],
363 vec![0x00, 0x14],
364 );
365 builder.add_participant_simple(
366 "bob",
367 vec![InputRef::from_outpoint([2u8; 32], 0, 200_000)],
368 vec![0x00, 0x14],
369 );
370
371 let fees = builder.calculate_fees(1000).unwrap();
373 assert_eq!(fees[0], 500);
374 assert_eq!(fees[1], 500);
375
376 builder.set_fee_strategy(FeeStrategy::Proportional);
378 let fees = builder.calculate_fees(1000).unwrap();
379 assert!(fees[1] > fees[0]); builder.set_fee_strategy(FeeStrategy::SinglePayer(1));
383 let fees = builder.calculate_fees(1000).unwrap();
384 assert_eq!(fees[0], 0);
385 assert_eq!(fees[1], 1000);
386 }
387
388 #[test]
389 fn test_shuffle_deterministic() {
390 let items = vec![1, 2, 3, 4, 5];
391 let seed = [0u8; 32];
392
393 let shuffled1 = shuffle_with_seed(&items, &seed);
394 let shuffled2 = shuffle_with_seed(&items, &seed);
395
396 assert_eq!(shuffled1, shuffled2);
397 }
398}