Skip to main content

chains_sdk/threshold/musig2/
nested.rs

1//! MuSig2 Nested Key Aggregation (Key Trees) and Partial Signature Verification.
2//!
3//! Supports hierarchical key aggregation where some "leaf" keys are themselves
4//! MuSig2 aggregate keys. Also provides partial signature verification to
5//! validate individual signer contributions before aggregation.
6//!
7//! # Nested Key Example
8//! ```text
9//!    TopLevel (Q = agg(Q_a, Q_b))
10//!         /         \
11//!     Q_a = agg(pk1, pk2)    Q_b = pk3
12//! ```
13//! The group Q_a signs internally using 2-of-2 MuSig2, then their combined
14//! partial signature contributes as one signer to the top-level session.
15
16use crate::crypto;
17use crate::error::SignerError;
18use k256::elliptic_curve::group::GroupEncoding;
19use k256::elliptic_curve::ops::Reduce;
20use k256::elliptic_curve::sec1::ToEncodedPoint;
21use k256::{AffinePoint, ProjectivePoint, Scalar};
22
23use super::signing::{
24    self, compute_nonce_coeff, AggNonce, KeyAggContext, PartialSignature, PubNonce,
25};
26
27// ═══════════════════════════════════════════════════════════════════
28// Partial Signature Verification
29// ═══════════════════════════════════════════════════════════════════
30
31/// Verify a partial signature from a specific signer.
32///
33/// Checks that the partial signature `s_i` satisfies:
34/// `s_i * G == R_i + e * a_i * (g * P_i)`
35///
36/// where:
37/// - `R_i = R_{1,i} + b * R_{2,i}` (effective signer nonce)
38/// - `e` = BIP-340 challenge
39/// - `a_i` = key aggregation coefficient
40/// - `g` = parity correction (1 or -1)
41/// - `P_i` = signer's public key
42///
43/// # Arguments
44/// - `partial_sig` — The partial signature to verify
45/// - `pub_nonce` — The signer's public nonce (from round 1)
46/// - `signer_pubkey` — The signer's compressed public key (33 bytes)
47/// - `key_agg_ctx` — The key aggregation context
48/// - `agg_nonce` — The aggregated nonce
49/// - `msg` — The signed message
50pub fn verify_partial_sig(
51    partial_sig: &PartialSignature,
52    pub_nonce: &PubNonce,
53    signer_pubkey: &[u8; 33],
54    key_agg_ctx: &KeyAggContext,
55    agg_nonce: &AggNonce,
56    msg: &[u8],
57) -> Result<bool, SignerError> {
58    // Compute nonce coefficient b
59    let b = compute_nonce_coeff(agg_nonce, &key_agg_ctx.x_only_pubkey, msg);
60
61    // Effective R for the group: R = R1_agg + b * R2_agg
62    let r_group = ProjectivePoint::from(agg_nonce.r1) + ProjectivePoint::from(agg_nonce.r2) * b;
63    let r_affine = r_group.to_affine();
64    let r_encoded = r_affine.to_encoded_point(true);
65    let r_bytes = r_encoded.as_bytes();
66    let nonce_negated = r_bytes[0] == 0x03;
67    let mut r_x = [0u8; 32];
68    r_x.copy_from_slice(&r_bytes[1..33]);
69
70    // BIP-340 challenge
71    let mut challenge_data = Vec::new();
72    challenge_data.extend_from_slice(&r_x);
73    challenge_data.extend_from_slice(&key_agg_ctx.x_only_pubkey);
74    challenge_data.extend_from_slice(msg);
75    let e = tagged_hash_scalar(b"BIP0340/challenge", &challenge_data);
76
77    // Find the signer's aggregation coefficient
78    let signer_idx = key_agg_ctx
79        .pubkeys
80        .iter()
81        .position(|pk| pk == signer_pubkey);
82    let signer_idx = match signer_idx {
83        Some(idx) => idx,
84        None => {
85            return Err(SignerError::SigningFailed(
86                "signer not in key_agg context".into(),
87            ))
88        }
89    };
90    let a_i = key_agg_ctx.coefficients[signer_idx];
91
92    // Parse the signer's public key
93    let pk_ct = AffinePoint::from_bytes(signer_pubkey.into());
94    if !bool::from(pk_ct.is_some()) {
95        return Ok(false);
96    }
97    #[allow(clippy::unwrap_used)]
98    let pk_point = ProjectivePoint::from(pk_ct.unwrap());
99
100    // Effective signer nonce: R_i = R_{1,i} + b * R_{2,i}
101    let ri = ProjectivePoint::from(pub_nonce.r1) + ProjectivePoint::from(pub_nonce.r2) * b;
102    let effective_ri = if nonce_negated { -ri } else { ri };
103
104    // Effective public key (negate if aggregate key has odd y)
105    let effective_pk = if key_agg_ctx.parity {
106        -pk_point
107    } else {
108        pk_point
109    };
110
111    // LHS: s_i * G
112    let lhs = ProjectivePoint::GENERATOR * partial_sig.s;
113
114    // RHS: R_i + e * a_i * P_i
115    let rhs = effective_ri + effective_pk * (e * a_i);
116
117    Ok(lhs == rhs)
118}
119
120// ═══════════════════════════════════════════════════════════════════
121// Nested Key Aggregation (Key Trees)
122// ═══════════════════════════════════════════════════════════════════
123
124/// A node in a MuSig2 key tree.
125#[derive(Clone, Debug)]
126pub enum KeyTreeNode {
127    /// A leaf node: a single public key (33 bytes compressed).
128    Leaf([u8; 33]),
129    /// An internal node: a MuSig2 aggregation of child nodes.
130    Internal(Vec<KeyTreeNode>),
131}
132
133impl KeyTreeNode {
134    /// Compute the effective public key for this node.
135    ///
136    /// For a leaf, returns the key directly.
137    /// For an internal node, recursively aggregates child keys using MuSig2.
138    pub fn effective_pubkey(&self) -> Result<[u8; 33], SignerError> {
139        match self {
140            KeyTreeNode::Leaf(pk) => Ok(*pk),
141            KeyTreeNode::Internal(children) => {
142                let child_keys: Result<Vec<[u8; 33]>, _> = children
143                    .iter()
144                    .map(|child| child.effective_pubkey())
145                    .collect();
146                let child_keys = child_keys?;
147                let ctx = signing::key_agg(&child_keys)?;
148                // Return the aggregate key as compressed point
149                let agg_enc = ctx.aggregate_key.to_encoded_point(true);
150                let mut out = [0u8; 33];
151                out.copy_from_slice(agg_enc.as_bytes());
152                Ok(out)
153            }
154        }
155    }
156
157    /// Get the key aggregation context for this node (only valid for Internal nodes).
158    pub fn key_agg_context(&self) -> Result<KeyAggContext, SignerError> {
159        match self {
160            KeyTreeNode::Leaf(_) => Err(SignerError::ParseError(
161                "leaf nodes don't have a key_agg context".into(),
162            )),
163            KeyTreeNode::Internal(children) => {
164                let child_keys: Result<Vec<[u8; 33]>, _> = children
165                    .iter()
166                    .map(|child| child.effective_pubkey())
167                    .collect();
168                signing::key_agg(&child_keys?)
169            }
170        }
171    }
172
173    /// Count the total number of leaf keys in the tree.
174    #[must_use]
175    pub fn leaf_count(&self) -> usize {
176        match self {
177            KeyTreeNode::Leaf(_) => 1,
178            KeyTreeNode::Internal(children) => children.iter().map(|c| c.leaf_count()).sum(),
179        }
180    }
181
182    /// Get the depth of the tree.
183    #[must_use]
184    pub fn depth(&self) -> usize {
185        match self {
186            KeyTreeNode::Leaf(_) => 0,
187            KeyTreeNode::Internal(children) => {
188                1 + children.iter().map(|c| c.depth()).max().unwrap_or(0)
189            }
190        }
191    }
192}
193
194/// Build a flat MuSig2 key tree from a list of public keys.
195///
196/// All keys are leaves under a single internal aggregation node.
197pub fn flat_key_tree(pubkeys: &[[u8; 33]]) -> KeyTreeNode {
198    KeyTreeNode::Internal(pubkeys.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
199}
200
201/// Build a 2-level key tree where each group is a sub-aggregation.
202///
203/// # Example
204/// ```text
205/// groups = [[pk1, pk2], [pk3, pk4]]
206/// Result:
207///   TopLevel
208///     ├── agg(pk1, pk2)
209///     └── agg(pk3, pk4)
210/// ```
211pub fn grouped_key_tree(groups: &[Vec<[u8; 33]>]) -> KeyTreeNode {
212    KeyTreeNode::Internal(
213        groups
214            .iter()
215            .map(|group| {
216                if group.len() == 1 {
217                    KeyTreeNode::Leaf(group[0])
218                } else {
219                    KeyTreeNode::Internal(group.iter().map(|pk| KeyTreeNode::Leaf(*pk)).collect())
220                }
221            })
222            .collect(),
223    )
224}
225
226// ─── Helpers ────────────────────────────────────────────────────────
227
228/// Tagged hash to scalar (duplicated for module isolation).
229fn tagged_hash_scalar(tag: &[u8], data: &[u8]) -> Scalar {
230    let hash = crypto::tagged_hash(tag, data);
231    let wide = k256::U256::from_be_slice(&hash);
232    <Scalar as Reduce<k256::U256>>::reduce(wide)
233}
234
235// ═══════════════════════════════════════════════════════════════════
236// Tests
237// ═══════════════════════════════════════════════════════════════════
238
239#[cfg(test)]
240#[allow(clippy::unwrap_used, clippy::expect_used)]
241mod tests {
242    use super::super::signing;
243    use super::*;
244
245    fn make_keys() -> ([u8; 32], [u8; 32], [u8; 33], [u8; 33]) {
246        let sk1 = [0x11u8; 32];
247        let sk2 = [0x22u8; 32];
248        let pk1 = signing::individual_pubkey(&sk1).unwrap();
249        let pk2 = signing::individual_pubkey(&sk2).unwrap();
250        (sk1, sk2, pk1, pk2)
251    }
252
253    // ─── Partial Signature Verification ─────────────────────────
254
255    #[test]
256    fn test_partial_sig_verify_valid() {
257        let (sk1, sk2, pk1, pk2) = make_keys();
258        let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
259        let msg = b"partial sig verify";
260
261        let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
262        let (sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
263        let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
264
265        let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
266
267        let valid = verify_partial_sig(&psig1, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
268        assert!(valid, "valid partial sig must verify");
269    }
270
271    #[test]
272    fn test_partial_sig_verify_tampered() {
273        let (sk1, sk2, pk1, pk2) = make_keys();
274        let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
275        let msg = b"tampered partial sig";
276
277        let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
278        let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
279        let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
280
281        let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
282
283        // Tamper
284        let tampered = PartialSignature {
285            s: psig1.s + Scalar::ONE,
286        };
287        let valid = verify_partial_sig(&tampered, &pub1, &pk1, &ctx, &agg_nonce, msg).unwrap();
288        assert!(!valid, "tampered partial sig must fail");
289    }
290
291    #[test]
292    fn test_partial_sig_wrong_key_fails() {
293        let (sk1, sk2, pk1, pk2) = make_keys();
294        let ctx = signing::key_agg(&[pk1, pk2]).unwrap();
295        let msg = b"wrong key";
296
297        let (sec1, pub1) = signing::nonce_gen(&sk1, &pk1, &ctx, msg, &[]).unwrap();
298        let (_sec2, pub2) = signing::nonce_gen(&sk2, &pk2, &ctx, msg, &[]).unwrap();
299        let agg_nonce = signing::nonce_agg(&[pub1.clone(), pub2]).unwrap();
300        let psig1 = signing::sign(sec1, &sk1, &ctx, &agg_nonce, msg).unwrap();
301
302        // Verify with wrong key (pk2 instead of pk1)
303        let valid = verify_partial_sig(&psig1, &pub1, &pk2, &ctx, &agg_nonce, msg).unwrap();
304        assert!(!valid, "partial sig verified with wrong key must fail");
305    }
306
307    // ─── Key Tree Tests ─────────────────────────────────────────
308
309    #[test]
310    fn test_flat_key_tree() {
311        let (_, _, pk1, pk2) = make_keys();
312        let tree = flat_key_tree(&[pk1, pk2]);
313        assert_eq!(tree.leaf_count(), 2);
314        assert_eq!(tree.depth(), 1);
315
316        let effective = tree.effective_pubkey().unwrap();
317        let direct_ctx = signing::key_agg(&[pk1, pk2]).unwrap();
318        let direct_enc = direct_ctx.aggregate_key.to_encoded_point(true);
319        let mut direct_bytes = [0u8; 33];
320        direct_bytes.copy_from_slice(direct_enc.as_bytes());
321        assert_eq!(effective, direct_bytes, "flat tree should match direct agg");
322    }
323
324    #[test]
325    fn test_nested_key_tree() {
326        let sk3 = [0x33u8; 32];
327        let pk3 = signing::individual_pubkey(&sk3).unwrap();
328        let (_, _, pk1, pk2) = make_keys();
329
330        // 2-level tree: agg(agg(pk1, pk2), pk3)
331        let tree = KeyTreeNode::Internal(vec![
332            KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
333            KeyTreeNode::Leaf(pk3),
334        ]);
335
336        assert_eq!(tree.leaf_count(), 3);
337        assert_eq!(tree.depth(), 2);
338
339        let effective = tree.effective_pubkey().unwrap();
340        assert_ne!(effective, [0u8; 33]);
341
342        // The top-level context should exist
343        let ctx = tree.key_agg_context().unwrap();
344        assert_eq!(ctx.pubkeys.len(), 2); // 2 children at top level
345    }
346
347    #[test]
348    fn test_grouped_key_tree() {
349        let (_, _, pk1, pk2) = make_keys();
350        let sk3 = [0x33u8; 32];
351        let sk4 = [0x44u8; 32];
352        let pk3 = signing::individual_pubkey(&sk3).unwrap();
353        let pk4 = signing::individual_pubkey(&sk4).unwrap();
354
355        let tree = grouped_key_tree(&[vec![pk1, pk2], vec![pk3, pk4]]);
356        assert_eq!(tree.leaf_count(), 4);
357        assert_eq!(tree.depth(), 2);
358
359        let effective = tree.effective_pubkey().unwrap();
360        assert_ne!(effective, [0u8; 33]);
361    }
362
363    #[test]
364    fn test_leaf_key_agg_context_error() {
365        let (_, _, pk1, _) = make_keys();
366        let leaf = KeyTreeNode::Leaf(pk1);
367        assert!(leaf.key_agg_context().is_err());
368    }
369
370    #[test]
371    fn test_nested_tree_deterministic() {
372        let (_, _, pk1, pk2) = make_keys();
373        let sk3 = [0x33u8; 32];
374        let pk3 = signing::individual_pubkey(&sk3).unwrap();
375
376        let tree1 = KeyTreeNode::Internal(vec![
377            KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
378            KeyTreeNode::Leaf(pk3),
379        ]);
380        let tree2 = KeyTreeNode::Internal(vec![
381            KeyTreeNode::Internal(vec![KeyTreeNode::Leaf(pk1), KeyTreeNode::Leaf(pk2)]),
382            KeyTreeNode::Leaf(pk3),
383        ]);
384
385        assert_eq!(
386            tree1.effective_pubkey().unwrap(),
387            tree2.effective_pubkey().unwrap()
388        );
389    }
390}