use std::time::Instant;
use blstrs::{Bls12, Scalar};
use rand_core::OsRng;
use webgpu_groth16::gpu::GpuContext;
use webgpu_groth16::{bellman, prover};
struct RepeatedSquaringCircuit<S: ff::PrimeField> {
x: Option<S>,
num_squarings: usize,
}
impl<S: ff::PrimeField> bellman::Circuit<S> for RepeatedSquaringCircuit<S> {
fn synthesize<CS: bellman::ConstraintSystem<S>>(
self,
cs: &mut CS,
) -> Result<(), bellman::SynthesisError> {
let mut cur_val = self.x;
let mut cur_var = cs.alloc(
|| "x",
|| cur_val.ok_or(bellman::SynthesisError::AssignmentMissing),
)?;
for i in 0..self.num_squarings {
let next_val = cur_val.map(|v| v.square());
let next_var = if i == self.num_squarings - 1 {
cs.alloc_input(
|| format!("sq_{i}"),
|| {
next_val
.ok_or(bellman::SynthesisError::AssignmentMissing)
},
)?
} else {
cs.alloc(
|| format!("sq_{i}"),
|| {
next_val
.ok_or(bellman::SynthesisError::AssignmentMissing)
},
)?
};
cs.enforce(
|| format!("sq_constraint_{i}"),
|lc| lc + cur_var,
|lc| lc + cur_var,
|lc| lc + next_var,
);
cur_val = next_val;
cur_var = next_var;
}
Ok(())
}
}
fn main() {
let args: Vec<String> = std::env::args().collect();
let num_squarings: usize =
args.get(1).and_then(|s| s.parse().ok()).unwrap_or(10_000);
let iterations: usize =
args.get(2).and_then(|s| s.parse().ok()).unwrap_or(5);
eprintln!("=== Groth16 GPU Profiling ===");
#[cfg(feature = "profiling")]
eprintln!(
" GPU profiling enabled — trace will be written to profile.json"
);
#[cfg(not(feature = "profiling"))]
eprintln!(
" Wall-clock timing only (enable 'profiling' feature for GPU \
breakdown)"
);
eprintln!(" constraints: {num_squarings}");
eprintln!(" iterations: {iterations}");
let t_setup = Instant::now();
let mut rng = OsRng;
let setup_circuit = RepeatedSquaringCircuit::<Scalar> {
x: None,
num_squarings,
};
let params = bellman::groth16::generate_random_parameters::<Bls12, _, _>(
setup_circuit,
&mut rng,
)
.expect("param gen failed");
let ppk = prover::prepare_proving_key::<Bls12, Bls12>(¶ms);
eprintln!(" setup: {:?}", t_setup.elapsed());
let t_gpu = Instant::now();
let rt = tokio::runtime::Runtime::new().expect("tokio runtime");
let gpu = rt
.block_on(GpuContext::<Bls12>::new())
.expect("gpu init failed");
eprintln!(" gpu init: {:?}", t_gpu.elapsed());
eprintln!(" warmup...");
{
let circuit = RepeatedSquaringCircuit::<Scalar> {
x: Some(Scalar::from(3u64)),
num_squarings,
};
rt.block_on(prover::create_proof::<Bls12, Bls12, _, _>(
circuit, ¶ms, &ppk, &gpu, &mut rng,
))
.expect("warmup proof failed");
#[cfg(feature = "profiling")]
{
gpu.end_profiler_frame();
let _ = gpu.process_profiler_results();
}
}
eprintln!(" profiling {iterations} iterations...");
let x = Scalar::from(3u64);
let t_total = Instant::now();
#[cfg(feature = "profiling")]
let mut all_profiling_data = Vec::new();
for i in 0..iterations {
let t_iter = Instant::now();
let circuit = RepeatedSquaringCircuit::<Scalar> {
x: Some(x),
num_squarings,
};
let _proof = rt
.block_on(prover::create_proof::<Bls12, Bls12, _, _>(
circuit, ¶ms, &ppk, &gpu, &mut rng,
))
.expect("proof failed");
#[cfg(feature = "profiling")]
{
gpu.end_profiler_frame();
if let Some(results) = gpu.process_profiler_results() {
eprintln!(
" iter {}: {:?} (GPU breakdown below)",
i + 1,
t_iter.elapsed()
);
print_gpu_results(&results, 2);
all_profiling_data.extend(results);
} else {
eprintln!(
" iter {}: {:?} (GPU results pending)",
i + 1,
t_iter.elapsed()
);
}
}
#[cfg(not(feature = "profiling"))]
eprintln!(" iter {}: {:?}", i + 1, t_iter.elapsed());
}
eprintln!(" total: {:?} ({iterations} proofs)", t_total.elapsed());
eprintln!(
" avg: {:?}/proof",
t_total.elapsed() / iterations as u32
);
#[cfg(feature = "profiling")]
{
let trace_path = std::path::Path::new("profile.json");
write_chrometrace(trace_path, &all_profiling_data)
.expect("failed to write chrome trace");
eprintln!();
eprintln!(" Trace written to {}", trace_path.display());
eprintln!(" Open in chrome://tracing or https://ui.perfetto.dev");
}
}
#[cfg(feature = "profiling")]
fn effective_time(
r: &wgpu_profiler::GpuTimerQueryResult,
) -> Option<std::ops::Range<f64>> {
if !r.nested_queries.is_empty() {
let mut start = f64::MAX;
let mut end = f64::MIN;
for child in &r.nested_queries {
if let Some(ct) = effective_time(child) {
start = start.min(ct.start);
end = end.max(ct.end);
}
}
if start < end { Some(start..end) } else { None }
} else if let Some(ref t) = r.time {
let dur = t.end - t.start;
if dur > 0.0 && t.start > 0.0 {
Some(t.clone())
} else {
None
}
} else {
None
}
}
#[cfg(feature = "profiling")]
fn print_gpu_results(
results: &[wgpu_profiler::GpuTimerQueryResult],
indent: usize,
) {
let pad = " ".repeat(indent);
for r in results {
if let Some(time) = effective_time(r) {
let duration_s = time.end - time.start;
if duration_s < 0.001 {
eprintln!(
"{pad}{}: {:.1} us",
r.label,
duration_s * 1_000_000.0
);
} else {
eprintln!("{pad}{}: {:.2} ms", r.label, duration_s * 1_000.0);
}
} else {
eprintln!("{pad}{}: (no timing data)", r.label);
}
if !r.nested_queries.is_empty() {
print_gpu_results(&r.nested_queries, indent + 2);
}
}
}
#[cfg(feature = "profiling")]
fn find_min_time(results: &[wgpu_profiler::GpuTimerQueryResult]) -> f64 {
let mut min_t = f64::MAX;
for r in results {
if let Some(t) = effective_time(r) {
min_t = min_t.min(t.start);
}
min_t = min_t.min(find_min_time(&r.nested_queries));
}
min_t
}
#[cfg(feature = "profiling")]
fn write_chrometrace(
path: &std::path::Path,
results: &[wgpu_profiler::GpuTimerQueryResult],
) -> std::io::Result<()> {
use std::io::Write;
let t0 = find_min_time(results);
let mut file = std::io::BufWriter::new(std::fs::File::create(path)?);
write!(file, "{{\n\"traceEvents\": [\n")?;
let mut first = true;
for r in results {
write_trace_event(&mut file, r, &mut first, t0)?;
}
write!(file, "\n]\n}}\n")?;
Ok(())
}
#[cfg(feature = "profiling")]
fn tid_int(r: &wgpu_profiler::GpuTimerQueryResult) -> u64 {
let raw = format!("{:?}", r.tid);
raw.chars()
.filter(|c| c.is_ascii_digit())
.collect::<String>()
.parse::<u64>()
.unwrap_or(1)
}
#[cfg(feature = "profiling")]
fn write_trace_event(
w: &mut impl std::io::Write,
r: &wgpu_profiler::GpuTimerQueryResult,
first: &mut bool,
t0: f64,
) -> std::io::Result<()> {
let has_children = !r.nested_queries.is_empty();
if let Some(time) = effective_time(r) {
let ts_us = (time.start - t0) * 1_000_000.0;
let pid = r.pid;
let tid = tid_int(r);
if has_children {
let b_us = ts_us - 0.01;
let comma = if *first { "" } else { ",\n" };
write!(
w,
"{comma}{{ \"pid\":{pid}, \"tid\":{tid}, \"ts\":{b_us}, \
\"ph\":\"B\", \"name\":\"{}\" }}",
r.label,
)?;
*first = false;
for child in &r.nested_queries {
write_trace_event(w, child, first, t0)?;
}
let e_us = (time.end - t0) * 1_000_000.0 + 0.01;
write!(
w,
",\n{{ \"pid\":{pid}, \"tid\":{tid}, \"ts\":{e_us}, \
\"ph\":\"E\", \"name\":\"{}\" }}",
r.label,
)?;
} else {
let dur_us = (time.end - time.start) * 1_000_000.0;
let comma = if *first { "" } else { ",\n" };
write!(
w,
"{comma}{{ \"pid\":{pid}, \"tid\":{tid}, \"ts\":{ts_us}, \
\"dur\":{dur_us}, \"ph\":\"X\", \"name\":\"{}\" }}",
r.label,
)?;
*first = false;
}
} else {
for child in &r.nested_queries {
write_trace_event(w, child, first, t0)?;
}
}
Ok(())
}