1#[cfg(feature = "alloc")]
2use alloc::{
3 boxed::Box,
4 format,
5 string::{String, ToString},
6 vec::Vec,
7};
8use arrayvec::ArrayVec;
9use iris_ztd::{
10 crypto::cheetah::{
11 ch_add, ch_neg, ch_scal_big, trunc_g_order, CheetahPoint, F6lt, A_GEN, G_ORDER,
12 },
13 tip5::hash::hash_varlen,
14 Belt, Digest, Hashable, MulMod, U256,
15};
16#[cfg(feature = "alloc")]
17use iris_ztd::{Noun, NounDecode, NounEncode};
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
21#[cfg_attr(feature = "alloc", derive(NounEncode, NounDecode))]
22#[iris_ztd::wasm_noun_codec]
23pub struct PublicKey(pub CheetahPoint);
24
25impl core::fmt::Display for PublicKey {
26 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31impl TryFrom<&str> for PublicKey {
32 type Error = iris_ztd::crypto::cheetah::CheetahError;
33
34 fn try_from(value: &str) -> Result<Self, Self::Error> {
35 value.try_into().map(Self)
36 }
37}
38
39#[iris_ztd::wasm_member_methods]
40impl PublicKey {
41 pub fn verify(&self, m: &Digest, sig: &Signature) -> bool {
42 if sig.c == U256::ZERO || sig.c >= G_ORDER || sig.s == U256::ZERO || sig.s >= G_ORDER {
43 return false;
44 }
45
46 let sg = match ch_scal_big(&sig.s, &A_GEN) {
49 Ok(pt) => pt,
50 Err(_) => return false,
51 };
52 let c_pk = match ch_scal_big(&sig.c, &self.0) {
53 Ok(pt) => pt,
54 Err(_) => return false,
55 };
56 let scalar = match ch_add(&sg, &ch_neg(&c_pk)) {
57 Ok(pt) => pt,
58 Err(_) => return false,
59 };
60 let chal = {
61 let mut transcript: ArrayVec<Belt, { 6 + 6 + 6 + 6 + 5 }> = ArrayVec::new();
62 transcript.try_extend_from_slice(&scalar.x.0).unwrap();
63 transcript.try_extend_from_slice(&scalar.y.0).unwrap();
64 transcript.try_extend_from_slice(&self.0.x.0).unwrap();
65 transcript.try_extend_from_slice(&self.0.y.0).unwrap();
66 transcript.try_extend_from_slice(&m.0).unwrap();
67 trunc_g_order(&hash_varlen(&transcript))
68 };
69
70 chal == sig.c
71 }
72
73 pub fn from_be_bytes(bytes: &[u8]) -> PublicKey {
74 let mut x = [Belt(0); 6];
75 let mut y = [Belt(0); 6];
76
77 for i in 0..6 {
79 let offset = 1 + i * 8;
80 let mut buf = [0u8; 8];
81 buf.copy_from_slice(&bytes[offset..offset + 8]);
82 y[5 - i] = Belt(u64::from_be_bytes(buf));
83 }
84
85 for i in 0..6 {
87 let offset = 49 + i * 8;
88 let mut buf = [0u8; 8];
89 buf.copy_from_slice(&bytes[offset..offset + 8]);
90 x[5 - i] = Belt(u64::from_be_bytes(buf));
91 }
92
93 PublicKey(CheetahPoint {
94 x: F6lt(x),
95 y: F6lt(y),
96 inf: false,
97 })
98 }
99
100 #[cfg(feature = "alloc")]
101 pub fn to_be_bytes_vec(&self) -> Vec<u8> {
102 self.to_be_bytes().to_vec()
103 }
104
105 #[cfg(feature = "alloc")]
106 pub fn from_hex(hex: &str) -> Option<PublicKey> {
107 let bytes = hex::decode(hex).ok()?;
108 if bytes.len() != 97 {
109 return None;
110 }
111 Some(Self::from_be_bytes(&bytes))
112 }
113
114 #[cfg(feature = "alloc")]
115 pub fn to_hex(&self) -> String {
116 hex::encode(self.to_be_bytes())
117 }
118}
119
120impl PublicKey {
121 pub fn to_be_bytes(&self) -> [u8; 97] {
122 let mut data = [0u8; 97];
123 data[0] = 0x01; let mut offset = 1;
125 for belt in self.0.y.0.iter().rev() {
127 data[offset..offset + 8].copy_from_slice(&belt.0.to_be_bytes());
128 offset += 8;
129 }
130 for belt in self.0.x.0.iter().rev() {
132 data[offset..offset + 8].copy_from_slice(&belt.0.to_be_bytes());
133 offset += 8;
134 }
135 data
136 }
137
138 pub(crate) fn as_slip10_bytes(&self) -> [u8; 96] {
140 let mut data = [0u8; 96];
141 let mut offset = 0;
142 for belt in self.0.y.0.iter().rev().chain(self.0.x.0.iter().rev()) {
143 data[offset..offset + 8].copy_from_slice(&belt.0.to_be_bytes());
144 offset += 8;
145 }
146 data
147 }
148}
149
150impl core::ops::Add for &PublicKey {
151 type Output = PublicKey;
152
153 fn add(self, other: &PublicKey) -> PublicKey {
154 PublicKey(ch_add(&self.0, &other.0).unwrap())
155 }
156}
157
158impl core::ops::Add for PublicKey {
159 type Output = PublicKey;
160
161 fn add(self, other: PublicKey) -> PublicKey {
162 (&self as &PublicKey) + (&other as &PublicKey)
163 }
164}
165
166impl core::ops::AddAssign for PublicKey {
167 fn add_assign(&mut self, other: PublicKey) {
168 *self = *self + other;
169 }
170}
171
172impl core::ops::Sub for &PublicKey {
173 type Output = PublicKey;
174
175 fn sub(self, other: &PublicKey) -> PublicKey {
176 PublicKey(ch_add(&self.0, &ch_neg(&other.0)).unwrap())
177 }
178}
179
180impl core::ops::SubAssign for PublicKey {
181 fn sub_assign(&mut self, other: PublicKey) {
182 *self = &*self - &other;
183 }
184}
185
186impl core::iter::Sum<PublicKey> for PublicKey {
187 fn sum<I: Iterator<Item = PublicKey>>(iter: I) -> Self {
188 iter.fold(PublicKey(CheetahPoint::identity()), |acc, x| acc + x)
189 }
190}
191
192impl<'a> core::iter::Sum<&'a PublicKey> for PublicKey {
193 fn sum<I: Iterator<Item = &'a PublicKey>>(iter: I) -> Self {
194 iter.fold(PublicKey(CheetahPoint::identity()), |acc, x| &acc + x)
195 }
196}
197
198impl Hashable for PublicKey {
199 fn hash(&self) -> Digest {
200 self.0.hash()
201 }
202
203 fn leaf_count(&self) -> usize {
204 self.0.leaf_count()
205 }
206
207 fn hashable_pair<'a>(&'a self) -> Option<(impl Hashable + 'a, impl Hashable + 'a)> {
208 self.0.hashable_pair()
209 }
210}
211
212#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
213#[iris_ztd::wasm_noun_codec]
214pub struct Signature {
215 #[cfg_attr(feature = "wasm", tsify(type = "string"))]
217 pub c: U256,
218 #[cfg_attr(feature = "wasm", tsify(type = "string"))]
220 pub s: U256,
221}
222
223impl core::iter::Sum<Signature> for Option<Signature> {
225 fn sum<I: Iterator<Item = Signature>>(mut iter: I) -> Self {
226 let mut c = None;
227 let s = iter.try_fold(U256::ZERO, |acc, x| {
228 if c.is_some() && c.as_ref() != Some(&x.c) {
229 return None;
230 }
231 c = Some(x.c);
232 Some(acc.add_mod(&x.s, &G_ORDER))
233 });
234 Some(Signature { c: c?, s: s? })
235 }
236}
237
238#[cfg(feature = "alloc")]
239impl NounEncode for Signature {
240 fn to_noun(&self) -> Noun {
241 (
242 Belt::from_bytes(&self.c.to_le_bytes()).as_slice(),
243 Belt::from_bytes(&self.s.to_le_bytes()).as_slice(),
244 )
245 .to_noun()
246 }
247}
248
249#[cfg(feature = "alloc")]
250impl NounDecode for Signature {
251 fn from_noun(noun: &Noun) -> Option<Self> {
252 let (c, s): ([Belt; 8], [Belt; 8]) = NounDecode::from_noun(noun)?;
253
254 let c = Belt::to_bytes(&c);
255 let s = Belt::to_bytes(&s);
256
257 Some(Signature {
258 c: U256::from_le_slice(&c),
259 s: U256::from_le_slice(&s),
260 })
261 }
262}
263
264#[cfg(feature = "alloc")]
266impl Hashable for Signature {
267 fn hash(&self) -> Digest {
268 self.to_noun().hash()
269 }
270
271 fn leaf_count(&self) -> usize {
272 1
273 }
274
275 fn hashable_pair<'a>(&'a self) -> Option<(impl Hashable + 'a, impl Hashable + 'a)> {
276 Option::<((), ())>::None
277 }
278}
279
280#[derive(Debug, Clone)]
281pub struct PrivateKey(pub U256);
282
283impl Drop for PrivateKey {
284 fn drop(&mut self) {
285 unsafe {
286 core::ptr::write_volatile(&mut self.0, U256::ZERO);
287 }
288 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
289 }
290}
291
292impl PrivateKey {
293 pub fn public_key(&self) -> PublicKey {
294 PublicKey(ch_scal_big(&self.0, &A_GEN).unwrap())
295 }
296
297 pub fn sign(&self, m: &Digest) -> Signature {
298 self.sign_multi(m, &self.nonce_for(m), &self.public_key())
299 }
300
301 pub fn nonce_for(&self, m: &Digest) -> U256 {
302 let pubkey = self.public_key().0;
303 let nonce = {
304 let mut transcript: ArrayVec<Belt, { 6 + 6 + 5 + 8 }> = ArrayVec::new();
305 transcript.try_extend_from_slice(&pubkey.x.0).unwrap();
306 transcript.try_extend_from_slice(&pubkey.y.0).unwrap();
307 transcript.try_extend_from_slice(&m.0).unwrap();
308 self.0.to_le_bytes().chunks(4).for_each(|chunk| {
309 let mut buf = [0u8; 4];
310 buf[..chunk.len()].copy_from_slice(chunk);
311 transcript.push(Belt(u32::from_le_bytes(buf) as u64));
312 });
313 trunc_g_order(&hash_varlen(&transcript))
314 };
315 nonce
316 }
317
318 pub fn combine_nonces(nonces: &[U256]) -> U256 {
319 nonces
320 .iter()
321 .fold(U256::ZERO, |acc, x| acc.add_mod(x, &G_ORDER))
322 }
323
324 pub fn sign_multi(
352 &self,
353 m: &Digest,
354 shared_nonce: &U256,
355 combined_pubkey: &PublicKey,
356 ) -> Signature {
357 let chal = {
358 let scalar = ch_scal_big(shared_nonce, &A_GEN).unwrap();
360 let mut transcript: ArrayVec<Belt, { 6 + 6 + 6 + 6 + 5 }> = ArrayVec::new();
361 transcript.try_extend_from_slice(&scalar.x.0).unwrap();
362 transcript.try_extend_from_slice(&scalar.y.0).unwrap();
363 transcript
364 .try_extend_from_slice(&combined_pubkey.0.x.0)
365 .unwrap();
366 transcript
367 .try_extend_from_slice(&combined_pubkey.0.y.0)
368 .unwrap();
369 transcript.try_extend_from_slice(&m.0).unwrap();
370 trunc_g_order(&hash_varlen(&transcript))
371 };
372 let nonce = self.nonce_for(m);
373 let chal_mul = MulMod::mul_mod(&chal, &self.0, &G_ORDER);
374 let sig = nonce.add_mod(&chal_mul, &G_ORDER);
375 Signature { c: chal, s: sig }
376 }
377
378 pub fn to_be_bytes(&self) -> [u8; 32] {
379 self.0.to_be_bytes()
380 }
381}
382
383impl core::ops::Add for &PrivateKey {
384 type Output = PrivateKey;
385
386 fn add(self, other: &PrivateKey) -> PrivateKey {
387 PrivateKey(self.0.add_mod(&other.0, &G_ORDER))
388 }
389}
390
391impl core::ops::Add for PrivateKey {
392 type Output = PrivateKey;
393
394 fn add(self, other: PrivateKey) -> PrivateKey {
395 PrivateKey(self.0.add_mod(&other.0, &G_ORDER))
396 }
397}
398
399impl core::ops::AddAssign for PrivateKey {
400 fn add_assign(&mut self, other: PrivateKey) {
401 *self = &*self + &other;
402 }
403}
404
405impl core::ops::Sub for &PrivateKey {
406 type Output = PrivateKey;
407
408 fn sub(self, other: &PrivateKey) -> PrivateKey {
409 PrivateKey(self.0.sub_mod(&other.0, &G_ORDER))
410 }
411}
412
413impl core::ops::SubAssign for PrivateKey {
414 fn sub_assign(&mut self, other: PrivateKey) {
415 *self = &*self - &other;
416 }
417}
418
419impl core::iter::Sum<PrivateKey> for PrivateKey {
420 fn sum<I: Iterator<Item = PrivateKey>>(iter: I) -> Self {
421 iter.fold(PrivateKey(U256::ZERO), |acc, x| &acc + &x)
422 }
423}
424
425impl<'a> core::iter::Sum<&'a PrivateKey> for PrivateKey {
426 fn sum<I: Iterator<Item = &'a PrivateKey>>(iter: I) -> Self {
427 iter.fold(PrivateKey(U256::ZERO), |acc, x| &acc + x)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 extern crate alloc;
434 use super::*;
435 use alloc::{vec, vec::Vec};
436
437 #[test]
438 fn mupk_test() {
439 let privs = [
440 U256::from_u64(123),
441 U256::from_u64(124),
442 G_ORDER.sub_mod(&U256::ONE, &G_ORDER),
443 ]
444 .map(PrivateKey);
445 let pubs = privs.clone().map(|p| p.public_key());
446 let pub_key: PublicKey = pubs.iter().sum();
447 let priv_key: PrivateKey = privs.iter().sum();
448 let pub_key_from_priv = priv_key.public_key();
449 assert_eq!(pub_key, pub_key_from_priv);
450 }
451
452 #[test]
453 fn musig_test() {
454 let privs = [
455 U256::from_u64(123),
456 U256::from_u64(124),
457 G_ORDER.sub_mod(&U256::ONE, &G_ORDER),
458 ]
459 .map(PrivateKey);
460 let pubs = privs.clone().map(|p| p.public_key());
461 let pub_key: PublicKey = pubs.iter().sum();
462 let priv_key: PrivateKey = privs.iter().sum();
463
464 let digest = Digest([Belt(1), Belt(2), Belt(3), Belt(4), Belt(5)]);
465 let signature_all = priv_key.sign(&digest);
466 assert!(pub_key.verify(&digest, &signature_all));
468
469 let nonces = privs
471 .iter()
472 .map(|p| p.nonce_for(&digest))
473 .collect::<Vec<_>>();
474 let nonce = PrivateKey::combine_nonces(&nonces);
475 let mut sigs = vec![];
476 for priv_key in &privs {
477 sigs.push(priv_key.sign_multi(&digest, &nonce, &pub_key));
478 }
479 let sig = sigs.into_iter().sum::<Option<Signature>>().unwrap();
481 assert!(pub_key.verify(&digest, &sig));
483 }
484
485 #[test]
486 fn test_sign_and_verify() {
487 let priv_key = PrivateKey(U256::from_u64(123));
488 let digest = Digest([Belt(1), Belt(2), Belt(3), Belt(4), Belt(5)]);
489 let signature = priv_key.sign(&digest);
490 let pubkey = priv_key.public_key();
491 assert!(
492 pubkey.verify(&digest, &signature),
493 "Signature verification failed!"
494 );
495
496 let mut wrong_digest = digest;
498 wrong_digest.0[0] = Belt(0);
499 assert!(
500 !pubkey.verify(&wrong_digest, &signature),
501 "Should reject wrong digest"
502 );
503 let mut wrong_sig = signature;
504 wrong_sig.s += U256::from_u64(1);
505 assert!(
506 !pubkey.verify(&digest, &wrong_sig),
507 "Should reject wrong signature"
508 );
509 let mut wrong_pubkey = pubkey;
510 wrong_pubkey.0.x.0[0].0 += 1;
511 assert!(
512 !wrong_pubkey.verify(&digest, &signature),
513 "Should reject wrong public key"
514 );
515 }
516
517 #[test]
518 fn test_vector() {
519 let digest = Digest([Belt(8), Belt(9), Belt(10), Belt(11), Belt(12)]);
521 let pubkey = PublicKey(CheetahPoint {
522 x: F6lt([
523 Belt(2754611494552410273),
524 Belt(8599518745794843693),
525 Belt(10526511002404673680),
526 Belt(4830863958577994148),
527 Belt(375185138577093320),
528 Belt(12938930721685970739),
529 ]),
530 y: F6lt([
531 Belt(3062714866612034253),
532 Belt(15671931273416742386),
533 Belt(4071440668668521568),
534 Belt(7738250649524482367),
535 Belt(5259065445844042557),
536 Belt(8456011930642078370),
537 ]),
538 inf: false,
539 });
540 let c_hex = "6f3cd43cd8709f4368aed04cd84292ab1c380cb645aaa7d010669d70375cbe88";
541 let s_hex = "5197ab182e307a350b5cf3606d6e99a6f35b0d382c8330dde6e51fb6ef8ebb8c";
542 let signature = Signature {
543 c: U256::from_str_radix_vartime(c_hex, 16).unwrap(),
544 s: U256::from_str_radix_vartime(s_hex, 16).unwrap(),
545 };
546 assert!(pubkey.verify(&digest, &signature));
547 }
548
549 #[test]
550 fn test_serde() {
551 let c_hex = "6f3cd43cd8709f4368aed04cd84292ab1c380cb645aaa7d010669d70375cbe88";
552 let s_hex = "5197ab182e307a350b5cf3606d6e99a6f35b0d382c8330dde6e51fb6ef8ebb8c";
553 let signature = Signature {
554 c: U256::from_str_radix_vartime(c_hex, 16).unwrap(),
555 s: U256::from_str_radix_vartime(s_hex, 16).unwrap(),
556 };
557 let json = serde_json::to_string(&signature).unwrap();
558 let sig: Signature = serde_json::from_str(&json).unwrap();
559 assert_eq!(signature, sig);
560 }
561}