Skip to main content

midnight_onchain_runtime/
context.rs

1// This file is part of midnight-ledger.
2// Copyright (C) 2025 Midnight Foundation
3// SPDX-License-Identifier: Apache-2.0
4// Licensed under the Apache License, Version 2.0 (the "License");
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::cost_model::CostModel;
15use crate::error::TranscriptRejected;
16use crate::ops::Op;
17use crate::result_mode::ResultMode;
18use crate::state::StateValue;
19use crate::transcript::Transcript;
20use crate::vm::run_program;
21use crate::vm_value::{ValueStrength, VmValue};
22use base_crypto::cost_model::RunningCost;
23use base_crypto::fab::{Aligned, Alignment};
24use base_crypto::fab::{InvalidBuiltinDecode, Value, ValueSlice};
25use base_crypto::hash::HashOutput;
26use base_crypto::time::Timestamp;
27use coin_structure::coin::PublicAddress;
28use coin_structure::coin::UserAddress;
29use coin_structure::coin::{
30    Commitment as CoinCommitment, Info as CoinInfo, Nullifier, QualifiedInfo as QualifiedCoinInfo,
31    TokenType,
32};
33use coin_structure::coin::{ShieldedTokenType, UnshieldedTokenType};
34use coin_structure::contract::ContractAddress;
35use coin_structure::transfer::Recipient;
36use derive_where::derive_where;
37use hex::FromHexError;
38use hex::{FromHex, ToHex};
39use onchain_runtime_state::state::ChargedState;
40use onchain_vm::error::OnchainProgramError;
41use onchain_vm::result_mode::ResultModeVerify;
42#[cfg(feature = "proptest")]
43use proptest_derive::Arbitrary;
44use rand::Rng;
45use rand::distributions::Standard;
46use rand::prelude::Distribution;
47use serde::{
48    de::{Deserialize, Deserializer},
49    ser::{Serialize, Serializer},
50};
51use serialize::{self, Deserializable, Serializable, Tagged, tag_enforcement_test};
52use std::collections::{HashMap, HashSet};
53use std::hash::Hash;
54use std::ops::Deref;
55use storage::arena::{ArenaKey, Sp};
56use storage::db::DB;
57use storage::storage::Map;
58use storage::{Storable, storable::Loader};
59use transient_crypto::curve::Fr;
60
61// Need to: Convert to SerdeBlockContext / SerdeEffects
62
63#[derive(serde::Serialize, serde::Deserialize)]
64#[serde(rename_all = "camelCase")]
65struct SerdeBlockContext {
66    seconds_since_epoch: u64,
67    seconds_since_epoch_err: u32,
68    parent_block_hash: String,
69    last_block_time: u64,
70}
71
72#[derive(serde::Serialize, serde::Deserialize)]
73#[serde(rename_all = "camelCase")]
74struct SerdeCallContext {
75    own_address: String,
76    seconds_since_epoch: u64,
77    seconds_since_epoch_err: u32,
78    parent_block_hash: String,
79    caller: Option<SerdePublicAddress>,
80    balance: HashMap<SerdeTokenType, u128>,
81    com_indices: HashMap<String, u64>,
82    last_block_time: u64,
83}
84
85#[derive(serde::Serialize, serde::Deserialize, Hash, PartialEq, Eq, PartialOrd, Ord)]
86#[serde(rename_all = "camelCase")]
87struct SerdeTokenType {
88    tag: String,
89    raw: Option<String>,
90}
91
92const SERDE_UNSHIELDED_TAG: &str = "unshielded";
93const SERDE_SHIELDED_TAG: &str = "shielded";
94const SERDE_DUST_TAG: &str = "dust";
95
96#[derive(serde::Serialize, serde::Deserialize, Hash, PartialEq, Eq, PartialOrd, Ord)]
97#[serde(rename_all = "camelCase")]
98struct SerdePublicAddress {
99    tag: String,
100    address: String,
101}
102
103const SERDE_CONTRACT_TAG: &str = "contract";
104const SERDE_USER_TAG: &str = "user";
105
106fn hex_from_tt(tt: TokenType) -> SerdeTokenType {
107    let (variant, val) = match tt {
108        TokenType::Unshielded(unshielded_token_type) => (
109            SERDE_UNSHIELDED_TAG.to_string(),
110            Some(unshielded_token_type.0),
111        ),
112        TokenType::Shielded(shielded_token_type) => {
113            (SERDE_SHIELDED_TAG.to_string(), Some(shielded_token_type.0))
114        }
115        TokenType::Dust => (SERDE_DUST_TAG.to_string(), None),
116    };
117
118    SerdeTokenType {
119        tag: variant,
120        raw: val.map(|v| v.0.encode_hex()),
121    }
122}
123
124fn tt_from_hex(serde_token_type: SerdeTokenType) -> Result<TokenType, std::io::Error> {
125    let hash_output = serde_token_type
126        .raw
127        .map(|raw| Ok::<_, std::io::Error>(HashOutput(FromHex::from_hex(raw).map_err(err_conv)?)))
128        .transpose()?;
129
130    match (serde_token_type.tag.as_str(), hash_output) {
131        (SERDE_UNSHIELDED_TAG, Some(hash_output)) => {
132            Ok(TokenType::Unshielded(UnshieldedTokenType(hash_output)))
133        }
134        (SERDE_SHIELDED_TAG, Some(hash_output)) => {
135            Ok(TokenType::Shielded(ShieldedTokenType(hash_output)))
136        }
137        (SERDE_SHIELDED_TAG, None) | (SERDE_UNSHIELDED_TAG, None) => Err(std::io::Error::new(
138            std::io::ErrorKind::InvalidData,
139            format!(
140                "expected raw data with tag {}, but got none",
141                serde_token_type.tag
142            ),
143        )),
144        (SERDE_DUST_TAG, Some(_)) => Err(std::io::Error::new(
145            std::io::ErrorKind::InvalidData,
146            format!(
147                "expected no raw data with tag {}, but got some",
148                serde_token_type.tag
149            ),
150        )),
151        (SERDE_DUST_TAG, None) => Ok(TokenType::Dust),
152        _ => Err(std::io::Error::new(
153            std::io::ErrorKind::InvalidData,
154            format!(
155                "Incorrect discriminant, expected one of \"unshielded\", \"shielded\", or \"dust\"; got {}",
156                serde_token_type.tag
157            ),
158        ))?,
159    }
160}
161
162fn public_address_from_hex(
163    serde_public_address: SerdePublicAddress,
164) -> Result<PublicAddress, std::io::Error> {
165    let mut address_bytes =
166        &mut &<Vec<u8>>::from_hex(serde_public_address.address.as_bytes()).map_err(err_conv)?[..];
167
168    Ok(match serde_public_address.tag.as_str() {
169        SERDE_CONTRACT_TAG => {
170            let addr = <ContractAddress as Deserializable>::deserialize(&mut address_bytes, 0)?;
171            ensure_fully_deserialized(address_bytes)?;
172            PublicAddress::Contract(addr)
173        }
174        SERDE_USER_TAG => {
175            let addr = <UserAddress as Deserializable>::deserialize(&mut address_bytes, 0)?;
176            ensure_fully_deserialized(address_bytes)?;
177            PublicAddress::User(addr)
178        }
179        _ => Err(std::io::Error::new(
180            std::io::ErrorKind::InvalidData,
181            format!(
182                "Incorrect discriminant, expected \"contract\" or \"user\", got {}",
183                serde_public_address.tag
184            ),
185        ))?,
186    })
187}
188
189fn public_address_to_hex(public_address: PublicAddress) -> SerdePublicAddress {
190    let mut addr_vec = Vec::new();
191
192    let variant = match public_address {
193        PublicAddress::Contract(contract_address) => {
194            <ContractAddress as Serializable>::serialize(&contract_address, &mut addr_vec)
195                .expect("In-memory serialization should succeed");
196            SERDE_CONTRACT_TAG
197        }
198        PublicAddress::User(user_address) => {
199            <UserAddress as Serializable>::serialize(&user_address, &mut addr_vec)
200                .expect("In-memory serialization should succeed");
201            SERDE_USER_TAG
202        }
203    };
204
205    SerdePublicAddress {
206        tag: variant.to_string(),
207        address: addr_vec.encode_hex(),
208    }
209}
210
211fn err_conv(err: FromHexError) -> std::io::Error {
212    std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string())
213}
214
215impl From<BlockContext> for SerdeBlockContext {
216    fn from(ctxt: BlockContext) -> SerdeBlockContext {
217        SerdeBlockContext {
218            seconds_since_epoch: ctxt.tblock.to_secs(),
219            seconds_since_epoch_err: ctxt.tblock_err,
220            parent_block_hash: ctxt.parent_block_hash.0.encode_hex(),
221            last_block_time: ctxt.last_block_time.to_secs(),
222        }
223    }
224}
225
226impl TryFrom<SerdeBlockContext> for BlockContext {
227    type Error = std::io::Error;
228
229    fn try_from(ctxt: SerdeBlockContext) -> Result<BlockContext, std::io::Error> {
230        let hash =
231            <[u8; base_crypto::hash::PERSISTENT_HASH_BYTES]>::from_hex(ctxt.parent_block_hash)
232                .map_err(err_conv)?;
233        Ok(BlockContext {
234            tblock: Timestamp::from_secs(ctxt.seconds_since_epoch),
235            tblock_err: ctxt.seconds_since_epoch_err,
236            parent_block_hash: HashOutput(hash),
237            last_block_time: Timestamp::from_secs(ctxt.last_block_time),
238        })
239    }
240}
241
242impl<D: DB> From<CallContext<D>> for SerdeCallContext {
243    fn from(ctxt: CallContext<D>) -> SerdeCallContext {
244        let mut own_address_vec = Vec::new();
245        <ContractAddress as Serializable>::serialize(&ctxt.own_address, &mut own_address_vec)
246            .expect("In-memory serialization should succeed");
247        SerdeCallContext {
248            own_address: own_address_vec.encode_hex(),
249            seconds_since_epoch: ctxt.tblock.to_secs(),
250            seconds_since_epoch_err: ctxt.tblock_err,
251            parent_block_hash: ctxt.parent_block_hash.0.encode_hex(),
252            caller: ctxt.caller.map(public_address_to_hex),
253            balance: ctxt
254                .balance
255                .iter()
256                .map(|tt_x_val| (hex_from_tt(*tt_x_val.0.deref()), *tt_x_val.1.deref()))
257                .collect(),
258            com_indices: ctxt
259                .com_indices
260                .iter()
261                .map(|(com, val)| (com.0.0.encode_hex(), *val))
262                .collect(),
263            last_block_time: ctxt.last_block_time.to_secs(),
264        }
265    }
266}
267
268impl<D: DB> TryFrom<SerdeCallContext> for CallContext<D> {
269    type Error = std::io::Error;
270
271    fn try_from(ctxt: SerdeCallContext) -> Result<CallContext<D>, std::io::Error> {
272        let block_hash =
273            <[u8; base_crypto::hash::PERSISTENT_HASH_BYTES]>::from_hex(ctxt.parent_block_hash)
274                .map_err(err_conv)?;
275        let mut own_address_bytes =
276            &mut &<Vec<u8>>::from_hex(ctxt.own_address.as_bytes()).map_err(err_conv)?[..];
277        let own_address =
278            <ContractAddress as Deserializable>::deserialize(&mut own_address_bytes, 0)?;
279        ensure_fully_deserialized(own_address_bytes)?;
280
281        let caller = ctxt.caller.map(public_address_from_hex).transpose()?;
282        Ok(CallContext {
283            own_address,
284            tblock: Timestamp::from_secs(ctxt.seconds_since_epoch),
285            tblock_err: ctxt.seconds_since_epoch_err,
286            parent_block_hash: HashOutput(block_hash),
287            caller,
288            balance: ctxt
289                .balance
290                .into_iter()
291                .map(|(tt, val)| Ok::<_, std::io::Error>((tt_from_hex(tt)?, val)))
292                .collect::<Result<storage::storage::HashMap<TokenType, u128, D>, _>>()?,
293            com_indices: ctxt
294                .com_indices
295                .into_iter()
296                .map(|(com, val)| {
297                    Ok::<_, std::io::Error>((
298                        CoinCommitment(HashOutput(FromHex::from_hex(com).map_err(err_conv)?)),
299                        val,
300                    ))
301                })
302                .collect::<Result<Map<CoinCommitment, u64>, _>>()?,
303            last_block_time: Timestamp::from_secs(ctxt.last_block_time),
304        })
305    }
306}
307
308#[derive_where(Clone, Debug, Default)]
309pub struct CallContext<D: DB> {
310    pub own_address: ContractAddress,
311    pub tblock: Timestamp,
312    pub tblock_err: u32,
313    pub parent_block_hash: HashOutput,
314    pub caller: Option<PublicAddress>,
315    pub balance: storage::storage::HashMap<TokenType, u128, D>,
316    pub com_indices: Map<CoinCommitment, u64>,
317    pub last_block_time: Timestamp,
318}
319
320impl<D: DB> Serialize for CallContext<D> {
321    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
322    where
323        S: Serializer,
324    {
325        let ser_effects: SerdeCallContext = self.clone().into();
326        <SerdeCallContext as Serialize>::serialize(&ser_effects, serializer)
327    }
328}
329
330impl<'de, DD: DB> Deserialize<'de> for CallContext<DD> {
331    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
332    where
333        D: Deserializer<'de>,
334    {
335        let ser_effects = <SerdeCallContext as Deserialize>::deserialize(deserializer)?;
336        CallContext::try_from(ser_effects).map_err(serde::de::Error::custom)
337    }
338}
339
340#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize, Serializable)]
341#[tag = "block-context[v2]"]
342#[serde(try_from = "SerdeBlockContext", into = "SerdeBlockContext")]
343pub struct BlockContext {
344    pub tblock: Timestamp,
345    pub tblock_err: u32,
346    pub parent_block_hash: HashOutput,
347    pub last_block_time: Timestamp,
348}
349tag_enforcement_test!(BlockContext);
350
351#[derive(serde::Serialize, serde::Deserialize)]
352#[serde(rename_all = "camelCase")]
353struct SerdeEffects {
354    claimed_nullifiers: HashSet<String>,
355    claimed_shielded_receives: HashSet<String>,
356    claimed_shielded_spends: HashSet<String>,
357    claimed_contract_calls: HashSet<(u64, String, String, Fr)>,
358    shielded_mints: HashMap<String, u64>,
359    unshielded_mints: HashMap<String, u64>,
360    unshielded_inputs: HashMap<SerdeTokenType, u128>,
361    unshielded_outputs: HashMap<SerdeTokenType, u128>,
362    claimed_unshielded_spends: HashMap<(SerdeTokenType, SerdePublicAddress), u128>,
363}
364
365impl<D: DB> From<Effects<D>> for SerdeEffects {
366    fn from(eff: Effects<D>) -> SerdeEffects {
367        SerdeEffects {
368            claimed_nullifiers: eff
369                .claimed_nullifiers
370                .iter()
371                .map(|n| n.0.0.encode_hex())
372                .collect(),
373            claimed_shielded_receives: eff
374                .claimed_shielded_receives
375                .iter()
376                .map(|cm| cm.0.0.encode_hex())
377                .collect(),
378            claimed_shielded_spends: eff
379                .claimed_shielded_spends
380                .iter()
381                .map(|cm| cm.0.0.encode_hex())
382                .collect(),
383            claimed_contract_calls: eff
384                .claimed_contract_calls
385                .iter()
386                .map(|sp| {
387                    let (seq, addr, ep_hash, comm_hash) = sp.deref().into_inner();
388                    let mut addr_bytes = Vec::new();
389                    Serializable::serialize(&addr, &mut addr_bytes)
390                        .expect("In-memory serialization must succeed");
391                    (
392                        seq,
393                        addr_bytes.encode_hex(),
394                        ep_hash.0.encode_hex(),
395                        comm_hash,
396                    )
397                })
398                .collect(),
399            shielded_mints: eff
400                .shielded_mints
401                .into_iter()
402                .map(|(tt, val)| (tt.0.encode_hex(), val))
403                .collect(),
404            unshielded_mints: eff
405                .unshielded_mints
406                .into_iter()
407                .map(|(tt, val)| (tt.0.encode_hex(), val))
408                .collect(),
409            unshielded_inputs: eff
410                .unshielded_inputs
411                .into_iter()
412                .map(|(tt, val)| (hex_from_tt(tt), val))
413                .collect(),
414            unshielded_outputs: eff
415                .unshielded_outputs
416                .into_iter()
417                .map(|(tt, val)| (hex_from_tt(tt), val))
418                .collect(),
419            claimed_unshielded_spends: eff
420                .claimed_unshielded_spends
421                .into_iter()
422                .map(|(spends_key, val)| {
423                    let (tt, addr) = spends_key.into_inner();
424                    ((hex_from_tt(tt), public_address_to_hex(addr)), val)
425                })
426                .collect(),
427        }
428    }
429}
430
431impl<D: DB> TryFrom<SerdeEffects> for Effects<D> {
432    type Error = std::io::Error;
433
434    fn try_from(eff: SerdeEffects) -> Result<Effects<D>, std::io::Error> {
435        Ok(Effects {
436            claimed_nullifiers: eff
437                .claimed_nullifiers
438                .into_iter()
439                .map(|n| Ok::<_, FromHexError>(Nullifier(HashOutput(FromHex::from_hex(n)?))))
440                .collect::<Result<_, _>>()
441                .map_err(err_conv)?,
442            claimed_shielded_receives: eff
443                .claimed_shielded_receives
444                .into_iter()
445                .map(|cm| Ok::<_, FromHexError>(CoinCommitment(HashOutput(FromHex::from_hex(cm)?))))
446                .collect::<Result<_, _>>()
447                .map_err(err_conv)?,
448            claimed_shielded_spends: eff
449                .claimed_shielded_spends
450                .into_iter()
451                .map(|cm| Ok::<_, FromHexError>(CoinCommitment(HashOutput(FromHex::from_hex(cm)?))))
452                .collect::<Result<_, _>>()
453                .map_err(err_conv)?,
454            claimed_contract_calls: eff
455                .claimed_contract_calls
456                .into_iter()
457                .map(|(seq, addr, ep_hash, comm_hash)| {
458                    let addr_bytes: Vec<u8> = FromHex::from_hex(addr).map_err(err_conv)?;
459                    Ok::<_, std::io::Error>(ClaimedContractCallsValue(
460                        seq,
461                        Deserializable::deserialize(&mut &addr_bytes[..], 0)?,
462                        HashOutput(FromHex::from_hex(ep_hash).map_err(err_conv)?),
463                        comm_hash,
464                    ))
465                })
466                .collect::<Result<_, _>>()?,
467            shielded_mints: eff
468                .shielded_mints
469                .into_iter()
470                .map(|(tt, val)| Ok::<_, FromHexError>((HashOutput(FromHex::from_hex(tt)?), val)))
471                .collect::<Result<_, _>>()
472                .map_err(err_conv)?,
473            unshielded_mints: eff
474                .unshielded_mints
475                .into_iter()
476                .map(|(tt, val)| Ok::<_, FromHexError>((HashOutput(FromHex::from_hex(tt)?), val)))
477                .collect::<Result<_, _>>()
478                .map_err(err_conv)?,
479            unshielded_inputs: eff
480                .unshielded_inputs
481                .into_iter()
482                .map(|(tt, val)| Ok::<_, std::io::Error>((tt_from_hex(tt)?, val)))
483                .collect::<Result<storage::storage::HashMap<TokenType, u128, D>, _>>()?,
484            unshielded_outputs: eff
485                .unshielded_outputs
486                .into_iter()
487                .map(|(tt, val)| Ok::<_, std::io::Error>((tt_from_hex(tt)?, val)))
488                .collect::<Result<storage::storage::HashMap<TokenType, u128, D>, _>>()?,
489            claimed_unshielded_spends:
490                eff.claimed_unshielded_spends
491                    .into_iter()
492                    .map(|((tt, addr), val)| {
493                        Ok::<_, std::io::Error>((
494                            ClaimedUnshieldedSpendsKey(
495                                tt_from_hex(tt)?,
496                                public_address_from_hex(addr)?,
497                            ),
498                            val,
499                        ))
500                    })
501                    .collect::<Result<
502                        storage::storage::HashMap<ClaimedUnshieldedSpendsKey, u128, D>,
503                        _,
504                    >>()?,
505        })
506    }
507}
508
509#[derive(
510    Clone, Debug, PartialEq, Eq, Serializable, Storable, Hash, serde::Serialize, serde::Deserialize,
511)]
512#[storable(base)]
513#[tag = "contract-effects-claimed-unshielded-spends-key[v1]"]
514pub struct ClaimedUnshieldedSpendsKey(pub TokenType, pub PublicAddress);
515tag_enforcement_test!(ClaimedUnshieldedSpendsKey);
516
517impl ClaimedUnshieldedSpendsKey {
518    pub fn into_inner(&self) -> (TokenType, PublicAddress) {
519        (self.0, self.1)
520    }
521
522    pub fn from_inner(tt: TokenType, addr: PublicAddress) -> ClaimedUnshieldedSpendsKey {
523        ClaimedUnshieldedSpendsKey(tt, addr)
524    }
525}
526
527impl From<ClaimedUnshieldedSpendsKey> for Value {
528    fn from(val: ClaimedUnshieldedSpendsKey) -> Value {
529        let v1: Value = val.0.into();
530        let v2: Value = val.1.into();
531        Value::concat([&v1, &v2])
532    }
533}
534
535impl TryFrom<&ValueSlice> for ClaimedUnshieldedSpendsKey {
536    type Error = InvalidBuiltinDecode;
537
538    fn try_from(value: &ValueSlice) -> Result<ClaimedUnshieldedSpendsKey, InvalidBuiltinDecode> {
539        if value.0.len() == 6 {
540            Ok(ClaimedUnshieldedSpendsKey(
541                (&value[0..3]).try_into()?,
542                (&value[3..6]).try_into()?,
543            ))
544        } else {
545            Err(InvalidBuiltinDecode("ClaimedUnshieldedSpendsKey"))
546        }
547    }
548}
549
550impl Aligned for ClaimedUnshieldedSpendsKey {
551    fn alignment() -> Alignment {
552        Alignment::concat([&TokenType::alignment(), &PublicAddress::alignment()])
553    }
554}
555
556impl Distribution<ClaimedUnshieldedSpendsKey> for Standard {
557    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> ClaimedUnshieldedSpendsKey {
558        ClaimedUnshieldedSpendsKey(rng.r#gen(), rng.r#gen())
559    }
560}
561
562#[derive(
563    Clone,
564    Debug,
565    Default,
566    PartialEq,
567    Eq,
568    PartialOrd,
569    Ord,
570    Serializable,
571    Storable,
572    Hash,
573    serde::Serialize,
574    serde::Deserialize,
575)]
576#[storable(base)]
577#[tag = "contract-effects-claimed-contract-calls-value[v1]"]
578pub struct ClaimedContractCallsValue(pub u64, pub ContractAddress, pub HashOutput, pub Fr);
579tag_enforcement_test!(ClaimedContractCallsValue);
580
581impl ClaimedContractCallsValue {
582    pub fn into_inner(&self) -> (u64, ContractAddress, HashOutput, Fr) {
583        (self.0, self.1, self.2, self.3)
584    }
585
586    pub fn from_inner(
587        pos: u64,
588        addr: ContractAddress,
589        hash: HashOutput,
590        rnd: Fr,
591    ) -> ClaimedContractCallsValue {
592        ClaimedContractCallsValue(pos, addr, hash, rnd)
593    }
594}
595
596impl From<ClaimedContractCallsValue> for Value {
597    fn from(val: ClaimedContractCallsValue) -> Value {
598        let v1: Value = val.0.into();
599        let v2: Value = val.1.into();
600        let v3: Value = val.2.into();
601        let v4: Value = val.3.into();
602        Value::concat([&v1, &v2, &v3, &v4])
603    }
604}
605
606impl TryFrom<&ValueSlice> for ClaimedContractCallsValue {
607    type Error = InvalidBuiltinDecode;
608
609    fn try_from(value: &ValueSlice) -> Result<ClaimedContractCallsValue, InvalidBuiltinDecode> {
610        if value.0.len() == 4 {
611            Ok(ClaimedContractCallsValue(
612                (&value.0[0]).try_into()?,
613                (&value.0[1]).try_into()?,
614                (&value.0[2]).try_into()?,
615                (&value.0[3]).try_into()?,
616            ))
617        } else {
618            Err(InvalidBuiltinDecode("ClaimedContractCallsValue"))
619        }
620    }
621}
622
623impl Aligned for ClaimedContractCallsValue {
624    fn alignment() -> Alignment {
625        Alignment::concat([
626            &u64::alignment(),
627            &ContractAddress::alignment(),
628            &HashOutput::alignment(),
629            &Fr::alignment(),
630        ])
631    }
632}
633
634impl Distribution<ClaimedContractCallsValue> for Standard {
635    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> ClaimedContractCallsValue {
636        ClaimedContractCallsValue(rng.r#gen(), rng.r#gen(), rng.r#gen(), rng.r#gen())
637    }
638}
639
640#[derive(Storable)]
641#[derive_where(Clone, Debug, Default, PartialEq, Eq)]
642#[cfg_attr(feature = "proptest", derive(Arbitrary))]
643#[storable(db = D)]
644#[tag = "contract-effects[v3]"]
645pub struct Effects<D: DB> {
646    pub claimed_nullifiers: storage::storage::HashSet<Nullifier, D>,
647    pub claimed_shielded_receives: storage::storage::HashSet<CoinCommitment, D>,
648    pub claimed_shielded_spends: storage::storage::HashSet<CoinCommitment, D>,
649    pub claimed_contract_calls: storage::storage::HashSet<ClaimedContractCallsValue, D>,
650    pub shielded_mints: storage::storage::HashMap<HashOutput, u64, D>,
651    pub unshielded_mints: storage::storage::HashMap<HashOutput, u64, D>,
652    pub unshielded_inputs: storage::storage::HashMap<TokenType, u128, D>,
653    pub unshielded_outputs: storage::storage::HashMap<TokenType, u128, D>,
654    pub claimed_unshielded_spends: storage::storage::HashMap<ClaimedUnshieldedSpendsKey, u128, D>,
655}
656tag_enforcement_test!(Effects<InMemoryDB>);
657
658impl<D: DB> Serialize for Effects<D> {
659    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
660    where
661        S: Serializer,
662    {
663        let ser_effects: SerdeEffects = self.clone().into();
664        <SerdeEffects as Serialize>::serialize(&ser_effects, serializer)
665    }
666}
667
668impl<'de, DD: DB> Deserialize<'de> for Effects<DD> {
669    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
670    where
671        D: Deserializer<'de>,
672    {
673        let ser_effects = <SerdeEffects as Deserialize>::deserialize(deserializer)?;
674        Effects::try_from(ser_effects).map_err(serde::de::Error::custom)
675    }
676}
677
678impl<D: DB> rand::distributions::Distribution<Effects<D>> for rand::distributions::Standard {
679    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Effects<D> {
680        Effects {
681            claimed_nullifiers: vec![rng.r#gen(); 5].into_iter().collect(),
682            claimed_shielded_receives: vec![rng.r#gen(); 5].into_iter().collect(),
683            claimed_shielded_spends: vec![rng.r#gen(); 5].into_iter().collect(),
684            claimed_contract_calls: vec![rng.r#gen(); 5].into_iter().collect(),
685            shielded_mints: vec![rng.r#gen(); 5].into_iter().collect(),
686            unshielded_mints: vec![rng.r#gen(); 5].into_iter().collect(),
687            unshielded_inputs: vec![rng.r#gen(); 5].into_iter().collect(),
688            unshielded_outputs: vec![rng.r#gen(); 5].into_iter().collect(),
689            claimed_unshielded_spends: vec![rng.r#gen(); 5].into_iter().collect(),
690        }
691    }
692}
693
694#[cfg(all(test, feature = "proptest"))]
695use storage::db::InMemoryDB;
696#[cfg(feature = "proptest")]
697serialize::randomised_serialization_test!(Effects<InMemoryDB>);
698
699impl<'a, D: DB> From<&'a Effects<D>> for VmValue<D> {
700    fn from(eff: &'a Effects<D>) -> VmValue<D> {
701        VmValue::new(
702            ValueStrength::Weak,
703            StateValue::Array(
704                vec![
705                    StateValue::Map(
706                        eff.claimed_nullifiers
707                            .iter()
708                            .map(|k| ((**k).into(), StateValue::Null))
709                            .collect(),
710                    ),
711                    StateValue::Map(
712                        eff.claimed_shielded_receives
713                            .iter()
714                            .map(|k| ((**k).into(), StateValue::Null))
715                            .collect(),
716                    ),
717                    StateValue::Map(
718                        eff.claimed_shielded_spends
719                            .iter()
720                            .map(|k| ((**k).into(), StateValue::Null))
721                            .collect(),
722                    ),
723                    StateValue::Map(
724                        eff.claimed_contract_calls
725                            .iter()
726                            .map(|sp_item| {
727                                let value_sp = &(*sp_item);
728                                let value: ClaimedContractCallsValue =
729                                    (*(*value_sp).clone()).clone();
730                                (value.into(), StateValue::Null)
731                            })
732                            .collect(),
733                    ),
734                    StateValue::Map(
735                        eff.shielded_mints
736                            .iter()
737                            .map(|x| ((*x.0).into(), StateValue::Cell(Sp::new((*(x.1)).into()))))
738                            .collect(),
739                    ),
740                    StateValue::Map(
741                        eff.unshielded_mints
742                            .iter()
743                            .map(|x| ((*x.0).into(), StateValue::Cell(Sp::new((*(x.1)).into()))))
744                            .collect(),
745                    ),
746                    StateValue::Map(
747                        eff.unshielded_inputs
748                            .iter()
749                            .map(|x| ((*x.0).into(), StateValue::Cell(Sp::new((*(x.1)).into()))))
750                            .collect(),
751                    ),
752                    StateValue::Map(
753                        eff.unshielded_outputs
754                            .iter()
755                            .map(|x| ((*x.0).into(), StateValue::Cell(Sp::new((*(x.1)).into()))))
756                            .collect(),
757                    ),
758                    StateValue::Map(
759                        eff.claimed_unshielded_spends
760                            .iter()
761                            .map(|sp_item| {
762                                let (ref key_sp, ref value_sp) = *sp_item;
763                                let key: ClaimedUnshieldedSpendsKey = (*(*key_sp).clone()).clone();
764                                let value: u128 = *(*value_sp).clone();
765                                (key.into(), StateValue::Cell(Sp::new(value.into())))
766                            })
767                            .collect(),
768                    ),
769                ]
770                .into(),
771            ),
772        )
773    }
774}
775
776impl<D: DB> TryFrom<VmValue<D>> for Effects<D> {
777    type Error = TranscriptRejected<D>;
778
779    fn try_from(val: VmValue<D>) -> Result<Effects<D>, TranscriptRejected<D>> {
780        fn map_from<
781            K: Eq
782                + Hash
783                + for<'a> TryFrom<&'a ValueSlice, Error = InvalidBuiltinDecode>
784                + Serializable
785                + Storable<D>,
786            V: Default + for<'a> TryFrom<&'a ValueSlice, Error = InvalidBuiltinDecode> + Storable<D>,
787            D: DB,
788        >(
789            st: &StateValue<D>,
790        ) -> Result<storage::storage::HashMap<K, V, D>, TranscriptRejected<D>> {
791            if let StateValue::Map(m) = st {
792                Ok(m.iter()
793                    .map(|kv| {
794                        let v = match *kv.1 {
795                            StateValue::Cell(ref v) => (&*v.value).try_into()?,
796                            StateValue::Null => V::default(),
797                            _ => return Err(TranscriptRejected::EffectDecodeError),
798                        };
799                        Ok::<_, TranscriptRejected<D>>((
800                            (&**AsRef::<Value>::as_ref(&(*kv.0))).try_into()?,
801                            v,
802                        ))
803                    })
804                    .collect::<Result<_, _>>()?)
805            } else {
806                Err(TranscriptRejected::EffectDecodeError)
807            }
808        }
809        if let StateValue::Array(arr) = &val.value
810            && arr.len() == 9
811        {
812            return Ok(Effects {
813                claimed_nullifiers: map_from::<Nullifier, (), D>(arr.get(0).unwrap())?
814                    .iter()
815                    .map(|x| *x.0)
816                    .collect(),
817                claimed_shielded_receives: map_from::<CoinCommitment, (), D>(arr.get(1).unwrap())?
818                    .iter()
819                    .map(|x| *x.0)
820                    .collect(),
821                claimed_shielded_spends: map_from::<CoinCommitment, (), D>(arr.get(2).unwrap())?
822                    .iter()
823                    .map(|x| *x.0)
824                    .collect(),
825                claimed_contract_calls: map_from::<ClaimedContractCallsValue, (), D>(
826                    arr.get(3).unwrap(),
827                )?
828                .iter()
829                .map(|x| (*x.0).clone())
830                .collect(),
831                shielded_mints: map_from(arr.get(4).unwrap())?,
832                unshielded_mints: map_from(arr.get(5).unwrap())?,
833                unshielded_inputs: map_from(arr.get(6).unwrap())?,
834                unshielded_outputs: map_from(arr.get(7).unwrap())?,
835                claimed_unshielded_spends: map_from(arr.get(8).unwrap())?,
836            });
837        }
838        Err(TranscriptRejected::EffectDecodeError)
839    }
840}
841
842#[derive_where(Clone, Debug)]
843pub struct QueryContext<D: DB> {
844    pub state: ChargedState<D>,
845    pub effects: Effects<D>,
846    // TODO WG
847    // Either this (`address`) should be removed, or `own_address` should be removed
848    // from `CallContext` and `call_context` should be optional.
849    pub address: ContractAddress,
850    pub call_context: CallContext<D>,
851}
852
853impl<D: DB> From<&QueryContext<D>> for VmValue<D> {
854    fn from(context: &QueryContext<D>) -> VmValue<D> {
855        VmValue::new(
856            ValueStrength::Weak,
857            StateValue::Array(
858                vec![
859                    StateValue::Cell(Sp::new(context.address.into())),
860                    StateValue::Map(
861                        context
862                            .call_context
863                            .com_indices
864                            .iter()
865                            .map(|(k, v)| {
866                                (k.into(), StateValue::Cell(Sp::new((*v.clone()).into())))
867                            })
868                            .collect(),
869                    ),
870                    StateValue::Cell(Sp::new(context.call_context.tblock.into())),
871                    StateValue::Cell(Sp::new(context.call_context.tblock_err.into())),
872                    StateValue::Cell(Sp::new(context.call_context.parent_block_hash.into())),
873                    StateValue::Map(
874                        context
875                            .call_context
876                            .balance
877                            .iter()
878                            .map(|tt_x_amount| {
879                                (
880                                    (*tt_x_amount.0.deref()).into(),
881                                    StateValue::Cell(Sp::new((*tt_x_amount.1.deref()).into())),
882                                )
883                            })
884                            .collect(),
885                    ),
886                    match context.call_context.caller {
887                        Some(x) => StateValue::Cell(Sp::new(x.into())),
888                        None => StateValue::Null,
889                    },
890                    StateValue::Cell(Sp::new(context.call_context.last_block_time.into())),
891                ]
892                .into(),
893            ),
894        )
895    }
896}
897
898#[derive(Debug)]
899pub struct QueryResults<M: ResultMode<D>, D: DB> {
900    pub context: QueryContext<D>,
901    pub events: Vec<M::Event>,
902    pub gas_cost: RunningCost,
903}
904
905impl<D: DB> QueryContext<D> {
906    pub fn new(state: ChargedState<D>, address: ContractAddress) -> Self {
907        QueryContext {
908            state,
909            address,
910            effects: Effects::default(),
911            call_context: CallContext::default(),
912        }
913    }
914
915    pub fn qualify(&self, coin: &CoinInfo) -> Option<QualifiedCoinInfo> {
916        self.call_context
917            .com_indices
918            .get(&coin.commitment(&Recipient::Contract(self.address)))
919            .map(|idx| coin.qualify(*idx))
920    }
921
922    #[instrument(skip(self, cost_model))]
923    pub fn query<M: ResultMode<D>>(
924        &self,
925        query: &[Op<M, D>],
926        gas_limit: Option<RunningCost>,
927        cost_model: &CostModel,
928    ) -> Result<QueryResults<M, D>, TranscriptRejected<D>> {
929        let mut state: Self = (*self).clone();
930        let mut res = run_program(&self.to_vm_stack(), query, gas_limit, cost_model)?;
931        if res.stack.len() != 3 {
932            return Err(TranscriptRejected::FinalStackWrongLength);
933        }
934        let new_state = match res.stack.pop().unwrap() {
935            VmValue {
936                strength: ValueStrength::Strong,
937                value,
938            } => value,
939            VmValue {
940                strength: ValueStrength::Weak,
941                ..
942            } => return Err(TranscriptRejected::WeakStateReturned),
943        };
944        state.effects = res.stack.pop().unwrap().try_into()?;
945
946        let (new_charged_state, state_cost) = state.state.update(
947            new_state,
948            |writes, deletes| {
949                RunningCost::compute(
950                    cost_model.gc_rcmap_constant
951                        + cost_model.gc_rcmap_coeff_keys_removed_size * deletes
952                        + cost_model.update_rcmap_constant
953                        + cost_model.update_rcmap_coeff_keys_added_size * writes
954                        + cost_model.get_writes_constant
955                        + cost_model.get_writes_coeff_keys_added_size * writes,
956                )
957            },
958            |budget| {
959                (budget.compute_time / cost_model.gc_rcmap_coeff_keys_removed_size)
960                    .into_atomic_units(1) as usize
961            },
962        );
963        state.state = new_charged_state;
964        let gas_cost = res.gas_cost + state_cost;
965        if let Some(gas_limit) = gas_limit
966            && gas_cost > gas_limit
967        {
968            // TODO?: return a more specific error, explaining that gas
969            // limit was exceeded by write+delete vs by cpu during vm eval?
970            return Err(TranscriptRejected::Execution(OnchainProgramError::OutOfGas));
971        }
972
973        trace!("transcript application successful");
974        Ok(QueryResults {
975            context: state,
976            events: res.events,
977            gas_cost,
978        })
979    }
980
981    pub fn to_vm_stack(&self) -> Vec<VmValue<D>> {
982        vec![
983            self.into(),
984            (&self.effects).into(),
985            VmValue::new(ValueStrength::Strong, (*self.state.get()).clone()),
986        ]
987    }
988
989    #[instrument(skip(self, cost_model))]
990    pub fn run_transcript(
991        &self,
992        transcript: &Transcript<D>,
993        cost_model: &CostModel,
994    ) -> Result<QueryResults<ResultModeVerify, D>, TranscriptRejected<D>> {
995        self.query(
996            &Vec::from(&transcript.program),
997            Some(transcript.gas),
998            cost_model,
999        )
1000    }
1001}
1002
1003fn ensure_fully_deserialized(data: &[u8]) -> Result<(), std::io::Error> {
1004    if !data.is_empty() {
1005        return Err(std::io::Error::new(
1006            std::io::ErrorKind::InvalidData,
1007            format!("Not all bytes read, {} bytes remaining", data.len()),
1008        ));
1009    }
1010    Ok(())
1011}