Skip to main content

simplicity/
analysis.rs

1// SPDX-License-Identifier: CC0-1.0
2
3use crate::jet::Jet;
4use std::{cmp, fmt};
5
6use crate::value::Word;
7#[cfg(feature = "elements")]
8use elements::encode::Encodable;
9#[cfg(feature = "serde")]
10use serde::Serialize;
11#[cfg(feature = "elements")]
12use std::{convert::TryFrom, io};
13
14/// Copy of [`bitcoin::Weight`] that uses [`u32`] instead of [`u64`].
15///
16/// This struct is useful for conversions between [`bitcoin::Weight`]
17/// (which uses [`u64`]) and [`Cost`] (which uses [`u32`]).
18#[cfg(feature = "bitcoin")]
19#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
20struct U32Weight(u32);
21
22#[cfg(feature = "bitcoin")]
23impl std::ops::Sub for U32Weight {
24    type Output = Self;
25
26    fn sub(self, rhs: Self) -> Self::Output {
27        Self(self.0.saturating_sub(rhs.0))
28    }
29}
30
31#[cfg(feature = "bitcoin")]
32impl From<bitcoin::Weight> for U32Weight {
33    fn from(value: bitcoin::Weight) -> Self {
34        Self(u32::try_from(value.to_wu()).unwrap_or(u32::MAX))
35    }
36}
37
38#[cfg(feature = "bitcoin")]
39impl From<U32Weight> for bitcoin::Weight {
40    fn from(value: U32Weight) -> Self {
41        bitcoin::Weight::from_wu(u64::from(value.0))
42    }
43}
44
45/// CPU cost of a Simplicity expression.
46///
47/// The cost is measured in milli weight units
48/// and can be converted into weight units using the appropriate method.
49///
50/// Roughly speaking, the operational semantics of a combinator
51/// on the Bit Machine determine its cost.
52///
53/// First, every combinator has a fixed overhead cost.
54/// Frame allocations, copy and write operations cost proportional
55/// to the number of allocated or written bits.
56/// Frame moves / drops or cursor moves are one-step operations
57/// that are covered by the overhead.
58///
59/// The cost of a program is compared to its _budget_.
60/// A program is valid if it does not exceed its budget.
61///
62/// The budget is the size of the witness stack
63/// of the transaction input that includes the program.
64/// Users pay for their Simplicity programs in terms of fees
65/// which are based on transaction size, like normal Tapscript.
66///
67/// Programs that are CPU-heavy need to be padded
68/// so that the witness stack provides a large-enough budget.
69#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
70#[cfg_attr(feature = "serde", derive(Serialize))]
71pub struct Cost(u32);
72
73impl Cost {
74    /// Overhead constant.
75    ///
76    /// Every combinator that is executed has this overhead added to its cost.
77    const OVERHEAD: Self = Cost(100);
78
79    /// Cost of combinators that are never executed.
80    ///
81    /// **This should only be used for `fail` nodes!**
82    const NEVER_EXECUTED: Self = Cost(0);
83
84    /// Maximum cost allowed by consensus.
85    ///
86    /// This is equal to the maximum budget that any program
87    /// can have inside a Taproot transaction:
88    /// 4 million weight units plus 50 free weight units for validation.
89    ///
90    /// This assumes a block that consists of a single transaction
91    /// which in turn consists of nothing but its witness stack.
92    ///
93    /// Transactions include other data besides the witness stack.
94    /// Also, transaction may have multiple inputs and
95    /// blocks usually include multiple transactions.
96    /// This means that the maximum budget is an unreachable upper bound.
97    pub const CONSENSUS_MAX: Self = Cost(4_000_050_000);
98
99    /// Return the cost of a type with the given bit width.
100    pub const fn of_type(bit_width: usize) -> Self {
101        // Cast safety: bit width cannot be more than 2^32 - 1
102        Cost(bit_width as u32)
103    }
104
105    /// Convert the given milli weight units into cost.
106    pub const fn from_milliweight(milliweight: u32) -> Self {
107        Cost(milliweight)
108    }
109
110    /// Return whether the cost is allowed by consensus.
111    ///
112    /// This means the cost is within the maximum budget
113    /// that any program inside a Taproot transaction can have.
114    pub fn is_consensus_valid(self) -> bool {
115        self <= Self::CONSENSUS_MAX
116    }
117
118    /// Return the budget of the given script witness of a transaction output.
119    ///
120    /// The script witness is passed as `&Vec<Vec<u8>>` in order to use
121    /// the consensus encoding implemented for this type.
122    #[cfg(feature = "elements")]
123    fn get_budget(script_witness: &Vec<Vec<u8>>) -> U32Weight {
124        let mut sink = io::sink();
125        let witness_stack_serialized_len = script_witness
126            .consensus_encode(&mut sink)
127            .expect("writing to sink never fails");
128        let budget = u32::try_from(witness_stack_serialized_len)
129            .expect("Serialized witness stack must be shorter than 2^32 elements")
130            .saturating_add(50);
131        U32Weight(budget)
132    }
133
134    /// Return whether the cost is within the budget of
135    /// the given script witness of a transaction input.
136    ///
137    /// The script witness is passed as `&Vec<Vec<u8>>` in order to use
138    /// the consensus encoding implemented for this type.
139    #[cfg(feature = "elements")]
140    pub fn is_budget_valid(self, script_witness: &Vec<Vec<u8>>) -> bool {
141        let budget = Self::get_budget(script_witness);
142        self.0 <= budget.0.saturating_mul(1000)
143    }
144
145    /// Return the annex bytes that are required as padding
146    /// so the transaction input has enough budget to cover the cost.
147    ///
148    /// The first annex byte is 0x50, as defined in BIP 341.
149    /// The following padding bytes are 0x00.
150    #[cfg(feature = "elements")]
151    pub fn get_padding(self, script_witness: &Vec<Vec<u8>>) -> Option<Vec<u8>> {
152        let weight = U32Weight::from(self);
153        let budget = Self::get_budget(script_witness);
154        if weight <= budget {
155            return None;
156        }
157
158        // Adding the annex to the witness stack increases the serialized size by:
159        //
160        // 1. CompactSize(annex_len): the length prefix of the annex item
161        // 2. annex_len: the annex bytes themselves (0x50 tag + zero padding)
162        //
163        // CompactSize uses 1 byte for values <= 252, 3 bytes for <= 65535,
164        // and 5 bytes for larger values. The overhead subtracted must account
165        // for the actual CompactSize encoding length of the resulting annex.
166        let deficit = (weight - budget).0 as usize; // cast safety: 32-bit machine or higher
167
168        // overhead = compact_size_len + 1 (for 0x50 tag)
169        let padding_len = match deficit {
170            // annex_len <= 252, compact_size uses 1 byte, overhead = 2
171            0..=253 => deficit.saturating_sub(2),
172            // Boundary region: annex must be >= 253 bytes (3-byte compact_size),
173            // but deficit - 4 < 252. Use minimum padding for 3-byte encoding.
174            254..=255 => 252,
175            // annex_len in 253..=65535, compact_size uses 3 bytes, overhead = 4
176            256..=65538 => deficit - 4,
177            // Boundary region for 5-byte compact_size encoding.
178            65539..=65540 => 65535,
179            // annex_len >= 65536, compact_size uses 5 bytes, overhead = 6
180            _ => deficit - 6,
181            // Note: the 9-byte compact_size boundary (deficit > 4_294_967_300)
182            // is unreachable because Cost uses u32 milliweight, limiting the
183            // maximum deficit to ~4_294_968 weight units.
184        };
185        let annex_bytes: Vec<u8> = std::iter::once(0x50)
186            .chain(std::iter::repeat(0x00).take(padding_len))
187            .collect();
188
189        Some(annex_bytes)
190    }
191}
192
193impl fmt::Display for Cost {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        fmt::Display::fmt(&self.0, f)
196    }
197}
198
199impl std::ops::Add for Cost {
200    type Output = Self;
201
202    fn add(self, rhs: Self) -> Self::Output {
203        Cost(self.0.saturating_add(rhs.0))
204    }
205}
206
207#[cfg(feature = "bitcoin")]
208impl From<U32Weight> for Cost {
209    fn from(value: U32Weight) -> Self {
210        Self(value.0.saturating_mul(1000))
211    }
212}
213
214#[cfg(feature = "bitcoin")]
215impl From<Cost> for U32Weight {
216    fn from(value: Cost) -> Self {
217        // Saturating addition to avoid panic at numeric bounds
218        // This results in a slightly different rounding for cost values close to u32::MAX.
219        // These values are strictly larger than CONSENSUS_MAX and are of no significance.
220        Self(value.0.saturating_add(999) / 1000)
221    }
222}
223
224#[cfg(feature = "bitcoin")]
225impl From<bitcoin::Weight> for Cost {
226    fn from(value: bitcoin::Weight) -> Self {
227        Self(U32Weight::from(value).0.saturating_mul(1000))
228    }
229}
230
231#[cfg(feature = "bitcoin")]
232impl From<Cost> for bitcoin::Weight {
233    fn from(value: Cost) -> Self {
234        bitcoin::Weight::from_wu(u64::from(U32Weight::from(value).0))
235    }
236}
237
238/// Bounds on the resources required by a node during execution on the Bit Machine
239#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
240#[cfg_attr(feature = "serde", derive(Serialize))]
241pub struct NodeBounds {
242    /// Upper bound on the required number of cells (bits).
243    /// The root additionally requires the bit width of its source and target type (input, output)
244    pub extra_cells: usize,
245    /// Upper bound on the required number of frames (sum of read and write frames).
246    /// The root additionally requires two frames (input, output)
247    pub extra_frames: usize,
248    /// CPU cost
249    pub cost: Cost,
250}
251
252impl NodeBounds {
253    const NOP: Self = NodeBounds {
254        extra_cells: 0,
255        extra_frames: 0,
256        cost: Cost::OVERHEAD,
257    };
258    const NEVER_EXECUTED: Self = NodeBounds {
259        extra_cells: 0,
260        extra_frames: 0,
261        cost: Cost::NEVER_EXECUTED,
262    };
263
264    fn from_child(child: Self) -> Self {
265        NodeBounds {
266            extra_cells: child.extra_cells,
267            extra_frames: child.extra_frames,
268            cost: Cost::OVERHEAD + child.cost,
269        }
270    }
271
272    /// Node bounds for an `iden` node
273    pub fn iden(target_type: usize) -> NodeBounds {
274        NodeBounds {
275            extra_cells: 0,
276            extra_frames: 0,
277            cost: Cost::OVERHEAD + Cost::of_type(target_type),
278        }
279    }
280
281    /// Node bounds for a `unit` node
282    pub const fn unit() -> NodeBounds {
283        NodeBounds::NOP
284    }
285
286    /// Node bounds for an `injl` node
287    pub fn injl(child: Self) -> NodeBounds {
288        Self::from_child(child)
289    }
290
291    /// Node bounds for an `injr` node
292    pub fn injr(child: Self) -> NodeBounds {
293        Self::from_child(child)
294    }
295
296    /// Node bounds for a `take` node
297    pub fn take(child: Self) -> NodeBounds {
298        Self::from_child(child)
299    }
300
301    /// Node bounds for a `drop` node
302    pub fn drop(child: Self) -> NodeBounds {
303        Self::from_child(child)
304    }
305
306    /// Node bounds for a `comp` node
307    pub fn comp(left: Self, right: Self, mid_ty_bit_width: usize) -> NodeBounds {
308        NodeBounds {
309            extra_cells: mid_ty_bit_width + cmp::max(left.extra_cells, right.extra_cells),
310            extra_frames: 1 + cmp::max(left.extra_frames, right.extra_frames),
311            cost: Cost::OVERHEAD + Cost::of_type(mid_ty_bit_width) + left.cost + right.cost,
312        }
313    }
314
315    /// Node bounds for a `case` node
316    pub fn case(left: Self, right: Self) -> NodeBounds {
317        NodeBounds {
318            extra_cells: cmp::max(left.extra_cells, right.extra_cells),
319            extra_frames: cmp::max(left.extra_frames, right.extra_frames),
320            cost: Cost::OVERHEAD + cmp::max(left.cost, right.cost),
321        }
322    }
323
324    /// Node bounds for a `assertl` node
325    pub fn assertl(child: Self) -> NodeBounds {
326        Self::from_child(child)
327    }
328
329    /// Node bounds for a `assertr` node
330    pub fn assertr(child: Self) -> NodeBounds {
331        Self::from_child(child)
332    }
333
334    /// Node bounds for a `pair` node
335    pub fn pair(left: Self, right: Self) -> NodeBounds {
336        NodeBounds {
337            extra_cells: cmp::max(left.extra_cells, right.extra_cells),
338            extra_frames: cmp::max(left.extra_frames, right.extra_frames),
339            cost: Cost::OVERHEAD + left.cost + right.cost,
340        }
341    }
342
343    // disconnect, jet, witness, word
344    /// Node bounds for a `disconnect` node
345    pub fn disconnect(
346        left: Self,
347        right: Self,
348        left_target_b_bit_width: usize, // bit width of B in (b x C) target type
349        left_source_bit_width: usize,
350        left_target_bit_width: usize,
351    ) -> NodeBounds {
352        NodeBounds {
353            extra_cells: left_source_bit_width
354                + left_target_bit_width
355                + cmp::max(left.extra_cells, right.extra_cells),
356            extra_frames: 2 + cmp::max(left.extra_frames, right.extra_frames),
357            cost: Cost::OVERHEAD
358                + Cost::of_type(left_source_bit_width)
359                + Cost::of_type(left_source_bit_width)
360                + Cost::of_type(left_target_bit_width)
361                + Cost::of_type(left_target_b_bit_width)
362                + left.cost
363                + right.cost,
364        }
365    }
366
367    /// Node bounds for an arbitrary jet node
368    pub fn witness(target_ty_bit_width: usize) -> NodeBounds {
369        NodeBounds {
370            extra_cells: target_ty_bit_width,
371            extra_frames: 0,
372            cost: Cost::OVERHEAD + Cost::of_type(target_ty_bit_width),
373        }
374    }
375
376    /// Node bounds for an arbitrary jet node
377    pub fn jet(jet: &dyn Jet) -> NodeBounds {
378        NodeBounds {
379            extra_cells: 0,
380            extra_frames: 0,
381            cost: Cost::OVERHEAD + jet.cost(),
382        }
383    }
384
385    /// Node bounds for an arbitrary constant word node
386    pub fn const_word(word: &Word) -> NodeBounds {
387        NodeBounds {
388            extra_cells: 0,
389            extra_frames: 0,
390            cost: Cost::OVERHEAD + Cost::of_type(word.len()),
391        }
392    }
393
394    /// Node bounds for a `fail` node.
395    ///
396    /// This is a bit of a silly constructor because if a `fail` node is actually
397    /// executed in the bit machine, it will fail instantly, while if it *isn't*
398    /// executed, it will fail the "no unexecuted nodes" check. But to analyze
399    /// arbitrary programs, we need it.
400    pub const fn fail() -> NodeBounds {
401        NodeBounds::NEVER_EXECUTED
402    }
403}
404
405/// Number of frames required for the input and output of a Simplicity expression
406pub(crate) const IO_EXTRA_FRAMES: usize = 2;
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use simplicity_sys::ffi::bounded::cost_overhead;
412
413    #[test]
414    fn test_overhead() {
415        // Check that C overhead is same OVERHEAD
416        assert_eq!(Cost::OVERHEAD.0, cost_overhead());
417    }
418
419    #[test]
420    #[cfg(feature = "bitcoin")]
421    fn cost_to_weight() {
422        let test_vectors = vec![
423            (Cost::NEVER_EXECUTED, 0),
424            (Cost::from_milliweight(1), 1),
425            (Cost::from_milliweight(999), 1),
426            (Cost::from_milliweight(1_000), 1),
427            (Cost::from_milliweight(1_001), 2),
428            (Cost::from_milliweight(1_999), 2),
429            (Cost::from_milliweight(2_000), 2),
430            (Cost::CONSENSUS_MAX, 4_000_050),
431        ];
432
433        for (cost, expected_weight) in test_vectors {
434            let converted_cost = U32Weight::from(cost);
435            let expected_weight = U32Weight(expected_weight);
436            assert_eq!(converted_cost, expected_weight);
437        }
438    }
439
440    #[test]
441    #[cfg(feature = "elements")]
442    fn test_get_padding() {
443        // The budget of the empty witness stack is 51 WU:
444        //
445        // 1. 50 WU of free signature operations
446        // 2. 1 WU for the length byte of the witness stack
447        let empty = 51_000;
448
449        // The encoded annex starts with a length byte, so remove one padding byte from the annex
450        let test_vectors = vec![
451            (Cost::from_milliweight(0), vec![], None),
452            (Cost::from_milliweight(empty), vec![], None),
453            (Cost::from_milliweight(empty + 1), vec![], Some(1)),
454            (Cost::from_milliweight(empty + 2_000), vec![], Some(1)),
455            (Cost::from_milliweight(empty + 2_001), vec![], Some(2)),
456            (Cost::from_milliweight(empty + 3_000), vec![], Some(2)),
457            (Cost::from_milliweight(empty + 3_001), vec![], Some(3)),
458            (Cost::from_milliweight(empty + 4_000), vec![], Some(3)),
459            (Cost::from_milliweight(empty + 4_001), vec![], Some(4)),
460            (Cost::from_milliweight(empty + 50_000), vec![], Some(49)),
461            // Test around CompactSize boundary (annex_len crossing 252 -> 253)
462            // deficit = 253: annex_len = 252, compact_size = 1 byte, overhead = 2
463            (Cost::from_milliweight(empty + 253_000), vec![], Some(252)),
464            // deficit = 254: annex_len must be 253 (3-byte compact_size), overhead = 4
465            (Cost::from_milliweight(empty + 254_000), vec![], Some(253)),
466            // deficit = 255: same boundary case
467            (Cost::from_milliweight(empty + 255_000), vec![], Some(253)),
468            // deficit = 256: annex_len = 253, compact_size = 3, exact fit
469            (Cost::from_milliweight(empty + 256_000), vec![], Some(253)),
470            // deficit = 257: annex_len = 254
471            (Cost::from_milliweight(empty + 257_000), vec![], Some(254)),
472            // Large annex (exercises the 3-byte compact_size path)
473            (
474                Cost::from_milliweight(empty + 7_424_000),
475                vec![],
476                Some(7_421),
477            ),
478            // Hash loop example
479            (
480                Cost::from_milliweight(8_045_103),
481                vec![vec![], vec![0; 497], vec![0; 32], vec![0; 33]],
482                Some(7_424),
483            ),
484            // Max
485            (Cost::CONSENSUS_MAX, vec![], Some(3_999_994)),
486        ];
487
488        for (cost, mut witness, maybe_padding) in test_vectors {
489            match maybe_padding {
490                None => {
491                    assert!(cost.is_budget_valid(&witness));
492                    assert!(cost.get_padding(&witness).is_none());
493                }
494                Some(expected_annex_len) => {
495                    assert!(!cost.is_budget_valid(&witness));
496
497                    let annex_bytes = cost.get_padding(&witness).expect("not enough budget");
498                    assert_eq!(expected_annex_len, annex_bytes.len());
499                    witness.extend(std::iter::once(annex_bytes));
500                    assert!(cost.is_budget_valid(&witness));
501
502                    witness.pop();
503                    assert!(!cost.is_budget_valid(&witness), "Padding must be minimal");
504                }
505            }
506        }
507    }
508}