1use crate::jet::Jet;
4use std::{cmp, fmt};
5
6use crate::value::Word;
7#[cfg(feature = "elements")]
8use elements::encode::Encodable;
9#[cfg(feature = "elements")]
10use std::{convert::TryFrom, io};
11
12#[cfg(feature = "bitcoin")]
17#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
18struct U32Weight(u32);
19
20#[cfg(feature = "bitcoin")]
21impl std::ops::Sub for U32Weight {
22 type Output = Self;
23
24 fn sub(self, rhs: Self) -> Self::Output {
25 Self(self.0.saturating_sub(rhs.0))
26 }
27}
28
29#[cfg(feature = "bitcoin")]
30impl From<bitcoin::Weight> for U32Weight {
31 fn from(value: bitcoin::Weight) -> Self {
32 Self(u32::try_from(value.to_wu()).unwrap_or(u32::MAX))
33 }
34}
35
36#[cfg(feature = "bitcoin")]
37impl From<U32Weight> for bitcoin::Weight {
38 fn from(value: U32Weight) -> Self {
39 bitcoin::Weight::from_wu(u64::from(value.0))
40 }
41}
42
43#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
68pub struct Cost(u32);
69
70impl Cost {
71 const OVERHEAD: Self = Cost(100);
75
76 const NEVER_EXECUTED: Self = Cost(0);
80
81 pub const CONSENSUS_MAX: Self = Cost(4_000_050_000);
95
96 pub const fn of_type(bit_width: usize) -> Self {
98 Cost(bit_width as u32)
100 }
101
102 pub const fn from_milliweight(milliweight: u32) -> Self {
104 Cost(milliweight)
105 }
106
107 pub fn is_consensus_valid(self) -> bool {
112 self <= Self::CONSENSUS_MAX
113 }
114
115 #[cfg(feature = "elements")]
120 fn get_budget(script_witness: &Vec<Vec<u8>>) -> U32Weight {
121 let mut sink = io::sink();
122 let witness_stack_serialized_len = script_witness
123 .consensus_encode(&mut sink)
124 .expect("writing to sink never fails");
125 let budget = u32::try_from(witness_stack_serialized_len)
126 .expect("Serialized witness stack must be shorter than 2^32 elements")
127 .saturating_add(50);
128 U32Weight(budget)
129 }
130
131 #[cfg(feature = "elements")]
137 pub fn is_budget_valid(self, script_witness: &Vec<Vec<u8>>) -> bool {
138 let budget = Self::get_budget(script_witness);
139 self.0 <= budget.0.saturating_mul(1000)
140 }
141
142 #[cfg(feature = "elements")]
148 pub fn get_padding(self, script_witness: &Vec<Vec<u8>>) -> Option<Vec<u8>> {
149 let weight = U32Weight::from(self);
150 let budget = Self::get_budget(script_witness);
151 if weight <= budget {
152 return None;
153 }
154
155 let required_padding = weight - budget - U32Weight(2);
162 let padding_len = required_padding.0 as usize; let annex_bytes: Vec<u8> = std::iter::once(0x50)
164 .chain(std::iter::repeat(0x00).take(padding_len))
165 .collect();
166
167 Some(annex_bytes)
168 }
169}
170
171impl fmt::Display for Cost {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 fmt::Display::fmt(&self.0, f)
174 }
175}
176
177impl std::ops::Add for Cost {
178 type Output = Self;
179
180 fn add(self, rhs: Self) -> Self::Output {
181 Cost(self.0.saturating_add(rhs.0))
182 }
183}
184
185#[cfg(feature = "bitcoin")]
186impl From<U32Weight> for Cost {
187 fn from(value: U32Weight) -> Self {
188 Self(value.0.saturating_mul(1000))
189 }
190}
191
192#[cfg(feature = "bitcoin")]
193impl From<Cost> for U32Weight {
194 fn from(value: Cost) -> Self {
195 Self(value.0.saturating_add(999) / 1000)
199 }
200}
201
202#[cfg(feature = "bitcoin")]
203impl From<bitcoin::Weight> for Cost {
204 fn from(value: bitcoin::Weight) -> Self {
205 Self(U32Weight::from(value).0.saturating_mul(1000))
206 }
207}
208
209#[cfg(feature = "bitcoin")]
210impl From<Cost> for bitcoin::Weight {
211 fn from(value: Cost) -> Self {
212 bitcoin::Weight::from_wu(u64::from(U32Weight::from(value).0))
213 }
214}
215
216#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
218pub struct NodeBounds {
219 pub extra_cells: usize,
222 pub extra_frames: usize,
225 pub cost: Cost,
227}
228
229impl NodeBounds {
230 const NOP: Self = NodeBounds {
231 extra_cells: 0,
232 extra_frames: 0,
233 cost: Cost::OVERHEAD,
234 };
235 const NEVER_EXECUTED: Self = NodeBounds {
236 extra_cells: 0,
237 extra_frames: 0,
238 cost: Cost::NEVER_EXECUTED,
239 };
240
241 fn from_child(child: Self) -> Self {
242 NodeBounds {
243 extra_cells: child.extra_cells,
244 extra_frames: child.extra_frames,
245 cost: Cost::OVERHEAD + child.cost,
246 }
247 }
248
249 pub fn iden(target_type: usize) -> NodeBounds {
251 NodeBounds {
252 extra_cells: 0,
253 extra_frames: 0,
254 cost: Cost::OVERHEAD + Cost::of_type(target_type),
255 }
256 }
257
258 pub const fn unit() -> NodeBounds {
260 NodeBounds::NOP
261 }
262
263 pub fn injl(child: Self) -> NodeBounds {
265 Self::from_child(child)
266 }
267
268 pub fn injr(child: Self) -> NodeBounds {
270 Self::from_child(child)
271 }
272
273 pub fn take(child: Self) -> NodeBounds {
275 Self::from_child(child)
276 }
277
278 pub fn drop(child: Self) -> NodeBounds {
280 Self::from_child(child)
281 }
282
283 pub fn comp(left: Self, right: Self, mid_ty_bit_width: usize) -> NodeBounds {
285 NodeBounds {
286 extra_cells: mid_ty_bit_width + cmp::max(left.extra_cells, right.extra_cells),
287 extra_frames: 1 + cmp::max(left.extra_frames, right.extra_frames),
288 cost: Cost::OVERHEAD + Cost::of_type(mid_ty_bit_width) + left.cost + right.cost,
289 }
290 }
291
292 pub fn case(left: Self, right: Self) -> NodeBounds {
294 NodeBounds {
295 extra_cells: cmp::max(left.extra_cells, right.extra_cells),
296 extra_frames: cmp::max(left.extra_frames, right.extra_frames),
297 cost: Cost::OVERHEAD + cmp::max(left.cost, right.cost),
298 }
299 }
300
301 pub fn assertl(child: Self) -> NodeBounds {
303 Self::from_child(child)
304 }
305
306 pub fn assertr(child: Self) -> NodeBounds {
308 Self::from_child(child)
309 }
310
311 pub fn pair(left: Self, right: Self) -> NodeBounds {
313 NodeBounds {
314 extra_cells: cmp::max(left.extra_cells, right.extra_cells),
315 extra_frames: cmp::max(left.extra_frames, right.extra_frames),
316 cost: Cost::OVERHEAD + left.cost + right.cost,
317 }
318 }
319
320 pub fn disconnect(
323 left: Self,
324 right: Self,
325 left_target_b_bit_width: usize, left_source_bit_width: usize,
327 left_target_bit_width: usize,
328 ) -> NodeBounds {
329 NodeBounds {
330 extra_cells: left_source_bit_width
331 + left_target_bit_width
332 + cmp::max(left.extra_cells, right.extra_cells),
333 extra_frames: 2 + cmp::max(left.extra_frames, right.extra_frames),
334 cost: Cost::OVERHEAD
335 + Cost::of_type(left_source_bit_width)
336 + Cost::of_type(left_source_bit_width)
337 + Cost::of_type(left_target_bit_width)
338 + Cost::of_type(left_target_b_bit_width)
339 + left.cost
340 + right.cost,
341 }
342 }
343
344 pub fn witness(target_ty_bit_width: usize) -> NodeBounds {
346 NodeBounds {
347 extra_cells: target_ty_bit_width,
348 extra_frames: 0,
349 cost: Cost::OVERHEAD + Cost::of_type(target_ty_bit_width),
350 }
351 }
352
353 pub fn jet<J: Jet>(jet: J) -> NodeBounds {
355 NodeBounds {
356 extra_cells: 0,
357 extra_frames: 0,
358 cost: Cost::OVERHEAD + jet.cost(),
359 }
360 }
361
362 pub fn const_word(word: &Word) -> NodeBounds {
364 NodeBounds {
365 extra_cells: 0,
366 extra_frames: 0,
367 cost: Cost::OVERHEAD + Cost::of_type(word.len()),
368 }
369 }
370
371 pub const fn fail() -> NodeBounds {
378 NodeBounds::NEVER_EXECUTED
379 }
380}
381
382pub(crate) const IO_EXTRA_FRAMES: usize = 2;
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use simplicity_sys::ffi::bounded::cost_overhead;
389
390 #[test]
391 fn test_overhead() {
392 assert_eq!(Cost::OVERHEAD.0, cost_overhead());
394 }
395
396 #[test]
397 #[cfg(feature = "bitcoin")]
398 fn cost_to_weight() {
399 let test_vectors = vec![
400 (Cost::NEVER_EXECUTED, 0),
401 (Cost::from_milliweight(1), 1),
402 (Cost::from_milliweight(999), 1),
403 (Cost::from_milliweight(1_000), 1),
404 (Cost::from_milliweight(1_001), 2),
405 (Cost::from_milliweight(1_999), 2),
406 (Cost::from_milliweight(2_000), 2),
407 (Cost::CONSENSUS_MAX, 4_000_050),
408 ];
409
410 for (cost, expected_weight) in test_vectors {
411 let converted_cost = U32Weight::from(cost);
412 let expected_weight = U32Weight(expected_weight);
413 assert_eq!(converted_cost, expected_weight);
414 }
415 }
416
417 #[test]
418 #[cfg(feature = "elements")]
419 fn test_get_padding() {
420 let empty = 51_000;
425
426 let test_vectors = vec![
428 (Cost::from_milliweight(0), vec![], None),
429 (Cost::from_milliweight(empty), vec![], None),
430 (Cost::from_milliweight(empty + 1), vec![], Some(1)),
431 (Cost::from_milliweight(empty + 2_000), vec![], Some(1)),
432 (Cost::from_milliweight(empty + 2_001), vec![], Some(2)),
433 (Cost::from_milliweight(empty + 3_000), vec![], Some(2)),
434 (Cost::from_milliweight(empty + 3_001), vec![], Some(3)),
435 (Cost::from_milliweight(empty + 4_000), vec![], Some(3)),
436 (Cost::from_milliweight(empty + 4_001), vec![], Some(4)),
437 (Cost::from_milliweight(empty + 50_000), vec![], Some(49)),
438 ];
439
440 for (cost, mut witness, maybe_padding) in test_vectors {
441 match maybe_padding {
442 None => {
443 assert!(cost.is_budget_valid(&witness));
444 assert!(cost.get_padding(&witness).is_none());
445 }
446 Some(expected_annex_len) => {
447 assert!(!cost.is_budget_valid(&witness));
448
449 let annex_bytes = cost.get_padding(&witness).expect("not enough budget");
450 assert_eq!(expected_annex_len, annex_bytes.len());
451 witness.extend(std::iter::once(annex_bytes));
452 assert!(cost.is_budget_valid(&witness));
453
454 witness.pop();
455 assert!(!cost.is_budget_valid(&witness), "Padding must be minimal");
456 }
457 }
458 }
459 }
460}