Skip to main content

chains_sdk/threshold/musig2/
adaptor.rs

1//! MuSig2 Adaptor Signatures and Key Aggregation Coefficient Caching.
2//!
3//! Extends MuSig2 with:
4//! - **Adaptor Signatures**: Encrypt partial signatures under an adaptor point.
5//!   The final signature reveals the adaptor secret (useful for atomic swaps).
6//! - **Key Aggregation Caching**: Cache expensive key aggregation computations
7//!   for repeated signing sessions with the same key set.
8
9use crate::error::SignerError;
10use core::fmt;
11use k256::elliptic_curve::group::GroupEncoding;
12use k256::elliptic_curve::sec1::ToEncodedPoint;
13use k256::{AffinePoint, ProjectivePoint, Scalar};
14
15use super::signing::{KeyAggContext, PartialSignature};
16
17// ═══════════════════════════════════════════════════════════════════
18// Adaptor Signatures
19// ═══════════════════════════════════════════════════════════════════
20
21/// An adaptor (pre-)signature that can be completed with the adaptor secret.
22#[derive(Clone)]
23pub struct AdaptorSignature {
24    /// The adaptor point `T` (compressed, 33 bytes).
25    pub adaptor_point: [u8; 33],
26    /// The adapted nonce point `R' = R + T` x-coordinate (32 bytes).
27    pub adapted_r: [u8; 32],
28    /// The partial adaptor signature scalar `s'`.
29    pub s_adaptor: Scalar,
30}
31
32impl fmt::Debug for AdaptorSignature {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.debug_struct("AdaptorSignature")
35            .field("adaptor_point", &hex::encode(self.adaptor_point))
36            .field("adapted_r", &hex::encode(self.adapted_r))
37            .field("s_adaptor", &"[REDACTED]")
38            .finish()
39    }
40}
41
42impl AdaptorSignature {
43    /// Complete the adaptor signature by revealing the adaptor secret.
44    ///
45    /// Given the adaptor secret scalar `t`, computes the final signature `s = s' + t`.
46    ///
47    /// # Returns
48    /// A 64-byte Schnorr signature `(R', s)`.
49    #[must_use]
50    pub fn complete(&self, adaptor_secret: &Scalar) -> [u8; 64] {
51        let s = self.s_adaptor + adaptor_secret;
52        let mut sig = [0u8; 64];
53        sig[..32].copy_from_slice(&self.adapted_r);
54        sig[32..].copy_from_slice(&s.to_bytes());
55        sig
56    }
57
58    /// Extract the adaptor secret from a completed signature.
59    ///
60    /// Given the completed signature `s` and the adaptor signature `s'`,
61    /// computes `t = s - s'` to learn the adaptor secret.
62    #[must_use]
63    pub fn extract_secret(&self, completed_s: &Scalar) -> Scalar {
64        *completed_s - self.s_adaptor
65    }
66}
67
68/// Create an adaptor partial signature.
69///
70/// The adaptor point `T = t * G` is publicly known. The signer creates
71/// a signature encrypted under `T`, such that completing it reveals `t`.
72///
73/// # Arguments
74/// - `partial_sig` — The standard MuSig2 partial signature scalar
75/// - `adaptor_point` — The public adaptor point (33 bytes compressed)
76/// - `agg_nonce_r_x` — The x-coordinate of the aggregated nonce R (32 bytes)
77pub fn create_adaptor_signature(
78    partial_sig: &PartialSignature,
79    adaptor_point: &[u8; 33],
80    agg_nonce_r_x: &[u8; 32],
81) -> Result<AdaptorSignature, SignerError> {
82    // Parse adaptor point
83    let t_affine = {
84        let ct = AffinePoint::from_bytes(adaptor_point.into());
85        if !bool::from(ct.is_some()) {
86            return Err(SignerError::InvalidPublicKey(
87                "invalid adaptor point".into(),
88            ));
89        }
90        #[allow(clippy::unwrap_used)]
91        ct.unwrap()
92    };
93
94    // Parse R point
95    let r_affine = {
96        let mut r_bytes = [0u8; 33];
97        r_bytes[0] = 0x02;
98        r_bytes[1..].copy_from_slice(agg_nonce_r_x);
99        let ct = AffinePoint::from_bytes((&r_bytes).into());
100        if !bool::from(ct.is_some()) {
101            return Err(SignerError::InvalidPublicKey(
102                "invalid nonce R point".into(),
103            ));
104        }
105        #[allow(clippy::unwrap_used)]
106        ct.unwrap()
107    };
108
109    // R' = R + T (adapted nonce)
110    let adapted = ProjectivePoint::from(r_affine) + ProjectivePoint::from(t_affine);
111    let adapted_affine = adapted.to_affine();
112    let adapted_encoded = adapted_affine.to_encoded_point(false);
113    let mut adapted_r = [0u8; 32];
114    if let Some(x) = adapted_encoded.x() {
115        adapted_r.copy_from_slice(&x[..]);
116    }
117
118    Ok(AdaptorSignature {
119        adaptor_point: *adaptor_point,
120        adapted_r,
121        s_adaptor: partial_sig.s,
122    })
123}
124
125// ═══════════════════════════════════════════════════════════════════
126// Key Aggregation Caching
127// ═══════════════════════════════════════════════════════════════════
128
129/// Cached key aggregation state for repeated signing with the same key set.
130///
131/// Pre-computes the expensive key aggregation coefficients and aggregate key
132/// so they can be reused across multiple signing sessions.
133#[derive(Clone)]
134pub struct CachedKeyAgg {
135    /// The underlying key aggregation context.
136    pub context: KeyAggContext,
137    /// Cache of the individual public keys.
138    pub pubkeys: Vec<[u8; 33]>,
139}
140
141impl fmt::Debug for CachedKeyAgg {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        f.debug_struct("CachedKeyAgg")
144            .field("num_keys", &self.pubkeys.len())
145            .field("aggregate_key", &hex::encode(self.context.x_only_pubkey))
146            .finish()
147    }
148}
149
150impl CachedKeyAgg {
151    /// Create a cached key aggregation from public keys.
152    ///
153    /// This performs the key aggregation once and caches the result
154    /// for reuse in multiple signing sessions.
155    pub fn new(pubkeys: &[[u8; 33]]) -> Result<Self, SignerError> {
156        let context = super::signing::key_agg(pubkeys)?;
157        Ok(Self {
158            context,
159            pubkeys: pubkeys.to_vec(),
160        })
161    }
162
163    /// Get the aggregate public key (x-only, 32 bytes).
164    #[must_use]
165    pub fn aggregate_pubkey(&self) -> [u8; 32] {
166        self.context.x_only_pubkey
167    }
168
169    /// Get the number of signers.
170    #[must_use]
171    pub fn num_signers(&self) -> usize {
172        self.pubkeys.len()
173    }
174
175    /// Check if a specific public key is part of this key aggregation.
176    #[must_use]
177    pub fn contains_key(&self, pubkey: &[u8; 33]) -> bool {
178        self.pubkeys.contains(pubkey)
179    }
180}
181
182// ═══════════════════════════════════════════════════════════════════
183// Tests
184// ═══════════════════════════════════════════════════════════════════
185
186#[cfg(test)]
187#[allow(clippy::unwrap_used, clippy::expect_used)]
188mod tests {
189    use super::super::signing;
190    use super::*;
191    use k256::elliptic_curve::ops::Reduce;
192
193    #[test]
194    fn test_cached_key_agg() {
195        let sk1 = [0x01u8; 32];
196        let sk2 = [0x02u8; 32];
197        let pk1 = signing::individual_pubkey(&sk1).unwrap();
198        let pk2 = signing::individual_pubkey(&sk2).unwrap();
199
200        let cached = CachedKeyAgg::new(&[pk1, pk2]).unwrap();
201        assert_eq!(cached.num_signers(), 2);
202        assert!(cached.contains_key(&pk1));
203        assert!(cached.contains_key(&pk2));
204        assert!(!cached.contains_key(&[0xFF; 33]));
205
206        let agg_pk = cached.aggregate_pubkey();
207        assert_ne!(agg_pk, [0u8; 32]);
208    }
209
210    #[test]
211    fn test_cached_key_agg_deterministic() {
212        let sk1 = [0x01u8; 32];
213        let sk2 = [0x02u8; 32];
214        let pk1 = signing::individual_pubkey(&sk1).unwrap();
215        let pk2 = signing::individual_pubkey(&sk2).unwrap();
216
217        let cached1 = CachedKeyAgg::new(&[pk1, pk2]).unwrap();
218        let cached2 = CachedKeyAgg::new(&[pk1, pk2]).unwrap();
219        assert_eq!(cached1.aggregate_pubkey(), cached2.aggregate_pubkey());
220    }
221
222    #[test]
223    fn test_adaptor_complete_extract_roundtrip() {
224        let adaptor_secret = Scalar::from(42u64);
225        let adaptor_point_proj = ProjectivePoint::GENERATOR * adaptor_secret;
226        let adaptor_affine = adaptor_point_proj.to_affine();
227        let adaptor_bytes: [u8; 33] = adaptor_affine.to_bytes().into();
228
229        let dummy_s = Scalar::from(100u64);
230        let partial = PartialSignature { s: dummy_s };
231
232        // R = 7*G
233        let r_point = ProjectivePoint::GENERATOR * Scalar::from(7u64);
234        let r_affine = r_point.to_affine();
235        let r_encoded = r_affine.to_encoded_point(false);
236        let mut r_x = [0u8; 32];
237        if let Some(x) = r_encoded.x() {
238            r_x.copy_from_slice(x.as_slice());
239        }
240
241        let adaptor_sig = create_adaptor_signature(&partial, &adaptor_bytes, &r_x).unwrap();
242        let completed = adaptor_sig.complete(&adaptor_secret);
243        let completed_s_bytes: [u8; 32] = completed[32..].try_into().unwrap();
244        let completed_s = <Scalar as Reduce<k256::U256>>::reduce_bytes(&completed_s_bytes.into());
245
246        let extracted = adaptor_sig.extract_secret(&completed_s);
247        assert_eq!(extracted, adaptor_secret);
248    }
249
250    #[test]
251    fn test_adaptor_signature_structure() {
252        let adaptor_secret = Scalar::from(99u64);
253        let adaptor_point_proj = ProjectivePoint::GENERATOR * adaptor_secret;
254        let adaptor_affine = adaptor_point_proj.to_affine();
255        let adaptor_bytes: [u8; 33] = adaptor_affine.to_bytes().into();
256
257        let partial = PartialSignature {
258            s: Scalar::from(50u64),
259        };
260        let r = ProjectivePoint::GENERATOR * Scalar::from(3u64);
261        let r_affine = r.to_affine();
262        let r_encoded = r_affine.to_encoded_point(false);
263        let mut r_x = [0u8; 32];
264        if let Some(x) = r_encoded.x() {
265            r_x.copy_from_slice(x.as_slice());
266        }
267
268        let adaptor = create_adaptor_signature(&partial, &adaptor_bytes, &r_x).unwrap();
269        assert_eq!(adaptor.adaptor_point, adaptor_bytes);
270        assert_ne!(adaptor.adapted_r, [0u8; 32]);
271    }
272}