simplicity/bit_encoding/
decode.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! # Decoding
4//!
5//! Functionality to decode Simplicity programs.
6//! Refer to [`crate::encode`] for information on the encoding.
7
8use crate::dag::{Dag, DagLike, InternalSharing};
9use crate::jet::Jet;
10use crate::merkle::cmr::Cmr;
11use crate::node::{
12    ConstructNode, CoreConstructible, DisconnectConstructible, JetConstructible,
13    WitnessConstructible,
14};
15use crate::types;
16use crate::value::Word;
17use crate::{BitIter, FailEntropy};
18use std::collections::HashSet;
19use std::sync::Arc;
20use std::{error, fmt};
21
22use super::bititer::u2;
23
24type ArcNode<J> = Arc<ConstructNode<J>>;
25
26/// Decoding error
27#[non_exhaustive]
28#[derive(Debug)]
29pub enum Error {
30    /// Node made a back-reference past the beginning of the program
31    BadIndex,
32    /// Error closing the bitstream
33    BitIter(crate::BitIterCloseError),
34    /// Both children of a node are hidden
35    BothChildrenHidden,
36    /// Program must not be empty
37    EmptyProgram,
38    /// Bitstream ended early
39    EndOfStream,
40    /// Hidden node occurred outside of a case combinator
41    HiddenNode,
42    /// Tried to parse a jet but the name wasn't recognized
43    InvalidJet,
44    /// Number exceeded 32 bits
45    NaturalOverflow,
46    /// Program is not encoded in canonical order
47    NotInCanonicalOrder,
48    /// Program does not have maximal sharing
49    SharingNotMaximal,
50    /// Tried to allocate too many nodes in a program
51    TooManyNodes(usize),
52    /// Type-checking error
53    Type(crate::types::Error),
54}
55
56impl From<crate::EarlyEndOfStreamError> for Error {
57    fn from(_: crate::EarlyEndOfStreamError) -> Error {
58        Error::EndOfStream
59    }
60}
61
62impl From<crate::BitIterCloseError> for Error {
63    fn from(e: crate::BitIterCloseError) -> Error {
64        Error::BitIter(e)
65    }
66}
67
68impl From<crate::types::Error> for Error {
69    fn from(e: crate::types::Error) -> Error {
70        Error::Type(e)
71    }
72}
73
74impl fmt::Display for Error {
75    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76        match self {
77            Error::BadIndex => {
78                f.write_str("node made a back-reference past the beginning of the program")
79            }
80            Error::BitIter(ref e) => fmt::Display::fmt(e, f),
81            Error::BothChildrenHidden => f.write_str("both children of a case node are hidden"),
82            Error::EmptyProgram => f.write_str("empty program"),
83            Error::EndOfStream => f.write_str("bitstream ended early"),
84            Error::HiddenNode => write!(f, "hidden node occurred outside of a case combinator"),
85            Error::InvalidJet => write!(f, "unrecognized jet"),
86            Error::NaturalOverflow => f.write_str("encoded number exceeded 32 bits"),
87            Error::NotInCanonicalOrder => f.write_str("program not in canonical order"),
88            Error::SharingNotMaximal => f.write_str("Decoded programs must have maximal sharing"),
89            Error::TooManyNodes(k) => {
90                write!(f, "program has too many nodes ({})", k)
91            }
92            Error::Type(ref e) => fmt::Display::fmt(e, f),
93        }
94    }
95}
96
97impl error::Error for Error {
98    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
99        match *self {
100            Error::BadIndex => None,
101            Error::BitIter(ref e) => Some(e),
102            Error::BothChildrenHidden => None,
103            Error::EmptyProgram => None,
104            Error::EndOfStream => None,
105            Error::HiddenNode => None,
106            Error::InvalidJet => None,
107            Error::NaturalOverflow => None,
108            Error::NotInCanonicalOrder => None,
109            Error::SharingNotMaximal => None,
110            Error::TooManyNodes(..) => None,
111            Error::Type(ref e) => Some(e),
112        }
113    }
114}
115
116#[derive(Debug)]
117enum DecodeNode<J: Jet> {
118    Iden,
119    Unit,
120    InjL(usize),
121    InjR(usize),
122    Take(usize),
123    Drop(usize),
124    Comp(usize, usize),
125    Case(usize, usize),
126    Pair(usize, usize),
127    Disconnect1(usize),
128    Disconnect(usize, usize),
129    Witness,
130    Fail(FailEntropy),
131    Hidden(Cmr),
132    Jet(J),
133    Word(Word),
134}
135
136impl<J: Jet> DagLike for (usize, &'_ [DecodeNode<J>]) {
137    type Node = DecodeNode<J>;
138
139    fn data(&self) -> &DecodeNode<J> {
140        &self.1[self.0]
141    }
142
143    fn as_dag_node(&self) -> Dag<Self> {
144        let nodes = &self.1;
145        match self.1[self.0] {
146            DecodeNode::Iden
147            | DecodeNode::Unit
148            | DecodeNode::Fail(..)
149            | DecodeNode::Hidden(..)
150            | DecodeNode::Jet(..)
151            | DecodeNode::Word(..) => Dag::Nullary,
152            DecodeNode::InjL(i)
153            | DecodeNode::InjR(i)
154            | DecodeNode::Take(i)
155            | DecodeNode::Drop(i)
156            | DecodeNode::Disconnect1(i) => Dag::Unary((i, nodes)),
157            DecodeNode::Comp(li, ri)
158            | DecodeNode::Case(li, ri)
159            | DecodeNode::Pair(li, ri)
160            | DecodeNode::Disconnect(li, ri) => Dag::Binary((li, nodes), (ri, nodes)),
161            DecodeNode::Witness => Dag::Nullary,
162        }
163    }
164}
165
166pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
167    bits: &mut BitIter<I>,
168) -> Result<ArcNode<J>, Error> {
169    enum Converted<J: Jet> {
170        Node(ArcNode<J>),
171        Hidden(Cmr),
172    }
173    use Converted::{Hidden, Node};
174    impl<J: Jet> Converted<J> {
175        fn get(&self) -> Result<&ArcNode<J>, Error> {
176            match self {
177                Node(arc) => Ok(arc),
178                Hidden(_) => Err(Error::HiddenNode),
179            }
180        }
181    }
182
183    let len = bits.read_natural(None)?;
184
185    if len == 0 {
186        return Err(Error::EmptyProgram);
187    }
188    // FIXME: check maximum length of DAG that is allowed by consensus
189    if len > 1_000_000 {
190        return Err(Error::TooManyNodes(len));
191    }
192
193    let inference_context = types::Context::new();
194    let mut nodes = Vec::with_capacity(len);
195    for _ in 0..len {
196        let new_node = decode_node(bits, nodes.len())?;
197        nodes.push(new_node);
198    }
199
200    // It is a sharing violation for any hidden node to be repeated. Track them in this set.
201    let mut hidden_set = HashSet::<Cmr>::new();
202    // Convert the DecodeNode structure into a CommitNode structure
203    let mut converted = Vec::<Converted<J>>::with_capacity(len);
204    for data in (nodes.len() - 1, &nodes[..]).post_order_iter::<InternalSharing>() {
205        // Check canonical order as we go
206        if data.index != data.node.0 {
207            return Err(Error::NotInCanonicalOrder);
208        }
209
210        let new = match nodes[data.node.0] {
211            DecodeNode::Unit => Node(ArcNode::unit(&inference_context)),
212            DecodeNode::Iden => Node(ArcNode::iden(&inference_context)),
213            DecodeNode::InjL(i) => Node(ArcNode::injl(converted[i].get()?)),
214            DecodeNode::InjR(i) => Node(ArcNode::injr(converted[i].get()?)),
215            DecodeNode::Take(i) => Node(ArcNode::take(converted[i].get()?)),
216            DecodeNode::Drop(i) => Node(ArcNode::drop_(converted[i].get()?)),
217            DecodeNode::Comp(i, j) => {
218                Node(ArcNode::comp(converted[i].get()?, converted[j].get()?)?)
219            }
220            DecodeNode::Case(i, j) => {
221                // Case is a special case, since it uniquely is allowed to have hidden
222                // children (but only one!) in which case it becomes an assertion.
223                match (&converted[i], &converted[j]) {
224                    (Node(left), Node(right)) => Node(ArcNode::case(left, right)?),
225                    (Node(left), Hidden(cmr)) => Node(ArcNode::assertl(left, *cmr)?),
226                    (Hidden(cmr), Node(right)) => Node(ArcNode::assertr(*cmr, right)?),
227                    (Hidden(_), Hidden(_)) => return Err(Error::BothChildrenHidden),
228                }
229            }
230            DecodeNode::Pair(i, j) => {
231                Node(ArcNode::pair(converted[i].get()?, converted[j].get()?)?)
232            }
233            DecodeNode::Disconnect1(i) => Node(ArcNode::disconnect(converted[i].get()?, &None)?),
234            DecodeNode::Disconnect(i, j) => Node(ArcNode::disconnect(
235                converted[i].get()?,
236                &Some(Arc::clone(converted[j].get()?)),
237            )?),
238            DecodeNode::Witness => Node(ArcNode::witness(&inference_context, None)),
239            DecodeNode::Fail(entropy) => Node(ArcNode::fail(&inference_context, entropy)),
240            DecodeNode::Hidden(cmr) => {
241                if !hidden_set.insert(cmr) {
242                    return Err(Error::SharingNotMaximal);
243                }
244                Hidden(cmr)
245            }
246            DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)),
247            DecodeNode::Word(ref w) => {
248                Node(ArcNode::const_word(&inference_context, w.shallow_clone()))
249            }
250        };
251        converted.push(new);
252    }
253
254    converted[len - 1].get().map(Arc::clone)
255}
256
257/// Decode a single Simplicity node from bits and
258/// insert it into a hash map at its index for future reference by ancestor nodes.
259fn decode_node<I: Iterator<Item = u8>, J: Jet>(
260    bits: &mut BitIter<I>,
261    index: usize,
262) -> Result<DecodeNode<J>, Error> {
263    // First bit: 1 for jets/words, 0 for normal combinators
264    if bits.read_bit()? {
265        // Second bit: 1 for jets, 0 for words
266        if bits.read_bit()? {
267            J::decode(bits).map(|jet| DecodeNode::Jet(jet))
268        } else {
269            let n = bits.read_natural(Some(32))? as u32; // cast safety: decoded number is at most the number 32
270            let word = Word::from_bits(bits, n - 1)?;
271            Ok(DecodeNode::Word(word))
272        }
273    } else {
274        // Bits 2 and 3: code
275        match bits.read_u2()? {
276            u2::_0 => {
277                let subcode = bits.read_u2()?;
278                let i_abs = index - bits.read_natural(Some(index))?;
279                let j_abs = index - bits.read_natural(Some(index))?;
280
281                // Bits 4 and 5: subcode
282                match subcode {
283                    u2::_0 => Ok(DecodeNode::Comp(i_abs, j_abs)),
284                    u2::_1 => Ok(DecodeNode::Case(i_abs, j_abs)),
285                    u2::_2 => Ok(DecodeNode::Pair(i_abs, j_abs)),
286                    u2::_3 => Ok(DecodeNode::Disconnect(i_abs, j_abs)),
287                }
288            }
289            u2::_1 => {
290                let subcode = bits.read_u2()?;
291                let i_abs = index - bits.read_natural(Some(index))?;
292                // Bits 4 and 5: subcode
293                match subcode {
294                    u2::_0 => Ok(DecodeNode::InjL(i_abs)),
295                    u2::_1 => Ok(DecodeNode::InjR(i_abs)),
296                    u2::_2 => Ok(DecodeNode::Take(i_abs)),
297                    u2::_3 => Ok(DecodeNode::Drop(i_abs)),
298                }
299            }
300            u2::_2 => {
301                // Bits 4 and 5: subcode
302                match bits.read_u2()? {
303                    u2::_0 => Ok(DecodeNode::Iden),
304                    u2::_1 => Ok(DecodeNode::Unit),
305                    u2::_2 => Ok(DecodeNode::Fail(bits.read_fail_entropy()?)),
306                    u2::_3 => {
307                        let i_abs = index - bits.read_natural(Some(index))?;
308                        Ok(DecodeNode::Disconnect1(i_abs))
309                    }
310                }
311            }
312            u2::_3 => {
313                // Bit 4: subcode
314                if bits.read_bit()? {
315                    Ok(DecodeNode::Witness)
316                } else {
317                    Ok(DecodeNode::Hidden(bits.read_cmr()?))
318                }
319            }
320        }
321    }
322}
323
324/// Decode a natural number from bits.
325/// If a bound is specified, then the decoding terminates before trying to decode a larger number.
326pub fn decode_natural<I: Iterator<Item = bool>>(
327    iter: &mut I,
328    bound: Option<usize>,
329) -> Result<usize, Error> {
330    let mut recurse_depth = 0;
331    loop {
332        match iter.next() {
333            Some(true) => recurse_depth += 1,
334            Some(false) => break,
335            None => return Err(Error::EndOfStream),
336        }
337    }
338
339    let mut len = 0;
340    loop {
341        let mut n = 1;
342        for _ in 0..len {
343            let bit = match iter.next() {
344                Some(false) => 0,
345                Some(true) => 1,
346                None => return Err(Error::EndOfStream),
347            };
348            n = 2 * n + bit;
349        }
350
351        if recurse_depth == 0 {
352            if let Some(bound) = bound {
353                if n > bound {
354                    return Err(Error::BadIndex);
355                }
356            }
357            return Ok(n);
358        } else {
359            len = n;
360            if len > 31 {
361                return Err(Error::NaturalOverflow);
362            }
363            recurse_depth -= 1;
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::encode;
372    use crate::jet::Core;
373    use crate::node::{CommitNode, RedeemNode};
374    use crate::BitWriter;
375
376    #[test]
377    fn root_unit_to_unit() {
378        // main = jet_eq_32 :: 2^64 -> 2 # 7387d279
379        let justjet = [0x6d, 0xb8, 0x80];
380        // Should be able to decode this as an expression...
381        let mut iter = BitIter::from(&justjet[..]);
382        decode_expression::<_, Core>(&mut iter).unwrap();
383        // ...but NOT as a CommitNode
384        let iter = BitIter::from(&justjet[..]);
385        CommitNode::<Core>::decode(iter).unwrap_err();
386        // ...or as a RedeemNode
387        let iter = BitIter::from(&justjet[..]);
388        RedeemNode::<Core>::decode(iter, BitIter::from(&[][..])).unwrap_err();
389    }
390
391    #[test]
392    fn decode_fixed_natural() {
393        let tries = vec![
394            (1, vec![false]),
395            (2, vec![true, false, false]),
396            (3, vec![true, false, true]),
397            (4, vec![true, true, false, false, false, false]),
398            (5, vec![true, true, false, false, false, true]),
399            (6, vec![true, true, false, false, true, false]),
400            (7, vec![true, true, false, false, true, true]),
401            (8, vec![true, true, false, true, false, false, false]),
402            (15, vec![true, true, false, true, true, true, true]),
403            (
404                16,
405                vec![
406                    true, true, true, false, // len: 1
407                    false, // len: 2
408                    false, false, // len: 4
409                    false, false, false, false,
410                ],
411            ),
412            // 31
413            (
414                31,
415                vec![
416                    true, true, true, false, // len: 1
417                    false, // len: 2
418                    false, false, // len: 4
419                    true, true, true, true,
420                ],
421            ),
422            // 32
423            (
424                32,
425                vec![
426                    true, true, true, false, // len: 1
427                    false, // len: 2
428                    false, true, // len: 5
429                    false, false, false, false, false,
430                ],
431            ),
432            // 2^15
433            (
434                32768,
435                vec![
436                    true, true, true, false, // len: 1
437                    true,  // len: 3
438                    true, true, true, // len: 15
439                    false, false, false, false, false, false, false, false, false, false, false,
440                    false, false, false, false,
441                ],
442            ),
443            (
444                65535,
445                vec![
446                    true, true, true, false, // len: 1
447                    true,  // len: 3
448                    true, true, true, // len: 15
449                    true, true, true, true, true, true, true, true, true, true, true, true, true,
450                    true, true,
451                ],
452            ),
453            (
454                65536,
455                vec![
456                    true, true, true, true, false, // len: 1
457                    false, // len: 2
458                    false, false, // len: 4
459                    false, false, false, false, // len: 16
460                    false, false, false, false, false, false, false, false, false, false, false,
461                    false, false, false, false, false,
462                ],
463            ),
464        ];
465
466        for (natural, bitvec) in tries {
467            let truncated = bitvec[0..bitvec.len() - 1].to_vec();
468            assert!(matches!(
469                decode_natural(&mut truncated.into_iter(), None),
470                Err(Error::EndOfStream)
471            ));
472
473            let mut sink = Vec::<u8>::new();
474
475            let mut w = BitWriter::from(&mut sink);
476            encode::encode_natural(natural, &mut w).expect("encoding to vector");
477            w.flush_all().expect("flushing");
478            assert_eq!(w.n_total_written(), bitvec.len());
479
480            let decoded_natural = decode_natural(&mut BitIter::from(sink.into_iter()), None)
481                .expect("decoding from vector");
482            assert_eq!(natural, decoded_natural);
483        }
484    }
485}