chia_protocol/
program.rs

1#[cfg(feature = "py-bindings")]
2use crate::LazyNode;
3use crate::bytes::Bytes;
4use chia_sha2::Sha256;
5use chia_traits::Streamable;
6use chia_traits::chia_error::{Error, Result};
7use clvm_traits::{FromClvm, FromClvmError, ToClvm, ToClvmError};
8#[cfg(feature = "py-bindings")]
9use clvmr::SExp;
10use clvmr::allocator::NodePtr;
11use clvmr::cost::Cost;
12use clvmr::error::EvalErr;
13use clvmr::run_program;
14use clvmr::serde::{
15    node_from_bytes, node_from_bytes_backrefs, node_to_bytes, serialized_length_from_bytes,
16    serialized_length_from_bytes_trusted,
17};
18use clvmr::{Allocator, ChiaDialect};
19#[cfg(feature = "py-bindings")]
20use pyo3::prelude::*;
21#[cfg(feature = "py-bindings")]
22use pyo3::types::{PyList, PyTuple, PyType};
23use std::io::Cursor;
24use std::ops::Deref;
25#[cfg(feature = "py-bindings")]
26use std::rc::Rc;
27
28#[cfg(feature = "py-bindings")]
29use clvm_utils::CurriedProgram;
30
31#[cfg_attr(feature = "py-bindings", pyclass(subclass), derive(PyStreamable))]
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct Program(Bytes);
35
36impl Default for Program {
37    fn default() -> Self {
38        Self(vec![0x80].into())
39    }
40}
41
42impl Program {
43    pub fn new(bytes: Bytes) -> Self {
44        Self(bytes)
45    }
46
47    pub fn len(&self) -> usize {
48        self.0.len()
49    }
50
51    pub fn is_empty(&self) -> bool {
52        self.0.is_empty()
53    }
54
55    pub fn as_slice(&self) -> &[u8] {
56        self.0.as_slice()
57    }
58
59    pub fn to_vec(&self) -> Vec<u8> {
60        self.0.to_vec()
61    }
62
63    pub fn into_inner(self) -> Bytes {
64        self.0
65    }
66
67    pub fn into_bytes(self) -> Vec<u8> {
68        self.0.into_inner()
69    }
70
71    pub fn run<A: ToClvm<Allocator>>(
72        &self,
73        a: &mut Allocator,
74        flags: u32,
75        max_cost: Cost,
76        arg: &A,
77    ) -> std::result::Result<(Cost, NodePtr), EvalErr> {
78        let arg = arg.to_clvm(a).map_err(|_| {
79            EvalErr::InvalidAllocArg(
80                a.nil(),
81                "failed to convert argument to CLVM objects".to_string(),
82            )
83        })?;
84        let program =
85            node_from_bytes_backrefs(a, self.0.as_ref()).expect("invalid SerializedProgram");
86        let dialect = ChiaDialect::new(flags);
87        let reduction = run_program(a, &dialect, program, arg, max_cost)?;
88        Ok((reduction.0, reduction.1))
89    }
90}
91
92impl From<Bytes> for Program {
93    fn from(value: Bytes) -> Self {
94        Self(value)
95    }
96}
97
98impl From<Program> for Bytes {
99    fn from(value: Program) -> Self {
100        value.0
101    }
102}
103
104impl From<Vec<u8>> for Program {
105    fn from(value: Vec<u8>) -> Self {
106        Self(Bytes::new(value))
107    }
108}
109
110impl From<&[u8]> for Program {
111    fn from(value: &[u8]) -> Self {
112        Self(value.into())
113    }
114}
115
116impl From<Program> for Vec<u8> {
117    fn from(value: Program) -> Self {
118        value.0.into()
119    }
120}
121
122impl AsRef<[u8]> for Program {
123    fn as_ref(&self) -> &[u8] {
124        self.0.as_ref()
125    }
126}
127
128impl Deref for Program {
129    type Target = [u8];
130
131    fn deref(&self) -> &[u8] {
132        &self.0
133    }
134}
135
136#[cfg(feature = "arbitrary")]
137impl<'a> arbitrary::Arbitrary<'a> for Program {
138    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
139        // generate an arbitrary CLVM structure. Not likely a valid program.
140        let mut items_left = 1;
141        let mut total_items = 0;
142        let mut buf = Vec::<u8>::with_capacity(200);
143
144        while items_left > 0 {
145            if total_items < 100 && u.ratio(1, 4).unwrap() {
146                // make a pair
147                buf.push(0xff);
148                items_left += 2;
149            } else {
150                // make an atom. just single bytes for now
151                buf.push(u.int_in_range(0..=0x80).unwrap());
152            }
153            total_items += 1;
154            items_left -= 1;
155        }
156        Ok(Self(buf.into()))
157    }
158}
159
160#[cfg(feature = "py-bindings")]
161use chia_traits::{FromJsonDict, ToJsonDict};
162
163#[cfg(feature = "py-bindings")]
164use chia_py_streamable_macro::PyStreamable;
165
166#[cfg(feature = "py-bindings")]
167use pyo3::exceptions::*;
168
169#[cfg(feature = "py-bindings")]
170#[allow(clippy::needless_pass_by_value)]
171// We use this for the map function, so using a reference here is not ideal.
172fn map_pyerr(err: EvalErr) -> PyErr {
173    // Convert EvalErr to PyErr, so that it can be used in python bindings
174    PyValueError::new_err(err.to_string())
175}
176
177// TODO: this conversion function should probably be converted to a type holding
178// the PyAny object implementing the ToClvm trait. That way, the Program::to()
179// function could turn a python structure directly into bytes, without taking
180// the detour via Allocator. propagating python errors through ToClvmError is a
181// bit tricky though
182#[cfg(feature = "py-bindings")]
183fn clvm_convert(a: &mut Allocator, o: &Bound<'_, PyAny>) -> PyResult<NodePtr> {
184    // None
185    if o.is_none() {
186        Ok(a.nil())
187    // bytes
188    } else if let Ok(buffer) = o.extract::<&[u8]>() {
189        a.new_atom(buffer)
190            .map_err(|e| PyMemoryError::new_err(e.to_string()))
191    // str
192    } else if let Ok(text) = o.extract::<String>() {
193        a.new_atom(text.as_bytes())
194            .map_err(|e| PyMemoryError::new_err(e.to_string()))
195    // int
196    } else if let Ok(val) = o.extract::<clvmr::number::Number>() {
197        a.new_number(val)
198            .map_err(|e| PyMemoryError::new_err(e.to_string()))
199    // Tuple (SExp-like)
200    } else if let Ok(pair) = o.cast::<PyTuple>() {
201        if pair.len() == 2 {
202            let left = clvm_convert(a, &pair.get_item(0)?)?;
203            let right = clvm_convert(a, &pair.get_item(1)?)?;
204            a.new_pair(left, right)
205                .map_err(|e| PyMemoryError::new_err(e.to_string()))
206        } else {
207            Err(PyValueError::new_err(format!(
208                "can't cast tuple of size {}",
209                pair.len()
210            )))
211        }
212    // List
213    } else if let Ok(list) = o.cast::<PyList>() {
214        let mut rev = Vec::new();
215        for py_item in list.iter() {
216            rev.push(py_item);
217        }
218        let mut ret = a.nil();
219        for py_item in rev.into_iter().rev() {
220            let item = clvm_convert(a, &py_item)?;
221            ret = a
222                .new_pair(item, ret)
223                .map_err(|e| PyMemoryError::new_err(e.to_string()))?;
224        }
225        Ok(ret)
226    // SExp (such as clvm.SExp)
227    } else if let (Ok(atom), Ok(pair)) = (o.getattr("atom"), o.getattr("pair")) {
228        if atom.is_none() {
229            if pair.is_none() {
230                Err(PyTypeError::new_err(format!("invalid SExp item {o}")))
231            } else {
232                let pair = pair.cast::<PyTuple>()?;
233                let left = clvm_convert(a, &pair.get_item(0)?)?;
234                let right = clvm_convert(a, &pair.get_item(1)?)?;
235                a.new_pair(left, right)
236                    .map_err(|e| PyMemoryError::new_err(e.to_string()))
237            }
238        } else {
239            a.new_atom(atom.extract::<&[u8]>()?)
240                .map_err(|e| PyMemoryError::new_err(e.to_string()))
241        }
242    // Program itself. This is interpreted as a program in serialized form, and
243    // just a buffer of that serialization. This is an optimization to finding
244    // __bytes__() and calling it
245    } else if let Ok(prg) = o.extract::<Program>() {
246        a.new_atom(prg.0.as_slice())
247            .map_err(|e| PyMemoryError::new_err(e.to_string()))
248    // anything convertible to bytes
249    } else if let Ok(fun) = o.getattr("__bytes__") {
250        let bytes = fun.call0()?;
251        let buffer = bytes.extract::<&[u8]>()?;
252        a.new_atom(buffer)
253            .map_err(|e| PyMemoryError::new_err(e.to_string()))
254    } else {
255        Err(PyTypeError::new_err(format!(
256            "unknown parameter to run_with_cost() {o}"
257        )))
258    }
259}
260
261#[cfg(feature = "py-bindings")]
262fn clvm_serialize(a: &mut Allocator, o: &Bound<'_, PyAny>) -> PyResult<NodePtr> {
263    /*
264    When passing arguments to run(), there's some special treatment, before falling
265    back on the regular python -> CLVM conversion (implemented by clvm_convert
266    above). This function mimics the _serialize() function in python:
267
268       def _serialize(node: object) -> bytes:
269           if isinstance(node, list):
270               serialized_list = bytearray()
271               for a in node:
272                   serialized_list += b"\xff"
273                   serialized_list += _serialize(a)
274               serialized_list += b"\x80"
275               return bytes(serialized_list)
276           if type(node) is SerializedProgram:
277               return bytes(node)
278           if type(node) is Program:
279               return bytes(node)
280           else:
281               ret: bytes = SExp.to(node).as_bin()
282               return ret
283    */
284
285    // List
286    if let Ok(list) = o.cast::<PyList>() {
287        let mut rev = Vec::new();
288        for py_item in list.iter() {
289            rev.push(py_item);
290        }
291        let mut ret = a.nil();
292        for py_item in rev.into_iter().rev() {
293            let item = clvm_serialize(a, &py_item)?;
294            ret = a
295                .new_pair(item, ret)
296                .map_err(|e| PyMemoryError::new_err(e.to_string()))?;
297        }
298        Ok(ret)
299    // Program itself
300    } else if let Ok(prg) = o.extract::<Program>() {
301        node_from_bytes_backrefs(a, prg.0.as_slice()).map_err(map_pyerr)
302    } else {
303        clvm_convert(a, o)
304    }
305}
306
307#[cfg(feature = "py-bindings")]
308#[allow(clippy::needless_pass_by_value)]
309#[pymethods]
310impl Program {
311    #[pyo3(name = "default")]
312    #[staticmethod]
313    fn py_default() -> Self {
314        Self::default()
315    }
316
317    #[staticmethod]
318    #[pyo3(name = "to")]
319    fn py_to(args: &Bound<'_, PyAny>) -> PyResult<Program> {
320        let mut a = Allocator::new_limited(500_000_000);
321        let clvm = clvm_convert(&mut a, args)?;
322        Program::from_clvm(&a, clvm)
323            .map_err(|error| PyErr::new::<PyTypeError, _>(error.to_string()))
324    }
325
326    fn get_tree_hash(&self) -> crate::Bytes32 {
327        clvm_utils::tree_hash_from_bytes(self.0.as_ref())
328            .unwrap()
329            .into()
330    }
331
332    #[staticmethod]
333    fn fromhex(h: String) -> Result<Self> {
334        let s = if let Some(st) = h.strip_prefix("0x") {
335            st
336        } else {
337            &h[..]
338        };
339        Self::from_bytes(hex::decode(s).map_err(|_| Error::InvalidString)?.as_slice())
340    }
341
342    fn run_rust(
343        &self,
344        py: Python<'_>,
345        max_cost: u64,
346        flags: u32,
347        args: &Bound<'_, PyAny>,
348    ) -> PyResult<(u64, LazyNode)> {
349        use clvmr::reduction::Response;
350
351        let mut a = Allocator::new_limited(500_000_000);
352        // The python behavior here is a bit messy, and is best not emulated
353        // on the rust side. We must be able to pass a Program as an argument,
354        // and it being treated as the CLVM structure it represents. In python's
355        // SerializedProgram, we have a hack where we interpret the first
356        // "layer" of SerializedProgram, or lists of SerializedProgram this way.
357        // But if we encounter an Optional or tuple, we defer to the clvm
358        // wheel's conversion function to SExp. This level does not have any
359        // special treatment for SerializedProgram (as that would cause a
360        // circular dependency).
361        let clvm_args = clvm_serialize(&mut a, args)?;
362
363        let r: Response = (|| -> PyResult<Response> {
364            let program = node_from_bytes_backrefs(&mut a, self.0.as_ref()).map_err(map_pyerr)?;
365            let dialect = ChiaDialect::new(flags);
366
367            Ok(py.detach(|| run_program(&mut a, &dialect, program, clvm_args, max_cost)))
368        })()?;
369        match r {
370            Ok(reduction) => {
371                let val = LazyNode::new(Rc::new(a), reduction.1);
372                Ok((reduction.0, val))
373            }
374            Err(eval_err) => {
375                let blob = node_to_bytes(&a, eval_err.node_ptr()).ok().map(hex::encode);
376                Err(PyValueError::new_err((eval_err.to_string(), blob)))
377            }
378        }
379    }
380
381    fn uncurry_rust(&self) -> PyResult<(LazyNode, LazyNode)> {
382        let mut a = Allocator::new_limited(500_000_000);
383        let prg = node_from_bytes_backrefs(&mut a, self.0.as_ref()).map_err(map_pyerr)?;
384        let Ok(uncurried) = CurriedProgram::<NodePtr, NodePtr>::from_clvm(&a, prg) else {
385            let a = Rc::new(a);
386            let prg = LazyNode::new(a.clone(), prg);
387            let ret = a.nil();
388            let ret = LazyNode::new(a, ret);
389            return Ok((prg, ret));
390        };
391
392        let mut curried_args = Vec::<NodePtr>::new();
393        let mut args = uncurried.args;
394        loop {
395            if let SExp::Atom = a.sexp(args) {
396                break;
397            }
398            // the args of curried puzzles are in the form of:
399            // (c . ((q . <arg1>) . (<rest> . ())))
400            let (_, ((_, arg), (rest, ()))) =
401                <(
402                    clvm_traits::MatchByte<4>,
403                    (clvm_traits::match_quote!(NodePtr), (NodePtr, ())),
404                ) as FromClvm<Allocator>>::from_clvm(&a, args)
405                .map_err(|error| PyErr::new::<PyTypeError, _>(error.to_string()))?;
406            curried_args.push(arg);
407            args = rest;
408        }
409        let mut ret = a.nil();
410        for item in curried_args.into_iter().rev() {
411            ret = a.new_pair(item, ret).map_err(|_e| Error::EndOfBuffer)?;
412        }
413        let a = Rc::new(a);
414        let prg = LazyNode::new(a.clone(), uncurried.program);
415        let ret = LazyNode::new(a, ret);
416        Ok((prg, ret))
417    }
418}
419
420impl Streamable for Program {
421    fn update_digest(&self, digest: &mut Sha256) {
422        digest.update(&self.0);
423    }
424
425    fn stream(&self, out: &mut Vec<u8>) -> Result<()> {
426        out.extend_from_slice(self.0.as_ref());
427        Ok(())
428    }
429
430    fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> Result<Self> {
431        let pos = input.position();
432        let buf: &[u8] = &input.get_ref()[pos as usize..];
433        let len = if TRUSTED {
434            serialized_length_from_bytes_trusted(buf).map_err(|_e| Error::EndOfBuffer)?
435        } else {
436            serialized_length_from_bytes(buf).map_err(|_e| Error::EndOfBuffer)?
437        };
438        if buf.len() < len as usize {
439            return Err(Error::EndOfBuffer);
440        }
441        let program = buf[..len as usize].to_vec();
442        input.set_position(pos + len);
443        Ok(Program(program.into()))
444    }
445}
446
447#[cfg(feature = "py-bindings")]
448impl ToJsonDict for Program {
449    fn to_json_dict(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
450        self.0.to_json_dict(py)
451    }
452}
453
454#[cfg(feature = "py-bindings")]
455#[pymethods]
456impl Program {
457    #[classmethod]
458    #[pyo3(name = "from_parent")]
459    pub fn from_parent(_cls: &Bound<'_, PyType>, _instance: &Self) -> PyResult<Py<PyAny>> {
460        Err(PyNotImplementedError::new_err(
461            "This class does not support from_parent().",
462        ))
463    }
464}
465
466#[cfg(feature = "py-bindings")]
467impl FromJsonDict for Program {
468    fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
469        let bytes = Bytes::from_json_dict(o)?;
470        let len =
471            serialized_length_from_bytes(bytes.as_slice()).map_err(|_e| Error::EndOfBuffer)?;
472        if len as usize != bytes.len() {
473            // If the bytes in the JSON string is not a valid CLVM
474            // serialization, or if it has garbage at the end of the string,
475            // reject it
476            return Err(Error::InvalidClvm)?;
477        }
478        Ok(Self(bytes))
479    }
480}
481
482impl FromClvm<Allocator> for Program {
483    fn from_clvm(a: &Allocator, node: NodePtr) -> std::result::Result<Self, FromClvmError> {
484        Ok(Self(
485            node_to_bytes(a, node)
486                .map_err(|error| FromClvmError::Custom(error.to_string()))?
487                .into(),
488        ))
489    }
490}
491
492impl ToClvm<Allocator> for Program {
493    fn to_clvm(&self, a: &mut Allocator) -> std::result::Result<NodePtr, ToClvmError> {
494        node_from_bytes(a, self.0.as_ref()).map_err(|error| ToClvmError::Custom(error.to_string()))
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn program_roundtrip() {
504        let a = &mut Allocator::new();
505        let expected = "ff01ff02ff62ff0480";
506        let expected_bytes = hex::decode(expected).unwrap();
507
508        let ptr = node_from_bytes(a, &expected_bytes).unwrap();
509        let program = Program::from_clvm(a, ptr).unwrap();
510
511        let round_trip = program.to_clvm(a).unwrap();
512        assert_eq!(expected, hex::encode(node_to_bytes(a, round_trip).unwrap()));
513    }
514
515    #[test]
516    fn program_run() {
517        let a = &mut Allocator::new();
518
519        // (+ 2 5)
520        let prg = Program::from_bytes(&hex::decode("ff10ff02ff0580").expect("hex::decode"))
521            .expect("from_bytes");
522        let (cost, result) = prg.run(a, 0, 1000, &[1300, 37]).expect("run");
523        assert_eq!(cost, 869);
524        assert_eq!(a.number(result), 1337.into());
525    }
526}
527
528#[cfg(all(test, feature = "serde"))]
529mod serde_tests {
530    use super::*;
531
532    #[test]
533    fn test_program_is_bytes() -> anyhow::Result<()> {
534        let bytes = Bytes::new(vec![1, 2, 3]);
535        let program = Program::new(bytes.clone());
536
537        let bytes_json = serde_json::to_string(&bytes)?;
538        let program_json = serde_json::to_string(&program)?;
539
540        assert_eq!(program_json, bytes_json);
541
542        Ok(())
543    }
544}