griffin_core/uplc/
flat.rs

1use crate::pallas_codec::flat::{
2    de::{self, Decode, Decoder},
3    en::{self, Encode, Encoder},
4    Flat,
5};
6use crate::pallas_primitives::{conway::PlutusData, Fragment};
7use crate::uplc::{
8    ast::{
9        Constant, DeBruijn, FakeNamedDeBruijn, Name, NamedDeBruijn, Program, Term, Type, Unique,
10    },
11    builtins::DefaultFunction,
12    machine::runtime::Compressable,
13};
14use alloc::{
15    collections::VecDeque,
16    rc::Rc,
17    string::{String, ToString},
18    vec::Vec,
19};
20use core::fmt::Debug;
21use num_bigint::BigInt;
22
23const BUILTIN_TAG_WIDTH: u32 = 7;
24const CONST_TAG_WIDTH: u32 = 4;
25const TERM_TAG_WIDTH: u32 = 4;
26
27pub trait Binder<'b>: Encode + Decode<'b> {
28    fn binder_encode(&self, e: &mut Encoder) -> Result<(), en::Error>;
29    fn binder_decode(d: &mut Decoder) -> Result<Self, de::Error>;
30    fn text(&self) -> String;
31}
32
33impl<'b, T> Flat<'b> for Program<T> where T: Binder<'b> + Debug {}
34
35impl<'b, T> Program<T>
36where
37    T: Binder<'b> + Debug,
38{
39    pub fn from_cbor(bytes: &'b [u8], buffer: &'b mut Vec<u8>) -> Result<Self, de::Error> {
40        let mut cbor_decoder = crate::pallas_codec::minicbor::Decoder::new(bytes);
41
42        let flat_bytes = cbor_decoder
43            .bytes()
44            .map_err(|err| de::Error::Message(err.to_string()))?;
45
46        buffer.extend(flat_bytes);
47
48        Self::unflat(buffer)
49    }
50
51    pub fn from_flat(bytes: &'b [u8]) -> Result<Self, de::Error> {
52        Self::unflat(bytes)
53    }
54
55    pub fn from_hex(
56        hex_str: &str,
57        cbor_buffer: &'b mut Vec<u8>,
58        flat_buffer: &'b mut Vec<u8>,
59    ) -> Result<Self, de::Error> {
60        let cbor_bytes = hex::decode(hex_str).map_err(|err| de::Error::Message(err.to_string()))?;
61
62        cbor_buffer.extend(cbor_bytes);
63
64        Self::from_cbor(cbor_buffer, flat_buffer)
65    }
66
67    /// Convert a program to cbor bytes.
68    ///
69    /// _note: The cbor bytes of a program are merely
70    /// the flat bytes of the program encoded as cbor bytes._
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use uplc::ast::{Program, Name, Term};
76    ///
77    /// let term = Term::var("x").lambda("x");
78    /// let program = Program { version: (1, 0, 0), term };
79    ///
80    /// assert_eq!(
81    ///     program.to_debruijn().unwrap().to_cbor().unwrap(),
82    ///     vec![
83    ///         0x46, 0x01, 0x00, 0x00,
84    ///         0x20, 0x01, 0x01
85    ///     ],
86    /// );
87    /// ```
88    pub fn to_cbor(&self) -> Result<Vec<u8>, en::Error> {
89        let flat_bytes = self.flat()?;
90
91        let mut bytes = Vec::new();
92
93        let mut cbor_encoder = crate::pallas_codec::minicbor::Encoder::new(&mut bytes);
94
95        cbor_encoder
96            .bytes(&flat_bytes)
97            .map_err(|err| en::Error::Message(err.to_string()))?;
98
99        Ok(bytes)
100    }
101
102    /// Convert a program to a flat bytes.
103    ///
104    /// _**note**: Convenient so that people don't need to depend on the flat crate
105    /// directly to call programs flat function._
106    ///
107    /// # Examples
108    ///
109    /// ```
110    /// use uplc::ast::{Program, Name, Term};
111    ///
112    /// let term = Term::var("x").lambda("x");
113    /// let program = Program { version: (1, 0, 0), term };
114    ///
115    /// assert_eq!(
116    ///     program
117    ///         .to_debruijn()
118    ///         .unwrap()
119    ///         .to_flat()
120    ///         .unwrap(),
121    ///     vec![
122    ///         0x01, 0x00, 0x00,
123    ///         0x20, 0x01, 0x01
124    ///     ],
125    /// );
126    /// ```
127    pub fn to_flat(&self) -> Result<Vec<u8>, en::Error> {
128        self.flat()
129    }
130
131    /// Convert a program to hex encoded cbor bytes
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use uplc::ast::{Program, Name, Term};
137    ///
138    /// let term = Term::var("x").lambda("x");
139    /// let program = Program { version: (1, 0, 0), term };
140    ///
141    /// assert_eq!(
142    ///     program.to_debruijn().unwrap().to_hex().unwrap(),
143    ///     "46010000200101".to_string(),
144    /// );
145    /// ```
146    pub fn to_hex(&self) -> Result<String, en::Error> {
147        let bytes = self.to_cbor()?;
148
149        let hex = hex::encode(bytes);
150
151        Ok(hex)
152    }
153}
154
155impl<'b, T> Encode for Program<T>
156where
157    T: Binder<'b> + Debug,
158{
159    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
160        let (major, minor, patch) = self.version;
161
162        major.encode(e)?;
163        minor.encode(e)?;
164        patch.encode(e)?;
165
166        self.term.encode(e)?;
167
168        Ok(())
169    }
170}
171
172impl<'b, T> Decode<'b> for Program<T>
173where
174    T: Binder<'b>,
175{
176    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
177        let mut state_log: Vec<String> = vec![];
178        let version = (usize::decode(d)?, usize::decode(d)?, usize::decode(d)?);
179        let term_option = Term::decode_debug(d, &mut state_log);
180
181        match term_option {
182            Ok(term) => Ok(Program { version, term }),
183            Err(error) => Err(de::Error::Message(format!(
184                "{} {error}",
185                state_log.join("")
186            ))),
187        }
188    }
189}
190
191impl<'b, T> Encode for Term<T>
192where
193    T: Binder<'b> + Debug,
194{
195    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
196        match self {
197            Term::Var(name) => {
198                encode_term_tag(0, e)?;
199                name.encode(e)?;
200            }
201            Term::Delay(term) => {
202                encode_term_tag(1, e)?;
203                term.encode(e)?;
204            }
205            Term::Lambda {
206                parameter_name,
207                body,
208            } => {
209                encode_term_tag(2, e)?;
210                parameter_name.binder_encode(e)?;
211                body.encode(e)?;
212            }
213            Term::Apply { function, argument } => {
214                encode_term_tag(3, e)?;
215                function.encode(e)?;
216                argument.encode(e)?;
217            }
218
219            Term::Constant(constant) => {
220                encode_term_tag(4, e)?;
221                constant.encode(e)?;
222            }
223
224            Term::Force(term) => {
225                encode_term_tag(5, e)?;
226                term.encode(e)?;
227            }
228
229            Term::Error => {
230                encode_term_tag(6, e)?;
231            }
232            Term::Builtin(builtin) => {
233                encode_term_tag(7, e)?;
234
235                builtin.encode(e)?;
236            }
237            Term::Constr { tag, fields } => {
238                encode_term_tag(8, e)?;
239
240                tag.encode(e)?;
241
242                e.encode_list_with(fields, |term, e| (*term).encode(e))?;
243            }
244            Term::Case { constr, branches } => {
245                encode_term_tag(9, e)?;
246
247                constr.encode(e)?;
248
249                e.encode_list_with(branches, |term, e| (*term).encode(e))?;
250            }
251        }
252
253        Ok(())
254    }
255}
256
257impl<'b, T> Decode<'b> for Term<T>
258where
259    T: Binder<'b>,
260{
261    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
262        match decode_term_tag(d)? {
263            0 => Ok(Term::Var(T::decode(d)?.into())),
264            1 => Ok(Term::Delay(Rc::new(Term::decode(d)?))),
265            2 => Ok(Term::Lambda {
266                parameter_name: T::binder_decode(d)?.into(),
267                body: Rc::new(Term::decode(d)?),
268            }),
269            3 => Ok(Term::Apply {
270                function: Rc::new(Term::decode(d)?),
271                argument: Rc::new(Term::decode(d)?),
272            }),
273            // Need size limit for Constant
274            4 => Ok(Term::Constant(Constant::decode(d)?.into())),
275            5 => Ok(Term::Force(Rc::new(Term::decode(d)?))),
276            6 => Ok(Term::Error),
277            7 => Ok(Term::Builtin(DefaultFunction::decode(d)?)),
278            8 => {
279                let tag = usize::decode(d)?;
280                let fields = d.decode_list_with(Term::<T>::decode)?;
281
282                Ok(Term::Constr { tag, fields })
283            }
284            9 => {
285                let constr = (Term::<T>::decode(d)?).into();
286
287                let branches = d.decode_list_with(Term::<T>::decode)?;
288
289                Ok(Term::Case { constr, branches })
290            }
291            x => {
292                let buffer_slice: Vec<u8> = d
293                    .buffer
294                    .to_vec()
295                    .iter()
296                    .skip(if d.pos > 5 { d.pos - 5 } else { 0 })
297                    .take(10)
298                    .cloned()
299                    .collect();
300
301                Err(de::Error::UnknownTermConstructor(
302                    x,
303                    if d.pos > 5 { 5 } else { d.pos },
304                    format!("{buffer_slice:02X?}"),
305                    d.pos,
306                    d.buffer.len(),
307                ))
308            }
309        }
310    }
311}
312
313impl<'b, T> Term<T>
314where
315    T: Binder<'b>,
316{
317    fn decode_debug(d: &mut Decoder, state_log: &mut Vec<String>) -> Result<Term<T>, de::Error> {
318        match decode_term_tag(d)? {
319            0 => {
320                state_log.push("(var ".to_string());
321                let var_option = T::decode(d);
322                match var_option {
323                    Ok(var) => {
324                        state_log.push(format!("{})", var.text()));
325                        Ok(Term::Var(var.into()))
326                    }
327                    Err(error) => {
328                        state_log.push("parse error)".to_string());
329                        Err(error)
330                    }
331                }
332            }
333
334            1 => {
335                state_log.push("(delay ".to_string());
336                let term_option = Term::decode_debug(d, state_log);
337                match term_option {
338                    Ok(term) => {
339                        state_log.push(")".to_string());
340                        Ok(Term::Delay(Rc::new(term)))
341                    }
342                    Err(error) => {
343                        state_log.push(")".to_string());
344                        Err(error)
345                    }
346                }
347            }
348            2 => {
349                state_log.push("(lam ".to_string());
350
351                let var_option = T::binder_decode(d);
352                match var_option {
353                    Ok(var) => {
354                        state_log.push(var.text());
355                        let term_option = Term::decode_debug(d, state_log);
356                        match term_option {
357                            Ok(term) => {
358                                state_log.push(")".to_string());
359                                Ok(Term::Lambda {
360                                    parameter_name: var.into(),
361                                    body: Rc::new(term),
362                                })
363                            }
364                            Err(error) => {
365                                state_log.push(")".to_string());
366                                Err(error)
367                            }
368                        }
369                    }
370                    Err(error) => {
371                        state_log.push(")".to_string());
372                        Err(error)
373                    }
374                }
375            }
376            3 => {
377                state_log.push("[ ".to_string());
378
379                let function_term_option = Term::decode_debug(d, state_log);
380                match function_term_option {
381                    Ok(function) => {
382                        state_log.push(" ".to_string());
383                        let arg_term_option = Term::decode_debug(d, state_log);
384                        match arg_term_option {
385                            Ok(argument) => {
386                                state_log.push("]".to_string());
387                                Ok(Term::Apply {
388                                    function: Rc::new(function),
389                                    argument: Rc::new(argument),
390                                })
391                            }
392                            Err(error) => {
393                                state_log.push("]".to_string());
394                                Err(error)
395                            }
396                        }
397                    }
398                    Err(error) => {
399                        state_log.push(" not parsed]".to_string());
400                        Err(error)
401                    }
402                }
403            }
404            // Need size limit for Constant
405            4 => {
406                state_log.push("(con ".to_string());
407
408                let con_option = Constant::decode(d);
409                match con_option {
410                    Ok(constant) => {
411                        state_log.push(format!("{:?})", constant));
412                        Ok(Term::Constant(constant.into()))
413                    }
414                    Err(error) => {
415                        state_log.push("parse error)".to_string());
416                        Err(error)
417                    }
418                }
419            }
420            5 => {
421                state_log.push("(force ".to_string());
422                let term_option = Term::decode_debug(d, state_log);
423                match term_option {
424                    Ok(term) => {
425                        state_log.push(")".to_string());
426                        Ok(Term::Force(Rc::new(term)))
427                    }
428                    Err(error) => {
429                        state_log.push(")".to_string());
430                        Err(error)
431                    }
432                }
433            }
434            6 => {
435                state_log.push("(error)".to_string());
436                Ok(Term::Error)
437            }
438            7 => {
439                state_log.push("(builtin ".to_string());
440
441                let builtin_option = DefaultFunction::decode(d);
442                match builtin_option {
443                    Ok(builtin) => {
444                        state_log.push(format!("{builtin})"));
445                        Ok(Term::Builtin(builtin))
446                    }
447                    Err(error) => {
448                        state_log.push("parse error)".to_string());
449                        Err(error)
450                    }
451                }
452            }
453            8 => {
454                state_log.push("(constr ".to_string());
455
456                let tag = usize::decode(d)?;
457
458                let fields = d.decode_list_with_debug(
459                    |d, state_log| Term::<T>::decode_debug(d, state_log),
460                    state_log,
461                )?;
462
463                Ok(Term::Constr { tag, fields })
464            }
465            9 => {
466                state_log.push("(case ".to_string());
467                let constr = Term::<T>::decode_debug(d, state_log)?.into();
468
469                let branches = d.decode_list_with_debug(
470                    |d, state_log| Term::<T>::decode_debug(d, state_log),
471                    state_log,
472                )?;
473
474                Ok(Term::Case { constr, branches })
475            }
476            x => {
477                state_log.push("parse error".to_string());
478
479                let buffer_slice: Vec<u8> = d
480                    .buffer
481                    .to_vec()
482                    .iter()
483                    .skip(if d.pos > 5 { d.pos - 5 } else { 0 })
484                    .take(10)
485                    .cloned()
486                    .collect();
487
488                Err(de::Error::UnknownTermConstructor(
489                    x,
490                    if d.pos > 5 { 5 } else { d.pos },
491                    format!("{buffer_slice:02X?}"),
492                    d.pos,
493                    d.buffer.len(),
494                ))
495            }
496        }
497    }
498}
499
500/// Integers are typically smaller so we save space
501/// by encoding them in 7 bits and this allows it to be byte alignment agnostic.
502/// Strings and bytestrings span multiple bytes so using bytestring is
503/// the most effective encoding.
504/// i.e. A 17 or greater length byte array loses efficiency being encoded as
505/// a unsigned integer instead of a byte array
506impl Encode for Constant {
507    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
508        match self {
509            Constant::Integer(i) => {
510                encode_constant(&[0], e)?;
511                i.encode(e)?;
512            }
513
514            Constant::ByteString(bytes) => {
515                encode_constant(&[1], e)?;
516                bytes.encode(e)?;
517            }
518            Constant::String(s) => {
519                encode_constant(&[2], e)?;
520                s.encode(e)?;
521            }
522            Constant::Unit => encode_constant(&[3], e)?,
523            Constant::Bool(b) => {
524                encode_constant(&[4], e)?;
525                b.encode(e)?;
526            }
527            Constant::ProtoList(typ, list) => {
528                let mut type_encode = vec![7, 5];
529
530                encode_type(typ, &mut type_encode);
531
532                encode_constant(&type_encode, e)?;
533
534                e.encode_list_with(list, encode_constant_value)?;
535            }
536            Constant::ProtoPair(type1, type2, a, b) => {
537                let mut type_encode = vec![7, 7, 6];
538
539                encode_type(type1, &mut type_encode);
540
541                encode_type(type2, &mut type_encode);
542
543                encode_constant(&type_encode, e)?;
544                encode_constant_value(a, e)?;
545                encode_constant_value(b, e)?;
546            }
547            Constant::Data(data) => {
548                encode_constant(&[8], e)?;
549
550                let cbor = data
551                    .encode_fragment()
552                    .map_err(|err| en::Error::Message(err.to_string()))?;
553
554                cbor.encode(e)?;
555            }
556            Constant::Bls12_381G1Element(_) => {
557                encode_constant(&[9], e)?;
558
559                return Err(en::Error::Message(
560                    "BLS12-381 G1 points are not supported for flat encoding".to_string(),
561                ));
562            }
563            Constant::Bls12_381G2Element(_) => {
564                encode_constant(&[10], e)?;
565
566                return Err(en::Error::Message(
567                    "BLS12-381 G2 points are not supported for flat encoding".to_string(),
568                ));
569            }
570            Constant::Bls12_381MlResult(_) => {
571                encode_constant(&[11], e)?;
572
573                return Err(en::Error::Message(
574                    "BLS12-381 ML results are not supported for flat encoding".to_string(),
575                ));
576            }
577        }
578
579        Ok(())
580    }
581}
582
583fn encode_constant_value(x: &Constant, e: &mut Encoder) -> Result<(), en::Error> {
584    match x {
585        Constant::Integer(x) => x.encode(e),
586        Constant::ByteString(b) => b.encode(e),
587        Constant::String(s) => s.encode(e),
588        Constant::Unit => Ok(()),
589        Constant::Bool(b) => b.encode(e),
590        Constant::ProtoList(_, list) => {
591            e.encode_list_with(list, encode_constant_value)?;
592            Ok(())
593        }
594        Constant::ProtoPair(_, _, a, b) => {
595            encode_constant_value(a, e)?;
596
597            encode_constant_value(b, e)
598        }
599        Constant::Data(data) => {
600            let cbor = data
601                .encode_fragment()
602                .map_err(|err| en::Error::Message(err.to_string()))?;
603
604            cbor.encode(e)
605        }
606        Constant::Bls12_381G1Element(_) => Err(en::Error::Message(
607            "BLS12-381 G1 points are not supported for flat encoding".to_string(),
608        )),
609        Constant::Bls12_381G2Element(_) => Err(en::Error::Message(
610            "BLS12-381 G2 points are not supported for flat encoding".to_string(),
611        )),
612        Constant::Bls12_381MlResult(_) => Err(en::Error::Message(
613            "BLS12-381 ML results are not supported for flat encoding".to_string(),
614        )),
615    }
616}
617
618fn encode_type(typ: &Type, bytes: &mut Vec<u8>) {
619    match typ {
620        Type::Integer => bytes.push(0),
621        Type::ByteString => bytes.push(1),
622        Type::String => bytes.push(2),
623        Type::Unit => bytes.push(3),
624        Type::Bool => bytes.push(4),
625        Type::List(sub_typ) => {
626            bytes.extend(vec![7, 5]);
627            encode_type(sub_typ, bytes);
628        }
629        Type::Pair(type1, type2) => {
630            bytes.extend(vec![7, 7, 6]);
631            encode_type(type1, bytes);
632            encode_type(type2, bytes);
633        }
634        Type::Data => bytes.push(8),
635        Type::Bls12_381G1Element => bytes.push(9),
636        Type::Bls12_381G2Element => bytes.push(10),
637        Type::Bls12_381MlResult => bytes.push(11),
638    }
639}
640
641impl Decode<'_> for Constant {
642    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
643        match &decode_constant(d)?[..] {
644            [0] => Ok(Constant::Integer(BigInt::decode(d)?)),
645            [1] => Ok(Constant::ByteString(Vec::<u8>::decode(d)?)),
646            [2] => Ok(Constant::String(String::decode(d)?)),
647            [3] => Ok(Constant::Unit),
648            [4] => Ok(Constant::Bool(bool::decode(d)?)),
649            [7, 5, rest @ ..] => {
650                let mut rest = VecDeque::from(rest.to_vec());
651
652                let typ = decode_type(&mut rest)?;
653
654                let list: Vec<Constant> =
655                    d.decode_list_with(|d| decode_constant_value(typ.clone().into(), d))?;
656
657                Ok(Constant::ProtoList(typ, list))
658            }
659            [7, 7, 6, rest @ ..] => {
660                let mut rest = VecDeque::from(rest.to_vec());
661
662                let type1 = decode_type(&mut rest)?;
663                let type2 = decode_type(&mut rest)?;
664
665                let a = decode_constant_value(type1.clone().into(), d)?;
666                let b = decode_constant_value(type2.clone().into(), d)?;
667
668                Ok(Constant::ProtoPair(type1, type2, a.into(), b.into()))
669            }
670            [8] => {
671                let cbor = Vec::<u8>::decode(d)?;
672
673                let data = PlutusData::decode_fragment(&cbor)
674                    .map_err(|err| de::Error::Message(err.to_string()))?;
675
676                Ok(Constant::Data(data))
677            }
678            [9] => {
679                let p1 = Vec::<u8>::decode(d)?;
680
681                let _p1 = blst::blst_p1::uncompress(&p1).map_err(|err| {
682                    de::Error::Message(format!("Failed to uncompress p1: {}", err))
683                })?;
684
685                Err(de::Error::Message(
686                    "BLS12-381 G1 points are not supported for flat decoding.".to_string(),
687                ))
688            }
689
690            [10] => {
691                let p2 = Vec::<u8>::decode(d)?;
692
693                let _p2 = blst::blst_p2::uncompress(&p2).map_err(|err| {
694                    de::Error::Message(format!("Failed to uncompress p2: {}", err))
695                })?;
696
697                Err(de::Error::Message(
698                    "BLS12-381 G2 points are not supported for flat decoding.".to_string(),
699                ))
700            }
701            [11] => Err(de::Error::Message(
702                "BLS12-381 ML results are not supported for flat decoding".to_string(),
703            )),
704            x => Err(de::Error::Message(format!(
705                "Unknown constant constructor tag: {x:?}"
706            ))),
707        }
708    }
709}
710
711fn decode_constant_value(typ: Rc<Type>, d: &mut Decoder) -> Result<Constant, de::Error> {
712    match typ.as_ref() {
713        Type::Integer => Ok(Constant::Integer(BigInt::decode(d)?)),
714        Type::ByteString => Ok(Constant::ByteString(Vec::<u8>::decode(d)?)),
715        Type::String => Ok(Constant::String(String::decode(d)?)),
716        Type::Unit => Ok(Constant::Unit),
717        Type::Bool => Ok(Constant::Bool(bool::decode(d)?)),
718        Type::List(sub_type) => {
719            let list: Vec<Constant> =
720                d.decode_list_with(|d| decode_constant_value(sub_type.clone(), d))?;
721
722            Ok(Constant::ProtoList(sub_type.as_ref().clone(), list))
723        }
724        Type::Pair(type1, type2) => {
725            let a = decode_constant_value(type1.clone(), d)?;
726            let b = decode_constant_value(type2.clone(), d)?;
727
728            Ok(Constant::ProtoPair(
729                type1.as_ref().clone(),
730                type2.as_ref().clone(),
731                a.into(),
732                b.into(),
733            ))
734        }
735        Type::Data => {
736            let cbor = Vec::<u8>::decode(d)?;
737
738            let data = PlutusData::decode_fragment(&cbor)
739                .map_err(|err| de::Error::Message(err.to_string()))?;
740
741            Ok(Constant::Data(data))
742        }
743        Type::Bls12_381G1Element => {
744            let p1 = Vec::<u8>::decode(d)?;
745
746            let _p1 = blst::blst_p1::uncompress(&p1)
747                .map_err(|err| de::Error::Message(format!("Failed to uncompress p1: {}", err)))?;
748
749            Err(de::Error::Message(
750                "BLS12-381 G1 points are not supported for flat decoding.".to_string(),
751            ))
752        }
753        Type::Bls12_381G2Element => {
754            let p2 = Vec::<u8>::decode(d)?;
755
756            let _p2 = blst::blst_p2::uncompress(&p2)
757                .map_err(|err| de::Error::Message(format!("Failed to uncompress p2: {}", err)))?;
758
759            Err(de::Error::Message(
760                "BLS12-381 G2 points are not supported for flat decoding.".to_string(),
761            ))
762        }
763        Type::Bls12_381MlResult => Err(de::Error::Message(
764            "BLS12-381 ML results are not supported for flat decoding".to_string(),
765        )),
766    }
767}
768
769fn decode_type(types: &mut VecDeque<u8>) -> Result<Type, de::Error> {
770    match types.pop_front() {
771        Some(4) => Ok(Type::Bool),
772        Some(0) => Ok(Type::Integer),
773        Some(2) => Ok(Type::String),
774        Some(1) => Ok(Type::ByteString),
775        Some(3) => Ok(Type::Unit),
776        Some(8) => Ok(Type::Data),
777        Some(9) => Ok(Type::Bls12_381G1Element),
778        Some(10) => Ok(Type::Bls12_381G2Element),
779        Some(11) => Ok(Type::Bls12_381MlResult),
780        Some(7) => match types.pop_front() {
781            Some(5) => Ok(Type::List(decode_type(types)?.into())),
782            Some(7) => match types.pop_front() {
783                Some(6) => {
784                    let type1 = decode_type(types)?;
785                    let type2 = decode_type(types)?;
786
787                    Ok(Type::Pair(type1.into(), type2.into()))
788                }
789                Some(x) => Err(de::Error::Message(format!(
790                    "Unknown constant type tag: {x}"
791                ))),
792                None => Err(de::Error::Message("Unexpected empty buffer".to_string())),
793            },
794            Some(x) => Err(de::Error::Message(format!(
795                "Unknown constant type tag: {x}"
796            ))),
797            None => Err(de::Error::Message("Unexpected empty buffer".to_string())),
798        },
799
800        Some(x) => Err(de::Error::Message(format!(
801            "Unknown constant type tag: {x}"
802        ))),
803        None => Err(de::Error::Message("Unexpected empty buffer".to_string())),
804    }
805}
806
807impl Encode for Unique {
808    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
809        isize::from(*self).encode(e)?;
810
811        Ok(())
812    }
813}
814
815impl Decode<'_> for Unique {
816    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
817        Ok(isize::decode(d)?.into())
818    }
819}
820
821impl Encode for Name {
822    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
823        self.text.encode(e)?;
824        self.unique.encode(e)?;
825
826        Ok(())
827    }
828}
829
830impl Decode<'_> for Name {
831    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
832        Ok(Name {
833            text: String::decode(d)?,
834            unique: Unique::decode(d)?,
835        })
836    }
837}
838
839impl Binder<'_> for Name {
840    fn binder_encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
841        self.encode(e)?;
842
843        Ok(())
844    }
845
846    fn binder_decode(d: &mut Decoder) -> Result<Self, de::Error> {
847        Name::decode(d)
848    }
849
850    fn text(&self) -> String {
851        self.text.clone()
852    }
853}
854
855impl Encode for NamedDeBruijn {
856    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
857        self.text.encode(e)?;
858        self.index.encode(e)?;
859
860        Ok(())
861    }
862}
863
864impl Decode<'_> for NamedDeBruijn {
865    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
866        Ok(NamedDeBruijn {
867            text: String::decode(d)?,
868            index: DeBruijn::decode(d)?,
869        })
870    }
871}
872
873impl Binder<'_> for NamedDeBruijn {
874    fn binder_encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
875        self.text.encode(e)?;
876        self.index.encode(e)?;
877
878        Ok(())
879    }
880
881    fn binder_decode(d: &mut Decoder) -> Result<Self, de::Error> {
882        Ok(NamedDeBruijn {
883            text: String::decode(d)?,
884            index: DeBruijn::decode(d)?,
885        })
886    }
887
888    fn text(&self) -> String {
889        format!("{}_{}", &self.text, self.index)
890    }
891}
892
893impl Encode for DeBruijn {
894    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
895        usize::from(*self).encode(e)?;
896
897        Ok(())
898    }
899}
900
901impl Decode<'_> for DeBruijn {
902    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
903        Ok(usize::decode(d)?.into())
904    }
905}
906
907impl Binder<'_> for DeBruijn {
908    fn binder_encode(&self, _: &mut Encoder) -> Result<(), en::Error> {
909        Ok(())
910    }
911
912    fn binder_decode(_d: &mut Decoder) -> Result<Self, de::Error> {
913        Ok(DeBruijn::new(0))
914    }
915
916    fn text(&self) -> String {
917        format!("i_{}", self)
918    }
919}
920
921impl Encode for FakeNamedDeBruijn {
922    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
923        let index: DeBruijn = self.clone().into();
924
925        index.encode(e)?;
926
927        Ok(())
928    }
929}
930
931impl Decode<'_> for FakeNamedDeBruijn {
932    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
933        let index = DeBruijn::decode(d)?;
934
935        Ok(index.into())
936    }
937}
938
939impl Binder<'_> for FakeNamedDeBruijn {
940    fn binder_encode(&self, _: &mut Encoder) -> Result<(), en::Error> {
941        Ok(())
942    }
943
944    fn binder_decode(_d: &mut Decoder) -> Result<Self, de::Error> {
945        let index = DeBruijn::new(0);
946
947        Ok(index.into())
948    }
949
950    fn text(&self) -> String {
951        format!("{}_{}", self.0.text, self.0.index)
952    }
953}
954
955impl Encode for DefaultFunction {
956    fn encode(&self, e: &mut Encoder) -> Result<(), en::Error> {
957        e.bits(BUILTIN_TAG_WIDTH as i64, *self as u8);
958
959        Ok(())
960    }
961}
962
963impl Decode<'_> for DefaultFunction {
964    fn decode(d: &mut Decoder) -> Result<Self, de::Error> {
965        let builtin_tag = d.bits8(BUILTIN_TAG_WIDTH as usize)?;
966        builtin_tag.try_into()
967    }
968}
969
970fn encode_term_tag(tag: u8, e: &mut Encoder) -> Result<(), en::Error> {
971    safe_encode_bits(TERM_TAG_WIDTH, tag, e)
972}
973
974fn decode_term_tag(d: &mut Decoder) -> Result<u8, de::Error> {
975    d.bits8(TERM_TAG_WIDTH as usize)
976}
977
978fn safe_encode_bits(num_bits: u32, byte: u8, e: &mut Encoder) -> Result<(), en::Error> {
979    if 2_u8.pow(num_bits) <= byte {
980        Err(en::Error::Message(format!(
981            "Overflow detected, cannot fit {byte} in {num_bits} bits."
982        )))
983    } else {
984        e.bits(num_bits as i64, byte);
985        Ok(())
986    }
987}
988
989pub fn encode_constant(tag: &[u8], e: &mut Encoder) -> Result<(), en::Error> {
990    e.encode_list_with(tag, encode_constant_tag)?;
991
992    Ok(())
993}
994
995pub fn decode_constant(d: &mut Decoder) -> Result<Vec<u8>, de::Error> {
996    d.decode_list_with(decode_constant_tag)
997}
998
999pub fn encode_constant_tag(tag: &u8, e: &mut Encoder) -> Result<(), en::Error> {
1000    safe_encode_bits(CONST_TAG_WIDTH, *tag, e)
1001}
1002
1003pub fn decode_constant_tag(d: &mut Decoder) -> Result<u8, de::Error> {
1004    d.bits8(CONST_TAG_WIDTH as usize)
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009    use super::{Constant, Program, Term};
1010    use crate::pallas_codec::flat::Flat;
1011    use crate::uplc::ast::{Name, Type};
1012
1013    #[test]
1014    fn flat_encode_integer() {
1015        let program = Program::<Name> {
1016            version: (11, 22, 33),
1017            term: Term::Constant(Constant::Integer(11.into()).into()),
1018        };
1019
1020        let expected_bytes = vec![
1021            0b00001011, 0b00010110, 0b00100001, 0b01001000, 0b00000101, 0b10000001,
1022        ];
1023
1024        let actual_bytes = program.to_flat().unwrap();
1025
1026        assert_eq!(actual_bytes, expected_bytes)
1027    }
1028
1029    #[test]
1030    fn flat_encode_list_list_integer() {
1031        let program = Program::<Name> {
1032            version: (1, 0, 0),
1033            term: Term::Constant(
1034                Constant::ProtoList(
1035                    Type::List(Type::Integer.into()),
1036                    vec![
1037                        Constant::ProtoList(Type::Integer, vec![Constant::Integer(7.into())]),
1038                        Constant::ProtoList(Type::Integer, vec![Constant::Integer(5.into())]),
1039                    ],
1040                )
1041                .into(),
1042            ),
1043        };
1044
1045        let expected_bytes = vec![
1046            0b00000001, 0b00000000, 0b00000000, 0b01001011, 0b11010110, 0b11110101, 0b10000011,
1047            0b00001110, 0b01100001, 0b01000001,
1048        ];
1049
1050        let actual_bytes = program.to_flat().unwrap();
1051
1052        assert_eq!(actual_bytes, expected_bytes)
1053    }
1054
1055    #[test]
1056    fn flat_encode_pair_pair_integer_bool_integer() {
1057        let program = Program::<Name> {
1058            version: (1, 0, 0),
1059            term: Term::Constant(
1060                Constant::ProtoPair(
1061                    Type::Pair(Type::Integer.into(), Type::Bool.into()),
1062                    Type::Integer,
1063                    Constant::ProtoPair(
1064                        Type::Integer,
1065                        Type::Bool,
1066                        Constant::Integer(11.into()).into(),
1067                        Constant::Bool(true).into(),
1068                    )
1069                    .into(),
1070                    Constant::Integer(11.into()).into(),
1071                )
1072                .into(),
1073            ),
1074        };
1075
1076        let expected_bytes = vec![
1077            0b00000001, 0b00000000, 0b00000000, 0b01001011, 0b11011110, 0b11010111, 0b10111101,
1078            0b10100001, 0b01001000, 0b00000101, 0b10100010, 0b11000001,
1079        ];
1080
1081        let actual_bytes = program.to_flat().unwrap();
1082
1083        assert_eq!(actual_bytes, expected_bytes)
1084    }
1085
1086    #[test]
1087    fn flat_decode_list_list_integer() {
1088        let bytes = vec![
1089            0b00000001, 0b00000000, 0b00000000, 0b01001011, 0b11010110, 0b11110101, 0b10000011,
1090            0b00001110, 0b01100001, 0b01000001,
1091        ];
1092
1093        let expected_program = Program::<Name> {
1094            version: (1, 0, 0),
1095            term: Term::Constant(
1096                Constant::ProtoList(
1097                    Type::List(Type::Integer.into()),
1098                    vec![
1099                        Constant::ProtoList(Type::Integer, vec![Constant::Integer(7.into())]),
1100                        Constant::ProtoList(Type::Integer, vec![Constant::Integer(5.into())]),
1101                    ],
1102                )
1103                .into(),
1104            ),
1105        };
1106
1107        let actual_program: Program<Name> = Program::unflat(&bytes).unwrap();
1108
1109        assert_eq!(actual_program, expected_program)
1110    }
1111
1112    #[test]
1113    fn flat_decode_pair_pair_integer_bool_integer() {
1114        let bytes = vec![
1115            0b00000001, 0b00000000, 0b00000000, 0b01001011, 0b11011110, 0b11010111, 0b10111101,
1116            0b10100001, 0b01001000, 0b00000101, 0b10100010, 0b11000001,
1117        ];
1118
1119        let expected_program = Program::<Name> {
1120            version: (1, 0, 0),
1121            term: Term::Constant(
1122                Constant::ProtoPair(
1123                    Type::Pair(Type::Integer.into(), Type::Bool.into()),
1124                    Type::Integer,
1125                    Constant::ProtoPair(
1126                        Type::Integer,
1127                        Type::Bool,
1128                        Constant::Integer(11.into()).into(),
1129                        Constant::Bool(true).into(),
1130                    )
1131                    .into(),
1132                    Constant::Integer(11.into()).into(),
1133                )
1134                .into(),
1135            ),
1136        };
1137
1138        let actual_program: Program<Name> = Program::unflat(&bytes).unwrap();
1139
1140        assert_eq!(actual_program, expected_program)
1141    }
1142
1143    #[test]
1144    fn flat_decode_integer() {
1145        let bytes = vec![
1146            0b00001011, 0b00010110, 0b00100001, 0b01001000, 0b00000101, 0b10000001,
1147        ];
1148
1149        let expected_program = Program {
1150            version: (11, 22, 33),
1151            term: Term::Constant(Constant::Integer(11.into()).into()),
1152        };
1153
1154        let actual_program: Program<Name> = Program::unflat(&bytes).unwrap();
1155
1156        assert_eq!(actual_program, expected_program)
1157    }
1158}