chains_sdk/threshold/musig2/
nested.rs1use crate::crypto;
17use crate::error::SignerError;
18use k256::elliptic_curve::group::GroupEncoding;
19use k256::elliptic_curve::ops::Reduce;
20use k256::elliptic_curve::sec1::ToEncodedPoint;
21use k256::{AffinePoint, ProjectivePoint, Scalar};
22
23use super::signing::{
24 self, compute_nonce_coeff, AggNonce, KeyAggContext, PartialSignature, PubNonce,
25};
26
27pub fn verify_partial_sig(
51 partial_sig: &PartialSignature,
52 pub_nonce: &PubNonce,
53 signer_pubkey: &[u8; 33],
54 key_agg_ctx: &KeyAggContext,
55 agg_nonce: &AggNonce,
56 msg: &[u8],
57) -> Result<bool, SignerError> {
58 let b = compute_nonce_coeff(agg_nonce, &key_agg_ctx.x_only_pubkey, msg);
60
61 let r_group = ProjectivePoint::from(agg_nonce.r1) + ProjectivePoint::from(agg_nonce.r2) * b;
63 let r_affine = r_group.to_affine();
64 let r_encoded = r_affine.to_encoded_point(true);
65 let r_bytes = r_encoded.as_bytes();
66 let nonce_negated = r_bytes[0] == 0x03;
67 let mut r_x = [0u8; 32];
68 r_x.copy_from_slice(&r_bytes[1..33]);
69
70 let mut challenge_data = Vec::new();
72 challenge_data.extend_from_slice(&r_x);
73 challenge_data.extend_from_slice(&key_agg_ctx.x_only_pubkey);
74 challenge_data.extend_from_slice(msg);
75 let e = tagged_hash_scalar(b"BIP0340/challenge", &challenge_data);
76
77 let signer_idx = key_agg_ctx
79 .pubkeys
80 .iter()
81 .position(|pk| pk == signer_pubkey);
82 let signer_idx = match signer_idx {
83 Some(idx) => idx,
84 None => {
85 return Err(SignerError::SigningFailed(
86 "signer not in key_agg context".into(),
87 ))
88 }
89 };
90 let a_i = key_agg_ctx.coefficients[signer_idx];
91
92 let pk_ct = AffinePoint::from_bytes(signer_pubkey.into());
94 if !bool::from(pk_ct.is_some()) {
95 return Ok(false);
96 }
97 #[allow(clippy::unwrap_used)]
98 let pk_point = ProjectivePoint::from(pk_ct.unwrap());
99
100 let ri = ProjectivePoint::from(pub_nonce.r1) + ProjectivePoint::from(pub_nonce.r2) * b;
102 let effective_ri = if nonce_negated { -ri } else { ri };
103
104 let effective_pk = if key_agg_ctx.parity {
106 -pk_point
107 } else {
108 pk_point
109 };
110
111 let lhs = ProjectivePoint::GENERATOR * partial_sig.s;
113
114 let rhs = effective_ri + effective_pk * (e * a_i);
116
117 Ok(lhs == rhs)
118}
119
120#[derive(Clone, Debug)]
126pub enum KeyTreeNode {
127 Leaf([u8; 33]),
129 Internal(Vec<KeyTreeNode>),
131}
132
133impl KeyTreeNode {
134 pub fn effective_pubkey(&self) -> Result<[u8; 33], SignerError> {
139 match self {
140 KeyTreeNode::Leaf(pk) => Ok(*pk),
141 KeyTreeNode::Internal(children) => {
142 let child_keys: Result<Vec<[u8; 33]>, _> = children
143 .iter()
144 .map(|child| child.effective_pubkey())
145 .collect();
146 let child_keys = child_keys?;
147 let ctx = signing::key_agg(&child_keys)?;
148 let agg_enc = ctx.aggregate_key.to_encoded_point(true);
150 let mut out = [0u8; 33];
151 out.copy_from_slice(agg_enc.as_bytes());
152 Ok(out)
153 }
154 }
155 }
156
157 pub fn key_agg_context(&self) -> Result<KeyAggContext, SignerError> {
159 match self {
160 KeyTreeNode::Leaf(_) => Err(SignerError::ParseError(
161 "leaf nodes don't have a key_agg context".into(),
162 )),
163 KeyTreeNode::Internal(children) => {
164 let child_keys: Result<Vec<[u8; 33]>, _> = children
165 .iter()
166 .map(|child| child.effective_pubkey())
167 .collect();
168 signing::key_agg(&child_keys?)
169 }
170 }
171 }
172
173 #[must_use]
175 pub fn leaf_count(&self) -> usize {
176 match self {
177 KeyTreeNode::Leaf(_) => 1,
178 KeyTreeNode::Internal(children) => children.iter().map(|c| c.leaf_count()).sum(),
179 }
180 }
181
182 #[must_use]
184 pub fn depth(&self) -> usize {
185 match self {
186 KeyTreeNode::Leaf(_) => 0,
187 KeyTreeNode::Internal(children) => {
188 1 + children.iter().map(|c| c.depth()).max().unwrap_or(0)
189 }
190 }
191 }
192}
193
194pub fn flat_key_tree(pubkeys: &[[u8; 33]]) -> KeyTreeNode {
198 KeyTreeNode::Internal(pubkeys.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
199}
200
201pub fn grouped_key_tree(groups: &[Vec<[u8; 33]>]) -> KeyTreeNode {
212 KeyTreeNode::Internal(
213 groups
214 .iter()
215 .map(|group| {
216 if group.len() == 1 {
217 KeyTreeNode::Leaf(group[0])
218 } else {
219 KeyTreeNode::Internal(group.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
220 }
221 })
222 .collect(),
223 )
224}
225
226fn tagged_hash_scalar(tag: &[u8], data: &[u8]) -> Scalar {
230 let hash = crypto::tagged_hash(tag, data);
231 let wide = k256::U256::from_be_slice(&hash);
232 <Scalar as Reduce<k256::U256>>::reduce(wide)
233}
234
235#[cfg(test)]
240#[allow(clippy::unwrap_used, clippy::expect_used)]
241mod tests {
242 use super::super::signing;
243 use super::*;
244
245 fn make_keys() -> ([u8; 32], [u8; 32], [u8; 33], [u8; 33]) {
246 let sk1 = [0x11u8; 32];
247 let sk2 = [0x22u8; 32];
248 let pk1 = signing::individual_pubkey(&sk1).unwrap();
249 let pk2 = signing::individual_pubkey(&sk2).unwrap();
250 (sk1, sk2, pk1, pk2)
251 }
252
253 #[test]
256 fn test_partial_sig_verify_valid() {
257 let (sk1, sk2, pk1, pk2) = make_keys();
258 let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
259 let msg = b"partial sig verify";
260
261 let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
262 let (sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
263 let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
264
265 let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
266
267 let valid = verify_partial_sig(&psig1, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
268 assert!(valid, "valid partial sig must verify");
269 }
270
271 #[test]
272 fn test_partial_sig_verify_tampered() {
273 let (sk1, sk2, pk1, pk2) = make_keys();
274 let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
275 let msg = b"tampered partial sig";
276
277 let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
278 let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
279 let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
280
281 let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
282
283 let tampered = PartialSignature {
285 s: psig1.s + Scalar::ONE,
286 };
287 let valid = verify_partial_sig(&tampered, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
288 assert!(!valid, "tampered partial sig must fail");
289 }
290
291 #[test]
292 fn test_partial_sig_wrong_key_fails() {
293 let (sk1, sk2, pk1, pk2) = make_keys();
294 let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
295 let msg = b"wrong key";
296
297 let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
298 let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
299 let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
300 let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
301
302 let valid = verify_partial_sig(&psig1, &pub1, &pk2, &ctx, &agg_nonce, msg).unwrap();
304 assert!(!valid, "partial sig verified with wrong key must fail");
305 }
306
307 #[test]
310 fn test_flat_key_tree() {
311 let (_, _, pk1, pk2) = make_keys();
312 let tree = flat_key_tree(&[pk1, pk2]);
313 assert_eq!(tree.leaf_count(), 2);
314 assert_eq!(tree.depth(), 1);
315
316 let effective = tree.effective_pubkey().unwrap();
317 let direct_ctx = signing::key_agg(&[pk1, pk2]).unwrap();
318 let direct_enc = direct_ctx.aggregate_key.to_encoded_point(true);
319 let mut direct_bytes = [0u8; 33];
320 direct_bytes.copy_from_slice(direct_enc.as_bytes());
321 assert_eq!(effective, direct_bytes, "flat tree should match direct agg");
322 }
323
324 #[test]
325 fn test_nested_key_tree() {
326 let sk3 = [0x33u8; 32];
327 let pk3 = signing::individual_pubkey(&sk3).unwrap();
328 let (_, _, pk1, pk2) = make_keys();
329
330 let tree = KeyTreeNode::Internal(vec![
332 KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
333 KeyTreeNode::Leaf(pk3),
334 ]);
335
336 assert_eq!(tree.leaf_count(), 3);
337 assert_eq!(tree.depth(), 2);
338
339 let effective = tree.effective_pubkey().unwrap();
340 assert_ne!(effective, [0u8; 33]);
341
342 let ctx = tree.key_agg_context().unwrap();
344 assert_eq!(ctx.pubkeys.len(), 2); }
346
347 #[test]
348 fn test_grouped_key_tree() {
349 let (_, _, pk1, pk2) = make_keys();
350 let sk3 = [0x33u8; 32];
351 let sk4 = [0x44u8; 32];
352 let pk3 = signing::individual_pubkey(&sk3).unwrap();
353 let pk4 = signing::individual_pubkey(&sk4).unwrap();
354
355 let tree = grouped_key_tree(&[vec![pk1, pk2], vec![pk3, pk4]]);
356 assert_eq!(tree.leaf_count(), 4);
357 assert_eq!(tree.depth(), 2);
358
359 let effective = tree.effective_pubkey().unwrap();
360 assert_ne!(effective, [0u8; 33]);
361 }
362
363 #[test]
364 fn test_leaf_key_agg_context_error() {
365 let (_, _, pk1, _) = make_keys();
366 let leaf = KeyTreeNode::Leaf(pk1);
367 assert!(leaf.key_agg_context().is_err());
368 }
369
370 #[test]
371 fn test_nested_tree_deterministic() {
372 let (_, _, pk1, pk2) = make_keys();
373 let sk3 = [0x33u8; 32];
374 let pk3 = signing::individual_pubkey(&sk3).unwrap();
375
376 let tree1 = KeyTreeNode::Internal(vec![
377 KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
378 KeyTreeNode::Leaf(pk3),
379 ]);
380 let tree2 = KeyTreeNode::Internal(vec![
381 KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
382 KeyTreeNode::Leaf(pk3),
383 ]);
384
385 assert_eq!(
386 tree1.effective_pubkey().unwrap(),
387 tree2.effective_pubkey().unwrap()
388 );
389 }
390}