use std::num::NonZeroUsize;
use miden_node_utils::tracing::OpenTelemetrySpanExt;
use tokio::sync::{Mutex, MutexGuard, SemaphorePermit};
use tracing::instrument;
use crate::server::proof_kind::ProofKind;
use crate::server::prover::Prover;
use crate::{COMPONENT, generated as proto};
pub struct ProverService {
permits: tokio::sync::Semaphore,
prover: tokio::sync::Mutex<Prover>,
kind: ProofKind,
}
impl ProverService {
pub fn with_capacity(kind: ProofKind, capacity: NonZeroUsize) -> Self {
let permits = tokio::sync::Semaphore::new(capacity.get());
let prover = Mutex::new(Prover::new(kind));
Self { permits, prover, kind }
}
fn is_supported(&self, kind: ProofKind) -> bool {
self.kind == kind
}
#[instrument(target=COMPONENT, skip_all, err)]
fn acquire_permit(&self) -> Result<SemaphorePermit<'_>, tonic::Status> {
self.permits
.try_acquire()
.map_err(|_| tonic::Status::resource_exhausted("proof queue is full"))
}
#[instrument(target=COMPONENT, skip_all)]
async fn acquire_prover(&self) -> MutexGuard<'_, Prover> {
self.prover.lock().await
}
}
#[async_trait::async_trait]
impl proto::api_server::Api for ProverService {
async fn prove(
&self,
request: tonic::Request<proto::ProofRequest>,
) -> Result<tonic::Response<proto::Proof>, tonic::Status> {
let request_id = request
.metadata()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
tracing::Span::current().set_attribute("request.id", request_id);
let request = request.into_inner();
if request.proof_type() as i32 != request.proof_type {
return Err(tonic::Status::invalid_argument("unknown proof_type value"));
}
let proof_kind = ProofKind::from(request.proof_type());
tracing::Span::current().set_attribute("request.kind", proof_kind);
if !self.is_supported(proof_kind) {
return Err(tonic::Status::invalid_argument("unsupported proof type"));
}
let _permit = self.acquire_permit()?;
let prover = self.acquire_prover().await;
prover.prove(request).await.map(tonic::Response::new)
}
}