af_keys/
multisig.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::hash::{Hash, Hasher};
5use std::str::FromStr;
6
7pub use enum_dispatch::enum_dispatch;
8use fastcrypto::encoding::{Base64, Encoding};
9use fastcrypto::error::FastCryptoError;
10use fastcrypto::hash::HashFunction as _;
11use fastcrypto::traits::ToFromBytes;
12use once_cell::sync::OnceCell;
13use serde::{Deserialize, Serialize};
14use serde_with::serde_as;
15use sui_sdk_types::Address as SuiAddress;
16
17use crate::crypto::{
18    CompressedSignature,
19    DefaultHash,
20    Error,
21    PublicKey,
22    Signature,
23    SignatureScheme,
24};
25
26pub type WeightUnit = u8;
27pub type ThresholdUnit = u16;
28pub type BitmapUnit = u16;
29pub const MAX_SIGNER_IN_MULTISIG: usize = 10;
30pub const MAX_BITMAP_VALUE: BitmapUnit = 0b1111111111;
31
32// =============================================================================
33//  MultiSigSigner
34// =============================================================================
35
36/// Data needed for signing as a multisig.
37#[derive(Deserialize, Debug)]
38pub struct MultiSigSigner {
39    pub multisig_pk: MultiSigPublicKey,
40    /// The indexes of the public keys in `multisig_pk` to sign for.
41    pub signers: Vec<usize>,
42}
43
44// =============================================================================
45//  MultiSig
46// =============================================================================
47
48/// The struct that contains signatures and public keys necessary for authenticating a MultiSig.
49#[serde_as]
50#[derive(Debug, Serialize, Deserialize, Clone)]
51pub struct MultiSig {
52    /// The plain signature encoded with signature scheme.
53    sigs: Vec<CompressedSignature>,
54    /// A bitmap that indicates the position of which public key the signature should be authenticated with.
55    bitmap: BitmapUnit,
56    /// The public key encoded with each public key with its signature scheme used along with the corresponding weight.
57    multisig_pk: MultiSigPublicKey,
58    /// A bytes representation of [struct MultiSig]. This helps with implementing [trait AsRef<[u8]>].
59    #[serde(skip)]
60    bytes: OnceCell<Vec<u8>>,
61}
62
63impl MultiSig {
64    /// This combines a list of [enum Signature] `flag || signature || pk` to a MultiSig.
65    /// The order of full_sigs must be the same as the order of public keys in
66    /// [enum MultiSigPublicKey]. e.g. for [pk1, pk2, pk3, pk4, pk5],
67    /// [sig1, sig2, sig5] is valid, but [sig2, sig1, sig5] is invalid.
68    pub fn combine(
69        full_sigs: Vec<Signature>,
70        multisig_pk: MultiSigPublicKey,
71    ) -> Result<Self, Error> {
72        multisig_pk
73            .validate()
74            .map_err(|_| Error::InvalidSignature {
75                error: "Invalid multisig public key".to_string(),
76            })?;
77
78        if full_sigs.len() > multisig_pk.pk_map.len() || full_sigs.is_empty() {
79            return Err(Error::InvalidSignature {
80                error: "Invalid number of signatures".to_string(),
81            });
82        }
83        let mut bitmap = 0;
84        let mut sigs = Vec::with_capacity(full_sigs.len());
85        for s in full_sigs {
86            let pk = s.to_public_key()?;
87            let index = multisig_pk
88                .get_index(&pk)
89                .ok_or_else(|| Error::IncorrectSigner {
90                    error: format!("pk does not exist: {pk:?}"),
91                })?;
92            if bitmap & (1 << index) != 0 {
93                return Err(Error::InvalidSignature {
94                    error: "Duplicate public key".to_string(),
95                });
96            }
97            bitmap |= 1 << index;
98            sigs.push(s.to_compressed()?);
99        }
100
101        Ok(Self {
102            sigs,
103            bitmap,
104            multisig_pk,
105            bytes: OnceCell::new(),
106        })
107    }
108
109    pub fn init_and_validate(&self) -> Result<Self, FastCryptoError> {
110        if self.sigs.len() > self.multisig_pk.pk_map.len()
111            || self.sigs.is_empty()
112            || self.bitmap > MAX_BITMAP_VALUE
113        {
114            return Err(FastCryptoError::InvalidInput);
115        }
116        self.multisig_pk.validate()?;
117        Ok(self.to_owned())
118    }
119
120    pub const fn get_pk(&self) -> &MultiSigPublicKey {
121        &self.multisig_pk
122    }
123
124    #[rustversion::attr(
125        stable,
126        expect(
127            clippy::missing_const_for_fn,
128            reason = "Not changing the public API right now"
129        )
130    )]
131    pub fn get_sigs(&self) -> &[CompressedSignature] {
132        &self.sigs
133    }
134
135    pub fn get_indices(&self) -> Result<Vec<u8>, Error> {
136        as_indices(self.bitmap)
137    }
138}
139
140/// Necessary trait for [struct SenderSignedData].
141impl PartialEq for MultiSig {
142    fn eq(&self, other: &Self) -> bool {
143        self.sigs == other.sigs
144            && self.bitmap == other.bitmap
145            && self.multisig_pk == other.multisig_pk
146    }
147}
148
149/// Necessary trait for [struct SenderSignedData].
150impl Eq for MultiSig {}
151
152/// Necessary trait for [struct SenderSignedData].
153impl Hash for MultiSig {
154    fn hash<H: Hasher>(&self, state: &mut H) {
155        self.as_ref().hash(state);
156    }
157}
158
159/// Interpret a bitmap of 01s as a list of indices that is set to 1s.
160/// e.g. 22 = 0b10110, then the result is [1, 2, 4].
161pub fn as_indices(bitmap: u16) -> Result<Vec<u8>, Error> {
162    if bitmap > MAX_BITMAP_VALUE {
163        return Err(Error::InvalidSignature {
164            error: "Invalid bitmap".to_string(),
165        });
166    }
167    let mut res = Vec::new();
168    for i in 0..10 {
169        if bitmap & (1 << i) != 0 {
170            res.push(i as u8);
171        }
172    }
173    Ok(res)
174}
175
176impl ToFromBytes for MultiSig {
177    fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
178        // The first byte matches the flag of MultiSig.
179        if bytes.first().ok_or(FastCryptoError::InvalidInput)? != &SignatureScheme::MultiSig.flag()
180        {
181            return Err(FastCryptoError::InvalidInput);
182        }
183        let multisig: Self =
184            bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
185        multisig.init_and_validate()
186    }
187}
188
189impl FromStr for MultiSig {
190    type Err = Error;
191
192    fn from_str(s: &str) -> Result<Self, Self::Err> {
193        let bytes = Base64::decode(s).map_err(|_| Error::InvalidSignature {
194            error: "Invalid base64 string".to_string(),
195        })?;
196        let sig = Self::from_bytes(&bytes).map_err(|_| Error::InvalidSignature {
197            error: "Invalid multisig bytes".to_string(),
198        })?;
199        Ok(sig)
200    }
201}
202
203/// This initialize the underlying bytes representation of MultiSig. It encodes
204/// [struct MultiSig] as the MultiSig flag (0x03) concat with the bcs bytes
205/// of [struct MultiSig] i.e. `flag || bcs_bytes(MultiSig)`.
206impl AsRef<[u8]> for MultiSig {
207    fn as_ref(&self) -> &[u8] {
208        self.bytes
209            .get_or_try_init::<_, eyre::Report>(|| {
210                let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
211                let mut bytes = Vec::with_capacity(1 + as_bytes.len());
212                bytes.push(SignatureScheme::MultiSig.flag());
213                bytes.extend_from_slice(as_bytes.as_slice());
214                Ok(bytes)
215            })
216            .expect("OnceCell invariant violated")
217    }
218}
219
220impl From<MultiSig> for sui_sdk_types::UserSignature {
221    fn from(value: MultiSig) -> Self {
222        Self::from_bytes(value.as_bytes()).expect("Compatible")
223    }
224}
225
226// =============================================================================
227//  MultiSigPublicKey
228// =============================================================================
229
230/// The struct that contains the public key used for authenticating a MultiSig.
231#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
232pub struct MultiSigPublicKey {
233    /// A list of public key and its corresponding weight.
234    pk_map: Vec<(PublicKey, WeightUnit)>,
235    /// If the total weight of the public keys corresponding to verified signatures is larger than threshold, the MultiSig is verified.
236    threshold: ThresholdUnit,
237}
238
239impl MultiSigPublicKey {
240    /// Construct MultiSigPublicKey without validation.
241    #[expect(
242        clippy::missing_const_for_fn,
243        reason = "Don't want to risk breaking the API if this uses a non-const init in the future"
244    )]
245    pub fn insecure_new(pk_map: Vec<(PublicKey, WeightUnit)>, threshold: ThresholdUnit) -> Self {
246        Self { pk_map, threshold }
247    }
248
249    pub fn new(
250        pks: Vec<PublicKey>,
251        weights: Vec<WeightUnit>,
252        threshold: ThresholdUnit,
253    ) -> Result<Self, Error> {
254        if pks.is_empty()
255            || weights.is_empty()
256            || threshold == 0
257            || pks.len() != weights.len()
258            || pks.len() > MAX_SIGNER_IN_MULTISIG
259            || weights.contains(&0)
260            || weights
261                .iter()
262                .map(|w| *w as ThresholdUnit)
263                .sum::<ThresholdUnit>()
264                < threshold
265            || pks
266                .iter()
267                .enumerate()
268                .any(|(i, pk)| pks.iter().skip(i + 1).any(|other_pk| *pk == *other_pk))
269        {
270            return Err(Error::InvalidSignature {
271                error: "Invalid multisig public key construction".to_string(),
272            });
273        }
274
275        Ok(Self {
276            pk_map: pks.into_iter().zip(weights).collect(),
277            threshold,
278        })
279    }
280
281    pub fn get_index(&self, pk: &PublicKey) -> Option<u8> {
282        self.pk_map.iter().position(|x| &x.0 == pk).map(|x| x as u8)
283    }
284
285    pub const fn threshold(&self) -> &ThresholdUnit {
286        &self.threshold
287    }
288
289    pub const fn pubkeys(&self) -> &Vec<(PublicKey, WeightUnit)> {
290        &self.pk_map
291    }
292
293    pub fn validate(&self) -> Result<Self, FastCryptoError> {
294        let pk_map = self.pubkeys();
295        if self.threshold == 0
296            || pk_map.is_empty()
297            || pk_map.len() > MAX_SIGNER_IN_MULTISIG
298            || pk_map.iter().any(|(_pk, weight)| *weight == 0)
299            || pk_map
300                .iter()
301                .map(|(_pk, weight)| *weight as ThresholdUnit)
302                .sum::<ThresholdUnit>()
303                < self.threshold
304            || pk_map.iter().enumerate().any(|(i, (pk, _weight))| {
305                pk_map
306                    .iter()
307                    .skip(i + 1)
308                    .any(|(other_pk, _weight)| *pk == *other_pk)
309            })
310        {
311            return Err(FastCryptoError::InvalidInput);
312        }
313        Ok(self.to_owned())
314    }
315}
316
317impl From<&MultiSigPublicKey> for SuiAddress {
318    /// Derive a SuiAddress from [struct MultiSigPublicKey]. A MultiSig address
319    /// is defined as the 32-byte Blake2b hash of serializing the flag, the
320    /// threshold, concatenation of all n flag, public keys and
321    /// its weight. `flag_MultiSig || threshold || flag_1 || pk_1 || weight_1
322    /// || ... || flag_n || pk_n || weight_n`.
323    ///
324    /// When flag_i is ZkLogin, pk_i refers to [struct ZkLoginPublicIdentifier]
325    /// derived from padded address seed in bytes and iss.
326    fn from(multisig_pk: &MultiSigPublicKey) -> Self {
327        let mut hasher = DefaultHash::default();
328        hasher.update([SignatureScheme::MultiSig.flag()]);
329        hasher.update(multisig_pk.threshold().to_le_bytes());
330        multisig_pk.pubkeys().iter().for_each(|(pk, w)| {
331            hasher.update([pk.flag()]);
332            hasher.update(pk.as_ref());
333            hasher.update(w.to_le_bytes());
334        });
335        Self::new(hasher.finalize().digest)
336    }
337}