1pub(crate) mod client;
16pub(crate) mod convert;
17#[cfg(feature = "prove")]
18pub(crate) mod server;
19#[cfg(test)]
20#[cfg(feature = "prove")]
21mod tests;
22
23use std::{
24 cell::RefCell,
25 io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write},
26 net::{TcpListener, TcpStream},
27 path::{Path, PathBuf},
28 process::{Child, Command},
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 mpsc::channel,
32 Arc, Mutex,
33 },
34 thread,
35 time::Duration,
36};
37
38use anyhow::{anyhow, bail, Context, Result};
39use bytes::{Buf, BufMut, Bytes};
40use lazy_regex::regex_captures;
41use prost::Message;
42use semver::Version;
43
44use crate::{get_version, ExitCode, Journal, ReceiptClaim};
45
46mod pb {
47 pub(crate) mod api {
48 pub use crate::host::protos::api::*;
49 }
50 pub(crate) mod base {
51 pub use crate::host::protos::base::*;
52 }
53 pub(crate) mod core {
54 pub use crate::host::protos::core::*;
55 }
56}
57
58const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
59
60trait RootMessage: Message {}
61
62pub trait Connection {
63 fn stream(&mut self) -> &mut TcpStream;
64 fn close(&mut self) -> Result<i32>;
65}
66
67#[derive(Clone)]
68pub struct ConnectionWrapper {
69 inner: Arc<Mutex<dyn Connection + Send>>,
70}
71
72thread_local! {
73 static LOCAL_BUF: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
74}
75
76impl RootMessage for pb::api::HelloRequest {}
77impl RootMessage for pb::api::HelloReply {}
78impl RootMessage for pb::api::ServerRequest {}
79impl RootMessage for pb::api::ServerReply {}
80impl RootMessage for pb::api::GenericReply {}
81impl RootMessage for pb::api::OnIoReply {}
82impl RootMessage for pb::api::ProveKeccakReply {}
83impl RootMessage for pb::api::ProveSegmentReply {}
84impl RootMessage for pb::api::ProveZkrReply {}
85impl RootMessage for pb::api::LiftRequest {}
86impl RootMessage for pb::api::LiftReply {}
87impl RootMessage for pb::api::JoinRequest {}
88impl RootMessage for pb::api::JoinReply {}
89impl RootMessage for pb::api::ResolveRequest {}
90impl RootMessage for pb::api::ResolveReply {}
91impl RootMessage for pb::api::IdentityP254Request {}
92impl RootMessage for pb::api::IdentityP254Reply {}
93impl RootMessage for pb::api::CompressRequest {}
94impl RootMessage for pb::api::CompressReply {}
95impl RootMessage for pb::api::UnionRequest {}
96impl RootMessage for pb::api::UnionReply {}
97
98fn lock_err() -> IoError {
99 IoError::new(IoErrorKind::WouldBlock, "Failed to lock connection mutex")
100}
101
102impl ConnectionWrapper {
103 fn new(inner: Arc<Mutex<dyn Connection + Send>>) -> Self {
104 Self { inner }
105 }
106
107 fn send<T: RootMessage>(&mut self, msg: T) -> Result<()> {
108 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
109 self.inner_send(guard.stream(), msg)
110 }
111
112 fn recv<T: Default + RootMessage>(&mut self) -> Result<T> {
113 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
114 self.inner_recv(guard.stream())
115 }
116
117 #[cfg(feature = "prove")]
118 fn send_recv<S: RootMessage, R: Default + RootMessage>(&mut self, msg: S) -> Result<R> {
119 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
120 let stream = guard.stream();
121 self.inner_send(stream, msg)?;
122 self.inner_recv(stream)
123 }
124
125 fn close(&mut self) -> Result<i32> {
126 self.inner.lock().map_err(|_| lock_err())?.close()
127 }
128
129 fn inner_send<T: RootMessage>(&self, stream: &mut TcpStream, msg: T) -> Result<()> {
130 let len = msg.encoded_len();
131 LOCAL_BUF.with_borrow_mut(|buf| {
132 buf.clear();
133 buf.put_u32_le(len as u32);
134 msg.encode(buf)?;
135 Ok(stream.write_all(buf)?)
136 })
137 }
138
139 fn inner_recv<T: Default + RootMessage>(&self, stream: &mut TcpStream) -> Result<T> {
140 LOCAL_BUF.with_borrow_mut(|buf| {
141 buf.resize(4, 0);
142 stream.read_exact(buf).context("rx len failed")?;
143 let len = buf.as_slice().get_u32_le() as usize;
144 buf.resize(len, 0);
145 stream.read_exact(buf).context("rx payload failed")?;
146 T::decode(buf.as_slice()).context("rx decode failed")
147 })
148 }
149}
150
151pub trait Connector {
153 fn connect(&self) -> Result<ConnectionWrapper>;
155}
156
157struct ParentProcessConnector {
158 server_path: PathBuf,
159 listener: TcpListener,
160}
161
162impl ParentProcessConnector {
163 pub fn new<P: AsRef<Path>>(server_path: P) -> Result<Self> {
164 Ok(Self {
165 server_path: server_path.as_ref().to_path_buf(),
166 listener: TcpListener::bind("127.0.0.1:0")?,
167 })
168 }
169
170 pub fn new_narrow_version<P: AsRef<Path>>(server_path: P) -> Result<Self> {
171 let client_version = get_version().map_err(|err| anyhow!(err))?;
173 let server_version = get_server_version(&server_path)?;
174 if !client::Compat::Narrow.check(&client_version, &server_version) {
175 let server_suggestion = if client_version.pre.is_empty() {
176 format!(
177 "1. Install r0vm server version {}.{}\n",
178 server_version.major, server_version.minor
179 )
180 } else {
181 format!("1. Your risc0 dependencies are using a pre-released version {client_version}.\n \
182 If you encounter this error message when running code on the risc0 codebase, you must\n \
183 either run the command `git checkout origin/release-{}.{}` to checkout the version of the\n \
184 risc0 code that is compatible with your server or build the r0vm server from source\n \
185 https://github.com/risc0/risc0/blob/main/CONTRIBUTING.md\n", server_version.major, server_version.minor
186 )
187 };
188 let msg = format!(
189 "Your installation of the r0vm server is not compatible with your host's risc0-zkvm crate.\n\
190 Do one of the following to fix this issue:\n\n\
191 {server_suggestion}\
192 2. Change the risc0-zkvm and risc0-build dependencies in your project to {}.{}\n\n\
193 risc0-zkvm version: {client_version}\n\
194 r0vm server version: {server_version}", server_version.major, server_version.minor
195 );
196 tracing::warn!("{msg}");
197 bail!(msg);
198 }
199
200 Self::new(server_path)
201 }
202
203 fn spawn_fail(&self) -> String {
204 format!(
205 "Could not launch zkvm: \"{}\". \n
206 Use `cargo binstall cargo-risczero` to install the latest zkvm.",
207 self.server_path.to_string_lossy()
208 )
209 }
210}
211
212fn get_server_version<P: AsRef<Path>>(server_path: P) -> Result<Version> {
213 let output = Command::new(server_path.as_ref().as_os_str())
214 .arg("--version")
215 .output()?;
216 let cmd_output = String::from_utf8(output.stdout)?;
217 let (_, version_str) = regex_captures!(r".* (.*)\n$", &cmd_output)
218 .ok_or_else(|| anyhow!("failed to parse server version number"))?;
219 Version::parse(version_str).map_err(|e| anyhow!(e))
220}
221
222impl Connector for ParentProcessConnector {
223 fn connect(&self) -> Result<ConnectionWrapper> {
224 let addr = self.listener.local_addr()?;
225 let child = Command::new(&self.server_path)
226 .arg("--port")
227 .arg(addr.port().to_string())
228 .spawn()
229 .with_context(|| self.spawn_fail())?;
230
231 let shutdown = Arc::new(AtomicBool::new(false));
232 let server_shutdown = shutdown.clone();
233 let (tx, rx) = channel();
234 let listener = self.listener.try_clone()?;
235 let handle = thread::spawn(move || {
236 let stream = listener.accept();
237 if server_shutdown.load(Ordering::Relaxed) {
238 return;
239 }
240 if let Ok((stream, _addr)) = stream {
241 tx.send(stream).unwrap();
242 }
243 });
244
245 let stream = rx.recv_timeout(CONNECT_TIMEOUT);
246 let stream = stream.inspect_err(|_| {
247 shutdown.store(true, Ordering::Relaxed);
248 let _ = TcpStream::connect(addr);
249 handle.join().unwrap();
250 })?;
251
252 Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
253 ParentProcessConnection::new(child, stream),
254 ))))
255 }
256}
257
258#[cfg(feature = "prove")]
259struct TcpConnector {
260 addr: String,
261}
262
263#[cfg(feature = "prove")]
264impl TcpConnector {
265 pub(crate) fn new(addr: &str) -> Self {
266 Self {
267 addr: addr.to_string(),
268 }
269 }
270}
271
272#[cfg(feature = "prove")]
273impl Connector for TcpConnector {
274 fn connect(&self) -> Result<ConnectionWrapper> {
275 tracing::debug!("connect");
276 let stream = TcpStream::connect(&self.addr)?;
277 Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
278 TcpConnection::new(stream),
279 ))))
280 }
281}
282
283struct ParentProcessConnection {
284 child: Child,
285 stream: TcpStream,
286}
287
288#[cfg(feature = "prove")]
289struct TcpConnection {
290 stream: TcpStream,
291}
292
293impl ParentProcessConnection {
294 pub fn new(child: Child, stream: TcpStream) -> Self {
295 Self { child, stream }
296 }
297}
298
299impl Connection for ParentProcessConnection {
300 fn stream(&mut self) -> &mut TcpStream {
301 &mut self.stream
302 }
303
304 fn close(&mut self) -> Result<i32> {
305 let status = self.child.wait()?;
306 Ok(status.code().unwrap_or_default())
307 }
308}
309
310#[cfg(feature = "prove")]
311impl TcpConnection {
312 pub fn new(stream: TcpStream) -> Self {
313 Self { stream }
314 }
315}
316
317#[cfg(feature = "prove")]
318impl Connection for TcpConnection {
319 fn stream(&mut self) -> &mut TcpStream {
320 &mut self.stream
321 }
322
323 fn close(&mut self) -> Result<i32> {
324 Ok(0)
325 }
326}
327
328fn malformed_err(field: &str) -> anyhow::Error {
329 anyhow!("Malformed error: {field}")
330}
331
332impl pb::api::Asset {
333 fn as_bytes(&self) -> Result<Bytes> {
334 let bytes = match self
335 .kind
336 .as_ref()
337 .ok_or_else(|| malformed_err("Asset.kind"))?
338 {
339 pb::api::asset::Kind::Inline(bytes) => bytes.clone(),
340 pb::api::asset::Kind::Path(path) => std::fs::read(path)?,
341 pb::api::asset::Kind::Redis(_) => bail!("as_bytes not supported for redis"),
342 };
343 Ok(bytes.into())
344 }
345}
346
347#[derive(Clone)]
349pub enum Asset {
350 Inline(Bytes),
352
353 Path(PathBuf),
355
356 Redis(String),
358}
359
360#[derive(Clone)]
362pub struct RedisParams {
363 pub url: String,
365
366 pub key: String,
368
369 pub ttl: u64,
371}
372
373#[derive(Clone)]
375pub enum AssetRequest {
376 Inline,
378
379 Path(PathBuf),
381
382 Redis(RedisParams),
384}
385
386#[derive(Clone, Debug)]
388#[non_exhaustive]
389pub struct SessionInfo {
390 pub segments: Vec<SegmentInfo>,
392
393 pub journal: Journal,
395
396 pub exit_code: ExitCode,
398
399 pub receipt_claim: Option<ReceiptClaim>,
402}
403
404impl SessionInfo {
405 pub fn cycles(&self) -> u64 {
408 self.segments.iter().map(|s| s.cycles as u64).sum()
409 }
410}
411
412#[derive(Clone, Debug)]
414#[non_exhaustive]
415pub struct SegmentInfo {
416 pub po2: u32,
418
419 pub cycles: u32,
422}
423
424impl Asset {
425 pub fn as_bytes(&self) -> Result<Bytes> {
427 Ok(match self {
428 Asset::Inline(bytes) => bytes.clone(),
429 Asset::Path(path) => std::fs::read(path)?.into(),
430 Asset::Redis(_) => bail!("as_bytes not supported for Asset::Redis"),
431 })
432 }
433}
434
435fn invalid_path() -> anyhow::Error {
436 anyhow::Error::msg("Path must be UTF-8")
437}
438
439fn path_to_string<P: AsRef<Path>>(path: P) -> Result<String> {
440 Ok(path.as_ref().to_str().ok_or_else(invalid_path)?.to_string())
441}