Skip to main content

midnight_onchain_state/
state.rs

1// This file is part of midnight-ledger.
2// Copyright (C) 2025 Midnight Foundation
3// SPDX-License-Identifier: Apache-2.0
4// Licensed under the Apache License, Version 2.0 (the "License");
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use base_crypto::cost_model::RunningCost;
15use base_crypto::fab::{Aligned, AlignedValue, Alignment, AlignmentAtom};
16use base_crypto::hash::{HashOutput, persistent_commit};
17use base_crypto::repr::MemWrite;
18use base_crypto::signatures::VerifyingKey;
19use coin_structure::coin::TokenType;
20use derive_where::derive_where;
21use fake::Dummy;
22use hex::ToHex;
23#[cfg(feature = "proptest")]
24use proptest::arbitrary::Arbitrary;
25#[cfg(feature = "proptest")]
26use proptest_derive::Arbitrary;
27use rand::Rng;
28use rand::distributions::{Distribution, Standard};
29use serde::{
30    Deserialize, Deserializer, Serialize, Serializer, de, de::MapAccess, de::SeqAccess,
31    de::Visitor, ser::SerializeStruct,
32};
33#[cfg(feature = "proptest")]
34use serialize::NoStrategy;
35#[cfg(feature = "proptest")]
36use serialize::randomised_serialization_test;
37#[cfg(feature = "proptest")]
38use serialize::simple_arbitrary;
39use serialize::{self, Deserializable, Serializable, Tagged, tag_enforcement_test};
40use std::borrow::Borrow;
41use std::fmt::{self, Debug, Formatter};
42use std::hash::Hash;
43use std::io::{self, Read, Write};
44use std::marker::PhantomData;
45use std::ops::Deref;
46use storage::Storable;
47use storage::arena::Sp;
48use storage::db::{DB, InMemoryDB};
49use storage::delta_tracking::{incremental_write_delete_costs, initial_write_delete_costs};
50use storage::{
51    arena::ArenaKey,
52    delta_tracking::RcMap,
53    storable::Loader,
54    storage::{Array, HashMap},
55};
56use transient_crypto::curve::Fr;
57use transient_crypto::merkle_tree::MerkleTree;
58use transient_crypto::proofs::VerifierKey;
59use transient_crypto::repr::FieldRepr;
60
61#[cfg(feature = "proptest")]
62fn proptest_valid<D: DB>(value: &StateValue<D>) -> bool {
63    match value {
64        StateValue::Array(arr) => arr.len() <= 16,
65        _ => true,
66    }
67}
68
69/// The size limit for Cell's. Currently 32 kiB
70pub const CELL_BOUND: usize = 1 << 15;
71
72#[derive(Default, Storable)]
73#[derive_where(Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
74#[cfg_attr(feature = "proptest", derive(Arbitrary))]
75#[cfg_attr(feature = "proptest", proptest(filter = "proptest_valid"))]
76#[storable(db = D, invariant = StateValue::invariant)]
77#[tag = "impact-state-value[v2]"]
78#[non_exhaustive]
79pub enum StateValue<D: DB = InMemoryDB> {
80    #[default]
81    Null,
82    Cell(#[storable(child)] Sp<AlignedValue, D>),
83    Map(HashMap<AlignedValue, StateValue<D>, D>),
84    /// A fixed size array, with `0 <= len <= 16`. The upper 5 bits of the
85    /// argument to the `new` opcode specify the length at creation time. The
86    /// underlying `storage::Array` type is not fixed length, but in the VM we
87    /// only allow size preserving operations.
88    Array(Array<StateValue<D>, D>),
89    /// Merkle tree with `0 < height <= 32`.
90    BoundedMerkleTree(
91        // The `Serializable::unversioned_serialize` impl requires this.
92        #[cfg_attr(
93            feature = "proptest",
94            proptest(filter = "|mt| !(mt.height() == 0 || mt.height() > 32)")
95        )]
96        MerkleTree<(), D>,
97    ),
98}
99tag_enforcement_test!(StateValue);
100
101impl<D: DB> From<u64> for StateValue<D> {
102    fn from(value: u64) -> Self {
103        StateValue::Cell(Sp::new(value.into()))
104    }
105}
106
107// We need to manually implement `Drop` to avoid implicit unbounded recursion, which could lead to
108// stack overflows. See https://rust-unofficial.github.io/too-many-lists/first-drop.html.
109impl<D: DB> Drop for StateValue<D> {
110    fn drop(&mut self) {
111        // Early return for non-recursive types. This ensures that we have a base-case for Drop,
112        // as we'll end up recursing at least once otherwise, because we keep a queue of state
113        // values otherwise.
114        match self {
115            StateValue::Null | StateValue::Cell(_) | StateValue::BoundedMerkleTree(_) => return,
116            StateValue::Map(m) if m.size() == 0 => return,
117            StateValue::Array(a) if a.is_empty() => return,
118            _ => {}
119        }
120        // This allows us to escape from the &mut to a owned reference
121        // Note that this relies on the `Default` of `Null` falling into our base case.
122        let mut frontier = vec![std::mem::take(self)];
123        while let Some(mut curr) = frontier.pop() {
124            match &mut curr {
125                StateValue::Map(m) => {
126                    let mut tmp = HashMap::new();
127                    std::mem::swap(m, &mut tmp);
128                    frontier.extend(tmp.into_inner_for_drop().flat_map(|(_, v)| v.into_iter()));
129                }
130                StateValue::Array(a) => {
131                    let mut tmp = Array::new();
132                    std::mem::swap(a, &mut tmp);
133                    frontier.extend(tmp.into_inner_for_drop());
134                }
135                _ => {}
136            }
137            // It is now safe to drop curr, as this has an empty map/array in it.
138            drop(curr);
139        }
140    }
141}
142
143impl<D: DB> Distribution<StateValue<D>> for Standard {
144    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> StateValue<D> {
145        let disc = rng.gen_range(0..40);
146        // Converges because:
147        // - 38/40 cases do not recurse
148        // - 1/40 (Map) samples randomly between 0..8 children (-> expected 3.5 recursive calls)
149        // - 1/40 (Array) samples randomly between 0..=16 children (-> expected 8 recursive calls)
150        // => 11.5 / 40 expected recursive calls < 1
151        match disc {
152            20..=35 => StateValue::Cell(Sp::new(rng.r#gen())),
153            36..=37 => {
154                let mut mt: MerkleTree<(), D> = rng.r#gen();
155                // The `Serializable::unversioned_serialize` impl requires this.
156                while mt.height() == 0 || mt.height() > 32 {
157                    mt = rng.r#gen();
158                }
159                StateValue::BoundedMerkleTree(mt)
160            }
161            38 => StateValue::Map(rng.r#gen()),
162            39 => {
163                let len = rng.gen_range(0..=16);
164                let arr = (0..len).fold(Array::new(), |arr, _| arr.push(rng.r#gen()));
165                StateValue::Array(arr)
166            }
167            _ => StateValue::Null,
168        }
169    }
170}
171
172impl<D: DB> FieldRepr for StateValue<D> {
173    fn field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
174        use StateValue::*;
175        match self {
176            Null => writer.write(&[0.into()]),
177            Cell(v) => {
178                writer.write(&[1.into()]);
179                v.field_repr(writer);
180            }
181            Map(m) => {
182                writer.write(&[(2u128 | ((m.size() as u128) << 4)).into()]);
183                let mut sorted = m.iter().collect::<Vec<_>>();
184                sorted.sort();
185                for kv in sorted.into_iter() {
186                    kv.0.field_repr(writer);
187                    kv.1.field_repr(writer);
188                }
189            }
190            Array(arr) => {
191                writer.write(&[(3u64 | ((arr.len() as u64) << 4)).into()]);
192                for elem in arr.iter() {
193                    elem.field_repr(writer);
194                }
195            }
196            BoundedMerkleTree(t) => {
197                let entries = t.iter().collect::<Vec<_>>();
198                writer.write(&[(4u128
199                    | ((t.height() as u128) << 4)
200                    | ((entries.len() as u128) << 12))
201                    .into()]);
202                for entry in entries.into_iter() {
203                    entry.field_repr(writer);
204                }
205            }
206        }
207    }
208
209    fn field_size(&self) -> usize {
210        use StateValue::*;
211        match self {
212            Null => 1,
213            Cell(v) => 1 + v.field_size(),
214            Map(m) => {
215                1 + m
216                    .iter()
217                    .map(|kv| kv.0.field_size() + kv.1.field_size())
218                    .sum::<usize>()
219            }
220            Array(arr) => 1 + arr.iter().map(|s| s.field_size()).sum::<usize>(),
221            BoundedMerkleTree(t) => 1 + t.iter().map(|(_, v)| 1 + v.field_size()).sum::<usize>(),
222        }
223    }
224}
225
226impl<D: DB> Serialize for StateValue<D> {
227    fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
228        match self {
229            StateValue::Null => {
230                let mut ser = ser.serialize_struct("StateValue", 1)?;
231                ser.serialize_field("tag", "null")?;
232                ser.end()
233            }
234            StateValue::Cell(val) => {
235                let mut ser = ser.serialize_struct("StateValue", 2)?;
236                ser.serialize_field("tag", "cell")?;
237                ser.serialize_field("content", &**val)?;
238                ser.end()
239            }
240            StateValue::Map(val) => {
241                let mut ser = ser.serialize_struct("StateValue", 2)?;
242                ser.serialize_field("tag", "map")?;
243                ser.serialize_field("content", val)?;
244                ser.end()
245            }
246            StateValue::Array(val) => {
247                let mut ser = ser.serialize_struct("StateValue", 2)?;
248                ser.serialize_field("tag", "array")?;
249                ser.serialize_field("content", val)?;
250                ser.end()
251            }
252            StateValue::BoundedMerkleTree(val) => {
253                let mut ser = ser.serialize_struct("StateValue", 2)?;
254                ser.serialize_field("tag", "boundedMerkleTree")?;
255                ser.serialize_field("content", val)?;
256                ser.end()
257            }
258        }
259    }
260}
261
262struct StateValueVisitor<D: DB>(PhantomData<D>);
263
264impl<'de, D: DB> Visitor<'de> for StateValueVisitor<D> {
265    type Value = StateValue<D>;
266    fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
267        write!(formatter, "a state value")
268    }
269
270    fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<StateValue<D>, V::Error> {
271        let tag: String = seq
272            .next_element()?
273            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
274        match &tag[..] {
275            "null" => Ok(StateValue::Null),
276            "cell" => Ok(StateValue::Cell(Sp::new(
277                seq.next_element()?
278                    .ok_or_else(|| de::Error::invalid_length(1, &self))?,
279            ))),
280            "map" => Ok(StateValue::Map(
281                seq.next_element()?
282                    .ok_or_else(|| de::Error::invalid_length(1, &self))?,
283            )),
284            "array" => Ok(StateValue::Array(
285                seq.next_element()?
286                    .ok_or_else(|| de::Error::invalid_length(1, &self))?,
287            )),
288            "boundedMerkleTree" => Ok(StateValue::BoundedMerkleTree(
289                seq.next_element()?
290                    .ok_or_else(|| de::Error::invalid_length(1, &self))?,
291            )),
292            tag => Err(de::Error::unknown_variant(
293                tag,
294                &["null", "cell", "map", "array", "boundedMerkleTree"],
295            )),
296        }
297    }
298
299    fn visit_map<V: MapAccess<'de>>(self, mut map: V) -> Result<StateValue<D>, V::Error> {
300        let first_key: String = map
301            .next_key()?
302            .ok_or_else(|| de::Error::missing_field("tag"))?;
303        match &first_key[..] {
304            "tag" => {
305                let tag: String = map.next_value()?;
306                fn get_content<'de2, V: MapAccess<'de2>, T: Deserialize<'de2>>(
307                    map: &mut V,
308                ) -> Result<T, V::Error> {
309                    let entry: (String, T) = map
310                        .next_entry()?
311                        .ok_or_else(|| de::Error::missing_field("content"))?;
312                    if &entry.0[..] == "content" {
313                        Ok(entry.1)
314                    } else {
315                        Err(de::Error::unknown_field(&entry.0[..], &["tag", "content"]))
316                    }
317                }
318                match &tag[..] {
319                    "null" => Ok(StateValue::Null),
320                    "cell" => Ok(StateValue::Cell(Sp::new(get_content(&mut map)?))),
321                    "map" => Ok(StateValue::Map(get_content(&mut map)?)),
322                    "array" => Ok(StateValue::Array(get_content(&mut map)?)),
323                    "boundedMerkleTree" => {
324                        Ok(StateValue::BoundedMerkleTree(get_content(&mut map)?))
325                    }
326                    tag => Err(de::Error::unknown_variant(
327                        tag,
328                        &["null", "cell", "map", "array", "boundedMerkleTree"],
329                    )),
330                }
331            }
332            "content" => Err(de::Error::custom(
333                "limitation of current deserialization: StateValue tag must preceed contents",
334            )),
335            field => Err(de::Error::unknown_field(field, &["tag", "content"])),
336        }
337    }
338}
339
340impl<'de, D1: DB> Deserialize<'de> for StateValue<D1> {
341    fn deserialize<D: Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
342        de.deserialize_struct(
343            "StateValue",
344            &["tag", "content"],
345            StateValueVisitor(PhantomData),
346        )
347    }
348}
349
350#[macro_export]
351macro_rules! stval {
352    (null) => {
353        StateValue::Null
354    };
355    (($val:expr_2021)) => {
356        StateValue::Cell(Sp::new($val.into()))
357    };
358    ({MT($height:expr_2021) {$($key:expr_2021 => $val:expr_2021),*}}) => {
359        StateValue::BoundedMerkleTree(MerkleTree::blank($height)$(.update_hash($key, $val, ()))*.rehash())
360    };
361    ({$($key:expr_2021 => $val:tt),*}) => {
362        StateValue::Map(HashMap::new()$(.insert($key.into(), stval!($val)))*)
363    };
364    ({$key:expr_2021 => $val:tt}; $n:expr_2021) => {
365        {
366            StateValue::Map((0..$n).into_iter().map(|x|{
367                (AlignedValue::from($key + x as u32), stval!($val))
368            }).collect())
369        }
370    };
371    ([$($val:tt),*]) => {
372        StateValue::Array(vec![$(stval!($val)),*].into())
373    };
374    ([$elem:tt; $n:expr_2021]) => {
375        StateValue::Array(vec![stval!($elem); $n].into())
376    };
377}
378
379pub use stval;
380
381impl<D: DB> Debug for StateValue<D> {
382    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        use StateValue::*;
384        match self {
385            Null => write!(formatter, "null"),
386            Cell(v) => write!(formatter, "{v:?}"),
387            Map(m) => {
388                write!(formatter, "Map ")?;
389                formatter
390                    .debug_map()
391                    .entries(m.iter().map(|kv| (kv.0.clone(), kv.1.clone())))
392                    .finish()
393            }
394            Array(arr) => {
395                write!(formatter, "Array({}) ", arr.len())?;
396                formatter.debug_list().entries(arr.iter()).finish()
397            }
398            BoundedMerkleTree(t) => {
399                write!(formatter, "MerkleTree({}) ", t.height())?;
400                formatter.debug_map().entries(t.iter()).finish()
401            }
402        }
403    }
404}
405
406impl<D: DB> StateValue<D> {
407    fn invariant(&self) -> std::io::Result<()> {
408        match self {
409            StateValue::Null | StateValue::Map(_) => {}
410            StateValue::Cell(v) => {
411                if (**v).serialized_size() > CELL_BOUND {
412                    return Err(std::io::Error::new(
413                        std::io::ErrorKind::InvalidData,
414                        format!("Cell exceeded maximum bound of {CELL_BOUND}"),
415                    ));
416                }
417            }
418            StateValue::Array(arr) => {
419                if arr.len() > 16 {
420                    return Err(std::io::Error::new(
421                        std::io::ErrorKind::InvalidData,
422                        "Array eceeded maximum length of 16",
423                    ));
424                }
425            }
426            StateValue::BoundedMerkleTree(bmt) => {
427                if bmt.height() > 32 {
428                    return Err(std::io::Error::new(
429                        std::io::ErrorKind::InvalidData,
430                        "BMT exceeded maximum height of 32",
431                    ));
432                }
433                if bmt.height() == 0 {
434                    return Err(std::io::Error::new(
435                        std::io::ErrorKind::InvalidData,
436                        "BMT has invalid height of 0",
437                    ));
438                }
439                if bmt.root().is_none() {
440                    return Err(std::io::Error::new(
441                        std::io::ErrorKind::InvalidData,
442                        "BMT must be rehashed",
443                    ));
444                }
445            }
446        }
447        Ok(())
448    }
449
450    pub fn log_size(&self) -> usize {
451        use StateValue::*;
452        match self {
453            Null => 0,
454            Cell(a) => {
455                // TODO: this is O(n), but probably needs to be O(1).
456                //
457                // Possible fixes: cache the size of the AlignedValue in the
458                // constructor. Not sure if this "size" necessarily needs to be the serialized
459                // size.
460                <AlignedValue as Serializable>::serialized_size(&**a)
461                    .next_power_of_two()
462                    .ilog2() as usize
463            }
464            Map(m) => (m.size() as u128).next_power_of_two().ilog2() as usize,
465            Array(a) => (a.len() as u128).next_power_of_two().ilog2() as usize,
466            BoundedMerkleTree(t) => t.height() as usize,
467        }
468    }
469}
470
471impl<D: DB> From<AlignedValue> for StateValue<D> {
472    fn from(val: AlignedValue) -> StateValue<D> {
473        StateValue::Cell(Sp::new(val))
474    }
475}
476
477pub fn write_int<W: Write>(writer: &mut W, int: u64) -> io::Result<()> {
478    match int {
479        0..=0x7F => writer.write_all(&[int as u8][..]),
480        0x80..=0x3FFF => writer.write_all(&[0x80 | (int % 0x80) as u8, (int >> 7) as u8][..]),
481        0x4000..=0x1FFFFF => writer.write_all(
482            &[
483                0x80 | (int % 0x80) as u8,
484                0x80 | ((int >> 7) % 0x80) as u8,
485                (int >> 14) as u8,
486            ][..],
487        ),
488        _ => Err(io::Error::new(
489            io::ErrorKind::InvalidData,
490            "too many entries to serialize state value length!",
491        )),
492    }
493}
494
495pub fn int_size(int: u64) -> usize {
496    match int {
497        0..=0x7F => 1,
498        0x80..=0x3FFF => 2,
499        0x4000..=0x1FFFFF => 3,
500        _ => 4,
501    }
502}
503
504pub fn read_int<R: Read>(reader: &mut R) -> io::Result<u64> {
505    let mut buf = [0u8; 3];
506    reader.read_exact(&mut buf[0..1])?;
507    if (buf[0] & 0x80) == 0 {
508        return Ok(buf[0] as u64);
509    }
510    reader.read_exact(&mut buf[1..2])?;
511    if (buf[1] & 0x80) == 0 {
512        return Ok((buf[0] & 0x7f) as u64 | ((buf[1] as u64) << 7));
513    }
514    reader.read_exact(&mut buf[2..3])?;
515    if (buf[2] & 0x80) == 0 {
516        Ok((buf[0] & 0x7f) as u64 | (((buf[1] & 0x7f) as u64) << 7) | ((buf[2] as u64) << 14))
517    } else {
518        Err(io::Error::new(
519            io::ErrorKind::InvalidData,
520            "reserved range for deserializing state value length",
521        ))
522    }
523}
524
525enum MaybeStr<'a> {
526    Str(&'a str),
527    Bytes(&'a [u8]),
528}
529
530struct MaybeStrVisitor<T>(PhantomData<T>);
531
532impl<'de, T: From<Vec<u8>>> serde::de::Visitor<'de> for MaybeStrVisitor<T> {
533    type Value = T;
534    fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
535        formatter.write_str("[byte]string")
536    }
537
538    fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
539        self.visit_string(v.to_owned())
540    }
541
542    fn visit_string<E: serde::de::Error>(self, v: String) -> Result<Self::Value, E> {
543        Ok(v.into_bytes().into())
544    }
545
546    fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
547        self.visit_byte_buf(v.to_vec())
548    }
549
550    fn visit_byte_buf<E: serde::de::Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
551        Ok(v.into())
552    }
553    // Required for serde_json compatibility. See
554    // https://github.com/serde-rs/json/pull/557
555    fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
556        let mut res = Vec::new();
557        while let Some(byte) = seq.next_element()? {
558            res.push(byte);
559        }
560        Ok(res.into())
561    }
562}
563
564fn maybe_str(buf: &[u8]) -> MaybeStr<'_> {
565    // For alphanumeric characters, as well as the following: '+-_":/\?#$%^*&.
566    // we will use a string as-is. For others, we will use byte enocding.
567    // This is to permit arbitrary bytes, while presenting strings to users
568    // where sensible.
569    fn permitted(c: u8) -> bool {
570        c.is_ascii_alphanumeric() || b"'+-_\":/\\?#$^*&.".contains(&c)
571    }
572    if buf.iter().copied().all(permitted)
573        && let Ok(s) = std::str::from_utf8(buf)
574    {
575        return MaybeStr::Str(s);
576    }
577    MaybeStr::Bytes(buf)
578}
579
580macro_rules! idty {
581    ($refty:ident, $bufty:ident) => {
582        pub type $refty<'a> = &'a [u8];
583
584        #[derive(
585            FieldRepr, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Serializable, Dummy, Storable,
586        )]
587        #[storable(base)]
588        #[cfg_attr(feature = "proptest", derive(Arbitrary))]
589        pub struct $bufty(pub Vec<u8>);
590
591        impl Serialize for $bufty {
592            fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
593                match maybe_str(&self.0) {
594                    MaybeStr::Str(s) => serializer.serialize_str(s),
595                    MaybeStr::Bytes(b) => serializer.serialize_bytes(b),
596                }
597            }
598        }
599
600        impl From<Vec<u8>> for $bufty {
601            fn from(vec: Vec<u8>) -> $bufty {
602                $bufty(vec)
603            }
604        }
605
606        impl<'de> Deserialize<'de> for $bufty {
607            fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
608                deserializer.deserialize_any(MaybeStrVisitor(PhantomData))
609            }
610        }
611
612        impl Debug for $bufty {
613            fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
614                match maybe_str(&self.0) {
615                    MaybeStr::Str(s) => formatter.write_str(s),
616                    MaybeStr::Bytes(b) => formatter.write_str(&b.encode_hex::<String>()),
617                }
618            }
619        }
620
621        impl Deref for $bufty {
622            type Target = [u8];
623            fn deref(&self) -> &[u8] {
624                &self.0
625            }
626        }
627
628        impl Borrow<[u8]> for $bufty {
629            fn borrow(&self) -> &[u8] {
630                &self.0
631            }
632        }
633
634        impl From<&[u8]> for $bufty {
635            fn from(e: &[u8]) -> $bufty {
636                $bufty(e.to_owned())
637            }
638        }
639    };
640}
641
642idty!(EntryPoint, EntryPointBuf);
643#[cfg(feature = "proptest")]
644randomised_serialization_test!(EntryPointBuf);
645
646impl Distribution<EntryPointBuf> for Standard {
647    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> EntryPointBuf {
648        let length = rng.gen_range(0..10);
649        EntryPointBuf(
650            vec![0; length]
651                .iter()
652                .map(|_| rng.r#gen::<u8>())
653                .collect::<Vec<u8>>()
654                .to_owned(),
655        )
656    }
657}
658
659impl Tagged for EntryPointBuf {
660    fn tag() -> std::borrow::Cow<'static, str> {
661        "entry-point".into()
662    }
663    fn tag_unique_factor() -> String {
664        "vec(u8)".into()
665    }
666}
667tag_enforcement_test!(EntryPointBuf);
668
669impl EntryPointBuf {
670    pub fn ep_hash(&self) -> HashOutput {
671        persistent_commit(
672            &self[..],
673            HashOutput(*b"midnight:entry-point\0\0\0\0\0\0\0\0\0\0\0\0"),
674        )
675    }
676}
677
678impl Aligned for EntryPointBuf {
679    fn alignment() -> Alignment {
680        Alignment::singleton(AlignmentAtom::Compress)
681    }
682}
683
684#[derive(
685    Clone,
686    Debug,
687    PartialEq,
688    Eq,
689    PartialOrd,
690    Ord,
691    Hash,
692    Serializable,
693    Storable,
694    Serialize,
695    Deserialize,
696)]
697#[storable(base)]
698#[tag = "contract-maintenance-authority[v1]"]
699#[cfg_attr(feature = "proptest", derive(Arbitrary))]
700pub struct ContractMaintenanceAuthority {
701    pub committee: Vec<VerifyingKey>,
702    pub threshold: u32,
703    pub counter: u32,
704}
705tag_enforcement_test!(ContractMaintenanceAuthority);
706
707impl ContractMaintenanceAuthority {
708    pub fn new() -> Self {
709        ContractMaintenanceAuthority {
710            committee: vec![],
711            threshold: 1,
712            counter: 0,
713        }
714    }
715}
716
717impl Default for ContractMaintenanceAuthority {
718    fn default() -> Self {
719        Self::new()
720    }
721}
722
723#[derive(Storable)]
724#[derive_where(Clone, PartialEq, Eq)]
725#[storable(db = D)]
726#[cfg_attr(feature = "proptest", derive(Arbitrary))]
727#[tag = "contract-state[v6]"]
728pub struct ContractState<D: DB> {
729    pub data: ChargedState<D>,
730    pub operations: HashMap<EntryPointBuf, ContractOperation, D>,
731    pub maintenance_authority: ContractMaintenanceAuthority,
732    pub balance: HashMap<TokenType, u128, D>,
733}
734tag_enforcement_test!(ContractState<InMemoryDB>);
735
736#[derive(Storable)]
737#[derive_where(Clone, PartialEq, Eq)]
738#[storable(db = D)]
739#[cfg_attr(feature = "proptest", derive(Arbitrary))]
740#[tag = "charged-state[v1]"]
741pub struct ChargedState<D: DB> {
742    #[cfg(feature = "public-internal-structure")]
743    pub state: Sp<StateValue<D>, D>,
744    #[cfg(not(feature = "public-internal-structure"))]
745    pub(crate) state: Sp<StateValue<D>, D>,
746    // TODO: it would be better to generate charged keys from `data`, since it's
747    // an invariant that the chargable contract state is always a subset of the
748    // `charged_keys`. I assume this implies a manual `Arbitrary`
749    // implementation, but maybe this is some `proptest` magic that supports
750    // deriving this ...
751    #[cfg(feature = "public-internal-structure")]
752    #[cfg_attr(feature = "proptest", proptest(value = "RcMap::default()"))]
753    pub charged_keys: RcMap<D>,
754    #[cfg_attr(feature = "proptest", proptest(value = "RcMap::default()"))]
755    #[cfg(not(feature = "public-internal-structure"))]
756    pub(crate) charged_keys: RcMap<D>,
757}
758tag_enforcement_test!(ChargedState<InMemoryDB>);
759
760impl<D: DB> Debug for ChargedState<D> {
761    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
762        self.state.fmt(f)
763    }
764}
765
766impl<D: DB> ChargedState<D> {
767    /// Creates a new charged state from a given state value. This assumes that
768    /// this state's storage is paid for elsewhere, and therefore the resulting
769    /// `ChargedState` *is* accounted for in its storage usage.
770    ///
771    /// Specifically, for contract deployments, the happens with a manual
772    /// `tree_copy` costing of `ContractDeploy` operations.
773    pub fn new(state: StateValue<D>) -> Self {
774        let state = Sp::new(state);
775        let charged_keys =
776            initial_write_delete_costs(&[state.as_child()].into_iter().collect(), |_, _| {
777                Default::default()
778            })
779            .updated_charged_keys;
780        ChargedState {
781            state,
782            charged_keys,
783        }
784    }
785
786    pub fn get(&self) -> Sp<StateValue<D>, D> {
787        self.state.clone()
788    }
789
790    pub fn get_ref(&self) -> &StateValue<D> {
791        &self.state
792    }
793
794    pub fn update(
795        &self,
796        new_state: StateValue<D>,
797        cpu_cost: impl Fn(u64, u64) -> RunningCost,
798        gc_limit: impl FnOnce(RunningCost) -> usize,
799    ) -> (Self, RunningCost) {
800        // WARNING: Need to be sure the old and new StateValue state is in the
801        // backend before doing calcs over their keys. The old state is already
802        // in the backend, because contract states get persisted after contract
803        // calls. But the new state we're working with now has not been
804        // persisted yet, indeed, it may never be, e.g. if we run out of gas
805        // when we cost its writes+deletes.
806        //
807        // This sp creation here should be cheap, since we quickly run into sps
808        // under the covers. However, the top level of the StateValue state is
809        // *not* an sp. Another solution would be to require the top-level of
810        // the StateValue state itself be an sp, e.g. by wrapping all fields of
811        // the ContractState in sps.
812        let new_state = Sp::new(new_state);
813        let results = incremental_write_delete_costs(
814            &self.charged_keys,
815            &[new_state.as_child()].into_iter().collect(),
816            cpu_cost,
817            gc_limit,
818        );
819        let cost = results.running_cost();
820        let state = ChargedState {
821            state: new_state,
822            charged_keys: results.updated_charged_keys,
823        };
824        (state, cost)
825    }
826}
827
828impl<D: DB> ContractState<D> {
829    pub fn new(
830        data: StateValue<D>,
831        operations: HashMap<EntryPointBuf, ContractOperation, D>,
832        maintenance_authority: ContractMaintenanceAuthority,
833    ) -> Self {
834        ContractState {
835            data: ChargedState::new(data),
836            operations,
837            maintenance_authority,
838            balance: HashMap::default(),
839        }
840    }
841}
842
843impl<D: DB> Debug for ContractState<D> {
844    fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
845        write!(formatter, "ContractState (")?;
846        self.data.state.fmt(formatter)?;
847        self.operations.fmt(formatter)?;
848        write!(formatter, "ContractState )")?;
849        Ok(())
850    }
851}
852
853impl<D: DB> Default for ContractState<D> {
854    fn default() -> Self {
855        Self::new(
856            StateValue::Null,
857            HashMap::new(),
858            ContractMaintenanceAuthority::default(),
859        )
860    }
861}
862
863#[derive(
864    Serializable, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Storable,
865)]
866#[storable(base)]
867#[tag = "contract-operation[v4]"]
868#[non_exhaustive]
869pub struct ContractOperation {
870    pub v2: Option<VerifierKey>,
871}
872tag_enforcement_test!(ContractOperation);
873
874impl ContractOperation {
875    pub fn new(vk: Option<VerifierKey>) -> Self {
876        ContractOperation { v2: vk }
877    }
878
879    pub fn latest(&self) -> Option<&VerifierKey> {
880        self.v2.as_ref()
881    }
882
883    pub fn latest_mut(&mut self) -> &mut Option<VerifierKey> {
884        &mut self.v2
885    }
886}
887
888#[cfg(feature = "proptest")]
889simple_arbitrary!(ContractOperation);
890#[cfg(feature = "proptest")]
891randomised_serialization_test!(ContractOperation);
892
893impl Distribution<ContractOperation> for Standard {
894    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> ContractOperation {
895        let some: bool = rng.r#gen();
896        if some {
897            ContractOperation {
898                v2: Some(rng.r#gen()),
899            }
900        } else {
901            ContractOperation { v2: None }
902        }
903    }
904}
905
906impl FieldRepr for ContractOperation {
907    fn field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
908        match self.v2 {
909            Some(ref vk) => {
910                writer.write(&[0x01.into()]);
911                let mut bytes: Vec<u8> = Vec::new();
912                <VerifierKey as Serializable>::serialize(vk, &mut bytes)
913                    .expect("VerifierKey is serializable");
914                bytes.field_repr(writer);
915            }
916            None => writer.write(&[0x00.into()]),
917        }
918    }
919
920    fn field_size(&self) -> usize {
921        match self.v2 {
922            Some(ref vk) => {
923                let mut bytes: Vec<u8> = Vec::new();
924                <VerifierKey as Serializable>::serialize(vk, &mut bytes)
925                    .expect("VerifierKey is serializable");
926                1 + bytes.into_iter().fold(0, |acc, b| acc + b.field_size())
927            }
928            None => 1,
929        }
930    }
931}
932
933impl Debug for ContractOperation {
934    fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
935        write!(formatter, "<verifier key>")
936    }
937}
938
939impl<F> Dummy<F> for ContractOperation {
940    fn dummy_with_rng<R: rand::Rng + ?Sized>(_config: &F, _rng: &mut R) -> Self {
941        ContractOperation { v2: None }
942    }
943}
944
945#[cfg(test)]
946mod tests {
947    use storage::db::InMemoryDB;
948
949    use super::*;
950
951    fn test_compact_int(x: u64) {
952        let mut bytes = Vec::new();
953        write_int(&mut bytes, x).unwrap();
954        let ptr = &mut &bytes[..];
955        let y = read_int(ptr).unwrap();
956        assert_eq!(x, y);
957        assert!(ptr.is_empty());
958    }
959
960    #[test]
961    fn test_ints() {
962        test_compact_int(0x0);
963        test_compact_int(0x1);
964        test_compact_int(0x42);
965        test_compact_int(0x80);
966        test_compact_int(0xff);
967        test_compact_int(0x100);
968        test_compact_int(0x1000);
969        test_compact_int(0x10000);
970    }
971
972    #[test]
973    fn test_nested_drop() {
974        let mut sv: StateValue<InMemoryDB> = StateValue::Null;
975        for i in 0..12_000 {
976            sv = StateValue::Array(vec![sv].into());
977            //sv = StateValue::Map(default_storage().new_map().insert(0u8.into(), sv));
978            if i % 100 == 0 {
979                dbg!(i);
980            }
981        }
982        drop(sv);
983        println!("drop(sv) finished!");
984    }
985
986    fn test_ser<T: Serializable + Deserializable + Eq + Debug>(val: T) {
987        dbg!(&val);
988        let mut bytes = Vec::new();
989        T::serialize(&val, &mut bytes).unwrap();
990        assert_eq!(bytes.len(), T::serialized_size(&val));
991        let mut b = bytes.as_slice();
992        let copy = T::deserialize(&mut b, 0).unwrap();
993        assert_eq!(b.bytes().count(), 0);
994        assert_eq!(val, copy);
995    }
996
997    #[test]
998    fn test_state_ser() {
999        let cs = ChargedState::<InMemoryDB>::new(StateValue::Null);
1000        test_ser::<RcMap<InMemoryDB>>(cs.charged_keys);
1001        test_ser::<ContractState<InMemoryDB>>(ContractState::default());
1002        test_ser::<StateValue<InMemoryDB>>(stval!((512u64)));
1003        test_ser::<StateValue<InMemoryDB>>(stval!({ 512u64 => (12u64) }));
1004        test_ser::<StateValue<InMemoryDB>>(stval!([(512u64)]));
1005        test_ser::<StateValue<InMemoryDB>>(stval!(null));
1006        test_ser::<StateValue<InMemoryDB>>(stval!({MT(12) {}}));
1007    }
1008
1009    #[test]
1010    fn test_log_size() {
1011        use transient_crypto::merkle_tree::MerkleTree;
1012
1013        // Like stval, but force database param to be InMemoryDB
1014        macro_rules! s {
1015            ($($tt:tt)*) => {
1016                {
1017                    let sv: StateValue<InMemoryDB> = stval!($($tt)*);
1018                    sv
1019                }
1020            };
1021        }
1022
1023        assert_eq!(s!(null).log_size(), 0);
1024
1025        assert_eq!(s!((0u8)).log_size(), 1);
1026        assert_eq!(s!((0u16)).log_size(), 1);
1027        assert_eq!(s!((0u32)).log_size(), 1);
1028        assert_eq!(s!((0u64)).log_size(), 1);
1029
1030        assert_eq!(s!({}).log_size(), 0);
1031        assert_eq!(s!({ 0u32 => (1u32) }; 3).log_size(), 2);
1032        assert_eq!(s!({ 0u32 => (1u32) }; 4).log_size(), 2);
1033        assert_eq!(s!({ 0u32 => (1u32) }; 5).log_size(), 3);
1034        assert_eq!(s!({ 0u32 => (1u32) }; 7).log_size(), 3);
1035        assert_eq!(s!({ 0u32 => (1u32) }; 8).log_size(), 3);
1036        assert_eq!(s!({ 0u32 => (1u32) }; 9).log_size(), 4);
1037
1038        assert_eq!(s!([]).log_size(), 0);
1039        assert_eq!(s!([(1u32); 3]).log_size(), 2);
1040        assert_eq!(s!([(1u32); 4]).log_size(), 2);
1041        assert_eq!(s!([(1u32); 7]).log_size(), 3);
1042        assert_eq!(s!([(1u32); 8]).log_size(), 3);
1043        assert_eq!(s!([(1u32); 15]).log_size(), 4);
1044        assert_eq!(s!([(1u32); 16]).log_size(), 4);
1045
1046        for h in 0..16 {
1047            assert_eq!(s!({MT(h) {}}).log_size(), h as usize);
1048        }
1049    }
1050}