risc0_zkvm/host/api/
mod.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub(crate) mod client;
16pub(crate) mod convert;
17#[cfg(feature = "prove")]
18pub(crate) mod server;
19#[cfg(test)]
20#[cfg(feature = "prove")]
21mod tests;
22
23use std::{
24    cell::RefCell,
25    io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write},
26    net::{TcpListener, TcpStream},
27    path::{Path, PathBuf},
28    process::{Child, Command},
29    sync::{
30        atomic::{AtomicBool, Ordering},
31        mpsc::channel,
32        Arc, Mutex,
33    },
34    thread,
35    time::Duration,
36};
37
38use anyhow::{anyhow, bail, Context, Result};
39use bytes::{Buf, BufMut, Bytes};
40use lazy_regex::regex_captures;
41use prost::Message;
42use semver::Version;
43
44use crate::{get_version, ExitCode, Journal, ReceiptClaim};
45
46mod pb {
47    pub(crate) mod api {
48        pub use crate::host::protos::api::*;
49    }
50    pub(crate) mod base {
51        pub use crate::host::protos::base::*;
52    }
53    pub(crate) mod core {
54        pub use crate::host::protos::core::*;
55    }
56}
57
58const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
59
60trait RootMessage: Message {}
61
62pub trait Connection {
63    fn stream(&mut self) -> &mut TcpStream;
64    fn close(&mut self) -> Result<i32>;
65}
66
67#[derive(Clone)]
68pub struct ConnectionWrapper {
69    inner: Arc<Mutex<dyn Connection + Send>>,
70}
71
72thread_local! {
73    static LOCAL_BUF: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
74}
75
76impl RootMessage for pb::api::HelloRequest {}
77impl RootMessage for pb::api::HelloReply {}
78impl RootMessage for pb::api::ServerRequest {}
79impl RootMessage for pb::api::ServerReply {}
80impl RootMessage for pb::api::GenericReply {}
81impl RootMessage for pb::api::OnIoReply {}
82impl RootMessage for pb::api::ProveKeccakReply {}
83impl RootMessage for pb::api::ProveSegmentReply {}
84impl RootMessage for pb::api::LiftRequest {}
85impl RootMessage for pb::api::LiftReply {}
86impl RootMessage for pb::api::JoinRequest {}
87impl RootMessage for pb::api::JoinReply {}
88impl RootMessage for pb::api::ResolveRequest {}
89impl RootMessage for pb::api::ResolveReply {}
90impl RootMessage for pb::api::IdentityP254Request {}
91impl RootMessage for pb::api::IdentityP254Reply {}
92impl RootMessage for pb::api::CompressRequest {}
93impl RootMessage for pb::api::CompressReply {}
94impl RootMessage for pb::api::UnionRequest {}
95impl RootMessage for pb::api::UnionReply {}
96
97fn lock_err() -> IoError {
98    IoError::new(IoErrorKind::WouldBlock, "Failed to lock connection mutex")
99}
100
101impl ConnectionWrapper {
102    fn new(inner: Arc<Mutex<dyn Connection + Send>>) -> Self {
103        Self { inner }
104    }
105
106    fn send<T: RootMessage>(&mut self, msg: T) -> Result<()> {
107        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
108        self.inner_send(guard.stream(), msg)
109    }
110
111    fn recv<T: Default + RootMessage>(&mut self) -> Result<T> {
112        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
113        self.inner_recv(guard.stream())
114    }
115
116    #[cfg(feature = "prove")]
117    fn send_recv<S: RootMessage, R: Default + RootMessage>(&mut self, msg: S) -> Result<R> {
118        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
119        let stream = guard.stream();
120        self.inner_send(stream, msg)?;
121        self.inner_recv(stream)
122    }
123
124    fn close(&mut self) -> Result<i32> {
125        self.inner.lock().map_err(|_| lock_err())?.close()
126    }
127
128    fn inner_send<T: RootMessage>(&self, stream: &mut TcpStream, msg: T) -> Result<()> {
129        let len = msg.encoded_len();
130        LOCAL_BUF.with_borrow_mut(|buf| {
131            buf.clear();
132            buf.put_u32_le(len as u32);
133            msg.encode(buf)?;
134            Ok(stream.write_all(buf)?)
135        })
136    }
137
138    fn inner_recv<T: Default + RootMessage>(&self, stream: &mut TcpStream) -> Result<T> {
139        LOCAL_BUF.with_borrow_mut(|buf| {
140            buf.resize(4, 0);
141            stream.read_exact(buf).context("rx len failed")?;
142            let len = buf.as_slice().get_u32_le() as usize;
143            buf.resize(len, 0);
144            stream.read_exact(buf).context("rx payload failed")?;
145            T::decode(buf.as_slice()).context("rx decode failed")
146        })
147    }
148}
149
150/// Connects a zkVM client and server
151pub trait Connector {
152    /// Create a client-server connection
153    fn connect(&self) -> Result<ConnectionWrapper>;
154}
155
156struct ParentProcessConnector {
157    server_path: PathBuf,
158    listener: TcpListener,
159}
160
161impl ParentProcessConnector {
162    pub fn new<P: AsRef<Path>>(server_path: P) -> Result<Self> {
163        Ok(Self {
164            server_path: server_path.as_ref().to_path_buf(),
165            listener: TcpListener::bind("127.0.0.1:0")?,
166        })
167    }
168
169    pub fn new_narrow_version<P: AsRef<Path>>(server_path: P) -> Result<Self> {
170        // Check the version of the client and server to ensure that they're compatible
171        let client_version = get_version().map_err(|err| anyhow!(err))?;
172        let server_version = get_server_version(&server_path)?;
173        if !client::Compat::Narrow.check(&client_version, &server_version) {
174            let server_suggestion = if client_version.pre.is_empty() {
175                format!(
176                    "1. Install r0vm server version {}.{}\n",
177                    client_version.major, client_version.minor
178                )
179            } else {
180                format!("1. Your risc0 dependencies are using a pre-released version {client_version}.\n   \
181                    If you encounter this error message when running code on the risc0 codebase, you must\n   \
182                    either run the command `git checkout origin/release-{}.{}` to checkout the version of the\n   \
183                    risc0 code that is compatible with your server or build the r0vm server from source\n   \
184                    https://github.com/risc0/risc0/blob/main/CONTRIBUTING.md\n", server_version.major, server_version.minor
185                )
186            };
187            let msg = format!(
188                "Your installation of the r0vm server is not compatible with your host's risc0-zkvm crate.\n\
189                Do one of the following to fix this issue:\n\n\
190                {server_suggestion}\
191                2. Change the risc0-zkvm and risc0-build dependencies in your project to {}.{}\n\n\
192                risc0-zkvm version: {client_version}\n\
193                r0vm server version: {server_version}", server_version.major, server_version.minor
194            );
195            tracing::warn!("{msg}");
196            bail!(msg);
197        }
198
199        Self::new(server_path)
200    }
201
202    fn spawn_fail(&self) -> String {
203        format!(
204            "Could not launch zkvm: \"{}\". \n
205            Use `cargo binstall cargo-risczero` to install the latest zkvm.",
206            self.server_path.to_string_lossy()
207        )
208    }
209}
210
211fn get_server_version<P: AsRef<Path>>(server_path: P) -> Result<Version> {
212    let output = Command::new(server_path.as_ref().as_os_str())
213        .arg("--version")
214        .output()?;
215    let cmd_output = String::from_utf8(output.stdout)?;
216    let (_, version_str) = regex_captures!(r".* (.*)\n$", &cmd_output)
217        .ok_or_else(|| anyhow!("failed to parse server version number"))?;
218    Version::parse(version_str).map_err(|e| anyhow!(e))
219}
220
221impl Connector for ParentProcessConnector {
222    fn connect(&self) -> Result<ConnectionWrapper> {
223        let addr = self.listener.local_addr()?;
224        let child = Command::new(&self.server_path)
225            .arg("--port")
226            .arg(addr.port().to_string())
227            .spawn()
228            .with_context(|| self.spawn_fail())?;
229
230        let shutdown = Arc::new(AtomicBool::new(false));
231        let server_shutdown = shutdown.clone();
232        let (tx, rx) = channel();
233        let listener = self.listener.try_clone()?;
234        let handle = thread::spawn(move || {
235            let stream = listener.accept();
236            if server_shutdown.load(Ordering::Relaxed) {
237                return;
238            }
239            if let Ok((stream, _addr)) = stream {
240                tx.send(stream).unwrap();
241            }
242        });
243
244        let stream = rx.recv_timeout(CONNECT_TIMEOUT);
245        let stream = stream.inspect_err(|_| {
246            shutdown.store(true, Ordering::Relaxed);
247            let _ = TcpStream::connect(addr);
248            handle.join().unwrap();
249        })?;
250
251        Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
252            ParentProcessConnection::new(child, stream),
253        ))))
254    }
255}
256
257#[cfg(feature = "prove")]
258struct TcpConnector {
259    addr: String,
260}
261
262#[cfg(feature = "prove")]
263impl TcpConnector {
264    pub(crate) fn new(addr: &str) -> Self {
265        Self {
266            addr: addr.to_string(),
267        }
268    }
269}
270
271#[cfg(feature = "prove")]
272impl Connector for TcpConnector {
273    fn connect(&self) -> Result<ConnectionWrapper> {
274        tracing::debug!("connect");
275        let stream = TcpStream::connect(&self.addr)?;
276        Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
277            TcpConnection::new(stream),
278        ))))
279    }
280}
281
282struct ParentProcessConnection {
283    child: Child,
284    stream: TcpStream,
285}
286
287#[cfg(feature = "prove")]
288struct TcpConnection {
289    stream: TcpStream,
290}
291
292impl ParentProcessConnection {
293    pub fn new(child: Child, stream: TcpStream) -> Self {
294        Self { child, stream }
295    }
296}
297
298impl Connection for ParentProcessConnection {
299    fn stream(&mut self) -> &mut TcpStream {
300        &mut self.stream
301    }
302
303    fn close(&mut self) -> Result<i32> {
304        let status = self.child.wait()?;
305        Ok(status.code().unwrap_or_default())
306    }
307}
308
309#[cfg(feature = "prove")]
310impl TcpConnection {
311    pub fn new(stream: TcpStream) -> Self {
312        Self { stream }
313    }
314}
315
316#[cfg(feature = "prove")]
317impl Connection for TcpConnection {
318    fn stream(&mut self) -> &mut TcpStream {
319        &mut self.stream
320    }
321
322    fn close(&mut self) -> Result<i32> {
323        Ok(0)
324    }
325}
326
327fn malformed_err(field: &str) -> anyhow::Error {
328    anyhow!("Malformed error: {field}")
329}
330
331impl pb::api::Asset {
332    fn as_bytes(&self) -> Result<Bytes> {
333        let bytes = match self
334            .kind
335            .as_ref()
336            .ok_or_else(|| malformed_err("Asset.kind"))?
337        {
338            pb::api::asset::Kind::Inline(bytes) => bytes.clone(),
339            pb::api::asset::Kind::Path(path) => std::fs::read(path)?,
340            pb::api::asset::Kind::Redis(_) => bail!("as_bytes not supported for redis"),
341        };
342        Ok(bytes.into())
343    }
344}
345
346/// Determines the format of an asset.
347#[derive(Clone)]
348pub enum Asset {
349    /// The asset is encoded inline.
350    Inline(Bytes),
351
352    /// The asset is written to disk.
353    Path(PathBuf),
354
355    /// The asset is written to redis.
356    Redis(String),
357}
358
359/// Determines the parameters for AssetRequest::Redis
360#[derive(Clone)]
361pub struct RedisParams {
362    /// The url of the redis instance
363    pub url: String,
364
365    /// The key used to write to redis
366    pub key: String,
367
368    /// time to live (expiration) for the key being set
369    pub ttl: u64,
370}
371
372/// Determines the format of an asset request.
373#[derive(Clone)]
374pub enum AssetRequest {
375    /// The asset is encoded inline.
376    Inline,
377
378    /// The asset is written to disk.
379    Path(PathBuf),
380
381    /// The asset is written to redis.
382    Redis(RedisParams),
383}
384
385/// Provides information about the result of execution.
386#[derive(Clone, Debug)]
387#[non_exhaustive]
388pub struct SessionInfo {
389    /// The number of user cycles for each segment.
390    pub segments: Vec<SegmentInfo>,
391
392    /// The data publicly committed by the guest program.
393    pub journal: Journal,
394
395    /// The [ExitCode] of the session.
396    pub exit_code: ExitCode,
397
398    /// The [ReceiptClaim] associated with the executed session. This receipt claim is what will be
399    /// proven if this session is passed to the Prover.
400    pub receipt_claim: Option<ReceiptClaim>,
401}
402
403impl SessionInfo {
404    /// The total number of user cycles across all segments, without any
405    /// overhead for continuations or po2 padding.
406    pub fn cycles(&self) -> u64 {
407        self.segments.iter().map(|s| s.cycles as u64).sum()
408    }
409}
410
411/// Provides information about a segment of execution.
412#[derive(Clone, Debug)]
413#[non_exhaustive]
414pub struct SegmentInfo {
415    /// The number of cycles used for proving in powers of 2.
416    pub po2: u32,
417
418    /// The number of user cycles without any overhead for continuations or po2
419    /// padding.
420    pub cycles: u32,
421}
422
423impl Asset {
424    /// Return the bytes for this asset.
425    pub fn as_bytes(&self) -> Result<Bytes> {
426        Ok(match self {
427            Asset::Inline(bytes) => bytes.clone(),
428            Asset::Path(path) => std::fs::read(path)?.into(),
429            Asset::Redis(_) => bail!("as_bytes not supported for Asset::Redis"),
430        })
431    }
432}
433
434fn invalid_path() -> anyhow::Error {
435    anyhow::Error::msg("Path must be UTF-8")
436}
437
438fn path_to_string<P: AsRef<Path>>(path: P) -> Result<String> {
439    Ok(path.as_ref().to_str().ok_or_else(invalid_path)?.to_string())
440}