Skip to main content

chains_sdk/threshold/musig2/
nested.rs

1//! MuSig2 Nested Key Aggregation (Key Trees) and Partial Signature Verification.
2//!
3//! Supports hierarchical key aggregation where some "leaf" keys are themselves
4//! MuSig2 aggregate keys. Also provides partial signature verification to
5//! validate individual signer contributions before aggregation.
6//!
7//! # Nested Key Example
8//! ```text
9//!    TopLevel (Q = agg(Q_a, Q_b))
10//!         /         \
11//!     Q_a = agg(pk1, pk2)    Q_b = pk3
12//! ```
13//! The group Q_a signs internally using 2-of-2 MuSig2, then their combined
14//! partial signature contributes as one signer to the top-level session.
15
16use crate::error::SignerError;
17use k256::elliptic_curve::group::GroupEncoding;
18use k256::elliptic_curve::sec1::ToEncodedPoint;
19use k256::{AffinePoint, ProjectivePoint, Scalar};
20
21use super::signing::{
22    self, compute_nonce_coeff, AggNonce, KeyAggContext, PartialSignature, PubNonce,
23};
24
25// ═══════════════════════════════════════════════════════════════════
26// Partial Signature Verification
27// ═══════════════════════════════════════════════════════════════════
28
29/// Verify a partial signature from a specific signer.
30///
31/// Checks that the partial signature `s_i` satisfies:
32/// `s_i * G == R_i + e * a_i * (g * P_i)`
33///
34/// where:
35/// - `R_i = R_{1,i} + b * R_{2,i}` (effective signer nonce)
36/// - `e` = BIP-340 challenge
37/// - `a_i` = key aggregation coefficient
38/// - `g` = parity correction (1 or -1)
39/// - `P_i` = signer's public key
40///
41/// # Arguments
42/// - `partial_sig` — The partial signature to verify
43/// - `pub_nonce` — The signer's public nonce (from round 1)
44/// - `signer_pubkey` — The signer's compressed public key (33 bytes)
45/// - `key_agg_ctx` — The key aggregation context
46/// - `agg_nonce` — The aggregated nonce
47/// - `msg` — The signed message
48pub fn verify_partial_sig(
49    partial_sig: &PartialSignature,
50    pub_nonce: &PubNonce,
51    signer_pubkey: &[u8; 33],
52    key_agg_ctx: &KeyAggContext,
53    agg_nonce: &AggNonce,
54    msg: &[u8],
55) -> Result<bool, SignerError> {
56    // Compute nonce coefficient b
57    let b = compute_nonce_coeff(agg_nonce, &key_agg_ctx.x_only_pubkey, msg);
58
59    // Effective R for the group: R = R1_agg + b * R2_agg
60    let r_group = ProjectivePoint::from(agg_nonce.r1) + ProjectivePoint::from(agg_nonce.r2) * b;
61    let r_affine = r_group.to_affine();
62    let r_encoded = r_affine.to_encoded_point(true);
63    let r_bytes = r_encoded.as_bytes();
64    let nonce_negated = r_bytes[0] == 0x03;
65    let mut r_x = [0u8; 32];
66    r_x.copy_from_slice(&r_bytes[1..33]);
67
68    // BIP-340 challenge
69    let mut challenge_data = Vec::new();
70    challenge_data.extend_from_slice(&r_x);
71    challenge_data.extend_from_slice(&key_agg_ctx.x_only_pubkey);
72    challenge_data.extend_from_slice(msg);
73    let e = tagged_hash_scalar(b"BIP0340/challenge", &challenge_data);
74
75    // Find the signer's aggregation coefficient
76    let signer_idx = key_agg_ctx
77        .pubkeys
78        .iter()
79        .position(|pk| pk == signer_pubkey);
80    let signer_idx = match signer_idx {
81        Some(idx) => idx,
82        None => {
83            return Err(SignerError::SigningFailed(
84                "signer not in key_agg context".into(),
85            ))
86        }
87    };
88    let a_i = key_agg_ctx.coefficients[signer_idx];
89
90    // Parse the signer's public key
91    let pk_ct = AffinePoint::from_bytes(signer_pubkey.into());
92    if !bool::from(pk_ct.is_some()) {
93        return Ok(false);
94    }
95    #[allow(clippy::unwrap_used)]
96    let pk_point = ProjectivePoint::from(pk_ct.unwrap());
97
98    // Effective signer nonce: R_i = R_{1,i} + b * R_{2,i}
99    let ri = ProjectivePoint::from(pub_nonce.r1) + ProjectivePoint::from(pub_nonce.r2) * b;
100    let effective_ri = if nonce_negated { -ri } else { ri };
101
102    // Effective public key (negate if aggregate key has odd y)
103    let effective_pk = if key_agg_ctx.parity {
104        -pk_point
105    } else {
106        pk_point
107    };
108
109    // LHS: s_i * G
110    let lhs = ProjectivePoint::GENERATOR * partial_sig.s;
111
112    // RHS: R_i + e * a_i * P_i
113    let rhs = effective_ri + effective_pk * (e * a_i);
114
115    Ok(lhs == rhs)
116}
117
118// ═══════════════════════════════════════════════════════════════════
119// Nested Key Aggregation (Key Trees)
120// ═══════════════════════════════════════════════════════════════════
121
122/// A node in a MuSig2 key tree.
123#[derive(Clone, Debug)]
124pub enum KeyTreeNode {
125    /// A leaf node: a single public key (33 bytes compressed).
126    Leaf([u8; 33]),
127    /// An internal node: a MuSig2 aggregation of child nodes.
128    Internal(Vec<KeyTreeNode>),
129}
130
131impl KeyTreeNode {
132    /// Compute the effective public key for this node.
133    ///
134    /// For a leaf, returns the key directly.
135    /// For an internal node, recursively aggregates child keys using MuSig2.
136    pub fn effective_pubkey(&self) -> Result<[u8; 33], SignerError> {
137        match self {
138            KeyTreeNode::Leaf(pk) => Ok(*pk),
139            KeyTreeNode::Internal(children) => {
140                let child_keys: Result<Vec<[u8; 33]>, _> = children
141                    .iter()
142                    .map(|child| child.effective_pubkey())
143                    .collect();
144                let child_keys = child_keys?;
145                let ctx = signing::key_agg(&child_keys)?;
146                // Return the aggregate key as compressed point
147                let agg_enc = ctx.aggregate_key.to_encoded_point(true);
148                let mut out = [0u8; 33];
149                out.copy_from_slice(agg_enc.as_bytes());
150                Ok(out)
151            }
152        }
153    }
154
155    /// Get the key aggregation context for this node (only valid for Internal nodes).
156    pub fn key_agg_context(&self) -> Result<KeyAggContext, SignerError> {
157        match self {
158            KeyTreeNode::Leaf(_) => Err(SignerError::ParseError(
159                "leaf nodes don't have a key_agg context".into(),
160            )),
161            KeyTreeNode::Internal(children) => {
162                let child_keys: Result<Vec<[u8; 33]>, _> = children
163                    .iter()
164                    .map(|child| child.effective_pubkey())
165                    .collect();
166                signing::key_agg(&child_keys?)
167            }
168        }
169    }
170
171    /// Count the total number of leaf keys in the tree.
172    #[must_use]
173    pub fn leaf_count(&self) -> usize {
174        match self {
175            KeyTreeNode::Leaf(_) => 1,
176            KeyTreeNode::Internal(children) => children.iter().map(|c| c.leaf_count()).sum(),
177        }
178    }
179
180    /// Get the depth of the tree.
181    #[must_use]
182    pub fn depth(&self) -> usize {
183        match self {
184            KeyTreeNode::Leaf(_) => 0,
185            KeyTreeNode::Internal(children) => {
186                1 + children.iter().map(|c| c.depth()).max().unwrap_or(0)
187            }
188        }
189    }
190}
191
192/// Build a flat MuSig2 key tree from a list of public keys.
193///
194/// All keys are leaves under a single internal aggregation node.
195pub fn flat_key_tree(pubkeys: &[[u8; 33]]) -> KeyTreeNode {
196    KeyTreeNode::Internal(pubkeys.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
197}
198
199/// Build a 2-level key tree where each group is a sub-aggregation.
200///
201/// # Example
202/// ```text
203/// groups = [[pk1, pk2], [pk3, pk4]]
204/// Result:
205///   TopLevel
206///     ├── agg(pk1, pk2)
207///     └── agg(pk3, pk4)
208/// ```
209pub fn grouped_key_tree(groups: &[Vec<[u8; 33]>]) -> KeyTreeNode {
210    KeyTreeNode::Internal(
211        groups
212            .iter()
213            .map(|group| {
214                if group.len() == 1 {
215                    KeyTreeNode::Leaf(group[0])
216                } else {
217                    KeyTreeNode::Internal(group.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
218                }
219            })
220            .collect(),
221    )
222}
223
224// ─── Helpers ────────────────────────────────────────────────────────
225
226/// Tagged hash to scalar — delegates to shared implementation.
227fn tagged_hash_scalar(tag: &[u8], data: &[u8]) -> Scalar {
228    super::tagged_hash_scalar(tag, data)
229}
230
231// ═══════════════════════════════════════════════════════════════════
232// Tests
233// ═══════════════════════════════════════════════════════════════════
234
235#[cfg(test)]
236#[allow(clippy::unwrap_used, clippy::expect_used)]
237mod tests {
238    use super::super::signing;
239    use super::*;
240
241    fn make_keys() -> ([u8; 32], [u8; 32], [u8; 33], [u8; 33]) {
242        let sk1 = [0x11u8; 32];
243        let sk2 = [0x22u8; 32];
244        let pk1 = signing::individual_pubkey(&sk1).unwrap();
245        let pk2 = signing::individual_pubkey(&sk2).unwrap();
246        (sk1, sk2, pk1, pk2)
247    }
248
249    // ─── Partial Signature Verification ─────────────────────────
250
251    #[test]
252    fn test_partial_sig_verify_valid() {
253        let (sk1, sk2, pk1, pk2) = make_keys();
254        let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
255        let msg = b"partial sig verify";
256
257        let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
258        let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
259        let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
260
261        let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
262
263        let valid = verify_partial_sig(&psig1, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
264        assert!(valid, "valid partial sig must verify");
265    }
266
267    #[test]
268    fn test_partial_sig_verify_tampered() {
269        let (sk1, sk2, pk1, pk2) = make_keys();
270        let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
271        let msg = b"tampered partial sig";
272
273        let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
274        let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
275        let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
276
277        let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
278
279        // Tamper
280        let tampered = PartialSignature {
281            s: psig1.s + Scalar::ONE,
282        };
283        let valid = verify_partial_sig(&tampered, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
284        assert!(!valid, "tampered partial sig must fail");
285    }
286
287    #[test]
288    fn test_partial_sig_wrong_key_fails() {
289        let (sk1, sk2, pk1, pk2) = make_keys();
290        let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
291        let msg = b"wrong key";
292
293        let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
294        let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
295        let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
296        let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
297
298        // Verify with wrong key (pk2 instead of pk1)
299        let valid = verify_partial_sig(&psig1, &pub1, &pk2, &ctx, &agg_nonce, msg).unwrap();
300        assert!(!valid, "partial sig verified with wrong key must fail");
301    }
302
303    // ─── Key Tree Tests ─────────────────────────────────────────
304
305    #[test]
306    fn test_flat_key_tree() {
307        let (_, _, pk1, pk2) = make_keys();
308        let tree = flat_key_tree(&[pk1, pk2]);
309        assert_eq!(tree.leaf_count(), 2);
310        assert_eq!(tree.depth(), 1);
311
312        let effective = tree.effective_pubkey().unwrap();
313        let direct_ctx = signing::key_agg(&[pk1, pk2]).unwrap();
314        let direct_enc = direct_ctx.aggregate_key.to_encoded_point(true);
315        let mut direct_bytes = [0u8; 33];
316        direct_bytes.copy_from_slice(direct_enc.as_bytes());
317        assert_eq!(effective, direct_bytes, "flat tree should match direct agg");
318    }
319
320    #[test]
321    fn test_nested_key_tree() {
322        let sk3 = [0x33u8; 32];
323        let pk3 = signing::individual_pubkey(&sk3).unwrap();
324        let (_, _, pk1, pk2) = make_keys();
325
326        // 2-level tree: agg(agg(pk1, pk2), pk3)
327        let tree = KeyTreeNode::Internal(vec![
328            KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
329            KeyTreeNode::Leaf(pk3),
330        ]);
331
332        assert_eq!(tree.leaf_count(), 3);
333        assert_eq!(tree.depth(), 2);
334
335        let effective = tree.effective_pubkey().unwrap();
336        assert_ne!(effective, [0u8; 33]);
337
338        // The top-level context should exist
339        let ctx = tree.key_agg_context().unwrap();
340        assert_eq!(ctx.pubkeys.len(), 2); // 2 children at top level
341    }
342
343    #[test]
344    fn test_grouped_key_tree() {
345        let (_, _, pk1, pk2) = make_keys();
346        let sk3 = [0x33u8; 32];
347        let sk4 = [0x44u8; 32];
348        let pk3 = signing::individual_pubkey(&sk3).unwrap();
349        let pk4 = signing::individual_pubkey(&sk4).unwrap();
350
351        let tree = grouped_key_tree(&[vec![pk1, pk2], vec![pk3, pk4]]);
352        assert_eq!(tree.leaf_count(), 4);
353        assert_eq!(tree.depth(), 2);
354
355        let effective = tree.effective_pubkey().unwrap();
356        assert_ne!(effective, [0u8; 33]);
357    }
358
359    #[test]
360    fn test_leaf_key_agg_context_error() {
361        let (_, _, pk1, _) = make_keys();
362        let leaf = KeyTreeNode::Leaf(pk1);
363        assert!(leaf.key_agg_context().is_err());
364    }
365
366    #[test]
367    fn test_nested_tree_deterministic() {
368        let (_, _, pk1, pk2) = make_keys();
369        let sk3 = [0x33u8; 32];
370        let pk3 = signing::individual_pubkey(&sk3).unwrap();
371
372        let tree1 = KeyTreeNode::Internal(vec![
373            KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
374            KeyTreeNode::Leaf(pk3),
375        ]);
376        let tree2 = KeyTreeNode::Internal(vec![
377            KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
378            KeyTreeNode::Leaf(pk3),
379        ]);
380
381        assert_eq!(
382            tree1.effective_pubkey().unwrap(),
383            tree2.effective_pubkey().unwrap()
384        );
385    }
386}