chia_sdk_driver/primitives/mips/
m_of_n.rs1use std::collections::HashMap;
2
3use chia_protocol::Bytes32;
4use chia_sdk_types::{
5 puzzles::{MofNArgs, MofNSolution, NofNArgs, NofNSolution, OneOfNArgs, OneOfNSolution},
6 MerkleTree, Mod,
7};
8use clvm_traits::clvm_tuple;
9use clvm_utils::{tree_hash_atom, tree_hash_pair, TreeHash};
10use clvmr::NodePtr;
11
12use crate::{DriverError, Spend, SpendContext};
13
14use super::mips_spend::MipsSpend;
15
16#[derive(Debug, Clone)]
17pub struct MofN {
18 pub required: usize,
19 pub items: Vec<TreeHash>,
20}
21
22impl MofN {
23 pub fn new(required: usize, items: Vec<TreeHash>) -> Self {
24 Self { required, items }
25 }
26
27 pub fn inner_puzzle_hash(&self) -> TreeHash {
28 if self.required == 1 {
29 let merkle_tree = self.merkle_tree();
30 OneOfNArgs::new(merkle_tree.root()).curry_tree_hash()
31 } else if self.required == self.items.len() {
32 NofNArgs::new(self.items.clone()).curry_tree_hash()
33 } else {
34 let merkle_tree = self.merkle_tree();
35 MofNArgs::new(self.required, merkle_tree.root()).curry_tree_hash()
36 }
37 }
38
39 pub fn spend(
40 &self,
41 ctx: &mut SpendContext,
42 spend: &MipsSpend,
43 delegated_puzzle_wrappers: &mut Vec<TreeHash>,
44 ) -> Result<Spend, DriverError> {
45 if self.required == 1 {
46 let (member_hash, member_spend) = self
47 .items
48 .iter()
49 .find_map(|item| Some((*item, spend.members.get(item)?)))
50 .ok_or(DriverError::MissingSubpathSpend)?;
51
52 let member_spend = member_spend.spend(ctx, spend, delegated_puzzle_wrappers, false)?;
53
54 let merkle_tree = self.merkle_tree();
55 let merkle_proof = merkle_tree
56 .proof(member_hash.into())
57 .ok_or(DriverError::InvalidMerkleProof)?;
58
59 let puzzle = ctx.curry(OneOfNArgs::new(merkle_tree.root()))?;
60 let solution = ctx.alloc(&OneOfNSolution::new(
61 merkle_proof,
62 member_spend.puzzle,
63 member_spend.solution,
64 ))?;
65 Ok(Spend::new(puzzle, solution))
66 } else if self.required == self.items.len() {
67 let mut puzzles = Vec::with_capacity(self.items.len());
68 let mut solutions = Vec::with_capacity(self.items.len());
69
70 for item in &self.items {
71 let member = spend
72 .members
73 .get(item)
74 .ok_or(DriverError::MissingSubpathSpend)?;
75
76 let member_spend = member.spend(ctx, spend, delegated_puzzle_wrappers, false)?;
77
78 puzzles.push(member_spend.puzzle);
79 solutions.push(member_spend.solution);
80 }
81
82 let puzzle = ctx.curry(NofNArgs::new(puzzles))?;
83 let solution = ctx.alloc(&NofNSolution::new(solutions))?;
84 Ok(Spend::new(puzzle, solution))
85 } else {
86 let mut puzzle_hashes = Vec::with_capacity(self.required);
87 let mut member_spends = HashMap::with_capacity(self.required);
88
89 for &item in &self.items {
90 puzzle_hashes.push(item.into());
91
92 let Some(member) = spend.members.get(&item) else {
93 continue;
94 };
95
96 member_spends.insert(
97 item,
98 member.spend(ctx, spend, delegated_puzzle_wrappers, false)?,
99 );
100 }
101
102 if member_spends.len() < self.required {
103 return Err(DriverError::InvalidSubpathSpendCount);
104 }
105
106 let merkle_tree = self.merkle_tree();
107 let proof = m_of_n_proof(ctx, &puzzle_hashes, &member_spends)?;
108
109 let puzzle = ctx.curry(MofNArgs::new(self.required, merkle_tree.root()))?;
110 let solution = ctx.alloc(&MofNSolution::new(proof))?;
111 Ok(Spend::new(puzzle, solution))
112 }
113 }
114
115 fn merkle_tree(&self) -> MerkleTree {
116 let leaves: Vec<Bytes32> = self.items.iter().map(|&member| member.into()).collect();
117 MerkleTree::new(&leaves)
118 }
119}
120
121fn m_of_n_proof(
122 ctx: &mut SpendContext,
123 puzzle_hashes: &[Bytes32],
124 member_spends: &HashMap<TreeHash, Spend>,
125) -> Result<NodePtr, DriverError> {
126 if puzzle_hashes.len() == 1 {
127 let puzzle_hash = puzzle_hashes[0];
128
129 return if let Some(spend) = member_spends.get(&puzzle_hash.into()) {
130 ctx.alloc(&clvm_tuple!((), spend.puzzle, spend.solution))
131 } else {
132 ctx.alloc(&Bytes32::from(tree_hash_atom(&puzzle_hash)))
133 };
134 }
135
136 let mid_index = puzzle_hashes.len().div_ceil(2);
137 let first = &puzzle_hashes[..mid_index];
138 let rest = &puzzle_hashes[mid_index..];
139
140 let first_proof = m_of_n_proof(ctx, first, member_spends)?;
141 let rest_proof = m_of_n_proof(ctx, rest, member_spends)?;
142
143 if first_proof.is_pair() || rest_proof.is_pair() {
144 ctx.alloc(&(first_proof, rest_proof))
145 } else {
146 let first_hash = ctx.extract::<Bytes32>(first_proof)?;
147 let rest_hash = ctx.extract::<Bytes32>(rest_proof)?;
148 let pair_hash = Bytes32::from(tree_hash_pair(first_hash.into(), rest_hash.into()));
149 ctx.alloc(&pair_hash)
150 }
151}