risc0_zkvm/host/api/
server.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    error::Error as StdError,
17    io::{BufReader, Error as IoError, Read, Write},
18    path::{Path, PathBuf},
19};
20
21use anyhow::{anyhow, bail, Context, Result};
22use bytes::Bytes;
23use prost::Message;
24use risc0_binfmt::PovwJobId;
25use risc0_zkp::core::digest::Digest;
26
27use super::{malformed_err, path_to_string, pb, ConnectionWrapper, Connector, TcpConnector};
28use crate::{
29    get_prover_server, get_version,
30    host::{
31        api::convert::keccak_input_to_bytes,
32        client::{
33            env::{CoprocessorCallback, ProveKeccakRequest},
34            slice_io::SliceIo,
35        },
36        server::{
37            exec::executor::ExecutorImpl, prove::keccak::prove_keccak, session::NullSegmentRef,
38        },
39    },
40    recursion::identity_p254,
41    AssetRequest, Assumption, ExecutorEnv, InnerAssumptionReceipt, ProverOpts, Receipt,
42    ReceiptClaim, Segment, SegmentReceipt, SegmentRef, Session, SuccinctReceipt, TraceCallback,
43    TraceEvent, Unknown, VerifierContext,
44};
45
46/// A server implementation for handling requests by clients of the zkVM.
47pub struct Server {
48    connector: Box<dyn Connector>,
49}
50
51struct PosixIoProxy {
52    fd: u32,
53    conn: ConnectionWrapper,
54}
55
56impl PosixIoProxy {
57    fn new(fd: u32, conn: ConnectionWrapper) -> Self {
58        PosixIoProxy { fd, conn }
59    }
60}
61
62impl Read for PosixIoProxy {
63    fn read(&mut self, to_guest: &mut [u8]) -> std::io::Result<usize> {
64        let nread = to_guest.len().try_into().map_io_err()?;
65        let request = pb::api::ServerReply {
66            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
67                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
68                    kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
69                        fd: self.fd,
70                        cmd: Some(pb::api::PosixCmd {
71                            kind: Some(pb::api::posix_cmd::Kind::Read(nread)),
72                        }),
73                    })),
74                })),
75            })),
76        };
77
78        tracing::trace!("tx: {request:?}");
79        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
80        tracing::trace!("rx: {reply:?}");
81
82        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
83        match kind {
84            pb::api::on_io_reply::Kind::Ok(bytes) => {
85                let (head, _) = to_guest.split_at_mut(bytes.len());
86                head.copy_from_slice(&bytes);
87                Ok(bytes.len())
88            }
89            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
90        }
91    }
92}
93
94impl Write for PosixIoProxy {
95    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
96        let request = pb::api::ServerReply {
97            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
98                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
99                    kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
100                        fd: self.fd,
101                        cmd: Some(pb::api::PosixCmd {
102                            kind: Some(pb::api::posix_cmd::Kind::Write(buf.into())),
103                        }),
104                    })),
105                })),
106            })),
107        };
108
109        tracing::trace!("tx: {request:?}");
110        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
111        tracing::trace!("rx: {reply:?}");
112
113        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
114        match kind {
115            pb::api::on_io_reply::Kind::Ok(_) => Ok(buf.len()),
116            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
117        }
118    }
119
120    fn flush(&mut self) -> std::io::Result<()> {
121        Ok(())
122    }
123}
124
125#[derive(Clone)]
126struct SliceIoProxy {
127    conn: ConnectionWrapper,
128}
129
130impl SliceIoProxy {
131    fn new(conn: ConnectionWrapper) -> Self {
132        Self { conn }
133    }
134}
135
136impl SliceIo for SliceIoProxy {
137    fn handle_io(&mut self, syscall: &str, from_guest: Bytes) -> Result<Bytes> {
138        let request = pb::api::ServerReply {
139            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
140                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
141                    kind: Some(pb::api::on_io_request::Kind::Slice(pb::api::SliceIo {
142                        name: syscall.to_string(),
143                        from_guest: from_guest.into(),
144                    })),
145                })),
146            })),
147        };
148        tracing::trace!("tx: {request:?}");
149        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
150        tracing::trace!("rx: {reply:?}");
151
152        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
153        match kind {
154            pb::api::on_io_reply::Kind::Ok(buf) => Ok(buf.into()),
155            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
156        }
157    }
158}
159
160struct TraceProxy {
161    conn: ConnectionWrapper,
162}
163
164impl TraceProxy {
165    fn new(conn: ConnectionWrapper) -> Self {
166        Self { conn }
167    }
168}
169
170impl TraceCallback for TraceProxy {
171    fn trace_callback(&mut self, event: TraceEvent) -> Result<()> {
172        let Ok(event) = event.clone().try_into() else {
173            tracing::trace!("ignoring unknown event {event:?}");
174            return Ok(());
175        };
176
177        let request = pb::api::ServerReply {
178            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
179                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
180                    kind: Some(pb::api::on_io_request::Kind::Trace(event)),
181                })),
182            })),
183        };
184        tracing::trace!("tx: {request:?}");
185        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
186        tracing::trace!("rx: {reply:?}");
187
188        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
189        match kind {
190            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
191            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
192        }
193    }
194}
195
196struct CoprocessorProxy {
197    conn: ConnectionWrapper,
198}
199
200impl CoprocessorProxy {
201    fn new(conn: ConnectionWrapper) -> Self {
202        Self { conn }
203    }
204}
205
206impl CoprocessorCallback for CoprocessorProxy {
207    fn prove_keccak(&mut self, proof_request: ProveKeccakRequest) -> Result<()> {
208        let input = keccak_input_to_bytes(&proof_request.input);
209        let request = pb::api::ServerReply {
210            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
211                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
212                    kind: Some(pb::api::on_io_request::Kind::Coprocessor(
213                        pb::api::CoprocessorRequest {
214                            kind: Some(pb::api::coprocessor_request::Kind::ProveKeccak({
215                                pb::api::ProveKeccakRequest {
216                                    claim_digest: Some(proof_request.claim_digest.into()),
217                                    po2: proof_request.po2 as u32,
218                                    control_root: Some(proof_request.control_root.into()),
219                                    input,
220                                    receipt_out: None,
221                                }
222                            })),
223                        },
224                    )),
225                })),
226            })),
227        };
228        tracing::trace!("tx: {request:?}");
229        self.conn.send(request)?;
230
231        let reply: pb::api::OnIoReply = self.conn.recv().map_io_err()?;
232        tracing::trace!("rx: {reply:?}");
233
234        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
235        match kind {
236            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
237            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
238        }
239    }
240}
241
242impl Server {
243    /// Construct a new [Server] with the specified [Connector].
244    pub fn new(connector: Box<dyn Connector>) -> Self {
245        Self { connector }
246    }
247
248    /// Construct a new [Server] which will connect to the specified TCP/IP
249    /// address.
250    pub fn new_tcp<A: AsRef<str>>(addr: A) -> Self {
251        let connector = TcpConnector::new(addr.as_ref());
252        Self::new(Box::new(connector))
253    }
254
255    /// Start the [Server] and run until all requests are complete.
256    pub fn run(&self) -> Result<()> {
257        tracing::debug!("connect");
258        let mut conn = self.connector.connect()?;
259
260        let server_version = get_version().map_err(|err| anyhow!(err))?;
261
262        let request: pb::api::HelloRequest = conn.recv()?;
263        tracing::trace!("rx: {request:?}");
264
265        let client_version: semver::Version = request
266            .version
267            .ok_or_else(|| malformed_err("HelloRequest.version"))?
268            .try_into()
269            .map_err(|err: semver::Error| anyhow!(err))?;
270
271        #[cfg(not(feature = "r0vm-ver-compat"))]
272        let check_client_func = check_client_version;
273        #[cfg(feature = "r0vm-ver-compat")]
274        let check_client_func = check_client_version_compat;
275
276        if !check_client_func(&client_version, &server_version) {
277            let msg = format!(
278                "incompatible client version: {client_version}, server version: {server_version}"
279            );
280            tracing::debug!("{msg}");
281            bail!(msg);
282        }
283
284        let reply = pb::api::HelloReply {
285            kind: Some(pb::api::hello_reply::Kind::Ok(pb::api::HelloResult {
286                version: Some(server_version.into()),
287            })),
288        };
289        tracing::trace!("tx: {reply:?}");
290        let request: pb::api::ServerRequest = conn.send_recv(reply)?;
291        tracing::trace!("rx: {request:?}");
292
293        match request
294            .kind
295            .ok_or_else(|| malformed_err("ServerRequest.kind"))?
296        {
297            pb::api::server_request::Kind::Prove(request) => self.on_prove(conn, request),
298            pb::api::server_request::Kind::Execute(request) => self.on_execute(conn, request),
299            pb::api::server_request::Kind::ProveSegment(request) => {
300                self.on_prove_segment(conn, request)
301            }
302            pb::api::server_request::Kind::Lift(request) => self.on_lift(conn, request),
303            pb::api::server_request::Kind::Join(request) => self.on_join(conn, request),
304            pb::api::server_request::Kind::Resolve(request) => self.on_resolve(conn, request),
305            pb::api::server_request::Kind::IdentityP254(request) => {
306                self.on_identity_p254(conn, request)
307            }
308            pb::api::server_request::Kind::Compress(request) => self.on_compress(conn, request),
309            pb::api::server_request::Kind::Verify(request) => self.on_verify(conn, request),
310            pb::api::server_request::Kind::ProveKeccak(request) => {
311                self.on_prove_keccak(conn, request)
312            }
313            pb::api::server_request::Kind::Union(request) => self.on_union(conn, request),
314        }
315    }
316
317    fn on_execute(
318        &self,
319        mut conn: ConnectionWrapper,
320        request: pb::api::ExecuteRequest,
321    ) -> Result<()> {
322        fn inner(
323            conn: &mut ConnectionWrapper,
324            request: pb::api::ExecuteRequest,
325        ) -> Result<pb::api::ServerReply> {
326            let env_request = request
327                .env
328                .ok_or_else(|| malformed_err("ExecuteRequest.env"))?;
329            let env = build_env(conn, &env_request)?;
330
331            let binary = env_request
332                .binary
333                .ok_or_else(|| malformed_err("ExecuteRequest.binary"))?;
334
335            let segments_out = request
336                .segments_out
337                .ok_or_else(|| malformed_err("ExecuteRequest.segments_out"))?;
338            let bytes = binary.as_bytes()?;
339
340            let session = match AssetRequest::try_from(segments_out.clone())? {
341                #[cfg(feature = "redis")]
342                AssetRequest::Redis(params) => execute_redis(conn, env, bytes, params)?,
343                _ => execute_default(conn, env, bytes, &segments_out)?,
344            };
345
346            let receipt_claim = session.claim()?;
347            Ok(pb::api::ServerReply {
348                kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
349                    kind: Some(pb::api::client_callback::Kind::SessionDone(
350                        pb::api::OnSessionDone {
351                            session: Some(pb::api::SessionInfo {
352                                segments: session.segments.len().try_into()?,
353                                journal: session.journal.unwrap_or_default().bytes,
354                                exit_code: Some(session.exit_code.try_into()?),
355                                receipt_claim: Some(pb::api::Asset::from_bytes(
356                                    &pb::api::AssetRequest {
357                                        kind: Some(pb::api::asset_request::Kind::Inline(())),
358                                    },
359                                    pb::core::ReceiptClaim::try_from(receipt_claim)?
360                                        .encode_to_vec()
361                                        .into(),
362                                    "session_info.claim",
363                                )?),
364                            }),
365                        },
366                    )),
367                })),
368            })
369        }
370
371        let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
372            kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
373                reason: err.to_string(),
374            })),
375        });
376
377        tracing::trace!("tx: {msg:?}");
378        conn.send(msg)
379    }
380
381    fn on_prove(&self, mut conn: ConnectionWrapper, request: pb::api::ProveRequest) -> Result<()> {
382        fn inner(
383            conn: &mut ConnectionWrapper,
384            request: pb::api::ProveRequest,
385        ) -> Result<pb::api::ServerReply> {
386            let env_request = request
387                .env
388                .ok_or_else(|| malformed_err("ProveRequest.env"))?;
389            let env = build_env(conn, &env_request)?;
390
391            let binary = env_request
392                .binary
393                .ok_or_else(|| malformed_err("ProveRequest.env_request.binary"))?;
394            let bytes = binary.as_bytes()?;
395
396            let opts: ProverOpts = request
397                .opts
398                .ok_or_else(|| malformed_err("ProveRequest.opts"))?
399                .try_into()?;
400            let prover = get_prover_server(&opts)?;
401            let ctx = VerifierContext::default().with_dev_mode(opts.dev_mode());
402            let prove_info = prover.prove_with_ctx(env, &ctx, &bytes)?;
403
404            let prove_info: pb::core::ProveInfo = prove_info.try_into()?;
405            let prove_info_bytes = prove_info.encode_to_vec();
406            let asset = pb::api::Asset::from_bytes(
407                &request
408                    .receipt_out
409                    .ok_or_else(|| malformed_err("ProveRequest.receipt_out"))?,
410                prove_info_bytes.into(),
411                "prove_info.zkp",
412            )?;
413
414            Ok(pb::api::ServerReply {
415                kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
416                    kind: Some(pb::api::client_callback::Kind::ProveDone(
417                        pb::api::OnProveDone {
418                            prove_info: Some(asset),
419                        },
420                    )),
421                })),
422            })
423        }
424
425        let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
426            kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
427                reason: err.to_string(),
428            })),
429        });
430
431        tracing::trace!("tx: {msg:?}");
432        conn.send(msg)
433    }
434
435    fn on_prove_segment(
436        &self,
437        mut conn: ConnectionWrapper,
438        request: pb::api::ProveSegmentRequest,
439    ) -> Result<()> {
440        fn inner(request: pb::api::ProveSegmentRequest) -> Result<pb::api::ProveSegmentReply> {
441            let opts: ProverOpts = request
442                .opts
443                .ok_or_else(|| malformed_err("ProveSegmentRequest.opts"))?
444                .try_into()?;
445            let segment_bytes = request
446                .segment
447                .ok_or_else(|| malformed_err("ProveSegmentRequest.segment"))?
448                .as_bytes()?;
449            let segment: Segment = bincode::deserialize(&segment_bytes)?;
450
451            let prover = get_prover_server(&opts)?;
452            let ctx = VerifierContext::default().with_dev_mode(opts.dev_mode());
453            let receipt = prover.prove_segment(&ctx, &segment)?;
454
455            let receipt_pb: pb::core::SegmentReceipt = receipt.try_into()?;
456            let receipt_bytes = receipt_pb.encode_to_vec();
457            let asset = pb::api::Asset::from_bytes(
458                &request
459                    .receipt_out
460                    .ok_or_else(|| malformed_err("ProveSegmentRequest.receipt_out"))?,
461                receipt_bytes.into(),
462                "receipt.zkp",
463            )?;
464
465            Ok(pb::api::ProveSegmentReply {
466                kind: Some(pb::api::prove_segment_reply::Kind::Ok(
467                    pb::api::ProveSegmentResult {
468                        receipt: Some(asset),
469                    },
470                )),
471            })
472        }
473
474        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveSegmentReply {
475            kind: Some(pb::api::prove_segment_reply::Kind::Error(
476                pb::api::GenericError {
477                    reason: err.to_string(),
478                },
479            )),
480        });
481
482        tracing::trace!("tx: {msg:?}");
483        conn.send(msg)
484    }
485
486    fn on_prove_keccak(
487        &self,
488        mut conn: ConnectionWrapper,
489        request: pb::api::ProveKeccakRequest,
490    ) -> Result<()> {
491        fn inner(request_pb: pb::api::ProveKeccakRequest) -> Result<pb::api::ProveKeccakReply> {
492            let request: ProveKeccakRequest = request_pb.clone().try_into()?;
493            let receipt = prove_keccak(&request)?;
494
495            let receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
496            let receipt_bytes = receipt_pb.encode_to_vec();
497            let asset = pb::api::Asset::from_bytes(
498                &request_pb
499                    .receipt_out
500                    .ok_or_else(|| malformed_err("ProveKeccakRequest.receipt_out"))?,
501                receipt_bytes.into(),
502                "receipt.zkp",
503            )?;
504
505            Ok(pb::api::ProveKeccakReply {
506                kind: Some(pb::api::prove_keccak_reply::Kind::Ok(
507                    pb::api::ProveKeccakResult {
508                        receipt: Some(asset),
509                    },
510                )),
511            })
512        }
513
514        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveKeccakReply {
515            kind: Some(pb::api::prove_keccak_reply::Kind::Error(
516                pb::api::GenericError {
517                    reason: err.to_string(),
518                },
519            )),
520        });
521
522        tracing::trace!("tx: {msg:?}");
523        conn.send(msg)
524    }
525
526    fn on_lift(&self, mut conn: ConnectionWrapper, request: pb::api::LiftRequest) -> Result<()> {
527        fn inner(request: pb::api::LiftRequest) -> Result<pb::api::LiftReply> {
528            let opts: ProverOpts = request
529                .opts
530                .ok_or_else(|| malformed_err("LiftRequest.opts"))?
531                .try_into()?;
532            let receipt_bytes = request
533                .receipt
534                .ok_or_else(|| malformed_err("LiftRequest.receipt"))?
535                .as_bytes()?;
536            let segment_receipt: SegmentReceipt = bincode::deserialize(&receipt_bytes)?;
537
538            let prover = get_prover_server(&opts)?;
539            let receipt = prover.lift(&segment_receipt)?;
540
541            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
542            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
543            let asset = pb::api::Asset::from_bytes(
544                &request
545                    .receipt_out
546                    .ok_or_else(|| malformed_err("LiftRequest.receipt_out"))?,
547                succinct_receipt_bytes.into(),
548                "receipt.zkp",
549            )?;
550
551            Ok(pb::api::LiftReply {
552                kind: Some(pb::api::lift_reply::Kind::Ok(pb::api::LiftResult {
553                    receipt: Some(asset),
554                })),
555            })
556        }
557
558        let msg = inner(request).unwrap_or_else(|err| pb::api::LiftReply {
559            kind: Some(pb::api::lift_reply::Kind::Error(pb::api::GenericError {
560                reason: err.to_string(),
561            })),
562        });
563
564        // tracing::trace!("tx: {msg:?}");
565        conn.send(msg)
566    }
567
568    fn on_join(&self, mut conn: ConnectionWrapper, request: pb::api::JoinRequest) -> Result<()> {
569        fn inner(request: pb::api::JoinRequest) -> Result<pb::api::JoinReply> {
570            let opts: ProverOpts = request
571                .opts
572                .ok_or_else(|| malformed_err("JoinRequest.opts"))?
573                .try_into()?;
574            let left_receipt_bytes = request
575                .left_receipt
576                .ok_or_else(|| malformed_err("JoinRequest.left_receipt"))?
577                .as_bytes()?;
578            let left_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
579                bincode::deserialize(&left_receipt_bytes)?;
580            let right_receipt_bytes = request
581                .right_receipt
582                .ok_or_else(|| malformed_err("JoinRequest.right_receipt"))?
583                .as_bytes()?;
584            let right_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
585                bincode::deserialize(&right_receipt_bytes)?;
586
587            let prover = get_prover_server(&opts)?;
588            let receipt = prover.join(&left_succinct_receipt, &right_succinct_receipt)?;
589
590            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
591            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
592            let asset = pb::api::Asset::from_bytes(
593                &request
594                    .receipt_out
595                    .ok_or_else(|| malformed_err("JoinRequest.receipt_out"))?,
596                succinct_receipt_bytes.into(),
597                "receipt.zkp",
598            )?;
599
600            Ok(pb::api::JoinReply {
601                kind: Some(pb::api::join_reply::Kind::Ok(pb::api::JoinResult {
602                    receipt: Some(asset),
603                })),
604            })
605        }
606
607        let msg = inner(request).unwrap_or_else(|err| pb::api::JoinReply {
608            kind: Some(pb::api::join_reply::Kind::Error(pb::api::GenericError {
609                reason: err.to_string(),
610            })),
611        });
612
613        // tracing::trace!("tx: {msg:?}");
614        conn.send(msg)
615    }
616
617    fn on_union(&self, mut conn: ConnectionWrapper, request: pb::api::UnionRequest) -> Result<()> {
618        fn inner(request: pb::api::UnionRequest) -> Result<pb::api::UnionReply> {
619            let opts: ProverOpts = request
620                .opts
621                .ok_or_else(|| malformed_err("UnionRequest.opts"))?
622                .try_into()?;
623            let left_receipt_bytes = request
624                .left_receipt
625                .ok_or_else(|| malformed_err("UnionRequest.left_receipt"))?
626                .as_bytes()?;
627            let left_succinct_receipt: SuccinctReceipt<Unknown> =
628                bincode::deserialize(&left_receipt_bytes)?;
629            let right_receipt_bytes = request
630                .right_receipt
631                .ok_or_else(|| malformed_err("UnionRequest.right_receipt"))?
632                .as_bytes()?;
633            let right_succinct_receipt: SuccinctReceipt<Unknown> =
634                bincode::deserialize(&right_receipt_bytes)?;
635
636            let prover = get_prover_server(&opts)?;
637            let receipt = prover.union(&left_succinct_receipt, &right_succinct_receipt)?;
638
639            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
640            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
641            let asset = pb::api::Asset::from_bytes(
642                &request
643                    .receipt_out
644                    .ok_or_else(|| malformed_err("UnionRequest.receipt_out"))?,
645                succinct_receipt_bytes.into(),
646                "receipt.zkp",
647            )?;
648
649            Ok(pb::api::UnionReply {
650                kind: Some(pb::api::union_reply::Kind::Ok(pb::api::UnionResult {
651                    receipt: Some(asset),
652                })),
653            })
654        }
655
656        let msg = inner(request).unwrap_or_else(|err| pb::api::UnionReply {
657            kind: Some(pb::api::union_reply::Kind::Error(pb::api::GenericError {
658                reason: err.to_string(),
659            })),
660        });
661
662        // tracing::trace!("tx: {msg:?}");
663        conn.send(msg)
664    }
665
666    fn on_resolve(
667        &self,
668        mut conn: ConnectionWrapper,
669        request: pb::api::ResolveRequest,
670    ) -> Result<()> {
671        fn inner(request: pb::api::ResolveRequest) -> Result<pb::api::ResolveReply> {
672            let opts: ProverOpts = request
673                .opts
674                .ok_or_else(|| malformed_err("ResolveRequest.opts"))?
675                .try_into()?;
676            let conditional_receipt_bytes = request
677                .conditional_receipt
678                .ok_or_else(|| malformed_err("ResolveRequest.conditional_receipt"))?
679                .as_bytes()?;
680            let conditional_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
681                bincode::deserialize(&conditional_receipt_bytes)?;
682            let assumption_receipt_bytes = request
683                .assumption_receipt
684                .ok_or_else(|| malformed_err("ResolveRequest.assumption_receipt"))?
685                .as_bytes()?;
686            let assumption_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
687                bincode::deserialize(&assumption_receipt_bytes)?;
688
689            let prover = get_prover_server(&opts)?;
690            let receipt = prover.resolve(
691                &conditional_succinct_receipt,
692                &assumption_succinct_receipt.into_unknown(),
693            )?;
694
695            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
696            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
697            let asset = pb::api::Asset::from_bytes(
698                &request
699                    .receipt_out
700                    .ok_or_else(|| malformed_err("ResolveRequest.receipt_out"))?,
701                succinct_receipt_bytes.into(),
702                "receipt.zkp",
703            )?;
704
705            Ok(pb::api::ResolveReply {
706                kind: Some(pb::api::resolve_reply::Kind::Ok(pb::api::ResolveResult {
707                    receipt: Some(asset),
708                })),
709            })
710        }
711
712        let msg = inner(request).unwrap_or_else(|err| pb::api::ResolveReply {
713            kind: Some(pb::api::resolve_reply::Kind::Error(pb::api::GenericError {
714                reason: err.to_string(),
715            })),
716        });
717
718        // tracing::trace!("tx: {msg:?}");
719        conn.send(msg)
720    }
721
722    fn on_identity_p254(
723        &self,
724        mut conn: ConnectionWrapper,
725        request: pb::api::IdentityP254Request,
726    ) -> Result<()> {
727        fn inner(request: pb::api::IdentityP254Request) -> Result<pb::api::IdentityP254Reply> {
728            let receipt_bytes = request
729                .receipt
730                .ok_or_else(|| malformed_err("IdentityP254Request.receipt"))?
731                .as_bytes()?;
732            let succinct_receipt: SuccinctReceipt<ReceiptClaim> =
733                bincode::deserialize(&receipt_bytes)?;
734
735            let receipt = identity_p254(&succinct_receipt)?;
736            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
737            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
738            let asset = pb::api::Asset::from_bytes(
739                &request
740                    .receipt_out
741                    .ok_or_else(|| malformed_err("IdentityP254Request.receipt_out"))?,
742                succinct_receipt_bytes.into(),
743                "receipt.zkp",
744            )?;
745
746            Ok(pb::api::IdentityP254Reply {
747                kind: Some(pb::api::identity_p254_reply::Kind::Ok(
748                    pb::api::IdentityP254Result {
749                        receipt: Some(asset),
750                    },
751                )),
752            })
753        }
754
755        let msg = inner(request).unwrap_or_else(|err| pb::api::IdentityP254Reply {
756            kind: Some(pb::api::identity_p254_reply::Kind::Error(
757                pb::api::GenericError {
758                    reason: err.to_string(),
759                },
760            )),
761        });
762
763        // tracing::trace!("tx: {msg:?}");
764        conn.send(msg)
765    }
766
767    fn on_compress(
768        &self,
769        mut conn: ConnectionWrapper,
770        request: pb::api::CompressRequest,
771    ) -> Result<()> {
772        fn inner(request: pb::api::CompressRequest) -> Result<pb::api::CompressReply> {
773            let opts: ProverOpts = request
774                .opts
775                .ok_or_else(|| malformed_err("CompressRequest.opts"))?
776                .try_into()?;
777            let receipt_bytes = request
778                .receipt
779                .ok_or_else(|| malformed_err("CompressRequest.receipt"))?
780                .as_bytes()?;
781            let receipt: Receipt = bincode::deserialize(&receipt_bytes)?;
782
783            let prover = get_prover_server(&opts)?;
784            let receipt = prover.compress(&opts, &receipt)?;
785
786            let receipt_pb: pb::core::Receipt = receipt.try_into()?;
787            let receipt_bytes = receipt_pb.encode_to_vec();
788            let asset = pb::api::Asset::from_bytes(
789                &request
790                    .receipt_out
791                    .ok_or_else(|| malformed_err("CompressRequest.receipt_out"))?,
792                receipt_bytes.into(),
793                "receipt.zkp",
794            )?;
795
796            Ok(pb::api::CompressReply {
797                kind: Some(pb::api::compress_reply::Kind::Ok(pb::api::CompressResult {
798                    receipt: Some(asset),
799                })),
800            })
801        }
802
803        let msg = inner(request).unwrap_or_else(|err| pb::api::CompressReply {
804            kind: Some(pb::api::compress_reply::Kind::Error(
805                pb::api::GenericError {
806                    reason: err.to_string(),
807                },
808            )),
809        });
810
811        // tracing::trace!("tx: {msg:?}");
812        conn.send(msg)
813    }
814
815    fn on_verify(
816        &self,
817        mut conn: ConnectionWrapper,
818        request: pb::api::VerifyRequest,
819    ) -> Result<()> {
820        fn inner(request: pb::api::VerifyRequest) -> Result<()> {
821            let receipt_bytes = request
822                .receipt
823                .ok_or_else(|| malformed_err("VerifyRequest.receipt"))?
824                .as_bytes()?;
825            let receipt: Receipt =
826                bincode::deserialize(&receipt_bytes).context("deserialize receipt")?;
827            let image_id: Digest = request
828                .image_id
829                .ok_or_else(|| malformed_err("VerifyRequest.image_id"))?
830                .try_into()?;
831            receipt
832                .verify(image_id)
833                .map_err(|err| anyhow!("verify failed: {err}"))
834        }
835
836        let msg: pb::api::GenericReply = inner(request).into();
837        // tracing::trace!("tx: {msg:?}");
838        conn.send(msg)
839    }
840}
841
842fn build_env<'a>(
843    conn: &ConnectionWrapper,
844    request: &pb::api::ExecutorEnv,
845) -> Result<ExecutorEnv<'a>> {
846    let mut env_builder = ExecutorEnv::builder();
847    env_builder.env_vars(request.env_vars.clone());
848    env_builder.args(&request.args);
849    for fd in request.read_fds.iter() {
850        let proxy = PosixIoProxy::new(*fd, conn.clone());
851        let reader = BufReader::new(proxy);
852        env_builder.read_fd(*fd, reader);
853    }
854    for fd in request.write_fds.iter() {
855        let proxy = PosixIoProxy::new(*fd, conn.clone());
856        env_builder.write_fd(*fd, proxy);
857    }
858    let proxy = SliceIoProxy::new(conn.clone());
859    for name in request.slice_ios.iter() {
860        env_builder.slice_io(name, proxy.clone());
861    }
862    if let Some(segment_limit_po2) = request.segment_limit_po2 {
863        env_builder.segment_limit_po2(segment_limit_po2);
864    }
865    if let Some(keccak_max_po2) = request.keccak_max_po2 {
866        env_builder.keccak_max_po2(keccak_max_po2)?;
867    }
868    env_builder.session_limit(request.session_limit);
869    if request.trace_events.is_some() {
870        let proxy = TraceProxy::new(conn.clone());
871        env_builder.trace_callback(proxy);
872    }
873    if !request.pprof_out.is_empty() {
874        env_builder.enable_profiler(Path::new(&request.pprof_out));
875    }
876    if !request.segment_path.is_empty() {
877        env_builder.segment_path(Path::new(&request.segment_path));
878    }
879    if request.coprocessor {
880        let proxy = CoprocessorProxy::new(conn.clone());
881        env_builder.coprocessor_callback(proxy);
882    }
883    if !request.povw_job_id.is_empty() {
884        let id = PovwJobId::from_bytes(
885            request
886                .povw_job_id
887                .as_slice()
888                .try_into()
889                .with_context(|| malformed_err("ExecutorEnv.povw_job_id"))?,
890        );
891        env_builder.povw(id);
892    }
893
894    for assumption in request.assumptions.iter() {
895        match assumption
896            .kind
897            .as_ref()
898            .ok_or_else(|| malformed_err("Assumption.kind"))?
899        {
900            pb::api::assumption_receipt::Kind::Proven(asset) => {
901                let receipt: InnerAssumptionReceipt =
902                    pb::core::InnerReceipt::decode(asset.as_bytes()?)?.try_into()?;
903                env_builder.add_assumption(receipt)
904            }
905            pb::api::assumption_receipt::Kind::Unresolved(asset) => {
906                let assumption: Assumption =
907                    pb::core::Assumption::decode(asset.as_bytes()?)?.try_into()?;
908                env_builder.add_assumption(assumption)
909            }
910        };
911    }
912    env_builder.build()
913}
914
915trait IoOtherError<T> {
916    fn map_io_err(self) -> Result<T, IoError>;
917}
918
919impl<T, E: Into<Box<dyn StdError + Send + Sync>>> IoOtherError<T> for Result<T, E> {
920    fn map_io_err(self) -> Result<T, IoError> {
921        self.map_err(|err| IoError::other(err))
922    }
923}
924
925impl From<pb::api::GenericError> for IoError {
926    fn from(err: pb::api::GenericError) -> Self {
927        IoError::other(err.reason)
928    }
929}
930
931impl pb::api::Asset {
932    pub fn from_bytes<P: AsRef<Path>>(
933        request: &pb::api::AssetRequest,
934        bytes: Bytes,
935        path: P,
936    ) -> Result<Self> {
937        match request
938            .kind
939            .as_ref()
940            .ok_or_else(|| malformed_err("AssetRequest.kind"))?
941        {
942            pb::api::asset_request::Kind::Inline(()) => Ok(Self {
943                kind: Some(pb::api::asset::Kind::Inline(bytes.into())),
944            }),
945            pb::api::asset_request::Kind::Path(base_path) => {
946                let base_path = PathBuf::from(base_path);
947                let path = base_path.join(path);
948                std::fs::write(&path, bytes)?;
949                Ok(Self {
950                    kind: Some(pb::api::asset::Kind::Path(path_to_string(path)?)),
951                })
952            }
953            pb::api::asset_request::Kind::Redis(_) => {
954                tracing::error!("It's likely that r0vm is not installed with the redis feature");
955                bail!("from_bytes not supported for redis")
956            }
957        }
958    }
959}
960
961#[allow(dead_code)]
962fn check_client_version(client: &semver::Version, server: &semver::Version) -> bool {
963    if server.pre.is_empty() {
964        let comparator = semver::Comparator {
965            op: semver::Op::GreaterEq,
966            major: server.major,
967            minor: Some(server.minor),
968            patch: None,
969            pre: semver::Prerelease::EMPTY,
970        };
971        comparator.matches(client)
972    } else {
973        client == server
974    }
975}
976
977#[allow(dead_code)]
978fn check_client_version_compat(client: &semver::Version, server: &semver::Version) -> bool {
979    client.major == server.major
980}
981
982#[cfg(feature = "redis")]
983fn execute_redis(
984    conn: &mut ConnectionWrapper,
985    env: ExecutorEnv,
986    bytes: Bytes,
987    params: super::RedisParams,
988) -> Result<Session> {
989    use redis::{Client, Commands, ConnectionLike, SetExpiry, SetOptions};
990    use std::{
991        sync::{
992            mpsc::{sync_channel, Receiver},
993            Arc, Mutex,
994        },
995        thread::{spawn, JoinHandle},
996    };
997
998    let channel_size = match std::env::var("RISC0_REDIS_CHANNEL_SIZE") {
999        Ok(val_str) => val_str.parse::<usize>().unwrap_or(100),
1000        Err(_) => 100,
1001    };
1002    let (sender, receiver) = sync_channel::<(String, Segment)>(channel_size);
1003    let opts = SetOptions::default().with_expiration(SetExpiry::EX(params.ttl));
1004
1005    let redis_err = Arc::new(Mutex::new(None));
1006    let redis_err_clone = redis_err.clone();
1007
1008    let conn = conn.clone();
1009    let join_handle: JoinHandle<()> = spawn(move || {
1010        fn inner(
1011            redis_url: String,
1012            receiver: &Receiver<(String, Segment)>,
1013            opts: SetOptions,
1014            mut conn: ConnectionWrapper,
1015        ) -> Result<()> {
1016            let client = Client::open(redis_url).context("Failed to open Redis connection")?;
1017            let mut connection = client
1018                .get_connection()
1019                .context("Failed to get redis connection")?;
1020            while let Ok((segment_key, segment)) = receiver.recv() {
1021                if !connection.is_open() {
1022                    connection = client
1023                        .get_connection()
1024                        .context("Failed to get redis connection")?;
1025                }
1026                let segment_bytes =
1027                    bincode::serialize(&segment).context("Failed to deserialize segment")?;
1028                match connection.set_options(segment_key.clone(), segment_bytes.clone(), opts) {
1029                    Ok(()) => (),
1030                    Err(err) => {
1031                        tracing::warn!(
1032                            "Failed to set redis key with TTL, trying again. Error: {err}"
1033                        );
1034                        connection = client
1035                            .get_connection()
1036                            .context("Failed to get redis connection")?;
1037                        let _: () = connection
1038                            .set_options(segment_key.clone(), segment_bytes, opts)
1039                            .context("Failed to set redis key with TTL again")?;
1040                    }
1041                };
1042                let asset = pb::api::Asset {
1043                    kind: Some(pb::api::asset::Kind::Redis(segment_key)),
1044                };
1045                send_segment_done_msg(&mut conn, segment, Some(asset))
1046                    .context("Failed to send segment_done msg")?;
1047            }
1048            Ok(())
1049        }
1050
1051        if let Err(err) = inner(params.url, &receiver, opts, conn) {
1052            *redis_err_clone.lock().unwrap() = Some(err);
1053        }
1054    });
1055
1056    let callback = |segment: Segment| -> Result<Box<dyn SegmentRef>> {
1057        let segment_key = format!("{}:{}", params.key, segment.index);
1058        if let Err(send_err) = sender.send((segment_key, segment)) {
1059            let mut redis_err_opt = redis_err.lock().unwrap();
1060            let redis_err_inner = redis_err_opt.take();
1061            return Err(match redis_err_inner {
1062                Some(redis_thread_err) => {
1063                    tracing::error!(
1064                        "Redis err: {redis_thread_err} root: {:?}",
1065                        redis_thread_err.root_cause()
1066                    );
1067                    anyhow!(redis_thread_err)
1068                }
1069                None => send_err.into(),
1070            });
1071        }
1072        Ok(Box::new(NullSegmentRef))
1073    };
1074
1075    let session = ExecutorImpl::from_elf(env, &bytes)?.run_with_callback(callback);
1076
1077    drop(sender);
1078
1079    join_handle
1080        .join()
1081        .map_err(|err| anyhow!("redis task join failed: {err:?}"))?;
1082
1083    session
1084}
1085
1086fn execute_default(
1087    conn: &mut ConnectionWrapper,
1088    env: ExecutorEnv,
1089    bytes: Bytes,
1090    segments_out: &pb::api::AssetRequest,
1091) -> Result<Session> {
1092    let callback = |segment: Segment| -> Result<Box<dyn SegmentRef>> {
1093        let segment_bytes = bincode::serialize(&segment)?;
1094        let asset = pb::api::Asset::from_bytes(
1095            segments_out,
1096            segment_bytes.into(),
1097            format!("segment-{}", segment.index),
1098        )?;
1099        send_segment_done_msg(conn, segment, Some(asset))?;
1100        Ok(Box::new(NullSegmentRef))
1101    };
1102
1103    ExecutorImpl::from_elf(env, &bytes)?.run_with_callback(callback)
1104}
1105
1106fn send_segment_done_msg(
1107    conn: &mut ConnectionWrapper,
1108    segment: Segment,
1109    some_asset: Option<pb::api::Asset>,
1110) -> Result<()> {
1111    let segment = Some(pb::api::SegmentInfo {
1112        index: segment.index,
1113        po2: segment.po2() as u32,
1114        cycles: segment.user_cycles(),
1115        segment: some_asset,
1116    });
1117
1118    let msg = pb::api::ServerReply {
1119        kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
1120            kind: Some(pb::api::client_callback::Kind::SegmentDone(
1121                pb::api::OnSegmentDone { segment },
1122            )),
1123        })),
1124    };
1125
1126    tracing::trace!("tx: {msg:?}");
1127    let reply: pb::api::GenericReply = conn.send_recv(msg)?;
1128    tracing::trace!("rx: {reply:?}");
1129
1130    let kind = reply
1131        .kind
1132        .ok_or_else(|| malformed_err("GenericReply.kind"))?;
1133    if let pb::api::generic_reply::Kind::Error(err) = kind {
1134        bail!(err)
1135    }
1136    Ok(())
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141    use semver::Version;
1142
1143    use super::{check_client_version, check_client_version_compat};
1144
1145    fn test_inner(check_func: fn(&Version, &Version) -> bool, client: &str, server: &str) -> bool {
1146        check_func(
1147            &Version::parse(client).unwrap(),
1148            &Version::parse(server).unwrap(),
1149        )
1150    }
1151
1152    #[test]
1153    fn check_version() {
1154        fn test(client: &str, server: &str) -> bool {
1155            test_inner(check_client_version, client, server)
1156        }
1157
1158        assert!(test("0.18.0", "0.18.0"));
1159        assert!(test("0.18.1", "0.18.0"));
1160        assert!(test("0.18.0", "0.18.1"));
1161        assert!(test("0.19.0", "0.18.0"));
1162        assert!(test("1.0.0", "0.18.0"));
1163        assert!(test("1.1.0", "1.0.0"));
1164        assert!(test("0.19.0-alpha.1", "0.19.0-alpha.1"));
1165
1166        assert!(!test("0.19.0-alpha.1", "0.19.0-alpha.2"));
1167        assert!(!test("0.18.0", "0.19.0"));
1168        assert!(!test("0.18.0", "1.0.0"));
1169    }
1170
1171    #[test]
1172    fn check_version_compat() {
1173        fn test(client: &str, server: &str) -> bool {
1174            test_inner(check_client_version_compat, client, server)
1175        }
1176
1177        assert!(test("1.1.0", "1.1.0"));
1178        assert!(test("1.1.1", "1.1.1"));
1179        assert!(test("1.2.0", "1.1.1"));
1180        assert!(test("1.2.0-rc.1", "1.1.1"));
1181
1182        assert!(!test("2.0.0", "1.1.1"));
1183    }
1184}