chains_sdk/threshold/musig2/
nested.rs1use crate::error::SignerError;
17use k256::elliptic_curve::group::GroupEncoding;
18use k256::elliptic_curve::sec1::ToEncodedPoint;
19use k256::{AffinePoint, ProjectivePoint, Scalar};
20
21use super::signing::{
22 self, compute_nonce_coeff, AggNonce, KeyAggContext, PartialSignature, PubNonce,
23};
24
25pub fn verify_partial_sig(
49 partial_sig: &PartialSignature,
50 pub_nonce: &PubNonce,
51 signer_pubkey: &[u8; 33],
52 key_agg_ctx: &KeyAggContext,
53 agg_nonce: &AggNonce,
54 msg: &[u8],
55) -> Result<bool, SignerError> {
56 let b = compute_nonce_coeff(agg_nonce, &key_agg_ctx.x_only_pubkey, msg);
58
59 let r_group = ProjectivePoint::from(agg_nonce.r1) + ProjectivePoint::from(agg_nonce.r2) * b;
61 let r_affine = r_group.to_affine();
62 let r_encoded = r_affine.to_encoded_point(true);
63 let r_bytes = r_encoded.as_bytes();
64 let nonce_negated = r_bytes[0] == 0x03;
65 let mut r_x = [0u8; 32];
66 r_x.copy_from_slice(&r_bytes[1..33]);
67
68 let mut challenge_data = Vec::new();
70 challenge_data.extend_from_slice(&r_x);
71 challenge_data.extend_from_slice(&key_agg_ctx.x_only_pubkey);
72 challenge_data.extend_from_slice(msg);
73 let e = tagged_hash_scalar(b"BIP0340/challenge", &challenge_data);
74
75 let signer_idx = key_agg_ctx
77 .pubkeys
78 .iter()
79 .position(|pk| pk == signer_pubkey);
80 let signer_idx = match signer_idx {
81 Some(idx) => idx,
82 None => {
83 return Err(SignerError::SigningFailed(
84 "signer not in key_agg context".into(),
85 ))
86 }
87 };
88 let a_i = key_agg_ctx.coefficients[signer_idx];
89
90 let pk_ct = AffinePoint::from_bytes(signer_pubkey.into());
92 if !bool::from(pk_ct.is_some()) {
93 return Ok(false);
94 }
95 #[allow(clippy::unwrap_used)]
96 let pk_point = ProjectivePoint::from(pk_ct.unwrap());
97
98 let ri = ProjectivePoint::from(pub_nonce.r1) + ProjectivePoint::from(pub_nonce.r2) * b;
100 let effective_ri = if nonce_negated { -ri } else { ri };
101
102 let effective_pk = if key_agg_ctx.parity {
104 -pk_point
105 } else {
106 pk_point
107 };
108
109 let lhs = ProjectivePoint::GENERATOR * partial_sig.s;
111
112 let rhs = effective_ri + effective_pk * (e * a_i);
114
115 Ok(lhs == rhs)
116}
117
118#[derive(Clone, Debug)]
124pub enum KeyTreeNode {
125 Leaf([u8; 33]),
127 Internal(Vec<KeyTreeNode>),
129}
130
131impl KeyTreeNode {
132 pub fn effective_pubkey(&self) -> Result<[u8; 33], SignerError> {
137 match self {
138 KeyTreeNode::Leaf(pk) => Ok(*pk),
139 KeyTreeNode::Internal(children) => {
140 let child_keys: Result<Vec<[u8; 33]>, _> = children
141 .iter()
142 .map(|child| child.effective_pubkey())
143 .collect();
144 let child_keys = child_keys?;
145 let ctx = signing::key_agg(&child_keys)?;
146 let agg_enc = ctx.aggregate_key.to_encoded_point(true);
148 let mut out = [0u8; 33];
149 out.copy_from_slice(agg_enc.as_bytes());
150 Ok(out)
151 }
152 }
153 }
154
155 pub fn key_agg_context(&self) -> Result<KeyAggContext, SignerError> {
157 match self {
158 KeyTreeNode::Leaf(_) => Err(SignerError::ParseError(
159 "leaf nodes don't have a key_agg context".into(),
160 )),
161 KeyTreeNode::Internal(children) => {
162 let child_keys: Result<Vec<[u8; 33]>, _> = children
163 .iter()
164 .map(|child| child.effective_pubkey())
165 .collect();
166 signing::key_agg(&child_keys?)
167 }
168 }
169 }
170
171 #[must_use]
173 pub fn leaf_count(&self) -> usize {
174 match self {
175 KeyTreeNode::Leaf(_) => 1,
176 KeyTreeNode::Internal(children) => children.iter().map(|c| c.leaf_count()).sum(),
177 }
178 }
179
180 #[must_use]
182 pub fn depth(&self) -> usize {
183 match self {
184 KeyTreeNode::Leaf(_) => 0,
185 KeyTreeNode::Internal(children) => {
186 1 + children.iter().map(|c| c.depth()).max().unwrap_or(0)
187 }
188 }
189 }
190}
191
192pub fn flat_key_tree(pubkeys: &[[u8; 33]]) -> KeyTreeNode {
196 KeyTreeNode::Internal(pubkeys.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
197}
198
199pub fn grouped_key_tree(groups: &[Vec<[u8; 33]>]) -> KeyTreeNode {
210 KeyTreeNode::Internal(
211 groups
212 .iter()
213 .map(|group| {
214 if group.len() == 1 {
215 KeyTreeNode::Leaf(group[0])
216 } else {
217 KeyTreeNode::Internal(group.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
218 }
219 })
220 .collect(),
221 )
222}
223
224fn tagged_hash_scalar(tag: &[u8], data: &[u8]) -> Scalar {
228 super::tagged_hash_scalar(tag, data)
229}
230
231#[cfg(test)]
236#[allow(clippy::unwrap_used, clippy::expect_used)]
237mod tests {
238 use super::super::signing;
239 use super::*;
240
241 fn make_keys() -> ([u8; 32], [u8; 32], [u8; 33], [u8; 33]) {
242 let sk1 = [0x11u8; 32];
243 let sk2 = [0x22u8; 32];
244 let pk1 = signing::individual_pubkey(&sk1).unwrap();
245 let pk2 = signing::individual_pubkey(&sk2).unwrap();
246 (sk1, sk2, pk1, pk2)
247 }
248
249 #[test]
252 fn test_partial_sig_verify_valid() {
253 let (sk1, sk2, pk1, pk2) = make_keys();
254 let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
255 let msg = b"partial sig verify";
256
257 let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
258 let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
259 let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
260
261 let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
262
263 let valid = verify_partial_sig(&psig1, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
264 assert!(valid, "valid partial sig must verify");
265 }
266
267 #[test]
268 fn test_partial_sig_verify_tampered() {
269 let (sk1, sk2, pk1, pk2) = make_keys();
270 let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
271 let msg = b"tampered partial sig";
272
273 let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
274 let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
275 let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
276
277 let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
278
279 let tampered = PartialSignature {
281 s: psig1.s + Scalar::ONE,
282 };
283 let valid = verify_partial_sig(&tampered, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
284 assert!(!valid, "tampered partial sig must fail");
285 }
286
287 #[test]
288 fn test_partial_sig_wrong_key_fails() {
289 let (sk1, sk2, pk1, pk2) = make_keys();
290 let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
291 let msg = b"wrong key";
292
293 let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
294 let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
295 let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
296 let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
297
298 let valid = verify_partial_sig(&psig1, &pub1, &pk2, &ctx, &agg_nonce, msg).unwrap();
300 assert!(!valid, "partial sig verified with wrong key must fail");
301 }
302
303 #[test]
306 fn test_flat_key_tree() {
307 let (_, _, pk1, pk2) = make_keys();
308 let tree = flat_key_tree(&[pk1, pk2]);
309 assert_eq!(tree.leaf_count(), 2);
310 assert_eq!(tree.depth(), 1);
311
312 let effective = tree.effective_pubkey().unwrap();
313 let direct_ctx = signing::key_agg(&[pk1, pk2]).unwrap();
314 let direct_enc = direct_ctx.aggregate_key.to_encoded_point(true);
315 let mut direct_bytes = [0u8; 33];
316 direct_bytes.copy_from_slice(direct_enc.as_bytes());
317 assert_eq!(effective, direct_bytes, "flat tree should match direct agg");
318 }
319
320 #[test]
321 fn test_nested_key_tree() {
322 let sk3 = [0x33u8; 32];
323 let pk3 = signing::individual_pubkey(&sk3).unwrap();
324 let (_, _, pk1, pk2) = make_keys();
325
326 let tree = KeyTreeNode::Internal(vec![
328 KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
329 KeyTreeNode::Leaf(pk3),
330 ]);
331
332 assert_eq!(tree.leaf_count(), 3);
333 assert_eq!(tree.depth(), 2);
334
335 let effective = tree.effective_pubkey().unwrap();
336 assert_ne!(effective, [0u8; 33]);
337
338 let ctx = tree.key_agg_context().unwrap();
340 assert_eq!(ctx.pubkeys.len(), 2); }
342
343 #[test]
344 fn test_grouped_key_tree() {
345 let (_, _, pk1, pk2) = make_keys();
346 let sk3 = [0x33u8; 32];
347 let sk4 = [0x44u8; 32];
348 let pk3 = signing::individual_pubkey(&sk3).unwrap();
349 let pk4 = signing::individual_pubkey(&sk4).unwrap();
350
351 let tree = grouped_key_tree(&[vec![pk1, pk2], vec![pk3, pk4]]);
352 assert_eq!(tree.leaf_count(), 4);
353 assert_eq!(tree.depth(), 2);
354
355 let effective = tree.effective_pubkey().unwrap();
356 assert_ne!(effective, [0u8; 33]);
357 }
358
359 #[test]
360 fn test_leaf_key_agg_context_error() {
361 let (_, _, pk1, _) = make_keys();
362 let leaf = KeyTreeNode::Leaf(pk1);
363 assert!(leaf.key_agg_context().is_err());
364 }
365
366 #[test]
367 fn test_nested_tree_deterministic() {
368 let (_, _, pk1, pk2) = make_keys();
369 let sk3 = [0x33u8; 32];
370 let pk3 = signing::individual_pubkey(&sk3).unwrap();
371
372 let tree1 = KeyTreeNode::Internal(vec![
373 KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
374 KeyTreeNode::Leaf(pk3),
375 ]);
376 let tree2 = KeyTreeNode::Internal(vec![
377 KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
378 KeyTreeNode::Leaf(pk3),
379 ]);
380
381 assert_eq!(
382 tree1.effective_pubkey().unwrap(),
383 tree2.effective_pubkey().unwrap()
384 );
385 }
386}