risc0_zkvm/host/api/
mod.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub(crate) mod client;
16pub(crate) mod convert;
17#[cfg(feature = "prove")]
18pub(crate) mod server;
19#[cfg(test)]
20#[cfg(feature = "prove")]
21mod tests;
22
23use std::{
24    cell::RefCell,
25    io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write},
26    net::{TcpListener, TcpStream},
27    path::{Path, PathBuf},
28    process::{Child, Command},
29    sync::{
30        atomic::{AtomicBool, Ordering},
31        mpsc::channel,
32        Arc, Mutex,
33    },
34    thread,
35    time::Duration,
36};
37
38use anyhow::{anyhow, bail, Context, Result};
39use bytes::{Buf, BufMut, Bytes};
40use lazy_regex::regex_captures;
41use prost::Message;
42use semver::Version;
43
44use crate::{get_version, ExitCode, Journal, ReceiptClaim};
45
46mod pb {
47    pub(crate) mod api {
48        pub use crate::host::protos::api::*;
49    }
50    pub(crate) mod base {
51        pub use crate::host::protos::base::*;
52    }
53    pub(crate) mod core {
54        pub use crate::host::protos::core::*;
55    }
56}
57
58const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
59
60trait RootMessage: Message {}
61
62pub trait Connection {
63    fn stream(&mut self) -> &mut TcpStream;
64    fn close(&mut self) -> Result<i32>;
65}
66
67#[derive(Clone)]
68pub struct ConnectionWrapper {
69    inner: Arc<Mutex<dyn Connection + Send>>,
70}
71
72thread_local! {
73    static LOCAL_BUF: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
74}
75
76impl RootMessage for pb::api::HelloRequest {}
77impl RootMessage for pb::api::HelloReply {}
78impl RootMessage for pb::api::ServerRequest {}
79impl RootMessage for pb::api::ServerReply {}
80impl RootMessage for pb::api::GenericReply {}
81impl RootMessage for pb::api::OnIoReply {}
82impl RootMessage for pb::api::ProveKeccakReply {}
83impl RootMessage for pb::api::ProveSegmentReply {}
84impl RootMessage for pb::api::ProveZkrReply {}
85impl RootMessage for pb::api::LiftRequest {}
86impl RootMessage for pb::api::LiftReply {}
87impl RootMessage for pb::api::JoinRequest {}
88impl RootMessage for pb::api::JoinReply {}
89impl RootMessage for pb::api::ResolveRequest {}
90impl RootMessage for pb::api::ResolveReply {}
91impl RootMessage for pb::api::IdentityP254Request {}
92impl RootMessage for pb::api::IdentityP254Reply {}
93impl RootMessage for pb::api::CompressRequest {}
94impl RootMessage for pb::api::CompressReply {}
95
96fn lock_err() -> IoError {
97    IoError::new(IoErrorKind::WouldBlock, "Failed to lock connection mutex")
98}
99
100impl ConnectionWrapper {
101    fn new(inner: Arc<Mutex<dyn Connection + Send>>) -> Self {
102        Self { inner }
103    }
104
105    fn send<T: RootMessage>(&mut self, msg: T) -> Result<()> {
106        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
107        self.inner_send(guard.stream(), msg)
108    }
109
110    fn recv<T: Default + RootMessage>(&mut self) -> Result<T> {
111        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
112        self.inner_recv(guard.stream())
113    }
114
115    #[cfg(feature = "prove")]
116    fn send_recv<S: RootMessage, R: Default + RootMessage>(&mut self, msg: S) -> Result<R> {
117        let mut guard = self.inner.lock().map_err(|_| lock_err())?;
118        let stream = guard.stream();
119        self.inner_send(stream, msg)?;
120        self.inner_recv(stream)
121    }
122
123    fn close(&mut self) -> Result<i32> {
124        self.inner.lock().map_err(|_| lock_err())?.close()
125    }
126
127    fn inner_send<T: RootMessage>(&self, stream: &mut TcpStream, msg: T) -> Result<()> {
128        let len = msg.encoded_len();
129        LOCAL_BUF.with_borrow_mut(|buf| {
130            buf.clear();
131            buf.put_u32_le(len as u32);
132            msg.encode(buf)?;
133            Ok(stream.write_all(buf)?)
134        })
135    }
136
137    fn inner_recv<T: Default + RootMessage>(&self, stream: &mut TcpStream) -> Result<T> {
138        LOCAL_BUF.with_borrow_mut(|buf| {
139            buf.resize(4, 0);
140            stream.read_exact(buf).context("rx len failed")?;
141            let len = buf.as_slice().get_u32_le() as usize;
142            buf.resize(len, 0);
143            stream.read_exact(buf).context("rx payload failed")?;
144            T::decode(buf.as_slice()).context("rx decode failed")
145        })
146    }
147}
148
149/// Connects a zkVM client and server
150pub trait Connector {
151    /// Create a client-server connection
152    fn connect(&self) -> Result<ConnectionWrapper>;
153}
154
155struct ParentProcessConnector {
156    server_path: PathBuf,
157    listener: TcpListener,
158}
159
160impl ParentProcessConnector {
161    pub fn new<P: AsRef<Path>>(server_path: P) -> Result<Self> {
162        // Check the version of the client and server to ensure that they're compatible
163        let client_version = get_version().map_err(|err| anyhow!(err))?;
164        let server_version = get_server_version(&server_path)?;
165        if !client::check_server_version(&client_version, &server_version) {
166            let server_suggestion = if client_version.pre.is_empty() {
167                format!(
168                    "1. Install r0vm server version {}.{}\n",
169                    server_version.major, server_version.minor
170                )
171            } else {
172                format!("1. Your risc0 dependencies are using a pre-released version {client_version}.\n   \
173                    If you encounter this error message when running code on the risc0 codebase, you must\n   \
174                    either run the command `git checkout origin/release-{}.{}` to checkout the version of the\n   \
175                    risc0 code that is compatible with your server or build the r0vm server from source\n   \
176                    https://github.com/risc0/risc0/blob/main/CONTRIBUTING.md\n", server_version.major, server_version.minor
177                )
178            };
179            let msg = format!(
180                "Your installation of the r0vm server is not compatible with your host's risc0-zkvm crate.\n\
181                Do one of the following to fix this issue:\n\n\
182                {server_suggestion}\
183                2. Change the risc0-zkvm and risc0-build dependencies in your project to {}.{}\n\n\
184                risc0-zkvm version: {client_version}\n\
185                r0vm server version: {server_version}", server_version.major, server_version.minor
186            );
187            tracing::warn!("{msg}");
188            bail!(msg);
189        }
190
191        Ok(Self {
192            server_path: server_path.as_ref().to_path_buf(),
193            listener: TcpListener::bind("127.0.0.1:0")?,
194        })
195    }
196
197    pub fn new_wide_version<P: AsRef<Path>>(server_path: P) -> Result<Self> {
198        let client_version = get_version().map_err(|err| anyhow!(err))?;
199        let server_version = get_server_version(&server_path)?;
200
201        if !client::check_server_version_wide(&client_version, &server_version) {
202            let msg = format!(
203                "Your installation of r0vm differs by a major version:\n\
204            {client_version} vs {server_version} only minor, patch / pre-releases supported"
205            );
206            tracing::warn!("{msg}");
207            bail!(msg);
208        }
209        Ok(Self {
210            server_path: server_path.as_ref().to_path_buf(),
211            listener: TcpListener::bind("127.0.0.1:0")?,
212        })
213    }
214
215    fn spawn_fail(&self) -> String {
216        format!(
217            "Could not launch zkvm: \"{}\". \n
218            Use `cargo binstall cargo-risczero` to install the latest zkvm.",
219            self.server_path.to_string_lossy()
220        )
221    }
222}
223
224fn get_server_version<P: AsRef<Path>>(server_path: P) -> Result<Version> {
225    let output = Command::new(server_path.as_ref().as_os_str())
226        .arg("--version")
227        .output()?;
228    let cmd_output = String::from_utf8(output.stdout)?;
229    let (_, version_str) = regex_captures!(r".* (.*)\n$", &cmd_output)
230        .ok_or(anyhow!("failed to parse server version number"))?;
231    Version::parse(version_str).map_err(|e| anyhow!(e))
232}
233
234impl Connector for ParentProcessConnector {
235    fn connect(&self) -> Result<ConnectionWrapper> {
236        let addr = self.listener.local_addr()?;
237        let child = Command::new(&self.server_path)
238            .arg("--port")
239            .arg(addr.port().to_string())
240            .spawn()
241            .with_context(|| self.spawn_fail())?;
242
243        let shutdown = Arc::new(AtomicBool::new(false));
244        let server_shutdown = shutdown.clone();
245        let (tx, rx) = channel();
246        let listener = self.listener.try_clone()?;
247        let handle = thread::spawn(move || {
248            let stream = listener.accept();
249            if server_shutdown.load(Ordering::Relaxed) {
250                return;
251            }
252            if let Ok((stream, _addr)) = stream {
253                tx.send(stream).unwrap();
254            }
255        });
256
257        let stream = rx.recv_timeout(CONNECT_TIMEOUT);
258        let stream = stream.inspect_err(|_| {
259            shutdown.store(true, Ordering::Relaxed);
260            let _ = TcpStream::connect(addr);
261            handle.join().unwrap();
262        })?;
263
264        Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
265            ParentProcessConnection::new(child, stream),
266        ))))
267    }
268}
269
270#[cfg(feature = "prove")]
271struct TcpConnector {
272    addr: String,
273}
274
275#[cfg(feature = "prove")]
276impl TcpConnector {
277    pub(crate) fn new(addr: &str) -> Self {
278        Self {
279            addr: addr.to_string(),
280        }
281    }
282}
283
284#[cfg(feature = "prove")]
285impl Connector for TcpConnector {
286    fn connect(&self) -> Result<ConnectionWrapper> {
287        tracing::debug!("connect");
288        let stream = TcpStream::connect(&self.addr)?;
289        Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
290            TcpConnection::new(stream),
291        ))))
292    }
293}
294
295struct ParentProcessConnection {
296    child: Child,
297    stream: TcpStream,
298}
299
300#[cfg(feature = "prove")]
301struct TcpConnection {
302    stream: TcpStream,
303}
304
305impl ParentProcessConnection {
306    pub fn new(child: Child, stream: TcpStream) -> Self {
307        Self { child, stream }
308    }
309}
310
311impl Connection for ParentProcessConnection {
312    fn stream(&mut self) -> &mut TcpStream {
313        &mut self.stream
314    }
315
316    fn close(&mut self) -> Result<i32> {
317        let status = self.child.wait()?;
318        Ok(status.code().unwrap_or_default())
319    }
320}
321
322#[cfg(feature = "prove")]
323impl TcpConnection {
324    pub fn new(stream: TcpStream) -> Self {
325        Self { stream }
326    }
327}
328
329#[cfg(feature = "prove")]
330impl Connection for TcpConnection {
331    fn stream(&mut self) -> &mut TcpStream {
332        &mut self.stream
333    }
334
335    fn close(&mut self) -> Result<i32> {
336        Ok(0)
337    }
338}
339
340fn malformed_err() -> anyhow::Error {
341    anyhow!("Malformed error")
342}
343
344impl pb::api::Asset {
345    fn as_bytes(&self) -> Result<Bytes> {
346        let bytes = match self.kind.as_ref().ok_or(malformed_err())? {
347            pb::api::asset::Kind::Inline(bytes) => bytes.clone(),
348            pb::api::asset::Kind::Path(path) => std::fs::read(path)?,
349            pb::api::asset::Kind::Redis(_) => bail!("as_bytes not supported for redis"),
350        };
351        Ok(bytes.into())
352    }
353}
354
355/// Determines the format of an asset.
356#[derive(Clone)]
357pub enum Asset {
358    /// The asset is encoded inline.
359    Inline(Bytes),
360
361    /// The asset is written to disk.
362    Path(PathBuf),
363
364    /// The asset is written to redis.
365    Redis(String),
366}
367
368/// Determines the parameters for AssetRequest::Redis
369#[derive(Clone)]
370pub struct RedisParams {
371    /// The url of the redis instance
372    pub url: String,
373
374    /// The key used to write to redis
375    pub key: String,
376
377    /// time to live (expiration) for the key being set
378    pub ttl: u64,
379}
380
381/// Determines the format of an asset request.
382#[derive(Clone)]
383pub enum AssetRequest {
384    /// The asset is encoded inline.
385    Inline,
386
387    /// The asset is written to disk.
388    Path(PathBuf),
389
390    /// The asset is written to redis.
391    Redis(RedisParams),
392}
393
394/// Provides information about the result of execution.
395#[derive(Clone, Debug)]
396pub struct SessionInfo {
397    /// The number of user cycles for each segment.
398    pub segments: Vec<SegmentInfo>,
399
400    /// The data publicly committed by the guest program.
401    pub journal: Journal,
402
403    /// The [ExitCode] of the session.
404    pub exit_code: ExitCode,
405
406    /// The [ReceiptClaim] associated with the executed session. This receipt claim is what will be
407    /// proven if this session is passed to the Prover.
408    pub receipt_claim: Option<ReceiptClaim>,
409}
410
411impl SessionInfo {
412    /// The total number of user cycles across all segments, without any
413    /// overhead for continuations or po2 padding.
414    pub fn cycles(&self) -> u64 {
415        self.segments.iter().map(|s| s.cycles as u64).sum()
416    }
417}
418
419/// Provides information about a segment of execution.
420#[derive(Clone, Debug)]
421pub struct SegmentInfo {
422    /// The number of cycles used for proving in powers of 2.
423    pub po2: u32,
424
425    /// The number of user cycles without any overhead for continuations or po2
426    /// padding.
427    pub cycles: u32,
428}
429
430impl Asset {
431    /// Return the bytes for this asset.
432    pub fn as_bytes(&self) -> Result<Bytes> {
433        Ok(match self {
434            Asset::Inline(bytes) => bytes.clone(),
435            Asset::Path(path) => std::fs::read(path)?.into(),
436            Asset::Redis(_) => bail!("as_bytes not supported for Asset::Redis"),
437        })
438    }
439}
440
441fn invalid_path() -> anyhow::Error {
442    anyhow::Error::msg("Path must be UTF-8")
443}
444
445fn path_to_string<P: AsRef<Path>>(path: P) -> Result<String> {
446    Ok(path.as_ref().to_str().ok_or(invalid_path())?.to_string())
447}