Skip to main content

iris_crypto/
cheetah.rs

1#[cfg(feature = "alloc")]
2use alloc::{
3    boxed::Box,
4    format,
5    string::{String, ToString},
6    vec::Vec,
7};
8use arrayvec::ArrayVec;
9use iris_ztd::{
10    crypto::cheetah::{
11        ch_add, ch_neg, ch_scal_big, trunc_g_order, CheetahPoint, F6lt, A_GEN, G_ORDER,
12    },
13    tip5::hash::hash_varlen,
14    Belt, Digest, Hashable, MulMod, U256,
15};
16#[cfg(feature = "alloc")]
17use iris_ztd::{Noun, NounDecode, NounEncode};
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
21#[cfg_attr(feature = "alloc", derive(NounEncode, NounDecode))]
22#[iris_ztd::wasm_noun_codec]
23pub struct PublicKey(pub CheetahPoint);
24
25impl core::fmt::Display for PublicKey {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        write!(f, "{}", self.0)
28    }
29}
30
31impl TryFrom<&str> for PublicKey {
32    type Error = iris_ztd::crypto::cheetah::CheetahError;
33
34    fn try_from(value: &str) -> Result<Self, Self::Error> {
35        value.try_into().map(Self)
36    }
37}
38
39#[iris_ztd::wasm_member_methods]
40impl PublicKey {
41    pub fn verify(&self, m: &Digest, sig: &Signature) -> bool {
42        if sig.c == U256::ZERO || sig.c >= G_ORDER || sig.s == U256::ZERO || sig.s >= G_ORDER {
43            return false;
44        }
45
46        // Compute scalar = s*G - c*pubkey
47        // This is equivalent to: scalar = s*G + (-c)*pubkey
48        let sg = match ch_scal_big(&sig.s, &A_GEN) {
49            Ok(pt) => pt,
50            Err(_) => return false,
51        };
52        let c_pk = match ch_scal_big(&sig.c, &self.0) {
53            Ok(pt) => pt,
54            Err(_) => return false,
55        };
56        let scalar = match ch_add(&sg, &ch_neg(&c_pk)) {
57            Ok(pt) => pt,
58            Err(_) => return false,
59        };
60        let chal = {
61            let mut transcript: ArrayVec<Belt, { 6 + 6 + 6 + 6 + 5 }> = ArrayVec::new();
62            transcript.try_extend_from_slice(&scalar.x.0).unwrap();
63            transcript.try_extend_from_slice(&scalar.y.0).unwrap();
64            transcript.try_extend_from_slice(&self.0.x.0).unwrap();
65            transcript.try_extend_from_slice(&self.0.y.0).unwrap();
66            transcript.try_extend_from_slice(&m.0).unwrap();
67            trunc_g_order(&hash_varlen(&transcript))
68        };
69
70        chal == sig.c
71    }
72
73    pub fn from_be_bytes(bytes: &[u8]) -> PublicKey {
74        let mut x = [Belt(0); 6];
75        let mut y = [Belt(0); 6];
76
77        // y-coordinate: bytes 1-48
78        for i in 0..6 {
79            let offset = 1 + i * 8;
80            let mut buf = [0u8; 8];
81            buf.copy_from_slice(&bytes[offset..offset + 8]);
82            y[5 - i] = Belt(u64::from_be_bytes(buf));
83        }
84
85        // x-coordinate: bytes 49-96
86        for i in 0..6 {
87            let offset = 49 + i * 8;
88            let mut buf = [0u8; 8];
89            buf.copy_from_slice(&bytes[offset..offset + 8]);
90            x[5 - i] = Belt(u64::from_be_bytes(buf));
91        }
92
93        PublicKey(CheetahPoint {
94            x: F6lt(x),
95            y: F6lt(y),
96            inf: false,
97        })
98    }
99
100    #[cfg(feature = "alloc")]
101    pub fn to_be_bytes_vec(&self) -> Vec<u8> {
102        self.to_be_bytes().to_vec()
103    }
104
105    #[cfg(feature = "alloc")]
106    pub fn from_hex(hex: &str) -> Option<PublicKey> {
107        let bytes = hex::decode(hex).ok()?;
108        if bytes.len() != 97 {
109            return None;
110        }
111        Some(Self::from_be_bytes(&bytes))
112    }
113
114    #[cfg(feature = "alloc")]
115    pub fn to_hex(&self) -> String {
116        hex::encode(self.to_be_bytes())
117    }
118}
119
120impl PublicKey {
121    pub fn to_be_bytes(&self) -> [u8; 97] {
122        let mut data = [0u8; 97];
123        data[0] = 0x01; // prefix byte
124        let mut offset = 1;
125        // y-coordinate: 6 belts × 8 bytes = 48 bytes
126        for belt in self.0.y.0.iter().rev() {
127            data[offset..offset + 8].copy_from_slice(&belt.0.to_be_bytes());
128            offset += 8;
129        }
130        // x-coordinate: 6 belts × 8 bytes = 48 bytes
131        for belt in self.0.x.0.iter().rev() {
132            data[offset..offset + 8].copy_from_slice(&belt.0.to_be_bytes());
133            offset += 8;
134        }
135        data
136    }
137
138    /// SLIP-10 compatible serialization (legacy 65-byte format for compatibility)
139    pub(crate) fn as_slip10_bytes(&self) -> [u8; 96] {
140        let mut data = [0u8; 96];
141        let mut offset = 0;
142        for belt in self.0.y.0.iter().rev().chain(self.0.x.0.iter().rev()) {
143            data[offset..offset + 8].copy_from_slice(&belt.0.to_be_bytes());
144            offset += 8;
145        }
146        data
147    }
148}
149
150impl core::ops::Add for &PublicKey {
151    type Output = PublicKey;
152
153    fn add(self, other: &PublicKey) -> PublicKey {
154        PublicKey(ch_add(&self.0, &other.0).unwrap())
155    }
156}
157
158impl core::ops::Add for PublicKey {
159    type Output = PublicKey;
160
161    fn add(self, other: PublicKey) -> PublicKey {
162        (&self as &PublicKey) + (&other as &PublicKey)
163    }
164}
165
166impl core::ops::AddAssign for PublicKey {
167    fn add_assign(&mut self, other: PublicKey) {
168        *self = *self + other;
169    }
170}
171
172impl core::ops::Sub for &PublicKey {
173    type Output = PublicKey;
174
175    fn sub(self, other: &PublicKey) -> PublicKey {
176        PublicKey(ch_add(&self.0, &ch_neg(&other.0)).unwrap())
177    }
178}
179
180impl core::ops::SubAssign for PublicKey {
181    fn sub_assign(&mut self, other: PublicKey) {
182        *self = &*self - &other;
183    }
184}
185
186impl core::iter::Sum<PublicKey> for PublicKey {
187    fn sum<I: Iterator<Item = PublicKey>>(iter: I) -> Self {
188        iter.fold(PublicKey(CheetahPoint::identity()), |acc, x| acc + x)
189    }
190}
191
192impl<'a> core::iter::Sum<&'a PublicKey> for PublicKey {
193    fn sum<I: Iterator<Item = &'a PublicKey>>(iter: I) -> Self {
194        iter.fold(PublicKey(CheetahPoint::identity()), |acc, x| &acc + x)
195    }
196}
197
198impl Hashable for PublicKey {
199    fn hash(&self) -> Digest {
200        self.0.hash()
201    }
202
203    fn leaf_count(&self) -> usize {
204        self.0.leaf_count()
205    }
206
207    fn hashable_pair<'a>(&'a self) -> Option<(impl Hashable + 'a, impl Hashable + 'a)> {
208        self.0.hashable_pair()
209    }
210}
211
212#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
213#[iris_ztd::wasm_noun_codec]
214pub struct Signature {
215    /// Challenge part in little-endian hex
216    #[cfg_attr(feature = "wasm", tsify(type = "string"))]
217    pub c: U256,
218    /// Signature scalar in little-endian hex
219    #[cfg_attr(feature = "wasm", tsify(type = "string"))]
220    pub s: U256,
221}
222
223// Aggregate signature of the same challenge
224impl core::iter::Sum<Signature> for Option<Signature> {
225    fn sum<I: Iterator<Item = Signature>>(mut iter: I) -> Self {
226        let mut c = None;
227        let s = iter.try_fold(U256::ZERO, |acc, x| {
228            if c.is_some() && c.as_ref() != Some(&x.c) {
229                return None;
230            }
231            c = Some(x.c);
232            Some(acc.add_mod(&x.s, &G_ORDER))
233        });
234        Some(Signature { c: c?, s: s? })
235    }
236}
237
238#[cfg(feature = "alloc")]
239impl NounEncode for Signature {
240    fn to_noun(&self) -> Noun {
241        (
242            Belt::from_bytes(&self.c.to_le_bytes()).as_slice(),
243            Belt::from_bytes(&self.s.to_le_bytes()).as_slice(),
244        )
245            .to_noun()
246    }
247}
248
249#[cfg(feature = "alloc")]
250impl NounDecode for Signature {
251    fn from_noun(noun: &Noun) -> Option<Self> {
252        let (c, s): ([Belt; 8], [Belt; 8]) = NounDecode::from_noun(noun)?;
253
254        let c = Belt::to_bytes(&c);
255        let s = Belt::to_bytes(&s);
256
257        Some(Signature {
258            c: U256::from_le_slice(&c),
259            s: U256::from_le_slice(&s),
260        })
261    }
262}
263
264// TODO: unblock alloc-less signature hashing by implementing allocless Belt::from_bytes
265#[cfg(feature = "alloc")]
266impl Hashable for Signature {
267    fn hash(&self) -> Digest {
268        self.to_noun().hash()
269    }
270
271    fn leaf_count(&self) -> usize {
272        1
273    }
274
275    fn hashable_pair<'a>(&'a self) -> Option<(impl Hashable + 'a, impl Hashable + 'a)> {
276        Option::<((), ())>::None
277    }
278}
279
280#[derive(Debug, Clone)]
281pub struct PrivateKey(pub U256);
282
283impl Drop for PrivateKey {
284    fn drop(&mut self) {
285        unsafe {
286            core::ptr::write_volatile(&mut self.0, U256::ZERO);
287        }
288        core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
289    }
290}
291
292impl PrivateKey {
293    pub fn public_key(&self) -> PublicKey {
294        PublicKey(ch_scal_big(&self.0, &A_GEN).unwrap())
295    }
296
297    pub fn sign(&self, m: &Digest) -> Signature {
298        self.sign_multi(m, &self.nonce_for(m), &self.public_key())
299    }
300
301    pub fn nonce_for(&self, m: &Digest) -> U256 {
302        let pubkey = self.public_key().0;
303        let nonce = {
304            let mut transcript: ArrayVec<Belt, { 6 + 6 + 5 + 8 }> = ArrayVec::new();
305            transcript.try_extend_from_slice(&pubkey.x.0).unwrap();
306            transcript.try_extend_from_slice(&pubkey.y.0).unwrap();
307            transcript.try_extend_from_slice(&m.0).unwrap();
308            self.0.to_le_bytes().chunks(4).for_each(|chunk| {
309                let mut buf = [0u8; 4];
310                buf[..chunk.len()].copy_from_slice(chunk);
311                transcript.push(Belt(u32::from_le_bytes(buf) as u64));
312            });
313            trunc_g_order(&hash_varlen(&transcript))
314        };
315        nonce
316    }
317
318    pub fn combine_nonces(nonces: &[U256]) -> U256 {
319        nonces
320            .iter()
321            .fold(U256::ZERO, |acc, x| acc.add_mod(x, &G_ORDER))
322    }
323
324    /// Perform a multiparty sign
325    ///
326    /// # Arguments
327    /// * `m` - The digest of message to sign
328    /// * `shared_nonce` - The challenge nonce. This is after taking `nonce_for(m)` on all private keys, and combining them with [`PrivateKey::combine_nonces`].
329    /// * `combined_pubkey` - The combined public key to sign against.
330    ///
331    /// # Returns
332    /// * `Signature` - The partial signature. This will be invalid until combined with other partial signatures.
333    ///
334    /// # Example
335    ///
336    /// ```
337    /// # use iris_ztd::{Digest, Belt, U256};
338    /// # use iris_crypto::cheetah::*;
339    /// let pk1 = PrivateKey(U256::from_u64(123));
340    /// let pk2 = PrivateKey(U256::from_u64(456));
341    /// let m = Digest([Belt(8), Belt(9), Belt(10), Belt(11), Belt(12)]);
342    /// let nonce1 = pk1.nonce_for(&m);
343    /// let nonce2 = pk2.nonce_for(&m);
344    /// let combined_nonce = PrivateKey::combine_nonces(&[nonce1, nonce2]);
345    /// let combined_pubkey = pk1.public_key() + pk2.public_key();
346    /// let sig1 = pk1.sign_multi(&m, &combined_nonce, &combined_pubkey);
347    /// let sig2 = pk2.sign_multi(&m, &combined_nonce, &combined_pubkey);
348    /// let sig = [sig1, sig2].into_iter().sum::<Option<Signature>>().unwrap();
349    /// assert!(combined_pubkey.verify(&m, &sig));
350    /// ```
351    pub fn sign_multi(
352        &self,
353        m: &Digest,
354        shared_nonce: &U256,
355        combined_pubkey: &PublicKey,
356    ) -> Signature {
357        let chal = {
358            // scalar = nonce * G
359            let scalar = ch_scal_big(shared_nonce, &A_GEN).unwrap();
360            let mut transcript: ArrayVec<Belt, { 6 + 6 + 6 + 6 + 5 }> = ArrayVec::new();
361            transcript.try_extend_from_slice(&scalar.x.0).unwrap();
362            transcript.try_extend_from_slice(&scalar.y.0).unwrap();
363            transcript
364                .try_extend_from_slice(&combined_pubkey.0.x.0)
365                .unwrap();
366            transcript
367                .try_extend_from_slice(&combined_pubkey.0.y.0)
368                .unwrap();
369            transcript.try_extend_from_slice(&m.0).unwrap();
370            trunc_g_order(&hash_varlen(&transcript))
371        };
372        let nonce = self.nonce_for(m);
373        let chal_mul = MulMod::mul_mod(&chal, &self.0, &G_ORDER);
374        let sig = nonce.add_mod(&chal_mul, &G_ORDER);
375        Signature { c: chal, s: sig }
376    }
377
378    pub fn to_be_bytes(&self) -> [u8; 32] {
379        self.0.to_be_bytes()
380    }
381}
382
383impl core::ops::Add for &PrivateKey {
384    type Output = PrivateKey;
385
386    fn add(self, other: &PrivateKey) -> PrivateKey {
387        PrivateKey(self.0.add_mod(&other.0, &G_ORDER))
388    }
389}
390
391impl core::ops::Add for PrivateKey {
392    type Output = PrivateKey;
393
394    fn add(self, other: PrivateKey) -> PrivateKey {
395        PrivateKey(self.0.add_mod(&other.0, &G_ORDER))
396    }
397}
398
399impl core::ops::AddAssign for PrivateKey {
400    fn add_assign(&mut self, other: PrivateKey) {
401        *self = &*self + &other;
402    }
403}
404
405impl core::ops::Sub for &PrivateKey {
406    type Output = PrivateKey;
407
408    fn sub(self, other: &PrivateKey) -> PrivateKey {
409        PrivateKey(self.0.sub_mod(&other.0, &G_ORDER))
410    }
411}
412
413impl core::ops::SubAssign for PrivateKey {
414    fn sub_assign(&mut self, other: PrivateKey) {
415        *self = &*self - &other;
416    }
417}
418
419impl core::iter::Sum<PrivateKey> for PrivateKey {
420    fn sum<I: Iterator<Item = PrivateKey>>(iter: I) -> Self {
421        iter.fold(PrivateKey(U256::ZERO), |acc, x| &acc + &x)
422    }
423}
424
425impl<'a> core::iter::Sum<&'a PrivateKey> for PrivateKey {
426    fn sum<I: Iterator<Item = &'a PrivateKey>>(iter: I) -> Self {
427        iter.fold(PrivateKey(U256::ZERO), |acc, x| &acc + x)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    extern crate alloc;
434    use super::*;
435    use alloc::{vec, vec::Vec};
436
437    #[test]
438    fn mupk_test() {
439        let privs = [
440            U256::from_u64(123),
441            U256::from_u64(124),
442            G_ORDER.sub_mod(&U256::ONE, &G_ORDER),
443        ]
444        .map(PrivateKey);
445        let pubs = privs.clone().map(|p| p.public_key());
446        let pub_key: PublicKey = pubs.iter().sum();
447        let priv_key: PrivateKey = privs.iter().sum();
448        let pub_key_from_priv = priv_key.public_key();
449        assert_eq!(pub_key, pub_key_from_priv);
450    }
451
452    #[test]
453    fn musig_test() {
454        let privs = [
455            U256::from_u64(123),
456            U256::from_u64(124),
457            G_ORDER.sub_mod(&U256::ONE, &G_ORDER),
458        ]
459        .map(PrivateKey);
460        let pubs = privs.clone().map(|p| p.public_key());
461        let pub_key: PublicKey = pubs.iter().sum();
462        let priv_key: PrivateKey = privs.iter().sum();
463
464        let digest = Digest([Belt(1), Belt(2), Belt(3), Belt(4), Belt(5)]);
465        let signature_all = priv_key.sign(&digest);
466        // Just testing regular signing
467        assert!(pub_key.verify(&digest, &signature_all));
468
469        // Now do split signing
470        let nonces = privs
471            .iter()
472            .map(|p| p.nonce_for(&digest))
473            .collect::<Vec<_>>();
474        let nonce = PrivateKey::combine_nonces(&nonces);
475        let mut sigs = vec![];
476        for priv_key in &privs {
477            sigs.push(priv_key.sign_multi(&digest, &nonce, &pub_key));
478        }
479        // Combine all signatures
480        let sig = sigs.into_iter().sum::<Option<Signature>>().unwrap();
481        // Verify combined signature
482        assert!(pub_key.verify(&digest, &sig));
483    }
484
485    #[test]
486    fn test_sign_and_verify() {
487        let priv_key = PrivateKey(U256::from_u64(123));
488        let digest = Digest([Belt(1), Belt(2), Belt(3), Belt(4), Belt(5)]);
489        let signature = priv_key.sign(&digest);
490        let pubkey = priv_key.public_key();
491        assert!(
492            pubkey.verify(&digest, &signature),
493            "Signature verification failed!"
494        );
495
496        // Corrupting digest, signature, or pubkey should all cause failure
497        let mut wrong_digest = digest;
498        wrong_digest.0[0] = Belt(0);
499        assert!(
500            !pubkey.verify(&wrong_digest, &signature),
501            "Should reject wrong digest"
502        );
503        let mut wrong_sig = signature;
504        wrong_sig.s += U256::from_u64(1);
505        assert!(
506            !pubkey.verify(&digest, &wrong_sig),
507            "Should reject wrong signature"
508        );
509        let mut wrong_pubkey = pubkey;
510        wrong_pubkey.0.x.0[0].0 += 1;
511        assert!(
512            !wrong_pubkey.verify(&digest, &signature),
513            "Should reject wrong public key"
514        );
515    }
516
517    #[test]
518    fn test_vector() {
519        // from nockchain zkvm-jetpack cheetah_jets.rs test_batch_verify_affine
520        let digest = Digest([Belt(8), Belt(9), Belt(10), Belt(11), Belt(12)]);
521        let pubkey = PublicKey(CheetahPoint {
522            x: F6lt([
523                Belt(2754611494552410273),
524                Belt(8599518745794843693),
525                Belt(10526511002404673680),
526                Belt(4830863958577994148),
527                Belt(375185138577093320),
528                Belt(12938930721685970739),
529            ]),
530            y: F6lt([
531                Belt(3062714866612034253),
532                Belt(15671931273416742386),
533                Belt(4071440668668521568),
534                Belt(7738250649524482367),
535                Belt(5259065445844042557),
536                Belt(8456011930642078370),
537            ]),
538            inf: false,
539        });
540        let c_hex = "6f3cd43cd8709f4368aed04cd84292ab1c380cb645aaa7d010669d70375cbe88";
541        let s_hex = "5197ab182e307a350b5cf3606d6e99a6f35b0d382c8330dde6e51fb6ef8ebb8c";
542        let signature = Signature {
543            c: U256::from_str_radix_vartime(c_hex, 16).unwrap(),
544            s: U256::from_str_radix_vartime(s_hex, 16).unwrap(),
545        };
546        assert!(pubkey.verify(&digest, &signature));
547    }
548
549    #[test]
550    fn test_serde() {
551        let c_hex = "6f3cd43cd8709f4368aed04cd84292ab1c380cb645aaa7d010669d70375cbe88";
552        let s_hex = "5197ab182e307a350b5cf3606d6e99a6f35b0d382c8330dde6e51fb6ef8ebb8c";
553        let signature = Signature {
554            c: U256::from_str_radix_vartime(c_hex, 16).unwrap(),
555            s: U256::from_str_radix_vartime(s_hex, 16).unwrap(),
556        };
557        let json = serde_json::to_string(&signature).unwrap();
558        let sig: Signature = serde_json::from_str(&json).unwrap();
559        assert_eq!(signature, sig);
560    }
561}