Skip to main content

risc0_groth16_sys/
lib.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    ffi::{c_char, CStr, CString, NulError},
17    path::{Path, PathBuf},
18};
19
20use anyhow::{anyhow, Result};
21
22#[cfg(feature = "cuda")]
23pub use sppark::Error as SpparkError;
24
25pub struct SetupParams {
26    pub pcoeffs_path: RawPath,
27    pub fres_path: RawPath,
28    pub srs_path: RawPath,
29}
30
31impl SetupParams {
32    pub fn new(root_dir: &Path) -> anyhow::Result<Self> {
33        Ok(SetupParams {
34            pcoeffs_path: root_dir.join("preprocessed_coeffs.bin").try_into()?,
35            fres_path: root_dir.join("fuzzed_msm_results.bin").try_into()?,
36            srs_path: root_dir.join("stark_verify_final.zkey").try_into()?,
37        })
38    }
39}
40
41pub struct WitnessParams {
42    pub graph_path: PathBuf,
43}
44
45impl WitnessParams {
46    pub fn new(root_dir: &Path) -> Self {
47        WitnessParams {
48            graph_path: root_dir.join("stark_verify_graph.bin"),
49        }
50    }
51}
52
53pub struct ProverParams {
54    pub public_path: RawPath,
55    pub proof_path: RawPath,
56    pub witness: *const u8,
57}
58
59impl ProverParams {
60    pub fn new(root_dir: &Path, witness: *const u8) -> anyhow::Result<Self> {
61        Ok(Self {
62            public_path: root_dir.join("public.json").try_into()?,
63            proof_path: root_dir.join("proof.json").try_into()?,
64            witness,
65        })
66    }
67}
68
69#[cfg(feature = "cuda")]
70pub fn prove(prover_params: &ProverParams, setup_params: &SetupParams) -> anyhow::Result<()> {
71    let setup_params = RawSetupParams {
72        pcoeffs_path: setup_params.pcoeffs_path.c_str.as_ptr(),
73        fres_path: setup_params.fres_path.c_str.as_ptr(),
74        srs_path: setup_params.srs_path.c_str.as_ptr(),
75    };
76    let prover_params = RawProverParams {
77        public_path: prover_params.public_path.c_str.as_ptr(),
78        proof_path: prover_params.proof_path.c_str.as_ptr(),
79        witness: prover_params.witness,
80    };
81
82    ffi_wrap(|| unsafe { risc0_groth16_cuda_prove(&setup_params, &prover_params) })
83}
84
85#[cfg(all(feature = "cuda", feature = "setup"))]
86pub fn setup(params: &SetupParams) -> anyhow::Result<()> {
87    let raw_params = RawSetupParams {
88        pcoeffs_path: params.pcoeffs_path.c_str.as_ptr(),
89        fres_path: params.fres_path.c_str.as_ptr(),
90        srs_path: params.srs_path.c_str.as_ptr(),
91    };
92    ffi_wrap(|| unsafe { risc0_groth16_cuda_setup(&raw_params) })
93}
94
95#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
96#[repr(C)]
97struct RawProverParams {
98    pub public_path: *const c_char,
99    pub proof_path: *const c_char,
100    pub witness: *const u8,
101}
102
103#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
104#[repr(C)]
105struct RawSetupParams {
106    pub pcoeffs_path: *const c_char,
107    pub fres_path: *const c_char,
108    pub srs_path: *const c_char,
109}
110
111extern "C" {
112    #[cfg(feature = "cuda")]
113    fn risc0_groth16_cuda_prove(
114        setup: *const RawSetupParams,
115        params: *const RawProverParams,
116    ) -> *const c_char;
117
118    #[cfg(all(feature = "cuda", feature = "setup"))]
119    fn risc0_groth16_cuda_setup(params: *const RawSetupParams) -> *const c_char;
120}
121
122#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
123fn ffi_wrap<F>(mut inner: F) -> Result<()>
124where
125    F: FnMut() -> *const c_char,
126{
127    extern "C" {
128        fn free(str: *const c_char);
129    }
130
131    let c_ptr = inner();
132    if c_ptr.is_null() {
133        Ok(())
134    } else {
135        let what = unsafe {
136            let msg = CStr::from_ptr(c_ptr)
137                .to_str()
138                .unwrap_or("Invalid error msg pointer")
139                .to_string();
140            free(c_ptr);
141            msg
142        };
143        Err(anyhow!(what))
144    }
145}
146
147#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
148pub struct RawPath {
149    path: PathBuf,
150    c_str: CString,
151}
152
153impl RawPath {
154    pub fn as_path(&self) -> &Path {
155        &self.path
156    }
157}
158
159impl TryFrom<&Path> for RawPath {
160    type Error = NulError;
161
162    fn try_from(value: &Path) -> Result<Self, Self::Error> {
163        Ok(RawPath {
164            path: value.to_path_buf(),
165            c_str: CString::new(value.as_os_str().as_encoded_bytes())?,
166        })
167    }
168}
169
170impl TryFrom<PathBuf> for RawPath {
171    type Error = NulError;
172
173    fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
174        RawPath::try_from(value.as_path())
175    }
176}