risc0_zkvm/host/api/
mod.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub(crate) mod client;
16pub(crate) mod convert;
17#[cfg(feature = "prove")]
18pub(crate) mod server;
19#[cfg(test)]
20#[cfg(feature = "prove")]
21mod tests;
22
23use std::{
24    cell::RefCell,
25    io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write},
26    net::{TcpListener, TcpStream},
27    path::{Path, PathBuf},
28    process::{Child, Command},
29    sync::{
30        atomic::{AtomicBool, Ordering},
31        mpsc::channel,
32        Arc, Mutex,
33    },
34    thread,
35    time::Duration,
36};
37
38use anyhow::{anyhow, bail, Context, Result};
39use bytes::{Buf, BufMut, Bytes};
40use lazy_regex::regex_captures;
41use prost::Message;
42use semver::Version;
43
44use crate::{get_version, ExitCode, Journal, ReceiptClaim};
45
46mod pb {
47    pub(crate) mod api {
48        pub use crate::host::protos::api::*;
49    }
50    pub(crate) mod base {
51        pub use crate::host::protos::base::*;
52    }
53    pub(crate) mod core {
54        pub use crate::host::protos::core::*;
55    }
56}
57
58const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
59
60trait RootMessage: Message {}
61
62pub trait Connection {
63    fn stream(&mut self) -> &mut TcpStream;
64    fn close(&mut self) -> Result<i32>;
65}
66
67#[derive(Clone)]
68pub struct ConnectionWrapper {
69    inner: Arc<Mutex<dyn Connection + Send>>,
70}
71
72thread_local! {
73    static LOCAL_BUF: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
74}
75
76impl RootMessage for pb::api::HelloRequest {}
77impl RootMessage for pb::api::HelloReply {}
78impl RootMessage for pb::api::ServerRequest {}
79impl RootMessage for pb::api::ServerReply {}
80impl RootMessage for pb::api::GenericReply {}
81impl RootMessage for pb::api::OnIoReply {}
82impl RootMessage for pb::api::ProveKeccakReply {}
83impl RootMessage for pb::api::ProveSegmentReply {}
84impl RootMessage for pb::api::ProveZkrReply {}
85impl RootMessage for pb::api::LiftRequest {}
86impl RootMessage for pb::api::LiftReply {}
87impl RootMessage for pb::api::JoinRequest {}
88impl RootMessage for pb::api::JoinReply {}
89impl RootMessage for pb::api::ResolveRequest {}
90impl RootMessage for pb::api::ResolveReply {}
91impl RootMessage for pb::api::IdentityP254Request {}
92impl RootMessage for pb::api::IdentityP254Reply {}
93impl RootMessage for pb::api::CompressRequest {}
94impl RootMessage for pb::api::CompressReply {}
95impl RootMessage for pb::api::UnionRequest {}
96impl RootMessage for pb::api::UnionReply {}
97
98fn lock_err() -> IoError {
99    IoError::new(IoErrorKind::WouldBlock, "Failed to lock connection mutex")
100}
101
102impl ConnectionWrapper {
103    fn new(inner: Arc<Mutex<dyn Connection + Send>>) -> Self {
104        Self { inner }
105    }
106
107    fn send<T: RootMessage>(&mut self, msg: T) -> Result<()> {
108        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
109        self.inner_send(guard.stream(), msg)
110    }
111
112    fn recv<T: Default + RootMessage>(&mut self) -> Result<T> {
113        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
114        self.inner_recv(guard.stream())
115    }
116
117    #[cfg(feature = "prove")]
118    fn send_recv<S: RootMessage, R: Default + RootMessage>(&mut self, msg: S) -> Result<R> {
119        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
120        let stream = guard.stream();
121        self.inner_send(stream, msg)?;
122        self.inner_recv(stream)
123    }
124
125    fn close(&mut self) -> Result<i32> {
126        self.inner.lock().map_err(|_| lock_err())?.close()
127    }
128
129    fn inner_send<T: RootMessage>(&self, stream: &mut TcpStream, msg: T) -> Result<()> {
130        let len = msg.encoded_len();
131        LOCAL_BUF.with_borrow_mut(|buf| {
132            buf.clear();
133            buf.put_u32_le(len as u32);
134            msg.encode(buf)?;
135            Ok(stream.write_all(buf)?)
136        })
137    }
138
139    fn inner_recv<T: Default + RootMessage>(&self, stream: &mut TcpStream) -> Result<T> {
140        LOCAL_BUF.with_borrow_mut(|buf| {
141            buf.resize(4, 0);
142            stream.read_exact(buf).context("rx len failed")?;
143            let len = buf.as_slice().get_u32_le() as usize;
144            buf.resize(len, 0);
145            stream.read_exact(buf).context("rx payload failed")?;
146            T::decode(buf.as_slice()).context("rx decode failed")
147        })
148    }
149}
150
151/// Connects a zkVM client and server
152pub trait Connector {
153    /// Create a client-server connection
154    fn connect(&self) -> Result<ConnectionWrapper>;
155}
156
157struct ParentProcessConnector {
158    server_path: PathBuf,
159    listener: TcpListener,
160}
161
162impl ParentProcessConnector {
163    pub fn new<P: AsRef<Path>>(server_path: P) -> Result<Self> {
164        Ok(Self {
165            server_path: server_path.as_ref().to_path_buf(),
166            listener: TcpListener::bind("127.0.0.1:0")?,
167        })
168    }
169
170    pub fn new_narrow_version<P: AsRef<Path>>(server_path: P) -> Result<Self> {
171        // Check the version of the client and server to ensure that they're compatible
172        let client_version = get_version().map_err(|err| anyhow!(err))?;
173        let server_version = get_server_version(&server_path)?;
174        if !client::Compat::Narrow.check(&client_version, &server_version) {
175            let server_suggestion = if client_version.pre.is_empty() {
176                format!(
177                    "1. Install r0vm server version {}.{}\n",
178                    server_version.major, server_version.minor
179                )
180            } else {
181                format!("1. Your risc0 dependencies are using a pre-released version {client_version}.\n   \
182                    If you encounter this error message when running code on the risc0 codebase, you must\n   \
183                    either run the command `git checkout origin/release-{}.{}` to checkout the version of the\n   \
184                    risc0 code that is compatible with your server or build the r0vm server from source\n   \
185                    https://github.com/risc0/risc0/blob/main/CONTRIBUTING.md\n", server_version.major, server_version.minor
186                )
187            };
188            let msg = format!(
189                "Your installation of the r0vm server is not compatible with your host's risc0-zkvm crate.\n\
190                Do one of the following to fix this issue:\n\n\
191                {server_suggestion}\
192                2. Change the risc0-zkvm and risc0-build dependencies in your project to {}.{}\n\n\
193                risc0-zkvm version: {client_version}\n\
194                r0vm server version: {server_version}", server_version.major, server_version.minor
195            );
196            tracing::warn!("{msg}");
197            bail!(msg);
198        }
199
200        Self::new(server_path)
201    }
202
203    fn spawn_fail(&self) -> String {
204        format!(
205            "Could not launch zkvm: \"{}\". \n
206            Use `cargo binstall cargo-risczero` to install the latest zkvm.",
207            self.server_path.to_string_lossy()
208        )
209    }
210}
211
212fn get_server_version<P: AsRef<Path>>(server_path: P) -> Result<Version> {
213    let output = Command::new(server_path.as_ref().as_os_str())
214        .arg("--version")
215        .output()?;
216    let cmd_output = String::from_utf8(output.stdout)?;
217    let (_, version_str) = regex_captures!(r".* (.*)\n$", &cmd_output)
218        .ok_or_else(|| anyhow!("failed to parse server version number"))?;
219    Version::parse(version_str).map_err(|e| anyhow!(e))
220}
221
222impl Connector for ParentProcessConnector {
223    fn connect(&self) -> Result<ConnectionWrapper> {
224        let addr = self.listener.local_addr()?;
225        let child = Command::new(&self.server_path)
226            .arg("--port")
227            .arg(addr.port().to_string())
228            .spawn()
229            .with_context(|| self.spawn_fail())?;
230
231        let shutdown = Arc::new(AtomicBool::new(false));
232        let server_shutdown = shutdown.clone();
233        let (tx, rx) = channel();
234        let listener = self.listener.try_clone()?;
235        let handle = thread::spawn(move || {
236            let stream = listener.accept();
237            if server_shutdown.load(Ordering::Relaxed) {
238                return;
239            }
240            if let Ok((stream, _addr)) = stream {
241                tx.send(stream).unwrap();
242            }
243        });
244
245        let stream = rx.recv_timeout(CONNECT_TIMEOUT);
246        let stream = stream.inspect_err(|_| {
247            shutdown.store(true, Ordering::Relaxed);
248            let _ = TcpStream::connect(addr);
249            handle.join().unwrap();
250        })?;
251
252        Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
253            ParentProcessConnection::new(child, stream),
254        ))))
255    }
256}
257
258#[cfg(feature = "prove")]
259struct TcpConnector {
260    addr: String,
261}
262
263#[cfg(feature = "prove")]
264impl TcpConnector {
265    pub(crate) fn new(addr: &str) -> Self {
266        Self {
267            addr: addr.to_string(),
268        }
269    }
270}
271
272#[cfg(feature = "prove")]
273impl Connector for TcpConnector {
274    fn connect(&self) -> Result<ConnectionWrapper> {
275        tracing::debug!("connect");
276        let stream = TcpStream::connect(&self.addr)?;
277        Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
278            TcpConnection::new(stream),
279        ))))
280    }
281}
282
283struct ParentProcessConnection {
284    child: Child,
285    stream: TcpStream,
286}
287
288#[cfg(feature = "prove")]
289struct TcpConnection {
290    stream: TcpStream,
291}
292
293impl ParentProcessConnection {
294    pub fn new(child: Child, stream: TcpStream) -> Self {
295        Self { child, stream }
296    }
297}
298
299impl Connection for ParentProcessConnection {
300    fn stream(&mut self) -> &mut TcpStream {
301        &mut self.stream
302    }
303
304    fn close(&mut self) -> Result<i32> {
305        let status = self.child.wait()?;
306        Ok(status.code().unwrap_or_default())
307    }
308}
309
310#[cfg(feature = "prove")]
311impl TcpConnection {
312    pub fn new(stream: TcpStream) -> Self {
313        Self { stream }
314    }
315}
316
317#[cfg(feature = "prove")]
318impl Connection for TcpConnection {
319    fn stream(&mut self) -> &mut TcpStream {
320        &mut self.stream
321    }
322
323    fn close(&mut self) -> Result<i32> {
324        Ok(0)
325    }
326}
327
328fn malformed_err(field: &str) -> anyhow::Error {
329    anyhow!("Malformed error: {field}")
330}
331
332impl pb::api::Asset {
333    fn as_bytes(&self) -> Result<Bytes> {
334        let bytes = match self
335            .kind
336            .as_ref()
337            .ok_or_else(|| malformed_err("Asset.kind"))?
338        {
339            pb::api::asset::Kind::Inline(bytes) => bytes.clone(),
340            pb::api::asset::Kind::Path(path) => std::fs::read(path)?,
341            pb::api::asset::Kind::Redis(_) => bail!("as_bytes not supported for redis"),
342        };
343        Ok(bytes.into())
344    }
345}
346
347/// Determines the format of an asset.
348#[derive(Clone)]
349pub enum Asset {
350    /// The asset is encoded inline.
351    Inline(Bytes),
352
353    /// The asset is written to disk.
354    Path(PathBuf),
355
356    /// The asset is written to redis.
357    Redis(String),
358}
359
360/// Determines the parameters for AssetRequest::Redis
361#[derive(Clone)]
362pub struct RedisParams {
363    /// The url of the redis instance
364    pub url: String,
365
366    /// The key used to write to redis
367    pub key: String,
368
369    /// time to live (expiration) for the key being set
370    pub ttl: u64,
371}
372
373/// Determines the format of an asset request.
374#[derive(Clone)]
375pub enum AssetRequest {
376    /// The asset is encoded inline.
377    Inline,
378
379    /// The asset is written to disk.
380    Path(PathBuf),
381
382    /// The asset is written to redis.
383    Redis(RedisParams),
384}
385
386/// Provides information about the result of execution.
387#[derive(Clone, Debug)]
388#[non_exhaustive]
389pub struct SessionInfo {
390    /// The number of user cycles for each segment.
391    pub segments: Vec<SegmentInfo>,
392
393    /// The data publicly committed by the guest program.
394    pub journal: Journal,
395
396    /// The [ExitCode] of the session.
397    pub exit_code: ExitCode,
398
399    /// The [ReceiptClaim] associated with the executed session. This receipt claim is what will be
400    /// proven if this session is passed to the Prover.
401    pub receipt_claim: Option<ReceiptClaim>,
402}
403
404impl SessionInfo {
405    /// The total number of user cycles across all segments, without any
406    /// overhead for continuations or po2 padding.
407    pub fn cycles(&self) -> u64 {
408        self.segments.iter().map(|s| s.cycles as u64).sum()
409    }
410}
411
412/// Provides information about a segment of execution.
413#[derive(Clone, Debug)]
414#[non_exhaustive]
415pub struct SegmentInfo {
416    /// The number of cycles used for proving in powers of 2.
417    pub po2: u32,
418
419    /// The number of user cycles without any overhead for continuations or po2
420    /// padding.
421    pub cycles: u32,
422}
423
424impl Asset {
425    /// Return the bytes for this asset.
426    pub fn as_bytes(&self) -> Result<Bytes> {
427        Ok(match self {
428            Asset::Inline(bytes) => bytes.clone(),
429            Asset::Path(path) => std::fs::read(path)?.into(),
430            Asset::Redis(_) => bail!("as_bytes not supported for Asset::Redis"),
431        })
432    }
433}
434
435fn invalid_path() -> anyhow::Error {
436    anyhow::Error::msg("Path must be UTF-8")
437}
438
439fn path_to_string<P: AsRef<Path>>(path: P) -> Result<String> {
440    Ok(path.as_ref().to_str().ok_or_else(invalid_path)?.to_string())
441}