1use std::{
16 ffi::{c_char, CStr, CString, NulError},
17 path::{Path, PathBuf},
18};
19
20use anyhow::{anyhow, Result};
21
22#[cfg(feature = "cuda")]
23pub use sppark::Error as SpparkError;
24
25pub struct SetupParams {
26 pub pcoeffs_path: RawPath,
27 pub fres_path: RawPath,
28 pub srs_path: RawPath,
29}
30
31impl SetupParams {
32 pub fn new(root_dir: &Path) -> anyhow::Result<Self> {
33 Ok(SetupParams {
34 pcoeffs_path: root_dir.join("preprocessed_coeffs.bin").try_into()?,
35 fres_path: root_dir.join("fuzzed_msm_results.bin").try_into()?,
36 srs_path: root_dir.join("stark_verify_final.zkey").try_into()?,
37 })
38 }
39}
40
41pub struct WitnessParams {
42 pub graph_path: PathBuf,
43}
44
45impl WitnessParams {
46 pub fn new(root_dir: &Path) -> Self {
47 WitnessParams {
48 graph_path: root_dir.join("stark_verify_graph.bin"),
49 }
50 }
51}
52
53pub struct ProverParams {
54 pub public_path: RawPath,
55 pub proof_path: RawPath,
56 pub witness: *const u8,
57}
58
59impl ProverParams {
60 pub fn new(root_dir: &Path, witness: *const u8) -> anyhow::Result<Self> {
61 Ok(Self {
62 public_path: root_dir.join("public.json").try_into()?,
63 proof_path: root_dir.join("proof.json").try_into()?,
64 witness,
65 })
66 }
67}
68
69#[cfg(feature = "cuda")]
70pub fn prove(prover_params: &ProverParams, setup_params: &SetupParams) -> anyhow::Result<()> {
71 let setup_params = RawSetupParams {
72 pcoeffs_path: setup_params.pcoeffs_path.c_str.as_ptr(),
73 fres_path: setup_params.fres_path.c_str.as_ptr(),
74 srs_path: setup_params.srs_path.c_str.as_ptr(),
75 };
76 let prover_params = RawProverParams {
77 public_path: prover_params.public_path.c_str.as_ptr(),
78 proof_path: prover_params.proof_path.c_str.as_ptr(),
79 witness: prover_params.witness,
80 };
81
82 ffi_wrap(|| unsafe { risc0_groth16_cuda_prove(&setup_params, &prover_params) })
83}
84
85#[cfg(all(feature = "cuda", feature = "setup"))]
86pub fn setup(params: &SetupParams) -> anyhow::Result<()> {
87 let raw_params = RawSetupParams {
88 pcoeffs_path: params.pcoeffs_path.c_str.as_ptr(),
89 fres_path: params.fres_path.c_str.as_ptr(),
90 srs_path: params.srs_path.c_str.as_ptr(),
91 };
92 ffi_wrap(|| unsafe { risc0_groth16_cuda_setup(&raw_params) })
93}
94
95#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
96#[repr(C)]
97struct RawProverParams {
98 pub public_path: *const c_char,
99 pub proof_path: *const c_char,
100 pub witness: *const u8,
101}
102
103#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
104#[repr(C)]
105struct RawSetupParams {
106 pub pcoeffs_path: *const c_char,
107 pub fres_path: *const c_char,
108 pub srs_path: *const c_char,
109}
110
111extern "C" {
112 #[cfg(feature = "cuda")]
113 fn risc0_groth16_cuda_prove(
114 setup: *const RawSetupParams,
115 params: *const RawProverParams,
116 ) -> *const c_char;
117
118 #[cfg(all(feature = "cuda", feature = "setup"))]
119 fn risc0_groth16_cuda_setup(params: *const RawSetupParams) -> *const c_char;
120}
121
122#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
123fn ffi_wrap<F>(mut inner: F) -> Result<()>
124where
125 F: FnMut() -> *const c_char,
126{
127 extern "C" {
128 fn free(str: *const c_char);
129 }
130
131 let c_ptr = inner();
132 if c_ptr.is_null() {
133 Ok(())
134 } else {
135 let what = unsafe {
136 let msg = CStr::from_ptr(c_ptr)
137 .to_str()
138 .unwrap_or("Invalid error msg pointer")
139 .to_string();
140 free(c_ptr);
141 msg
142 };
143 Err(anyhow!(what))
144 }
145}
146
147#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
148pub struct RawPath {
149 path: PathBuf,
150 c_str: CString,
151}
152
153impl RawPath {
154 pub fn as_path(&self) -> &Path {
155 &self.path
156 }
157}
158
159impl TryFrom<&Path> for RawPath {
160 type Error = NulError;
161
162 fn try_from(value: &Path) -> Result<Self, Self::Error> {
163 Ok(RawPath {
164 path: value.to_path_buf(),
165 c_str: CString::new(value.as_os_str().as_encoded_bytes())?,
166 })
167 }
168}
169
170impl TryFrom<PathBuf> for RawPath {
171 type Error = NulError;
172
173 fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
174 RawPath::try_from(value.as_path())
175 }
176}