1use std::hash::{Hash, Hasher};
5use std::str::FromStr;
6
7pub use enum_dispatch::enum_dispatch;
8use fastcrypto::encoding::{Base64, Encoding};
9use fastcrypto::error::FastCryptoError;
10use fastcrypto::hash::HashFunction as _;
11use fastcrypto::traits::ToFromBytes;
12use once_cell::sync::OnceCell;
13use serde::{Deserialize, Serialize};
14use serde_with::serde_as;
15use sui_sdk_types::Address as SuiAddress;
16use sui_sdk_types::bcs::{FromBcs, ToBcs};
17
18use crate::crypto::{
19 CompressedSignature,
20 DefaultHash,
21 Error,
22 PublicKey,
23 Signature,
24 SignatureScheme,
25};
26
27pub type WeightUnit = u8;
28pub type ThresholdUnit = u16;
29pub type BitmapUnit = u16;
30pub const MAX_SIGNER_IN_MULTISIG: usize = 10;
31pub const MAX_BITMAP_VALUE: BitmapUnit = 0b1111111111;
32
33#[derive(Deserialize, Debug)]
39pub struct MultiSigSigner {
40 pub multisig_pk: MultiSigPublicKey,
41 pub signers: Vec<usize>,
43}
44
45#[serde_as]
51#[derive(Debug, Serialize, Deserialize, Clone)]
52pub struct MultiSig {
53 sigs: Vec<CompressedSignature>,
55 bitmap: BitmapUnit,
57 multisig_pk: MultiSigPublicKey,
59 #[serde(skip)]
61 bytes: OnceCell<Vec<u8>>,
62}
63
64impl MultiSig {
65 pub fn combine(
70 full_sigs: Vec<Signature>,
71 multisig_pk: MultiSigPublicKey,
72 ) -> Result<Self, Error> {
73 multisig_pk
74 .validate()
75 .map_err(|_| Error::InvalidSignature {
76 error: "Invalid multisig public key".to_string(),
77 })?;
78
79 if full_sigs.len() > multisig_pk.pk_map.len() || full_sigs.is_empty() {
80 return Err(Error::InvalidSignature {
81 error: "Invalid number of signatures".to_string(),
82 });
83 }
84 let mut bitmap = 0;
85 let mut sigs = Vec::with_capacity(full_sigs.len());
86 for s in full_sigs {
87 let pk = s.to_public_key()?;
88 let index = multisig_pk
89 .get_index(&pk)
90 .ok_or_else(|| Error::IncorrectSigner {
91 error: format!("pk does not exist: {pk:?}"),
92 })?;
93 if bitmap & (1 << index) != 0 {
94 return Err(Error::InvalidSignature {
95 error: "Duplicate public key".to_string(),
96 });
97 }
98 bitmap |= 1 << index;
99 sigs.push(s.to_compressed()?);
100 }
101
102 Ok(Self {
103 sigs,
104 bitmap,
105 multisig_pk,
106 bytes: OnceCell::new(),
107 })
108 }
109
110 pub fn init_and_validate(&self) -> Result<Self, FastCryptoError> {
111 if self.sigs.len() > self.multisig_pk.pk_map.len()
112 || self.sigs.is_empty()
113 || self.bitmap > MAX_BITMAP_VALUE
114 {
115 return Err(FastCryptoError::InvalidInput);
116 }
117 self.multisig_pk.validate()?;
118 Ok(self.to_owned())
119 }
120
121 pub const fn get_pk(&self) -> &MultiSigPublicKey {
122 &self.multisig_pk
123 }
124
125 #[rustversion::attr(
126 stable,
127 expect(
128 clippy::missing_const_for_fn,
129 reason = "Not changing the public API right now"
130 )
131 )]
132 pub fn get_sigs(&self) -> &[CompressedSignature] {
133 &self.sigs
134 }
135
136 pub fn get_indices(&self) -> Result<Vec<u8>, Error> {
137 as_indices(self.bitmap)
138 }
139}
140
141impl PartialEq for MultiSig {
143 fn eq(&self, other: &Self) -> bool {
144 self.sigs == other.sigs
145 && self.bitmap == other.bitmap
146 && self.multisig_pk == other.multisig_pk
147 }
148}
149
150impl Eq for MultiSig {}
152
153impl Hash for MultiSig {
155 fn hash<H: Hasher>(&self, state: &mut H) {
156 self.as_ref().hash(state);
157 }
158}
159
160pub fn as_indices(bitmap: u16) -> Result<Vec<u8>, Error> {
163 if bitmap > MAX_BITMAP_VALUE {
164 return Err(Error::InvalidSignature {
165 error: "Invalid bitmap".to_string(),
166 });
167 }
168 let mut res = Vec::new();
169 for i in 0..10 {
170 if bitmap & (1 << i) != 0 {
171 res.push(i as u8);
172 }
173 }
174 Ok(res)
175}
176
177impl ToFromBytes for MultiSig {
178 fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
179 if bytes.first().ok_or(FastCryptoError::InvalidInput)? != &SignatureScheme::MultiSig.flag()
181 {
182 return Err(FastCryptoError::InvalidInput);
183 }
184 let multisig: Self =
185 FromBcs::from_bcs(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
186 multisig.init_and_validate()
187 }
188}
189
190impl FromStr for MultiSig {
191 type Err = Error;
192
193 fn from_str(s: &str) -> Result<Self, Self::Err> {
194 let bytes = Base64::decode(s).map_err(|_| Error::InvalidSignature {
195 error: "Invalid base64 string".to_string(),
196 })?;
197 let sig = Self::from_bytes(&bytes).map_err(|_| Error::InvalidSignature {
198 error: "Invalid multisig bytes".to_string(),
199 })?;
200 Ok(sig)
201 }
202}
203
204impl AsRef<[u8]> for MultiSig {
208 fn as_ref(&self) -> &[u8] {
209 self.bytes
210 .get_or_try_init::<_, eyre::Report>(|| {
211 let as_bytes = self.to_bcs().expect("BCS serialization should not fail");
212 let mut bytes = Vec::with_capacity(1 + as_bytes.len());
213 bytes.push(SignatureScheme::MultiSig.flag());
214 bytes.extend_from_slice(as_bytes.as_slice());
215 Ok(bytes)
216 })
217 .expect("OnceCell invariant violated")
218 }
219}
220
221impl From<MultiSig> for sui_sdk_types::UserSignature {
222 fn from(value: MultiSig) -> Self {
223 Self::from_bytes(value.as_bytes()).expect("Compatible")
224 }
225}
226
227#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
233pub struct MultiSigPublicKey {
234 pk_map: Vec<(PublicKey, WeightUnit)>,
236 threshold: ThresholdUnit,
238}
239
240impl MultiSigPublicKey {
241 #[expect(
243 clippy::missing_const_for_fn,
244 reason = "Don't want to risk breaking the API if this uses a non-const init in the future"
245 )]
246 pub fn insecure_new(pk_map: Vec<(PublicKey, WeightUnit)>, threshold: ThresholdUnit) -> Self {
247 Self { pk_map, threshold }
248 }
249
250 pub fn new(
251 pks: Vec<PublicKey>,
252 weights: Vec<WeightUnit>,
253 threshold: ThresholdUnit,
254 ) -> Result<Self, Error> {
255 if pks.is_empty()
256 || weights.is_empty()
257 || threshold == 0
258 || pks.len() != weights.len()
259 || pks.len() > MAX_SIGNER_IN_MULTISIG
260 || weights.contains(&0)
261 || weights
262 .iter()
263 .map(|w| *w as ThresholdUnit)
264 .sum::<ThresholdUnit>()
265 < threshold
266 || pks
267 .iter()
268 .enumerate()
269 .any(|(i, pk)| pks.iter().skip(i + 1).any(|other_pk| *pk == *other_pk))
270 {
271 return Err(Error::InvalidSignature {
272 error: "Invalid multisig public key construction".to_string(),
273 });
274 }
275
276 Ok(Self {
277 pk_map: pks.into_iter().zip(weights).collect(),
278 threshold,
279 })
280 }
281
282 pub fn get_index(&self, pk: &PublicKey) -> Option<u8> {
283 self.pk_map.iter().position(|x| &x.0 == pk).map(|x| x as u8)
284 }
285
286 pub const fn threshold(&self) -> &ThresholdUnit {
287 &self.threshold
288 }
289
290 pub const fn pubkeys(&self) -> &Vec<(PublicKey, WeightUnit)> {
291 &self.pk_map
292 }
293
294 pub fn validate(&self) -> Result<Self, FastCryptoError> {
295 let pk_map = self.pubkeys();
296 if self.threshold == 0
297 || pk_map.is_empty()
298 || pk_map.len() > MAX_SIGNER_IN_MULTISIG
299 || pk_map.iter().any(|(_pk, weight)| *weight == 0)
300 || pk_map
301 .iter()
302 .map(|(_pk, weight)| *weight as ThresholdUnit)
303 .sum::<ThresholdUnit>()
304 < self.threshold
305 || pk_map.iter().enumerate().any(|(i, (pk, _weight))| {
306 pk_map
307 .iter()
308 .skip(i + 1)
309 .any(|(other_pk, _weight)| *pk == *other_pk)
310 })
311 {
312 return Err(FastCryptoError::InvalidInput);
313 }
314 Ok(self.to_owned())
315 }
316}
317
318impl From<&MultiSigPublicKey> for SuiAddress {
319 fn from(multisig_pk: &MultiSigPublicKey) -> Self {
328 let mut hasher = DefaultHash::default();
329 hasher.update([SignatureScheme::MultiSig.flag()]);
330 hasher.update(multisig_pk.threshold().to_le_bytes());
331 multisig_pk.pubkeys().iter().for_each(|(pk, w)| {
332 hasher.update([pk.flag()]);
333 hasher.update(pk.as_ref());
334 hasher.update(w.to_le_bytes());
335 });
336 Self::new(hasher.finalize().digest)
337 }
338}