light_prover_client/
proof_client.rs

1use std::time::{Duration, Instant};
2
3use reqwest::Client;
4use serde::Deserialize;
5use tokio::time::sleep;
6use tracing::{debug, error, info, trace, warn};
7
8use crate::{
9    constants::PROVE_PATH,
10    errors::ProverClientError,
11    proof::{
12        compress_proof, deserialize_gnark_proof_json, proof_from_json_struct, ProofCompressed,
13    },
14    proof_types::{
15        batch_address_append::{to_json, BatchAddressAppendInputs},
16        batch_append::{BatchAppendInputsJson, BatchAppendsCircuitInputs},
17        batch_update::{update_inputs_string, BatchUpdateCircuitInputs},
18    },
19};
20
21const MAX_RETRIES: u32 = 10;
22const BASE_RETRY_DELAY_SECS: u64 = 1;
23const DEFAULT_POLLING_INTERVAL_SECS: u64 = 1;
24const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 600;
25const DEFAULT_LOCAL_SERVER: &str = "http://localhost:3001";
26
27#[derive(Debug, Deserialize)]
28#[serde(untagged)]
29pub enum ProofResponse {
30    Async {
31        job_id: String,
32        estimated_time: Option<String>,
33    },
34}
35
36#[derive(Debug, Deserialize)]
37pub struct JobStatusResponse {
38    pub status: String,
39    pub message: Option<String>,
40    pub result: Option<serde_json::Value>,
41}
42
43#[derive(Debug, Deserialize)]
44pub struct ErrorResponse {
45    pub code: String,
46    pub message: String,
47}
48
49pub struct ProofClient {
50    client: Client,
51    server_address: String,
52    polling_interval: Duration,
53    max_wait_time: Duration,
54    api_key: Option<String>,
55}
56
57impl ProofClient {
58    pub fn local() -> Self {
59        Self {
60            client: Client::new(),
61            server_address: DEFAULT_LOCAL_SERVER.to_string(),
62            polling_interval: Duration::from_secs(DEFAULT_POLLING_INTERVAL_SECS),
63            max_wait_time: Duration::from_secs(DEFAULT_MAX_WAIT_TIME_SECS),
64            api_key: None,
65        }
66    }
67
68    #[allow(unused)]
69    pub fn with_config(
70        server_address: String,
71        polling_interval: Duration,
72        max_wait_time: Duration,
73        api_key: Option<String>,
74    ) -> Self {
75        Self {
76            client: Client::new(),
77            server_address,
78            polling_interval,
79            max_wait_time,
80            api_key,
81        }
82    }
83
84    pub async fn generate_proof(
85        &self,
86        inputs_json: String,
87    ) -> Result<ProofCompressed, ProverClientError> {
88        let start_time = Instant::now();
89        let mut retries = 0;
90
91        loop {
92            let elapsed = start_time.elapsed();
93            if elapsed > self.max_wait_time {
94                return Err(ProverClientError::ProverServerError(format!(
95                    "Overall proof generation timed out after {:?} (max: {:?}), retries: {}",
96                    elapsed, self.max_wait_time, retries
97                )));
98            }
99
100            match self.try_generate_proof(&inputs_json, elapsed).await {
101                Ok(proof) => return Ok(proof),
102                Err(err) if self.should_retry(&err, retries, elapsed) => {
103                    retries += 1;
104                    let retry_delay = Duration::from_secs(BASE_RETRY_DELAY_SECS * retries as u64);
105
106                    if elapsed + retry_delay > self.max_wait_time {
107                        warn!(
108                            "Skipping retry due to max wait time constraint: elapsed={:?}, retry_delay={:?}, max_wait={:?}",
109                            elapsed, retry_delay, self.max_wait_time
110                        );
111                        return Err(err);
112                    }
113
114                    warn!(
115                        "Retrying proof generation ({}/{}) after {:?} due to: {}",
116                        retries, MAX_RETRIES, retry_delay, err
117                    );
118                    sleep(retry_delay).await;
119                }
120                Err(err) => {
121                    debug!(
122                        "Not retrying error (retries={}, elapsed={:?}): {}",
123                        retries, elapsed, err
124                    );
125                    return Err(err);
126                }
127            }
128        }
129    }
130
131    async fn try_generate_proof(
132        &self,
133        inputs_json: &str,
134        elapsed: Duration,
135    ) -> Result<ProofCompressed, ProverClientError> {
136        let response = self.send_proof_request(inputs_json).await?;
137        let status_code = response.status();
138        let response_text = response.text().await.map_err(|e| {
139            ProverClientError::ProverServerError(format!("Failed to read response body: {}", e))
140        })?;
141
142        self.log_response(status_code, &response_text);
143        self.handle_proof_response(status_code, &response_text, elapsed)
144            .await
145    }
146
147    async fn send_proof_request(
148        &self,
149        inputs_json: &str,
150    ) -> Result<reqwest::Response, ProverClientError> {
151        let url = format!("{}{}", self.server_address, PROVE_PATH);
152
153        let mut request = self
154            .client
155            .post(&url)
156            .header("Content-Type", "application/json");
157
158        if let Some(api_key) = &self.api_key {
159            request = request.header("X-API-Key", api_key);
160        }
161
162        request
163            .body(inputs_json.to_string())
164            .send()
165            .await
166            .map_err(|e| {
167                ProverClientError::ProverServerError(format!(
168                    "Failed to send request to prover server: {}",
169                    e
170                ))
171            })
172    }
173
174    fn log_response(&self, status_code: reqwest::StatusCode, response_text: &str) {
175        if !status_code.is_success() {
176            error!("HTTP error: status={}, body={}", status_code, response_text);
177        }
178    }
179
180    async fn handle_proof_response(
181        &self,
182        status_code: reqwest::StatusCode,
183        response_text: &str,
184        start_elapsed: Duration,
185    ) -> Result<ProofCompressed, ProverClientError> {
186        match status_code {
187            reqwest::StatusCode::OK => self.parse_proof_from_json(response_text),
188            reqwest::StatusCode::ACCEPTED => {
189                let job_response = self.parse_job_response(response_text)?;
190                self.handle_async_job(job_response, start_elapsed).await
191            }
192            _ => self.handle_error_response(response_text),
193        }
194    }
195
196    fn parse_job_response(&self, response_text: &str) -> Result<ProofResponse, ProverClientError> {
197        serde_json::from_str(response_text).map_err(|e| {
198            error!("Failed to parse async response: {}", e);
199            ProverClientError::ProverServerError(format!("Failed to parse async response: {}", e))
200        })
201    }
202
203    async fn handle_async_job(
204        &self,
205        job_response: ProofResponse,
206        start_elapsed: Duration,
207    ) -> Result<ProofCompressed, ProverClientError> {
208        match job_response {
209            ProofResponse::Async { job_id, .. } => {
210                info!("Proof job queued with ID: {}", job_id);
211                self.poll_for_result(&job_id, start_elapsed).await
212            }
213        }
214    }
215
216    fn handle_error_response(
217        &self,
218        response_text: &str,
219    ) -> Result<ProofCompressed, ProverClientError> {
220        if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(response_text) {
221            error!(
222                "Prover server error: {} - {}",
223                error_response.code, error_response.message
224            );
225            Err(ProverClientError::ProverServerError(format!(
226                "Prover server error: {} - {}",
227                error_response.code, error_response.message
228            )))
229        } else {
230            error!("Prover server error: {}", response_text);
231            Err(ProverClientError::ProverServerError(format!(
232                "Prover server error: {}",
233                response_text
234            )))
235        }
236    }
237
238    fn should_retry(&self, error: &ProverClientError, retries: u32, elapsed: Duration) -> bool {
239        let error_str = error.to_string();
240
241        let is_constraint_error =
242            error_str.contains("constraint") || error_str.contains("is not satisfied");
243        if is_constraint_error {
244            return false;
245        }
246
247        let is_retryable_error = error_str.contains("job_not_found")
248            || error_str.contains("connection")
249            || error_str.contains("timeout")
250            || error_str.contains("503")
251            || error_str.contains("502")
252            || error_str.contains("500");
253        let should_retry =
254            retries < MAX_RETRIES && is_retryable_error && elapsed < self.max_wait_time;
255
256        trace!(
257            "Retry check: retries={}/{}, is_retryable_error={}, elapsed={:?}/{:?}, should_retry={}, error={}",
258            retries, MAX_RETRIES, is_retryable_error, elapsed, self.max_wait_time, should_retry, error_str
259        );
260
261        should_retry
262    }
263
264    async fn poll_for_result(
265        &self,
266        job_id: &str,
267        start_elapsed: Duration,
268    ) -> Result<ProofCompressed, ProverClientError> {
269        let poll_start_time = Instant::now();
270        let status_url = format!("{}/prove/status?job_id={}", self.server_address, job_id);
271
272        info!("Starting to poll for job {} at URL: {}", job_id, status_url);
273
274        let mut poll_count = 0;
275        let mut transient_error_count = 0;
276
277        loop {
278            poll_count += 1;
279            let poll_elapsed = poll_start_time.elapsed();
280            let total_elapsed = start_elapsed + poll_elapsed;
281
282            if total_elapsed > self.max_wait_time {
283                return Err(ProverClientError::ProverServerError(format!(
284                    "Job {} timed out after {:?} total (max: {:?}), polling time: {:?}, total polls: {}",
285                    job_id, total_elapsed, self.max_wait_time, poll_elapsed, poll_count
286                )));
287            }
288
289            trace!(
290                "Poll #{} for job {} at total elapsed time {:?} (polling: {:?})",
291                poll_count,
292                job_id,
293                total_elapsed,
294                poll_elapsed
295            );
296
297            match self.poll_job_status(&status_url, job_id, poll_count).await {
298                Ok(response) => {
299                    transient_error_count = 0;
300
301                    if let Some(proof) = self
302                        .handle_job_status(response, job_id, total_elapsed, poll_count)
303                        .await?
304                    {
305                        return Ok(proof);
306                    }
307
308                    if total_elapsed + self.polling_interval > self.max_wait_time {
309                        warn!(
310                            "Skipping polling interval due to max wait time constraint: total_elapsed={:?}, polling_interval={:?}, max_wait={:?}",
311                            total_elapsed, self.polling_interval, self.max_wait_time
312                        );
313                        return Err(ProverClientError::ProverServerError(format!(
314                            "Job {} polling stopped due to max wait time constraint",
315                            job_id
316                        )));
317                    }
318
319                    sleep(self.polling_interval).await;
320                }
321                Err(err) if self.is_job_not_found_error(&err) => {
322                    error!(
323                        "Job {} not found during polling - will retry with new proof request at higher level: {}",
324                        job_id, err
325                    );
326                    return Err(err);
327                }
328                Err(err) if self.is_transient_polling_error(&err) => {
329                    transient_error_count += 1;
330
331                    trace!(
332                        "Transient polling error for job {}: attempt {}/{}, error: {}",
333                        job_id,
334                        transient_error_count,
335                        MAX_RETRIES,
336                        err
337                    );
338
339                    if transient_error_count >= MAX_RETRIES {
340                        error!(
341                            "Job {} polling failed after {} transient errors, giving up",
342                            job_id, transient_error_count
343                        );
344                        return Err(err);
345                    }
346
347                    let retry_delay =
348                        Duration::from_secs(BASE_RETRY_DELAY_SECS * transient_error_count as u64);
349
350                    if total_elapsed + retry_delay > self.max_wait_time {
351                        warn!(
352                            "Skipping transient error retry due to max wait time constraint: total_elapsed={:?}, retry_delay={:?}, max_wait={:?}",
353                            total_elapsed, retry_delay, self.max_wait_time
354                        );
355                        return Err(err);
356                    }
357
358                    warn!(
359                        "Job {} transient error (attempt {}/{}), retrying after {:?}",
360                        job_id, transient_error_count, MAX_RETRIES, retry_delay
361                    );
362                    sleep(retry_delay).await;
363                }
364                Err(err) => {
365                    debug!("Not retrying polling error for job {}: {}", job_id, err);
366                    return Err(err);
367                }
368            }
369        }
370    }
371
372    async fn poll_job_status(
373        &self,
374        status_url: &str,
375        job_id: &str,
376        poll_count: u32,
377    ) -> Result<JobStatusResponse, ProverClientError> {
378        let mut request = self.client.get(status_url);
379
380        if let Some(api_key) = &self.api_key {
381            request = request.header("X-API-Key", api_key);
382        }
383
384        let response = request.send().await.map_err(|e| {
385            error!("Failed to send status request for job {}: {}", job_id, e);
386            ProverClientError::ProverServerError(format!("Failed to check job status: {}", e))
387        })?;
388
389        let status_code = response.status();
390        let response_text = response.text().await.unwrap_or_default();
391
392        trace!(
393            "Poll #{} for job {}: status={}, body_len={}",
394            poll_count,
395            job_id,
396            status_code,
397            response_text.len()
398        );
399
400        if !status_code.is_success() {
401            return Err(ProverClientError::ProverServerError(format!(
402                "HTTP error while polling for result: status={}, body={}",
403                status_code, response_text
404            )));
405        }
406
407        serde_json::from_str(&response_text).map_err(|e| {
408            error!(
409                "Failed to parse status response on poll #{} for job {}: error={}, body={}",
410                poll_count, job_id, e, response_text
411            );
412            ProverClientError::ProverServerError(format!(
413                "Failed to parse status response: {}, body: {}",
414                e, response_text
415            ))
416        })
417    }
418
419    async fn handle_job_status(
420        &self,
421        status_response: JobStatusResponse,
422        job_id: &str,
423        elapsed: Duration,
424        poll_count: u32,
425    ) -> Result<Option<ProofCompressed>, ProverClientError> {
426        trace!(
427            "Poll #{} for job {}: status='{}', message='{}'",
428            poll_count,
429            job_id,
430            status_response.status,
431            status_response.message.as_deref().unwrap_or("none")
432        );
433
434        match status_response.status.as_str() {
435            "completed" => {
436                info!(
437                    "Job {} completed successfully after {:?} and {} polls",
438                    job_id, elapsed, poll_count
439                );
440                self.extract_proof_from_result(status_response.result, job_id)
441                    .map(Some)
442            }
443            "failed" => {
444                let error_msg = status_response
445                    .message
446                    .unwrap_or_else(|| "No error message provided".to_string());
447                error!(
448                    "Job {} failed after {:?} and {} polls: {}",
449                    job_id, elapsed, poll_count, error_msg
450                );
451                Err(ProverClientError::ProverServerError(format!(
452                    "Proof job {} failed: {}",
453                    job_id, error_msg
454                )))
455            }
456            "processing" | "queued" => {
457                trace!(
458                    "Job {} still {} after {:?} (poll #{}), waiting {:?} before next check",
459                    job_id,
460                    status_response.status,
461                    elapsed,
462                    poll_count,
463                    self.polling_interval
464                );
465                Ok(None)
466            }
467            _ => {
468                warn!(
469                    "Job {} has unknown status '{}' on poll #{} after {:?}, continuing to poll",
470                    job_id, status_response.status, poll_count, elapsed
471                );
472                Ok(None)
473            }
474        }
475    }
476
477    fn extract_proof_from_result(
478        &self,
479        result: Option<serde_json::Value>,
480        job_id: &str,
481    ) -> Result<ProofCompressed, ProverClientError> {
482        match result {
483            Some(result) => {
484                trace!("Job {} has result, parsing proof JSON", job_id);
485                let proof_json = serde_json::to_string(&result).map_err(|e| {
486                    error!("Failed to serialize result for job {}: {}", job_id, e);
487                    ProverClientError::ProverServerError("Cannot serialize result".to_string())
488                })?;
489                self.parse_proof_from_json(&proof_json)
490            }
491            None => {
492                error!("Job {} completed but has no result", job_id);
493                Err(ProverClientError::ProverServerError(
494                    "No result in completed job status".to_string(),
495                ))
496            }
497        }
498    }
499
500    fn is_job_not_found_error(&self, error: &ProverClientError) -> bool {
501        error.to_string().contains("job_not_found")
502    }
503
504    fn is_transient_polling_error(&self, error: &ProverClientError) -> bool {
505        let error_str = error.to_string();
506        error_str.contains("503") || error_str.contains("502") || error_str.contains("500")
507    }
508
509    fn parse_proof_from_json(&self, json_str: &str) -> Result<ProofCompressed, ProverClientError> {
510        let proof_json = deserialize_gnark_proof_json(json_str).map_err(|e| {
511            ProverClientError::ProverServerError(format!("Failed to deserialize proof JSON: {}", e))
512        })?;
513
514        let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json);
515        let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c);
516
517        Ok(ProofCompressed {
518            a: proof_a,
519            b: proof_b,
520            c: proof_c,
521        })
522    }
523
524    pub async fn generate_batch_address_append_proof(
525        &self,
526        inputs: BatchAddressAppendInputs,
527    ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> {
528        let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(&inputs.new_root)?;
529        let inputs_json = to_json(&inputs);
530        let proof = self.generate_proof(inputs_json).await?;
531        Ok((proof, new_root))
532    }
533
534    pub async fn generate_batch_append_proof(
535        &self,
536        circuit_inputs: BatchAppendsCircuitInputs,
537    ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> {
538        let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(
539            &circuit_inputs.new_root.to_biguint().unwrap(),
540        )?;
541        let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string();
542        let proof = self.generate_proof(inputs_json).await?;
543        Ok((proof, new_root))
544    }
545
546    pub async fn generate_batch_update_proof(
547        &self,
548        circuit_inputs: BatchUpdateCircuitInputs,
549    ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> {
550        let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(
551            &circuit_inputs.new_root.to_biguint().unwrap(),
552        )?;
553        let json_str = update_inputs_string(&circuit_inputs);
554        let proof = self.generate_proof(json_str).await?;
555        Ok((proof, new_root))
556    }
557}