bonsai_rest_api_mock/
lib.rs1mod error;
16mod prover;
17mod routes;
18mod state;
19
20use std::sync::{Arc, RwLock};
21
22use anyhow::Context;
23use axum::{
24 extract::DefaultBodyLimit,
25 routing::{get, post, put},
26 Extension, Router,
27};
28use tokio::{net::TcpListener, sync::mpsc};
29use tower_http::trace::{DefaultOnRequest, TraceLayer};
30use tracing::{info, Level};
31
32use crate::{
33 prover::{Prover, ProverHandle},
34 routes::{
35 create_session, create_snark, get_image_upload, get_input_upload, get_receipt,
36 get_receipt_upload, put_image_upload, put_input_upload, put_receipt, session_status,
37 snark_status,
38 },
39 state::BonsaiState,
40};
41
42fn app(state: Arc<RwLock<BonsaiState>>, prover_handle: ProverHandle) -> Router {
43 Router::new()
44 .route("/images/upload/:image_id", get(get_image_upload))
45 .route("/images/:image_id", put(put_image_upload))
46 .route("/inputs/upload", get(get_input_upload))
47 .route("/inputs/:input_id", put(put_input_upload))
48 .route("/sessions/create", post(create_session))
49 .route("/sessions/status/:session_id", get(session_status))
50 .route("/snark/create", post(create_snark))
51 .route("/snark/status/:snark_id", get(snark_status))
52 .route("/receipts/:session_id", get(get_receipt))
53 .route("/receipts/:session_id", put(put_receipt))
54 .route("/receipts/upload", get(get_receipt_upload))
55 .layer(Extension(prover_handle))
56 .with_state(state)
57 .layer(DefaultBodyLimit::max(256 * 1024 * 1024))
58 .layer(TraceLayer::new_for_http().on_request(
59 DefaultOnRequest::new().level(Level::TRACE), ))
61}
62
63pub async fn serve(listener: TcpListener) -> anyhow::Result<()> {
68 let local_addr = listener.local_addr().unwrap();
69 let port = local_addr.port();
70 let local_url = format!("http://127.0.0.1:{port}");
71 let state = Arc::new(RwLock::new(BonsaiState::new(local_url)));
72
73 let (sender, receiver) = mpsc::channel(8);
74 let mut prover = Prover::new(receiver, Arc::clone(&state));
75
76 let prover_handle = ProverHandle { sender };
77
78 tokio::spawn(async move { prover.run().await });
79
80 info!("Local Bonsai started on {local_addr}");
81
82 axum::serve(listener, app(state, prover_handle))
83 .await
84 .context(format!("failed to serve Local Bonsai API on {local_addr}"))
85}
86
87#[cfg(test)]
88mod test {
89 use std::time::Duration;
90
91 use anyhow::{bail, Result};
92 use bonsai_sdk::alpha_async as bonsai_sdk;
93 use risc0_zkvm::compute_image_id;
94 use risc0_zkvm_methods::HELLO_COMMIT_ELF;
95 use tokio::net::TcpListener;
96
97 use crate::serve;
98
99 async fn run_bonsai(bonsai_api_url: String, bonsai_api_key: String, elf: &[u8]) -> Result<()> {
100 let client =
101 bonsai_sdk::get_client_from_parts(bonsai_api_url, bonsai_api_key, risc0_zkvm::VERSION)
102 .await?;
103
104 let image_id = hex::encode(compute_image_id(elf)?);
108 bonsai_sdk::upload_img(client.clone(), image_id.clone(), elf.to_vec()).await?;
109
110 let input_id = bonsai_sdk::upload_input(client.clone(), vec![]).await?;
112
113 let receipts_ids = vec![bonsai_sdk::upload_receipt(client.clone(), vec![]).await?];
115
116 let session =
118 bonsai_sdk::create_session(client.clone(), image_id, input_id, receipts_ids).await?;
119 loop {
120 let res = bonsai_sdk::session_status(client.clone(), session.clone()).await?;
121 if res.status == "RUNNING" {
122 std::thread::sleep(Duration::from_secs(15));
123 continue;
124 }
125 if res.status == "SUCCEEDED" {
126 let receipt_url = res
128 .receipt_url
129 .expect("API error, missing receipt on completed session");
130 bonsai_sdk::download(client.clone(), receipt_url)
131 .await
132 .unwrap();
133 } else {
134 bail!("Error");
135 }
136
137 break;
138 }
139
140 Ok(())
141 }
142
143 #[tokio::test]
144 async fn local_bonsai() {
145 use std::{thread::sleep, time::Duration};
146
147 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
148 let local_addr = listener.local_addr().unwrap();
149 let local_bonsai_handle = tokio::spawn(async move { serve(listener).await });
150
151 sleep(Duration::from_secs(1));
153
154 run_bonsai(
155 format!("http://{local_addr}"),
156 "test_key".to_string(),
157 HELLO_COMMIT_ELF,
158 )
159 .await
160 .unwrap();
161
162 local_bonsai_handle.abort();
163 }
164
165 #[tokio::test]
166 async fn local_bonsai_wrong_elf() {
167 use std::{thread::sleep, time::Duration};
168
169 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
170 let local_addr = listener.local_addr().unwrap();
171 let local_bonsai_handle = tokio::spawn(async move { serve(listener).await });
172
173 sleep(Duration::from_secs(1));
175
176 assert!(run_bonsai(
177 format!("http://{local_addr}"),
178 "test_key".to_string(),
179 b"wrong ELF"
180 )
181 .await
182 .is_err());
183
184 local_bonsai_handle.abort();
185 }
186}