risc0_zkvm/host/api/
server.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    error::Error as StdError,
17    io::{BufReader, Error as IoError, ErrorKind as IoErrorKind, Read, Write},
18    path::{Path, PathBuf},
19};
20
21use anyhow::{anyhow, bail, Context, Result};
22use bytes::Bytes;
23use prost::Message;
24use risc0_zkp::core::digest::Digest;
25
26use super::{malformed_err, path_to_string, pb, ConnectionWrapper, Connector, TcpConnector};
27use crate::{
28    get_prover_server, get_version,
29    host::{
30        api::convert::keccak_input_to_bytes,
31        client::{
32            env::{CoprocessorCallback, ProveKeccakRequest, ProveZkrRequest},
33            slice_io::SliceIo,
34        },
35        server::{
36            exec::executor::ExecutorImpl, prove::keccak::prove_keccak, session::NullSegmentRef,
37        },
38    },
39    prove_registered_zkr,
40    recursion::identity_p254,
41    AssetRequest, Assumption, ExecutorEnv, InnerAssumptionReceipt, ProverOpts, Receipt,
42    ReceiptClaim, Segment, SegmentReceipt, SegmentRef, Session, SuccinctReceipt, TraceCallback,
43    TraceEvent, Unknown, VerifierContext,
44};
45
46/// A server implementation for handling requests by clients of the zkVM.
47pub struct Server {
48    connector: Box<dyn Connector>,
49}
50
51struct PosixIoProxy {
52    fd: u32,
53    conn: ConnectionWrapper,
54}
55
56impl PosixIoProxy {
57    fn new(fd: u32, conn: ConnectionWrapper) -> Self {
58        PosixIoProxy { fd, conn }
59    }
60}
61
62impl Read for PosixIoProxy {
63    fn read(&mut self, to_guest: &mut [u8]) -> std::io::Result<usize> {
64        let nread = to_guest.len().try_into().map_io_err()?;
65        let request = pb::api::ServerReply {
66            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
67                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
68                    kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
69                        fd: self.fd,
70                        cmd: Some(pb::api::PosixCmd {
71                            kind: Some(pb::api::posix_cmd::Kind::Read(nread)),
72                        }),
73                    })),
74                })),
75            })),
76        };
77
78        tracing::trace!("tx: {request:?}");
79        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
80        tracing::trace!("rx: {reply:?}");
81
82        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
83        match kind {
84            pb::api::on_io_reply::Kind::Ok(bytes) => {
85                let (head, _) = to_guest.split_at_mut(bytes.len());
86                head.copy_from_slice(&bytes);
87                Ok(bytes.len())
88            }
89            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
90        }
91    }
92}
93
94impl Write for PosixIoProxy {
95    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
96        let request = pb::api::ServerReply {
97            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
98                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
99                    kind: Some(pb::api::on_io_request::Kind::Posix(pb::api::PosixIo {
100                        fd: self.fd,
101                        cmd: Some(pb::api::PosixCmd {
102                            kind: Some(pb::api::posix_cmd::Kind::Write(buf.into())),
103                        }),
104                    })),
105                })),
106            })),
107        };
108
109        tracing::trace!("tx: {request:?}");
110        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
111        tracing::trace!("rx: {reply:?}");
112
113        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
114        match kind {
115            pb::api::on_io_reply::Kind::Ok(_) => Ok(buf.len()),
116            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
117        }
118    }
119
120    fn flush(&mut self) -> std::io::Result<()> {
121        Ok(())
122    }
123}
124
125#[derive(Clone)]
126struct SliceIoProxy {
127    conn: ConnectionWrapper,
128}
129
130impl SliceIoProxy {
131    fn new(conn: ConnectionWrapper) -> Self {
132        Self { conn }
133    }
134}
135
136impl SliceIo for SliceIoProxy {
137    fn handle_io(&mut self, syscall: &str, from_guest: Bytes) -> Result<Bytes> {
138        let request = pb::api::ServerReply {
139            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
140                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
141                    kind: Some(pb::api::on_io_request::Kind::Slice(pb::api::SliceIo {
142                        name: syscall.to_string(),
143                        from_guest: from_guest.into(),
144                    })),
145                })),
146            })),
147        };
148        tracing::trace!("tx: {request:?}");
149        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
150        tracing::trace!("rx: {reply:?}");
151
152        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
153        match kind {
154            pb::api::on_io_reply::Kind::Ok(buf) => Ok(buf.into()),
155            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
156        }
157    }
158}
159
160struct TraceProxy {
161    conn: ConnectionWrapper,
162}
163
164impl TraceProxy {
165    fn new(conn: ConnectionWrapper) -> Self {
166        Self { conn }
167    }
168}
169
170impl TraceCallback for TraceProxy {
171    fn trace_callback(&mut self, event: TraceEvent) -> Result<()> {
172        let Ok(event) = event.clone().try_into() else {
173            tracing::trace!("ignoring unknown event {event:?}");
174            return Ok(());
175        };
176
177        let request = pb::api::ServerReply {
178            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
179                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
180                    kind: Some(pb::api::on_io_request::Kind::Trace(event)),
181                })),
182            })),
183        };
184        tracing::trace!("tx: {request:?}");
185        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
186        tracing::trace!("rx: {reply:?}");
187
188        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
189        match kind {
190            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
191            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
192        }
193    }
194}
195
196struct CoprocessorProxy {
197    conn: ConnectionWrapper,
198}
199
200impl CoprocessorProxy {
201    fn new(conn: ConnectionWrapper) -> Self {
202        Self { conn }
203    }
204}
205
206impl CoprocessorCallback for CoprocessorProxy {
207    fn prove_zkr(&mut self, proof_request: ProveZkrRequest) -> Result<()> {
208        let request = pb::api::ServerReply {
209            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
210                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
211                    kind: Some(pb::api::on_io_request::Kind::Coprocessor(
212                        pb::api::CoprocessorRequest {
213                            kind: Some(pb::api::coprocessor_request::Kind::ProveZkr({
214                                pb::api::ProveZkrRequest {
215                                    claim_digest: Some(proof_request.claim_digest.into()),
216                                    control_id: Some(proof_request.control_id.into()),
217                                    input: proof_request.input,
218                                    receipt_out: None,
219                                }
220                            })),
221                        },
222                    )),
223                })),
224            })),
225        };
226        tracing::trace!("tx: {request:?}");
227        let reply: pb::api::OnIoReply = self.conn.send_recv(request).map_io_err()?;
228        tracing::trace!("rx: {reply:?}");
229
230        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
231        match kind {
232            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
233            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
234        }
235    }
236
237    fn prove_keccak(&mut self, proof_request: ProveKeccakRequest) -> Result<()> {
238        let input = keccak_input_to_bytes(&proof_request.input);
239        let request = pb::api::ServerReply {
240            kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
241                kind: Some(pb::api::client_callback::Kind::Io(pb::api::OnIoRequest {
242                    kind: Some(pb::api::on_io_request::Kind::Coprocessor(
243                        pb::api::CoprocessorRequest {
244                            kind: Some(pb::api::coprocessor_request::Kind::ProveKeccak({
245                                pb::api::ProveKeccakRequest {
246                                    claim_digest: Some(proof_request.claim_digest.into()),
247                                    po2: proof_request.po2 as u32,
248                                    control_root: Some(proof_request.control_root.into()),
249                                    input,
250                                    receipt_out: None,
251                                }
252                            })),
253                        },
254                    )),
255                })),
256            })),
257        };
258        tracing::trace!("tx: {request:?}");
259        self.conn.send(request)?;
260
261        let reply: pb::api::OnIoReply = self.conn.recv().map_io_err()?;
262        tracing::trace!("rx: {reply:?}");
263
264        let kind = reply.kind.ok_or("Malformed message").map_io_err()?;
265        match kind {
266            pb::api::on_io_reply::Kind::Ok(_) => Ok(()),
267            pb::api::on_io_reply::Kind::Error(err) => Err(err.into()),
268        }
269    }
270}
271
272impl Server {
273    /// Construct a new [Server] with the specified [Connector].
274    pub fn new(connector: Box<dyn Connector>) -> Self {
275        Self { connector }
276    }
277
278    /// Construct a new [Server] which will connect to the specified TCP/IP
279    /// address.
280    pub fn new_tcp<A: AsRef<str>>(addr: A) -> Self {
281        let connector = TcpConnector::new(addr.as_ref());
282        Self::new(Box::new(connector))
283    }
284
285    /// Start the [Server] and run until all requests are complete.
286    pub fn run(&self) -> Result<()> {
287        tracing::debug!("connect");
288        let mut conn = self.connector.connect()?;
289
290        let server_version = get_version().map_err(|err| anyhow!(err))?;
291
292        let request: pb::api::HelloRequest = conn.recv()?;
293        tracing::trace!("rx: {request:?}");
294
295        let client_version: semver::Version = request
296            .version
297            .ok_or_else(|| malformed_err("HelloRequest.version"))?
298            .try_into()
299            .map_err(|err: semver::Error| anyhow!(err))?;
300
301        #[cfg(not(feature = "r0vm-ver-compat"))]
302        let check_client_func = check_client_version;
303        #[cfg(feature = "r0vm-ver-compat")]
304        let check_client_func = check_client_version_compat;
305
306        if !check_client_func(&client_version, &server_version) {
307            let msg = format!(
308                "incompatible client version: {client_version}, server version: {server_version}"
309            );
310            tracing::debug!("{msg}");
311            bail!(msg);
312        }
313
314        let reply = pb::api::HelloReply {
315            kind: Some(pb::api::hello_reply::Kind::Ok(pb::api::HelloResult {
316                version: Some(server_version.into()),
317            })),
318        };
319        tracing::trace!("tx: {reply:?}");
320        let request: pb::api::ServerRequest = conn.send_recv(reply)?;
321        tracing::trace!("rx: {request:?}");
322
323        match request
324            .kind
325            .ok_or_else(|| malformed_err("ServerRequest.kind"))?
326        {
327            pb::api::server_request::Kind::Prove(request) => self.on_prove(conn, request),
328            pb::api::server_request::Kind::Execute(request) => self.on_execute(conn, request),
329            pb::api::server_request::Kind::ProveSegment(request) => {
330                self.on_prove_segment(conn, request)
331            }
332            pb::api::server_request::Kind::Lift(request) => self.on_lift(conn, request),
333            pb::api::server_request::Kind::Join(request) => self.on_join(conn, request),
334            pb::api::server_request::Kind::Resolve(request) => self.on_resolve(conn, request),
335            pb::api::server_request::Kind::IdentityP254(request) => {
336                self.on_identity_p254(conn, request)
337            }
338            pb::api::server_request::Kind::Compress(request) => self.on_compress(conn, request),
339            pb::api::server_request::Kind::Verify(request) => self.on_verify(conn, request),
340            pb::api::server_request::Kind::ProveZkr(request) => self.on_prove_zkr(conn, request),
341            pb::api::server_request::Kind::ProveKeccak(request) => {
342                self.on_prove_keccak(conn, request)
343            }
344            pb::api::server_request::Kind::Union(request) => self.on_union(conn, request),
345        }
346    }
347
348    fn on_execute(
349        &self,
350        mut conn: ConnectionWrapper,
351        request: pb::api::ExecuteRequest,
352    ) -> Result<()> {
353        fn inner(
354            conn: &mut ConnectionWrapper,
355            request: pb::api::ExecuteRequest,
356        ) -> Result<pb::api::ServerReply> {
357            let env_request = request
358                .env
359                .ok_or_else(|| malformed_err("ExecuteRequest.env"))?;
360            let env = build_env(conn, &env_request)?;
361
362            let binary = env_request
363                .binary
364                .ok_or_else(|| malformed_err("ExecuteRequest.binary"))?;
365
366            let segments_out = request
367                .segments_out
368                .ok_or_else(|| malformed_err("ExecuteRequest.segments_out"))?;
369            let bytes = binary.as_bytes()?;
370
371            let session = match AssetRequest::try_from(segments_out.clone())? {
372                #[cfg(feature = "redis")]
373                AssetRequest::Redis(params) => execute_redis(conn, env, bytes, params)?,
374                _ => execute_default(conn, env, bytes, &segments_out)?,
375            };
376
377            let receipt_claim = session.claim()?;
378            Ok(pb::api::ServerReply {
379                kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
380                    kind: Some(pb::api::client_callback::Kind::SessionDone(
381                        pb::api::OnSessionDone {
382                            session: Some(pb::api::SessionInfo {
383                                segments: session.segments.len().try_into()?,
384                                journal: session.journal.unwrap_or_default().bytes,
385                                exit_code: Some(session.exit_code.try_into()?),
386                                receipt_claim: Some(pb::api::Asset::from_bytes(
387                                    &pb::api::AssetRequest {
388                                        kind: Some(pb::api::asset_request::Kind::Inline(())),
389                                    },
390                                    pb::core::ReceiptClaim::try_from(receipt_claim)?
391                                        .encode_to_vec()
392                                        .into(),
393                                    "session_info.claim",
394                                )?),
395                            }),
396                        },
397                    )),
398                })),
399            })
400        }
401
402        let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
403            kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
404                reason: err.to_string(),
405            })),
406        });
407
408        tracing::trace!("tx: {msg:?}");
409        conn.send(msg)
410    }
411
412    fn on_prove(&self, mut conn: ConnectionWrapper, request: pb::api::ProveRequest) -> Result<()> {
413        fn inner(
414            conn: &mut ConnectionWrapper,
415            request: pb::api::ProveRequest,
416        ) -> Result<pb::api::ServerReply> {
417            let env_request = request
418                .env
419                .ok_or_else(|| malformed_err("ProveRequest.env"))?;
420            let env = build_env(conn, &env_request)?;
421
422            let binary = env_request
423                .binary
424                .ok_or_else(|| malformed_err("ProveRequest.env_request.binary"))?;
425            let bytes = binary.as_bytes()?;
426
427            let opts: ProverOpts = request
428                .opts
429                .ok_or_else(|| malformed_err("ProveRequest.opts"))?
430                .try_into()?;
431            let prover = get_prover_server(&opts)?;
432            let ctx = VerifierContext::default();
433            let prove_info = prover.prove_with_ctx(env, &ctx, &bytes)?;
434
435            let prove_info: pb::core::ProveInfo = prove_info.try_into()?;
436            let prove_info_bytes = prove_info.encode_to_vec();
437            let asset = pb::api::Asset::from_bytes(
438                &request
439                    .receipt_out
440                    .ok_or_else(|| malformed_err("ProveRequest.receipt_out"))?,
441                prove_info_bytes.into(),
442                "prove_info.zkp",
443            )?;
444
445            Ok(pb::api::ServerReply {
446                kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
447                    kind: Some(pb::api::client_callback::Kind::ProveDone(
448                        pb::api::OnProveDone {
449                            prove_info: Some(asset),
450                        },
451                    )),
452                })),
453            })
454        }
455
456        let msg = inner(&mut conn, request).unwrap_or_else(|err| pb::api::ServerReply {
457            kind: Some(pb::api::server_reply::Kind::Error(pb::api::GenericError {
458                reason: err.to_string(),
459            })),
460        });
461
462        tracing::trace!("tx: {msg:?}");
463        conn.send(msg)
464    }
465
466    fn on_prove_segment(
467        &self,
468        mut conn: ConnectionWrapper,
469        request: pb::api::ProveSegmentRequest,
470    ) -> Result<()> {
471        fn inner(request: pb::api::ProveSegmentRequest) -> Result<pb::api::ProveSegmentReply> {
472            let opts: ProverOpts = request
473                .opts
474                .ok_or_else(|| malformed_err("ProveSegmentRequest.opts"))?
475                .try_into()?;
476            let segment_bytes = request
477                .segment
478                .ok_or_else(|| malformed_err("ProveSegmentRequest.segment"))?
479                .as_bytes()?;
480            let segment: Segment = bincode::deserialize(&segment_bytes)?;
481
482            let prover = get_prover_server(&opts)?;
483            let ctx = VerifierContext::default();
484            let receipt = prover.prove_segment(&ctx, &segment)?;
485
486            let receipt_pb: pb::core::SegmentReceipt = receipt.try_into()?;
487            let receipt_bytes = receipt_pb.encode_to_vec();
488            let asset = pb::api::Asset::from_bytes(
489                &request
490                    .receipt_out
491                    .ok_or_else(|| malformed_err("ProveSegmentRequest.receipt_out"))?,
492                receipt_bytes.into(),
493                "receipt.zkp",
494            )?;
495
496            Ok(pb::api::ProveSegmentReply {
497                kind: Some(pb::api::prove_segment_reply::Kind::Ok(
498                    pb::api::ProveSegmentResult {
499                        receipt: Some(asset),
500                    },
501                )),
502            })
503        }
504
505        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveSegmentReply {
506            kind: Some(pb::api::prove_segment_reply::Kind::Error(
507                pb::api::GenericError {
508                    reason: err.to_string(),
509                },
510            )),
511        });
512
513        tracing::trace!("tx: {msg:?}");
514        conn.send(msg)
515    }
516
517    fn on_prove_zkr(
518        &self,
519        mut conn: ConnectionWrapper,
520        request: pb::api::ProveZkrRequest,
521    ) -> Result<()> {
522        fn inner(request: pb::api::ProveZkrRequest) -> Result<pb::api::ProveZkrReply> {
523            let control_id = request
524                .control_id
525                .ok_or_else(|| malformed_err("ProveZkrRequest.control_id"))?
526                .try_into()?;
527            let receipt = prove_registered_zkr(&control_id, vec![control_id], &request.input)?;
528
529            let receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
530            let receipt_bytes = receipt_pb.encode_to_vec();
531            let asset = pb::api::Asset::from_bytes(
532                &request
533                    .receipt_out
534                    .ok_or_else(|| malformed_err("ProveZkrRequest.receipt_out"))?,
535                receipt_bytes.into(),
536                "receipt.zkp",
537            )?;
538
539            Ok(pb::api::ProveZkrReply {
540                kind: Some(pb::api::prove_zkr_reply::Kind::Ok(
541                    pb::api::ProveZkrResult {
542                        receipt: Some(asset),
543                    },
544                )),
545            })
546        }
547
548        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveZkrReply {
549            kind: Some(pb::api::prove_zkr_reply::Kind::Error(
550                pb::api::GenericError {
551                    reason: err.to_string(),
552                },
553            )),
554        });
555
556        tracing::trace!("tx: {msg:?}");
557        conn.send(msg)
558    }
559
560    fn on_prove_keccak(
561        &self,
562        mut conn: ConnectionWrapper,
563        request: pb::api::ProveKeccakRequest,
564    ) -> Result<()> {
565        fn inner(request_pb: pb::api::ProveKeccakRequest) -> Result<pb::api::ProveKeccakReply> {
566            let request: ProveKeccakRequest = request_pb.clone().try_into()?;
567            let receipt = prove_keccak(&request)?;
568
569            let receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
570            let receipt_bytes = receipt_pb.encode_to_vec();
571            let asset = pb::api::Asset::from_bytes(
572                &request_pb
573                    .receipt_out
574                    .ok_or_else(|| malformed_err("ProveKeccakRequest.receipt_out"))?,
575                receipt_bytes.into(),
576                "receipt.zkp",
577            )?;
578
579            Ok(pb::api::ProveKeccakReply {
580                kind: Some(pb::api::prove_keccak_reply::Kind::Ok(
581                    pb::api::ProveKeccakResult {
582                        receipt: Some(asset),
583                    },
584                )),
585            })
586        }
587
588        let msg = inner(request).unwrap_or_else(|err| pb::api::ProveKeccakReply {
589            kind: Some(pb::api::prove_keccak_reply::Kind::Error(
590                pb::api::GenericError {
591                    reason: err.to_string(),
592                },
593            )),
594        });
595
596        tracing::trace!("tx: {msg:?}");
597        conn.send(msg)
598    }
599
600    fn on_lift(&self, mut conn: ConnectionWrapper, request: pb::api::LiftRequest) -> Result<()> {
601        fn inner(request: pb::api::LiftRequest) -> Result<pb::api::LiftReply> {
602            let opts: ProverOpts = request
603                .opts
604                .ok_or_else(|| malformed_err("LiftRequest.opts"))?
605                .try_into()?;
606            let receipt_bytes = request
607                .receipt
608                .ok_or_else(|| malformed_err("LiftRequest.receipt"))?
609                .as_bytes()?;
610            let segment_receipt: SegmentReceipt = bincode::deserialize(&receipt_bytes)?;
611
612            let prover = get_prover_server(&opts)?;
613            let receipt = prover.lift(&segment_receipt)?;
614
615            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
616            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
617            let asset = pb::api::Asset::from_bytes(
618                &request
619                    .receipt_out
620                    .ok_or_else(|| malformed_err("LiftRequest.receipt_out"))?,
621                succinct_receipt_bytes.into(),
622                "receipt.zkp",
623            )?;
624
625            Ok(pb::api::LiftReply {
626                kind: Some(pb::api::lift_reply::Kind::Ok(pb::api::LiftResult {
627                    receipt: Some(asset),
628                })),
629            })
630        }
631
632        let msg = inner(request).unwrap_or_else(|err| pb::api::LiftReply {
633            kind: Some(pb::api::lift_reply::Kind::Error(pb::api::GenericError {
634                reason: err.to_string(),
635            })),
636        });
637
638        // tracing::trace!("tx: {msg:?}");
639        conn.send(msg)
640    }
641
642    fn on_join(&self, mut conn: ConnectionWrapper, request: pb::api::JoinRequest) -> Result<()> {
643        fn inner(request: pb::api::JoinRequest) -> Result<pb::api::JoinReply> {
644            let opts: ProverOpts = request
645                .opts
646                .ok_or_else(|| malformed_err("JoinRequest.opts"))?
647                .try_into()?;
648            let left_receipt_bytes = request
649                .left_receipt
650                .ok_or_else(|| malformed_err("JoinRequest.left_receipt"))?
651                .as_bytes()?;
652            let left_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
653                bincode::deserialize(&left_receipt_bytes)?;
654            let right_receipt_bytes = request
655                .right_receipt
656                .ok_or_else(|| malformed_err("JoinRequest.right_receipt"))?
657                .as_bytes()?;
658            let right_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
659                bincode::deserialize(&right_receipt_bytes)?;
660
661            let prover = get_prover_server(&opts)?;
662            let receipt = prover.join(&left_succinct_receipt, &right_succinct_receipt)?;
663
664            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
665            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
666            let asset = pb::api::Asset::from_bytes(
667                &request
668                    .receipt_out
669                    .ok_or_else(|| malformed_err("JoinRequest.receipt_out"))?,
670                succinct_receipt_bytes.into(),
671                "receipt.zkp",
672            )?;
673
674            Ok(pb::api::JoinReply {
675                kind: Some(pb::api::join_reply::Kind::Ok(pb::api::JoinResult {
676                    receipt: Some(asset),
677                })),
678            })
679        }
680
681        let msg = inner(request).unwrap_or_else(|err| pb::api::JoinReply {
682            kind: Some(pb::api::join_reply::Kind::Error(pb::api::GenericError {
683                reason: err.to_string(),
684            })),
685        });
686
687        // tracing::trace!("tx: {msg:?}");
688        conn.send(msg)
689    }
690
691    fn on_union(&self, mut conn: ConnectionWrapper, request: pb::api::UnionRequest) -> Result<()> {
692        fn inner(request: pb::api::UnionRequest) -> Result<pb::api::UnionReply> {
693            let opts: ProverOpts = request
694                .opts
695                .ok_or_else(|| malformed_err("UnionRequest.opts"))?
696                .try_into()?;
697            let left_receipt_bytes = request
698                .left_receipt
699                .ok_or_else(|| malformed_err("UnionRequest.left_receipt"))?
700                .as_bytes()?;
701            let left_succinct_receipt: SuccinctReceipt<Unknown> =
702                bincode::deserialize(&left_receipt_bytes)?;
703            let right_receipt_bytes = request
704                .right_receipt
705                .ok_or_else(|| malformed_err("UnionRequest.right_receipt"))?
706                .as_bytes()?;
707            let right_succinct_receipt: SuccinctReceipt<Unknown> =
708                bincode::deserialize(&right_receipt_bytes)?;
709
710            let prover = get_prover_server(&opts)?;
711            let receipt = prover.union(&left_succinct_receipt, &right_succinct_receipt)?;
712
713            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
714            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
715            let asset = pb::api::Asset::from_bytes(
716                &request
717                    .receipt_out
718                    .ok_or_else(|| malformed_err("UnionRequest.receipt_out"))?,
719                succinct_receipt_bytes.into(),
720                "receipt.zkp",
721            )?;
722
723            Ok(pb::api::UnionReply {
724                kind: Some(pb::api::union_reply::Kind::Ok(pb::api::UnionResult {
725                    receipt: Some(asset),
726                })),
727            })
728        }
729
730        let msg = inner(request).unwrap_or_else(|err| pb::api::UnionReply {
731            kind: Some(pb::api::union_reply::Kind::Error(pb::api::GenericError {
732                reason: err.to_string(),
733            })),
734        });
735
736        // tracing::trace!("tx: {msg:?}");
737        conn.send(msg)
738    }
739
740    fn on_resolve(
741        &self,
742        mut conn: ConnectionWrapper,
743        request: pb::api::ResolveRequest,
744    ) -> Result<()> {
745        fn inner(request: pb::api::ResolveRequest) -> Result<pb::api::ResolveReply> {
746            let opts: ProverOpts = request
747                .opts
748                .ok_or_else(|| malformed_err("ResolveRequest.opts"))?
749                .try_into()?;
750            let conditional_receipt_bytes = request
751                .conditional_receipt
752                .ok_or_else(|| malformed_err("ResolveRequest.conditional_receipt"))?
753                .as_bytes()?;
754            let conditional_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
755                bincode::deserialize(&conditional_receipt_bytes)?;
756            let assumption_receipt_bytes = request
757                .assumption_receipt
758                .ok_or_else(|| malformed_err("ResolveRequest.assumption_receipt"))?
759                .as_bytes()?;
760            let assumption_succinct_receipt: SuccinctReceipt<ReceiptClaim> =
761                bincode::deserialize(&assumption_receipt_bytes)?;
762
763            let prover = get_prover_server(&opts)?;
764            let receipt = prover.resolve(
765                &conditional_succinct_receipt,
766                &assumption_succinct_receipt.into_unknown(),
767            )?;
768
769            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
770            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
771            let asset = pb::api::Asset::from_bytes(
772                &request
773                    .receipt_out
774                    .ok_or_else(|| malformed_err("ResolveRequest.receipt_out"))?,
775                succinct_receipt_bytes.into(),
776                "receipt.zkp",
777            )?;
778
779            Ok(pb::api::ResolveReply {
780                kind: Some(pb::api::resolve_reply::Kind::Ok(pb::api::ResolveResult {
781                    receipt: Some(asset),
782                })),
783            })
784        }
785
786        let msg = inner(request).unwrap_or_else(|err| pb::api::ResolveReply {
787            kind: Some(pb::api::resolve_reply::Kind::Error(pb::api::GenericError {
788                reason: err.to_string(),
789            })),
790        });
791
792        // tracing::trace!("tx: {msg:?}");
793        conn.send(msg)
794    }
795
796    fn on_identity_p254(
797        &self,
798        mut conn: ConnectionWrapper,
799        request: pb::api::IdentityP254Request,
800    ) -> Result<()> {
801        fn inner(request: pb::api::IdentityP254Request) -> Result<pb::api::IdentityP254Reply> {
802            let receipt_bytes = request
803                .receipt
804                .ok_or_else(|| malformed_err("IdentityP254Request.receipt"))?
805                .as_bytes()?;
806            let succinct_receipt: SuccinctReceipt<ReceiptClaim> =
807                bincode::deserialize(&receipt_bytes)?;
808
809            let receipt = identity_p254(&succinct_receipt)?;
810            let succinct_receipt_pb: pb::core::SuccinctReceipt = receipt.try_into()?;
811            let succinct_receipt_bytes = succinct_receipt_pb.encode_to_vec();
812            let asset = pb::api::Asset::from_bytes(
813                &request
814                    .receipt_out
815                    .ok_or_else(|| malformed_err("IdentityP254Request.receipt_out"))?,
816                succinct_receipt_bytes.into(),
817                "receipt.zkp",
818            )?;
819
820            Ok(pb::api::IdentityP254Reply {
821                kind: Some(pb::api::identity_p254_reply::Kind::Ok(
822                    pb::api::IdentityP254Result {
823                        receipt: Some(asset),
824                    },
825                )),
826            })
827        }
828
829        let msg = inner(request).unwrap_or_else(|err| pb::api::IdentityP254Reply {
830            kind: Some(pb::api::identity_p254_reply::Kind::Error(
831                pb::api::GenericError {
832                    reason: err.to_string(),
833                },
834            )),
835        });
836
837        // tracing::trace!("tx: {msg:?}");
838        conn.send(msg)
839    }
840
841    fn on_compress(
842        &self,
843        mut conn: ConnectionWrapper,
844        request: pb::api::CompressRequest,
845    ) -> Result<()> {
846        fn inner(request: pb::api::CompressRequest) -> Result<pb::api::CompressReply> {
847            let opts: ProverOpts = request
848                .opts
849                .ok_or_else(|| malformed_err("CompressRequest.opts"))?
850                .try_into()?;
851            let receipt_bytes = request
852                .receipt
853                .ok_or_else(|| malformed_err("CompressRequest.receipt"))?
854                .as_bytes()?;
855            let receipt: Receipt = bincode::deserialize(&receipt_bytes)?;
856
857            let prover = get_prover_server(&opts)?;
858            let receipt = prover.compress(&opts, &receipt)?;
859
860            let receipt_pb: pb::core::Receipt = receipt.try_into()?;
861            let receipt_bytes = receipt_pb.encode_to_vec();
862            let asset = pb::api::Asset::from_bytes(
863                &request
864                    .receipt_out
865                    .ok_or_else(|| malformed_err("CompressRequest.receipt_out"))?,
866                receipt_bytes.into(),
867                "receipt.zkp",
868            )?;
869
870            Ok(pb::api::CompressReply {
871                kind: Some(pb::api::compress_reply::Kind::Ok(pb::api::CompressResult {
872                    receipt: Some(asset),
873                })),
874            })
875        }
876
877        let msg = inner(request).unwrap_or_else(|err| pb::api::CompressReply {
878            kind: Some(pb::api::compress_reply::Kind::Error(
879                pb::api::GenericError {
880                    reason: err.to_string(),
881                },
882            )),
883        });
884
885        // tracing::trace!("tx: {msg:?}");
886        conn.send(msg)
887    }
888
889    fn on_verify(
890        &self,
891        mut conn: ConnectionWrapper,
892        request: pb::api::VerifyRequest,
893    ) -> Result<()> {
894        fn inner(request: pb::api::VerifyRequest) -> Result<()> {
895            let receipt_bytes = request
896                .receipt
897                .ok_or_else(|| malformed_err("VerifyRequest.receipt"))?
898                .as_bytes()?;
899            let receipt: Receipt =
900                bincode::deserialize(&receipt_bytes).context("deserialize receipt")?;
901            let image_id: Digest = request
902                .image_id
903                .ok_or_else(|| malformed_err("VerifyRequest.image_id"))?
904                .try_into()?;
905            receipt
906                .verify(image_id)
907                .map_err(|err| anyhow!("verify failed: {err}"))
908        }
909
910        let msg: pb::api::GenericReply = inner(request).into();
911        // tracing::trace!("tx: {msg:?}");
912        conn.send(msg)
913    }
914}
915
916fn build_env<'a>(
917    conn: &ConnectionWrapper,
918    request: &pb::api::ExecutorEnv,
919) -> Result<ExecutorEnv<'a>> {
920    let mut env_builder = ExecutorEnv::builder();
921    env_builder.env_vars(request.env_vars.clone());
922    env_builder.args(&request.args);
923    for fd in request.read_fds.iter() {
924        let proxy = PosixIoProxy::new(*fd, conn.clone());
925        let reader = BufReader::new(proxy);
926        env_builder.read_fd(*fd, reader);
927    }
928    for fd in request.write_fds.iter() {
929        let proxy = PosixIoProxy::new(*fd, conn.clone());
930        env_builder.write_fd(*fd, proxy);
931    }
932    let proxy = SliceIoProxy::new(conn.clone());
933    for name in request.slice_ios.iter() {
934        env_builder.slice_io(name, proxy.clone());
935    }
936    if let Some(segment_limit_po2) = request.segment_limit_po2 {
937        env_builder.segment_limit_po2(segment_limit_po2);
938    }
939    if let Some(keccak_max_po2) = request.keccak_max_po2 {
940        env_builder.keccak_max_po2(keccak_max_po2)?;
941    }
942    env_builder.session_limit(request.session_limit);
943    if request.trace_events.is_some() {
944        let proxy = TraceProxy::new(conn.clone());
945        env_builder.trace_callback(proxy);
946    }
947    if !request.pprof_out.is_empty() {
948        env_builder.enable_profiler(Path::new(&request.pprof_out));
949    }
950    if !request.segment_path.is_empty() {
951        env_builder.segment_path(Path::new(&request.segment_path));
952    }
953    if request.coprocessor {
954        let proxy = CoprocessorProxy::new(conn.clone());
955        env_builder.coprocessor_callback(proxy);
956    }
957
958    for assumption in request.assumptions.iter() {
959        match assumption
960            .kind
961            .as_ref()
962            .ok_or_else(|| malformed_err("Assumption.kind"))?
963        {
964            pb::api::assumption_receipt::Kind::Proven(asset) => {
965                let receipt: InnerAssumptionReceipt =
966                    pb::core::InnerReceipt::decode(asset.as_bytes()?)?.try_into()?;
967                env_builder.add_assumption(receipt)
968            }
969            pb::api::assumption_receipt::Kind::Unresolved(asset) => {
970                let assumption: Assumption =
971                    pb::core::Assumption::decode(asset.as_bytes()?)?.try_into()?;
972                env_builder.add_assumption(assumption)
973            }
974        };
975    }
976    env_builder.build()
977}
978
979trait IoOtherError<T> {
980    fn map_io_err(self) -> Result<T, IoError>;
981}
982
983impl<T, E: Into<Box<dyn StdError + Send + Sync>>> IoOtherError<T> for Result<T, E> {
984    fn map_io_err(self) -> Result<T, IoError> {
985        self.map_err(|err| IoError::new(IoErrorKind::Other, err))
986    }
987}
988
989impl From<pb::api::GenericError> for IoError {
990    fn from(err: pb::api::GenericError) -> Self {
991        IoError::new(IoErrorKind::Other, err.reason)
992    }
993}
994
995impl pb::api::Asset {
996    pub fn from_bytes<P: AsRef<Path>>(
997        request: &pb::api::AssetRequest,
998        bytes: Bytes,
999        path: P,
1000    ) -> Result<Self> {
1001        match request
1002            .kind
1003            .as_ref()
1004            .ok_or_else(|| malformed_err("AssetRequest.kind"))?
1005        {
1006            pb::api::asset_request::Kind::Inline(()) => Ok(Self {
1007                kind: Some(pb::api::asset::Kind::Inline(bytes.into())),
1008            }),
1009            pb::api::asset_request::Kind::Path(base_path) => {
1010                let base_path = PathBuf::from(base_path);
1011                let path = base_path.join(path);
1012                std::fs::write(&path, bytes)?;
1013                Ok(Self {
1014                    kind: Some(pb::api::asset::Kind::Path(path_to_string(path)?)),
1015                })
1016            }
1017            pb::api::asset_request::Kind::Redis(_) => {
1018                tracing::error!("It's likely that r0vm is not installed with the redis feature");
1019                bail!("from_bytes not supported for redis")
1020            }
1021        }
1022    }
1023}
1024
1025#[allow(dead_code)]
1026fn check_client_version(client: &semver::Version, server: &semver::Version) -> bool {
1027    if server.pre.is_empty() {
1028        let comparator = semver::Comparator {
1029            op: semver::Op::GreaterEq,
1030            major: server.major,
1031            minor: Some(server.minor),
1032            patch: None,
1033            pre: semver::Prerelease::EMPTY,
1034        };
1035        comparator.matches(client)
1036    } else {
1037        client == server
1038    }
1039}
1040
1041#[allow(dead_code)]
1042fn check_client_version_compat(client: &semver::Version, server: &semver::Version) -> bool {
1043    client.major == server.major
1044}
1045
1046#[cfg(feature = "redis")]
1047fn execute_redis(
1048    conn: &mut ConnectionWrapper,
1049    env: ExecutorEnv,
1050    bytes: Bytes,
1051    params: super::RedisParams,
1052) -> Result<Session> {
1053    use redis::{Client, Commands, ConnectionLike, SetExpiry, SetOptions};
1054    use std::{
1055        sync::{
1056            mpsc::{sync_channel, Receiver},
1057            Arc, Mutex,
1058        },
1059        thread::{spawn, JoinHandle},
1060    };
1061
1062    let channel_size = match std::env::var("RISC0_REDIS_CHANNEL_SIZE") {
1063        Ok(val_str) => val_str.parse::<usize>().unwrap_or(100),
1064        Err(_) => 100,
1065    };
1066    let (sender, receiver) = sync_channel::<(String, Segment)>(channel_size);
1067    let opts = SetOptions::default().with_expiration(SetExpiry::EX(params.ttl));
1068
1069    let redis_err = Arc::new(Mutex::new(None));
1070    let redis_err_clone = redis_err.clone();
1071
1072    let conn = conn.clone();
1073    let join_handle: JoinHandle<()> = spawn(move || {
1074        fn inner(
1075            redis_url: String,
1076            receiver: &Receiver<(String, Segment)>,
1077            opts: SetOptions,
1078            mut conn: ConnectionWrapper,
1079        ) -> Result<()> {
1080            let client = Client::open(redis_url).context("Failed to open Redis connection")?;
1081            let mut connection = client
1082                .get_connection()
1083                .context("Failed to get redis connection")?;
1084            while let Ok((segment_key, segment)) = receiver.recv() {
1085                if !connection.is_open() {
1086                    connection = client
1087                        .get_connection()
1088                        .context("Failed to get redis connection")?;
1089                }
1090                let segment_bytes =
1091                    bincode::serialize(&segment).context("Failed to deserialize segment")?;
1092                match connection.set_options(segment_key.clone(), segment_bytes.clone(), opts) {
1093                    Ok(()) => (),
1094                    Err(err) => {
1095                        tracing::warn!(
1096                            "Failed to set redis key with TTL, trying again. Error: {err}"
1097                        );
1098                        connection = client
1099                            .get_connection()
1100                            .context("Failed to get redis connection")?;
1101                        let _: () = connection
1102                            .set_options(segment_key.clone(), segment_bytes, opts)
1103                            .context("Failed to set redis key with TTL again")?;
1104                    }
1105                };
1106                let asset = pb::api::Asset {
1107                    kind: Some(pb::api::asset::Kind::Redis(segment_key)),
1108                };
1109                send_segment_done_msg(&mut conn, segment, Some(asset))
1110                    .context("Failed to send segment_done msg")?;
1111            }
1112            Ok(())
1113        }
1114
1115        if let Err(err) = inner(params.url, &receiver, opts, conn) {
1116            *redis_err_clone.lock().unwrap() = Some(err);
1117        }
1118    });
1119
1120    let callback = |segment: Segment| -> Result<Box<dyn SegmentRef>> {
1121        let segment_key = format!("{}:{}", params.key, segment.index);
1122        if let Err(send_err) = sender.send((segment_key, segment)) {
1123            let mut redis_err_opt = redis_err.lock().unwrap();
1124            let redis_err_inner = redis_err_opt.take();
1125            return Err(match redis_err_inner {
1126                Some(redis_thread_err) => {
1127                    tracing::error!(
1128                        "Redis err: {redis_thread_err} root: {:?}",
1129                        redis_thread_err.root_cause()
1130                    );
1131                    anyhow!(redis_thread_err)
1132                }
1133                None => send_err.into(),
1134            });
1135        }
1136        Ok(Box::new(NullSegmentRef))
1137    };
1138
1139    let session = ExecutorImpl::from_elf(env, &bytes)?.run_with_callback(callback);
1140
1141    drop(sender);
1142
1143    join_handle
1144        .join()
1145        .map_err(|err| anyhow!("redis task join failed: {err:?}"))?;
1146
1147    session
1148}
1149
1150fn execute_default(
1151    conn: &mut ConnectionWrapper,
1152    env: ExecutorEnv,
1153    bytes: Bytes,
1154    segments_out: &pb::api::AssetRequest,
1155) -> Result<Session> {
1156    let callback = |segment: Segment| -> Result<Box<dyn SegmentRef>> {
1157        let segment_bytes = bincode::serialize(&segment)?;
1158        let asset = pb::api::Asset::from_bytes(
1159            segments_out,
1160            segment_bytes.into(),
1161            format!("segment-{}", segment.index),
1162        )?;
1163        send_segment_done_msg(conn, segment, Some(asset))?;
1164        Ok(Box::new(NullSegmentRef))
1165    };
1166
1167    ExecutorImpl::from_elf(env, &bytes)?.run_with_callback(callback)
1168}
1169
1170fn send_segment_done_msg(
1171    conn: &mut ConnectionWrapper,
1172    segment: Segment,
1173    some_asset: Option<pb::api::Asset>,
1174) -> Result<()> {
1175    let segment = Some(pb::api::SegmentInfo {
1176        index: segment.index,
1177        po2: segment.po2() as u32,
1178        cycles: segment.user_cycles(),
1179        segment: some_asset,
1180    });
1181
1182    let msg = pb::api::ServerReply {
1183        kind: Some(pb::api::server_reply::Kind::Ok(pb::api::ClientCallback {
1184            kind: Some(pb::api::client_callback::Kind::SegmentDone(
1185                pb::api::OnSegmentDone { segment },
1186            )),
1187        })),
1188    };
1189
1190    tracing::trace!("tx: {msg:?}");
1191    let reply: pb::api::GenericReply = conn.send_recv(msg)?;
1192    tracing::trace!("rx: {reply:?}");
1193
1194    let kind = reply
1195        .kind
1196        .ok_or_else(|| malformed_err("GenericReply.kind"))?;
1197    if let pb::api::generic_reply::Kind::Error(err) = kind {
1198        bail!(err)
1199    }
1200    Ok(())
1201}
1202
1203#[cfg(test)]
1204mod tests {
1205    use semver::Version;
1206
1207    use super::{check_client_version, check_client_version_compat};
1208
1209    fn test_inner(check_func: fn(&Version, &Version) -> bool, client: &str, server: &str) -> bool {
1210        check_func(
1211            &Version::parse(client).unwrap(),
1212            &Version::parse(server).unwrap(),
1213        )
1214    }
1215
1216    #[test]
1217    fn check_version() {
1218        fn test(client: &str, server: &str) -> bool {
1219            test_inner(check_client_version, client, server)
1220        }
1221
1222        assert!(test("0.18.0", "0.18.0"));
1223        assert!(test("0.18.1", "0.18.0"));
1224        assert!(test("0.18.0", "0.18.1"));
1225        assert!(test("0.19.0", "0.18.0"));
1226        assert!(test("1.0.0", "0.18.0"));
1227        assert!(test("1.1.0", "1.0.0"));
1228        assert!(test("0.19.0-alpha.1", "0.19.0-alpha.1"));
1229
1230        assert!(!test("0.19.0-alpha.1", "0.19.0-alpha.2"));
1231        assert!(!test("0.18.0", "0.19.0"));
1232        assert!(!test("0.18.0", "1.0.0"));
1233    }
1234
1235    #[test]
1236    fn check_version_compat() {
1237        fn test(client: &str, server: &str) -> bool {
1238            test_inner(check_client_version_compat, client, server)
1239        }
1240
1241        assert!(test("1.1.0", "1.1.0"));
1242        assert!(test("1.1.1", "1.1.1"));
1243        assert!(test("1.2.0", "1.1.1"));
1244        assert!(test("1.2.0-rc.1", "1.1.1"));
1245
1246        assert!(!test("2.0.0", "1.1.1"));
1247    }
1248}