risc0_zkvm/host/client/prove/
default.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    io::{Read, Write},
17    os::{fd::OwnedFd, unix::net::UnixStream},
18    path::Path,
19    process::{Child, Command},
20    sync::Arc,
21};
22
23use anyhow::{bail, Context as _, Result};
24
25use crate::{
26    rpc::{JobInfo, JobStatus, ProofRequest},
27    ExecutorEnv, ProveInfo, Receipt, SessionInfo, SessionStats, VerifierContext,
28};
29
30use super::{Executor, Prover, ProverOpts};
31
32/// TODO
33pub struct DefaultProver {
34    child: Child,
35    socket: UnixStream,
36}
37
38/// An implementation of a [Prover] that runs proof workloads via local `r0vm` cluster.
39impl DefaultProver {
40    /// Construct a [DefaultProver].
41    pub fn new<P: AsRef<Path>>(r0vm_path: P) -> Result<Self> {
42        let r0vm_path = r0vm_path.as_ref();
43
44        let (socket, child_socket) = UnixStream::pair()?;
45        let child_fd: OwnedFd = child_socket.into();
46        let mut cmd = Command::new(r0vm_path);
47        cmd.stdin(child_fd).arg("--rpc");
48        if let Ok(num_gpus) = std::env::var("RISC0_DEFAULT_PROVER_NUM_GPUS") {
49            cmd.arg("--num-gpus").arg(num_gpus);
50        }
51        let child = cmd.spawn().with_context(|| spawn_fail(r0vm_path))?;
52
53        Ok(Self { child, socket })
54    }
55
56    /// TODO
57    pub fn stop(&mut self) -> Result<()> {
58        self.socket.shutdown(std::net::Shutdown::Both)?;
59        self.child.wait()?;
60
61        Ok(())
62    }
63}
64
65impl Drop for DefaultProver {
66    fn drop(&mut self) {
67        if let Err(error) = self.stop() {
68            tracing::warn!("error stopping r0vm: {error}");
69        }
70    }
71}
72
73fn spawn_fail(path: &Path) -> String {
74    format!("Could not launch \"{}\".", path.to_string_lossy())
75}
76
77impl Prover for DefaultProver {
78    fn get_name(&self) -> String {
79        "default".to_string()
80    }
81
82    fn prove_with_ctx(
83        &self,
84        env: ExecutorEnv<'_>,
85        _ctx: &VerifierContext,
86        elf: &[u8],
87        _opts: &ProverOpts,
88    ) -> Result<ProveInfo> {
89        let proof_request = ProofRequest {
90            binary: elf.to_vec(),
91            input: env.input,
92            assumptions: env.assumptions.borrow().0.clone(),
93            segment_limit_po2: env.segment_limit_po2,
94        };
95
96        let mut buf = vec![0u8; 4];
97        bincode::serialize_into(&mut buf, &proof_request)
98            .context("error serializing RPC header")?;
99        let body_len = buf.len() as u32 - 4;
100        bincode::serialize_into(&mut buf[0..4], &body_len).context("error serializing RPC body")?;
101        let mut socket = &self.socket;
102
103        socket
104            .write_all(&buf)
105            .context("error sending RPC message")?;
106
107        let mut buf = vec![0u8; 4];
108        socket
109            .read_exact(&mut buf)
110            .context("error receiving RPC header")?;
111        let body_len: u32 = bincode::deserialize(&buf).context("error deserializing RPC header")?;
112        let mut buf = vec![0u8; body_len as usize];
113        socket
114            .read_exact(&mut buf)
115            .context("error receiving RPC body")?;
116        let job_info: JobInfo =
117            bincode::deserialize(&buf).context("error deserializing RPC body")?;
118
119        tracing::info!("Elapsed time: {:?}", job_info.elapsed_time);
120
121        let prove_info = match job_info.status {
122            JobStatus::Running(progress) => bail!("Job is still running: {progress}"),
123            JobStatus::Succeeded(result) => ProveInfo {
124                receipt: Arc::into_inner(result.receipt).unwrap(),
125                work_receipt: None, // TODO(povw): implement PoVW here
126                stats: SessionStats {
127                    segments: result.session.segment_count,
128                    total_cycles: result.session.total_cycles,
129                    user_cycles: result.session.user_cycles,
130                    paging_cycles: 0,
131                    reserved_cycles: 0,
132                },
133            },
134            JobStatus::Failed(err) => bail!(format!("Task error: {err:?}")),
135            JobStatus::TimedOut => bail!("TimedOut"),
136            JobStatus::Aborted => bail!("Aborted"),
137        };
138
139        Ok(prove_info)
140    }
141
142    fn compress(&self, _opts: &ProverOpts, _receipt: &Receipt) -> Result<Receipt> {
143        unimplemented!()
144    }
145}
146
147impl Executor for DefaultProver {
148    fn execute(&self, _env: ExecutorEnv<'_>, _elf: &[u8]) -> Result<SessionInfo> {
149        todo!()
150    }
151}