1#![cfg_attr(not(feature = "std"), no_std)]
8#![deny(unused_crate_dependencies)]
9#![deny(unused_extern_crates)]
10
11extern crate alloc;
12#[cfg(feature = "std")]
13extern crate std;
14
15use alloc::format;
16use alloc::vec::Vec;
17
18use dusk_bytes::Serializable;
19use dusk_core::Error;
20use dusk_core::transfer::phoenix::{
21 NOTES_TREE_DEPTH, Prove, TxCircuit, TxCircuitVec,
22};
23use dusk_plonk::prelude::{PlonkVersion, Prover as PlonkProver};
24use once_cell::sync::Lazy;
25
26static TX_CIRCUIT_1_2_PROVER: Lazy<PlonkProver> =
27 Lazy::new(|| fetch_prover("TxCircuitOneTwo"));
28
29static TX_CIRCUIT_2_2_PROVER: Lazy<PlonkProver> =
30 Lazy::new(|| fetch_prover("TxCircuitTwoTwo"));
31
32static TX_CIRCUIT_3_2_PROVER: Lazy<PlonkProver> =
33 Lazy::new(|| fetch_prover("TxCircuitThreeTwo"));
34
35static TX_CIRCUIT_4_2_PROVER: Lazy<PlonkProver> =
36 Lazy::new(|| fetch_prover("TxCircuitFourTwo"));
37
38fn plonk_prove_version_from_mode(mode: Option<&str>) -> PlonkVersion {
39 match mode {
40 Some(mode)
41 if mode.eq_ignore_ascii_case("v2")
42 || mode.eq_ignore_ascii_case("legacy") =>
43 {
44 PlonkVersion::V2
45 }
46 _ => PlonkVersion::V3,
47 }
48}
49
50#[derive(Debug, Default)]
51pub struct LocalProver;
52
53impl Prove for LocalProver {
54 fn prove(&self, tx_circuit_vec_bytes: &[u8]) -> Result<Vec<u8>, Error> {
55 let tx_circuit_vec = TxCircuitVec::from_slice(tx_circuit_vec_bytes)?;
56
57 #[cfg(feature = "std")]
60 let plonk_version = plonk_prove_version_from_mode(
61 std::env::var("RUSK_PLONK_PROVE_MODE").ok().as_deref(),
62 );
63 #[cfg(not(feature = "std"))]
64 let plonk_version = plonk_prove_version_from_mode(None);
65
66 #[cfg(not(feature = "unsafe_deterministic_rng"))]
67 let rng = &mut rand::rngs::OsRng;
68
69 #[cfg(feature = "unsafe_deterministic_rng")]
70 use rand::{SeedableRng, rngs::StdRng};
71 #[cfg(feature = "unsafe_deterministic_rng")]
72 let rng = &mut StdRng::seed_from_u64(0xbeef);
73
74 #[cfg(feature = "debug")]
75 tracing::info!(
76 "tx_circuit_vec:\n{}",
77 hex::encode(tx_circuit_vec_bytes)
78 );
79
80 let (proof, _pi) = match tx_circuit_vec.input_notes_info.len() {
81 1 => {
82 let circuit = create_circuit::<1>(tx_circuit_vec)?;
83 TX_CIRCUIT_1_2_PROVER
84 .prove_with_version(rng, &circuit, plonk_version)
85 .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
86 }
87 2 => {
88 let circuit = create_circuit::<2>(tx_circuit_vec)?;
89 TX_CIRCUIT_2_2_PROVER
90 .prove_with_version(rng, &circuit, plonk_version)
91 .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
92 }
93 3 => {
94 let circuit = create_circuit::<3>(tx_circuit_vec)?;
95 TX_CIRCUIT_3_2_PROVER
96 .prove_with_version(rng, &circuit, plonk_version)
97 .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
98 }
99 4 => {
100 let circuit = create_circuit::<4>(tx_circuit_vec)?;
101 TX_CIRCUIT_4_2_PROVER
102 .prove_with_version(rng, &circuit, plonk_version)
103 .map_err(|e| Error::PhoenixProver(format!("{e:?}")))?
104 }
105 _ => return Err(Error::InvalidData),
106 };
107
108 Ok(proof.to_bytes().to_vec())
109 }
110}
111
112fn fetch_prover(circuit_name: &str) -> PlonkProver {
113 let circuit_profile = rusk_profile::Circuit::from_name(circuit_name)
114 .unwrap_or_else(|_| {
115 panic!("There should be tx-circuit data stored for {circuit_name}",)
116 });
117 let pk = circuit_profile.get_prover().unwrap_or_else(|_| {
118 panic!("there should be a prover key stored for {circuit_name}")
119 });
120
121 PlonkProver::try_from_bytes(pk).expect("Prover key is expected to by valid")
122}
123
124fn create_circuit<const I: usize>(
125 tx_circuit_vec: TxCircuitVec,
126) -> Result<TxCircuit<NOTES_TREE_DEPTH, I>, Error> {
127 Ok(TxCircuit {
128 input_notes_info: tx_circuit_vec
129 .input_notes_info
130 .try_into()
131 .map_err(|e| Error::PhoenixCircuit(format!("{e:?}")))?,
132 output_notes_info: tx_circuit_vec.output_notes_info,
133 payload_hash: tx_circuit_vec.payload_hash,
134 root: tx_circuit_vec.root,
135 deposit: tx_circuit_vec.deposit,
136 max_fee: tx_circuit_vec.max_fee,
137 sender_pk: tx_circuit_vec.sender_pk,
138 signatures: tx_circuit_vec.signatures,
139 })
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn prove_mode_mapping() {
148 let cases = [
149 (None, PlonkVersion::V3),
150 (Some("current"), PlonkVersion::V3),
151 (Some("random"), PlonkVersion::V3),
152 (Some("v2"), PlonkVersion::V2),
153 (Some("legacy"), PlonkVersion::V2),
154 ];
155
156 for (mode, expected) in cases {
157 assert_eq!(plonk_prove_version_from_mode(mode), expected);
158 }
159 }
160
161 #[test]
162 fn test_prove_tx_circuit() {
163 let tx_circuit_vec_bytes =
164 hex::decode(include_str!("../tests/tx_circuit_vec.hex")).unwrap();
165 let _proof = LocalProver.prove(&tx_circuit_vec_bytes).unwrap();
166 }
167}