rustywallet_coinjoin/
builder.rs

1//! CoinJoin transaction builder.
2
3use crate::error::{CoinJoinError, Result};
4use crate::types::{FeeStrategy, InputRef, OutputDef, Participant};
5use sha2::{Digest, Sha256};
6
7/// CoinJoin transaction builder.
8///
9/// Builds CoinJoin transactions with equal output amounts
10/// and shuffled inputs/outputs for privacy.
11pub struct CoinJoinBuilder {
12    /// Participants
13    participants: Vec<Participant>,
14    /// Output amount (equal for all)
15    output_amount: Option<u64>,
16    /// Fee rate (sat/vB)
17    fee_rate: f64,
18    /// Fee distribution strategy
19    fee_strategy: FeeStrategy,
20    /// Minimum number of participants
21    min_participants: usize,
22}
23
24impl CoinJoinBuilder {
25    /// Create a new CoinJoin builder.
26    pub fn new() -> Self {
27        Self {
28            participants: Vec::new(),
29            output_amount: None,
30            fee_rate: 1.0,
31            fee_strategy: FeeStrategy::Equal,
32            min_participants: 2,
33        }
34    }
35
36    /// Add a participant.
37    pub fn add_participant(&mut self, participant: Participant) -> &mut Self {
38        self.participants.push(participant);
39        self
40    }
41
42    /// Add participant with inputs and output.
43    pub fn add_participant_simple(
44        &mut self,
45        id: impl Into<String>,
46        inputs: Vec<InputRef>,
47        output_script: Vec<u8>,
48    ) -> &mut Self {
49        self.participants
50            .push(Participant::new(id, inputs, output_script));
51        self
52    }
53
54    /// Set the equal output amount.
55    pub fn set_output_amount(&mut self, amount: u64) -> &mut Self {
56        self.output_amount = Some(amount);
57        self
58    }
59
60    /// Set fee rate in sat/vB.
61    pub fn set_fee_rate(&mut self, rate: f64) -> &mut Self {
62        self.fee_rate = rate;
63        self
64    }
65
66    /// Set fee distribution strategy.
67    pub fn set_fee_strategy(&mut self, strategy: FeeStrategy) -> &mut Self {
68        self.fee_strategy = strategy;
69        self
70    }
71
72    /// Set minimum number of participants.
73    pub fn set_min_participants(&mut self, min: usize) -> &mut Self {
74        self.min_participants = min;
75        self
76    }
77
78    /// Build the CoinJoin transaction.
79    pub fn build(&self) -> Result<CoinJoinTransaction> {
80        // Validate
81        if self.participants.len() < self.min_participants {
82            return Err(CoinJoinError::NoParticipants);
83        }
84
85        let output_amount = self
86            .output_amount
87            .ok_or_else(|| CoinJoinError::InvalidAmount("Output amount not set".into()))?;
88
89        // Calculate total inputs and fees
90        let _total_inputs: u64 = self.participants.iter().map(|p| p.total_input()).sum();
91        let estimated_size = self.estimate_tx_size();
92        let total_fee = (estimated_size as f64 * self.fee_rate) as u64;
93
94        // Calculate per-participant fee
95        let fees = self.calculate_fees(total_fee)?;
96
97        // Verify each participant has enough funds
98        for (i, participant) in self.participants.iter().enumerate() {
99            let needed = output_amount + fees[i];
100            let available = participant.total_input();
101            if available < needed {
102                return Err(CoinJoinError::InsufficientFunds { needed, available });
103            }
104        }
105
106        // Collect all inputs
107        let mut inputs: Vec<(InputRef, usize)> = Vec::new();
108        for (idx, participant) in self.participants.iter().enumerate() {
109            for input in &participant.inputs {
110                inputs.push((input.clone(), idx));
111            }
112        }
113
114        // Create equal outputs
115        let mut outputs: Vec<(OutputDef, usize)> = Vec::new();
116        for (idx, participant) in self.participants.iter().enumerate() {
117            outputs.push((
118                OutputDef::new(output_amount, participant.output_script.clone()),
119                idx,
120            ));
121        }
122
123        // Create change outputs
124        let mut change_outputs: Vec<(OutputDef, usize)> = Vec::new();
125        for (idx, participant) in self.participants.iter().enumerate() {
126            let change = participant.total_input() - output_amount - fees[idx];
127            if change > 546 {
128                // Dust threshold
129                if let Some(change_script) = &participant.change_script {
130                    change_outputs.push((OutputDef::new(change, change_script.clone()), idx));
131                }
132            }
133        }
134
135        // Shuffle inputs and outputs
136        let shuffled_inputs = shuffle_with_seed(&inputs, &self.generate_shuffle_seed());
137        let shuffled_outputs = shuffle_with_seed(&outputs, &self.generate_shuffle_seed());
138
139        Ok(CoinJoinTransaction {
140            inputs: shuffled_inputs.into_iter().map(|(i, _)| i).collect(),
141            outputs: shuffled_outputs.into_iter().map(|(o, _)| o).collect(),
142            change_outputs: change_outputs.into_iter().map(|(o, _)| o).collect(),
143            participant_count: self.participants.len(),
144            output_amount,
145            total_fee,
146        })
147    }
148
149    /// Calculate fees per participant.
150    fn calculate_fees(&self, total_fee: u64) -> Result<Vec<u64>> {
151        let n = self.participants.len();
152        if n == 0 {
153            return Err(CoinJoinError::NoParticipants);
154        }
155
156        match self.fee_strategy {
157            FeeStrategy::Equal => {
158                let per_participant = total_fee / n as u64;
159                let remainder = total_fee % n as u64;
160                let mut fees: Vec<u64> = vec![per_participant; n];
161                // First participant pays remainder
162                fees[0] += remainder;
163                Ok(fees)
164            }
165            FeeStrategy::Proportional => {
166                let total_input: u64 = self.participants.iter().map(|p| p.total_input()).sum();
167                if total_input == 0 {
168                    return Err(CoinJoinError::FeeError("No inputs".into()));
169                }
170                let fees: Vec<u64> = self
171                    .participants
172                    .iter()
173                    .map(|p| (p.total_input() as f64 / total_input as f64 * total_fee as f64) as u64)
174                    .collect();
175                Ok(fees)
176            }
177            FeeStrategy::SinglePayer(idx) => {
178                if idx >= n {
179                    return Err(CoinJoinError::FeeError("Invalid payer index".into()));
180                }
181                let mut fees = vec![0u64; n];
182                fees[idx] = total_fee;
183                Ok(fees)
184            }
185        }
186    }
187
188    /// Estimate transaction size in vBytes.
189    fn estimate_tx_size(&self) -> usize {
190        let input_count: usize = self.participants.iter().map(|p| p.inputs.len()).sum();
191        let output_count = self.participants.len() * 2; // output + change
192
193        // Rough estimate: 10 + 68*inputs + 34*outputs (for P2WPKH)
194        10 + 68 * input_count + 34 * output_count
195    }
196
197    /// Generate deterministic shuffle seed.
198    fn generate_shuffle_seed(&self) -> [u8; 32] {
199        let mut hasher = Sha256::new();
200        for participant in &self.participants {
201            hasher.update(participant.id.as_bytes());
202            for input in &participant.inputs {
203                hasher.update(input.txid);
204                hasher.update(input.vout.to_le_bytes());
205            }
206        }
207        let result = hasher.finalize();
208        let mut seed = [0u8; 32];
209        seed.copy_from_slice(&result);
210        seed
211    }
212}
213
214impl Default for CoinJoinBuilder {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220/// Built CoinJoin transaction.
221#[derive(Debug, Clone)]
222pub struct CoinJoinTransaction {
223    /// Shuffled inputs
224    pub inputs: Vec<InputRef>,
225    /// Equal amount outputs (shuffled)
226    pub outputs: Vec<OutputDef>,
227    /// Change outputs
228    pub change_outputs: Vec<OutputDef>,
229    /// Number of participants
230    pub participant_count: usize,
231    /// Equal output amount
232    pub output_amount: u64,
233    /// Total fee
234    pub total_fee: u64,
235}
236
237impl CoinJoinTransaction {
238    /// Total input amount.
239    pub fn total_input(&self) -> u64 {
240        self.inputs.iter().map(|i| i.amount).sum()
241    }
242
243    /// Total output amount (excluding change).
244    pub fn total_output(&self) -> u64 {
245        self.outputs.iter().map(|o| o.amount).sum()
246    }
247
248    /// Total change amount.
249    pub fn total_change(&self) -> u64 {
250        self.change_outputs.iter().map(|o| o.amount).sum()
251    }
252
253    /// Verify all main outputs are equal.
254    pub fn verify_equal_outputs(&self) -> bool {
255        self.outputs.iter().all(|o| o.amount == self.output_amount)
256    }
257
258    /// Get all outputs (main + change).
259    pub fn all_outputs(&self) -> Vec<&OutputDef> {
260        self.outputs
261            .iter()
262            .chain(self.change_outputs.iter())
263            .collect()
264    }
265}
266
267/// Shuffle a vector deterministically using a seed.
268fn shuffle_with_seed<T: Clone>(items: &[T], seed: &[u8; 32]) -> Vec<T> {
269    if items.is_empty() {
270        return Vec::new();
271    }
272
273    let mut result: Vec<(T, u64)> = items
274        .iter()
275        .enumerate()
276        .map(|(i, item)| {
277            let mut hasher = Sha256::new();
278            hasher.update(seed);
279            hasher.update(i.to_le_bytes());
280            let hash = hasher.finalize();
281            let sort_key = u64::from_le_bytes(hash[0..8].try_into().unwrap());
282            (item.clone(), sort_key)
283        })
284        .collect();
285
286    result.sort_by_key(|(_, key)| *key);
287    result.into_iter().map(|(item, _)| item).collect()
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_builder_creation() {
296        let builder = CoinJoinBuilder::new();
297        assert_eq!(builder.participants.len(), 0);
298        assert_eq!(builder.min_participants, 2);
299    }
300
301    #[test]
302    fn test_add_participant() {
303        let mut builder = CoinJoinBuilder::new();
304        builder.add_participant_simple(
305            "alice",
306            vec![InputRef::from_outpoint([1u8; 32], 0, 100_000)],
307            vec![0x00, 0x14],
308        );
309        assert_eq!(builder.participants.len(), 1);
310    }
311
312    #[test]
313    fn test_build_coinjoin() {
314        let mut builder = CoinJoinBuilder::new();
315
316        builder.add_participant_simple(
317            "alice",
318            vec![InputRef::from_outpoint([1u8; 32], 0, 100_000)],
319            vec![0x00, 0x14, 0x01],
320        );
321        builder.add_participant_simple(
322            "bob",
323            vec![InputRef::from_outpoint([2u8; 32], 0, 100_000)],
324            vec![0x00, 0x14, 0x02],
325        );
326        builder.set_output_amount(50_000);
327        builder.set_fee_rate(1.0);
328
329        let tx = builder.build().unwrap();
330
331        assert_eq!(tx.participant_count, 2);
332        assert_eq!(tx.inputs.len(), 2);
333        assert_eq!(tx.outputs.len(), 2);
334        assert!(tx.verify_equal_outputs());
335    }
336
337    #[test]
338    fn test_insufficient_funds() {
339        let mut builder = CoinJoinBuilder::new();
340
341        builder.add_participant_simple(
342            "alice",
343            vec![InputRef::from_outpoint([1u8; 32], 0, 10_000)],
344            vec![0x00, 0x14],
345        );
346        builder.add_participant_simple(
347            "bob",
348            vec![InputRef::from_outpoint([2u8; 32], 0, 10_000)],
349            vec![0x00, 0x14],
350        );
351        builder.set_output_amount(50_000);
352
353        let result = builder.build();
354        assert!(matches!(result, Err(CoinJoinError::InsufficientFunds { .. })));
355    }
356
357    #[test]
358    fn test_fee_strategies() {
359        let mut builder = CoinJoinBuilder::new();
360        builder.add_participant_simple(
361            "alice",
362            vec![InputRef::from_outpoint([1u8; 32], 0, 100_000)],
363            vec![0x00, 0x14],
364        );
365        builder.add_participant_simple(
366            "bob",
367            vec![InputRef::from_outpoint([2u8; 32], 0, 200_000)],
368            vec![0x00, 0x14],
369        );
370
371        // Equal
372        let fees = builder.calculate_fees(1000).unwrap();
373        assert_eq!(fees[0], 500);
374        assert_eq!(fees[1], 500);
375
376        // Proportional
377        builder.set_fee_strategy(FeeStrategy::Proportional);
378        let fees = builder.calculate_fees(1000).unwrap();
379        assert!(fees[1] > fees[0]); // Bob has more input
380
381        // Single payer
382        builder.set_fee_strategy(FeeStrategy::SinglePayer(1));
383        let fees = builder.calculate_fees(1000).unwrap();
384        assert_eq!(fees[0], 0);
385        assert_eq!(fees[1], 1000);
386    }
387
388    #[test]
389    fn test_shuffle_deterministic() {
390        let items = vec![1, 2, 3, 4, 5];
391        let seed = [0u8; 32];
392
393        let shuffled1 = shuffle_with_seed(&items, &seed);
394        let shuffled2 = shuffle_with_seed(&items, &seed);
395
396        assert_eq!(shuffled1, shuffled2);
397    }
398}