risc0_zkvm/host/api/
server.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    error::Error as StdError,
17    io::{BufReader, Error as IoError, ErrorKind as IoErrorKind, Read, Write},
18    path::{Path, PathBuf},
19};
20
21use anyhow::{anyhow, bail, Context, Result};
22use bytes::Bytes;
23use prost::Message;
24use risc0_zkp::core::digest::Digest;
25
26use super::{malformed_err, path_to_string, pb, ConnectionWrapper, Connector, TcpConnector};
27use crate::{
28    get_prover_server, get_version,
29    host::{
30        api::convert::keccak_input_to_bytes,
31        client::{
32            env::{CoprocessorCallback, ProveKeccakRequest, ProveZkrRequest},
33            slice_io::SliceIo,
34        },
35        server::{prove::keccak::prove_keccak, session::NullSegmentRef},
36    },
37    prove_registered_zkr,
38    recursion::identity_p254,
39    AssetRequest, Assumption, ExecutorEnv, ExecutorImpl, InnerAssumptionReceipt, ProverOpts,
40    Receipt, ReceiptClaim, Segment, SegmentReceipt, Session, SuccinctReceipt, TraceCallback,
41    TraceEvent, VerifierContext,
42};
43
44/// A server implementation for handling requests by clients of the zkVM.
45pub struct Server {
46    connector: Box<dyn Connector>,
47}
48struct PosixIoProxy {
49    fd: u32,
50    conn: ConnectionWrapper,
51}
52
53impl PosixIoProxy {
54    fn new(fd: u32, conn: ConnectionWrapper) -> Self {
55        PosixIoProxy { fd, conn }
56    }
57}
58
59impl Read for PosixIoProxy {
60    fn read(&mut self, to_guest: &mut [u8]) -> std::io::Result<usize> {
61        let nread = to_guest.len().try_into().map_io_err()?;
62        let request = pb::api::ServerReply {
63            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
64                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
65                    kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
66                        fd: self.fd,
67                        cmd: Some(pb::api::PosixCmd {
68                            kind: Some(pb::api::posix_cmd::Kind::Read(nread)),
69                        }),
70                    })),
71                })),
72            })),
73        };
74
75        tracing::trace!("tx: {request:?}");
76        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
77        tracing::trace!("rx: {reply:?}");
78
79        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
80        match kind {
81            pb::api::on_io_reply::Kind::Ok(bytes) => {
82                let (head, _) = to_guest.split_at_mut(bytes.len());
83                head.copy_from_slice(&bytes);
84                Ok(bytes.len())
85            }
86            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
87        }
88    }
89}
90
91impl Write for PosixIoProxy {
92    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
93        let request = pb::api::ServerReply {
94            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
95                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
96                    kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
97                        fd: self.fd,
98                        cmd: Some(pb::api::PosixCmd {
99                            kind: Some(pb::api::posix_cmd::Kind::Write(buf.into())),
100                        }),
101                    })),
102                })),
103            })),
104        };
105
106        tracing::trace!("tx: {request:?}");
107        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
108        tracing::trace!("rx: {reply:?}");
109
110        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
111        match kind {
112            pb::api::on_io_reply::Kind::Ok(_) => Ok(buf.len()),
113            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
114        }
115    }
116
117    fn flush(&mut self) -> std::io::Result<()> {
118        Ok(())
119    }
120}
121
122#[derive(Clone)]
123struct SliceIoProxy {
124    conn: ConnectionWrapper,
125}
126
127impl SliceIoProxy {
128    fn new(conn: ConnectionWrapper) -> Self {
129        Self { conn }
130    }
131}
132
133impl SliceIo for SliceIoProxy {
134    fn handle_io(&mut self, syscall: &str, from_guest: Bytes) -> Result<Bytes> {
135        let request = pb::api::ServerReply {
136            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
137                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
138                    kind: Some(pb::api::on_io_request::Kind::Slice(pb::api::SliceIo {
139                        name: syscall.to_string(),
140                        from_guest: from_guest.into(),
141                    })),
142                })),
143            })),
144        };
145        tracing::trace!("tx: {request:?}");
146        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
147        tracing::trace!("rx: {reply:?}");
148
149        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
150        match kind {
151            pb::api::on_io_reply::Kind::Ok(buf) => Ok(buf.into()),
152            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
153        }
154    }
155}
156
157struct TraceProxy {
158    conn: ConnectionWrapper,
159}
160
161impl TraceProxy {
162    fn new(conn: ConnectionWrapper) -> Self {
163        Self { conn }
164    }
165}
166
167impl TraceCallback for TraceProxy {
168    fn trace_callback(&mut self, event: TraceEvent) -> Result<()> {
169        let Ok(event) = event.clone().try_into() else {
170            tracing::trace!("ignoring unknown event {event:?}");
171            return Ok(());
172        };
173
174        let request = pb::api::ServerReply {
175            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
176                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
177                    kind: Some(pb::api::on_io_request::Kind::Trace(event)),
178                })),
179            })),
180        };
181        tracing::trace!("tx: {request:?}");
182        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
183        tracing::trace!("rx: {reply:?}");
184
185        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
186        match kind {
187            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
188            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
189        }
190    }
191}
192
193struct CoprocessorProxy {
194    conn: ConnectionWrapper,
195}
196
197impl CoprocessorProxy {
198    fn new(conn: ConnectionWrapper) -> Self {
199        Self { conn }
200    }
201}
202
203impl CoprocessorCallback for CoprocessorProxy {
204    fn prove_zkr(&mut self, proof_request: ProveZkrRequest) -> Result<()> {
205        let request = pb::api::ServerReply {
206            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
207                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
208                    kind: Some(pb::api::on_io_request::Kind::Coprocessor(
209                        pb::api::CoprocessorRequest {
210                            kind: Some(pb::api::coprocessor_request::Kind::ProveZkr({
211                                pb::api::ProveZkrRequest {
212                                    claim_digest: Some(proof_request.claim_digest.into()),
213                                    control_id: Some(proof_request.control_id.into()),
214                                    input: proof_request.input,
215                                    receipt_out: None,
216                                }
217                            })),
218                        },
219                    )),
220                })),
221            })),
222        };
223        tracing::trace!("tx: {request:?}");
224        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
225        tracing::trace!("rx: {reply:?}");
226
227        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
228        match kind {
229            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
230            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
231        }
232    }
233
234    fn prove_keccak(&mut self, proof_request: ProveKeccakRequest) -> Result<()> {
235        let input = keccak_input_to_bytes(&proof_request.input);
236        let request = pb::api::ServerReply {
237            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
238                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
239                    kind: Some(pb::api::on_io_request::Kind::Coprocessor(
240                        pb::api::CoprocessorRequest {
241                            kind: Some(pb::api::coprocessor_request::Kind::ProveKeccak({
242                                pb::api::ProveKeccakRequest {
243                                    claim_digest: Some(proof_request.claim_digest.into()),
244                                    po2: proof_request.po2 as u32,
245                                    control_root: Some(proof_request.control_root.into()),
246                                    input,
247                                    receipt_out: None,
248                                }
249                            })),
250                        },
251                    )),
252                })),
253            })),
254        };
255        tracing::trace!("tx: {request:?}");
256        self.conn.send(request)?;
257
258        let reply: pb::api::OnIoReply = self.conn.recv().map_io_err()?;
259        tracing::trace!("rx: {reply:?}");
260
261        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
262        match kind {
263            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
264            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
265        }
266    }
267}
268
269impl Server {
270    /// Construct a new [Server] with the specified [Connector].
271    pub fn new(connector: Box<dyn Connector>) -> Self {
272        Self { connector }
273    }
274
275    /// Construct a new [Server] which will connect to the specified TCP/IP
276    /// address.
277    pub fn new_tcp<A: AsRef<str>>(addr: A) -> Self {
278        let connector = TcpConnector::new(addr.as_ref());
279        Self::new(Box::new(connector))
280    }
281
282    /// Start the [Server] and run until all requests are complete.
283    pub fn run(&self) -> Result<()> {
284        tracing::debug!("connect");
285        let mut conn = self.connector.connect()?;
286
287        let server_version = get_version().map_err(|err| anyhow!(err))?;
288
289        let request: pb::api::HelloRequest = conn.recv()?;
290        tracing::trace!("rx: {request:?}");
291
292        let client_version: semver::Version = request
293            .version
294            .ok_or(malformed_err())?
295            .try_into()
296            .map_err(|err: semver::Error| anyhow!(err))?;
297
298        #[cfg(not(feature = "r0vm-ver-compat"))]
299        let check_client_func = check_client_version;
300        #[cfg(feature = "r0vm-ver-compat")]
301        let check_client_func = check_client_version_compat;
302
303        if !check_client_func(&client_version, &server_version) {
304            let msg = format!(
305                "incompatible client version: {client_version}, server version: {server_version}"
306            );
307            tracing::debug!("{msg}");
308            bail!(msg);
309        }
310
311        let reply = pb::api::HelloReply {
312            kind: Some(pb::api::hello_reply::Kind::Ok(pb::api::HelloResult {
313                version: Some(server_version.into()),
314            })),
315        };
316        tracing::trace!("tx: {reply:?}");
317        let request: pb::api::ServerRequest = conn.send_recv(reply)?;
318        tracing::trace!("rx: {request:?}");
319
320        match request.kind.ok_or(malformed_err())? {
321            pb::api::server_request::Kind::Prove(request) => self.on_prove(conn, request),
322            pb::api::server_request::Kind::Execute(request) => self.on_execute(conn, request),
323            pb::api::server_request::Kind::ProveSegment(request) => {
324                self.on_prove_segment(conn, request)
325            }
326            pb::api::server_request::Kind::Lift(request) => self.on_lift(conn, request),
327            pb::api::server_request::Kind::Join(request) => self.on_join(conn, request),
328            pb::api::server_request::Kind::Resolve(request) => self.on_resolve(conn, request),
329            pb::api::server_request::Kind::IdentityP254(request) => {
330                self.on_identity_p254(conn, request)
331            }
332            pb::api::server_request::Kind::Compress(request) => self.on_compress(conn, request),
333            pb::api::server_request::Kind::Verify(request) => self.on_verify(conn, request),
334            pb::api::server_request::Kind::ProveZkr(request) => self.on_prove_zkr(conn, request),
335            pb::api::server_request::Kind::ProveKeccak(request) => {
336                self.on_prove_keccak(conn, request)
337            }
338        }
339    }
340
341    fn on_execute(
342        &self,
343        mut conn: ConnectionWrapper,
344        request: pb::api::ExecuteRequest,
345    ) -> Result<()> {
346        fn inner(
347            conn: &mut ConnectionWrapper,
348            request: pb::api::ExecuteRequest,
349        ) -> Result<pb::api::ServerReply> {
350            let env_request = request.env.ok_or(malformed_err())?;
351            let env = build_env(conn, &env_request)?;
352
353            let binary = env_request.binary.ok_or(malformed_err())?;
354
355            let segments_out = request.segments_out.ok_or(malformed_err())?;
356            let bytes = binary.as_bytes()?;
357            let mut exec = ExecutorImpl::from_elf(env, &bytes)?;
358
359            let session = match AssetRequest::try_from(segments_out.clone())? {
360                #[cfg(feature = "redis")]
361                AssetRequest::Redis(params) => execute_redis(conn, &mut exec, params)?,
362                _ => execute_default(conn, &mut exec, &segments_out)?,
363            };
364
365            let receipt_claim = session.claim()?;
366            Ok(pb::api::ServerReply {
367                kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
368                    kind: Some(pb::api::client_callback::Kind::SessionDone(
369                        pb::api::OnSessionDone {
370                            session: Some(pb::api::SessionInfo {
371                                segments: session.segments.len().try_into()?,
372                                journal: session.journal.unwrap_or_default().bytes,
373                                exit_code: Some(session.exit_code.into()),
374                                receipt_claim: Some(pb::api::Asset::from_bytes(
375                                    &pb::api::AssetRequest {
376                                        kind: Some(pb::api::asset_request::Kind::Inline(())),
377                                    },
378                                    pb::core::ReceiptClaim::from(receipt_claim)
379                                        .encode_to_vec()
380                                        .into(),
381                                    "session_info.claim",
382                                )?),
383                            }),
384                        },
385                    )),
386                })),
387            })
388        }
389
390        let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
391            kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
392                reason: err.to_string(),
393            })),
394        });
395
396        tracing::trace!("tx: {msg:?}");
397        conn.send(msg)
398    }
399
400    fn on_prove(&self, mut conn: ConnectionWrapper, request: pb::api::ProveRequest) -> Result<()> {
401        fn inner(
402            conn: &mut ConnectionWrapper,
403            request: pb::api::ProveRequest,
404        ) -> Result<pb::api::ServerReply> {
405            let env_request = request.env.ok_or(malformed_err())?;
406            let env = build_env(conn, &env_request)?;
407
408            let binary = env_request.binary.ok_or(malformed_err())?;
409            let bytes = binary.as_bytes()?;
410
411            let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
412            let prover = get_prover_server(&opts)?;
413            let ctx = VerifierContext::default();
414            let prove_info = prover.prove_with_ctx(env, &ctx, &bytes)?;
415
416            let prove_info: pb::core::ProveInfo = prove_info.into();
417            let prove_info_bytes = prove_info.encode_to_vec();
418            let asset = pb::api::Asset::from_bytes(
419                &request.receipt_out.ok_or(malformed_err())?,
420                prove_info_bytes.into(),
421                "prove_info.zkp",
422            )?;
423
424            Ok(pb::api::ServerReply {
425                kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
426                    kind: Some(pb::api::client_callback::Kind::ProveDone(
427                        pb::api::OnProveDone {
428                            prove_info: Some(asset),
429                        },
430                    )),
431                })),
432            })
433        }
434
435        let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
436            kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
437                reason: err.to_string(),
438            })),
439        });
440
441        tracing::trace!("tx: {msg:?}");
442        conn.send(msg)
443    }
444
445    fn on_prove_segment(
446        &self,
447        mut conn: ConnectionWrapper,
448        request: pb::api::ProveSegmentRequest,
449    ) -> Result<()> {
450        fn inner(request: pb::api::ProveSegmentRequest) -> Result<pb::api::ProveSegmentReply> {
451            let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
452            let segment_bytes = request.segment.ok_or(malformed_err())?.as_bytes()?;
453            let segment: Segment = bincode::deserialize(&segment_bytes)?;
454
455            let prover = get_prover_server(&opts)?;
456            let ctx = VerifierContext::default();
457            let receipt = prover.prove_segment(&ctx, &segment)?;
458
459            let receipt_pb: pb::core::SegmentReceipt = receipt.into();
460            let receipt_bytes = receipt_pb.encode_to_vec();
461            let asset = pb::api::Asset::from_bytes(
462                &request.receipt_out.ok_or(malformed_err())?,
463                receipt_bytes.into(),
464                "receipt.zkp",
465            )?;
466
467            Ok(pb::api::ProveSegmentReply {
468                kind: Some(pb::api::prove_segment_reply::Kind::Ok(
469                    pb::api::ProveSegmentResult {
470                        receipt: Some(asset),
471                    },
472                )),
473            })
474        }
475
476        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveSegmentReply {
477            kind: Some(pb::api::prove_segment_reply::Kind::Error(
478                pb::api::GenericError {
479                    reason: err.to_string(),
480                },
481            )),
482        });
483
484        tracing::trace!("tx: {msg:?}");
485        conn.send(msg)
486    }
487
488    fn on_prove_zkr(
489        &self,
490        mut conn: ConnectionWrapper,
491        request: pb::api::ProveZkrRequest,
492    ) -> Result<()> {
493        fn inner(request: pb::api::ProveZkrRequest) -> Result<pb::api::ProveZkrReply> {
494            let control_id = request.control_id.ok_or(malformed_err())?.try_into()?;
495            let receipt = prove_registered_zkr(&control_id, vec![control_id], &request.input)?;
496
497            let receipt_pb: pb::core::SuccinctReceipt = receipt.into();
498            let receipt_bytes = receipt_pb.encode_to_vec();
499            let asset = pb::api::Asset::from_bytes(
500                &request.receipt_out.ok_or(malformed_err())?,
501                receipt_bytes.into(),
502                "receipt.zkp",
503            )?;
504
505            Ok(pb::api::ProveZkrReply {
506                kind: Some(pb::api::prove_zkr_reply::Kind::Ok(
507                    pb::api::ProveZkrResult {
508                        receipt: Some(asset),
509                    },
510                )),
511            })
512        }
513
514        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveZkrReply {
515            kind: Some(pb::api::prove_zkr_reply::Kind::Error(
516                pb::api::GenericError {
517                    reason: err.to_string(),
518                },
519            )),
520        });
521
522        tracing::trace!("tx: {msg:?}");
523        conn.send(msg)
524    }
525
526    fn on_prove_keccak(
527        &self,
528        mut conn: ConnectionWrapper,
529        request: pb::api::ProveKeccakRequest,
530    ) -> Result<()> {
531        fn inner(request_pb: pb::api::ProveKeccakRequest) -> Result<pb::api::ProveKeccakReply> {
532            let request: ProveKeccakRequest = request_pb.clone().try_into()?;
533            let receipt = prove_keccak(&request)?;
534
535            let receipt_pb: pb::core::SuccinctReceipt = receipt.into();
536            let receipt_bytes = receipt_pb.encode_to_vec();
537            let asset = pb::api::Asset::from_bytes(
538                &request_pb.receipt_out.ok_or(malformed_err())?,
539                receipt_bytes.into(),
540                "receipt.zkp",
541            )?;
542
543            Ok(pb::api::ProveKeccakReply {
544                kind: Some(pb::api::prove_keccak_reply::Kind::Ok(
545                    pb::api::ProveKeccakResult {
546                        receipt: Some(asset),
547                    },
548                )),
549            })
550        }
551
552        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveKeccakReply {
553            kind: Some(pb::api::prove_keccak_reply::Kind::Error(
554                pb::api::GenericError {
555                    reason: err.to_string(),
556                },
557            )),
558        });
559
560        tracing::trace!("tx: {msg:?}");
561        conn.send(msg)
562    }
563
564    fn on_lift(&self, mut conn: ConnectionWrapper, request: pb::api::LiftRequest) -> Result<()> {
565        fn inner(request: pb::api::LiftRequest) -> Result<pb::api::LiftReply> {
566            let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
567            let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
568            let segment_receipt: SegmentReceipt = bincode::deserialize(&receipt_bytes)?;
569
570            let prover = get_prover_server(&opts)?;
571            let receipt = prover.lift(&segment_receipt)?;
572
573            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
574            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
575            let asset = pb::api::Asset::from_bytes(
576                &request.receipt_out.ok_or(malformed_err())?,
577                succinct_receipt_bytes.into(),
578                "receipt.zkp",
579            )?;
580
581            Ok(pb::api::LiftReply {
582                kind: Some(pb::api::lift_reply::Kind::Ok(pb::api::LiftResult {
583                    receipt: Some(asset),
584                })),
585            })
586        }
587
588        let msg = inner(request).unwrap_or_else(|err| pb::api::LiftReply {
589            kind: Some(pb::api::lift_reply::Kind::Error(pb::api::GenericError {
590                reason: err.to_string(),
591            })),
592        });
593
594        // tracing::trace!("tx: {msg:?}");
595        conn.send(msg)
596    }
597
598    fn on_join(&self, mut conn: ConnectionWrapper, request: pb::api::JoinRequest) -> Result<()> {
599        fn inner(request: pb::api::JoinRequest) -> Result<pb::api::JoinReply> {
600            let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
601            let left_receipt_bytes = request.left_receipt.ok_or(malformed_err())?.as_bytes()?;
602            let left_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
603                bincode::deserialize(&left_receipt_bytes)?;
604            let right_receipt_bytes = request.right_receipt.ok_or(malformed_err())?.as_bytes()?;
605            let right_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
606                bincode::deserialize(&right_receipt_bytes)?;
607
608            let prover = get_prover_server(&opts)?;
609            let receipt = prover.join(&left_succinct_receipt, &right_succinct_receipt)?;
610
611            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
612            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
613            let asset = pb::api::Asset::from_bytes(
614                &request.receipt_out.ok_or(malformed_err())?,
615                succinct_receipt_bytes.into(),
616                "receipt.zkp",
617            )?;
618
619            Ok(pb::api::JoinReply {
620                kind: Some(pb::api::join_reply::Kind::Ok(pb::api::JoinResult {
621                    receipt: Some(asset),
622                })),
623            })
624        }
625
626        let msg = inner(request).unwrap_or_else(|err| pb::api::JoinReply {
627            kind: Some(pb::api::join_reply::Kind::Error(pb::api::GenericError {
628                reason: err.to_string(),
629            })),
630        });
631
632        // tracing::trace!("tx: {msg:?}");
633        conn.send(msg)
634    }
635
636    fn on_resolve(
637        &self,
638        mut conn: ConnectionWrapper,
639        request: pb::api::ResolveRequest,
640    ) -> Result<()> {
641        fn inner(request: pb::api::ResolveRequest) -> Result<pb::api::ResolveReply> {
642            let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
643            let conditional_receipt_bytes = request
644                .conditional_receipt
645                .ok_or(malformed_err())?
646                .as_bytes()?;
647            let conditional_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
648                bincode::deserialize(&conditional_receipt_bytes)?;
649            let assumption_receipt_bytes = request
650                .assumption_receipt
651                .ok_or(malformed_err())?
652                .as_bytes()?;
653            let assumption_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
654                bincode::deserialize(&assumption_receipt_bytes)?;
655
656            let prover = get_prover_server(&opts)?;
657            let receipt = prover.resolve(
658                &conditional_succinct_receipt,
659                &assumption_succinct_receipt.into_unknown(),
660            )?;
661
662            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
663            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
664            let asset = pb::api::Asset::from_bytes(
665                &request.receipt_out.ok_or(malformed_err())?,
666                succinct_receipt_bytes.into(),
667                "receipt.zkp",
668            )?;
669
670            Ok(pb::api::ResolveReply {
671                kind: Some(pb::api::resolve_reply::Kind::Ok(pb::api::ResolveResult {
672                    receipt: Some(asset),
673                })),
674            })
675        }
676
677        let msg = inner(request).unwrap_or_else(|err| pb::api::ResolveReply {
678            kind: Some(pb::api::resolve_reply::Kind::Error(pb::api::GenericError {
679                reason: err.to_string(),
680            })),
681        });
682
683        // tracing::trace!("tx: {msg:?}");
684        conn.send(msg)
685    }
686
687    fn on_identity_p254(
688        &self,
689        mut conn: ConnectionWrapper,
690        request: pb::api::IdentityP254Request,
691    ) -> Result<()> {
692        fn inner(request: pb::api::IdentityP254Request) -> Result<pb::api::IdentityP254Reply> {
693            let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
694            let succinct_receipt: SuccinctReceipt<ReceiptClaim> =
695                bincode::deserialize(&receipt_bytes)?;
696
697            let receipt = identity_p254(&succinct_receipt)?;
698            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.into();
699            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
700            let asset = pb::api::Asset::from_bytes(
701                &request.receipt_out.ok_or(malformed_err())?,
702                succinct_receipt_bytes.into(),
703                "receipt.zkp",
704            )?;
705
706            Ok(pb::api::IdentityP254Reply {
707                kind: Some(pb::api::identity_p254_reply::Kind::Ok(
708                    pb::api::IdentityP254Result {
709                        receipt: Some(asset),
710                    },
711                )),
712            })
713        }
714
715        let msg = inner(request).unwrap_or_else(|err| pb::api::IdentityP254Reply {
716            kind: Some(pb::api::identity_p254_reply::Kind::Error(
717                pb::api::GenericError {
718                    reason: err.to_string(),
719                },
720            )),
721        });
722
723        // tracing::trace!("tx: {msg:?}");
724        conn.send(msg)
725    }
726
727    fn on_compress(
728        &self,
729        mut conn: ConnectionWrapper,
730        request: pb::api::CompressRequest,
731    ) -> Result<()> {
732        fn inner(request: pb::api::CompressRequest) -> Result<pb::api::CompressReply> {
733            let opts: ProverOpts = request.opts.ok_or(malformed_err())?.try_into()?;
734            let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
735            let receipt: Receipt = bincode::deserialize(&receipt_bytes)?;
736
737            let prover = get_prover_server(&opts)?;
738            let receipt = prover.compress(&opts, &receipt)?;
739
740            let receipt_pb: pb::core::Receipt = receipt.into();
741            let receipt_bytes = receipt_pb.encode_to_vec();
742            let asset = pb::api::Asset::from_bytes(
743                &request.receipt_out.ok_or(malformed_err())?,
744                receipt_bytes.into(),
745                "receipt.zkp",
746            )?;
747
748            Ok(pb::api::CompressReply {
749                kind: Some(pb::api::compress_reply::Kind::Ok(pb::api::CompressResult {
750                    receipt: Some(asset),
751                })),
752            })
753        }
754
755        let msg = inner(request).unwrap_or_else(|err| pb::api::CompressReply {
756            kind: Some(pb::api::compress_reply::Kind::Error(
757                pb::api::GenericError {
758                    reason: err.to_string(),
759                },
760            )),
761        });
762
763        // tracing::trace!("tx: {msg:?}");
764        conn.send(msg)
765    }
766
767    fn on_verify(
768        &self,
769        mut conn: ConnectionWrapper,
770        request: pb::api::VerifyRequest,
771    ) -> Result<()> {
772        fn inner(request: pb::api::VerifyRequest) -> Result<()> {
773            let receipt_bytes = request.receipt.ok_or(malformed_err())?.as_bytes()?;
774            let receipt: Receipt =
775                bincode::deserialize(&receipt_bytes).context("deserialize receipt")?;
776            let image_id: Digest = request.image_id.ok_or(malformed_err())?.try_into()?;
777            receipt
778                .verify(image_id)
779                .map_err(|err| anyhow!("verify failed: {err}"))
780        }
781
782        let msg: pb::api::GenericReply = inner(request).into();
783        // tracing::trace!("tx: {msg:?}");
784        conn.send(msg)
785    }
786}
787
788fn build_env<'a>(
789    conn: &ConnectionWrapper,
790    request: &pb::api::ExecutorEnv,
791) -> Result<ExecutorEnv<'a>> {
792    let mut env_builder = ExecutorEnv::builder();
793    env_builder.env_vars(request.env_vars.clone());
794    env_builder.args(&request.args);
795    for fd in request.read_fds.iter() {
796        let proxy = PosixIoProxy::new(*fd, conn.clone());
797        let reader = BufReader::new(proxy);
798        env_builder.read_fd(*fd, reader);
799    }
800    for fd in request.write_fds.iter() {
801        let proxy = PosixIoProxy::new(*fd, conn.clone());
802        env_builder.write_fd(*fd, proxy);
803    }
804    let proxy = SliceIoProxy::new(conn.clone());
805    for name in request.slice_ios.iter() {
806        env_builder.slice_io(name, proxy.clone());
807    }
808    if let Some(segment_limit_po2) = request.segment_limit_po2 {
809        env_builder.segment_limit_po2(segment_limit_po2);
810    }
811    env_builder.session_limit(request.session_limit);
812    if request.trace_events.is_some() {
813        let proxy = TraceProxy::new(conn.clone());
814        env_builder.trace_callback(proxy);
815    }
816    if !request.pprof_out.is_empty() {
817        env_builder.enable_profiler(Path::new(&request.pprof_out));
818    }
819    if !request.segment_path.is_empty() {
820        env_builder.segment_path(Path::new(&request.segment_path));
821    }
822    if request.coprocessor {
823        let proxy = CoprocessorProxy::new(conn.clone());
824        env_builder.coprocessor_callback(proxy);
825    }
826
827    for assumption in request.assumptions.iter() {
828        match assumption.kind.as_ref().ok_or(malformed_err())? {
829            pb::api::assumption_receipt::Kind::Proven(asset) => {
830                let receipt: InnerAssumptionReceipt =
831                    pb::core::InnerReceipt::decode(asset.as_bytes()?)?.try_into()?;
832                env_builder.add_assumption(receipt)
833            }
834            pb::api::assumption_receipt::Kind::Unresolved(asset) => {
835                let assumption: Assumption =
836                    pb::core::Assumption::decode(asset.as_bytes()?)?.try_into()?;
837                env_builder.add_assumption(assumption)
838            }
839        };
840    }
841    env_builder.build()
842}
843
844trait IoOtherError<T> {
845    fn map_io_err(self) -> Result<T, IoError>;
846}
847
848impl<T, E: Into<Box<dyn StdError + Send + Sync>>> IoOtherError<T> for Result<T, E> {
849    fn map_io_err(self) -> Result<T, IoError> {
850        self.map_err(|err| IoError::new(IoErrorKind::Other, err))
851    }
852}
853
854impl From<pb::api::GenericError> for IoError {
855    fn from(err: pb::api::GenericError) -> Self {
856        IoError::new(IoErrorKind::Other, err.reason)
857    }
858}
859
860impl pb::api::Asset {
861    pub fn from_bytes<P: AsRef<Path>>(
862        request: &pb::api::AssetRequest,
863        bytes: Bytes,
864        path: P,
865    ) -> Result<Self> {
866        match request.kind.as_ref().ok_or(malformed_err())? {
867            pb::api::asset_request::Kind::Inline(()) => Ok(Self {
868                kind: Some(pb::api::asset::Kind::Inline(bytes.into())),
869            }),
870            pb::api::asset_request::Kind::Path(base_path) => {
871                let base_path = PathBuf::from(base_path);
872                let path = base_path.join(path);
873                std::fs::write(&path, bytes)?;
874                Ok(Self {
875                    kind: Some(pb::api::asset::Kind::Path(path_to_string(path)?)),
876                })
877            }
878            pb::api::asset_request::Kind::Redis(_) => {
879                tracing::error!("It's likely that r0vm is not installed with the redis feature");
880                bail!("from_bytes not supported for redis")
881            }
882        }
883    }
884}
885
886#[allow(dead_code)]
887fn check_client_version(client: &semver::Version, server: &semver::Version) -> bool {
888    if server.pre.is_empty() {
889        let comparator = semver::Comparator {
890            op: semver::Op::GreaterEq,
891            major: server.major,
892            minor: Some(server.minor),
893            patch: None,
894            pre: semver::Prerelease::EMPTY,
895        };
896        comparator.matches(client)
897    } else {
898        client == server
899    }
900}
901
902#[allow(dead_code)]
903fn check_client_version_compat(client: &semver::Version, server: &semver::Version) -> bool {
904    client.major == server.major
905}
906
907#[cfg(feature = "redis")]
908fn execute_redis(
909    conn: &mut ConnectionWrapper,
910    exec: &mut ExecutorImpl,
911    params: super::RedisParams,
912) -> Result<Session> {
913    use redis::{Client, Commands, ConnectionLike, SetExpiry, SetOptions};
914    use std::{
915        sync::{
916            mpsc::{sync_channel, Receiver},
917            Arc, Mutex,
918        },
919        thread::{spawn, JoinHandle},
920    };
921
922    let channel_size = match std::env::var("RISC0_REDIS_CHANNEL_SIZE") {
923        Ok(val_str) => val_str.parse::<usize>().unwrap_or(100),
924        Err(_) => 100,
925    };
926    let (sender, receiver) = sync_channel::<(String, Segment)>(channel_size);
927    let opts = SetOptions::default().with_expiration(SetExpiry::EX(params.ttl));
928
929    let redis_err = Arc::new(Mutex::new(None));
930    let redis_err_clone = redis_err.clone();
931
932    let conn = conn.clone();
933    let join_handle: JoinHandle<()> = spawn(move || {
934        fn inner(
935            redis_url: String,
936            receiver: &Receiver<(String, Segment)>,
937            opts: SetOptions,
938            mut conn: ConnectionWrapper,
939        ) -> Result<()> {
940            let client = Client::open(redis_url).context("Failed to open Redis connection")?;
941            let mut connection = client
942                .get_connection()
943                .context("Failed to get redis connection")?;
944            while let Ok((segment_key, segment)) = receiver.recv() {
945                if !connection.is_open() {
946                    connection = client
947                        .get_connection()
948                        .context("Failed to get redis connection")?;
949                }
950                let segment_bytes =
951                    bincode::serialize(&segment).context("Failed to deserialize segment")?;
952                match connection.set_options(segment_key.clone(), segment_bytes.clone(), opts) {
953                    Ok(()) => (),
954                    Err(err) => {
955                        tracing::warn!(
956                            "Failed to set redis key with TTL, trying again. Error: {err}"
957                        );
958                        connection = client
959                            .get_connection()
960                            .context("Failed to get redis connection")?;
961                        let _: () = connection
962                            .set_options(segment_key.clone(), segment_bytes, opts)
963                            .context("Failed to set redis key with TTL again")?;
964                    }
965                };
966                let asset = pb::api::Asset {
967                    kind: Some(pb::api::asset::Kind::Redis(segment_key)),
968                };
969                send_segment_done_msg(&mut conn, segment, Some(asset))
970                    .context("Failed to send segment_done msg")?;
971            }
972            Ok(())
973        }
974
975        if let Err(err) = inner(params.url, &receiver, opts, conn) {
976            *redis_err_clone.lock().unwrap() = Some(err);
977        }
978    });
979
980    let session = exec.run_with_callback(|segment| {
981        let segment_key = format!("{}:{}", params.key, segment.index);
982        if let Err(send_err) = sender.send((segment_key, segment)) {
983            let mut redis_err_opt = redis_err.lock().unwrap();
984            let redis_err_inner = redis_err_opt.take();
985            return Err(match redis_err_inner {
986                Some(redis_thread_err) => {
987                    tracing::error!(
988                        "Redis err: {redis_thread_err} root: {:?}",
989                        redis_thread_err.root_cause()
990                    );
991                    anyhow!(redis_thread_err)
992                }
993                None => send_err.into(),
994            });
995        }
996        Ok(Box::new(NullSegmentRef))
997    });
998
999    drop(sender);
1000
1001    join_handle
1002        .join()
1003        .map_err(|err| anyhow!("redis task join failed: {err:?}"))?;
1004
1005    session
1006}
1007
1008fn execute_default(
1009    conn: &mut ConnectionWrapper,
1010    exec: &mut ExecutorImpl,
1011    segments_out: &pb::api::AssetRequest,
1012) -> Result<Session> {
1013    exec.run_with_callback(|segment| {
1014        let segment_bytes = bincode::serialize(&segment)?;
1015        let asset = pb::api::Asset::from_bytes(
1016            segments_out,
1017            segment_bytes.into(),
1018            format!("segment-{}", segment.index),
1019        )?;
1020        send_segment_done_msg(conn, segment, Some(asset))?;
1021        Ok(Box::new(NullSegmentRef))
1022    })
1023}
1024
1025fn send_segment_done_msg(
1026    conn: &mut ConnectionWrapper,
1027    segment: Segment,
1028    some_asset: Option<pb::api::Asset>,
1029) -> Result<()> {
1030    let segment = Some(pb::api::SegmentInfo {
1031        index: segment.index,
1032        po2: segment.inner.po2 as u32,
1033        cycles: segment.inner.insn_cycles as u32,
1034        segment: some_asset,
1035    });
1036
1037    let msg = pb::api::ServerReply {
1038        kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
1039            kind: Some(pb::api::client_callback::Kind::SegmentDone(
1040                pb::api::OnSegmentDone { segment },
1041            )),
1042        })),
1043    };
1044
1045    tracing::trace!("tx: {msg:?}");
1046    let reply: pb::api::GenericReply = conn.send_recv(msg)?;
1047    tracing::trace!("rx: {reply:?}");
1048
1049    let kind = reply.kind.ok_or(malformed_err())?;
1050    if let pb::api::generic_reply::Kind::Error(err) = kind {
1051        bail!(err)
1052    }
1053    Ok(())
1054}
1055
1056#[cfg(test)]
1057mod tests {
1058    use semver::Version;
1059
1060    use super::{check_client_version, check_client_version_compat};
1061
1062    fn test_inner(check_func: fn(&Version, &Version) -> bool, client: &str, server: &str) -> bool {
1063        check_func(
1064            &Version::parse(client).unwrap(),
1065            &Version::parse(server).unwrap(),
1066        )
1067    }
1068
1069    #[test]
1070    fn check_version() {
1071        fn test(client: &str, server: &str) -> bool {
1072            test_inner(check_client_version, client, server)
1073        }
1074
1075        assert!(test("0.18.0", "0.18.0"));
1076        assert!(test("0.18.1", "0.18.0"));
1077        assert!(test("0.18.0", "0.18.1"));
1078        assert!(test("0.19.0", "0.18.0"));
1079        assert!(test("1.0.0", "0.18.0"));
1080        assert!(test("1.1.0", "1.0.0"));
1081        assert!(test("0.19.0-alpha.1", "0.19.0-alpha.1"));
1082
1083        assert!(!test("0.19.0-alpha.1", "0.19.0-alpha.2"));
1084        assert!(!test("0.18.0", "0.19.0"));
1085        assert!(!test("0.18.0", "1.0.0"));
1086    }
1087
1088    #[test]
1089    fn check_version_compat() {
1090        fn test(client: &str, server: &str) -> bool {
1091            test_inner(check_client_version_compat, client, server)
1092        }
1093
1094        assert!(test("1.1.0", "1.1.0"));
1095        assert!(test("1.1.1", "1.1.1"));
1096        assert!(test("1.2.0", "1.1.1"));
1097        assert!(test("1.2.0-rc.1", "1.1.1"));
1098
1099        assert!(!test("2.0.0", "1.1.1"));
1100    }
1101}