lit_utilities_wasm/
ecdsa.rs

1use elliptic_curve::{
2    generic_array::typenum::Unsigned,
3    group::{cofactor::CofactorGroup, Curve, GroupEncoding},
4    ops::{Invert as _, Reduce},
5    point::AffineCoordinates,
6    scalar::IsHigh as _,
7    sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
8    subtle::ConditionallySelectable as _,
9    Curve as ECurve, CurveArithmetic, Field as _, Group, PrimeCurve, PrimeField,
10};
11use hd_keys_curves_wasm::{HDDerivable, HDDeriver};
12use js_sys::Uint8Array;
13use k256::Secp256k1;
14use p256::NistP256;
15use serde::Deserialize;
16use serde_bytes::Bytes;
17use tsify::Tsify;
18use wasm_bindgen::{prelude::*, JsError};
19
20use crate::abi::{from_js, into_js, into_uint8array, JsResult};
21
22#[derive(Tsify, Deserialize)]
23#[tsify(from_wasm_abi)]
24pub enum EcdsaVariant {
25    K256,
26    P256,
27}
28
29struct Ecdsa<C>(C);
30
31trait HdCtx {
32    const CTX: &'static [u8];
33}
34
35impl HdCtx for Secp256k1 {
36    const CTX: &'static [u8] = b"LIT_HD_KEY_ID_K256_XMD:SHA-256_SSWU_RO_NUL_";
37}
38
39impl HdCtx for NistP256 {
40    const CTX: &'static [u8] = b"LIT_HD_KEY_ID_P256_XMD:SHA-256_SSWU_RO_NUL_";
41}
42
43#[wasm_bindgen]
44extern "C" {
45    #[wasm_bindgen(typescript_type = "[Uint8Array, Uint8Array, number]")]
46    pub type EcdsaSignature;
47}
48
49impl<C> Ecdsa<C>
50where
51    C: PrimeCurve + CurveArithmetic + HdCtx,
52    C::AffinePoint: GroupEncoding + FromEncodedPoint<C>,
53    C::Scalar: HDDeriver,
54    C::FieldBytesSize: ModulusSize,
55    C::ProjectivePoint: CofactorGroup + HDDerivable + FromEncodedPoint<C> + ToEncodedPoint<C>,
56{
57    pub fn combine(
58        presignature: Uint8Array,
59        signature_shares: Vec<Uint8Array>,
60    ) -> JsResult<EcdsaSignature> {
61        let (big_r, s) = Self::combine_inner(presignature, signature_shares)?;
62        Self::signature_into_js(big_r.to_affine(), s)
63    }
64
65    pub(crate) fn combine_inner(
66        presignature: Uint8Array,
67        signature_shares: Vec<Uint8Array>,
68    ) -> JsResult<(C::ProjectivePoint, C::Scalar)> {
69        let signature_shares = signature_shares
70            .into_iter()
71            .map(Self::scalar_from_js)
72            .collect::<JsResult<Vec<_>>>()?;
73
74        let big_r: C::AffinePoint = Self::point_from_js(presignature)?;
75        let s = Self::sum_scalars(signature_shares)?;
76        Ok((C::ProjectivePoint::from(big_r), s))
77    }
78
79    pub fn verify(
80        message_hash: Uint8Array,
81        public_key: Uint8Array,
82        signature: EcdsaSignature,
83    ) -> JsResult<()> {
84        let (r, s, _) = Self::signature_from_js(signature)?;
85
86        let z = Self::scalar_from_hash(message_hash)?;
87        let public_key: C::ProjectivePoint = Self::point_from_js(public_key)?;
88
89        if r.is_zero().into() {
90            return Err(JsError::new("invalid signature"));
91        }
92        // This will fail if s == 0
93        let s_inv = Option::<C::Scalar>::from(s.invert_vartime())
94            .ok_or_else(|| JsError::new("invalid signature"))?;
95
96        if z.is_zero().into() {
97            return Err(JsError::new("invalid message digest"));
98        }
99
100        let reproduced =
101            (<C::ProjectivePoint as Group>::generator() * (z * s_inv)) + (public_key * (r * s_inv));
102        let reproduced_x = Self::x_coordinate(&reproduced.to_affine());
103
104        if reproduced_x != r {
105            return Err(JsError::new("invalid signature"));
106        }
107
108        Ok(())
109    }
110
111    fn sum_scalars(values: Vec<C::Scalar>) -> JsResult<C::Scalar> {
112        if values.is_empty() {
113            return Err(JsError::new("no shares provided"));
114        }
115        let mut acc: C::Scalar = values.into_iter().sum();
116        acc.conditional_assign(&(-acc), acc.is_high());
117        Ok(acc)
118    }
119
120    pub fn derive_key(id: Uint8Array, public_keys: Vec<Uint8Array>) -> JsResult<Uint8Array> {
121        let k = Self::derive_key_inner(id, public_keys)?;
122        let k = k.to_encoded_point(false);
123
124        into_uint8array(k.as_bytes())
125    }
126
127    fn derive_key_inner(
128        id: Uint8Array,
129        public_keys: Vec<Uint8Array>,
130    ) -> JsResult<C::ProjectivePoint> {
131        let id = from_js::<Vec<u8>>(id)?;
132        let public_keys = public_keys
133            .into_iter()
134            .map(Self::point_from_js::<C::ProjectivePoint>)
135            .collect::<JsResult<Vec<_>>>()?;
136
137        let deriver = C::Scalar::create(&id, C::CTX);
138        Ok(deriver.hd_derive_public_key(&public_keys))
139    }
140
141    fn scalar_from_js(s: Uint8Array) -> JsResult<C::Scalar> {
142        let s = from_js::<Vec<u8>>(s)?;
143        Self::scalar_from_bytes(s)
144    }
145
146    fn scalar_from_bytes(s: Vec<u8>) -> JsResult<C::Scalar> {
147        let s = C::Scalar::from_repr(<C::Scalar as PrimeField>::Repr::from_slice(&s).clone());
148        let s = Option::from(s);
149        let s = s.ok_or_else(|| JsError::new("cannot deserialize"))?;
150
151        Ok(s)
152    }
153
154    fn point_from_js<T: FromEncodedPoint<C>>(q: Uint8Array) -> JsResult<T> {
155        let q = from_js::<Vec<u8>>(q)?;
156        let q = EncodedPoint::<C>::from_bytes(q)?;
157        let q = T::from_encoded_point(&q);
158        let q = Option::<T>::from(q);
159        let q = q.ok_or_else(|| JsError::new("cannot deserialize"))?;
160
161        Ok(q)
162    }
163
164    fn signature_from_js(signature: EcdsaSignature) -> JsResult<(C::Scalar, C::Scalar, u8)> {
165        let (r, s, v): (Vec<u8>, Vec<u8>, u8) = from_js(signature)?;
166        let r = Self::scalar_from_bytes(r)?;
167        let s = Self::scalar_from_bytes(s)?;
168        Ok((r, s, v))
169    }
170
171    fn signature_into_js(big_r: C::AffinePoint, s: C::Scalar) -> JsResult<EcdsaSignature> {
172        let r = Self::x_coordinate(&big_r).to_repr();
173        let s = s.to_repr();
174        let v = u8::conditional_select(&0, &1, big_r.y_is_odd());
175
176        Ok(EcdsaSignature {
177            obj: into_js(&(Bytes::new(&r), Bytes::new(&s), v))?,
178        })
179    }
180
181    pub(crate) fn x_coordinate(pt: &C::AffinePoint) -> C::Scalar {
182        <C::Scalar as Reduce<<C as ECurve>::Uint>>::reduce_bytes(&pt.x())
183    }
184
185    pub fn scalar_from_hash(msg_digest: Uint8Array) -> JsResult<C::Scalar> {
186        let digest = from_js::<Vec<u8>>(msg_digest)?;
187        if digest.len() != C::FieldBytesSize::to_usize() {
188            return Err(JsError::new("invalid message digest length"));
189        }
190        let z_bytes =
191            <C::Scalar as Reduce<<C as ECurve>::Uint>>::Bytes::from_slice(digest.as_slice());
192        Ok(<C::Scalar as Reduce<<C as ECurve>::Uint>>::reduce_bytes(
193            z_bytes,
194        ))
195    }
196
197    pub fn combine_and_verify_with_derived_key(
198        pre_signature: Uint8Array,
199        signature_shares: Vec<Uint8Array>,
200        message_hash: Uint8Array,
201        id: Uint8Array,
202        public_keys: Vec<Uint8Array>,
203    ) -> JsResult<EcdsaSignature> {
204        let public_key = Self::derive_key_inner(id, public_keys)?;
205        Self::combine_and_verify(pre_signature, signature_shares, message_hash, public_key)
206    }
207
208    pub fn combine_and_verify_with_specified_key(
209        pre_signature: Uint8Array,
210        signature_shares: Vec<Uint8Array>,
211        message_hash: Uint8Array,
212        public_key: Uint8Array,
213    ) -> JsResult<EcdsaSignature> {
214        let public_key: C::ProjectivePoint = Self::point_from_js(public_key)?;
215        Self::combine_and_verify(pre_signature, signature_shares, message_hash, public_key)
216    }
217
218    fn combine_and_verify(
219        pre_signature: Uint8Array,
220        signature_shares: Vec<Uint8Array>,
221        message_hash: Uint8Array,
222        public_key: C::ProjectivePoint,
223    ) -> JsResult<EcdsaSignature> {
224        let z = Self::scalar_from_hash(message_hash)?;
225        let (big_r, s) = Self::combine_inner(pre_signature, signature_shares)?;
226        let r = Self::x_coordinate(&big_r.to_affine());
227
228        if z.is_zero().into() {
229            return Err(JsError::new("invalid message digest"));
230        }
231        if (s.is_zero() | big_r.is_identity()).into() {
232            return Err(JsError::new("invalid signature"));
233        }
234        if r.is_zero().into() {
235            return Err(JsError::new("invalid r coordinate"));
236        }
237        // sR == zG * rY =
238        // (z + rx/k) * k * G == zG + rxG =
239        // (z + rx) G == (z + rx) G
240        if (big_r * s - (public_key * r + C::ProjectivePoint::generator() * z))
241            .is_identity()
242            .into()
243        {
244            Self::signature_into_js(big_r.to_affine(), s)
245        } else {
246            Err(JsError::new("invalid signature"))
247        }
248    }
249}
250
251/// Perform all three functions at once
252#[wasm_bindgen(js_name = "ecdsaCombineAndVerifyWithDerivedKey")]
253pub fn ecdsa_combine_and_verify_with_derived_key(
254    variant: EcdsaVariant,
255    pre_signature: Uint8Array,
256    signature_shares: Vec<Uint8Array>,
257    message_hash: Uint8Array,
258    id: Uint8Array,
259    public_keys: Vec<Uint8Array>,
260) -> JsResult<EcdsaSignature> {
261    match variant {
262        EcdsaVariant::K256 => Ecdsa::<Secp256k1>::combine_and_verify_with_derived_key(
263            pre_signature,
264            signature_shares,
265            message_hash,
266            id,
267            public_keys,
268        ),
269        EcdsaVariant::P256 => Ecdsa::<NistP256>::combine_and_verify_with_derived_key(
270            pre_signature,
271            signature_shares,
272            message_hash,
273            id,
274            public_keys,
275        ),
276    }
277}
278
279/// Perform combine and verify with a specified public key
280#[wasm_bindgen(js_name = "ecdsaCombineAndVerify")]
281pub fn ecdsa_combine_and_verify(
282    variant: EcdsaVariant,
283    pre_signature: Uint8Array,
284    signature_shares: Vec<Uint8Array>,
285    message_hash: Uint8Array,
286    public_key: Uint8Array,
287) -> JsResult<EcdsaSignature> {
288    match variant {
289        EcdsaVariant::K256 => Ecdsa::<Secp256k1>::combine_and_verify_with_specified_key(
290            pre_signature,
291            signature_shares,
292            message_hash,
293            public_key,
294        ),
295        EcdsaVariant::P256 => Ecdsa::<NistP256>::combine_and_verify_with_specified_key(
296            pre_signature,
297            signature_shares,
298            message_hash,
299            public_key,
300        ),
301    }
302}
303
304/// Combine ECDSA signatures shares
305#[wasm_bindgen(js_name = "ecdsaCombine")]
306pub fn ecdsa_combine(
307    variant: EcdsaVariant,
308    presignature: Uint8Array,
309    signature_shares: Vec<Uint8Array>,
310) -> JsResult<EcdsaSignature> {
311    match variant {
312        EcdsaVariant::K256 => Ecdsa::<Secp256k1>::combine(presignature, signature_shares),
313        EcdsaVariant::P256 => Ecdsa::<NistP256>::combine(presignature, signature_shares),
314    }
315}
316
317#[wasm_bindgen(js_name = "ecdsaVerify")]
318pub fn ecdsa_verify(
319    variant: EcdsaVariant,
320    message_hash: Uint8Array,
321    public_key: Uint8Array,
322    signature: EcdsaSignature,
323) -> JsResult<()> {
324    match variant {
325        EcdsaVariant::K256 => Ecdsa::<Secp256k1>::verify(message_hash, public_key, signature),
326        EcdsaVariant::P256 => Ecdsa::<NistP256>::verify(message_hash, public_key, signature),
327    }
328}
329
330#[wasm_bindgen(js_name = "ecdsaDeriveKey")]
331pub fn ecdsa_derive_key(
332    variant: EcdsaVariant,
333    id: Uint8Array,
334    public_keys: Vec<Uint8Array>,
335) -> JsResult<Uint8Array> {
336    match variant {
337        EcdsaVariant::K256 => Ecdsa::<Secp256k1>::derive_key(id, public_keys),
338        EcdsaVariant::P256 => Ecdsa::<NistP256>::derive_key(id, public_keys),
339    }
340}