af_keys/
multisig.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::hash::{Hash, Hasher};
5use std::str::FromStr;
6
7pub use enum_dispatch::enum_dispatch;
8use fastcrypto::encoding::{Base64, Encoding};
9use fastcrypto::error::FastCryptoError;
10use fastcrypto::hash::HashFunction as _;
11use fastcrypto::traits::ToFromBytes;
12use once_cell::sync::OnceCell;
13use serde::{Deserialize, Serialize};
14use serde_with::serde_as;
15use sui_sdk_types::Address as SuiAddress;
16use sui_sdk_types::bcs::{FromBcs, ToBcs};
17
18use crate::crypto::{
19    CompressedSignature,
20    DefaultHash,
21    Error,
22    PublicKey,
23    Signature,
24    SignatureScheme,
25};
26
27pub type WeightUnit = u8;
28pub type ThresholdUnit = u16;
29pub type BitmapUnit = u16;
30pub const MAX_SIGNER_IN_MULTISIG: usize = 10;
31pub const MAX_BITMAP_VALUE: BitmapUnit = 0b1111111111;
32
33// =============================================================================
34//  MultiSigSigner
35// =============================================================================
36
37/// Data needed for signing as a multisig.
38#[derive(Deserialize, Debug)]
39pub struct MultiSigSigner {
40    pub multisig_pk: MultiSigPublicKey,
41    /// The indexes of the public keys in `multisig_pk` to sign for.
42    pub signers: Vec<usize>,
43}
44
45// =============================================================================
46//  MultiSig
47// =============================================================================
48
49/// The struct that contains signatures and public keys necessary for authenticating a MultiSig.
50#[serde_as]
51#[derive(Debug, Serialize, Deserialize, Clone)]
52pub struct MultiSig {
53    /// The plain signature encoded with signature scheme.
54    sigs: Vec<CompressedSignature>,
55    /// A bitmap that indicates the position of which public key the signature should be authenticated with.
56    bitmap: BitmapUnit,
57    /// The public key encoded with each public key with its signature scheme used along with the corresponding weight.
58    multisig_pk: MultiSigPublicKey,
59    /// A bytes representation of [struct MultiSig]. This helps with implementing [trait AsRef<[u8]>].
60    #[serde(skip)]
61    bytes: OnceCell<Vec<u8>>,
62}
63
64impl MultiSig {
65    /// This combines a list of [enum Signature] `flag || signature || pk` to a MultiSig.
66    /// The order of full_sigs must be the same as the order of public keys in
67    /// [enum MultiSigPublicKey]. e.g. for [pk1, pk2, pk3, pk4, pk5],
68    /// [sig1, sig2, sig5] is valid, but [sig2, sig1, sig5] is invalid.
69    pub fn combine(
70        full_sigs: Vec<Signature>,
71        multisig_pk: MultiSigPublicKey,
72    ) -> Result<Self, Error> {
73        multisig_pk
74            .validate()
75            .map_err(|_| Error::InvalidSignature {
76                error: "Invalid multisig public key".to_string(),
77            })?;
78
79        if full_sigs.len() > multisig_pk.pk_map.len() || full_sigs.is_empty() {
80            return Err(Error::InvalidSignature {
81                error: "Invalid number of signatures".to_string(),
82            });
83        }
84        let mut bitmap = 0;
85        let mut sigs = Vec::with_capacity(full_sigs.len());
86        for s in full_sigs {
87            let pk = s.to_public_key()?;
88            let index = multisig_pk
89                .get_index(&pk)
90                .ok_or_else(|| Error::IncorrectSigner {
91                    error: format!("pk does not exist: {pk:?}"),
92                })?;
93            if bitmap & (1 << index) != 0 {
94                return Err(Error::InvalidSignature {
95                    error: "Duplicate public key".to_string(),
96                });
97            }
98            bitmap |= 1 << index;
99            sigs.push(s.to_compressed()?);
100        }
101
102        Ok(Self {
103            sigs,
104            bitmap,
105            multisig_pk,
106            bytes: OnceCell::new(),
107        })
108    }
109
110    pub fn init_and_validate(&self) -> Result<Self, FastCryptoError> {
111        if self.sigs.len() > self.multisig_pk.pk_map.len()
112            || self.sigs.is_empty()
113            || self.bitmap > MAX_BITMAP_VALUE
114        {
115            return Err(FastCryptoError::InvalidInput);
116        }
117        self.multisig_pk.validate()?;
118        Ok(self.to_owned())
119    }
120
121    pub const fn get_pk(&self) -> &MultiSigPublicKey {
122        &self.multisig_pk
123    }
124
125    #[rustversion::attr(
126        stable,
127        expect(
128            clippy::missing_const_for_fn,
129            reason = "Not changing the public API right now"
130        )
131    )]
132    pub fn get_sigs(&self) -> &[CompressedSignature] {
133        &self.sigs
134    }
135
136    pub fn get_indices(&self) -> Result<Vec<u8>, Error> {
137        as_indices(self.bitmap)
138    }
139}
140
141/// Necessary trait for [struct SenderSignedData].
142impl PartialEq for MultiSig {
143    fn eq(&self, other: &Self) -> bool {
144        self.sigs == other.sigs
145            && self.bitmap == other.bitmap
146            && self.multisig_pk == other.multisig_pk
147    }
148}
149
150/// Necessary trait for [struct SenderSignedData].
151impl Eq for MultiSig {}
152
153/// Necessary trait for [struct SenderSignedData].
154impl Hash for MultiSig {
155    fn hash<H: Hasher>(&self, state: &mut H) {
156        self.as_ref().hash(state);
157    }
158}
159
160/// Interpret a bitmap of 01s as a list of indices that is set to 1s.
161/// e.g. 22 = 0b10110, then the result is [1, 2, 4].
162pub fn as_indices(bitmap: u16) -> Result<Vec<u8>, Error> {
163    if bitmap > MAX_BITMAP_VALUE {
164        return Err(Error::InvalidSignature {
165            error: "Invalid bitmap".to_string(),
166        });
167    }
168    let mut res = Vec::new();
169    for i in 0..10 {
170        if bitmap & (1 << i) != 0 {
171            res.push(i as u8);
172        }
173    }
174    Ok(res)
175}
176
177impl ToFromBytes for MultiSig {
178    fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
179        // The first byte matches the flag of MultiSig.
180        if bytes.first().ok_or(FastCryptoError::InvalidInput)? != &SignatureScheme::MultiSig.flag()
181        {
182            return Err(FastCryptoError::InvalidInput);
183        }
184        let multisig: Self =
185            FromBcs::from_bcs(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
186        multisig.init_and_validate()
187    }
188}
189
190impl FromStr for MultiSig {
191    type Err = Error;
192
193    fn from_str(s: &str) -> Result<Self, Self::Err> {
194        let bytes = Base64::decode(s).map_err(|_| Error::InvalidSignature {
195            error: "Invalid base64 string".to_string(),
196        })?;
197        let sig = Self::from_bytes(&bytes).map_err(|_| Error::InvalidSignature {
198            error: "Invalid multisig bytes".to_string(),
199        })?;
200        Ok(sig)
201    }
202}
203
204/// This initialize the underlying bytes representation of MultiSig. It encodes
205/// [struct MultiSig] as the MultiSig flag (0x03) concat with the bcs bytes
206/// of [struct MultiSig] i.e. `flag || bcs_bytes(MultiSig)`.
207impl AsRef<[u8]> for MultiSig {
208    fn as_ref(&self) -> &[u8] {
209        self.bytes
210            .get_or_try_init::<_, eyre::Report>(|| {
211                let as_bytes = self.to_bcs().expect("BCS serialization should not fail");
212                let mut bytes = Vec::with_capacity(1 + as_bytes.len());
213                bytes.push(SignatureScheme::MultiSig.flag());
214                bytes.extend_from_slice(as_bytes.as_slice());
215                Ok(bytes)
216            })
217            .expect("OnceCell invariant violated")
218    }
219}
220
221impl From<MultiSig> for sui_sdk_types::UserSignature {
222    fn from(value: MultiSig) -> Self {
223        Self::from_bytes(value.as_bytes()).expect("Compatible")
224    }
225}
226
227// =============================================================================
228//  MultiSigPublicKey
229// =============================================================================
230
231/// The struct that contains the public key used for authenticating a MultiSig.
232#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
233pub struct MultiSigPublicKey {
234    /// A list of public key and its corresponding weight.
235    pk_map: Vec<(PublicKey, WeightUnit)>,
236    /// If the total weight of the public keys corresponding to verified signatures is larger than threshold, the MultiSig is verified.
237    threshold: ThresholdUnit,
238}
239
240impl MultiSigPublicKey {
241    /// Construct MultiSigPublicKey without validation.
242    #[expect(
243        clippy::missing_const_for_fn,
244        reason = "Don't want to risk breaking the API if this uses a non-const init in the future"
245    )]
246    pub fn insecure_new(pk_map: Vec<(PublicKey, WeightUnit)>, threshold: ThresholdUnit) -> Self {
247        Self { pk_map, threshold }
248    }
249
250    pub fn new(
251        pks: Vec<PublicKey>,
252        weights: Vec<WeightUnit>,
253        threshold: ThresholdUnit,
254    ) -> Result<Self, Error> {
255        if pks.is_empty()
256            || weights.is_empty()
257            || threshold == 0
258            || pks.len() != weights.len()
259            || pks.len() > MAX_SIGNER_IN_MULTISIG
260            || weights.contains(&0)
261            || weights
262                .iter()
263                .map(|w| *w as ThresholdUnit)
264                .sum::<ThresholdUnit>()
265                < threshold
266            || pks
267                .iter()
268                .enumerate()
269                .any(|(i, pk)| pks.iter().skip(i + 1).any(|other_pk| *pk == *other_pk))
270        {
271            return Err(Error::InvalidSignature {
272                error: "Invalid multisig public key construction".to_string(),
273            });
274        }
275
276        Ok(Self {
277            pk_map: pks.into_iter().zip(weights).collect(),
278            threshold,
279        })
280    }
281
282    pub fn get_index(&self, pk: &PublicKey) -> Option<u8> {
283        self.pk_map.iter().position(|x| &x.0 == pk).map(|x| x as u8)
284    }
285
286    pub const fn threshold(&self) -> &ThresholdUnit {
287        &self.threshold
288    }
289
290    pub const fn pubkeys(&self) -> &Vec<(PublicKey, WeightUnit)> {
291        &self.pk_map
292    }
293
294    pub fn validate(&self) -> Result<Self, FastCryptoError> {
295        let pk_map = self.pubkeys();
296        if self.threshold == 0
297            || pk_map.is_empty()
298            || pk_map.len() > MAX_SIGNER_IN_MULTISIG
299            || pk_map.iter().any(|(_pk, weight)| *weight == 0)
300            || pk_map
301                .iter()
302                .map(|(_pk, weight)| *weight as ThresholdUnit)
303                .sum::<ThresholdUnit>()
304                < self.threshold
305            || pk_map.iter().enumerate().any(|(i, (pk, _weight))| {
306                pk_map
307                    .iter()
308                    .skip(i + 1)
309                    .any(|(other_pk, _weight)| *pk == *other_pk)
310            })
311        {
312            return Err(FastCryptoError::InvalidInput);
313        }
314        Ok(self.to_owned())
315    }
316}
317
318impl From<&MultiSigPublicKey> for SuiAddress {
319    /// Derive a SuiAddress from [struct MultiSigPublicKey]. A MultiSig address
320    /// is defined as the 32-byte Blake2b hash of serializing the flag, the
321    /// threshold, concatenation of all n flag, public keys and
322    /// its weight. `flag_MultiSig || threshold || flag_1 || pk_1 || weight_1
323    /// || ... || flag_n || pk_n || weight_n`.
324    ///
325    /// When flag_i is ZkLogin, pk_i refers to [struct ZkLoginPublicIdentifier]
326    /// derived from padded address seed in bytes and iss.
327    fn from(multisig_pk: &MultiSigPublicKey) -> Self {
328        let mut hasher = DefaultHash::default();
329        hasher.update([SignatureScheme::MultiSig.flag()]);
330        hasher.update(multisig_pk.threshold().to_le_bytes());
331        multisig_pk.pubkeys().iter().for_each(|(pk, w)| {
332            hasher.update([pk.flag()]);
333            hasher.update(pk.as_ref());
334            hasher.update(w.to_le_bytes());
335        });
336        Self::new(hasher.finalize().digest)
337    }
338}