1use std::hash::{Hash, Hasher};
5use std::str::FromStr;
6
7use af_sui_types::Address as SuiAddress;
8pub use enum_dispatch::enum_dispatch;
9use fastcrypto::encoding::{Base64, Encoding};
10use fastcrypto::error::FastCryptoError;
11use fastcrypto::hash::HashFunction as _;
12use fastcrypto::traits::ToFromBytes;
13use once_cell::sync::OnceCell;
14use serde::{Deserialize, Serialize};
15use serde_with::serde_as;
16
17use crate::crypto::{
18 CompressedSignature,
19 DefaultHash,
20 Error,
21 PublicKey,
22 Signature,
23 SignatureScheme,
24};
25
26pub type WeightUnit = u8;
27pub type ThresholdUnit = u16;
28pub type BitmapUnit = u16;
29pub const MAX_SIGNER_IN_MULTISIG: usize = 10;
30pub const MAX_BITMAP_VALUE: BitmapUnit = 0b1111111111;
31
32#[derive(Deserialize, Debug)]
38pub struct MultiSigSigner {
39 pub multisig_pk: MultiSigPublicKey,
40 pub signers: Vec<usize>,
42}
43
44#[serde_as]
50#[derive(Debug, Serialize, Deserialize, Clone)]
51pub struct MultiSig {
52 sigs: Vec<CompressedSignature>,
54 bitmap: BitmapUnit,
56 multisig_pk: MultiSigPublicKey,
58 #[serde(skip)]
60 bytes: OnceCell<Vec<u8>>,
61}
62
63impl MultiSig {
64 pub fn combine(
69 full_sigs: Vec<Signature>,
70 multisig_pk: MultiSigPublicKey,
71 ) -> Result<Self, Error> {
72 multisig_pk
73 .validate()
74 .map_err(|_| Error::InvalidSignature {
75 error: "Invalid multisig public key".to_string(),
76 })?;
77
78 if full_sigs.len() > multisig_pk.pk_map.len() || full_sigs.is_empty() {
79 return Err(Error::InvalidSignature {
80 error: "Invalid number of signatures".to_string(),
81 });
82 }
83 let mut bitmap = 0;
84 let mut sigs = Vec::with_capacity(full_sigs.len());
85 for s in full_sigs {
86 let pk = s.to_public_key()?;
87 let index = multisig_pk
88 .get_index(&pk)
89 .ok_or_else(|| Error::IncorrectSigner {
90 error: format!("pk does not exist: {pk:?}"),
91 })?;
92 if bitmap & (1 << index) != 0 {
93 return Err(Error::InvalidSignature {
94 error: "Duplicate public key".to_string(),
95 });
96 }
97 bitmap |= 1 << index;
98 sigs.push(s.to_compressed()?);
99 }
100
101 Ok(Self {
102 sigs,
103 bitmap,
104 multisig_pk,
105 bytes: OnceCell::new(),
106 })
107 }
108
109 pub fn init_and_validate(&self) -> Result<Self, FastCryptoError> {
110 if self.sigs.len() > self.multisig_pk.pk_map.len()
111 || self.sigs.is_empty()
112 || self.bitmap > MAX_BITMAP_VALUE
113 {
114 return Err(FastCryptoError::InvalidInput);
115 }
116 self.multisig_pk.validate()?;
117 Ok(self.to_owned())
118 }
119
120 pub const fn get_pk(&self) -> &MultiSigPublicKey {
121 &self.multisig_pk
122 }
123
124 #[rustversion::attr(
125 stable,
126 expect(
127 clippy::missing_const_for_fn,
128 reason = "Not changing the public API right now"
129 )
130 )]
131 pub fn get_sigs(&self) -> &[CompressedSignature] {
132 &self.sigs
133 }
134
135 pub fn get_indices(&self) -> Result<Vec<u8>, Error> {
136 as_indices(self.bitmap)
137 }
138}
139
140impl PartialEq for MultiSig {
142 fn eq(&self, other: &Self) -> bool {
143 self.sigs == other.sigs
144 && self.bitmap == other.bitmap
145 && self.multisig_pk == other.multisig_pk
146 }
147}
148
149impl Eq for MultiSig {}
151
152impl Hash for MultiSig {
154 fn hash<H: Hasher>(&self, state: &mut H) {
155 self.as_ref().hash(state);
156 }
157}
158
159pub fn as_indices(bitmap: u16) -> Result<Vec<u8>, Error> {
162 if bitmap > MAX_BITMAP_VALUE {
163 return Err(Error::InvalidSignature {
164 error: "Invalid bitmap".to_string(),
165 });
166 }
167 let mut res = Vec::new();
168 for i in 0..10 {
169 if bitmap & (1 << i) != 0 {
170 res.push(i as u8);
171 }
172 }
173 Ok(res)
174}
175
176impl ToFromBytes for MultiSig {
177 fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
178 if bytes.first().ok_or(FastCryptoError::InvalidInput)? != &SignatureScheme::MultiSig.flag()
180 {
181 return Err(FastCryptoError::InvalidInput);
182 }
183 let multisig: Self =
184 bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
185 multisig.init_and_validate()
186 }
187}
188
189impl FromStr for MultiSig {
190 type Err = Error;
191
192 fn from_str(s: &str) -> Result<Self, Self::Err> {
193 let bytes = Base64::decode(s).map_err(|_| Error::InvalidSignature {
194 error: "Invalid base64 string".to_string(),
195 })?;
196 let sig = Self::from_bytes(&bytes).map_err(|_| Error::InvalidSignature {
197 error: "Invalid multisig bytes".to_string(),
198 })?;
199 Ok(sig)
200 }
201}
202
203impl AsRef<[u8]> for MultiSig {
207 fn as_ref(&self) -> &[u8] {
208 self.bytes
209 .get_or_try_init::<_, eyre::Report>(|| {
210 let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
211 let mut bytes = Vec::with_capacity(1 + as_bytes.len());
212 bytes.push(SignatureScheme::MultiSig.flag());
213 bytes.extend_from_slice(as_bytes.as_slice());
214 Ok(bytes)
215 })
216 .expect("OnceCell invariant violated")
217 }
218}
219
220impl From<MultiSig> for af_sui_types::UserSignature {
221 fn from(value: MultiSig) -> Self {
222 Self::from_bytes(value.as_bytes()).expect("Compatible")
223 }
224}
225
226#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
232pub struct MultiSigPublicKey {
233 pk_map: Vec<(PublicKey, WeightUnit)>,
235 threshold: ThresholdUnit,
237}
238
239impl MultiSigPublicKey {
240 #[expect(
242 clippy::missing_const_for_fn,
243 reason = "Don't want to risk breaking the API if this uses a non-const init in the future"
244 )]
245 pub fn insecure_new(pk_map: Vec<(PublicKey, WeightUnit)>, threshold: ThresholdUnit) -> Self {
246 Self { pk_map, threshold }
247 }
248
249 pub fn new(
250 pks: Vec<PublicKey>,
251 weights: Vec<WeightUnit>,
252 threshold: ThresholdUnit,
253 ) -> Result<Self, Error> {
254 if pks.is_empty()
255 || weights.is_empty()
256 || threshold == 0
257 || pks.len() != weights.len()
258 || pks.len() > MAX_SIGNER_IN_MULTISIG
259 || weights.contains(&0)
260 || weights
261 .iter()
262 .map(|w| *w as ThresholdUnit)
263 .sum::<ThresholdUnit>()
264 < threshold
265 || pks
266 .iter()
267 .enumerate()
268 .any(|(i, pk)| pks.iter().skip(i + 1).any(|other_pk| *pk == *other_pk))
269 {
270 return Err(Error::InvalidSignature {
271 error: "Invalid multisig public key construction".to_string(),
272 });
273 }
274
275 Ok(Self {
276 pk_map: pks.into_iter().zip(weights).collect(),
277 threshold,
278 })
279 }
280
281 pub fn get_index(&self, pk: &PublicKey) -> Option<u8> {
282 self.pk_map.iter().position(|x| &x.0 == pk).map(|x| x as u8)
283 }
284
285 pub const fn threshold(&self) -> &ThresholdUnit {
286 &self.threshold
287 }
288
289 pub const fn pubkeys(&self) -> &Vec<(PublicKey, WeightUnit)> {
290 &self.pk_map
291 }
292
293 pub fn validate(&self) -> Result<Self, FastCryptoError> {
294 let pk_map = self.pubkeys();
295 if self.threshold == 0
296 || pk_map.is_empty()
297 || pk_map.len() > MAX_SIGNER_IN_MULTISIG
298 || pk_map.iter().any(|(_pk, weight)| *weight == 0)
299 || pk_map
300 .iter()
301 .map(|(_pk, weight)| *weight as ThresholdUnit)
302 .sum::<ThresholdUnit>()
303 < self.threshold
304 || pk_map.iter().enumerate().any(|(i, (pk, _weight))| {
305 pk_map
306 .iter()
307 .skip(i + 1)
308 .any(|(other_pk, _weight)| *pk == *other_pk)
309 })
310 {
311 return Err(FastCryptoError::InvalidInput);
312 }
313 Ok(self.to_owned())
314 }
315}
316
317impl From<&MultiSigPublicKey> for SuiAddress {
318 fn from(multisig_pk: &MultiSigPublicKey) -> Self {
327 let mut hasher = DefaultHash::default();
328 hasher.update([SignatureScheme::MultiSig.flag()]);
329 hasher.update(multisig_pk.threshold().to_le_bytes());
330 multisig_pk.pubkeys().iter().for_each(|(pk, w)| {
331 hasher.update([pk.flag()]);
332 hasher.update(pk.as_ref());
333 hasher.update(w.to_le_bytes());
334 });
335 Self::new(hasher.finalize().digest)
336 }
337}