1pub(crate) mod client;
16pub(crate) mod convert;
17#[cfg(feature = "prove")]
18pub(crate) mod server;
19#[cfg(test)]
20#[cfg(feature = "prove")]
21mod tests;
22
23use std::{
24 cell::RefCell,
25 io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write},
26 net::{TcpListener, TcpStream},
27 path::{Path, PathBuf},
28 process::{Child, Command},
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 mpsc::channel,
32 Arc, Mutex,
33 },
34 thread,
35 time::Duration,
36};
37
38use anyhow::{anyhow, bail, Context, Result};
39use bytes::{Buf, BufMut, Bytes};
40use lazy_regex::regex_captures;
41use prost::Message;
42use semver::Version;
43
44use crate::{get_version, ExitCode, Journal, ReceiptClaim};
45
46mod pb {
47 pub(crate) mod api {
48 pub use crate::host::protos::api::*;
49 }
50 pub(crate) mod base {
51 pub use crate::host::protos::base::*;
52 }
53 pub(crate) mod core {
54 pub use crate::host::protos::core::*;
55 }
56}
57
58const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
59
60trait RootMessage: Message {}
61
62pub trait Connection {
63 fn stream(&mut self) -> &mut TcpStream;
64 fn close(&mut self) -> Result<i32>;
65}
66
67#[derive(Clone)]
68pub struct ConnectionWrapper {
69 inner: Arc<Mutex<dyn Connection + Send>>,
70}
71
72thread_local! {
73 static LOCAL_BUF: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
74}
75
76impl RootMessage for pb::api::HelloRequest {}
77impl RootMessage for pb::api::HelloReply {}
78impl RootMessage for pb::api::ServerRequest {}
79impl RootMessage for pb::api::ServerReply {}
80impl RootMessage for pb::api::GenericReply {}
81impl RootMessage for pb::api::OnIoReply {}
82impl RootMessage for pb::api::ProveKeccakReply {}
83impl RootMessage for pb::api::ProveSegmentReply {}
84impl RootMessage for pb::api::ProveZkrReply {}
85impl RootMessage for pb::api::LiftRequest {}
86impl RootMessage for pb::api::LiftReply {}
87impl RootMessage for pb::api::JoinRequest {}
88impl RootMessage for pb::api::JoinReply {}
89impl RootMessage for pb::api::ResolveRequest {}
90impl RootMessage for pb::api::ResolveReply {}
91impl RootMessage for pb::api::IdentityP254Request {}
92impl RootMessage for pb::api::IdentityP254Reply {}
93impl RootMessage for pb::api::CompressRequest {}
94impl RootMessage for pb::api::CompressReply {}
95
96fn lock_err() -> IoError {
97 IoError::new(IoErrorKind::WouldBlock, "Failed to lock connection mutex")
98}
99
100impl ConnectionWrapper {
101 fn new(inner: Arc<Mutex<dyn Connection + Send>>) -> Self {
102 Self { inner }
103 }
104
105 fn send<T: RootMessage>(&mut self, msg: T) -> Result<()> {
106 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
107 self.inner_send(guard.stream(), msg)
108 }
109
110 fn recv<T: Default + RootMessage>(&mut self) -> Result<T> {
111 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
112 self.inner_recv(guard.stream())
113 }
114
115 #[cfg(feature = "prove")]
116 fn send_recv<S: RootMessage, R: Default + RootMessage>(&mut self, msg: S) -> Result<R> {
117 let mut guard = self.inner.lock().map_err(|_| lock_err())?;
118 let stream = guard.stream();
119 self.inner_send(stream, msg)?;
120 self.inner_recv(stream)
121 }
122
123 fn close(&mut self) -> Result<i32> {
124 self.inner.lock().map_err(|_| lock_err())?.close()
125 }
126
127 fn inner_send<T: RootMessage>(&self, stream: &mut TcpStream, msg: T) -> Result<()> {
128 let len = msg.encoded_len();
129 LOCAL_BUF.with_borrow_mut(|buf| {
130 buf.clear();
131 buf.put_u32_le(len as u32);
132 msg.encode(buf)?;
133 Ok(stream.write_all(buf)?)
134 })
135 }
136
137 fn inner_recv<T: Default + RootMessage>(&self, stream: &mut TcpStream) -> Result<T> {
138 LOCAL_BUF.with_borrow_mut(|buf| {
139 buf.resize(4, 0);
140 stream.read_exact(buf).context("rx len failed")?;
141 let len = buf.as_slice().get_u32_le() as usize;
142 buf.resize(len, 0);
143 stream.read_exact(buf).context("rx payload failed")?;
144 T::decode(buf.as_slice()).context("rx decode failed")
145 })
146 }
147}
148
149pub trait Connector {
151 fn connect(&self) -> Result<ConnectionWrapper>;
153}
154
155struct ParentProcessConnector {
156 server_path: PathBuf,
157 listener: TcpListener,
158}
159
160impl ParentProcessConnector {
161 pub fn new<P: AsRef<Path>>(server_path: P) -> Result<Self> {
162 let client_version = get_version().map_err(|err| anyhow!(err))?;
164 let server_version = get_server_version(&server_path)?;
165 if !client::check_server_version(&client_version, &server_version) {
166 let server_suggestion = if client_version.pre.is_empty() {
167 format!(
168 "1. Install r0vm server version {}.{}\n",
169 server_version.major, server_version.minor
170 )
171 } else {
172 format!("1. Your risc0 dependencies are using a pre-released version {client_version}.\n \
173 If you encounter this error message when running code on the risc0 codebase, you must\n \
174 either run the command `git checkout origin/release-{}.{}` to checkout the version of the\n \
175 risc0 code that is compatible with your server or build the r0vm server from source\n \
176 https://github.com/risc0/risc0/blob/main/CONTRIBUTING.md\n", server_version.major, server_version.minor
177 )
178 };
179 let msg = format!(
180 "Your installation of the r0vm server is not compatible with your host's risc0-zkvm crate.\n\
181 Do one of the following to fix this issue:\n\n\
182 {server_suggestion}\
183 2. Change the risc0-zkvm and risc0-build dependencies in your project to {}.{}\n\n\
184 risc0-zkvm version: {client_version}\n\
185 r0vm server version: {server_version}", server_version.major, server_version.minor
186 );
187 tracing::warn!("{msg}");
188 bail!(msg);
189 }
190
191 Ok(Self {
192 server_path: server_path.as_ref().to_path_buf(),
193 listener: TcpListener::bind("127.0.0.1:0")?,
194 })
195 }
196
197 pub fn new_wide_version<P: AsRef<Path>>(server_path: P) -> Result<Self> {
198 let client_version = get_version().map_err(|err| anyhow!(err))?;
199 let server_version = get_server_version(&server_path)?;
200
201 if !client::check_server_version_wide(&client_version, &server_version) {
202 let msg = format!(
203 "Your installation of r0vm differs by a major version:\n\
204 {client_version} vs {server_version} only minor, patch / pre-releases supported"
205 );
206 tracing::warn!("{msg}");
207 bail!(msg);
208 }
209 Ok(Self {
210 server_path: server_path.as_ref().to_path_buf(),
211 listener: TcpListener::bind("127.0.0.1:0")?,
212 })
213 }
214
215 fn spawn_fail(&self) -> String {
216 format!(
217 "Could not launch zkvm: \"{}\". \n
218 Use `cargo binstall cargo-risczero` to install the latest zkvm.",
219 self.server_path.to_string_lossy()
220 )
221 }
222}
223
224fn get_server_version<P: AsRef<Path>>(server_path: P) -> Result<Version> {
225 let output = Command::new(server_path.as_ref().as_os_str())
226 .arg("--version")
227 .output()?;
228 let cmd_output = String::from_utf8(output.stdout)?;
229 let (_, version_str) = regex_captures!(r".* (.*)\n$", &cmd_output)
230 .ok_or(anyhow!("failed to parse server version number"))?;
231 Version::parse(version_str).map_err(|e| anyhow!(e))
232}
233
234impl Connector for ParentProcessConnector {
235 fn connect(&self) -> Result<ConnectionWrapper> {
236 let addr = self.listener.local_addr()?;
237 let child = Command::new(&self.server_path)
238 .arg("--port")
239 .arg(addr.port().to_string())
240 .spawn()
241 .with_context(|| self.spawn_fail())?;
242
243 let shutdown = Arc::new(AtomicBool::new(false));
244 let server_shutdown = shutdown.clone();
245 let (tx, rx) = channel();
246 let listener = self.listener.try_clone()?;
247 let handle = thread::spawn(move || {
248 let stream = listener.accept();
249 if server_shutdown.load(Ordering::Relaxed) {
250 return;
251 }
252 if let Ok((stream, _addr)) = stream {
253 tx.send(stream).unwrap();
254 }
255 });
256
257 let stream = rx.recv_timeout(CONNECT_TIMEOUT);
258 let stream = stream.inspect_err(|_| {
259 shutdown.store(true, Ordering::Relaxed);
260 let _ = TcpStream::connect(addr);
261 handle.join().unwrap();
262 })?;
263
264 Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
265 ParentProcessConnection::new(child, stream),
266 ))))
267 }
268}
269
270#[cfg(feature = "prove")]
271struct TcpConnector {
272 addr: String,
273}
274
275#[cfg(feature = "prove")]
276impl TcpConnector {
277 pub(crate) fn new(addr: &str) -> Self {
278 Self {
279 addr: addr.to_string(),
280 }
281 }
282}
283
284#[cfg(feature = "prove")]
285impl Connector for TcpConnector {
286 fn connect(&self) -> Result<ConnectionWrapper> {
287 tracing::debug!("connect");
288 let stream = TcpStream::connect(&self.addr)?;
289 Ok(ConnectionWrapper::new(Arc::new(Mutex::new(
290 TcpConnection::new(stream),
291 ))))
292 }
293}
294
295struct ParentProcessConnection {
296 child: Child,
297 stream: TcpStream,
298}
299
300#[cfg(feature = "prove")]
301struct TcpConnection {
302 stream: TcpStream,
303}
304
305impl ParentProcessConnection {
306 pub fn new(child: Child, stream: TcpStream) -> Self {
307 Self { child, stream }
308 }
309}
310
311impl Connection for ParentProcessConnection {
312 fn stream(&mut self) -> &mut TcpStream {
313 &mut self.stream
314 }
315
316 fn close(&mut self) -> Result<i32> {
317 let status = self.child.wait()?;
318 Ok(status.code().unwrap_or_default())
319 }
320}
321
322#[cfg(feature = "prove")]
323impl TcpConnection {
324 pub fn new(stream: TcpStream) -> Self {
325 Self { stream }
326 }
327}
328
329#[cfg(feature = "prove")]
330impl Connection for TcpConnection {
331 fn stream(&mut self) -> &mut TcpStream {
332 &mut self.stream
333 }
334
335 fn close(&mut self) -> Result<i32> {
336 Ok(0)
337 }
338}
339
340fn malformed_err() -> anyhow::Error {
341 anyhow!("Malformed error")
342}
343
344impl pb::api::Asset {
345 fn as_bytes(&self) -> Result<Bytes> {
346 let bytes = match self.kind.as_ref().ok_or(malformed_err())? {
347 pb::api::asset::Kind::Inline(bytes) => bytes.clone(),
348 pb::api::asset::Kind::Path(path) => std::fs::read(path)?,
349 pb::api::asset::Kind::Redis(_) => bail!("as_bytes not supported for redis"),
350 };
351 Ok(bytes.into())
352 }
353}
354
355#[derive(Clone)]
357pub enum Asset {
358 Inline(Bytes),
360
361 Path(PathBuf),
363
364 Redis(String),
366}
367
368#[derive(Clone)]
370pub struct RedisParams {
371 pub url: String,
373
374 pub key: String,
376
377 pub ttl: u64,
379}
380
381#[derive(Clone)]
383pub enum AssetRequest {
384 Inline,
386
387 Path(PathBuf),
389
390 Redis(RedisParams),
392}
393
394#[derive(Clone, Debug)]
396pub struct SessionInfo {
397 pub segments: Vec<SegmentInfo>,
399
400 pub journal: Journal,
402
403 pub exit_code: ExitCode,
405
406 pub receipt_claim: Option<ReceiptClaim>,
409}
410
411impl SessionInfo {
412 pub fn cycles(&self) -> u64 {
415 self.segments.iter().map(|s| s.cycles as u64).sum()
416 }
417}
418
419#[derive(Clone, Debug)]
421pub struct SegmentInfo {
422 pub po2: u32,
424
425 pub cycles: u32,
428}
429
430impl Asset {
431 pub fn as_bytes(&self) -> Result<Bytes> {
433 Ok(match self {
434 Asset::Inline(bytes) => bytes.clone(),
435 Asset::Path(path) => std::fs::read(path)?.into(),
436 Asset::Redis(_) => bail!("as_bytes not supported for Asset::Redis"),
437 })
438 }
439}
440
441fn invalid_path() -> anyhow::Error {
442 anyhow::Error::msg("Path must be UTF-8")
443}
444
445fn path_to_string<P: AsRef<Path>>(path: P) -> Result<String> {
446 Ok(path.as_ref().to_str().ok_or(invalid_path())?.to_string())
447}