1use std::{
2 collections::HashMap,
3 error::Error as StdError,
4 future::Future,
5 process::{Command, Stdio},
6 sync::{
7 atomic::{AtomicBool, Ordering},
8 Arc, Mutex,
9 },
10 time::{Duration, Instant},
11};
12
13use crate::proto::api::ProverServiceClient;
14use async_trait::async_trait;
15use proto::api::ReadyRequest;
16use reqwest::{Request, Response};
17use serde::{Deserialize, Serialize};
18use monerochan_core_machine::{io::MONEROCHANStdin, reduce::MONEROCHANReduceProof, utils::MONEROCHANCoreProverError};
19use monerochan_prover::{
20 InnerSC, OuterSC, MONEROCHANCoreProof, MONEROCHANProvingKey, MONEROCHANRecursionProverError, MONEROCHANVerifyingKey,
21};
22use std::sync::LazyLock;
23use tokio::task::block_in_place;
24use twirp::{
25 async_trait,
26 reqwest::{self},
27 url::Url,
28 Client, ClientError, Middleware, Next,
29};
30
31#[rustfmt::skip]
32pub mod proto {
33 pub mod api;
34}
35
36static MOONGATE_CONTAINERS: LazyLock<Mutex<HashMap<String, Arc<AtomicBool>>>> =
37 LazyLock::new(|| Mutex::new(HashMap::new()));
38
39pub struct MONEROCHANCudaProver {
45 client: Client,
47 managed_container: Option<CudaProverContainer>,
49}
50
51pub struct CudaProverContainer {
52 name: String,
54 cleaned_up: Arc<AtomicBool>,
56}
57
58#[derive(Serialize, Deserialize)]
62pub struct SetupRequestPayload {
63 pub elf: Vec<u8>,
64}
65
66#[derive(Serialize, Deserialize)]
70pub struct SetupResponsePayload {
71 pub pk: MONEROCHANProvingKey,
72 pub vk: MONEROCHANVerifyingKey,
73}
74
75#[derive(Serialize, Deserialize)]
79pub struct ProveCoreRequestPayload {
80 pub stdin: MONEROCHANStdin,
82}
83
84#[derive(Serialize, Deserialize)]
90pub struct StatelessProveCoreRequestPayload {
91 pub stdin: MONEROCHANStdin,
93 pub pk: MONEROCHANProvingKey,
95}
96
97#[derive(Serialize, Deserialize)]
101pub struct CompressRequestPayload {
102 pub vk: MONEROCHANVerifyingKey,
104 pub proof: MONEROCHANCoreProof,
106 pub deferred_proofs: Vec<MONEROCHANReduceProof<InnerSC>>,
108}
109
110#[derive(Serialize, Deserialize)]
114pub struct ShrinkRequestPayload {
115 pub reduced_proof: MONEROCHANReduceProof<InnerSC>,
116}
117
118#[derive(Serialize, Deserialize)]
122pub struct WrapRequestPayload {
123 pub reduced_proof: MONEROCHANReduceProof<InnerSC>,
124}
125
126#[derive(Debug)]
128pub enum MoongateServer {
129 External { endpoint: String },
130 Local { visible_device_index: Option<u64>, port: Option<u64> },
131}
132
133impl Default for MoongateServer {
134 fn default() -> Self {
135 Self::Local { visible_device_index: None, port: None }
136 }
137}
138
139impl MONEROCHANCudaProver {
140 pub fn new(moongate_server: MoongateServer) -> Result<Self, Box<dyn StdError>> {
143 let reqwest_middlewares = vec![Box::new(LoggingMiddleware) as Box<dyn Middleware>];
144
145 let prover = match moongate_server {
146 MoongateServer::External { endpoint } => {
147 let client = Client::new(
148 Url::parse(&endpoint).expect("failed to parse url"),
149 reqwest::Client::new(),
150 reqwest_middlewares,
151 )
152 .expect("failed to create client");
153
154 MONEROCHANCudaProver { client, managed_container: None }
155 }
156 MoongateServer::Local { visible_device_index, port } => {
157 Self::start_moongate_server(reqwest_middlewares, visible_device_index, port)?
158 }
159 };
160
161 let timeout = Duration::from_secs(300);
162 let start_time = Instant::now();
163
164 block_on(async {
165 tracing::info!("waiting for proving server to be ready");
166 loop {
167 if start_time.elapsed() > timeout {
168 return Err("Timeout: proving server did not become ready within 60 seconds. Please check your Docker container and network settings.".to_string());
169 }
170
171 let request = ReadyRequest {};
172 match prover.client.ready(request).await {
173 Ok(response) if response.ready => {
174 tracing::info!("proving server is ready");
175 break;
176 }
177 Ok(_) => {
178 tracing::info!("proving server is not ready, retrying...");
179 }
180 Err(e) => {
181 tracing::warn!("Error checking server readiness: {}", e);
182 }
183 }
184 tokio::time::sleep(Duration::from_secs(2)).await;
185 }
186 Ok(())
187 })?;
188
189 Ok(prover)
190 }
191
192 fn check_docker_availability() -> Result<bool, Box<dyn std::error::Error>> {
193 match Command::new("docker").arg("version").output() {
194 Ok(output) => Ok(output.status.success()),
195 Err(_) => Ok(false),
196 }
197 }
198
199 fn start_moongate_server(
200 reqwest_middlewares: Vec<Box<dyn Middleware>>,
201 visible_device_index: Option<u64>,
202 port: Option<u64>,
203 ) -> Result<MONEROCHANCudaProver, Box<dyn StdError>> {
204 let container_name = port.map(|p| format!("monerochan-gpu-{p}")).unwrap_or("monerochan-gpu".to_string());
206 let image_name = std::env::var("MONEROCHAN_GPU_IMAGE")
207 .unwrap_or_else(|_| "public.ecr.aws/succinct-labs/moongate:v5.0.8".to_string());
208
209 let cleaned_up = Arc::new(AtomicBool::new(false));
210 let port = port.unwrap_or(3000);
211 let gpus = visible_device_index.map(|i| format!("device={i}")).unwrap_or("all".to_string());
212
213 if !Self::check_docker_availability()? {
215 return Err("Docker is not available or you don't have the necessary permissions. Please ensure Docker is installed and you are part of the docker group.".into());
216 }
217
218 if let Err(e) = Command::new("docker").args(["pull", &image_name]).output() {
220 return Err(format!("Failed to pull Docker image: {e}. Please check your internet connection and Docker permissions.").into());
221 }
222
223 let rust_log_level = std::env::var("RUST_LOG").unwrap_or_else(|_| "none".to_string());
225 Command::new("docker")
226 .args([
227 "run",
228 "-e",
229 &format!("RUST_LOG={rust_log_level}"),
230 "-p",
231 &format!("{port}:3000"),
232 "--rm",
233 "--gpus",
234 &gpus,
235 "--name",
236 &container_name,
237 &image_name,
238 ])
239 .stdout(Stdio::inherit())
241 .stderr(Stdio::inherit())
242 .spawn()
243 .map_err(|e| format!("Failed to start Docker container: {e}. Please check your Docker installation and permissions."))?;
244
245 MOONGATE_CONTAINERS.lock()?.insert(container_name.clone(), cleaned_up.clone());
246
247 let _ = ctrlc::set_handler(move || {
251 tracing::info!("received Ctrl+C, cleaning up...");
252
253 for (container_name, cleanup_flag) in MOONGATE_CONTAINERS.lock().unwrap().iter() {
254 if !cleanup_flag.load(Ordering::SeqCst) {
255 cleanup_container(container_name);
256 cleanup_flag.store(true, Ordering::SeqCst);
257 }
258 }
259 std::process::exit(0);
260 });
261
262 std::thread::sleep(Duration::from_secs(2));
264
265 let client = Client::new(
266 Url::parse(&format!("http://localhost:{port}/twirp/")).expect("failed to parse url"),
267 reqwest::Client::new(),
268 reqwest_middlewares,
269 )
270 .expect("failed to create client");
271
272 Ok(MONEROCHANCudaProver {
273 client,
274 managed_container: Some(CudaProverContainer { name: container_name, cleaned_up }),
275 })
276 }
277
278 pub fn setup(&self, elf: &[u8]) -> Result<(MONEROCHANProvingKey, MONEROCHANVerifyingKey), Box<dyn StdError>> {
280 let payload = SetupRequestPayload { elf: elf.to_vec() };
281 let request =
282 crate::proto::api::SetupRequest { data: bincode::serialize(&payload).unwrap() };
283 let response = block_on(async { self.client.setup(request).await }).unwrap();
284 let payload: SetupResponsePayload = bincode::deserialize(&response.result).unwrap();
285 Ok((payload.pk, payload.vk))
286 }
287
288 pub fn prove_core(&self, stdin: &MONEROCHANStdin) -> Result<MONEROCHANCoreProof, MONEROCHANCoreProverError> {
292 let payload = ProveCoreRequestPayload { stdin: stdin.clone() };
293 let request =
294 crate::proto::api::ProveCoreRequest { data: bincode::serialize(&payload).unwrap() };
295 let response = block_on(async { self.client.prove_core(request).await }).unwrap();
296 let proof: MONEROCHANCoreProof = bincode::deserialize(&response.result).unwrap();
297 Ok(proof)
298 }
299
300 pub fn prove_core_stateless(
304 &self,
305 pk: &MONEROCHANProvingKey,
306 stdin: &MONEROCHANStdin,
307 ) -> Result<MONEROCHANCoreProof, MONEROCHANCoreProverError> {
308 let payload = StatelessProveCoreRequestPayload { pk: pk.clone(), stdin: stdin.clone() };
309 let request =
310 crate::proto::api::ProveCoreRequest { data: bincode::serialize(&payload).unwrap() };
311 let response = block_on(async { self.client.prove_core_stateless(request).await }).unwrap();
312 let proof: MONEROCHANCoreProof = bincode::deserialize(&response.result).unwrap();
313 Ok(proof)
314 }
315
316 pub fn compress(
320 &self,
321 vk: &MONEROCHANVerifyingKey,
322 proof: MONEROCHANCoreProof,
323 deferred_proofs: Vec<MONEROCHANReduceProof<InnerSC>>,
324 ) -> Result<MONEROCHANReduceProof<InnerSC>, MONEROCHANRecursionProverError> {
325 let payload = CompressRequestPayload { vk: vk.clone(), proof, deferred_proofs };
326 let request =
327 crate::proto::api::CompressRequest { data: bincode::serialize(&payload).unwrap() };
328
329 let response = block_on(async { self.client.compress(request).await }).unwrap();
330 let proof: MONEROCHANReduceProof<InnerSC> = bincode::deserialize(&response.result).unwrap();
331 Ok(proof)
332 }
333
334 pub fn shrink(
338 &self,
339 reduced_proof: MONEROCHANReduceProof<InnerSC>,
340 ) -> Result<MONEROCHANReduceProof<InnerSC>, MONEROCHANRecursionProverError> {
341 let payload = ShrinkRequestPayload { reduced_proof: reduced_proof.clone() };
342 let request =
343 crate::proto::api::ShrinkRequest { data: bincode::serialize(&payload).unwrap() };
344
345 let response = block_on(async { self.client.shrink(request).await }).unwrap();
346 let proof: MONEROCHANReduceProof<InnerSC> = bincode::deserialize(&response.result).unwrap();
347 Ok(proof)
348 }
349
350 pub fn wrap_bn254(
354 &self,
355 reduced_proof: MONEROCHANReduceProof<InnerSC>,
356 ) -> Result<MONEROCHANReduceProof<OuterSC>, MONEROCHANRecursionProverError> {
357 let payload = WrapRequestPayload { reduced_proof: reduced_proof.clone() };
358 let request =
359 crate::proto::api::WrapRequest { data: bincode::serialize(&payload).unwrap() };
360
361 let response = block_on(async { self.client.wrap(request).await }).unwrap();
362 let proof: MONEROCHANReduceProof<OuterSC> = bincode::deserialize(&response.result).unwrap();
363 Ok(proof)
364 }
365}
366
367impl Default for MONEROCHANCudaProver {
368 fn default() -> Self {
369 Self::new(Default::default()).expect("Failed to create MONEROCHANCudaProver")
370 }
371}
372
373impl Drop for MONEROCHANCudaProver {
374 fn drop(&mut self) {
375 if let Some(container) = &self.managed_container {
376 if !container.cleaned_up.load(Ordering::SeqCst) {
377 tracing::debug!("dropping MONEROCHANProverClient, cleaning up...");
378 cleanup_container(&container.name);
379 container.cleaned_up.store(true, Ordering::SeqCst);
380 }
381 }
382 }
383}
384
385fn cleanup_container(container_name: &str) {
387 if let Err(e) = Command::new("docker").args(["rm", "-f", container_name]).output() {
388 eprintln!(
389 "Failed to remove container: {e}. You may need to manually remove it using 'docker rm -f {container_name}'"
390 );
391 }
392}
393
394pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
399 if let Ok(handle) = tokio::runtime::Handle::try_current() {
401 block_in_place(|| handle.block_on(fut))
402 } else {
403 let rt = tokio::runtime::Runtime::new().expect("Failed to create a new runtime");
405 rt.block_on(fut)
406 }
407}
408
409struct LoggingMiddleware;
410
411pub type Result<T, E = ClientError> = std::result::Result<T, E>;
412
413#[async_trait]
414impl Middleware for LoggingMiddleware {
415 async fn handle(&self, req: Request, next: Next<'_>) -> Result<Response> {
416 let response = next.run(req).await;
417 match response {
418 Ok(response) => {
419 tracing::info!("{:?}", response);
420 Ok(response)
421 }
422 Err(e) => Err(e),
423 }
424 }
425}
426
427