1pub(crate) mod client;
16pub(crate) mod convert;
17#[cfg(feature = "prove")]
18pub(crate) mod server;
19#[cfg(test)]
20#[cfg(feature = "prove")]
21mod tests;
22
23use std::{
24 cell::RefCell,
25 io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write},
26 net::{TcpListener, TcpStream},
27 path::{Path, PathBuf},
28 process::{Child, Command},
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 mpsc::channel,
32 Arc, Mutex,
33 },
34 thread,
35 time::Duration,
36};
37
38use anyhow::{anyhow, bail, Context, Result};
39use bytes::{Buf, BufMut, Bytes};
40use lazy_regex::regex_captures;
41use prost::Message;
42use semver::Version;
43
44use crate::{get_version, ExitCode, Journal, ReceiptClaim};
45
46mod pb {
47 pub(crate) mod api {
48 pub use crate::host::protos::api::*;
49 }
50 pub(crate) mod base {
51 pub use crate::host::protos::base::*;
52 }
53 pub(crate) mod core {
54 pub use crate::host::protos::core::*;
55 }
56}
57
58const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
59
60trait RootMessage: Message {}
61
62pub trait Connection {
63 fn stream(&mut self) -> &mut TcpStream;
64 fn close(&mut self) -> Result<i32>;
65}
66
67#[derive(Clone)]
68pub struct ConnectionWrapper {
69 inner: Arc<Mutex<dyn Connection + Send>>,
70}
71
72thread_local! {
73 static LOCAL_BUF: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
74}
75
76impl RootMessage for pb::api::HelloRequest {}
77impl RootMessage for pb::api::HelloReply {}
78impl RootMessage for pb::api::ServerRequest {}
79impl RootMessage for pb::api::ServerReply {}
80impl RootMessage for pb::api::GenericReply {}
81impl RootMessage for pb::api::OnIoReply {}
82impl RootMessage for pb::api::ProveKeccakReply {}
83impl RootMessage for pb::api::ProveSegmentReply {}
84impl RootMessage for pb::api::LiftRequest {}
85impl RootMessage for pb::api::LiftReply {}
86impl RootMessage for pb::api::JoinRequest {}
87impl RootMessage for pb::api::JoinReply {}
88impl RootMessage for pb::api::ResolveRequest {}
89impl RootMessage for pb::api::ResolveReply {}
90impl RootMessage for pb::api::IdentityP254Request {}
91impl RootMessage for pb::api::IdentityP254Reply {}
92impl RootMessage for pb::api::CompressRequest {}
93impl RootMessage for pb::api::CompressReply {}
94impl RootMessage for pb::api::UnionRequest {}
95impl RootMessage for pb::api::UnionReply {}
96
97fn lock_err() -> IoError {
98 IoError::new(IoErrorKind::WouldBlock, "Failed to lock connection mutex")
99}
100
101impl ConnectionWrapper {
102 fn new(inner: Arc<Mutex<dyn Connection + Send>>) -> Self {
103 Self { inner }
104 }
105
106 fn send<T: RootMessage>(&mut self, msg: T) -> Result<()> {
107 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
108 self.inner_send(guard.stream(), msg)
109 }
110
111 fn recv<T: Default + RootMessage>(&mut self) -> Result<T> {
112 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
113 self.inner_recv(guard.stream())
114 }
115
116 #[cfg(feature = "prove")]
117 fn send_recv<S: RootMessage, R: Default + RootMessage>(&mut self, msg: S) -> Result<R> {
118 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
119 let stream = guard.stream();
120 self.inner_send(stream, msg)?;
121 self.inner_recv(stream)
122 }
123
124 fn close(&mut self) -> Result<i32> {
125 self.inner.lock().map_err(|_| lock_err())?.close()
126 }
127
128 fn inner_send<T: RootMessage>(&self, stream: &mut TcpStream, msg: T) -> Result<()> {
129 let len = msg.encoded_len();
130 LOCAL_BUF.with_borrow_mut(|buf| {
131 buf.clear();
132 buf.put_u32_le(len as u32);
133 msg.encode(buf)?;
134 Ok(stream.write_all(buf)?)
135 })
136 }
137
138 fn inner_recv<T: Default + RootMessage>(&self, stream: &mut TcpStream) -> Result<T> {
139 LOCAL_BUF.with_borrow_mut(|buf| {
140 buf.resize(4, 0);
141 stream.read_exact(buf).context("rx len failed")?;
142 let len = buf.as_slice().get_u32_le() as usize;
143 buf.resize(len, 0);
144 stream.read_exact(buf).context("rx payload failed")?;
145 T::decode(buf.as_slice()).context("rx decode failed")
146 })
147 }
148}
149
150pub trait Connector {
152 fn connect(&self) -> Result<ConnectionWrapper>;
154}
155
156struct ParentProcessConnector {
157 server_path: PathBuf,
158 listener: TcpListener,
159}
160
161impl ParentProcessConnector {
162 pub fn new<P: AsRef<Path>>(server_path: P) -> Result<Self> {
163 Ok(Self {
164 server_path: server_path.as_ref().to_path_buf(),
165 listener: TcpListener::bind("127.0.0.1:0")?,
166 })
167 }
168
169 pub fn new_narrow_version<P: AsRef<Path>>(server_path: P) -> Result<Self> {
170 let client_version = get_version().map_err(|err| anyhow!(err))?;
172 let server_version = get_server_version(&server_path)?;
173 if !client::Compat::Narrow.check(&client_version, &server_version) {
174 let server_suggestion = if client_version.pre.is_empty() {
175 format!(
176 "1. Install r0vm server version {}.{}\n",
177 client_version.major, client_version.minor
178 )
179 } else {
180 format!("1. Your risc0 dependencies are using a pre-released version {client_version}.\n \
181 If you encounter this error message when running code on the risc0 codebase, you must\n \
182 either run the command `git checkout origin/release-{}.{}` to checkout the version of the\n \
183 risc0 code that is compatible with your server or build the r0vm server from source\n \
184 https://github.com/risc0/risc0/blob/main/CONTRIBUTING.md\n", server_version.major, server_version.minor
185 )
186 };
187 let msg = format!(
188 "Your installation of the r0vm server is not compatible with your host's risc0-zkvm crate.\n\
189 Do one of the following to fix this issue:\n\n\
190 {server_suggestion}\
191 2. Change the risc0-zkvm and risc0-build dependencies in your project to {}.{}\n\n\
192 risc0-zkvm version: {client_version}\n\
193 r0vm server version: {server_version}", server_version.major, server_version.minor
194 );
195 tracing::warn!("{msg}");
196 bail!(msg);
197 }
198
199 Self::new(server_path)
200 }
201
202 fn spawn_fail(&self) -> String {
203 format!(
204 "Could not launch zkvm: \"{}\". \n
205 Use `cargo binstall cargo-risczero` to install the latest zkvm.",
206 self.server_path.to_string_lossy()
207 )
208 }
209}
210
211fn get_server_version<P: AsRef<Path>>(server_path: P) -> Result<Version> {
212 let output = Command::new(server_path.as_ref().as_os_str())
213 .arg("--version")
214 .output()?;
215 let cmd_output = String::from_utf8(output.stdout)?;
216 let (_, version_str) = regex_captures!(r".* (.*)\n$", &cmd_output)
217 .ok_or_else(|| anyhow!("failed to parse server version number"))?;
218 Version::parse(version_str).map_err(|e| anyhow!(e))
219}
220
221impl Connector for ParentProcessConnector {
222 fn connect(&self) -> Result<ConnectionWrapper> {
223 let addr = self.listener.local_addr()?;
224 let child = Command::new(&self.server_path)
225 .arg("--port")
226 .arg(addr.port().to_string())
227 .spawn()
228 .with_context(|| self.spawn_fail())?;
229
230 let shutdown = Arc::new(AtomicBool::new(false));
231 let server_shutdown = shutdown.clone();
232 let (tx, rx) = channel();
233 let listener = self.listener.try_clone()?;
234 let handle = thread::spawn(move || {
235 let stream = listener.accept();
236 if server_shutdown.load(Ordering::Relaxed) {
237 return;
238 }
239 if let Ok((stream, _addr)) = stream {
240 tx.send(stream).unwrap();
241 }
242 });
243
244 let stream = rx.recv_timeout(CONNECT_TIMEOUT);
245 let stream = stream.inspect_err(|_| {
246 shutdown.store(true, Ordering::Relaxed);
247 let _ = TcpStream::connect(addr);
248 handle.join().unwrap();
249 })?;
250
251 Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
252 ParentProcessConnection::new(child, stream),
253 ))))
254 }
255}
256
257#[cfg(feature = "prove")]
258struct TcpConnector {
259 addr: String,
260}
261
262#[cfg(feature = "prove")]
263impl TcpConnector {
264 pub(crate) fn new(addr: &str) -> Self {
265 Self {
266 addr: addr.to_string(),
267 }
268 }
269}
270
271#[cfg(feature = "prove")]
272impl Connector for TcpConnector {
273 fn connect(&self) -> Result<ConnectionWrapper> {
274 tracing::debug!("connect");
275 let stream = TcpStream::connect(&self.addr)?;
276 Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
277 TcpConnection::new(stream),
278 ))))
279 }
280}
281
282struct ParentProcessConnection {
283 child: Child,
284 stream: TcpStream,
285}
286
287#[cfg(feature = "prove")]
288struct TcpConnection {
289 stream: TcpStream,
290}
291
292impl ParentProcessConnection {
293 pub fn new(child: Child, stream: TcpStream) -> Self {
294 Self { child, stream }
295 }
296}
297
298impl Connection for ParentProcessConnection {
299 fn stream(&mut self) -> &mut TcpStream {
300 &mut self.stream
301 }
302
303 fn close(&mut self) -> Result<i32> {
304 let status = self.child.wait()?;
305 Ok(status.code().unwrap_or_default())
306 }
307}
308
309#[cfg(feature = "prove")]
310impl TcpConnection {
311 pub fn new(stream: TcpStream) -> Self {
312 Self { stream }
313 }
314}
315
316#[cfg(feature = "prove")]
317impl Connection for TcpConnection {
318 fn stream(&mut self) -> &mut TcpStream {
319 &mut self.stream
320 }
321
322 fn close(&mut self) -> Result<i32> {
323 Ok(0)
324 }
325}
326
327fn malformed_err(field: &str) -> anyhow::Error {
328 anyhow!("Malformed error: {field}")
329}
330
331impl pb::api::Asset {
332 fn as_bytes(&self) -> Result<Bytes> {
333 let bytes = match self
334 .kind
335 .as_ref()
336 .ok_or_else(|| malformed_err("Asset.kind"))?
337 {
338 pb::api::asset::Kind::Inline(bytes) => bytes.clone(),
339 pb::api::asset::Kind::Path(path) => std::fs::read(path)?,
340 pb::api::asset::Kind::Redis(_) => bail!("as_bytes not supported for redis"),
341 };
342 Ok(bytes.into())
343 }
344}
345
346#[derive(Clone)]
348pub enum Asset {
349 Inline(Bytes),
351
352 Path(PathBuf),
354
355 Redis(String),
357}
358
359#[derive(Clone)]
361pub struct RedisParams {
362 pub url: String,
364
365 pub key: String,
367
368 pub ttl: u64,
370}
371
372#[derive(Clone)]
374pub enum AssetRequest {
375 Inline,
377
378 Path(PathBuf),
380
381 Redis(RedisParams),
383}
384
385#[derive(Clone, Debug)]
387#[non_exhaustive]
388pub struct SessionInfo {
389 pub segments: Vec<SegmentInfo>,
391
392 pub journal: Journal,
394
395 pub exit_code: ExitCode,
397
398 pub receipt_claim: Option<ReceiptClaim>,
401}
402
403impl SessionInfo {
404 pub fn cycles(&self) -> u64 {
407 self.segments.iter().map(|s| s.cycles as u64).sum()
408 }
409}
410
411#[derive(Clone, Debug)]
413#[non_exhaustive]
414pub struct SegmentInfo {
415 pub po2: u32,
417
418 pub cycles: u32,
421}
422
423impl Asset {
424 pub fn as_bytes(&self) -> Result<Bytes> {
426 Ok(match self {
427 Asset::Inline(bytes) => bytes.clone(),
428 Asset::Path(path) => std::fs::read(path)?.into(),
429 Asset::Redis(_) => bail!("as_bytes not supported for Asset::Redis"),
430 })
431 }
432}
433
434fn invalid_path() -> anyhow::Error {
435 anyhow::Error::msg("Path must be UTF-8")
436}
437
438fn path_to_string<P: AsRef<Path>>(path: P) -> Result<String> {
439 Ok(path.as_ref().to_str().ok_or_else(invalid_path)?.to_string())
440}