1use std::time::{Duration, Instant};
2
3use reqwest::Client;
4use serde::Deserialize;
5use tokio::time::sleep;
6use tracing::{debug, error, info, trace, warn};
7
8use crate::{
9 constants::PROVE_PATH,
10 errors::ProverClientError,
11 proof::{
12 compress_proof, deserialize_gnark_proof_json, proof_from_json_struct, ProofCompressed,
13 },
14 proof_types::{
15 batch_address_append::{to_json, BatchAddressAppendInputs},
16 batch_append::{BatchAppendInputsJson, BatchAppendsCircuitInputs},
17 batch_update::{update_inputs_string, BatchUpdateCircuitInputs},
18 },
19};
20
21const MAX_RETRIES: u32 = 10;
22const BASE_RETRY_DELAY_SECS: u64 = 1;
23const DEFAULT_POLLING_INTERVAL_SECS: u64 = 1;
24const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 600;
25const DEFAULT_LOCAL_SERVER: &str = "http://localhost:3001";
26
27#[derive(Debug, Deserialize)]
28#[serde(untagged)]
29pub enum ProofResponse {
30 Async {
31 job_id: String,
32 estimated_time: Option<String>,
33 },
34}
35
36#[derive(Debug, Deserialize)]
37pub struct JobStatusResponse {
38 pub status: String,
39 pub message: Option<String>,
40 pub result: Option<serde_json::Value>,
41}
42
43#[derive(Debug, Deserialize)]
44pub struct ErrorResponse {
45 pub code: String,
46 pub message: String,
47}
48
49pub struct ProofClient {
50 client: Client,
51 server_address: String,
52 polling_interval: Duration,
53 max_wait_time: Duration,
54 api_key: Option<String>,
55}
56
57impl ProofClient {
58 pub fn local() -> Self {
59 Self {
60 client: Client::new(),
61 server_address: DEFAULT_LOCAL_SERVER.to_string(),
62 polling_interval: Duration::from_secs(DEFAULT_POLLING_INTERVAL_SECS),
63 max_wait_time: Duration::from_secs(DEFAULT_MAX_WAIT_TIME_SECS),
64 api_key: None,
65 }
66 }
67
68 #[allow(unused)]
69 pub fn with_config(
70 server_address: String,
71 polling_interval: Duration,
72 max_wait_time: Duration,
73 api_key: Option<String>,
74 ) -> Self {
75 Self {
76 client: Client::new(),
77 server_address,
78 polling_interval,
79 max_wait_time,
80 api_key,
81 }
82 }
83
84 pub async fn generate_proof(
85 &self,
86 inputs_json: String,
87 ) -> Result<ProofCompressed, ProverClientError> {
88 let start_time = Instant::now();
89 let mut retries = 0;
90
91 loop {
92 let elapsed = start_time.elapsed();
93 if elapsed > self.max_wait_time {
94 return Err(ProverClientError::ProverServerError(format!(
95 "Overall proof generation timed out after {:?} (max: {:?}), retries: {}",
96 elapsed, self.max_wait_time, retries
97 )));
98 }
99
100 match self.try_generate_proof(&inputs_json, elapsed).await {
101 Ok(proof) => return Ok(proof),
102 Err(err) if self.should_retry(&err, retries, elapsed) => {
103 retries += 1;
104 let retry_delay = Duration::from_secs(BASE_RETRY_DELAY_SECS * retries as u64);
105
106 if elapsed + retry_delay > self.max_wait_time {
107 warn!(
108 "Skipping retry due to max wait time constraint: elapsed={:?}, retry_delay={:?}, max_wait={:?}",
109 elapsed, retry_delay, self.max_wait_time
110 );
111 return Err(err);
112 }
113
114 warn!(
115 "Retrying proof generation ({}/{}) after {:?} due to: {}",
116 retries, MAX_RETRIES, retry_delay, err
117 );
118 sleep(retry_delay).await;
119 }
120 Err(err) => {
121 debug!(
122 "Not retrying error (retries={}, elapsed={:?}): {}",
123 retries, elapsed, err
124 );
125 return Err(err);
126 }
127 }
128 }
129 }
130
131 async fn try_generate_proof(
132 &self,
133 inputs_json: &str,
134 elapsed: Duration,
135 ) -> Result<ProofCompressed, ProverClientError> {
136 let response = self.send_proof_request(inputs_json).await?;
137 let status_code = response.status();
138 let response_text = response.text().await.map_err(|e| {
139 ProverClientError::ProverServerError(format!("Failed to read response body: {}", e))
140 })?;
141
142 self.log_response(status_code, &response_text);
143 self.handle_proof_response(status_code, &response_text, elapsed)
144 .await
145 }
146
147 async fn send_proof_request(
148 &self,
149 inputs_json: &str,
150 ) -> Result<reqwest::Response, ProverClientError> {
151 let url = format!("{}{}", self.server_address, PROVE_PATH);
152
153 let mut request = self
154 .client
155 .post(&url)
156 .header("Content-Type", "application/json");
157
158 if let Some(api_key) = &self.api_key {
159 request = request.header("X-API-Key", api_key);
160 }
161
162 request
163 .body(inputs_json.to_string())
164 .send()
165 .await
166 .map_err(|e| {
167 ProverClientError::ProverServerError(format!(
168 "Failed to send request to prover server: {}",
169 e
170 ))
171 })
172 }
173
174 fn log_response(&self, status_code: reqwest::StatusCode, response_text: &str) {
175 if !status_code.is_success() {
176 error!("HTTP error: status={}, body={}", status_code, response_text);
177 }
178 }
179
180 async fn handle_proof_response(
181 &self,
182 status_code: reqwest::StatusCode,
183 response_text: &str,
184 start_elapsed: Duration,
185 ) -> Result<ProofCompressed, ProverClientError> {
186 match status_code {
187 reqwest::StatusCode::OK => self.parse_proof_from_json(response_text),
188 reqwest::StatusCode::ACCEPTED => {
189 let job_response = self.parse_job_response(response_text)?;
190 self.handle_async_job(job_response, start_elapsed).await
191 }
192 _ => self.handle_error_response(response_text),
193 }
194 }
195
196 fn parse_job_response(&self, response_text: &str) -> Result<ProofResponse, ProverClientError> {
197 serde_json::from_str(response_text).map_err(|e| {
198 error!("Failed to parse async response: {}", e);
199 ProverClientError::ProverServerError(format!("Failed to parse async response: {}", e))
200 })
201 }
202
203 async fn handle_async_job(
204 &self,
205 job_response: ProofResponse,
206 start_elapsed: Duration,
207 ) -> Result<ProofCompressed, ProverClientError> {
208 match job_response {
209 ProofResponse::Async { job_id, .. } => {
210 info!("Proof job queued with ID: {}", job_id);
211 self.poll_for_result(&job_id, start_elapsed).await
212 }
213 }
214 }
215
216 fn handle_error_response(
217 &self,
218 response_text: &str,
219 ) -> Result<ProofCompressed, ProverClientError> {
220 if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(response_text) {
221 error!(
222 "Prover server error: {} - {}",
223 error_response.code, error_response.message
224 );
225 Err(ProverClientError::ProverServerError(format!(
226 "Prover server error: {} - {}",
227 error_response.code, error_response.message
228 )))
229 } else {
230 error!("Prover server error: {}", response_text);
231 Err(ProverClientError::ProverServerError(format!(
232 "Prover server error: {}",
233 response_text
234 )))
235 }
236 }
237
238 fn should_retry(&self, error: &ProverClientError, retries: u32, elapsed: Duration) -> bool {
239 let error_str = error.to_string();
240
241 let is_constraint_error =
242 error_str.contains("constraint") || error_str.contains("is not satisfied");
243 if is_constraint_error {
244 return false;
245 }
246
247 let is_retryable_error = error_str.contains("job_not_found")
248 || error_str.contains("connection")
249 || error_str.contains("timeout")
250 || error_str.contains("503")
251 || error_str.contains("502")
252 || error_str.contains("500");
253 let should_retry =
254 retries < MAX_RETRIES && is_retryable_error && elapsed < self.max_wait_time;
255
256 trace!(
257 "Retry check: retries={}/{}, is_retryable_error={}, elapsed={:?}/{:?}, should_retry={}, error={}",
258 retries, MAX_RETRIES, is_retryable_error, elapsed, self.max_wait_time, should_retry, error_str
259 );
260
261 should_retry
262 }
263
264 async fn poll_for_result(
265 &self,
266 job_id: &str,
267 start_elapsed: Duration,
268 ) -> Result<ProofCompressed, ProverClientError> {
269 let poll_start_time = Instant::now();
270 let status_url = format!("{}/prove/status?job_id={}", self.server_address, job_id);
271
272 info!("Starting to poll for job {} at URL: {}", job_id, status_url);
273
274 let mut poll_count = 0;
275 let mut transient_error_count = 0;
276
277 loop {
278 poll_count += 1;
279 let poll_elapsed = poll_start_time.elapsed();
280 let total_elapsed = start_elapsed + poll_elapsed;
281
282 if total_elapsed > self.max_wait_time {
283 return Err(ProverClientError::ProverServerError(format!(
284 "Job {} timed out after {:?} total (max: {:?}), polling time: {:?}, total polls: {}",
285 job_id, total_elapsed, self.max_wait_time, poll_elapsed, poll_count
286 )));
287 }
288
289 trace!(
290 "Poll #{} for job {} at total elapsed time {:?} (polling: {:?})",
291 poll_count,
292 job_id,
293 total_elapsed,
294 poll_elapsed
295 );
296
297 match self.poll_job_status(&status_url, job_id, poll_count).await {
298 Ok(response) => {
299 transient_error_count = 0;
300
301 if let Some(proof) = self
302 .handle_job_status(response, job_id, total_elapsed, poll_count)
303 .await?
304 {
305 return Ok(proof);
306 }
307
308 if total_elapsed + self.polling_interval > self.max_wait_time {
309 warn!(
310 "Skipping polling interval due to max wait time constraint: total_elapsed={:?}, polling_interval={:?}, max_wait={:?}",
311 total_elapsed, self.polling_interval, self.max_wait_time
312 );
313 return Err(ProverClientError::ProverServerError(format!(
314 "Job {} polling stopped due to max wait time constraint",
315 job_id
316 )));
317 }
318
319 sleep(self.polling_interval).await;
320 }
321 Err(err) if self.is_job_not_found_error(&err) => {
322 error!(
323 "Job {} not found during polling - will retry with new proof request at higher level: {}",
324 job_id, err
325 );
326 return Err(err);
327 }
328 Err(err) if self.is_transient_polling_error(&err) => {
329 transient_error_count += 1;
330
331 trace!(
332 "Transient polling error for job {}: attempt {}/{}, error: {}",
333 job_id,
334 transient_error_count,
335 MAX_RETRIES,
336 err
337 );
338
339 if transient_error_count >= MAX_RETRIES {
340 error!(
341 "Job {} polling failed after {} transient errors, giving up",
342 job_id, transient_error_count
343 );
344 return Err(err);
345 }
346
347 let retry_delay =
348 Duration::from_secs(BASE_RETRY_DELAY_SECS * transient_error_count as u64);
349
350 if total_elapsed + retry_delay > self.max_wait_time {
351 warn!(
352 "Skipping transient error retry due to max wait time constraint: total_elapsed={:?}, retry_delay={:?}, max_wait={:?}",
353 total_elapsed, retry_delay, self.max_wait_time
354 );
355 return Err(err);
356 }
357
358 warn!(
359 "Job {} transient error (attempt {}/{}), retrying after {:?}",
360 job_id, transient_error_count, MAX_RETRIES, retry_delay
361 );
362 sleep(retry_delay).await;
363 }
364 Err(err) => {
365 debug!("Not retrying polling error for job {}: {}", job_id, err);
366 return Err(err);
367 }
368 }
369 }
370 }
371
372 async fn poll_job_status(
373 &self,
374 status_url: &str,
375 job_id: &str,
376 poll_count: u32,
377 ) -> Result<JobStatusResponse, ProverClientError> {
378 let mut request = self.client.get(status_url);
379
380 if let Some(api_key) = &self.api_key {
381 request = request.header("X-API-Key", api_key);
382 }
383
384 let response = request.send().await.map_err(|e| {
385 error!("Failed to send status request for job {}: {}", job_id, e);
386 ProverClientError::ProverServerError(format!("Failed to check job status: {}", e))
387 })?;
388
389 let status_code = response.status();
390 let response_text = response.text().await.unwrap_or_default();
391
392 trace!(
393 "Poll #{} for job {}: status={}, body_len={}",
394 poll_count,
395 job_id,
396 status_code,
397 response_text.len()
398 );
399
400 if !status_code.is_success() {
401 return Err(ProverClientError::ProverServerError(format!(
402 "HTTP error while polling for result: status={}, body={}",
403 status_code, response_text
404 )));
405 }
406
407 serde_json::from_str(&response_text).map_err(|e| {
408 error!(
409 "Failed to parse status response on poll #{} for job {}: error={}, body={}",
410 poll_count, job_id, e, response_text
411 );
412 ProverClientError::ProverServerError(format!(
413 "Failed to parse status response: {}, body: {}",
414 e, response_text
415 ))
416 })
417 }
418
419 async fn handle_job_status(
420 &self,
421 status_response: JobStatusResponse,
422 job_id: &str,
423 elapsed: Duration,
424 poll_count: u32,
425 ) -> Result<Option<ProofCompressed>, ProverClientError> {
426 trace!(
427 "Poll #{} for job {}: status='{}', message='{}'",
428 poll_count,
429 job_id,
430 status_response.status,
431 status_response.message.as_deref().unwrap_or("none")
432 );
433
434 match status_response.status.as_str() {
435 "completed" => {
436 info!(
437 "Job {} completed successfully after {:?} and {} polls",
438 job_id, elapsed, poll_count
439 );
440 self.extract_proof_from_result(status_response.result, job_id)
441 .map(Some)
442 }
443 "failed" => {
444 let error_msg = status_response
445 .message
446 .unwrap_or_else(|| "No error message provided".to_string());
447 error!(
448 "Job {} failed after {:?} and {} polls: {}",
449 job_id, elapsed, poll_count, error_msg
450 );
451 Err(ProverClientError::ProverServerError(format!(
452 "Proof job {} failed: {}",
453 job_id, error_msg
454 )))
455 }
456 "processing" | "queued" => {
457 trace!(
458 "Job {} still {} after {:?} (poll #{}), waiting {:?} before next check",
459 job_id,
460 status_response.status,
461 elapsed,
462 poll_count,
463 self.polling_interval
464 );
465 Ok(None)
466 }
467 _ => {
468 warn!(
469 "Job {} has unknown status '{}' on poll #{} after {:?}, continuing to poll",
470 job_id, status_response.status, poll_count, elapsed
471 );
472 Ok(None)
473 }
474 }
475 }
476
477 fn extract_proof_from_result(
478 &self,
479 result: Option<serde_json::Value>,
480 job_id: &str,
481 ) -> Result<ProofCompressed, ProverClientError> {
482 match result {
483 Some(result) => {
484 trace!("Job {} has result, parsing proof JSON", job_id);
485 let proof_json = serde_json::to_string(&result).map_err(|e| {
486 error!("Failed to serialize result for job {}: {}", job_id, e);
487 ProverClientError::ProverServerError("Cannot serialize result".to_string())
488 })?;
489 self.parse_proof_from_json(&proof_json)
490 }
491 None => {
492 error!("Job {} completed but has no result", job_id);
493 Err(ProverClientError::ProverServerError(
494 "No result in completed job status".to_string(),
495 ))
496 }
497 }
498 }
499
500 fn is_job_not_found_error(&self, error: &ProverClientError) -> bool {
501 error.to_string().contains("job_not_found")
502 }
503
504 fn is_transient_polling_error(&self, error: &ProverClientError) -> bool {
505 let error_str = error.to_string();
506 error_str.contains("503") || error_str.contains("502") || error_str.contains("500")
507 }
508
509 fn parse_proof_from_json(&self, json_str: &str) -> Result<ProofCompressed, ProverClientError> {
510 let proof_json = deserialize_gnark_proof_json(json_str).map_err(|e| {
511 ProverClientError::ProverServerError(format!("Failed to deserialize proof JSON: {}", e))
512 })?;
513
514 let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json);
515 let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c);
516
517 Ok(ProofCompressed {
518 a: proof_a,
519 b: proof_b,
520 c: proof_c,
521 })
522 }
523
524 pub async fn generate_batch_address_append_proof(
525 &self,
526 inputs: BatchAddressAppendInputs,
527 ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> {
528 let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(&inputs.new_root)?;
529 let inputs_json = to_json(&inputs);
530 let proof = self.generate_proof(inputs_json).await?;
531 Ok((proof, new_root))
532 }
533
534 pub async fn generate_batch_append_proof(
535 &self,
536 circuit_inputs: BatchAppendsCircuitInputs,
537 ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> {
538 let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(
539 &circuit_inputs.new_root.to_biguint().unwrap(),
540 )?;
541 let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string();
542 let proof = self.generate_proof(inputs_json).await?;
543 Ok((proof, new_root))
544 }
545
546 pub async fn generate_batch_update_proof(
547 &self,
548 circuit_inputs: BatchUpdateCircuitInputs,
549 ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> {
550 let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(
551 &circuit_inputs.new_root.to_biguint().unwrap(),
552 )?;
553 let json_str = update_inputs_string(&circuit_inputs);
554 let proof = self.generate_proof(json_str).await?;
555 Ok((proof, new_root))
556 }
557}