chia_protocol/
program.rs

1use crate::bytes::Bytes;
2#[cfg(feature = "py-bindings")]
3use crate::LazyNode;
4use chia_sha2::Sha256;
5use chia_traits::chia_error::{Error, Result};
6use chia_traits::Streamable;
7use clvm_traits::{FromClvm, FromClvmError, ToClvm, ToClvmError};
8use clvmr::allocator::NodePtr;
9use clvmr::cost::Cost;
10use clvmr::error::EvalErr;
11use clvmr::run_program;
12use clvmr::serde::{
13    node_from_bytes, node_from_bytes_backrefs, node_to_bytes, serialized_length_from_bytes,
14    serialized_length_from_bytes_trusted,
15};
16#[cfg(feature = "py-bindings")]
17use clvmr::SExp;
18use clvmr::{Allocator, ChiaDialect};
19#[cfg(feature = "py-bindings")]
20use pyo3::prelude::*;
21#[cfg(feature = "py-bindings")]
22use pyo3::types::PyType;
23use std::io::Cursor;
24use std::ops::Deref;
25#[cfg(feature = "py-bindings")]
26use std::rc::Rc;
27
28#[cfg(feature = "py-bindings")]
29use clvm_utils::CurriedProgram;
30
31#[cfg_attr(feature = "py-bindings", pyclass(subclass), derive(PyStreamable))]
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct Program(Bytes);
35
36impl Default for Program {
37    fn default() -> Self {
38        Self(vec![0x80].into())
39    }
40}
41
42impl Program {
43    pub fn new(bytes: Bytes) -> Self {
44        Self(bytes)
45    }
46
47    pub fn len(&self) -> usize {
48        self.0.len()
49    }
50
51    pub fn is_empty(&self) -> bool {
52        self.0.is_empty()
53    }
54
55    pub fn as_slice(&self) -> &[u8] {
56        self.0.as_slice()
57    }
58
59    pub fn to_vec(&self) -> Vec<u8> {
60        self.0.to_vec()
61    }
62
63    pub fn into_inner(self) -> Bytes {
64        self.0
65    }
66
67    pub fn into_bytes(self) -> Vec<u8> {
68        self.0.into_inner()
69    }
70
71    pub fn run<A: ToClvm<Allocator>>(
72        &self,
73        a: &mut Allocator,
74        flags: u32,
75        max_cost: Cost,
76        arg: &A,
77    ) -> std::result::Result<(Cost, NodePtr), EvalErr> {
78        let arg = arg.to_clvm(a).map_err(|_| {
79            EvalErr::InvalidAllocArg(
80                a.nil(),
81                "failed to convert argument to CLVM objects".to_string(),
82            )
83        })?;
84        let program =
85            node_from_bytes_backrefs(a, self.0.as_ref()).expect("invalid SerializedProgram");
86        let dialect = ChiaDialect::new(flags);
87        let reduction = run_program(a, &dialect, program, arg, max_cost)?;
88        Ok((reduction.0, reduction.1))
89    }
90}
91
92impl From<Bytes> for Program {
93    fn from(value: Bytes) -> Self {
94        Self(value)
95    }
96}
97
98impl From<Program> for Bytes {
99    fn from(value: Program) -> Self {
100        value.0
101    }
102}
103
104impl From<Vec<u8>> for Program {
105    fn from(value: Vec<u8>) -> Self {
106        Self(Bytes::new(value))
107    }
108}
109
110impl From<&[u8]> for Program {
111    fn from(value: &[u8]) -> Self {
112        Self(value.into())
113    }
114}
115
116impl From<Program> for Vec<u8> {
117    fn from(value: Program) -> Self {
118        value.0.into()
119    }
120}
121
122impl AsRef<[u8]> for Program {
123    fn as_ref(&self) -> &[u8] {
124        self.0.as_ref()
125    }
126}
127
128impl Deref for Program {
129    type Target = [u8];
130
131    fn deref(&self) -> &[u8] {
132        &self.0
133    }
134}
135
136#[cfg(feature = "arbitrary")]
137impl<'a> arbitrary::Arbitrary<'a> for Program {
138    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
139        // generate an arbitrary CLVM structure. Not likely a valid program.
140        let mut items_left = 1;
141        let mut total_items = 0;
142        let mut buf = Vec::<u8>::with_capacity(200);
143
144        while items_left > 0 {
145            if total_items < 100 && u.ratio(1, 4).unwrap() {
146                // make a pair
147                buf.push(0xff);
148                items_left += 2;
149            } else {
150                // make an atom. just single bytes for now
151                buf.push(u.int_in_range(0..=0x80).unwrap());
152            }
153            total_items += 1;
154            items_left -= 1;
155        }
156        Ok(Self(buf.into()))
157    }
158}
159
160#[cfg(feature = "py-bindings")]
161use chia_traits::{FromJsonDict, ToJsonDict};
162
163#[cfg(feature = "py-bindings")]
164use chia_py_streamable_macro::PyStreamable;
165
166#[cfg(feature = "py-bindings")]
167use pyo3::types::{PyList, PyTuple};
168
169#[cfg(feature = "py-bindings")]
170use pyo3::exceptions::*;
171
172#[cfg(feature = "py-bindings")]
173#[allow(clippy::needless_pass_by_value)]
174// We use this for the map function, so using a reference here is not ideal.
175fn map_pyerr(err: EvalErr) -> PyErr {
176    // Convert EvalErr to PyErr, so that it can be used in python bindings
177    PyValueError::new_err(err.to_string())
178}
179
180// TODO: this conversion function should probably be converted to a type holding
181// the PyAny object implementing the ToClvm trait. That way, the Program::to()
182// function could turn a python structure directly into bytes, without taking
183// the detour via Allocator. propagating python errors through ToClvmError is a
184// bit tricky though
185#[cfg(feature = "py-bindings")]
186fn clvm_convert(a: &mut Allocator, o: &Bound<'_, PyAny>) -> PyResult<NodePtr> {
187    // None
188    if o.is_none() {
189        Ok(a.nil())
190    // bytes
191    } else if let Ok(buffer) = o.extract::<&[u8]>() {
192        a.new_atom(buffer)
193            .map_err(|e| PyMemoryError::new_err(e.to_string()))
194    // str
195    } else if let Ok(text) = o.extract::<String>() {
196        a.new_atom(text.as_bytes())
197            .map_err(|e| PyMemoryError::new_err(e.to_string()))
198    // int
199    } else if let Ok(val) = o.extract::<clvmr::number::Number>() {
200        a.new_number(val)
201            .map_err(|e| PyMemoryError::new_err(e.to_string()))
202    // Tuple (SExp-like)
203    } else if let Ok(pair) = o.downcast::<PyTuple>() {
204        if pair.len() == 2 {
205            let left = clvm_convert(a, &pair.get_item(0)?)?;
206            let right = clvm_convert(a, &pair.get_item(1)?)?;
207            a.new_pair(left, right)
208                .map_err(|e| PyMemoryError::new_err(e.to_string()))
209        } else {
210            Err(PyValueError::new_err(format!(
211                "can't cast tuple of size {}",
212                pair.len()
213            )))
214        }
215    // List
216    } else if let Ok(list) = o.downcast::<PyList>() {
217        let mut rev = Vec::new();
218        for py_item in list.iter() {
219            rev.push(py_item);
220        }
221        let mut ret = a.nil();
222        for py_item in rev.into_iter().rev() {
223            let item = clvm_convert(a, &py_item)?;
224            ret = a
225                .new_pair(item, ret)
226                .map_err(|e| PyMemoryError::new_err(e.to_string()))?;
227        }
228        Ok(ret)
229    // SExp (such as clvm.SExp)
230    } else if let (Ok(atom), Ok(pair)) = (o.getattr("atom"), o.getattr("pair")) {
231        if atom.is_none() {
232            if pair.is_none() {
233                Err(PyTypeError::new_err(format!("invalid SExp item {o}")))
234            } else {
235                let pair = pair.downcast::<PyTuple>()?;
236                let left = clvm_convert(a, &pair.get_item(0)?)?;
237                let right = clvm_convert(a, &pair.get_item(1)?)?;
238                a.new_pair(left, right)
239                    .map_err(|e| PyMemoryError::new_err(e.to_string()))
240            }
241        } else {
242            a.new_atom(atom.extract::<&[u8]>()?)
243                .map_err(|e| PyMemoryError::new_err(e.to_string()))
244        }
245    // Program itself. This is interpreted as a program in serialized form, and
246    // just a buffer of that serialization. This is an optimization to finding
247    // __bytes__() and calling it
248    } else if let Ok(prg) = o.extract::<Program>() {
249        a.new_atom(prg.0.as_slice())
250            .map_err(|e| PyMemoryError::new_err(e.to_string()))
251    // anything convertible to bytes
252    } else if let Ok(fun) = o.getattr("__bytes__") {
253        let bytes = fun.call0()?;
254        let buffer = bytes.extract::<&[u8]>()?;
255        a.new_atom(buffer)
256            .map_err(|e| PyMemoryError::new_err(e.to_string()))
257    } else {
258        Err(PyTypeError::new_err(format!(
259            "unknown parameter to run_with_cost() {o}"
260        )))
261    }
262}
263
264#[cfg(feature = "py-bindings")]
265fn clvm_serialize(a: &mut Allocator, o: &Bound<'_, PyAny>) -> PyResult<NodePtr> {
266    /*
267    When passing arguments to run(), there's some special treatment, before falling
268    back on the regular python -> CLVM conversion (implemented by clvm_convert
269    above). This function mimics the _serialize() function in python:
270
271       def _serialize(node: object) -> bytes:
272           if isinstance(node, list):
273               serialized_list = bytearray()
274               for a in node:
275                   serialized_list += b"\xff"
276                   serialized_list += _serialize(a)
277               serialized_list += b"\x80"
278               return bytes(serialized_list)
279           if type(node) is SerializedProgram:
280               return bytes(node)
281           if type(node) is Program:
282               return bytes(node)
283           else:
284               ret: bytes = SExp.to(node).as_bin()
285               return ret
286    */
287
288    // List
289    if let Ok(list) = o.downcast::<PyList>() {
290        let mut rev = Vec::new();
291        for py_item in list.iter() {
292            rev.push(py_item);
293        }
294        let mut ret = a.nil();
295        for py_item in rev.into_iter().rev() {
296            let item = clvm_serialize(a, &py_item)?;
297            ret = a
298                .new_pair(item, ret)
299                .map_err(|e| PyMemoryError::new_err(e.to_string()))?;
300        }
301        Ok(ret)
302    // Program itself
303    } else if let Ok(prg) = o.extract::<Program>() {
304        node_from_bytes_backrefs(a, prg.0.as_slice()).map_err(map_pyerr)
305    } else {
306        clvm_convert(a, o)
307    }
308}
309
310#[cfg(feature = "py-bindings")]
311#[allow(clippy::needless_pass_by_value)]
312#[pymethods]
313impl Program {
314    #[pyo3(name = "default")]
315    #[staticmethod]
316    fn py_default() -> Self {
317        Self::default()
318    }
319
320    #[staticmethod]
321    #[pyo3(name = "to")]
322    fn py_to(args: &Bound<'_, PyAny>) -> PyResult<Program> {
323        let mut a = Allocator::new_limited(500_000_000);
324        let clvm = clvm_convert(&mut a, args)?;
325        Program::from_clvm(&a, clvm)
326            .map_err(|error| PyErr::new::<PyTypeError, _>(error.to_string()))
327    }
328
329    fn get_tree_hash(&self) -> crate::Bytes32 {
330        clvm_utils::tree_hash_from_bytes(self.0.as_ref())
331            .unwrap()
332            .into()
333    }
334
335    #[staticmethod]
336    fn fromhex(h: String) -> Result<Self> {
337        let s = if let Some(st) = h.strip_prefix("0x") {
338            st
339        } else {
340            &h[..]
341        };
342        Self::from_bytes(hex::decode(s).map_err(|_| Error::InvalidString)?.as_slice())
343    }
344
345    fn run_rust(
346        &self,
347        py: Python<'_>,
348        max_cost: u64,
349        flags: u32,
350        args: &Bound<'_, PyAny>,
351    ) -> PyResult<(u64, LazyNode)> {
352        use clvmr::reduction::Response;
353
354        let mut a = Allocator::new_limited(500_000_000);
355        // The python behavior here is a bit messy, and is best not emulated
356        // on the rust side. We must be able to pass a Program as an argument,
357        // and it being treated as the CLVM structure it represents. In python's
358        // SerializedProgram, we have a hack where we interpret the first
359        // "layer" of SerializedProgram, or lists of SerializedProgram this way.
360        // But if we encounter an Optional or tuple, we defer to the clvm
361        // wheel's conversion function to SExp. This level does not have any
362        // special treatment for SerializedProgram (as that would cause a
363        // circular dependency).
364        let clvm_args = clvm_serialize(&mut a, args)?;
365
366        let r: Response = (|| -> PyResult<Response> {
367            let program = node_from_bytes_backrefs(&mut a, self.0.as_ref()).map_err(map_pyerr)?;
368            let dialect = ChiaDialect::new(flags);
369
370            Ok(py.allow_threads(|| run_program(&mut a, &dialect, program, clvm_args, max_cost)))
371        })()?;
372        match r {
373            Ok(reduction) => {
374                let val = LazyNode::new(Rc::new(a), reduction.1);
375                Ok((reduction.0, val))
376            }
377            Err(eval_err) => {
378                let blob = node_to_bytes(&a, eval_err.node_ptr()).ok().map(hex::encode);
379                Err(PyValueError::new_err((eval_err.to_string(), blob)))
380            }
381        }
382    }
383
384    fn uncurry_rust(&self) -> PyResult<(LazyNode, LazyNode)> {
385        let mut a = Allocator::new_limited(500_000_000);
386        let prg = node_from_bytes_backrefs(&mut a, self.0.as_ref()).map_err(map_pyerr)?;
387        let Ok(uncurried) = CurriedProgram::<NodePtr, NodePtr>::from_clvm(&a, prg) else {
388            let a = Rc::new(a);
389            let prg = LazyNode::new(a.clone(), prg);
390            let ret = a.nil();
391            let ret = LazyNode::new(a, ret);
392            return Ok((prg, ret));
393        };
394
395        let mut curried_args = Vec::<NodePtr>::new();
396        let mut args = uncurried.args;
397        loop {
398            if let SExp::Atom = a.sexp(args) {
399                break;
400            }
401            // the args of curried puzzles are in the form of:
402            // (c . ((q . <arg1>) . (<rest> . ())))
403            let (_, ((_, arg), (rest, ()))) =
404                <(
405                    clvm_traits::MatchByte<4>,
406                    (clvm_traits::match_quote!(NodePtr), (NodePtr, ())),
407                ) as FromClvm<Allocator>>::from_clvm(&a, args)
408                .map_err(|error| PyErr::new::<PyTypeError, _>(error.to_string()))?;
409            curried_args.push(arg);
410            args = rest;
411        }
412        let mut ret = a.nil();
413        for item in curried_args.into_iter().rev() {
414            ret = a.new_pair(item, ret).map_err(|_e| Error::EndOfBuffer)?;
415        }
416        let a = Rc::new(a);
417        let prg = LazyNode::new(a.clone(), uncurried.program);
418        let ret = LazyNode::new(a, ret);
419        Ok((prg, ret))
420    }
421}
422
423impl Streamable for Program {
424    fn update_digest(&self, digest: &mut Sha256) {
425        digest.update(&self.0);
426    }
427
428    fn stream(&self, out: &mut Vec<u8>) -> Result<()> {
429        out.extend_from_slice(self.0.as_ref());
430        Ok(())
431    }
432
433    fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> Result<Self> {
434        let pos = input.position();
435        let buf: &[u8] = &input.get_ref()[pos as usize..];
436        let len = if TRUSTED {
437            serialized_length_from_bytes_trusted(buf).map_err(|_e| Error::EndOfBuffer)?
438        } else {
439            serialized_length_from_bytes(buf).map_err(|_e| Error::EndOfBuffer)?
440        };
441        if buf.len() < len as usize {
442            return Err(Error::EndOfBuffer);
443        }
444        let program = buf[..len as usize].to_vec();
445        input.set_position(pos + len);
446        Ok(Program(program.into()))
447    }
448}
449
450#[cfg(feature = "py-bindings")]
451impl ToJsonDict for Program {
452    fn to_json_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
453        self.0.to_json_dict(py)
454    }
455}
456
457#[cfg(feature = "py-bindings")]
458#[pymethods]
459impl Program {
460    #[classmethod]
461    #[pyo3(name = "from_parent")]
462    pub fn from_parent(_cls: &Bound<'_, PyType>, _instance: &Self) -> PyResult<PyObject> {
463        Err(PyNotImplementedError::new_err(
464            "This class does not support from_parent().",
465        ))
466    }
467}
468
469#[cfg(feature = "py-bindings")]
470impl FromJsonDict for Program {
471    fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
472        let bytes = Bytes::from_json_dict(o)?;
473        let len =
474            serialized_length_from_bytes(bytes.as_slice()).map_err(|_e| Error::EndOfBuffer)?;
475        if len as usize != bytes.len() {
476            // If the bytes in the JSON string is not a valid CLVM
477            // serialization, or if it has garbage at the end of the string,
478            // reject it
479            return Err(Error::InvalidClvm)?;
480        }
481        Ok(Self(bytes))
482    }
483}
484
485impl FromClvm<Allocator> for Program {
486    fn from_clvm(a: &Allocator, node: NodePtr) -> std::result::Result<Self, FromClvmError> {
487        Ok(Self(
488            node_to_bytes(a, node)
489                .map_err(|error| FromClvmError::Custom(error.to_string()))?
490                .into(),
491        ))
492    }
493}
494
495impl ToClvm<Allocator> for Program {
496    fn to_clvm(&self, a: &mut Allocator) -> std::result::Result<NodePtr, ToClvmError> {
497        node_from_bytes(a, self.0.as_ref()).map_err(|error| ToClvmError::Custom(error.to_string()))
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn program_roundtrip() {
507        let a = &mut Allocator::new();
508        let expected = "ff01ff02ff62ff0480";
509        let expected_bytes = hex::decode(expected).unwrap();
510
511        let ptr = node_from_bytes(a, &expected_bytes).unwrap();
512        let program = Program::from_clvm(a, ptr).unwrap();
513
514        let round_trip = program.to_clvm(a).unwrap();
515        assert_eq!(expected, hex::encode(node_to_bytes(a, round_trip).unwrap()));
516    }
517
518    #[test]
519    fn program_run() {
520        let a = &mut Allocator::new();
521
522        // (+ 2 5)
523        let prg = Program::from_bytes(&hex::decode("ff10ff02ff0580").expect("hex::decode"))
524            .expect("from_bytes");
525        let (cost, result) = prg.run(a, 0, 1000, &[1300, 37]).expect("run");
526        assert_eq!(cost, 869);
527        assert_eq!(a.number(result), 1337.into());
528    }
529}
530
531#[cfg(all(test, feature = "serde"))]
532mod serde_tests {
533    use super::*;
534
535    #[test]
536    fn test_program_is_bytes() -> anyhow::Result<()> {
537        let bytes = Bytes::new(vec![1, 2, 3]);
538        let program = Program::new(bytes.clone());
539
540        let bytes_json = serde_json::to_string(&bytes)?;
541        let program_json = serde_json::to_string(&program)?;
542
543        assert_eq!(program_json, bytes_json);
544
545        Ok(())
546    }
547}