cml_chain/
utils.rs

1use cbor_event::{de::Deserializer, se::Serializer, Sz};
2use cml_core::{
3    error::{DeserializeError, DeserializeFailure},
4    serialization::{fit_sz, sz_max, Deserialize, LenEncoding, Serialize},
5    Int, Slot,
6};
7use cml_crypto::{Ed25519KeyHash, RawBytesEncoding, ScriptHash};
8use derivative::Derivative;
9use std::iter::IntoIterator;
10use std::{
11    convert::TryFrom,
12    io::{BufRead, Seek, Write},
13};
14
15use crate::{
16    crypto::hash::{hash_script, ScriptHashNamespace},
17    plutus::{Language, PlutusScript, PlutusV1Script, PlutusV2Script, PlutusV3Script},
18    NativeScript, Script, SubCoin,
19};
20
21impl Script {
22    pub fn hash(&self) -> ScriptHash {
23        match self {
24            Self::Native { script, .. } => script.hash(),
25            Self::PlutusV1 { script, .. } => script.hash(),
26            Self::PlutusV2 { script, .. } => script.hash(),
27            Self::PlutusV3 { script, .. } => script.hash(),
28        }
29    }
30
31    pub fn raw_plutus_bytes(&self) -> Result<&[u8], ScriptConversionError> {
32        match self {
33            Self::Native { .. } => Err(ScriptConversionError::NativeScriptNotPlutus),
34            Self::PlutusV1 { script, .. } => Ok(script.to_raw_bytes()),
35            Self::PlutusV2 { script, .. } => Ok(script.to_raw_bytes()),
36            Self::PlutusV3 { script, .. } => Ok(script.to_raw_bytes()),
37        }
38    }
39
40    // Returns which language the script is if it's a Plutus script
41    // Returns None otherwise (i.e. NativeScript)
42    pub fn language(&self) -> Option<Language> {
43        match self {
44            Self::Native { .. } => None,
45            Self::PlutusV1 { .. } => Some(Language::PlutusV1),
46            Self::PlutusV2 { .. } => Some(Language::PlutusV2),
47            Self::PlutusV3 { .. } => Some(Language::PlutusV3),
48        }
49    }
50}
51
52impl NativeScript {
53    pub fn hash(&self) -> ScriptHash {
54        hash_script(ScriptHashNamespace::NativeScript, &self.to_cbor_bytes())
55    }
56
57    pub fn verify(
58        &self,
59        lower_bound: Option<Slot>,
60        upper_bound: Option<Slot>,
61        key_hashes: &Vec<Ed25519KeyHash>,
62    ) -> bool {
63        fn verify_helper(
64            script: &NativeScript,
65            lower_bound: Option<Slot>,
66            upper_bound: Option<Slot>,
67            key_hashes: &Vec<Ed25519KeyHash>,
68        ) -> bool {
69            match &script {
70                NativeScript::ScriptPubkey(pub_key) => {
71                    key_hashes.contains(&pub_key.ed25519_key_hash)
72                }
73                NativeScript::ScriptAll(script_all) => {
74                    script_all.native_scripts.iter().all(|sub_script| {
75                        verify_helper(sub_script, lower_bound, upper_bound, key_hashes)
76                    })
77                }
78                NativeScript::ScriptAny(script_any) => {
79                    script_any.native_scripts.iter().any(|sub_script| {
80                        verify_helper(sub_script, lower_bound, upper_bound, key_hashes)
81                    })
82                }
83                NativeScript::ScriptNOfK(script_atleast) => {
84                    script_atleast
85                        .native_scripts
86                        .iter()
87                        .map(|sub_script| {
88                            verify_helper(sub_script, lower_bound, upper_bound, key_hashes)
89                        })
90                        .filter(|r| *r)
91                        .count()
92                        >= script_atleast.n as usize
93                }
94                NativeScript::ScriptInvalidBefore(timelock_start) => match lower_bound {
95                    Some(tx_slot) => tx_slot >= timelock_start.before,
96                    _ => false,
97                },
98                NativeScript::ScriptInvalidHereafter(timelock_expiry) => match upper_bound {
99                    Some(tx_slot) => tx_slot < timelock_expiry.after,
100                    _ => false,
101                },
102            }
103        }
104
105        verify_helper(self, lower_bound, upper_bound, key_hashes)
106    }
107}
108
109impl From<NativeScript> for Script {
110    fn from(script: NativeScript) -> Self {
111        Self::new_native(script)
112    }
113}
114
115impl From<PlutusV1Script> for Script {
116    fn from(script: PlutusV1Script) -> Self {
117        Self::new_plutus_v1(script)
118    }
119}
120
121impl From<PlutusV2Script> for Script {
122    fn from(script: PlutusV2Script) -> Self {
123        Self::new_plutus_v2(script)
124    }
125}
126
127impl From<PlutusV3Script> for Script {
128    fn from(script: PlutusV3Script) -> Self {
129        Self::new_plutus_v3(script)
130    }
131}
132
133impl From<PlutusScript> for Script {
134    fn from(script: PlutusScript) -> Self {
135        match script {
136            PlutusScript::PlutusV1(v1) => Self::new_plutus_v1(v1),
137            PlutusScript::PlutusV2(v2) => Self::new_plutus_v2(v2),
138            PlutusScript::PlutusV3(v3) => Self::new_plutus_v3(v3),
139        }
140    }
141}
142
143#[derive(Debug, thiserror::Error)]
144pub enum ScriptConversionError {
145    #[error("Cannot convert NativeScript to PlutusScript")]
146    NativeScriptNotPlutus,
147}
148
149impl TryFrom<Script> for PlutusScript {
150    type Error = ScriptConversionError;
151
152    fn try_from(script: Script) -> Result<PlutusScript, Self::Error> {
153        match script {
154            Script::Native { .. } => Err(ScriptConversionError::NativeScriptNotPlutus),
155            Script::PlutusV1 { script, .. } => Ok(PlutusScript::PlutusV1(script)),
156            Script::PlutusV2 { script, .. } => Ok(PlutusScript::PlutusV2(script)),
157            Script::PlutusV3 { script, .. } => Ok(PlutusScript::PlutusV3(script)),
158        }
159    }
160}
161
162const BOUNDED_BYTES_CHUNK_SIZE: usize = 64;
163
164// to get around not having access from outside the library we just write the raw CBOR indefinite byte string code here
165fn write_cbor_indefinite_byte_tag<W: Write>(
166    serializer: &mut Serializer<W>,
167) -> cbor_event::Result<&mut Serializer<W>> {
168    serializer.write_raw_bytes(&[0x5f])
169}
170
171use cml_core::serialization::StringEncoding;
172
173fn valid_indefinite_string_encoding(chunks: &[(u64, cbor_event::Sz)], total_len: usize) -> bool {
174    let mut len_counter = 0;
175    let valid_sz = chunks.iter().all(|(len, sz)| {
176        len_counter += len;
177        *len <= sz_max(*sz)
178    });
179    valid_sz && len_counter == total_len as u64
180}
181
182/// Write bounded bytes according to Cardano's special format:
183/// bounded_bytes = bytes .size (0..64)
184///  ; the real bounded_bytes does not have this limit. it instead has a different
185///   ; limit which cannot be expressed in CDDL.
186///   ; The limit is as follows:
187///   ;  - bytes with a definite-length encoding are limited to size 0..64
188///   ;  - for bytes with an indefinite-length CBOR encoding, each chunk is
189///   ;    limited to size 0..64
190///   ;  ( reminder: in CBOR, the indefinite-length encoding of bytestrings
191///   ;    consists of a token #2.31 followed by a sequence of definite-length
192///   ;    encoded bytestrings and a stop code )
193pub fn write_bounded_bytes<'se, W: Write>(
194    serializer: &'se mut Serializer<W>,
195    bytes: &[u8],
196    enc: &StringEncoding,
197    force_canonical: bool,
198) -> cbor_event::Result<&'se mut Serializer<W>> {
199    match enc {
200        StringEncoding::Definite(sz) if !force_canonical => {
201            if bytes.len() <= BOUNDED_BYTES_CHUNK_SIZE {
202                let fit_sz = fit_sz(bytes.len() as u64, Some(*sz), force_canonical);
203                return serializer.write_bytes_sz(bytes, cbor_event::StringLenSz::Len(fit_sz));
204            }
205        }
206        StringEncoding::Indefinite(chunks) if !force_canonical => {
207            if valid_indefinite_string_encoding(chunks, bytes.len()) {
208                write_cbor_indefinite_byte_tag(serializer)?;
209                let mut start = 0;
210                for (len, sz) in chunks {
211                    let end = start + *len as usize;
212                    serializer
213                        .write_bytes_sz(&bytes[start..end], cbor_event::StringLenSz::Len(*sz))?;
214                    start = end;
215                }
216                return serializer.write_special(cbor_event::Special::Break);
217            }
218        }
219        _ =>
220            /* handled below */
221            {}
222    };
223    // This is a fallback for when either it's canonical or the passed in encoding isn't
224    // compatible with the passed in bytes (e.g. someone deserialized then modified the bytes)
225    // If we truly need to encode canonical CBOR there's really no way to abide by both canonical
226    // CBOR as well as following the Cardano format. So this is the best attempt at it while keeping
227    // chunks when len > 64
228    if bytes.len() <= BOUNDED_BYTES_CHUNK_SIZE {
229        serializer.write_bytes(bytes)
230    } else {
231        write_cbor_indefinite_byte_tag(serializer)?;
232        for chunk in bytes.chunks(BOUNDED_BYTES_CHUNK_SIZE) {
233            serializer.write_bytes(chunk)?;
234        }
235        serializer.write_special(cbor_event::Special::Break)
236    }
237}
238
239/// Read bounded bytes according to Cardano's special format:
240/// bounded_bytes = bytes .size (0..64)
241///  ; the real bounded_bytes does not have this limit. it instead has a different
242///  ; limit which cannot be expressed in CDDL.
243///  ; The limit is as follows:
244///  ;  - bytes with a definite-length encoding are limited to size 0..64
245///  ;  - for bytes with an indefinite-length CBOR encoding, each chunk is
246///  ;    limited to size 0..64
247///  ;  ( reminder: in CBOR, the indefinite-length encoding of bytestrings
248///  ;    consists of a token #2.31 followed by a sequence of definite-length
249///  ;    encoded bytestrings and a stop code )
250pub fn read_bounded_bytes<R: BufRead + Seek>(
251    raw: &mut Deserializer<R>,
252) -> Result<(Vec<u8>, StringEncoding), DeserializeError> {
253    let (bytes, bytes_enc) = raw.bytes_sz()?;
254    match &bytes_enc {
255        cbor_event::StringLenSz::Len(_sz) => {
256            if bytes.len() > BOUNDED_BYTES_CHUNK_SIZE {
257                return Err(DeserializeFailure::OutOfRange {
258                    min: 0,
259                    max: BOUNDED_BYTES_CHUNK_SIZE,
260                    found: bytes.len(),
261                }
262                .into());
263            }
264        }
265        cbor_event::StringLenSz::Indefinite(chunks) => {
266            for (chunk_len, _chunk_len_sz) in chunks.iter() {
267                if *chunk_len as usize > BOUNDED_BYTES_CHUNK_SIZE {
268                    return Err(DeserializeFailure::OutOfRange {
269                        min: 0,
270                        max: BOUNDED_BYTES_CHUNK_SIZE,
271                        found: *chunk_len as usize,
272                    }
273                    .into());
274                }
275            }
276        }
277    }
278    Ok((bytes, bytes_enc.into()))
279}
280
281#[derive(Clone, Debug)]
282enum BigIntEncoding {
283    Int(cbor_event::Sz),
284    Bytes(StringEncoding),
285}
286
287#[derive(Clone, Debug, Derivative)]
288#[derivative(Eq, PartialEq, Ord, PartialOrd, Hash)]
289pub struct BigInteger {
290    pub(crate) num: num_bigint::BigInt,
291    #[derivative(
292        PartialEq = "ignore",
293        Ord = "ignore",
294        PartialOrd = "ignore",
295        Hash = "ignore"
296    )]
297    encoding: Option<BigIntEncoding>,
298}
299
300impl serde::Serialize for BigInteger {
301    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
302    where
303        S: serde::Serializer,
304    {
305        serializer.serialize_str(&self.to_string())
306    }
307}
308
309impl<'de> serde::de::Deserialize<'de> for BigInteger {
310    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
311    where
312        D: serde::de::Deserializer<'de>,
313    {
314        use std::str::FromStr;
315        let s = <String as serde::de::Deserialize>::deserialize(deserializer)?;
316        BigInteger::from_str(&s).map_err(|_e| {
317            serde::de::Error::invalid_value(
318                serde::de::Unexpected::Str(&s),
319                &"string rep of a big int",
320            )
321        })
322    }
323}
324
325impl schemars::JsonSchema for BigInteger {
326    fn schema_name() -> String {
327        String::from("BigInteger")
328    }
329    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
330        String::json_schema(gen)
331    }
332    fn is_referenceable() -> bool {
333        String::is_referenceable()
334    }
335}
336
337impl std::fmt::Display for BigInteger {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        self.num.fmt(f)
340    }
341}
342
343impl std::str::FromStr for BigInteger {
344    type Err = num_bigint::ParseBigIntError;
345    fn from_str(string: &str) -> Result<Self, Self::Err> {
346        num_bigint::BigInt::from_str(string).map(|num| Self {
347            num,
348            encoding: None,
349        })
350    }
351}
352
353impl BigInteger {
354    // can't be a trait due to being in other crate
355    pub fn from_int(x: &Int) -> Self {
356        Self {
357            num: Into::<i128>::into(x).into(),
358            encoding: x.encoding().map(BigIntEncoding::Int),
359        }
360    }
361
362    /// Converts to a u64
363    /// Returns None if the number was negative or too big for a u64
364    pub fn as_u64(&self) -> Option<u64> {
365        let (sign, u64_digits) = self.num.to_u64_digits();
366        if sign == num_bigint::Sign::Minus {
367            return None;
368        }
369        match u64_digits.len() {
370            0 => Some(0),
371            1 => Some(*u64_digits.first().unwrap()),
372            _ => None,
373        }
374    }
375
376    /// Converts to a u128
377    /// Returns None if the number was negative or too big for a u128
378    pub fn as_u128(&self) -> Option<u128> {
379        let (sign, u32_digits) = self.num.to_u32_digits();
380        if sign == num_bigint::Sign::Minus {
381            return None;
382        }
383        match *u32_digits {
384            [] => Some(0),
385            [a] => Some(u128::from(a)),
386            [a, b] => Some(u128::from(a) | (u128::from(b) << 32)),
387            [a, b, c] => Some(u128::from(a) | (u128::from(b) << 32) | (u128::from(c) << 64)),
388            [a, b, c, d] => Some(
389                u128::from(a)
390                    | (u128::from(b) << 32)
391                    | (u128::from(c) << 64)
392                    | (u128::from(d) << 96),
393            ),
394            _ => None,
395        }
396    }
397
398    /// Converts to an Int
399    /// Returns None when the number is too big for an Int (outside +/- 64-bit unsigned)
400    /// Retains encoding info if the original was encoded as an Int
401    pub fn as_int(&self) -> Option<Int> {
402        let (sign, u64_digits) = self.num.to_u64_digits();
403        // unsigned raw value that can fit in the up to 8 bytes of a CBOR uint or nint
404        // negative values evaluate to -u64_value - 1
405        let u64_value = match u64_digits.len() {
406            0 => 0u64,
407            1 => {
408                if sign == num_bigint::Sign::Minus {
409                    (*u64_digits.first().unwrap())
410                        .checked_sub(1)
411                        .expect("negative (non-zero) so can't underflow")
412                } else {
413                    *u64_digits.first().unwrap()
414                }
415            }
416            // this could actually be -u64::MAX which in CBOR can be a single u64 as the sign
417            // is encoded separately so values here start from -1 instead of 0.
418            2 if sign == num_bigint::Sign::Minus && u64_digits[0] == 0 && u64_digits[1] == 1 => {
419                u64::MAX
420            }
421            _ => return None,
422        };
423        let encoding = match &self.encoding {
424            Some(BigIntEncoding::Int(sz)) => Some(*sz),
425            _ => None,
426        };
427        match sign {
428            num_bigint::Sign::NoSign | num_bigint::Sign::Plus => Some(Int::Uint {
429                value: u64_value,
430                encoding,
431            }),
432            num_bigint::Sign::Minus => Some(Int::Nint {
433                value: u64_value,
434                encoding,
435            }),
436        }
437    }
438}
439
440impl Serialize for BigInteger {
441    fn serialize<'se, W: Write>(
442        &self,
443        serializer: &'se mut Serializer<W>,
444        force_canonical: bool,
445    ) -> cbor_event::Result<&'se mut Serializer<W>> {
446        let write_self_as_bytes = |serializer: &'se mut Serializer<W>,
447                                   enc: &StringEncoding|
448         -> cbor_event::Result<&'se mut Serializer<W>> {
449            let (sign, bytes) = self.num.to_bytes_be();
450            match sign {
451                // positive bigint
452                num_bigint::Sign::Plus | num_bigint::Sign::NoSign => {
453                    serializer.write_tag(2u64)?;
454                    write_bounded_bytes(serializer, &bytes, enc, force_canonical)
455                }
456                // negative bigint
457                num_bigint::Sign::Minus => {
458                    serializer.write_tag(3u64)?;
459                    use std::ops::Neg;
460                    // CBOR RFC defines this as the bytes of -n -1
461                    let adjusted = self
462                        .num
463                        .clone()
464                        .neg()
465                        .checked_sub(&num_bigint::BigInt::from(1u32))
466                        .unwrap()
467                        .to_biguint()
468                        .unwrap();
469                    write_bounded_bytes(serializer, &adjusted.to_bytes_be(), enc, force_canonical)
470                }
471            }
472        };
473        // use encoding if possible
474        match &self.encoding {
475            Some(BigIntEncoding::Int(_sz)) if !force_canonical => {
476                // as_int() retains encoding info so we can direclty use Int::serialize()
477                if let Some(int) = self.as_int() {
478                    return int.serialize(serializer, force_canonical);
479                }
480            }
481            Some(BigIntEncoding::Bytes(str_enc)) if !force_canonical => {
482                let (_sign, bytes) = self.num.to_bytes_be();
483                let valid_non_canonical = match str_enc {
484                    StringEncoding::Canonical => false,
485                    StringEncoding::Definite(sz) => bytes.len() <= sz_max(*sz) as usize,
486                    StringEncoding::Indefinite(chunks) => {
487                        valid_indefinite_string_encoding(chunks, bytes.len())
488                    }
489                };
490                if valid_non_canonical {
491                    return write_self_as_bytes(serializer, str_enc);
492                }
493            }
494            _ =>
495                /* always fallback to default */
496                {}
497        }
498        // fallback for:
499        // 1) canonical bytes needed
500        // 2) no encoding specified (never deseiralized)
501        // 3) deserialized but data changed and no longer compatible
502        let (sign, u64_digits) = self.num.to_u64_digits();
503        match u64_digits.len() {
504            0 => serializer.write_unsigned_integer(0),
505            // we use the uint/nint encodings to use a minimum of space
506            1 => match sign {
507                // uint
508                num_bigint::Sign::Plus | num_bigint::Sign::NoSign => {
509                    serializer.write_unsigned_integer(*u64_digits.first().unwrap())
510                }
511                // nint
512                num_bigint::Sign::Minus => serializer
513                    .write_negative_integer(-(*u64_digits.first().unwrap() as i128) as i64),
514            },
515            _ => {
516                // Small edge case: nint's minimum is -18446744073709551616 but in this bigint lib
517                // that takes 2 u64 bytes so we put that as a special case here:
518                if sign == num_bigint::Sign::Minus && u64_digits == vec![0, 1] {
519                    serializer.write_negative_integer(-18446744073709551616i128 as i64)
520                } else {
521                    write_self_as_bytes(serializer, &StringEncoding::Canonical)
522                }
523            }
524        }
525    }
526}
527
528impl Deserialize for BigInteger {
529    fn deserialize<R: BufRead + Seek>(raw: &mut Deserializer<R>) -> Result<Self, DeserializeError> {
530        (|| -> Result<_, DeserializeError> {
531            match raw.cbor_type()? {
532                // bigint
533                cbor_event::Type::Tag => {
534                    let tag = raw.tag()?;
535                    let (bytes, bytes_enc) = read_bounded_bytes(raw)?;
536                    match tag {
537                        // positive bigint
538                        2 => Ok(Self {
539                            num: num_bigint::BigInt::from_bytes_be(num_bigint::Sign::Plus, &bytes),
540                            encoding: Some(BigIntEncoding::Bytes(bytes_enc)),
541                        }),
542                        // negative bigint
543                        3 => {
544                            // CBOR RFC defines this as the bytes of -n -1
545                            let initial =
546                                num_bigint::BigInt::from_bytes_be(num_bigint::Sign::Plus, &bytes);
547                            use std::ops::Neg;
548                            let adjusted = initial
549                                .checked_add(&num_bigint::BigInt::from(1u32))
550                                .unwrap()
551                                .neg();
552                            Ok(Self {
553                                num: adjusted,
554                                encoding: Some(BigIntEncoding::Bytes(bytes_enc)),
555                            })
556                        }
557                        _ => Err(DeserializeFailure::TagMismatch {
558                            found: tag,
559                            expected: 2,
560                        }
561                        .into()),
562                    }
563                }
564                // uint
565                cbor_event::Type::UnsignedInteger => {
566                    let (num, num_enc) = raw.unsigned_integer_sz()?;
567                    Ok(Self {
568                        num: num_bigint::BigInt::from(num),
569                        encoding: Some(BigIntEncoding::Int(num_enc)),
570                    })
571                }
572                // nint
573                cbor_event::Type::NegativeInteger => {
574                    let (num, num_enc) = raw.negative_integer_sz()?;
575                    Ok(Self {
576                        num: num_bigint::BigInt::from(num),
577                        encoding: Some(BigIntEncoding::Int(num_enc)),
578                    })
579                }
580                _ => Err(DeserializeFailure::NoVariantMatched.into()),
581            }
582        })()
583        .map_err(|e| e.annotate("BigInteger"))
584    }
585}
586
587impl<T> std::convert::From<T> for BigInteger
588where
589    T: std::convert::Into<num_bigint::BigInt>,
590{
591    fn from(x: T) -> Self {
592        Self {
593            num: x.into(),
594            encoding: None,
595        }
596    }
597}
598
599#[derive(Clone, Copy, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
600pub struct NetworkId {
601    pub network: u64,
602    #[serde(skip)]
603    pub encoding: Option<cbor_event::Sz>,
604}
605
606impl NetworkId {
607    pub fn new(network: u64) -> Self {
608        Self {
609            network,
610            encoding: None,
611        }
612    }
613
614    pub fn mainnet() -> Self {
615        Self {
616            network: 1,
617            encoding: None,
618        }
619    }
620
621    pub fn testnet() -> Self {
622        Self {
623            network: 0,
624            encoding: None,
625        }
626    }
627}
628
629impl From<u64> for NetworkId {
630    fn from(network: u64) -> Self {
631        NetworkId::new(network)
632    }
633}
634
635impl From<NetworkId> for u64 {
636    fn from(id: NetworkId) -> u64 {
637        id.network
638    }
639}
640
641impl Serialize for NetworkId {
642    fn serialize<'se, W: Write>(
643        &self,
644        serializer: &'se mut Serializer<W>,
645        force_canonical: bool,
646    ) -> cbor_event::Result<&'se mut Serializer<W>> {
647        serializer.write_unsigned_integer_sz(
648            self.network,
649            fit_sz(self.network, self.encoding, force_canonical),
650        )
651    }
652}
653
654impl Deserialize for NetworkId {
655    fn deserialize<R: BufRead + Seek>(raw: &mut Deserializer<R>) -> Result<Self, DeserializeError> {
656        let (network, encoding) = raw.unsigned_integer_sz().map(|(x, enc)| (x, Some(enc)))?;
657        Ok(Self { network, encoding })
658    }
659}
660
661impl SubCoin {
662    /// Converts base 10 floats to SubCoin.
663    /// This is the format used by blockfrost for ex units
664    /// Warning: If the passed in float was not meant to be base 10
665    /// this might result in a slightly inaccurate fraction.
666    pub fn from_base10_f32(f: f32) -> Self {
667        let mut denom = 1u64;
668        while (f * (denom as f32)).fract().abs() > f32::EPSILON {
669            denom *= 10;
670        }
671        Self::new((f * (denom as f32)).ceil() as u64, denom)
672    }
673}
674
675// Represents the cddl: #6.258([+ T]) / [* T]
676// it DOES NOT and CAN NOT have any encoding detials per element!
677// so you can NOT use it on any primitives so must be serializable directly
678#[derive(Debug, Clone)]
679pub struct NonemptySet<T> {
680    elems: Vec<T>,
681    len_encoding: LenEncoding,
682    // also controls whether to use the tag encoding (Some) or raw array (None)
683    tag_encoding: Option<Sz>,
684}
685
686impl<T: serde::Serialize> serde::Serialize for NonemptySet<T> {
687    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
688    where
689        S: serde::Serializer,
690    {
691        self.elems.serialize(serializer)
692    }
693}
694
695impl<'de, T: serde::de::Deserialize<'de>> serde::de::Deserialize<'de> for NonemptySet<T> {
696    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
697    where
698        D: serde::de::Deserializer<'de>,
699    {
700        Vec::deserialize(deserializer).map(|elems| Self {
701            elems,
702            len_encoding: LenEncoding::default(),
703            tag_encoding: None,
704        })
705    }
706}
707
708impl<T: schemars::JsonSchema> schemars::JsonSchema for NonemptySet<T> {
709    fn schema_name() -> String {
710        Vec::<T>::schema_name()
711    }
712    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
713        Vec::<T>::json_schema(gen)
714    }
715    fn is_referenceable() -> bool {
716        Vec::<T>::is_referenceable()
717    }
718}
719
720impl<T> AsRef<[T]> for NonemptySet<T> {
721    fn as_ref(&self) -> &[T] {
722        self.elems.as_ref()
723    }
724}
725
726impl<T> IntoIterator for NonemptySet<T> {
727    type Item = T;
728    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
729
730    fn into_iter(self) -> Self::IntoIter {
731        self.elems.into_iter()
732    }
733}
734
735impl<'a, T> IntoIterator for &'a NonemptySet<T> {
736    type Item = &'a T;
737    type IntoIter = std::slice::Iter<'a, T>;
738
739    fn into_iter(self) -> Self::IntoIter {
740        self.elems.iter()
741    }
742}
743
744impl<'a, T> IntoIterator for &'a mut NonemptySet<T> {
745    type Item = &'a mut T;
746    type IntoIter = std::slice::IterMut<'a, T>;
747
748    fn into_iter(self) -> Self::IntoIter {
749        self.elems.iter_mut()
750    }
751}
752
753impl<T> std::ops::Deref for NonemptySet<T> {
754    type Target = Vec<T>;
755
756    fn deref(&self) -> &Self::Target {
757        &self.elems
758    }
759}
760
761impl<T> std::ops::DerefMut for NonemptySet<T> {
762    fn deref_mut(&mut self) -> &mut Self::Target {
763        &mut self.elems
764    }
765}
766
767impl<T> From<Vec<T>> for NonemptySet<T> {
768    fn from(elems: Vec<T>) -> Self {
769        Self {
770            elems,
771            len_encoding: LenEncoding::default(),
772            tag_encoding: Some(Sz::Two),
773        }
774    }
775}
776
777impl<T> From<NonemptySet<T>> for Vec<T> {
778    fn from(set: NonemptySet<T>) -> Self {
779        set.elems
780    }
781}
782
783impl<T: Serialize> Serialize for NonemptySet<T> {
784    fn serialize<'se, W: Write>(
785        &self,
786        serializer: &'se mut Serializer<W>,
787        force_canonical: bool,
788    ) -> cbor_event::Result<&'se mut Serializer<W>> {
789        if let Some(tag_encoding) = &self.tag_encoding {
790            serializer.write_tag_sz(258, *tag_encoding)?;
791        }
792        serializer.write_array_sz(
793            self.len_encoding
794                .to_len_sz(self.elems.len() as u64, force_canonical),
795        )?;
796        for elem in self.elems.iter() {
797            elem.serialize(serializer, force_canonical)?;
798        }
799        self.len_encoding.end(serializer, force_canonical)
800    }
801}
802
803impl<T: Deserialize> Deserialize for NonemptySet<T> {
804    fn deserialize<R: BufRead + Seek>(raw: &mut Deserializer<R>) -> Result<Self, DeserializeError> {
805        (|| -> Result<_, DeserializeError> {
806            let mut elems = Vec::new();
807            let (arr_len, tag_encoding) = if raw.cbor_type()? == cbor_event::Type::Tag {
808                let (tag, tag_encoding) = raw.tag_sz()?;
809                if tag != 258 {
810                    return Err(DeserializeFailure::TagMismatch {
811                        found: tag,
812                        expected: 258,
813                    }
814                    .into());
815                }
816                (raw.array_sz()?, Some(tag_encoding))
817            } else {
818                (raw.array_sz()?, None)
819            };
820            let len_encoding = arr_len.into();
821            while match arr_len {
822                cbor_event::LenSz::Len(n, _) => (elems.len() as u64) < n,
823                cbor_event::LenSz::Indefinite => true,
824            } {
825                if raw.cbor_type()? == cbor_event::Type::Special {
826                    assert_eq!(raw.special()?, cbor_event::Special::Break);
827                    break;
828                }
829                let elem = T::deserialize(raw)?;
830                elems.push(elem);
831            }
832            Ok(Self {
833                elems,
834                len_encoding,
835                tag_encoding,
836            })
837        })()
838        .map_err(|e| e.annotate("NonemptySet"))
839    }
840}
841
842// for now just do this
843pub type Set<T> = NonemptySet<T>;
844
845// Represents the cddl: #6.258([+ T]) / [* T] where T uses RawBytesEncoding
846#[derive(Debug, Clone)]
847pub struct NonemptySetRawBytes<T> {
848    elems: Vec<T>,
849    len_encoding: LenEncoding,
850    // also controls whether to use the tag encoding (Some) or raw array (None)
851    tag_encoding: Option<Sz>,
852    bytes_encodings: Vec<StringEncoding>,
853}
854
855impl<T: serde::Serialize> serde::Serialize for NonemptySetRawBytes<T> {
856    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
857    where
858        S: serde::Serializer,
859    {
860        self.elems.serialize(serializer)
861    }
862}
863
864impl<'de, T: serde::de::Deserialize<'de>> serde::de::Deserialize<'de> for NonemptySetRawBytes<T> {
865    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
866    where
867        D: serde::de::Deserializer<'de>,
868    {
869        Vec::deserialize(deserializer).map(|elems| Self {
870            elems,
871            len_encoding: LenEncoding::default(),
872            tag_encoding: None,
873            bytes_encodings: Vec::new(),
874        })
875    }
876}
877
878impl<T: schemars::JsonSchema> schemars::JsonSchema for NonemptySetRawBytes<T> {
879    fn schema_name() -> String {
880        Vec::<T>::schema_name()
881    }
882    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
883        Vec::<T>::json_schema(gen)
884    }
885    fn is_referenceable() -> bool {
886        Vec::<T>::is_referenceable()
887    }
888}
889
890impl<T> AsRef<[T]> for NonemptySetRawBytes<T> {
891    fn as_ref(&self) -> &[T] {
892        self.elems.as_ref()
893    }
894}
895
896impl<T> IntoIterator for NonemptySetRawBytes<T> {
897    type Item = T;
898    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
899
900    fn into_iter(self) -> Self::IntoIter {
901        self.elems.into_iter()
902    }
903}
904
905impl<'a, T> IntoIterator for &'a NonemptySetRawBytes<T> {
906    type Item = &'a T;
907    type IntoIter = std::slice::Iter<'a, T>;
908
909    fn into_iter(self) -> Self::IntoIter {
910        self.elems.iter()
911    }
912}
913
914impl<'a, T> IntoIterator for &'a mut NonemptySetRawBytes<T> {
915    type Item = &'a mut T;
916    type IntoIter = std::slice::IterMut<'a, T>;
917
918    fn into_iter(self) -> Self::IntoIter {
919        self.elems.iter_mut()
920    }
921}
922
923impl<T> std::ops::Deref for NonemptySetRawBytes<T> {
924    type Target = Vec<T>;
925
926    fn deref(&self) -> &Self::Target {
927        &self.elems
928    }
929}
930
931impl<T> std::ops::DerefMut for NonemptySetRawBytes<T> {
932    fn deref_mut(&mut self) -> &mut Self::Target {
933        &mut self.elems
934    }
935}
936
937impl<T> From<Vec<T>> for NonemptySetRawBytes<T> {
938    fn from(elems: Vec<T>) -> Self {
939        Self {
940            elems,
941            len_encoding: LenEncoding::default(),
942            tag_encoding: Some(Sz::Two),
943            bytes_encodings: Vec::new(),
944        }
945    }
946}
947
948impl<T> From<NonemptySetRawBytes<T>> for Vec<T> {
949    fn from(set: NonemptySetRawBytes<T>) -> Self {
950        set.elems
951    }
952}
953
954impl<T: RawBytesEncoding> Serialize for NonemptySetRawBytes<T> {
955    fn serialize<'se, W: Write>(
956        &self,
957        serializer: &'se mut Serializer<W>,
958        force_canonical: bool,
959    ) -> cbor_event::Result<&'se mut Serializer<W>> {
960        if let Some(tag_encoding) = &self.tag_encoding {
961            serializer.write_tag_sz(258, *tag_encoding)?;
962        }
963        serializer.write_array_sz(
964            self.len_encoding
965                .to_len_sz(self.elems.len() as u64, force_canonical),
966        )?;
967        for (i, elem) in self.elems.iter().enumerate() {
968            serializer.write_bytes_sz(
969                elem.to_raw_bytes(),
970                self.bytes_encodings
971                    .get(i)
972                    .cloned()
973                    .unwrap_or_default()
974                    .to_str_len_sz(elem.to_raw_bytes().len() as u64, force_canonical),
975            )?;
976        }
977        self.len_encoding.end(serializer, force_canonical)
978    }
979}
980
981impl<T: RawBytesEncoding> Deserialize for NonemptySetRawBytes<T> {
982    fn deserialize<R: BufRead + Seek>(raw: &mut Deserializer<R>) -> Result<Self, DeserializeError> {
983        (|| -> Result<_, DeserializeError> {
984            let mut elems = Vec::new();
985            let mut bytes_encodings = Vec::new();
986            let (arr_len, tag_encoding) = if raw.cbor_type()? == cbor_event::Type::Tag {
987                let (tag, tag_encoding) = raw.tag_sz()?;
988                if tag != 258 {
989                    return Err(DeserializeFailure::TagMismatch {
990                        found: tag,
991                        expected: 258,
992                    }
993                    .into());
994                }
995                (raw.array_sz()?, Some(tag_encoding))
996            } else {
997                (raw.array_sz()?, None)
998            };
999            let len_encoding = arr_len.into();
1000            while match arr_len {
1001                cbor_event::LenSz::Len(n, _) => (elems.len() as u64) < n,
1002                cbor_event::LenSz::Indefinite => true,
1003            } {
1004                if raw.cbor_type()? == cbor_event::Type::Special {
1005                    assert_eq!(raw.special()?, cbor_event::Special::Break);
1006                    break;
1007                }
1008                let (bytes, bytes_enc) = raw.bytes_sz()?;
1009                let elem = T::from_raw_bytes(&bytes)
1010                    .map_err(|e| DeserializeFailure::InvalidStructure(Box::new(e)))?;
1011                elems.push(elem);
1012                bytes_encodings.push(bytes_enc.into());
1013            }
1014            Ok(Self {
1015                elems,
1016                len_encoding,
1017                tag_encoding,
1018                bytes_encodings,
1019            })
1020        })()
1021        .map_err(|e| e.annotate("NonemptySetRawBytes"))
1022    }
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027    use super::*;
1028    use std::str::FromStr;
1029
1030    #[test]
1031    fn bigint_uint_u64_min() {
1032        let bytes = [0x00];
1033        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1034        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1035        assert_eq!(x.as_u64(), Some(u64::MIN));
1036        assert_eq!(x.as_int().unwrap().to_string(), x.to_string());
1037        assert_eq!(x.to_string(), "0");
1038    }
1039
1040    #[test]
1041    fn bigint_uint_u64_max() {
1042        let bytes = [0x1B, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
1043        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1044        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1045        assert_eq!(x.as_u64(), Some(u64::MAX));
1046        assert_eq!(x.as_int().unwrap().to_string(), x.to_string());
1047        assert_eq!(x.to_string(), "18446744073709551615");
1048    }
1049
1050    #[test]
1051    fn bigint_uint_u128_roundtrip() {
1052        let int = 462_164_030_739_157_517;
1053        let x = BigInteger::from_int(&Int::Uint {
1054            value: int,
1055            encoding: None,
1056        });
1057        assert_eq!(x.as_u128(), Some(int as u128))
1058    }
1059
1060    #[test]
1061    fn bigint_uint_u128_roundtrip_min() {
1062        let int = u64::MIN;
1063        let x = BigInteger::from_int(&Int::Uint {
1064            value: int,
1065            encoding: None,
1066        });
1067        assert_eq!(x.as_u128(), Some(int as u128))
1068    }
1069
1070    #[test]
1071    fn bigint_uint_u128_roundtrip_max() {
1072        let int = u64::MAX;
1073        let x = BigInteger::from_int(&Int::Uint {
1074            value: int,
1075            encoding: None,
1076        });
1077        assert_eq!(x.as_u128(), Some(int as u128))
1078    }
1079
1080    #[test]
1081    fn bigint_uint_u128_min() {
1082        let bytes = [0x00];
1083        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1084        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1085        assert_eq!(x.as_u128(), Some(u128::MIN));
1086        assert_eq!(x.to_string(), "0");
1087    }
1088
1089    #[test]
1090    fn bigint_uint_u128_max() {
1091        let bytes = BigInteger::from_str(&u128::MAX.to_string())
1092            .unwrap()
1093            .to_cbor_bytes();
1094        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1095        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1096        assert_eq!(x.as_u128(), Some(u128::MAX));
1097        assert_eq!(x.to_string(), "340282366920938463463374607431768211455");
1098    }
1099
1100    #[test]
1101    fn bigint_above_uint_min() {
1102        let bytes = [
1103            0xC2, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1104        ];
1105        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1106        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1107        assert_eq!(x.as_int(), None);
1108        assert_eq!(x.to_string(), "18446744073709551616");
1109    }
1110
1111    #[test]
1112    fn bigint_nint_min() {
1113        let bytes = [0x3B, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
1114        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1115        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1116        assert_eq!(
1117            Into::<i128>::into(&x.as_int().unwrap()),
1118            -((u64::MAX as i128) + 1)
1119        );
1120        assert_eq!(x.as_int().unwrap().to_string(), x.to_string());
1121        assert_eq!(x.to_string(), "-18446744073709551616");
1122    }
1123
1124    #[test]
1125    fn bigint_nint_max() {
1126        let bytes = [0x20];
1127        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1128        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1129        assert_eq!(x.as_u64(), None);
1130        assert_eq!(x.as_int().unwrap().to_string(), x.to_string());
1131        assert_eq!(x.to_string(), "-1");
1132    }
1133
1134    #[test]
1135    fn bigint_below_nint_min() {
1136        let bytes = [
1137            0xC3, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1138        ];
1139        let x = BigInteger::from_cbor_bytes(&bytes).unwrap();
1140        assert_eq!(bytes, x.to_cbor_bytes().as_slice());
1141        assert_eq!(x.as_int(), None);
1142        assert_eq!(x.to_string(), "-18446744073709551617");
1143    }
1144}