Skip to main content

rusk_prover/
lib.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4//
5// Copyright (c) DUSK NETWORK. All rights reserved.
6
7#![cfg_attr(not(feature = "std"), no_std)]
8#![deny(unused_crate_dependencies)]
9#![deny(unused_extern_crates)]
10
11extern crate alloc;
12#[cfg(feature = "std")]
13extern crate std;
14
15use alloc::format;
16use alloc::vec::Vec;
17
18use dusk_bytes::Serializable;
19use dusk_core::Error;
20use dusk_core::transfer::phoenix::{
21    NOTES_TREE_DEPTH, Prove, TxCircuit, TxCircuitVec,
22};
23use dusk_plonk::prelude::{PlonkVersion, Prover as PlonkProver};
24use once_cell::sync::Lazy;
25
26static TX_CIRCUIT_1_2_PROVER: Lazy<PlonkProver> =
27    Lazy::new(|| fetch_prover("TxCircuitOneTwo"));
28
29static TX_CIRCUIT_2_2_PROVER: Lazy<PlonkProver> =
30    Lazy::new(|| fetch_prover("TxCircuitTwoTwo"));
31
32static TX_CIRCUIT_3_2_PROVER: Lazy<PlonkProver> =
33    Lazy::new(|| fetch_prover("TxCircuitThreeTwo"));
34
35static TX_CIRCUIT_4_2_PROVER: Lazy<PlonkProver> =
36    Lazy::new(|| fetch_prover("TxCircuitFourTwo"));
37
38fn plonk_prove_version_from_mode(mode: Option<&str>) -> PlonkVersion {
39    match mode {
40        Some(mode)
41            if mode.eq_ignore_ascii_case("v2")
42                || mode.eq_ignore_ascii_case("legacy") =>
43        {
44            PlonkVersion::V2
45        }
46        _ => PlonkVersion::V3,
47    }
48}
49
50#[derive(Debug, Default)]
51pub struct LocalProver;
52
53impl Prove for LocalProver {
54    fn prove(&self, tx_circuit_vec_bytes: &[u8]) -> Result<Vec<u8>, Error> {
55        let tx_circuit_vec = TxCircuitVec::from_slice(tx_circuit_vec_bytes)?;
56
57        // Proving mode is chosen by the prover service. Default is V3 and can
58        // be switched to V2 explicitly for legacy proving.
59        #[cfg(feature = "std")]
60        let plonk_version = plonk_prove_version_from_mode(
61            std::env::var("RUSK_PLONK_PROVE_MODE").ok().as_deref(),
62        );
63        #[cfg(not(feature = "std"))]
64        let plonk_version = plonk_prove_version_from_mode(None);
65
66        #[cfg(not(feature = "unsafe_deterministic_rng"))]
67        let rng = &mut rand::rngs::OsRng;
68
69        #[cfg(feature = "unsafe_deterministic_rng")]
70        use rand::{SeedableRng, rngs::StdRng};
71        #[cfg(feature = "unsafe_deterministic_rng")]
72        let rng = &mut StdRng::seed_from_u64(0xbeef);
73
74        #[cfg(feature = "debug")]
75        tracing::info!(
76            "tx_circuit_vec:\n{}",
77            hex::encode(tx_circuit_vec_bytes)
78        );
79
80        let (proof, _pi) = match tx_circuit_vec.input_notes_info.len() {
81            1 => {
82                let circuit = create_circuit::<1>(tx_circuit_vec)?;
83                TX_CIRCUIT_1_2_PROVER
84                    .prove_with_version(rng, &circuit, plonk_version)
85                    .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
86            }
87            2 => {
88                let circuit = create_circuit::<2>(tx_circuit_vec)?;
89                TX_CIRCUIT_2_2_PROVER
90                    .prove_with_version(rng, &circuit, plonk_version)
91                    .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
92            }
93            3 => {
94                let circuit = create_circuit::<3>(tx_circuit_vec)?;
95                TX_CIRCUIT_3_2_PROVER
96                    .prove_with_version(rng, &circuit, plonk_version)
97                    .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
98            }
99            4 => {
100                let circuit = create_circuit::<4>(tx_circuit_vec)?;
101                TX_CIRCUIT_4_2_PROVER
102                    .prove_with_version(rng, &circuit, plonk_version)
103                    .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
104            }
105            _ => return Err(Error::InvalidData),
106        };
107
108        Ok(proof.to_bytes().to_vec())
109    }
110}
111
112fn fetch_prover(circuit_name: &str) -> PlonkProver {
113    let circuit_profile = rusk_profile::Circuit::from_name(circuit_name)
114        .unwrap_or_else(|_| {
115            panic!("There should be tx-circuit data stored for {circuit_name}",)
116        });
117    let pk = circuit_profile.get_prover().unwrap_or_else(|_| {
118        panic!("there should be a prover key stored for {circuit_name}")
119    });
120
121    PlonkProver::try_from_bytes(pk).expect("Prover key is expected to by valid")
122}
123
124fn create_circuit<const I: usize>(
125    tx_circuit_vec: TxCircuitVec,
126) -> Result<TxCircuit<NOTES_TREE_DEPTH, I>, Error> {
127    Ok(TxCircuit {
128        input_notes_info: tx_circuit_vec
129            .input_notes_info
130            .try_into()
131            .map_err(|e| Error::PhoenixCircuit(format!("{e:?}")))?,
132        output_notes_info: tx_circuit_vec.output_notes_info,
133        payload_hash: tx_circuit_vec.payload_hash,
134        root: tx_circuit_vec.root,
135        deposit: tx_circuit_vec.deposit,
136        max_fee: tx_circuit_vec.max_fee,
137        sender_pk: tx_circuit_vec.sender_pk,
138        signatures: tx_circuit_vec.signatures,
139    })
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn prove_mode_mapping() {
148        let cases = [
149            (None, PlonkVersion::V3),
150            (Some("current"), PlonkVersion::V3),
151            (Some("random"), PlonkVersion::V3),
152            (Some("v2"), PlonkVersion::V2),
153            (Some("legacy"), PlonkVersion::V2),
154        ];
155
156        for (mode, expected) in cases {
157            assert_eq!(plonk_prove_version_from_mode(mode), expected);
158        }
159    }
160
161    #[test]
162    fn test_prove_tx_circuit() {
163        let tx_circuit_vec_bytes =
164            hex::decode(include_str!("../tests/tx_circuit_vec.hex")).unwrap();
165        let _proof = LocalProver.prove(&tx_circuit_vec_bytes).unwrap();
166    }
167}