monerochan_cuda/
lib.rs

1use std::{
2    collections::HashMap,
3    error::Error as StdError,
4    future::Future,
5    process::{Command, Stdio},
6    sync::{
7        atomic::{AtomicBool, Ordering},
8        Arc, Mutex,
9    },
10    time::{Duration, Instant},
11};
12
13use crate::proto::api::ProverServiceClient;
14use async_trait::async_trait;
15use proto::api::ReadyRequest;
16use reqwest::{Request, Response};
17use serde::{Deserialize, Serialize};
18use monerochan_core_machine::{io::MONEROCHANStdin, reduce::MONEROCHANReduceProof, utils::MONEROCHANCoreProverError};
19use monerochan_prover::{
20    InnerSC, OuterSC, MONEROCHANCoreProof, MONEROCHANProvingKey, MONEROCHANRecursionProverError, MONEROCHANVerifyingKey,
21};
22use std::sync::LazyLock;
23use tokio::task::block_in_place;
24use twirp::{
25    async_trait,
26    reqwest::{self},
27    url::Url,
28    Client, ClientError, Middleware, Next,
29};
30
31#[rustfmt::skip]
32pub mod proto {
33    pub mod api;
34}
35
36static MOONGATE_CONTAINERS: LazyLock<Mutex<HashMap<String, Arc<AtomicBool>>>> =
37    LazyLock::new(|| Mutex::new(HashMap::new()));
38
39/// A remote client to [monerochan_prover::MONEROCHANProver] that runs inside a container.
40///
41/// This is currently used to provide experimental support for GPU hardware acceleration.
42///
43/// **WARNING**: This is an experimental feature and may not work as expected.
44pub struct MONEROCHANCudaProver {
45    /// The gRPC client to communicate with the container.
46    client: Client,
47    /// The Moongate server container, if managed by the prover.
48    managed_container: Option<CudaProverContainer>,
49}
50
51pub struct CudaProverContainer {
52    /// The name of the container.
53    name: String,
54    /// A flag to indicate whether the container has already been cleaned up.
55    cleaned_up: Arc<AtomicBool>,
56}
57
58/// The payload for the [monerochan_prover::MONEROCHANProver::setup] method.
59///
60/// This object is used to serialize and deserialize the payloads for the Moongate server.
61#[derive(Serialize, Deserialize)]
62pub struct SetupRequestPayload {
63    pub elf: Vec<u8>,
64}
65
66/// The payload for the [monerochan_prover::MONEROCHANProver::setup] method response.
67///
68/// We use this object to serialize and deserialize the payload from the server to the client.
69#[derive(Serialize, Deserialize)]
70pub struct SetupResponsePayload {
71    pub pk: MONEROCHANProvingKey,
72    pub vk: MONEROCHANVerifyingKey,
73}
74
75/// The payload for the [monerochan_prover::MONEROCHANProver::prove_core] method.
76///
77/// This object is used to serialize and deserialize the payloads for the Moongate server.
78#[derive(Serialize, Deserialize)]
79pub struct ProveCoreRequestPayload {
80    /// The input stream.
81    pub stdin: MONEROCHANStdin,
82}
83
84/// The payload for the [monerochan_prover::MONEROCHANProver::stateless_prove_core] method.
85///
86/// This object is used to serialize and deserialize the payloads for the Moongate server.
87/// The proving key is sent in the payload with the request to allow the Moongate server to generate
88/// proofs without re-generating the proving key.
89#[derive(Serialize, Deserialize)]
90pub struct StatelessProveCoreRequestPayload {
91    /// The input stream.
92    pub stdin: MONEROCHANStdin,
93    /// The proving key.
94    pub pk: MONEROCHANProvingKey,
95}
96
97/// The payload for the [monerochan_prover::MONEROCHANProver::compress] method.
98///
99/// This object is used to serialize and deserialize the payloads for the Moongate server.
100#[derive(Serialize, Deserialize)]
101pub struct CompressRequestPayload {
102    /// The verifying key.
103    pub vk: MONEROCHANVerifyingKey,
104    /// The core proof.
105    pub proof: MONEROCHANCoreProof,
106    /// The deferred proofs.
107    pub deferred_proofs: Vec<MONEROCHANReduceProof<InnerSC>>,
108}
109
110/// The payload for the [monerochan_prover::MONEROCHANProver::shrink] method.
111///
112/// This object is used to serialize and deserialize the payloads for the Moongate server.
113#[derive(Serialize, Deserialize)]
114pub struct ShrinkRequestPayload {
115    pub reduced_proof: MONEROCHANReduceProof<InnerSC>,
116}
117
118/// The payload for the [monerochan_prover::MONEROCHANProver::wrap_bn254] method.
119///
120/// This object is used to serialize and deserialize the payloads for the Moongate server.
121#[derive(Serialize, Deserialize)]
122pub struct WrapRequestPayload {
123    pub reduced_proof: MONEROCHANReduceProof<InnerSC>,
124}
125
126/// Defines how the Moongate server is created.
127#[derive(Debug)]
128pub enum MoongateServer {
129    External { endpoint: String },
130    Local { visible_device_index: Option<u64>, port: Option<u64> },
131}
132
133impl Default for MoongateServer {
134    fn default() -> Self {
135        Self::Local { visible_device_index: None, port: None }
136    }
137}
138
139impl MONEROCHANCudaProver {
140    /// Creates a new [MONEROCHANCudaProver] that can be used to communicate with the Moongate server at
141    /// `moongate_endpoint`, or if not provided, create one that runs inside a Docker container.
142    pub fn new(moongate_server: MoongateServer) -> Result<Self, Box<dyn StdError>> {
143        let reqwest_middlewares = vec![Box::new(LoggingMiddleware) as Box<dyn Middleware>];
144
145        let prover = match moongate_server {
146            MoongateServer::External { endpoint } => {
147                let client = Client::new(
148                    Url::parse(&endpoint).expect("failed to parse url"),
149                    reqwest::Client::new(),
150                    reqwest_middlewares,
151                )
152                .expect("failed to create client");
153
154                MONEROCHANCudaProver { client, managed_container: None }
155            }
156            MoongateServer::Local { visible_device_index, port } => {
157                Self::start_moongate_server(reqwest_middlewares, visible_device_index, port)?
158            }
159        };
160
161        let timeout = Duration::from_secs(300);
162        let start_time = Instant::now();
163
164        block_on(async {
165            tracing::info!("waiting for proving server to be ready");
166            loop {
167                if start_time.elapsed() > timeout {
168                    return Err("Timeout: proving server did not become ready within 60 seconds. Please check your Docker container and network settings.".to_string());
169                }
170
171                let request = ReadyRequest {};
172                match prover.client.ready(request).await {
173                    Ok(response) if response.ready => {
174                        tracing::info!("proving server is ready");
175                        break;
176                    }
177                    Ok(_) => {
178                        tracing::info!("proving server is not ready, retrying...");
179                    }
180                    Err(e) => {
181                        tracing::warn!("Error checking server readiness: {}", e);
182                    }
183                }
184                tokio::time::sleep(Duration::from_secs(2)).await;
185            }
186            Ok(())
187        })?;
188
189        Ok(prover)
190    }
191
192    fn check_docker_availability() -> Result<bool, Box<dyn std::error::Error>> {
193        match Command::new("docker").arg("version").output() {
194            Ok(output) => Ok(output.status.success()),
195            Err(_) => Ok(false),
196        }
197    }
198
199    fn start_moongate_server(
200        reqwest_middlewares: Vec<Box<dyn Middleware>>,
201        visible_device_index: Option<u64>,
202        port: Option<u64>,
203    ) -> Result<MONEROCHANCudaProver, Box<dyn StdError>> {
204        // If the moongate endpoint url hasn't been provided, we start the Docker container
205        let container_name = port.map(|p| format!("monerochan-gpu-{p}")).unwrap_or("monerochan-gpu".to_string());
206        let image_name = std::env::var("MONEROCHAN_GPU_IMAGE")
207            .unwrap_or_else(|_| "public.ecr.aws/succinct-labs/moongate:v5.0.8".to_string());
208
209        let cleaned_up = Arc::new(AtomicBool::new(false));
210        let port = port.unwrap_or(3000);
211        let gpus = visible_device_index.map(|i| format!("device={i}")).unwrap_or("all".to_string());
212
213        // Check if Docker is available and the user has necessary permissions
214        if !Self::check_docker_availability()? {
215            return Err("Docker is not available or you don't have the necessary permissions. Please ensure Docker is installed and you are part of the docker group.".into());
216        }
217
218        // Pull the docker image if it's not present
219        if let Err(e) = Command::new("docker").args(["pull", &image_name]).output() {
220            return Err(format!("Failed to pull Docker image: {e}. Please check your internet connection and Docker permissions.").into());
221        }
222
223        // Start the docker container
224        let rust_log_level = std::env::var("RUST_LOG").unwrap_or_else(|_| "none".to_string());
225        Command::new("docker")
226            .args([
227                "run",
228                "-e",
229                &format!("RUST_LOG={rust_log_level}"),
230                "-p",
231                &format!("{port}:3000"),
232                "--rm",
233                "--gpus",
234                &gpus,
235                "--name",
236                &container_name,
237                &image_name,
238            ])
239            // Redirect stdout and stderr to the parent process
240            .stdout(Stdio::inherit())
241            .stderr(Stdio::inherit())
242            .spawn()
243            .map_err(|e| format!("Failed to start Docker container: {e}. Please check your Docker installation and permissions."))?;
244
245        MOONGATE_CONTAINERS.lock()?.insert(container_name.clone(), cleaned_up.clone());
246
247        // Kill the container on control-c
248        // The error returned by set_handler is ignored to avoid panic when the handler has already
249        // been set.
250        let _ = ctrlc::set_handler(move || {
251            tracing::info!("received Ctrl+C, cleaning up...");
252
253            for (container_name, cleanup_flag) in MOONGATE_CONTAINERS.lock().unwrap().iter() {
254                if !cleanup_flag.load(Ordering::SeqCst) {
255                    cleanup_container(container_name);
256                    cleanup_flag.store(true, Ordering::SeqCst);
257                }
258            }
259            std::process::exit(0);
260        });
261
262        // Wait a few seconds for the container to start
263        std::thread::sleep(Duration::from_secs(2));
264
265        let client = Client::new(
266            Url::parse(&format!("http://localhost:{port}/twirp/")).expect("failed to parse url"),
267            reqwest::Client::new(),
268            reqwest_middlewares,
269        )
270        .expect("failed to create client");
271
272        Ok(MONEROCHANCudaProver {
273            client,
274            managed_container: Some(CudaProverContainer { name: container_name, cleaned_up }),
275        })
276    }
277
278    /// Executes the [monerochan_prover::MONEROCHANProver::setup] method inside the container.
279    pub fn setup(&self, elf: &[u8]) -> Result<(MONEROCHANProvingKey, MONEROCHANVerifyingKey), Box<dyn StdError>> {
280        let payload = SetupRequestPayload { elf: elf.to_vec() };
281        let request =
282            crate::proto::api::SetupRequest { data: bincode::serialize(&payload).unwrap() };
283        let response = block_on(async { self.client.setup(request).await }).unwrap();
284        let payload: SetupResponsePayload = bincode::deserialize(&response.result).unwrap();
285        Ok((payload.pk, payload.vk))
286    }
287
288    /// Executes the [monerochan_prover::MONEROCHANProver::prove_core] method inside the container.
289    ///
290    /// You will need at least 24GB of VRAM to run this method.
291    pub fn prove_core(&self, stdin: &MONEROCHANStdin) -> Result<MONEROCHANCoreProof, MONEROCHANCoreProverError> {
292        let payload = ProveCoreRequestPayload { stdin: stdin.clone() };
293        let request =
294            crate::proto::api::ProveCoreRequest { data: bincode::serialize(&payload).unwrap() };
295        let response = block_on(async { self.client.prove_core(request).await }).unwrap();
296        let proof: MONEROCHANCoreProof = bincode::deserialize(&response.result).unwrap();
297        Ok(proof)
298    }
299
300    /// Executes the [monerochan_prover::MONEROCHANProver::prove_core] method inside the container.
301    ///
302    /// You will need at least 24GB of VRAM to run this method.
303    pub fn prove_core_stateless(
304        &self,
305        pk: &MONEROCHANProvingKey,
306        stdin: &MONEROCHANStdin,
307    ) -> Result<MONEROCHANCoreProof, MONEROCHANCoreProverError> {
308        let payload = StatelessProveCoreRequestPayload { pk: pk.clone(), stdin: stdin.clone() };
309        let request =
310            crate::proto::api::ProveCoreRequest { data: bincode::serialize(&payload).unwrap() };
311        let response = block_on(async { self.client.prove_core_stateless(request).await }).unwrap();
312        let proof: MONEROCHANCoreProof = bincode::deserialize(&response.result).unwrap();
313        Ok(proof)
314    }
315
316    /// Executes the [monerochan_prover::MONEROCHANProver::compress] method inside the container.
317    ///
318    /// You will need at least 24GB of VRAM to run this method.
319    pub fn compress(
320        &self,
321        vk: &MONEROCHANVerifyingKey,
322        proof: MONEROCHANCoreProof,
323        deferred_proofs: Vec<MONEROCHANReduceProof<InnerSC>>,
324    ) -> Result<MONEROCHANReduceProof<InnerSC>, MONEROCHANRecursionProverError> {
325        let payload = CompressRequestPayload { vk: vk.clone(), proof, deferred_proofs };
326        let request =
327            crate::proto::api::CompressRequest { data: bincode::serialize(&payload).unwrap() };
328
329        let response = block_on(async { self.client.compress(request).await }).unwrap();
330        let proof: MONEROCHANReduceProof<InnerSC> = bincode::deserialize(&response.result).unwrap();
331        Ok(proof)
332    }
333
334    /// Executes the [monerochan_prover::MONEROCHANProver::shrink] method inside the container.
335    ///
336    /// You will need at least 24GB of VRAM to run this method.
337    pub fn shrink(
338        &self,
339        reduced_proof: MONEROCHANReduceProof<InnerSC>,
340    ) -> Result<MONEROCHANReduceProof<InnerSC>, MONEROCHANRecursionProverError> {
341        let payload = ShrinkRequestPayload { reduced_proof: reduced_proof.clone() };
342        let request =
343            crate::proto::api::ShrinkRequest { data: bincode::serialize(&payload).unwrap() };
344
345        let response = block_on(async { self.client.shrink(request).await }).unwrap();
346        let proof: MONEROCHANReduceProof<InnerSC> = bincode::deserialize(&response.result).unwrap();
347        Ok(proof)
348    }
349
350    /// Executes the [monerochan_prover::MONEROCHANProver::wrap_bn254] method inside the container.
351    ///
352    /// You will need at least 24GB of VRAM to run this method.
353    pub fn wrap_bn254(
354        &self,
355        reduced_proof: MONEROCHANReduceProof<InnerSC>,
356    ) -> Result<MONEROCHANReduceProof<OuterSC>, MONEROCHANRecursionProverError> {
357        let payload = WrapRequestPayload { reduced_proof: reduced_proof.clone() };
358        let request =
359            crate::proto::api::WrapRequest { data: bincode::serialize(&payload).unwrap() };
360
361        let response = block_on(async { self.client.wrap(request).await }).unwrap();
362        let proof: MONEROCHANReduceProof<OuterSC> = bincode::deserialize(&response.result).unwrap();
363        Ok(proof)
364    }
365}
366
367impl Default for MONEROCHANCudaProver {
368    fn default() -> Self {
369        Self::new(Default::default()).expect("Failed to create MONEROCHANCudaProver")
370    }
371}
372
373impl Drop for MONEROCHANCudaProver {
374    fn drop(&mut self) {
375        if let Some(container) = &self.managed_container {
376            if !container.cleaned_up.load(Ordering::SeqCst) {
377                tracing::debug!("dropping MONEROCHANProverClient, cleaning up...");
378                cleanup_container(&container.name);
379                container.cleaned_up.store(true, Ordering::SeqCst);
380            }
381        }
382    }
383}
384
385/// Cleans up the a docker container with the given name.
386fn cleanup_container(container_name: &str) {
387    if let Err(e) = Command::new("docker").args(["rm", "-f", container_name]).output() {
388        eprintln!(
389            "Failed to remove container: {e}. You may need to manually remove it using 'docker rm -f {container_name}'"
390        );
391    }
392}
393
394/// Utility method for blocking on an async function.
395///
396/// If we're already in a tokio runtime, we'll block in place. Otherwise, we'll create a new
397/// runtime.
398pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
399    // Handle case if we're already in an tokio runtime.
400    if let Ok(handle) = tokio::runtime::Handle::try_current() {
401        block_in_place(|| handle.block_on(fut))
402    } else {
403        // Otherwise create a new runtime.
404        let rt = tokio::runtime::Runtime::new().expect("Failed to create a new runtime");
405        rt.block_on(fut)
406    }
407}
408
409struct LoggingMiddleware;
410
411pub type Result<T, E = ClientError> = std::result::Result<T, E>;
412
413#[async_trait]
414impl Middleware for LoggingMiddleware {
415    async fn handle(&self, req: Request, next: Next<'_>) -> Result<Response> {
416        let response = next.run(req).await;
417        match response {
418            Ok(response) => {
419                tracing::info!("{:?}", response);
420                Ok(response)
421            }
422            Err(e) => Err(e),
423        }
424    }
425}
426
427// #[cfg(feature = "protobuf")]
428// #[cfg(test)]
429// mod tests {
430//     use monerochan_core_machine::{
431//         reduce::MONEROCHANReduceProof,
432//         utils::{setup_logger, tests::FIBONACCI_ELF},
433//     };
434//     use monerochan_prover::{components::DefaultProverComponents, InnerSC, MONEROCHANCoreProof, MONEROCHANProver};
435//     use twirp::{url::Url, Client};
436
437//     use crate::{
438//         proto::api::ProverServiceClient, CompressRequestPayload, ProveCoreRequestPayload,
439//         MONEROCHANCudaProver, MONEROCHANStdin,
440//     };
441
442//     #[test]
443//     fn test_client() {
444//         setup_logger();
445
446//         let prover = MONEROCHANProver::<DefaultProverComponents>::new();
447//         let client = MONEROCHANCudaProver::new().expect("Failed to create MONEROCHANCudaProver");
448//         let (pk, vk) = prover.setup(FIBONACCI_ELF);
449
450//         println!("proving core");
451//         let proof = client.prove_core(&pk, &MONEROCHANStdin::new()).unwrap();
452
453//         println!("verifying core");
454//         prover.verify(&proof.proof, &vk).unwrap();
455
456//         println!("proving compress");
457//         let proof = client.compress(&vk, proof, vec![]).unwrap();
458
459//         println!("verifying compress");
460//         prover.verify_compressed(&proof, &vk).unwrap();
461
462//         println!("proving shrink");
463//         let proof = client.shrink(proof).unwrap();
464
465//         println!("verifying shrink");
466//         prover.verify_shrink(&proof, &vk).unwrap();
467
468//         println!("proving wrap_bn254");
469//         let proof = client.wrap_bn254(proof).unwrap();
470
471//         println!("verifying wrap_bn254");
472//         prover.verify_wrap_bn254(&proof, &vk).unwrap();
473//     }
474
475//     #[tokio::test]
476//     async fn test_prove_core() {
477//         let client =
478//             Client::from_base_url(Url::parse("http://localhost:3000/twirp/").unwrap()).unwrap();
479
480//         let prover = MONEROCHANProver::<DefaultProverComponents>::new();
481//         let (pk, vk) = prover.setup(FIBONACCI_ELF);
482//         let payload = ProveCoreRequestPayload { pk, stdin: MONEROCHANStdin::new() };
483//         let request =
484//             crate::proto::api::ProveCoreRequest { data: bincode::serialize(&payload).unwrap() };
485//         let proof = client.prove_core(request).await.unwrap();
486//         let proof: MONEROCHANCoreProof = bincode::deserialize(&proof.result).unwrap();
487//         prover.verify(&proof.proof, &vk).unwrap();
488
489//         tracing::info!("compress");
490//         let payload = CompressRequestPayload { vk: vk.clone(), proof, deferred_proofs: vec![] };
491//         let request =
492//             crate::proto::api::CompressRequest { data: bincode::serialize(&payload).unwrap() };
493//         let compressed_proof = client.compress(request).await.unwrap();
494//         let compressed_proof: MONEROCHANReduceProof<InnerSC> =
495//             bincode::deserialize(&compressed_proof.result).unwrap();
496
497//         tracing::info!("verify compressed");
498//         prover.verify_compressed(&compressed_proof, &vk).unwrap();
499//     }
500// }