1use elliptic_curve::{
2 generic_array::typenum::Unsigned,
3 group::{cofactor::CofactorGroup, Curve, GroupEncoding},
4 ops::{Invert as _, Reduce},
5 point::AffineCoordinates,
6 scalar::IsHigh as _,
7 sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
8 subtle::ConditionallySelectable as _,
9 Curve as ECurve, CurveArithmetic, Field as _, Group, PrimeCurve, PrimeField,
10};
11use hd_keys_curves_wasm::{HDDerivable, HDDeriver};
12use js_sys::Uint8Array;
13use k256::Secp256k1;
14use p256::NistP256;
15use serde::Deserialize;
16use serde_bytes::Bytes;
17use tsify::Tsify;
18use wasm_bindgen::{prelude::*, JsError};
19
20use crate::abi::{from_js, into_js, into_uint8array, JsResult};
21
22#[derive(Tsify, Deserialize)]
23#[tsify(from_wasm_abi)]
24pub enum EcdsaVariant {
25 K256,
26 P256,
27}
28
29struct Ecdsa<C>(C);
30
31trait HdCtx {
32 const CTX: &'static [u8];
33}
34
35impl HdCtx for Secp256k1 {
36 const CTX: &'static [u8] = b"LIT_HD_KEY_ID_K256_XMD:SHA-256_SSWU_RO_NUL_";
37}
38
39impl HdCtx for NistP256 {
40 const CTX: &'static [u8] = b"LIT_HD_KEY_ID_P256_XMD:SHA-256_SSWU_RO_NUL_";
41}
42
43#[wasm_bindgen]
44extern "C" {
45 #[wasm_bindgen(typescript_type = "[Uint8Array, Uint8Array, number]")]
46 pub type EcdsaSignature;
47}
48
49impl<C> Ecdsa<C>
50where
51 C: PrimeCurve + CurveArithmetic + HdCtx,
52 C::AffinePoint: GroupEncoding + FromEncodedPoint<C>,
53 C::Scalar: HDDeriver,
54 C::FieldBytesSize: ModulusSize,
55 C::ProjectivePoint: CofactorGroup + HDDerivable + FromEncodedPoint<C> + ToEncodedPoint<C>,
56{
57 pub fn combine(
58 presignature: Uint8Array,
59 signature_shares: Vec<Uint8Array>,
60 ) -> JsResult<EcdsaSignature> {
61 let (big_r, s) = Self::combine_inner(presignature, signature_shares)?;
62 Self::signature_into_js(big_r.to_affine(), s)
63 }
64
65 pub(crate) fn combine_inner(
66 presignature: Uint8Array,
67 signature_shares: Vec<Uint8Array>,
68 ) -> JsResult<(C::ProjectivePoint, C::Scalar)> {
69 let signature_shares = signature_shares
70 .into_iter()
71 .map(Self::scalar_from_js)
72 .collect::<JsResult<Vec<_>>>()?;
73
74 let big_r: C::AffinePoint = Self::point_from_js(presignature)?;
75 let s = Self::sum_scalars(signature_shares)?;
76 Ok((C::ProjectivePoint::from(big_r), s))
77 }
78
79 pub fn verify(
80 message_hash: Uint8Array,
81 public_key: Uint8Array,
82 signature: EcdsaSignature,
83 ) -> JsResult<()> {
84 let (r, s, _) = Self::signature_from_js(signature)?;
85
86 let z = Self::scalar_from_hash(message_hash)?;
87 let public_key: C::ProjectivePoint = Self::point_from_js(public_key)?;
88
89 if r.is_zero().into() {
90 return Err(JsError::new("invalid signature"));
91 }
92 let s_inv = Option::<C::Scalar>::from(s.invert_vartime())
94 .ok_or_else(|| JsError::new("invalid signature"))?;
95
96 if z.is_zero().into() {
97 return Err(JsError::new("invalid message digest"));
98 }
99
100 let reproduced =
101 (<C::ProjectivePoint as Group>::generator() * (z * s_inv)) + (public_key * (r * s_inv));
102 let reproduced_x = Self::x_coordinate(&reproduced.to_affine());
103
104 if reproduced_x != r {
105 return Err(JsError::new("invalid signature"));
106 }
107
108 Ok(())
109 }
110
111 fn sum_scalars(values: Vec<C::Scalar>) -> JsResult<C::Scalar> {
112 if values.is_empty() {
113 return Err(JsError::new("no shares provided"));
114 }
115 let mut acc: C::Scalar = values.into_iter().sum();
116 acc.conditional_assign(&(-acc), acc.is_high());
117 Ok(acc)
118 }
119
120 pub fn derive_key(id: Uint8Array, public_keys: Vec<Uint8Array>) -> JsResult<Uint8Array> {
121 let k = Self::derive_key_inner(id, public_keys)?;
122 let k = k.to_encoded_point(false);
123
124 into_uint8array(k.as_bytes())
125 }
126
127 fn derive_key_inner(
128 id: Uint8Array,
129 public_keys: Vec<Uint8Array>,
130 ) -> JsResult<C::ProjectivePoint> {
131 let id = from_js::<Vec<u8>>(id)?;
132 let public_keys = public_keys
133 .into_iter()
134 .map(Self::point_from_js::<C::ProjectivePoint>)
135 .collect::<JsResult<Vec<_>>>()?;
136
137 let deriver = C::Scalar::create(&id, C::CTX);
138 Ok(deriver.hd_derive_public_key(&public_keys))
139 }
140
141 fn scalar_from_js(s: Uint8Array) -> JsResult<C::Scalar> {
142 let s = from_js::<Vec<u8>>(s)?;
143 Self::scalar_from_bytes(s)
144 }
145
146 fn scalar_from_bytes(s: Vec<u8>) -> JsResult<C::Scalar> {
147 let s = C::Scalar::from_repr(<C::Scalar as PrimeField>::Repr::from_slice(&s).clone());
148 let s = Option::from(s);
149 let s = s.ok_or_else(|| JsError::new("cannot deserialize"))?;
150
151 Ok(s)
152 }
153
154 fn point_from_js<T: FromEncodedPoint<C>>(q: Uint8Array) -> JsResult<T> {
155 let q = from_js::<Vec<u8>>(q)?;
156 let q = EncodedPoint::<C>::from_bytes(q)?;
157 let q = T::from_encoded_point(&q);
158 let q = Option::<T>::from(q);
159 let q = q.ok_or_else(|| JsError::new("cannot deserialize"))?;
160
161 Ok(q)
162 }
163
164 fn signature_from_js(signature: EcdsaSignature) -> JsResult<(C::Scalar, C::Scalar, u8)> {
165 let (r, s, v): (Vec<u8>, Vec<u8>, u8) = from_js(signature)?;
166 let r = Self::scalar_from_bytes(r)?;
167 let s = Self::scalar_from_bytes(s)?;
168 Ok((r, s, v))
169 }
170
171 fn signature_into_js(big_r: C::AffinePoint, s: C::Scalar) -> JsResult<EcdsaSignature> {
172 let r = Self::x_coordinate(&big_r).to_repr();
173 let s = s.to_repr();
174 let v = u8::conditional_select(&0, &1, big_r.y_is_odd());
175
176 Ok(EcdsaSignature {
177 obj: into_js(&(Bytes::new(&r), Bytes::new(&s), v))?,
178 })
179 }
180
181 pub(crate) fn x_coordinate(pt: &C::AffinePoint) -> C::Scalar {
182 <C::Scalar as Reduce<<C as ECurve>::Uint>>::reduce_bytes(&pt.x())
183 }
184
185 pub fn scalar_from_hash(msg_digest: Uint8Array) -> JsResult<C::Scalar> {
186 let digest = from_js::<Vec<u8>>(msg_digest)?;
187 if digest.len() != C::FieldBytesSize::to_usize() {
188 return Err(JsError::new("invalid message digest length"));
189 }
190 let z_bytes =
191 <C::Scalar as Reduce<<C as ECurve>::Uint>>::Bytes::from_slice(digest.as_slice());
192 Ok(<C::Scalar as Reduce<<C as ECurve>::Uint>>::reduce_bytes(
193 z_bytes,
194 ))
195 }
196
197 pub fn combine_and_verify_with_derived_key(
198 pre_signature: Uint8Array,
199 signature_shares: Vec<Uint8Array>,
200 message_hash: Uint8Array,
201 id: Uint8Array,
202 public_keys: Vec<Uint8Array>,
203 ) -> JsResult<EcdsaSignature> {
204 let public_key = Self::derive_key_inner(id, public_keys)?;
205 Self::combine_and_verify(pre_signature, signature_shares, message_hash, public_key)
206 }
207
208 pub fn combine_and_verify_with_specified_key(
209 pre_signature: Uint8Array,
210 signature_shares: Vec<Uint8Array>,
211 message_hash: Uint8Array,
212 public_key: Uint8Array,
213 ) -> JsResult<EcdsaSignature> {
214 let public_key: C::ProjectivePoint = Self::point_from_js(public_key)?;
215 Self::combine_and_verify(pre_signature, signature_shares, message_hash, public_key)
216 }
217
218 fn combine_and_verify(
219 pre_signature: Uint8Array,
220 signature_shares: Vec<Uint8Array>,
221 message_hash: Uint8Array,
222 public_key: C::ProjectivePoint,
223 ) -> JsResult<EcdsaSignature> {
224 let z = Self::scalar_from_hash(message_hash)?;
225 let (big_r, s) = Self::combine_inner(pre_signature, signature_shares)?;
226 let r = Self::x_coordinate(&big_r.to_affine());
227
228 if z.is_zero().into() {
229 return Err(JsError::new("invalid message digest"));
230 }
231 if (s.is_zero() | big_r.is_identity()).into() {
232 return Err(JsError::new("invalid signature"));
233 }
234 if r.is_zero().into() {
235 return Err(JsError::new("invalid r coordinate"));
236 }
237 if (big_r * s - (public_key * r + C::ProjectivePoint::generator() * z))
241 .is_identity()
242 .into()
243 {
244 Self::signature_into_js(big_r.to_affine(), s)
245 } else {
246 Err(JsError::new("invalid signature"))
247 }
248 }
249}
250
251#[wasm_bindgen(js_name = "ecdsaCombineAndVerifyWithDerivedKey")]
253pub fn ecdsa_combine_and_verify_with_derived_key(
254 variant: EcdsaVariant,
255 pre_signature: Uint8Array,
256 signature_shares: Vec<Uint8Array>,
257 message_hash: Uint8Array,
258 id: Uint8Array,
259 public_keys: Vec<Uint8Array>,
260) -> JsResult<EcdsaSignature> {
261 match variant {
262 EcdsaVariant::K256 => Ecdsa::<Secp256k1>::combine_and_verify_with_derived_key(
263 pre_signature,
264 signature_shares,
265 message_hash,
266 id,
267 public_keys,
268 ),
269 EcdsaVariant::P256 => Ecdsa::<NistP256>::combine_and_verify_with_derived_key(
270 pre_signature,
271 signature_shares,
272 message_hash,
273 id,
274 public_keys,
275 ),
276 }
277}
278
279#[wasm_bindgen(js_name = "ecdsaCombineAndVerify")]
281pub fn ecdsa_combine_and_verify(
282 variant: EcdsaVariant,
283 pre_signature: Uint8Array,
284 signature_shares: Vec<Uint8Array>,
285 message_hash: Uint8Array,
286 public_key: Uint8Array,
287) -> JsResult<EcdsaSignature> {
288 match variant {
289 EcdsaVariant::K256 => Ecdsa::<Secp256k1>::combine_and_verify_with_specified_key(
290 pre_signature,
291 signature_shares,
292 message_hash,
293 public_key,
294 ),
295 EcdsaVariant::P256 => Ecdsa::<NistP256>::combine_and_verify_with_specified_key(
296 pre_signature,
297 signature_shares,
298 message_hash,
299 public_key,
300 ),
301 }
302}
303
304#[wasm_bindgen(js_name = "ecdsaCombine")]
306pub fn ecdsa_combine(
307 variant: EcdsaVariant,
308 presignature: Uint8Array,
309 signature_shares: Vec<Uint8Array>,
310) -> JsResult<EcdsaSignature> {
311 match variant {
312 EcdsaVariant::K256 => Ecdsa::<Secp256k1>::combine(presignature, signature_shares),
313 EcdsaVariant::P256 => Ecdsa::<NistP256>::combine(presignature, signature_shares),
314 }
315}
316
317#[wasm_bindgen(js_name = "ecdsaVerify")]
318pub fn ecdsa_verify(
319 variant: EcdsaVariant,
320 message_hash: Uint8Array,
321 public_key: Uint8Array,
322 signature: EcdsaSignature,
323) -> JsResult<()> {
324 match variant {
325 EcdsaVariant::K256 => Ecdsa::<Secp256k1>::verify(message_hash, public_key, signature),
326 EcdsaVariant::P256 => Ecdsa::<NistP256>::verify(message_hash, public_key, signature),
327 }
328}
329
330#[wasm_bindgen(js_name = "ecdsaDeriveKey")]
331pub fn ecdsa_derive_key(
332 variant: EcdsaVariant,
333 id: Uint8Array,
334 public_keys: Vec<Uint8Array>,
335) -> JsResult<Uint8Array> {
336 match variant {
337 EcdsaVariant::K256 => Ecdsa::<Secp256k1>::derive_key(id, public_keys),
338 EcdsaVariant::P256 => Ecdsa::<NistP256>::derive_key(id, public_keys),
339 }
340}