bonsai_rest_api_mock/
lib.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod error;
16mod prover;
17mod routes;
18mod state;
19
20use std::sync::{Arc, RwLock};
21
22use anyhow::Context;
23use axum::{
24    extract::DefaultBodyLimit,
25    routing::{get, post, put},
26    Extension, Router,
27};
28use tokio::{net::TcpListener, sync::mpsc};
29use tower_http::trace::{DefaultOnRequest, TraceLayer};
30use tracing::{info, Level};
31
32use crate::{
33    prover::{Prover, ProverHandle},
34    routes::{
35        create_session, create_snark, get_image_upload, get_input_upload, get_receipt,
36        get_receipt_upload, put_image_upload, put_input_upload, put_receipt, session_status,
37        snark_status,
38    },
39    state::BonsaiState,
40};
41
42fn app(state: Arc<RwLock<BonsaiState>>, prover_handle: ProverHandle) -> Router {
43    Router::new()
44        .route("/images/upload/:image_id", get(get_image_upload))
45        .route("/images/:image_id", put(put_image_upload))
46        .route("/inputs/upload", get(get_input_upload))
47        .route("/inputs/:input_id", put(put_input_upload))
48        .route("/sessions/create", post(create_session))
49        .route("/sessions/status/:session_id", get(session_status))
50        .route("/snark/create", post(create_snark))
51        .route("/snark/status/:snark_id", get(snark_status))
52        .route("/receipts/:session_id", get(get_receipt))
53        .route("/receipts/:session_id", put(put_receipt))
54        .route("/receipts/upload", get(get_receipt_upload))
55        .layer(Extension(prover_handle))
56        .with_state(state)
57        .layer(DefaultBodyLimit::max(256 * 1024 * 1024))
58        .layer(TraceLayer::new_for_http().on_request(
59            DefaultOnRequest::new().level(Level::TRACE), // make on_request less visible
60        ))
61}
62
63/// Starts a mock of Bonsai on localhost at the given port. It exposes the same
64/// REST API of Bonsai alpha.
65///
66/// Note that this mock only performs execution, no proving.
67pub async fn serve(listener: TcpListener) -> anyhow::Result<()> {
68    let local_addr = listener.local_addr().unwrap();
69    let port = local_addr.port();
70    let local_url = format!("http://127.0.0.1:{port}");
71    let state = Arc::new(RwLock::new(BonsaiState::new(local_url)));
72
73    let (sender, receiver) = mpsc::channel(8);
74    let mut prover = Prover::new(receiver, Arc::clone(&state));
75
76    let prover_handle = ProverHandle { sender };
77
78    tokio::spawn(async move { prover.run().await });
79
80    info!("Local Bonsai started on {local_addr}");
81
82    axum::serve(listener, app(state, prover_handle))
83        .await
84        .context(format!("failed to serve Local Bonsai API on {local_addr}"))
85}
86
87#[cfg(test)]
88mod test {
89    use std::time::Duration;
90
91    use anyhow::{bail, Result};
92    use bonsai_sdk::alpha_async as bonsai_sdk;
93    use risc0_zkvm::compute_image_id;
94    use risc0_zkvm_methods::HELLO_COMMIT_ELF;
95    use tokio::net::TcpListener;
96
97    use crate::serve;
98
99    async fn run_bonsai(bonsai_api_url: String, bonsai_api_key: String, elf: &[u8]) -> Result<()> {
100        let client =
101            bonsai_sdk::get_client_from_parts(bonsai_api_url, bonsai_api_key, risc0_zkvm::VERSION)
102                .await?;
103
104        // Compute the image_id, then upload the ELF with the image_id as its key.
105        // TODO: it would be nice if `bonsai_sdk::upload_img` only took the ELF
106        // so that the image_id can be computed server-side.
107        let image_id = hex::encode(compute_image_id(elf)?);
108        bonsai_sdk::upload_img(client.clone(), image_id.clone(), elf.to_vec()).await?;
109
110        // Prepare input data and upload it.
111        let input_id = bonsai_sdk::upload_input(client.clone(), vec![]).await?;
112
113        // Prepare symbolic list of receipt data and upload it.
114        let receipts_ids = vec![bonsai_sdk::upload_receipt(client.clone(), vec![]).await?];
115
116        // Start a session running the prover
117        let session =
118            bonsai_sdk::create_session(client.clone(), image_id, input_id, receipts_ids).await?;
119        loop {
120            let res = bonsai_sdk::session_status(client.clone(), session.clone()).await?;
121            if res.status == "RUNNING" {
122                std::thread::sleep(Duration::from_secs(15));
123                continue;
124            }
125            if res.status == "SUCCEEDED" {
126                // Download the receipt, containing the output
127                let receipt_url = res
128                    .receipt_url
129                    .expect("API error, missing receipt on completed session");
130                bonsai_sdk::download(client.clone(), receipt_url)
131                    .await
132                    .unwrap();
133            } else {
134                bail!("Error");
135            }
136
137            break;
138        }
139
140        Ok(())
141    }
142
143    #[tokio::test]
144    async fn local_bonsai() {
145        use std::{thread::sleep, time::Duration};
146
147        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
148        let local_addr = listener.local_addr().unwrap();
149        let local_bonsai_handle = tokio::spawn(async move { serve(listener).await });
150
151        // wait for the service to be up
152        sleep(Duration::from_secs(1));
153
154        run_bonsai(
155            format!("http://{local_addr}"),
156            "test_key".to_string(),
157            HELLO_COMMIT_ELF,
158        )
159        .await
160        .unwrap();
161
162        local_bonsai_handle.abort();
163    }
164
165    #[tokio::test]
166    async fn local_bonsai_wrong_elf() {
167        use std::{thread::sleep, time::Duration};
168
169        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
170        let local_addr = listener.local_addr().unwrap();
171        let local_bonsai_handle = tokio::spawn(async move { serve(listener).await });
172
173        // wait for the service to be up
174        sleep(Duration::from_secs(1));
175
176        assert!(run_bonsai(
177            format!("http://{local_addr}"),
178            "test_key".to_string(),
179            b"wrong ELF"
180        )
181        .await
182        .is_err());
183
184        local_bonsai_handle.abort();
185    }
186}