use std::{
sync::{Arc, Mutex},
time::Duration,
};
use sp1_prover_types::{ArtifactClient, ArtifactId, ShardPermit, TaskStatus};
use tokio::task::AbortHandle;
use crate::worker::{ProofId, TaskId, TaskSubscriber, WorkerClient};
pub struct ProveShardGate<A: ArtifactClient, W: WorkerClient> {
inner: Arc<GateInner<A, W>>,
}
impl<A: ArtifactClient, W: WorkerClient> std::fmt::Debug for ProveShardGate<A, W> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProveShardGate").finish_non_exhaustive()
}
}
struct GateInner<A: ArtifactClient, W: WorkerClient> {
artifact_client: A,
subscriber: TaskSubscriber<W>,
release_handles: Mutex<Vec<AbortHandle>>,
}
impl<A: ArtifactClient, W: WorkerClient> Drop for GateInner<A, W> {
fn drop(&mut self) {
self.subscriber.close();
if let Ok(mut handles) = self.release_handles.lock() {
for handle in handles.drain(..) {
handle.abort();
}
}
}
}
impl<A: ArtifactClient, W: WorkerClient> Clone for ProveShardGate<A, W> {
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<A: ArtifactClient, W: WorkerClient> ProveShardGate<A, W> {
pub async fn new(
artifact_client: A,
worker_client: W,
proof_id: ProofId,
) -> anyhow::Result<Self> {
let subscriber = worker_client.subscriber(proof_id).await?.per_task();
Ok(Self {
inner: Arc::new(GateInner {
artifact_client,
subscriber,
release_handles: Mutex::new(Vec::new()),
}),
})
}
pub async fn acquire(&self, artifact: &impl ArtifactId) -> ShardPermit {
self.inner.artifact_client.acquire_shard_permit(artifact).await
}
pub fn schedule_release(&self, task_id: TaskId, permit: ShardPermit) {
let subscriber = self.inner.subscriber.clone();
let handle = tokio::spawn(async move {
let _permit = permit;
loop {
match subscriber.wait_task(task_id.clone()).await {
Ok(TaskStatus::Succeeded | TaskStatus::FailedFatal) => break,
Ok(_) => {
tokio::time::sleep(Duration::from_millis(500)).await;
}
Err(e) => {
tracing::warn!(%task_id, error = %e, "wait_task failed, releasing permit");
break;
}
}
}
});
if let Ok(mut handles) = self.inner.release_handles.lock() {
handles.push(handle.abort_handle());
}
}
}