chia_protocol/
program.rs

1use crate::bytes::Bytes;
2#[cfg(feature = "py-bindings")]
3use crate::LazyNode;
4use chia_sha2::Sha256;
5use chia_traits::chia_error::{Error, Result};
6use chia_traits::Streamable;
7use clvm_traits::{FromClvm, FromClvmError, ToClvm, ToClvmError};
8use clvmr::allocator::NodePtr;
9use clvmr::cost::Cost;
10use clvmr::reduction::EvalErr;
11use clvmr::run_program;
12use clvmr::serde::{
13    node_from_bytes, node_from_bytes_backrefs, node_to_bytes, serialized_length_from_bytes,
14    serialized_length_from_bytes_trusted,
15};
16#[cfg(feature = "py-bindings")]
17use clvmr::SExp;
18use clvmr::{Allocator, ChiaDialect};
19#[cfg(feature = "py-bindings")]
20use pyo3::prelude::*;
21#[cfg(feature = "py-bindings")]
22use pyo3::types::PyType;
23use std::io::Cursor;
24use std::ops::Deref;
25#[cfg(feature = "py-bindings")]
26use std::rc::Rc;
27
28#[cfg(feature = "py-bindings")]
29use clvm_utils::CurriedProgram;
30
31#[cfg_attr(feature = "py-bindings", pyclass(subclass), derive(PyStreamable))]
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct Program(Bytes);
35
36impl Default for Program {
37    fn default() -> Self {
38        Self(vec![0x80].into())
39    }
40}
41
42impl Program {
43    pub fn new(bytes: Bytes) -> Self {
44        Self(bytes)
45    }
46
47    pub fn len(&self) -> usize {
48        self.0.len()
49    }
50
51    pub fn is_empty(&self) -> bool {
52        self.0.is_empty()
53    }
54
55    pub fn as_slice(&self) -> &[u8] {
56        self.0.as_slice()
57    }
58
59    pub fn to_vec(&self) -> Vec<u8> {
60        self.0.to_vec()
61    }
62
63    pub fn into_inner(self) -> Bytes {
64        self.0
65    }
66
67    pub fn into_bytes(self) -> Vec<u8> {
68        self.0.into_inner()
69    }
70
71    pub fn run<A: ToClvm<Allocator>>(
72        &self,
73        a: &mut Allocator,
74        flags: u32,
75        max_cost: Cost,
76        arg: &A,
77    ) -> std::result::Result<(Cost, NodePtr), EvalErr> {
78        let arg = arg.to_clvm(a).map_err(|_| {
79            EvalErr(
80                a.nil(),
81                "failed to convert argument to CLVM objects".to_string(),
82            )
83        })?;
84        let program =
85            node_from_bytes_backrefs(a, self.0.as_ref()).expect("invalid SerializedProgram");
86        let dialect = ChiaDialect::new(flags);
87        let reduction = run_program(a, &dialect, program, arg, max_cost)?;
88        Ok((reduction.0, reduction.1))
89    }
90}
91
92impl From<Bytes> for Program {
93    fn from(value: Bytes) -> Self {
94        Self(value)
95    }
96}
97
98impl From<Program> for Bytes {
99    fn from(value: Program) -> Self {
100        value.0
101    }
102}
103
104impl From<Vec<u8>> for Program {
105    fn from(value: Vec<u8>) -> Self {
106        Self(Bytes::new(value))
107    }
108}
109
110impl From<&[u8]> for Program {
111    fn from(value: &[u8]) -> Self {
112        Self(value.into())
113    }
114}
115
116impl From<Program> for Vec<u8> {
117    fn from(value: Program) -> Self {
118        value.0.into()
119    }
120}
121
122impl AsRef<[u8]> for Program {
123    fn as_ref(&self) -> &[u8] {
124        self.0.as_ref()
125    }
126}
127
128impl Deref for Program {
129    type Target = [u8];
130
131    fn deref(&self) -> &[u8] {
132        &self.0
133    }
134}
135
136#[cfg(feature = "arbitrary")]
137impl<'a> arbitrary::Arbitrary<'a> for Program {
138    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
139        // generate an arbitrary CLVM structure. Not likely a valid program.
140        let mut items_left = 1;
141        let mut total_items = 0;
142        let mut buf = Vec::<u8>::with_capacity(200);
143
144        while items_left > 0 {
145            if total_items < 100 && u.ratio(1, 4).unwrap() {
146                // make a pair
147                buf.push(0xff);
148                items_left += 2;
149            } else {
150                // make an atom. just single bytes for now
151                buf.push(u.int_in_range(0..=0x80).unwrap());
152            }
153            total_items += 1;
154            items_left -= 1;
155        }
156        Ok(Self(buf.into()))
157    }
158}
159
160#[cfg(feature = "py-bindings")]
161use chia_traits::{FromJsonDict, ToJsonDict};
162
163#[cfg(feature = "py-bindings")]
164use chia_py_streamable_macro::PyStreamable;
165
166#[cfg(feature = "py-bindings")]
167use pyo3::types::{PyList, PyTuple};
168
169#[cfg(feature = "py-bindings")]
170use pyo3::exceptions::*;
171
172// TODO: this conversion function should probably be converted to a type holding
173// the PyAny object implementing the ToClvm trait. That way, the Program::to()
174// function could turn a python structure directly into bytes, without taking
175// the detour via Allocator. propagating python errors through ToClvmError is a
176// bit tricky though
177#[cfg(feature = "py-bindings")]
178fn clvm_convert(a: &mut Allocator, o: &Bound<'_, PyAny>) -> PyResult<NodePtr> {
179    // None
180    if o.is_none() {
181        Ok(a.nil())
182    // bytes
183    } else if let Ok(buffer) = o.extract::<&[u8]>() {
184        a.new_atom(buffer)
185            .map_err(|e| PyMemoryError::new_err(e.to_string()))
186    // str
187    } else if let Ok(text) = o.extract::<String>() {
188        a.new_atom(text.as_bytes())
189            .map_err(|e| PyMemoryError::new_err(e.to_string()))
190    // int
191    } else if let Ok(val) = o.extract::<clvmr::number::Number>() {
192        a.new_number(val)
193            .map_err(|e| PyMemoryError::new_err(e.to_string()))
194    // Tuple (SExp-like)
195    } else if let Ok(pair) = o.downcast::<PyTuple>() {
196        if pair.len() == 2 {
197            let left = clvm_convert(a, &pair.get_item(0)?)?;
198            let right = clvm_convert(a, &pair.get_item(1)?)?;
199            a.new_pair(left, right)
200                .map_err(|e| PyMemoryError::new_err(e.to_string()))
201        } else {
202            Err(PyValueError::new_err(format!(
203                "can't cast tuple of size {}",
204                pair.len()
205            )))
206        }
207    // List
208    } else if let Ok(list) = o.downcast::<PyList>() {
209        let mut rev = Vec::new();
210        for py_item in list.iter() {
211            rev.push(py_item);
212        }
213        let mut ret = a.nil();
214        for py_item in rev.into_iter().rev() {
215            let item = clvm_convert(a, &py_item)?;
216            ret = a
217                .new_pair(item, ret)
218                .map_err(|e| PyMemoryError::new_err(e.to_string()))?;
219        }
220        Ok(ret)
221    // SExp (such as clvm.SExp)
222    } else if let (Ok(atom), Ok(pair)) = (o.getattr("atom"), o.getattr("pair")) {
223        if atom.is_none() {
224            if pair.is_none() {
225                Err(PyTypeError::new_err(format!("invalid SExp item {o}")))
226            } else {
227                let pair = pair.downcast::<PyTuple>()?;
228                let left = clvm_convert(a, &pair.get_item(0)?)?;
229                let right = clvm_convert(a, &pair.get_item(1)?)?;
230                a.new_pair(left, right)
231                    .map_err(|e| PyMemoryError::new_err(e.to_string()))
232            }
233        } else {
234            a.new_atom(atom.extract::<&[u8]>()?)
235                .map_err(|e| PyMemoryError::new_err(e.to_string()))
236        }
237    // Program itself. This is interpreted as a program in serialized form, and
238    // just a buffer of that serialization. This is an optimization to finding
239    // __bytes__() and calling it
240    } else if let Ok(prg) = o.extract::<Program>() {
241        a.new_atom(prg.0.as_slice())
242            .map_err(|e| PyMemoryError::new_err(e.to_string()))
243    // anything convertible to bytes
244    } else if let Ok(fun) = o.getattr("__bytes__") {
245        let bytes = fun.call0()?;
246        let buffer = bytes.extract::<&[u8]>()?;
247        a.new_atom(buffer)
248            .map_err(|e| PyMemoryError::new_err(e.to_string()))
249    } else {
250        Err(PyTypeError::new_err(format!(
251            "unknown parameter to run_with_cost() {o}"
252        )))
253    }
254}
255
256#[cfg(feature = "py-bindings")]
257fn clvm_serialize(a: &mut Allocator, o: &Bound<'_, PyAny>) -> PyResult<NodePtr> {
258    /*
259    When passing arguments to run(), there's some special treatment, before falling
260    back on the regular python -> CLVM conversion (implemented by clvm_convert
261    above). This function mimics the _serialize() function in python:
262
263       def _serialize(node: object) -> bytes:
264           if isinstance(node, list):
265               serialized_list = bytearray()
266               for a in node:
267                   serialized_list += b"\xff"
268                   serialized_list += _serialize(a)
269               serialized_list += b"\x80"
270               return bytes(serialized_list)
271           if type(node) is SerializedProgram:
272               return bytes(node)
273           if type(node) is Program:
274               return bytes(node)
275           else:
276               ret: bytes = SExp.to(node).as_bin()
277               return ret
278    */
279
280    // List
281    if let Ok(list) = o.downcast::<PyList>() {
282        let mut rev = Vec::new();
283        for py_item in list.iter() {
284            rev.push(py_item);
285        }
286        let mut ret = a.nil();
287        for py_item in rev.into_iter().rev() {
288            let item = clvm_serialize(a, &py_item)?;
289            ret = a
290                .new_pair(item, ret)
291                .map_err(|e| PyMemoryError::new_err(e.to_string()))?;
292        }
293        Ok(ret)
294    // Program itself
295    } else if let Ok(prg) = o.extract::<Program>() {
296        Ok(node_from_bytes_backrefs(a, prg.0.as_slice())?)
297    } else {
298        clvm_convert(a, o)
299    }
300}
301
302#[cfg(feature = "py-bindings")]
303#[allow(clippy::needless_pass_by_value)]
304#[pymethods]
305impl Program {
306    #[pyo3(name = "default")]
307    #[staticmethod]
308    fn py_default() -> Self {
309        Self::default()
310    }
311
312    #[staticmethod]
313    #[pyo3(name = "to")]
314    fn py_to(args: &Bound<'_, PyAny>) -> PyResult<Program> {
315        let mut a = Allocator::new_limited(500_000_000);
316        let clvm = clvm_convert(&mut a, args)?;
317        Program::from_clvm(&a, clvm)
318            .map_err(|error| PyErr::new::<PyTypeError, _>(error.to_string()))
319    }
320
321    fn get_tree_hash(&self) -> crate::Bytes32 {
322        clvm_utils::tree_hash_from_bytes(self.0.as_ref())
323            .unwrap()
324            .into()
325    }
326
327    #[staticmethod]
328    fn fromhex(h: String) -> Result<Self> {
329        let s = if let Some(st) = h.strip_prefix("0x") {
330            st
331        } else {
332            &h[..]
333        };
334        Self::from_bytes(hex::decode(s).map_err(|_| Error::InvalidString)?.as_slice())
335    }
336
337    fn run_rust(
338        &self,
339        py: Python<'_>,
340        max_cost: u64,
341        flags: u32,
342        args: &Bound<'_, PyAny>,
343    ) -> PyResult<(u64, LazyNode)> {
344        use clvmr::reduction::Response;
345
346        let mut a = Allocator::new_limited(500_000_000);
347        // The python behavior here is a bit messy, and is best not emulated
348        // on the rust side. We must be able to pass a Program as an argument,
349        // and it being treated as the CLVM structure it represents. In python's
350        // SerializedProgram, we have a hack where we interpret the first
351        // "layer" of SerializedProgram, or lists of SerializedProgram this way.
352        // But if we encounter an Optional or tuple, we defer to the clvm
353        // wheel's conversion function to SExp. This level does not have any
354        // special treatment for SerializedProgram (as that would cause a
355        // circular dependency).
356        let clvm_args = clvm_serialize(&mut a, args)?;
357
358        let r: Response = (|| -> PyResult<Response> {
359            let program = node_from_bytes_backrefs(&mut a, self.0.as_ref())?;
360            let dialect = ChiaDialect::new(flags);
361
362            Ok(py.allow_threads(|| run_program(&mut a, &dialect, program, clvm_args, max_cost)))
363        })()?;
364        match r {
365            Ok(reduction) => {
366                let val = LazyNode::new(Rc::new(a), reduction.1);
367                Ok((reduction.0, val))
368            }
369            Err(eval_err) => {
370                let blob = node_to_bytes(&a, eval_err.0).ok().map(hex::encode);
371                Err(PyValueError::new_err((eval_err.1, blob)))
372            }
373        }
374    }
375
376    fn uncurry_rust(&self) -> PyResult<(LazyNode, LazyNode)> {
377        let mut a = Allocator::new_limited(500_000_000);
378        let prg = node_from_bytes_backrefs(&mut a, self.0.as_ref())?;
379        let Ok(uncurried) = CurriedProgram::<NodePtr, NodePtr>::from_clvm(&a, prg) else {
380            let a = Rc::new(a);
381            let prg = LazyNode::new(a.clone(), prg);
382            let ret = a.nil();
383            let ret = LazyNode::new(a, ret);
384            return Ok((prg, ret));
385        };
386
387        let mut curried_args = Vec::<NodePtr>::new();
388        let mut args = uncurried.args;
389        loop {
390            if let SExp::Atom = a.sexp(args) {
391                break;
392            }
393            // the args of curried puzzles are in the form of:
394            // (c . ((q . <arg1>) . (<rest> . ())))
395            let (_, ((_, arg), (rest, ()))) =
396                <(
397                    clvm_traits::MatchByte<4>,
398                    (clvm_traits::match_quote!(NodePtr), (NodePtr, ())),
399                ) as FromClvm<Allocator>>::from_clvm(&a, args)
400                .map_err(|error| PyErr::new::<PyTypeError, _>(error.to_string()))?;
401            curried_args.push(arg);
402            args = rest;
403        }
404        let mut ret = a.nil();
405        for item in curried_args.into_iter().rev() {
406            ret = a.new_pair(item, ret).map_err(|_e| Error::EndOfBuffer)?;
407        }
408        let a = Rc::new(a);
409        let prg = LazyNode::new(a.clone(), uncurried.program);
410        let ret = LazyNode::new(a, ret);
411        Ok((prg, ret))
412    }
413}
414
415impl Streamable for Program {
416    fn update_digest(&self, digest: &mut Sha256) {
417        digest.update(&self.0);
418    }
419
420    fn stream(&self, out: &mut Vec<u8>) -> Result<()> {
421        out.extend_from_slice(self.0.as_ref());
422        Ok(())
423    }
424
425    fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> Result<Self> {
426        let pos = input.position();
427        let buf: &[u8] = &input.get_ref()[pos as usize..];
428        let len = if TRUSTED {
429            serialized_length_from_bytes_trusted(buf).map_err(|_e| Error::EndOfBuffer)?
430        } else {
431            serialized_length_from_bytes(buf).map_err(|_e| Error::EndOfBuffer)?
432        };
433        if buf.len() < len as usize {
434            return Err(Error::EndOfBuffer);
435        }
436        let program = buf[..len as usize].to_vec();
437        input.set_position(pos + len);
438        Ok(Program(program.into()))
439    }
440}
441
442#[cfg(feature = "py-bindings")]
443impl ToJsonDict for Program {
444    fn to_json_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
445        self.0.to_json_dict(py)
446    }
447}
448
449#[cfg(feature = "py-bindings")]
450#[pymethods]
451impl Program {
452    #[classmethod]
453    #[pyo3(name = "from_parent")]
454    pub fn from_parent(_cls: &Bound<'_, PyType>, _instance: &Self) -> PyResult<PyObject> {
455        Err(PyNotImplementedError::new_err(
456            "This class does not support from_parent().",
457        ))
458    }
459}
460
461#[cfg(feature = "py-bindings")]
462impl FromJsonDict for Program {
463    fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
464        let bytes = Bytes::from_json_dict(o)?;
465        let len =
466            serialized_length_from_bytes(bytes.as_slice()).map_err(|_e| Error::EndOfBuffer)?;
467        if len as usize != bytes.len() {
468            // If the bytes in the JSON string is not a valid CLVM
469            // serialization, or if it has garbage at the end of the string,
470            // reject it
471            return Err(Error::InvalidClvm)?;
472        }
473        Ok(Self(bytes))
474    }
475}
476
477impl FromClvm<Allocator> for Program {
478    fn from_clvm(a: &Allocator, node: NodePtr) -> std::result::Result<Self, FromClvmError> {
479        Ok(Self(
480            node_to_bytes(a, node)
481                .map_err(|error| FromClvmError::Custom(error.to_string()))?
482                .into(),
483        ))
484    }
485}
486
487impl ToClvm<Allocator> for Program {
488    fn to_clvm(&self, a: &mut Allocator) -> std::result::Result<NodePtr, ToClvmError> {
489        node_from_bytes(a, self.0.as_ref()).map_err(|error| ToClvmError::Custom(error.to_string()))
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn program_roundtrip() {
499        let a = &mut Allocator::new();
500        let expected = "ff01ff02ff62ff0480";
501        let expected_bytes = hex::decode(expected).unwrap();
502
503        let ptr = node_from_bytes(a, &expected_bytes).unwrap();
504        let program = Program::from_clvm(a, ptr).unwrap();
505
506        let round_trip = program.to_clvm(a).unwrap();
507        assert_eq!(expected, hex::encode(node_to_bytes(a, round_trip).unwrap()));
508    }
509
510    #[test]
511    fn program_run() {
512        let a = &mut Allocator::new();
513
514        // (+ 2 5)
515        let prg = Program::from_bytes(&hex::decode("ff10ff02ff0580").expect("hex::decode"))
516            .expect("from_bytes");
517        let (cost, result) = prg.run(a, 0, 1000, &[1300, 37]).expect("run");
518        assert_eq!(cost, 869);
519        assert_eq!(a.number(result), 1337.into());
520    }
521}
522
523#[cfg(all(test, feature = "serde"))]
524mod serde_tests {
525    use super::*;
526
527    #[test]
528    fn test_program_is_bytes() -> anyhow::Result<()> {
529        let bytes = Bytes::new(vec![1, 2, 3]);
530        let program = Program::new(bytes.clone());
531
532        let bytes_json = serde_json::to_string(&bytes)?;
533        let program_json = serde_json::to_string(&program)?;
534
535        assert_eq!(program_json, bytes_json);
536
537        Ok(())
538    }
539}