1use std::{
16 error::Error as StdError,
17 io::{BufReader, Error as IoError, ErrorKind as IoErrorKind, Read, Write},
18 path::{Path, PathBuf},
19};
20
21use anyhow::{anyhow, bail, Context, Result};
22use bytes::Bytes;
23use prost::Message;
24use risc0_zkp::core::digest::Digest;
25
26use super::{malformed_err, path_to_string, pb, ConnectionWrapper, Connector, TcpConnector};
27use crate::{
28 get_prover_server, get_version,
29 host::{
30 api::convert::keccak_input_to_bytes,
31 client::{
32 env::{CoprocessorCallback, ProveKeccakRequest, ProveZkrRequest},
33 slice_io::SliceIo,
34 },
35 server::{prove::keccak::prove_keccak, session::NullSegmentRef},
36 },
37 prove_registered_zkr,
38 recursion::identity_p254,
39 AssetRequest, Assumption, ExecutorEnv, ExecutorImpl, InnerAssumptionReceipt, ProverOpts,
40 Receipt, ReceiptClaim, Segment, SegmentReceipt, Session, SuccinctReceipt, TraceCallback,
41 TraceEvent, VerifierContext,
42};
43
44pub struct Server {
46 connector: Box<dyn Connector>,
47}
48struct PosixIoProxy {
49 fd: u32,
50 conn: ConnectionWrapper,
51}
52
53impl PosixIoProxy {
54 fn new(fd: u32, conn: ConnectionWrapper) -> Self {
55 PosixIoProxy { fd, conn }
56 }
57}
58
59impl Read for PosixIoProxy {
60 fn read(&mut self, to_guest: &mut [u8]) -> std::io::Result<usize> {
61 let nread = to_guest.len().try_into().map_io_err()?;
62 let request = pb::api::ServerReply {
63 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
64 kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
65 kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
66 fd: self.fd,
67 cmd: Some(pb::api::PosixCmd {
68 kind: Some(pb::api::posix_cmd::Kind::Read(nread)),
69 }),
70 })),
71 })),
72 })),
73 };
74
75 tracing::trace!("tx: {request:?}");
76 let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
77 tracing::trace!("rx: {reply:?}");
78
79 let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
80 match kind {
81 pb::api::on_io_reply::Kind::Ok(bytes) => {
82 let (head, _) = to_guest.split_at_mut(bytes.len());
83 head.copy_from_slice(&bytes);
84 Ok(bytes.len())
85 }
86 pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
87 }
88 }
89}
90
91impl Write for PosixIoProxy {
92 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
93 let request = pb::api::ServerReply {
94 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
95 kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
96 kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
97 fd: self.fd,
98 cmd: Some(pb::api::PosixCmd {
99 kind: Some(pb::api::posix_cmd::Kind::Write(buf.into())),
100 }),
101 })),
102 })),
103 })),
104 };
105
106 tracing::trace!("tx: {request:?}");
107 let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
108 tracing::trace!("rx: {reply:?}");
109
110 let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
111 match kind {
112 pb::api::on_io_reply::Kind::Ok(_) => Ok(buf.len()),
113 pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
114 }
115 }
116
117 fn flush(&mut self) -> std::io::Result<()> {
118 Ok(())
119 }
120}
121
122#[derive(Clone)]
123struct SliceIoProxy {
124 conn: ConnectionWrapper,
125}
126
127impl SliceIoProxy {
128 fn new(conn: ConnectionWrapper) -> Self {
129 Self { conn }
130 }
131}
132
133impl SliceIo for SliceIoProxy {
134 fn handle_io(&mut self, syscall: &str, from_guest: Bytes) -> Result<Bytes> {
135 let request = pb::api::ServerReply {
136 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
137 kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
138 kind: Some(pb::api::on_io_request::Kind::Slice(pb::api::SliceIo {
139 name: syscall.to_string(),
140 from_guest: from_guest.into(),
141 })),
142 })),
143 })),
144 };
145 tracing::trace!("tx: {request:?}");
146 let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
147 tracing::trace!("rx: {reply:?}");
148
149 let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
150 match kind {
151 pb::api::on_io_reply::Kind::Ok(buf) => Ok(buf.into()),
152 pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
153 }
154 }
155}
156
157struct TraceProxy {
158 conn: ConnectionWrapper,
159}
160
161impl TraceProxy {
162 fn new(conn: ConnectionWrapper) -> Self {
163 Self { conn }
164 }
165}
166
167impl TraceCallback for TraceProxy {
168 fn trace_callback(&mut self, event: TraceEvent) -> Result<()> {
169 let Ok(event) = event.clone().try_into() else {
170 tracing::trace!("ignoring unknown event {event:?}");
171 return Ok(());
172 };
173
174 let request = pb::api::ServerReply {
175 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
176 kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
177 kind: Some(pb::api::on_io_request::Kind::Trace(event)),
178 })),
179 })),
180 };
181 tracing::trace!("tx: {request:?}");
182 let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
183 tracing::trace!("rx: {reply:?}");
184
185 let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
186 match kind {
187 pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
188 pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
189 }
190 }
191}
192
193struct CoprocessorProxy {
194 conn: ConnectionWrapper,
195}
196
197impl CoprocessorProxy {
198 fn new(conn: ConnectionWrapper) -> Self {
199 Self { conn }
200 }
201}
202
203impl CoprocessorCallback for CoprocessorProxy {
204 fn prove_zkr(&mut self, proof_request: ProveZkrRequest) -> Result<()> {
205 let request = pb::api::ServerReply {
206 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
207 kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
208 kind: Some(pb::api::on_io_request::Kind::Coprocessor(
209 pb::api::CoprocessorRequest {
210 kind: Some(pb::api::coprocessor_request::Kind::ProveZkr({
211 pb::api::ProveZkrRequest {
212 claim_digest: Some(proof_request.claim_digest.into()),
213 control_id: Some(proof_request.control_id.into()),
214 input: proof_request.input,
215 receipt_out: None,
216 }
217 })),
218 },
219 )),
220 })),
221 })),
222 };
223 tracing::trace!("tx: {request:?}");
224 let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
225 tracing::trace!("rx: {reply:?}");
226
227 let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
228 match kind {
229 pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
230 pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
231 }
232 }
233
234 fn prove_keccak(&mut self, proof_request: ProveKeccakRequest) -> Result<()> {
235 let input = keccak_input_to_bytes(&proof_request.input);
236 let request = pb::api::ServerReply {
237 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
238 kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
239 kind: Some(pb::api::on_io_request::Kind::Coprocessor(
240 pb::api::CoprocessorRequest {
241 kind: Some(pb::api::coprocessor_request::Kind::ProveKeccak({
242 pb::api::ProveKeccakRequest {
243 claim_digest: Some(proof_request.claim_digest.into()),
244 po2: proof_request.po2 as u32,
245 control_root: Some(proof_request.control_root.into()),
246 input,
247 receipt_out: None,
248 }
249 })),
250 },
251 )),
252 })),
253 })),
254 };
255 tracing::trace!("tx: {request:?}");
256 self.conn.send(request)?;
257
258 let reply: pb::api::OnIoReply = self.conn.recv().map_io_err()?;
259 tracing::trace!("rx: {reply:?}");
260
261 let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
262 match kind {
263 pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
264 pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
265 }
266 }
267}
268
269impl Server {
270 pub fn new(connector: Box<dyn Connector>) -> Self {
272 Self { connector }
273 }
274
275 pub fn new_tcp<A: AsRef<str>>(addr: A) -> Self {
278 let connector = TcpConnector::new(addr.as_ref());
279 Self::new(Box::new(connector))
280 }
281
282 pub fn run(&self) -> Result<()> {
284 tracing::debug!("connect");
285 let mut conn = self.connector.connect()?;
286
287 let server_version = get_version().map_err(|err| anyhow!(err))?;
288
289 let request: pb::api::HelloRequest = conn.recv()?;
290 tracing::trace!("rx: {request:?}");
291
292 let client_version: semver::Version = request
293 .version
294 .ok_or(malformed_err())?
295 .try_into()
296 .map_err(|err: semver::Error| anyhow!(err))?;
297
298 #[cfg(not(feature = "r0vm-ver-compat"))]
299 let check_client_func = check_client_version;
300 #[cfg(feature = "r0vm-ver-compat")]
301 let check_client_func = check_client_version_compat;
302
303 if !check_client_func(&client_version, &server_version) {
304 let msg = format!(
305 "incompatible client version: {client_version}, server version: {server_version}"
306 );
307 tracing::debug!("{msg}");
308 bail!(msg);
309 }
310
311 let reply = pb::api::HelloReply {
312 kind: Some(pb::api::hello_reply::Kind::Ok(pb::api::HelloResult {
313 version: Some(server_version.into()),
314 })),
315 };
316 tracing::trace!("tx: {reply:?}");
317 let request: pb::api::ServerRequest = conn.send_recv(reply)?;
318 tracing::trace!("rx: {request:?}");
319
320 match request.kind.ok_or(malformed_err())? {
321 pb::api::server_request::Kind::Prove(request) => self.on_prove(conn, request),
322 pb::api::server_request::Kind::Execute(request) => self.on_execute(conn, request),
323 pb::api::server_request::Kind::ProveSegment(request) => {
324 self.on_prove_segment(conn, request)
325 }
326 pb::api::server_request::Kind::Lift(request) => self.on_lift(conn, request),
327 pb::api::server_request::Kind::Join(request) => self.on_join(conn, request),
328 pb::api::server_request::Kind::Resolve(request) => self.on_resolve(conn, request),
329 pb::api::server_request::Kind::IdentityP254(request) => {
330 self.on_identity_p254(conn, request)
331 }
332 pb::api::server_request::Kind::Compress(request) => self.on_compress(conn, request),
333 pb::api::server_request::Kind::Verify(request) => self.on_verify(conn, request),
334 pb::api::server_request::Kind::ProveZkr(request) => self.on_prove_zkr(conn, request),
335 pb::api::server_request::Kind::ProveKeccak(request) => {
336 self.on_prove_keccak(conn, request)
337 }
338 }
339 }
340
341 fn on_execute(
342 &self,
343 mut conn: ConnectionWrapper,
344 request: pb::api::ExecuteRequest,
345 ) -> Result<()> {
346 fn inner(
347 conn: &mut ConnectionWrapper,
348 request: pb::api::ExecuteRequest,
349 ) -> Result<pb::api::ServerReply> {
350 let env_request = request.env.ok_or(malformed_err())?;
351 let env = build_env(conn, &env_request)?;
352
353 let binary = env_request.binary.ok_or(malformed_err())?;
354
355 let segments_out = request.segments_out.ok_or(malformed_err())?;
356 let bytes = binary.as_bytes()?;
357 let mut exec = ExecutorImpl::from_elf(env, &bytes)?;
358
359 let session = match AssetRequest::try_from(segments_out.clone())? {
360 #[cfg(feature = "redis")]
361 AssetRequest::Redis(params) => execute_redis(conn, &mut exec, params)?,
362 _ => execute_default(conn, &mut exec, &segments_out)?,
363 };
364
365 let receipt_claim = session.claim()?;
366 Ok(pb::api::ServerReply {
367 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
368 kind: Some(pb::api::client_callback::Kind::SessionDone(
369 pb::api::OnSessionDone {
370 session: Some(pb::api::SessionInfo {
371 segments: session.segments.len().try_into()?,
372 journal: session.journal.unwrap_or_default().bytes,
373 exit_code: Some(session.exit_code.into()),
374 receipt_claim: Some(pb::api::Asset::from_bytes(
375 &pb::api::AssetRequest {
376 kind: Some(pb::api::asset_request::Kind::Inline(())),
377 },
378 pb::core::ReceiptClaim::from(receipt_claim)
379 .encode_to_vec()
380 .into(),
381 "session_info.claim",
382 )?),
383 }),
384 },
385 )),
386 })),
387 })
388 }
389
390 let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
391 kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
392 reason: err.to_string(),
393 })),
394 });
395
396 tracing::trace!("tx: {msg:?}");
397 conn.send(msg)
398 }
399
400 fn on_prove(&self, mut conn: ConnectionWrapper, request: pb::api::ProveRequest) -> Result<()> {
401 fn inner(
402 conn: &mut ConnectionWrapper,
403 request: pb::api::ProveRequest,
404 ) -> Result<pb::api::ServerReply> {
405 let env_request = request.env.ok_or(malformed_err())?;
406 let env = build_env(conn, &env_request)?;
407
408 let binary = env_request.binary.ok_or(malformed_err())?;
409 let bytes = binary.as_bytes()?;
410
411 let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
412 let prover = get_prover_server(&opts)?;
413 let ctx = VerifierContext::default();
414 let prove_info = prover.prove_with_ctx(env, &ctx, &bytes)?;
415
416 let prove_info: pb::core::ProveInfo = prove_info.into();
417 let prove_info_bytes = prove_info.encode_to_vec();
418 let asset = pb::api::Asset::from_bytes(
419 &request.receipt_out.ok_or(malformed_err())?,
420 prove_info_bytes.into(),
421 "prove_info.zkp",
422 )?;
423
424 Ok(pb::api::ServerReply {
425 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
426 kind: Some(pb::api::client_callback::Kind::ProveDone(
427 pb::api::OnProveDone {
428 prove_info: Some(asset),
429 },
430 )),
431 })),
432 })
433 }
434
435 let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
436 kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
437 reason: err.to_string(),
438 })),
439 });
440
441 tracing::trace!("tx: {msg:?}");
442 conn.send(msg)
443 }
444
445 fn on_prove_segment(
446 &self,
447 mut conn: ConnectionWrapper,
448 request: pb::api::ProveSegmentRequest,
449 ) -> Result<()> {
450 fn inner(request: pb::api::ProveSegmentRequest) -> Result<pb::api::ProveSegmentReply> {
451 let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
452 let segment_bytes = request.segment.ok_or(malformed_err())?.as_bytes()?;
453 let segment: Segment = bincode::deserialize(&segment_bytes)?;
454
455 let prover = get_prover_server(&opts)?;
456 let ctx = VerifierContext::default();
457 let receipt = prover.prove_segment(&ctx, &segment)?;
458
459 let receipt_pb: pb::core::SegmentReceipt = receipt.into();
460 let receipt_bytes = receipt_pb.encode_to_vec();
461 let asset = pb::api::Asset::from_bytes(
462 &request.receipt_out.ok_or(malformed_err())?,
463 receipt_bytes.into(),
464 "receipt.zkp",
465 )?;
466
467 Ok(pb::api::ProveSegmentReply {
468 kind: Some(pb::api::prove_segment_reply::Kind::Ok(
469 pb::api::ProveSegmentResult {
470 receipt: Some(asset),
471 },
472 )),
473 })
474 }
475
476 let msg = inner(request).unwrap_or_else(|err| pb::api::ProveSegmentReply {
477 kind: Some(pb::api::prove_segment_reply::Kind::Error(
478 pb::api::GenericError {
479 reason: err.to_string(),
480 },
481 )),
482 });
483
484 tracing::trace!("tx: {msg:?}");
485 conn.send(msg)
486 }
487
488 fn on_prove_zkr(
489 &self,
490 mut conn: ConnectionWrapper,
491 request: pb::api::ProveZkrRequest,
492 ) -> Result<()> {
493 fn inner(request: pb::api::ProveZkrRequest) -> Result<pb::api::ProveZkrReply> {
494 let control_id = request.control_id.ok_or(malformed_err())?.try_into()?;
495 let receipt = prove_registered_zkr(&control_id, vec![control_id], &request.input)?;
496
497 let receipt_pb: pb::core::SuccinctReceipt = receipt.into();
498 let receipt_bytes = receipt_pb.encode_to_vec();
499 let asset = pb::api::Asset::from_bytes(
500 &request.receipt_out.ok_or(malformed_err())?,
501 receipt_bytes.into(),
502 "receipt.zkp",
503 )?;
504
505 Ok(pb::api::ProveZkrReply {
506 kind: Some(pb::api::prove_zkr_reply::Kind::Ok(
507 pb::api::ProveZkrResult {
508 receipt: Some(asset),
509 },
510 )),
511 })
512 }
513
514 let msg = inner(request).unwrap_or_else(|err| pb::api::ProveZkrReply {
515 kind: Some(pb::api::prove_zkr_reply::Kind::Error(
516 pb::api::GenericError {
517 reason: err.to_string(),
518 },
519 )),
520 });
521
522 tracing::trace!("tx: {msg:?}");
523 conn.send(msg)
524 }
525
526 fn on_prove_keccak(
527 &self,
528 mut conn: ConnectionWrapper,
529 request: pb::api::ProveKeccakRequest,
530 ) -> Result<()> {
531 fn inner(request_pb: pb::api::ProveKeccakRequest) -> Result<pb::api::ProveKeccakReply> {
532 let request: ProveKeccakRequest = request_pb.clone().try_into()?;
533 let receipt = prove_keccak(&request)?;
534
535 let receipt_pb: pb::core::SuccinctReceipt = receipt.into();
536 let receipt_bytes = receipt_pb.encode_to_vec();
537 let asset = pb::api::Asset::from_bytes(
538 &request_pb.receipt_out.ok_or(malformed_err())?,
539 receipt_bytes.into(),
540 "receipt.zkp",
541 )?;
542
543 Ok(pb::api::ProveKeccakReply {
544 kind: Some(pb::api::prove_keccak_reply::Kind::Ok(
545 pb::api::ProveKeccakResult {
546 receipt: Some(asset),
547 },
548 )),
549 })
550 }
551
552 let msg = inner(request).unwrap_or_else(|err| pb::api::ProveKeccakReply {
553 kind: Some(pb::api::prove_keccak_reply::Kind::Error(
554 pb::api::GenericError {
555 reason: err.to_string(),
556 },
557 )),
558 });
559
560 tracing::trace!("tx: {msg:?}");
561 conn.send(msg)
562 }
563
564 fn on_lift(&self, mut conn: ConnectionWrapper, request: pb::api::LiftRequest) -> Result<()> {
565 fn inner(request: pb::api::LiftRequest) -> Result<pb::api::LiftReply> {
566 let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
567 let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
568 let segment_receipt: SegmentReceipt = bincode::deserialize(&receipt_bytes)?;
569
570 let prover = get_prover_server(&opts)?;
571 let receipt = prover.lift(&segment_receipt)?;
572
573 let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
574 let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
575 let asset = pb::api::Asset::from_bytes(
576 &request.receipt_out.ok_or(malformed_err())?,
577 succinct_receipt_bytes.into(),
578 "receipt.zkp",
579 )?;
580
581 Ok(pb::api::LiftReply {
582 kind: Some(pb::api::lift_reply::Kind::Ok(pb::api::LiftResult {
583 receipt: Some(asset),
584 })),
585 })
586 }
587
588 let msg = inner(request).unwrap_or_else(|err| pb::api::LiftReply {
589 kind: Some(pb::api::lift_reply::Kind::Error(pb::api::GenericError {
590 reason: err.to_string(),
591 })),
592 });
593
594 conn.send(msg)
596 }
597
598 fn on_join(&self, mut conn: ConnectionWrapper, request: pb::api::JoinRequest) -> Result<()> {
599 fn inner(request: pb::api::JoinRequest) -> Result<pb::api::JoinReply> {
600 let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
601 let left_receipt_bytes = request.left_receipt.ok_or(malformed_err())?.as_bytes()?;
602 let left_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
603 bincode::deserialize(&left_receipt_bytes)?;
604 let right_receipt_bytes = request.right_receipt.ok_or(malformed_err())?.as_bytes()?;
605 let right_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
606 bincode::deserialize(&right_receipt_bytes)?;
607
608 let prover = get_prover_server(&opts)?;
609 let receipt = prover.join(&left_succinct_receipt, &right_succinct_receipt)?;
610
611 let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
612 let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
613 let asset = pb::api::Asset::from_bytes(
614 &request.receipt_out.ok_or(malformed_err())?,
615 succinct_receipt_bytes.into(),
616 "receipt.zkp",
617 )?;
618
619 Ok(pb::api::JoinReply {
620 kind: Some(pb::api::join_reply::Kind::Ok(pb::api::JoinResult {
621 receipt: Some(asset),
622 })),
623 })
624 }
625
626 let msg = inner(request).unwrap_or_else(|err| pb::api::JoinReply {
627 kind: Some(pb::api::join_reply::Kind::Error(pb::api::GenericError {
628 reason: err.to_string(),
629 })),
630 });
631
632 conn.send(msg)
634 }
635
636 fn on_resolve(
637 &self,
638 mut conn: ConnectionWrapper,
639 request: pb::api::ResolveRequest,
640 ) -> Result<()> {
641 fn inner(request: pb::api::ResolveRequest) -> Result<pb::api::ResolveReply> {
642 let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
643 let conditional_receipt_bytes = request
644 .conditional_receipt
645 .ok_or(malformed_err())?
646 .as_bytes()?;
647 let conditional_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
648 bincode::deserialize(&conditional_receipt_bytes)?;
649 let assumption_receipt_bytes = request
650 .assumption_receipt
651 .ok_or(malformed_err())?
652 .as_bytes()?;
653 let assumption_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
654 bincode::deserialize(&assumption_receipt_bytes)?;
655
656 let prover = get_prover_server(&opts)?;
657 let receipt = prover.resolve(
658 &conditional_succinct_receipt,
659 &assumption_succinct_receipt.into_unknown(),
660 )?;
661
662 let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
663 let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
664 let asset = pb::api::Asset::from_bytes(
665 &request.receipt_out.ok_or(malformed_err())?,
666 succinct_receipt_bytes.into(),
667 "receipt.zkp",
668 )?;
669
670 Ok(pb::api::ResolveReply {
671 kind: Some(pb::api::resolve_reply::Kind::Ok(pb::api::ResolveResult {
672 receipt: Some(asset),
673 })),
674 })
675 }
676
677 let msg = inner(request).unwrap_or_else(|err| pb::api::ResolveReply {
678 kind: Some(pb::api::resolve_reply::Kind::Error(pb::api::GenericError {
679 reason: err.to_string(),
680 })),
681 });
682
683 conn.send(msg)
685 }
686
687 fn on_identity_p254(
688 &self,
689 mut conn: ConnectionWrapper,
690 request: pb::api::IdentityP254Request,
691 ) -> Result<()> {
692 fn inner(request: pb::api::IdentityP254Request) -> Result<pb::api::IdentityP254Reply> {
693 let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
694 let succinct_receipt: SuccinctReceipt<ReceiptClaim> =
695 bincode::deserialize(&receipt_bytes)?;
696
697 let receipt = identity_p254(&succinct_receipt)?;
698 let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
699 let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
700 let asset = pb::api::Asset::from_bytes(
701 &request.receipt_out.ok_or(malformed_err())?,
702 succinct_receipt_bytes.into(),
703 "receipt.zkp",
704 )?;
705
706 Ok(pb::api::IdentityP254Reply {
707 kind: Some(pb::api::identity_p254_reply::Kind::Ok(
708 pb::api::IdentityP254Result {
709 receipt: Some(asset),
710 },
711 )),
712 })
713 }
714
715 let msg = inner(request).unwrap_or_else(|err| pb::api::IdentityP254Reply {
716 kind: Some(pb::api::identity_p254_reply::Kind::Error(
717 pb::api::GenericError {
718 reason: err.to_string(),
719 },
720 )),
721 });
722
723 conn.send(msg)
725 }
726
727 fn on_compress(
728 &self,
729 mut conn: ConnectionWrapper,
730 request: pb::api::CompressRequest,
731 ) -> Result<()> {
732 fn inner(request: pb::api::CompressRequest) -> Result<pb::api::CompressReply> {
733 let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
734 let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
735 let receipt: Receipt = bincode::deserialize(&receipt_bytes)?;
736
737 let prover = get_prover_server(&opts)?;
738 let receipt = prover.compress(&opts, &receipt)?;
739
740 let receipt_pb: pb::core::Receipt = receipt.into();
741 let receipt_bytes = receipt_pb.encode_to_vec();
742 let asset = pb::api::Asset::from_bytes(
743 &request.receipt_out.ok_or(malformed_err())?,
744 receipt_bytes.into(),
745 "receipt.zkp",
746 )?;
747
748 Ok(pb::api::CompressReply {
749 kind: Some(pb::api::compress_reply::Kind::Ok(pb::api::CompressResult {
750 receipt: Some(asset),
751 })),
752 })
753 }
754
755 let msg = inner(request).unwrap_or_else(|err| pb::api::CompressReply {
756 kind: Some(pb::api::compress_reply::Kind::Error(
757 pb::api::GenericError {
758 reason: err.to_string(),
759 },
760 )),
761 });
762
763 conn.send(msg)
765 }
766
767 fn on_verify(
768 &self,
769 mut conn: ConnectionWrapper,
770 request: pb::api::VerifyRequest,
771 ) -> Result<()> {
772 fn inner(request: pb::api::VerifyRequest) -> Result<()> {
773 let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
774 let receipt: Receipt =
775 bincode::deserialize(&receipt_bytes).context("deserialize receipt")?;
776 let image_id: Digest = request.image_id.ok_or(malformed_err())?.try_into()?;
777 receipt
778 .verify(image_id)
779 .map_err(|err| anyhow!("verify failed: {err}"))
780 }
781
782 let msg: pb::api::GenericReply = inner(request).into();
783 conn.send(msg)
785 }
786}
787
788fn build_env<'a>(
789 conn: &ConnectionWrapper,
790 request: &pb::api::ExecutorEnv,
791) -> Result<ExecutorEnv<'a>> {
792 let mut env_builder = ExecutorEnv::builder();
793 env_builder.env_vars(request.env_vars.clone());
794 env_builder.args(&request.args);
795 for fd in request.read_fds.iter() {
796 let proxy = PosixIoProxy::new(*fd, conn.clone());
797 let reader = BufReader::new(proxy);
798 env_builder.read_fd(*fd, reader);
799 }
800 for fd in request.write_fds.iter() {
801 let proxy = PosixIoProxy::new(*fd, conn.clone());
802 env_builder.write_fd(*fd, proxy);
803 }
804 let proxy = SliceIoProxy::new(conn.clone());
805 for name in request.slice_ios.iter() {
806 env_builder.slice_io(name, proxy.clone());
807 }
808 if let Some(segment_limit_po2) = request.segment_limit_po2 {
809 env_builder.segment_limit_po2(segment_limit_po2);
810 }
811 env_builder.session_limit(request.session_limit);
812 if request.trace_events.is_some() {
813 let proxy = TraceProxy::new(conn.clone());
814 env_builder.trace_callback(proxy);
815 }
816 if !request.pprof_out.is_empty() {
817 env_builder.enable_profiler(Path::new(&request.pprof_out));
818 }
819 if !request.segment_path.is_empty() {
820 env_builder.segment_path(Path::new(&request.segment_path));
821 }
822 if request.coprocessor {
823 let proxy = CoprocessorProxy::new(conn.clone());
824 env_builder.coprocessor_callback(proxy);
825 }
826
827 for assumption in request.assumptions.iter() {
828 match assumption.kind.as_ref().ok_or(malformed_err())? {
829 pb::api::assumption_receipt::Kind::Proven(asset) => {
830 let receipt: InnerAssumptionReceipt =
831 pb::core::InnerReceipt::decode(asset.as_bytes()?)?.try_into()?;
832 env_builder.add_assumption(receipt)
833 }
834 pb::api::assumption_receipt::Kind::Unresolved(asset) => {
835 let assumption: Assumption =
836 pb::core::Assumption::decode(asset.as_bytes()?)?.try_into()?;
837 env_builder.add_assumption(assumption)
838 }
839 };
840 }
841 env_builder.build()
842}
843
844trait IoOtherError<T> {
845 fn map_io_err(self) -> Result<T, IoError>;
846}
847
848impl<T, E: Into<Box<dyn StdError + Send + Sync>>> IoOtherError<T> for Result<T, E> {
849 fn map_io_err(self) -> Result<T, IoError> {
850 self.map_err(|err| IoError::new(IoErrorKind::Other, err))
851 }
852}
853
854impl From<pb::api::GenericError> for IoError {
855 fn from(err: pb::api::GenericError) -> Self {
856 IoError::new(IoErrorKind::Other, err.reason)
857 }
858}
859
860impl pb::api::Asset {
861 pub fn from_bytes<P: AsRef<Path>>(
862 request: &pb::api::AssetRequest,
863 bytes: Bytes,
864 path: P,
865 ) -> Result<Self> {
866 match request.kind.as_ref().ok_or(malformed_err())? {
867 pb::api::asset_request::Kind::Inline(()) => Ok(Self {
868 kind: Some(pb::api::asset::Kind::Inline(bytes.into())),
869 }),
870 pb::api::asset_request::Kind::Path(base_path) => {
871 let base_path = PathBuf::from(base_path);
872 let path = base_path.join(path);
873 std::fs::write(&path, bytes)?;
874 Ok(Self {
875 kind: Some(pb::api::asset::Kind::Path(path_to_string(path)?)),
876 })
877 }
878 pb::api::asset_request::Kind::Redis(_) => {
879 tracing::error!("It's likely that r0vm is not installed with the redis feature");
880 bail!("from_bytes not supported for redis")
881 }
882 }
883 }
884}
885
886#[allow(dead_code)]
887fn check_client_version(client: &semver::Version, server: &semver::Version) -> bool {
888 if server.pre.is_empty() {
889 let comparator = semver::Comparator {
890 op: semver::Op::GreaterEq,
891 major: server.major,
892 minor: Some(server.minor),
893 patch: None,
894 pre: semver::Prerelease::EMPTY,
895 };
896 comparator.matches(client)
897 } else {
898 client == server
899 }
900}
901
902#[allow(dead_code)]
903fn check_client_version_compat(client: &semver::Version, server: &semver::Version) -> bool {
904 client.major == server.major
905}
906
907#[cfg(feature = "redis")]
908fn execute_redis(
909 conn: &mut ConnectionWrapper,
910 exec: &mut ExecutorImpl,
911 params: super::RedisParams,
912) -> Result<Session> {
913 use redis::{Client, Commands, ConnectionLike, SetExpiry, SetOptions};
914 use std::{
915 sync::{
916 mpsc::{sync_channel, Receiver},
917 Arc, Mutex,
918 },
919 thread::{spawn, JoinHandle},
920 };
921
922 let channel_size = match std::env::var("RISC0_REDIS_CHANNEL_SIZE") {
923 Ok(val_str) => val_str.parse::<usize>().unwrap_or(100),
924 Err(_) => 100,
925 };
926 let (sender, receiver) = sync_channel::<(String, Segment)>(channel_size);
927 let opts = SetOptions::default().with_expiration(SetExpiry::EX(params.ttl));
928
929 let redis_err = Arc::new(Mutex::new(None));
930 let redis_err_clone = redis_err.clone();
931
932 let conn = conn.clone();
933 let join_handle: JoinHandle<()> = spawn(move || {
934 fn inner(
935 redis_url: String,
936 receiver: &Receiver<(String, Segment)>,
937 opts: SetOptions,
938 mut conn: ConnectionWrapper,
939 ) -> Result<()> {
940 let client = Client::open(redis_url).context("Failed to open Redis connection")?;
941 let mut connection = client
942 .get_connection()
943 .context("Failed to get redis connection")?;
944 while let Ok((segment_key, segment)) = receiver.recv() {
945 if !connection.is_open() {
946 connection = client
947 .get_connection()
948 .context("Failed to get redis connection")?;
949 }
950 let segment_bytes =
951 bincode::serialize(&segment).context("Failed to deserialize segment")?;
952 match connection.set_options(segment_key.clone(), segment_bytes.clone(), opts) {
953 Ok(()) => (),
954 Err(err) => {
955 tracing::warn!(
956 "Failed to set redis key with TTL, trying again. Error: {err}"
957 );
958 connection = client
959 .get_connection()
960 .context("Failed to get redis connection")?;
961 let _: () = connection
962 .set_options(segment_key.clone(), segment_bytes, opts)
963 .context("Failed to set redis key with TTL again")?;
964 }
965 };
966 let asset = pb::api::Asset {
967 kind: Some(pb::api::asset::Kind::Redis(segment_key)),
968 };
969 send_segment_done_msg(&mut conn, segment, Some(asset))
970 .context("Failed to send segment_done msg")?;
971 }
972 Ok(())
973 }
974
975 if let Err(err) = inner(params.url, &receiver, opts, conn) {
976 *redis_err_clone.lock().unwrap() = Some(err);
977 }
978 });
979
980 let session = exec.run_with_callback(|segment| {
981 let segment_key = format!("{}:{}", params.key, segment.index);
982 if let Err(send_err) = sender.send((segment_key, segment)) {
983 let mut redis_err_opt = redis_err.lock().unwrap();
984 let redis_err_inner = redis_err_opt.take();
985 return Err(match redis_err_inner {
986 Some(redis_thread_err) => {
987 tracing::error!(
988 "Redis err: {redis_thread_err} root: {:?}",
989 redis_thread_err.root_cause()
990 );
991 anyhow!(redis_thread_err)
992 }
993 None => send_err.into(),
994 });
995 }
996 Ok(Box::new(NullSegmentRef))
997 });
998
999 drop(sender);
1000
1001 join_handle
1002 .join()
1003 .map_err(|err| anyhow!("redis task join failed: {err:?}"))?;
1004
1005 session
1006}
1007
1008fn execute_default(
1009 conn: &mut ConnectionWrapper,
1010 exec: &mut ExecutorImpl,
1011 segments_out: &pb::api::AssetRequest,
1012) -> Result<Session> {
1013 exec.run_with_callback(|segment| {
1014 let segment_bytes = bincode::serialize(&segment)?;
1015 let asset = pb::api::Asset::from_bytes(
1016 segments_out,
1017 segment_bytes.into(),
1018 format!("segment-{}", segment.index),
1019 )?;
1020 send_segment_done_msg(conn, segment, Some(asset))?;
1021 Ok(Box::new(NullSegmentRef))
1022 })
1023}
1024
1025fn send_segment_done_msg(
1026 conn: &mut ConnectionWrapper,
1027 segment: Segment,
1028 some_asset: Option<pb::api::Asset>,
1029) -> Result<()> {
1030 let segment = Some(pb::api::SegmentInfo {
1031 index: segment.index,
1032 po2: segment.inner.po2 as u32,
1033 cycles: segment.inner.insn_cycles as u32,
1034 segment: some_asset,
1035 });
1036
1037 let msg = pb::api::ServerReply {
1038 kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
1039 kind: Some(pb::api::client_callback::Kind::SegmentDone(
1040 pb::api::OnSegmentDone { segment },
1041 )),
1042 })),
1043 };
1044
1045 tracing::trace!("tx: {msg:?}");
1046 let reply: pb::api::GenericReply = conn.send_recv(msg)?;
1047 tracing::trace!("rx: {reply:?}");
1048
1049 let kind = reply.kind.ok_or(malformed_err())?;
1050 if let pb::api::generic_reply::Kind::Error(err) = kind {
1051 bail!(err)
1052 }
1053 Ok(())
1054}
1055
1056#[cfg(test)]
1057mod tests {
1058 use semver::Version;
1059
1060 use super::{check_client_version, check_client_version_compat};
1061
1062 fn test_inner(check_func: fn(&Version, &Version) -> bool, client: &str, server: &str) -> bool {
1063 check_func(
1064 &Version::parse(client).unwrap(),
1065 &Version::parse(server).unwrap(),
1066 )
1067 }
1068
1069 #[test]
1070 fn check_version() {
1071 fn test(client: &str, server: &str) -> bool {
1072 test_inner(check_client_version, client, server)
1073 }
1074
1075 assert!(test("0.18.0", "0.18.0"));
1076 assert!(test("0.18.1", "0.18.0"));
1077 assert!(test("0.18.0", "0.18.1"));
1078 assert!(test("0.19.0", "0.18.0"));
1079 assert!(test("1.0.0", "0.18.0"));
1080 assert!(test("1.1.0", "1.0.0"));
1081 assert!(test("0.19.0-alpha.1", "0.19.0-alpha.1"));
1082
1083 assert!(!test("0.19.0-alpha.1", "0.19.0-alpha.2"));
1084 assert!(!test("0.18.0", "0.19.0"));
1085 assert!(!test("0.18.0", "1.0.0"));
1086 }
1087
1088 #[test]
1089 fn check_version_compat() {
1090 fn test(client: &str, server: &str) -> bool {
1091 test_inner(check_client_version_compat, client, server)
1092 }
1093
1094 assert!(test("1.1.0", "1.1.0"));
1095 assert!(test("1.1.1", "1.1.1"));
1096 assert!(test("1.2.0", "1.1.1"));
1097 assert!(test("1.2.0-rc.1", "1.1.1"));
1098
1099 assert!(!test("2.0.0", "1.1.1"));
1100 }
1101}