Skip to main content

amaru_uplc/flat/encode/
mod.rs

1mod encoder;
2mod error;
3
4pub use encoder::Encoder;
5pub use error::FlatEncodeError;
6
7use crate::{binder::Binder, constant::Constant, program::Program, term::Term, typ::Type};
8
9use super::tag;
10
11pub fn encode<'a, V>(program: &'a Program<'a, V>) -> Result<Vec<u8>, FlatEncodeError>
12where
13    V: Binder<'a>,
14{
15    let mut encoder = Encoder::default();
16
17    encoder
18        .word(program.version.major())
19        .word(program.version.minor())
20        .word(program.version.patch());
21
22    encode_term(&mut encoder, program.term)?;
23
24    encoder.filler();
25
26    Ok(encoder.buffer)
27}
28
29fn encode_term<'a, V>(encoder: &mut Encoder, term: &'a Term<'a, V>) -> Result<(), FlatEncodeError>
30where
31    V: Binder<'a>,
32{
33    match term {
34        Term::Var(name) => {
35            encode_term_tag(encoder, tag::VAR)?;
36
37            name.var_encode(encoder)?;
38        }
39        Term::Lambda { parameter, body } => {
40            encode_term_tag(encoder, tag::LAMBDA)?;
41
42            parameter.parameter_encode(encoder)?;
43
44            encode_term(encoder, body)?;
45        }
46        Term::Apply { function, argument } => {
47            encode_term_tag(encoder, tag::APPLY)?;
48
49            encode_term(encoder, function)?;
50
51            encode_term(encoder, argument)?;
52        }
53        Term::Delay(body) => {
54            encode_term_tag(encoder, tag::DELAY)?;
55
56            encode_term(encoder, body)?;
57        }
58        Term::Force(body) => {
59            encode_term_tag(encoder, tag::FORCE)?;
60
61            encode_term(encoder, body)?;
62        }
63        Term::Case { constr, branches } => {
64            encode_term_tag(encoder, tag::CASE)?;
65
66            encode_term(encoder, constr)?;
67
68            encoder.list_with(branches, |e, t| encode_term(e, t))?;
69        }
70        Term::Constr { tag, fields } => {
71            encode_term_tag(encoder, tag::CONSTR)?;
72
73            encoder.word(*tag);
74
75            encoder.list_with(fields, |e, t| encode_term(e, t))?;
76        }
77        Term::Constant(c) => {
78            encode_term_tag(encoder, tag::CONSTANT)?;
79
80            encode_constant(encoder, c)?;
81        }
82        Term::Builtin(b) => {
83            encode_term_tag(encoder, tag::BUILTIN)?;
84
85            encoder.bits(tag::BUILTIN_TAG_WIDTH as i64, **b as u8);
86        }
87        Term::Error => {
88            encode_term_tag(encoder, tag::ERROR)?;
89        }
90    }
91
92    Ok(())
93}
94
95fn encode_constant<'a>(e: &mut Encoder, constant: &'a Constant<'a>) -> Result<(), FlatEncodeError> {
96    match constant {
97        Constant::Integer(i) => {
98            e.list_with(&[tag::INTEGER], encode_constant_tag)?;
99
100            e.integer(i);
101        }
102        Constant::ByteString(b) => {
103            e.list_with(&[tag::BYTE_STRING], encode_constant_tag)?;
104
105            e.bytes(b)?;
106        }
107        Constant::String(s) => {
108            e.list_with(&[tag::STRING], encode_constant_tag)?;
109
110            e.utf8(s)?;
111        }
112        Constant::Unit => {
113            e.list_with(&[tag::UNIT], encode_constant_tag)?;
114        }
115        Constant::Boolean(b) => {
116            e.list_with(&[tag::BOOL], encode_constant_tag)?;
117
118            e.bool(*b);
119        }
120        Constant::Data(data) => {
121            e.list_with(&[tag::DATA], encode_constant_tag)?;
122
123            let data = minicbor::to_vec(*data)?;
124
125            e.bytes(&data)?;
126        }
127        Constant::ProtoList(typ, list) => {
128            let mut type_encodings = vec![tag::PROTO_LIST_ONE, tag::PROTO_LIST_TWO];
129
130            encode_type(typ, &mut type_encodings)?;
131
132            e.list_with(&type_encodings, encode_constant_tag)?;
133
134            e.list_with(list, encode_constant_value)?;
135        }
136        Constant::ProtoArray(typ, array) => {
137            let mut type_encodings = vec![tag::PROTO_ARRAY_ONE, tag::PROTO_ARRAY_TWO];
138
139            encode_type(typ, &mut type_encodings)?;
140
141            e.list_with(&type_encodings, encode_constant_tag)?;
142
143            e.list_with(array, encode_constant_value)?;
144        }
145        Constant::ProtoPair(fst_type, snd_type, fst, snd) => {
146            let mut type_encodings = vec![
147                tag::PROTO_PAIR_ONE,
148                tag::PROTO_PAIR_TWO,
149                tag::PROTO_PAIR_THREE,
150            ];
151
152            encode_type(fst_type, &mut type_encodings)?;
153
154            encode_type(snd_type, &mut type_encodings)?;
155
156            e.list_with(&type_encodings, encode_constant_tag)?;
157
158            encode_constant_value(e, fst)?;
159            encode_constant_value(e, snd)?;
160        }
161        Constant::Bls12_381G1Element(_)
162        | Constant::Bls12_381G2Element(_)
163        | Constant::Bls12_381MlResult(_) => return Err(FlatEncodeError::BlsElementNotSupported),
164    }
165
166    Ok(())
167}
168
169fn encode_term_tag(e: &mut Encoder, tag: u8) -> Result<(), FlatEncodeError> {
170    safe_encode_bits(e, tag::TERM_TAG_WIDTH, tag)
171}
172
173fn encode_constant_tag(e: &mut Encoder, tag: &u8) -> Result<(), FlatEncodeError> {
174    safe_encode_bits(e, tag::CONST_TAG_WIDTH, *tag)
175}
176
177fn encode_type(typ: &Type, bytes: &mut Vec<u8>) -> Result<(), FlatEncodeError> {
178    match typ {
179        Type::Integer => bytes.push(tag::INTEGER),
180        Type::ByteString => bytes.push(tag::BYTE_STRING),
181        Type::String => bytes.push(tag::STRING),
182        Type::Unit => bytes.push(tag::UNIT),
183        Type::Bool => bytes.push(tag::BOOL),
184        Type::List(sub_typ) => {
185            bytes.extend(vec![tag::PROTO_LIST_ONE, tag::PROTO_LIST_TWO]);
186
187            encode_type(sub_typ, bytes)?;
188        }
189        Type::Array(sub_typ) => {
190            bytes.extend(vec![tag::PROTO_ARRAY_ONE, tag::PROTO_ARRAY_TWO]);
191
192            encode_type(sub_typ, bytes)?;
193        }
194        Type::Pair(type1, type2) => {
195            bytes.extend(vec![
196                tag::PROTO_PAIR_ONE,
197                tag::PROTO_PAIR_TWO,
198                tag::PROTO_PAIR_THREE,
199            ]);
200
201            encode_type(type1, bytes)?;
202            encode_type(type2, bytes)?;
203        }
204        Type::Data => bytes.push(tag::DATA),
205        Type::Bls12_381G1Element | Type::Bls12_381G2Element | Type::Bls12_381MlResult => {
206            return Err(FlatEncodeError::BlsElementNotSupported)
207        }
208    }
209
210    Ok(())
211}
212
213fn encode_constant_value<'a>(e: &mut Encoder, x: &'a &Constant<'a>) -> Result<(), FlatEncodeError> {
214    match *x {
215        Constant::Integer(x) => {
216            e.integer(x);
217        }
218        Constant::ByteString(b) => {
219            e.bytes(b)?;
220        }
221        Constant::String(s) => {
222            e.utf8(s)?;
223        }
224        Constant::Unit => (),
225        Constant::Boolean(b) => {
226            e.bool(*b);
227        }
228        Constant::ProtoList(_, list) => {
229            e.list_with(list, encode_constant_value)?;
230        }
231        Constant::ProtoArray(_, array) => {
232            e.list_with(array, encode_constant_value)?;
233        }
234        Constant::ProtoPair(_, _, a, b) => {
235            encode_constant_value(e, a)?;
236
237            encode_constant_value(e, b)?;
238        }
239        Constant::Data(_data) => {
240            todo!();
241        }
242        Constant::Bls12_381G1Element(_)
243        | Constant::Bls12_381G2Element(_)
244        | Constant::Bls12_381MlResult(_) => return Err(FlatEncodeError::BlsElementNotSupported),
245    }
246
247    Ok(())
248}
249
250fn safe_encode_bits(e: &mut Encoder, num_bits: usize, byte: u8) -> Result<(), FlatEncodeError> {
251    if 2_u8.pow(num_bits as u32) <= byte {
252        Err(FlatEncodeError::Overflow { byte, num_bits })
253    } else {
254        e.bits(num_bits as i64, byte);
255
256        Ok(())
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::arena::Arena;
264    use crate::binder::DeBruijn;
265    use crate::flat::decode;
266    use crate::machine::PlutusVersion;
267
268    #[test]
269    fn roundtrip_program_big_constr_tag() {
270        // (program 1.1.0
271        //   [
272        //     [
273        //       (builtin addInteger)
274        //       (con integer 1)
275        //     ]
276        //     [ (force (force (builtin fstPair)))
277        //       [ (builtin unConstrData)
278        //         (con data (Constr 128 [B #00, B #0101]))
279        //       ]
280        //     ]
281        //   ])
282        let bytes_hex = "0101003370090011aab9d37549810cd8668218809f4100420101ff0001";
283        let bytes = hex::decode(bytes_hex).unwrap();
284        let arena = Arena::new();
285        let program: Result<&Program<DeBruijn>, _> = decode(&arena, &bytes, PlutusVersion::V3, 9);
286        match program {
287            Ok(program) => {
288                let encoded = encode(program);
289                match encoded {
290                    Ok(roundtripped) => {
291                        assert_eq!(bytes_hex, hex::encode(roundtripped));
292                    }
293                    Err(_) => {
294                        panic!()
295                    }
296                }
297            }
298            Err(_) => {
299                panic!();
300            }
301        }
302    }
303
304    #[test]
305    fn roundtrip_program_bigint() {
306        // (program 1.1.0
307        //   [
308        //     [
309        //       (builtin addInteger)
310        //       (con integer 1)
311        //     ]
312        //     [ (builtin unIData)
313        //       [ (force (builtin headList))
314        //         [ (force (force (builtin sndPair)))
315        //           [ (builtin unConstrData)
316        //             (con data (Constr 0 [I 999999999999999999999999999]))
317        //           ]
318        //         ]
319        //       ]
320        //     ]
321        //   ])
322        let bytes_hex =
323            "0101003370090011bad357426aae78dd526112d8799fc24c033b2e3c9fd0803ce7ffffffff0001";
324        let bytes = hex::decode(bytes_hex).unwrap();
325        let arena = Arena::new();
326        let program: Result<&Program<DeBruijn>, _> = decode(&arena, &bytes, PlutusVersion::V3, 9);
327        match program {
328            Ok(program) => {
329                let encoded = encode(program);
330                match encoded {
331                    Ok(roundtripped) => {
332                        assert_eq!(bytes_hex, hex::encode(roundtripped));
333                    }
334                    Err(e) => {
335                        panic!("{}", e);
336                    }
337                }
338            }
339            Err(e) => {
340                panic!("{}", e);
341            }
342        }
343    }
344
345    #[test]
346    fn roundtrip_program_list() {
347        // (program 1.1.0
348        //   [
349        //     [
350        //       (builtin multiplyInteger)
351        //       (con integer 2)
352        //     ]
353        //     [ (builtin unIData)
354        //       [ (force (builtin headList))
355        //         [ (force (builtin tailList))
356        //           [ (builtin unListData)
357        //             (con data (List [I 7, I 14]))
358        //           ]
359        //         ]
360        //       ]
361        //     ]
362        //   ])
363        let bytes_hex = "0101003370490021bad357426ae88dd62601049f070eff0001";
364        let bytes = hex::decode(bytes_hex).unwrap();
365        let arena = Arena::new();
366        let program: Result<&Program<DeBruijn>, _> = decode(&arena, &bytes, PlutusVersion::V3, 9);
367        match program {
368            Ok(program) => {
369                let encoded = encode(program);
370                match encoded {
371                    Ok(roundtripped) => {
372                        assert_eq!(bytes_hex, hex::encode(roundtripped));
373                    }
374                    Err(e) => {
375                        panic!("{}", e);
376                    }
377                }
378            }
379            Err(e) => {
380                panic!("{}", e);
381            }
382        }
383    }
384}