use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use miden_crypto::utils::Serializable;
use miden_protocol::block::{BlockNumber, BlockProof};
use miden_remote_prover_client::RemoteProverClientError;
use thiserror::Error;
use tokio::sync::watch;
use tokio::task::{JoinHandle, JoinSet};
use tracing::{error, info, instrument};
use crate::COMPONENT;
use crate::blocks::BlockStore;
use crate::db::Db;
use crate::errors::{DatabaseError, ProofSchedulerError};
use crate::server::block_prover_client::{BlockProver, StoreProverError};
const BLOCK_PROVE_TIMEOUT: Duration = Duration::from_mins(4);
pub const DEFAULT_MAX_CONCURRENT_PROOFS: NonZeroUsize = NonZeroUsize::new(8).unwrap();
struct ProofTaskJoinSet(JoinSet<anyhow::Result<()>>);
impl ProofTaskJoinSet {
fn new() -> Self {
Self(JoinSet::new())
}
fn len(&self) -> usize {
self.0.len()
}
fn spawn(
&mut self,
db: &Arc<Db>,
block_prover: &Arc<BlockProver>,
block_store: &Arc<BlockStore>,
block_num: BlockNumber,
) {
let db = Arc::clone(db);
let block_prover = Arc::clone(block_prover);
let block_store = Arc::clone(block_store);
self.0
.spawn(async move { prove_block(&db, &block_prover, &block_store, block_num).await });
}
async fn join_next(&mut self) -> anyhow::Result<()> {
if self.0.is_empty() {
std::future::pending().await
} else {
self.0
.join_next()
.await
.expect("join set is not empty")
.context("proving task panicked")
.flatten()
}
}
}
pub fn spawn(
db: Arc<Db>,
block_prover: Arc<BlockProver>,
block_store: Arc<BlockStore>,
chain_tip_rx: watch::Receiver<BlockNumber>,
max_concurrent_proofs: NonZeroUsize,
) -> JoinHandle<anyhow::Result<()>> {
tokio::spawn(run(db, block_prover, block_store, chain_tip_rx, max_concurrent_proofs))
}
async fn run(
db: Arc<Db>,
block_prover: Arc<BlockProver>,
block_store: Arc<BlockStore>,
mut chain_tip_rx: watch::Receiver<BlockNumber>,
max_concurrent_proofs: NonZeroUsize,
) -> anyhow::Result<()> {
info!(target: COMPONENT, "Proof scheduler started");
let mut join_set = ProofTaskJoinSet::new();
let mut highest_scheduled = db.select_latest_proven_in_sequence_block_num().await?;
loop {
let capacity = max_concurrent_proofs.get() - join_set.len();
if capacity > 0 {
let unproven = db.select_unproven_blocks(highest_scheduled, capacity).await?;
if let Some(&last) = unproven.last() {
highest_scheduled = last;
}
for block_num in unproven {
join_set.spawn(&db, &block_prover, &block_store, block_num);
}
}
tokio::select! {
result = join_set.join_next() => {
result?;
},
result = chain_tip_rx.changed() => {
if result.is_err() {
info!(target: COMPONENT, "Chain tip channel closed, proof scheduler exiting");
return Ok(());
}
},
}
}
}
#[instrument(target = COMPONENT, name = "prove_block", skip_all, fields(block.number=block_num.as_u32()), err)]
async fn prove_block(
db: &Db,
block_prover: &BlockProver,
block_store: &BlockStore,
block_num: BlockNumber,
) -> anyhow::Result<()> {
const MAX_RETRIES: u32 = 10;
for _ in 0..MAX_RETRIES {
match tokio::time::timeout(
BLOCK_PROVE_TIMEOUT,
generate_block_proof(db, block_prover, block_num),
)
.await
{
Ok(Ok(proof)) => {
block_store.save_proof(block_num, &proof.to_bytes()).await?;
let advanced_in_sequence = db.mark_proven_and_advance_sequence(block_num).await?;
if let Some(&last) = advanced_in_sequence.last() {
info!(
target = COMPONENT,
block.number = %block_num,
proven_in_sequence_tip = %last,
"Block proven and in-sequence advanced",
);
} else {
info!(target = COMPONENT, block.number = %block_num, "Block proven");
}
return Ok(());
},
Ok(Err(ProveBlockError::Fatal(err))) => Err(err).context("fatal error")?,
Ok(Err(ProveBlockError::Transient(err))) => {
error!(target = COMPONENT, block.number = %block_num, err = ?err, "transient error proving block, retrying");
},
Err(elapsed) => {
error!(target = COMPONENT, block.number = %block_num, %elapsed, "block proving timed out, retrying");
},
}
}
anyhow::bail!("maximum retries ({MAX_RETRIES}) exceeded");
}
#[instrument(target = COMPONENT, name = "prove_block.generate", skip_all, fields(block.number=block_num.as_u32()), err)]
async fn generate_block_proof(
db: &Db,
block_prover: &BlockProver,
block_num: BlockNumber,
) -> Result<BlockProof, ProveBlockError> {
let request = db
.select_block_proving_inputs(block_num)
.await
.map_err(ProveBlockError::from_db_error)?
.ok_or_else(|| {
ProveBlockError::Fatal(ProofSchedulerError::MissingProvingInputs(block_num))
})?;
let proof = block_prover
.prove(request.tx_batches, request.block_inputs, &request.block_header)
.await
.map_err(ProveBlockError::from_prover_error)?;
Ok(proof)
}
#[derive(Debug, Error)]
enum ProveBlockError {
#[error("fatal error")]
Fatal(#[source] ProofSchedulerError),
#[error("transient error: {0}")]
Transient(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl ProveBlockError {
fn from_db_error(err: DatabaseError) -> Self {
match err {
DatabaseError::DeserializationError(err) => {
Self::Fatal(ProofSchedulerError::DeserializationFailed(err))
},
_ => Self::Transient(err.into()),
}
}
fn from_prover_error(err: StoreProverError) -> Self {
match err {
StoreProverError::RemoteProvingFailed(RemoteProverClientError::InvalidEndpoint(
uri,
)) => Self::Fatal(ProofSchedulerError::InvalidProverEndpoint(uri)),
_ => Self::Transient(err.into()),
}
}
}