1mod encoder;
2mod error;
3
4pub use encoder::Encoder;
5pub use error::FlatEncodeError;
6
7use crate::{binder::Binder, constant::Constant, program::Program, term::Term, typ::Type};
8
9use super::tag;
10
11pub fn encode<'a, V>(program: &'a Program<'a, V>) -> Result<Vec<u8>, FlatEncodeError>
12where
13 V: Binder<'a>,
14{
15 let mut encoder = Encoder::default();
16
17 encoder
18 .word(program.version.major())
19 .word(program.version.minor())
20 .word(program.version.patch());
21
22 encode_term(&mut encoder, program.term)?;
23
24 encoder.filler();
25
26 Ok(encoder.buffer)
27}
28
29fn encode_term<'a, V>(encoder: &mut Encoder, term: &'a Term<'a, V>) -> Result<(), FlatEncodeError>
30where
31 V: Binder<'a>,
32{
33 match term {
34 Term::Var(name) => {
35 encode_term_tag(encoder, tag::VAR)?;
36
37 name.var_encode(encoder)?;
38 }
39 Term::Lambda { parameter, body } => {
40 encode_term_tag(encoder, tag::LAMBDA)?;
41
42 parameter.parameter_encode(encoder)?;
43
44 encode_term(encoder, body)?;
45 }
46 Term::Apply { function, argument } => {
47 encode_term_tag(encoder, tag::APPLY)?;
48
49 encode_term(encoder, function)?;
50
51 encode_term(encoder, argument)?;
52 }
53 Term::Delay(body) => {
54 encode_term_tag(encoder, tag::DELAY)?;
55
56 encode_term(encoder, body)?;
57 }
58 Term::Force(body) => {
59 encode_term_tag(encoder, tag::FORCE)?;
60
61 encode_term(encoder, body)?;
62 }
63 Term::Case { constr, branches } => {
64 encode_term_tag(encoder, tag::CASE)?;
65
66 encode_term(encoder, constr)?;
67
68 encoder.list_with(branches, |e, t| encode_term(e, t))?;
69 }
70 Term::Constr { tag, fields } => {
71 encode_term_tag(encoder, tag::CONSTR)?;
72
73 encoder.word(*tag);
74
75 encoder.list_with(fields, |e, t| encode_term(e, t))?;
76 }
77 Term::Constant(c) => {
78 encode_term_tag(encoder, tag::CONSTANT)?;
79
80 encode_constant(encoder, c)?;
81 }
82 Term::Builtin(b) => {
83 encode_term_tag(encoder, tag::BUILTIN)?;
84
85 encoder.bits(tag::BUILTIN_TAG_WIDTH as i64, **b as u8);
86 }
87 Term::Error => {
88 encode_term_tag(encoder, tag::ERROR)?;
89 }
90 }
91
92 Ok(())
93}
94
95fn encode_constant<'a>(e: &mut Encoder, constant: &'a Constant<'a>) -> Result<(), FlatEncodeError> {
96 match constant {
97 Constant::Integer(i) => {
98 e.list_with(&[tag::INTEGER], encode_constant_tag)?;
99
100 e.integer(i);
101 }
102 Constant::ByteString(b) => {
103 e.list_with(&[tag::BYTE_STRING], encode_constant_tag)?;
104
105 e.bytes(b)?;
106 }
107 Constant::String(s) => {
108 e.list_with(&[tag::STRING], encode_constant_tag)?;
109
110 e.utf8(s)?;
111 }
112 Constant::Unit => {
113 e.list_with(&[tag::UNIT], encode_constant_tag)?;
114 }
115 Constant::Boolean(b) => {
116 e.list_with(&[tag::BOOL], encode_constant_tag)?;
117
118 e.bool(*b);
119 }
120 Constant::Data(data) => {
121 e.list_with(&[tag::DATA], encode_constant_tag)?;
122
123 let data = minicbor::to_vec(*data)?;
124
125 e.bytes(&data)?;
126 }
127 Constant::ProtoList(typ, list) => {
128 let mut type_encodings = vec![tag::PROTO_LIST_ONE, tag::PROTO_LIST_TWO];
129
130 encode_type(typ, &mut type_encodings)?;
131
132 e.list_with(&type_encodings, encode_constant_tag)?;
133
134 e.list_with(list, encode_constant_value)?;
135 }
136 Constant::ProtoArray(typ, array) => {
137 let mut type_encodings = vec![tag::PROTO_ARRAY_ONE, tag::PROTO_ARRAY_TWO];
138
139 encode_type(typ, &mut type_encodings)?;
140
141 e.list_with(&type_encodings, encode_constant_tag)?;
142
143 e.list_with(array, encode_constant_value)?;
144 }
145 Constant::ProtoPair(fst_type, snd_type, fst, snd) => {
146 let mut type_encodings = vec![
147 tag::PROTO_PAIR_ONE,
148 tag::PROTO_PAIR_TWO,
149 tag::PROTO_PAIR_THREE,
150 ];
151
152 encode_type(fst_type, &mut type_encodings)?;
153
154 encode_type(snd_type, &mut type_encodings)?;
155
156 e.list_with(&type_encodings, encode_constant_tag)?;
157
158 encode_constant_value(e, fst)?;
159 encode_constant_value(e, snd)?;
160 }
161 Constant::Bls12_381G1Element(_)
162 | Constant::Bls12_381G2Element(_)
163 | Constant::Bls12_381MlResult(_) => return Err(FlatEncodeError::BlsElementNotSupported),
164 }
165
166 Ok(())
167}
168
169fn encode_term_tag(e: &mut Encoder, tag: u8) -> Result<(), FlatEncodeError> {
170 safe_encode_bits(e, tag::TERM_TAG_WIDTH, tag)
171}
172
173fn encode_constant_tag(e: &mut Encoder, tag: &u8) -> Result<(), FlatEncodeError> {
174 safe_encode_bits(e, tag::CONST_TAG_WIDTH, *tag)
175}
176
177fn encode_type(typ: &Type, bytes: &mut Vec<u8>) -> Result<(), FlatEncodeError> {
178 match typ {
179 Type::Integer => bytes.push(tag::INTEGER),
180 Type::ByteString => bytes.push(tag::BYTE_STRING),
181 Type::String => bytes.push(tag::STRING),
182 Type::Unit => bytes.push(tag::UNIT),
183 Type::Bool => bytes.push(tag::BOOL),
184 Type::List(sub_typ) => {
185 bytes.extend(vec![tag::PROTO_LIST_ONE, tag::PROTO_LIST_TWO]);
186
187 encode_type(sub_typ, bytes)?;
188 }
189 Type::Array(sub_typ) => {
190 bytes.extend(vec![tag::PROTO_ARRAY_ONE, tag::PROTO_ARRAY_TWO]);
191
192 encode_type(sub_typ, bytes)?;
193 }
194 Type::Pair(type1, type2) => {
195 bytes.extend(vec![
196 tag::PROTO_PAIR_ONE,
197 tag::PROTO_PAIR_TWO,
198 tag::PROTO_PAIR_THREE,
199 ]);
200
201 encode_type(type1, bytes)?;
202 encode_type(type2, bytes)?;
203 }
204 Type::Data => bytes.push(tag::DATA),
205 Type::Bls12_381G1Element | Type::Bls12_381G2Element | Type::Bls12_381MlResult => {
206 return Err(FlatEncodeError::BlsElementNotSupported)
207 }
208 }
209
210 Ok(())
211}
212
213fn encode_constant_value<'a>(e: &mut Encoder, x: &'a &Constant<'a>) -> Result<(), FlatEncodeError> {
214 match *x {
215 Constant::Integer(x) => {
216 e.integer(x);
217 }
218 Constant::ByteString(b) => {
219 e.bytes(b)?;
220 }
221 Constant::String(s) => {
222 e.utf8(s)?;
223 }
224 Constant::Unit => (),
225 Constant::Boolean(b) => {
226 e.bool(*b);
227 }
228 Constant::ProtoList(_, list) => {
229 e.list_with(list, encode_constant_value)?;
230 }
231 Constant::ProtoArray(_, array) => {
232 e.list_with(array, encode_constant_value)?;
233 }
234 Constant::ProtoPair(_, _, a, b) => {
235 encode_constant_value(e, a)?;
236
237 encode_constant_value(e, b)?;
238 }
239 Constant::Data(_data) => {
240 todo!();
241 }
242 Constant::Bls12_381G1Element(_)
243 | Constant::Bls12_381G2Element(_)
244 | Constant::Bls12_381MlResult(_) => return Err(FlatEncodeError::BlsElementNotSupported),
245 }
246
247 Ok(())
248}
249
250fn safe_encode_bits(e: &mut Encoder, num_bits: usize, byte: u8) -> Result<(), FlatEncodeError> {
251 if 2_u8.pow(num_bits as u32) <= byte {
252 Err(FlatEncodeError::Overflow { byte, num_bits })
253 } else {
254 e.bits(num_bits as i64, byte);
255
256 Ok(())
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::arena::Arena;
264 use crate::binder::DeBruijn;
265 use crate::flat::decode;
266 use crate::machine::PlutusVersion;
267
268 #[test]
269 fn roundtrip_program_big_constr_tag() {
270 let bytes_hex = "0101003370090011aab9d37549810cd8668218809f4100420101ff0001";
283 let bytes = hex::decode(bytes_hex).unwrap();
284 let arena = Arena::new();
285 let program: Result<&Program<DeBruijn>, _> = decode(&arena, &bytes, PlutusVersion::V3, 9);
286 match program {
287 Ok(program) => {
288 let encoded = encode(program);
289 match encoded {
290 Ok(roundtripped) => {
291 assert_eq!(bytes_hex, hex::encode(roundtripped));
292 }
293 Err(_) => {
294 panic!()
295 }
296 }
297 }
298 Err(_) => {
299 panic!();
300 }
301 }
302 }
303
304 #[test]
305 fn roundtrip_program_bigint() {
306 let bytes_hex =
323 "0101003370090011bad357426aae78dd526112d8799fc24c033b2e3c9fd0803ce7ffffffff0001";
324 let bytes = hex::decode(bytes_hex).unwrap();
325 let arena = Arena::new();
326 let program: Result<&Program<DeBruijn>, _> = decode(&arena, &bytes, PlutusVersion::V3, 9);
327 match program {
328 Ok(program) => {
329 let encoded = encode(program);
330 match encoded {
331 Ok(roundtripped) => {
332 assert_eq!(bytes_hex, hex::encode(roundtripped));
333 }
334 Err(e) => {
335 panic!("{}", e);
336 }
337 }
338 }
339 Err(e) => {
340 panic!("{}", e);
341 }
342 }
343 }
344
345 #[test]
346 fn roundtrip_program_list() {
347 let bytes_hex = "0101003370490021bad357426ae88dd62601049f070eff0001";
364 let bytes = hex::decode(bytes_hex).unwrap();
365 let arena = Arena::new();
366 let program: Result<&Program<DeBruijn>, _> = decode(&arena, &bytes, PlutusVersion::V3, 9);
367 match program {
368 Ok(program) => {
369 let encoded = encode(program);
370 match encoded {
371 Ok(roundtripped) => {
372 assert_eq!(bytes_hex, hex::encode(roundtripped));
373 }
374 Err(e) => {
375 panic!("{}", e);
376 }
377 }
378 }
379 Err(e) => {
380 panic!("{}", e);
381 }
382 }
383 }
384}