use std::{collections::BTreeMap, sync::Arc};
use hashbrown::{HashMap, HashSet};
use mti::prelude::{MagicTypeIdExt, V7};
use sp1_prover_types::{ProofRequestStatus, TaskStatus, TaskType};
use tokio::sync::{mpsc, watch, RwLock};
use crate::worker::{
ProofId, RawTaskRequest, SubscriberBuilder, TaskId, TaskMetadata, WorkerClient,
};
struct MessageChannelState {
tx: mpsc::UnboundedSender<Vec<u8>>,
rx: Option<mpsc::UnboundedReceiver<Vec<u8>>>,
}
type LocalDb =
Arc<RwLock<HashMap<TaskId, (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>)>>>;
type ProofIndex = Arc<RwLock<HashMap<ProofId, HashSet<TaskId>>>>;
pub struct LocalWorkerClientChannels {
pub task_receivers: BTreeMap<TaskType, mpsc::Receiver<(TaskId, RawTaskRequest)>>,
}
pub struct LocalWorkerClientInner {
db: LocalDb,
proof_index: ProofIndex,
input_task_queues: HashMap<TaskType, mpsc::Sender<(TaskId, RawTaskRequest)>>,
task_channels: RwLock<HashMap<TaskId, MessageChannelState>>,
}
impl LocalWorkerClientInner {
fn create_id() -> TaskId {
TaskId::new("local_worker".create_type_id::<V7>().to_string())
}
fn init() -> (Self, LocalWorkerClientChannels) {
let mut task_outputs = BTreeMap::new();
let mut task_queues = HashMap::new();
for task_type in [
TaskType::UnspecifiedTaskType,
TaskType::Controller,
TaskType::ProveShard,
TaskType::RecursionReduce,
TaskType::RecursionDeferred,
TaskType::ShrinkWrap,
TaskType::SetupVkey,
TaskType::MarkerDeferredRecord,
TaskType::PlonkWrap,
TaskType::Groth16Wrap,
TaskType::ExecuteOnly,
TaskType::UtilVkeyMapChunk,
TaskType::UtilVkeyMapController,
TaskType::CoreExecute,
] {
let (tx, rx) = mpsc::channel(1);
task_outputs.insert(task_type, rx);
task_queues.insert(task_type, tx);
}
let db = Arc::new(RwLock::new(HashMap::new()));
let proof_index = Arc::new(RwLock::new(HashMap::new()));
let task_channels = RwLock::new(HashMap::new());
let inner = Self { db, proof_index, input_task_queues: task_queues, task_channels };
(inner, LocalWorkerClientChannels { task_receivers: task_outputs })
}
}
pub struct LocalWorkerClient {
inner: Arc<LocalWorkerClientInner>,
}
impl LocalWorkerClient {
#[must_use]
pub fn init() -> (Self, LocalWorkerClientChannels) {
let (inner, channels) = LocalWorkerClientInner::init();
(Self { inner: Arc::new(inner) }, channels)
}
pub async fn update_task_status(
&self,
task_id: TaskId,
status: TaskStatus,
) -> anyhow::Result<()> {
let (status_tx, _) = self
.inner
.db
.read()
.await
.get(&task_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("task does not exist"))?;
status_tx.send(status).map_err(|_| anyhow::anyhow!("failed to send status to task"))?;
if matches!(
status,
TaskStatus::Succeeded | TaskStatus::FailedFatal | TaskStatus::FailedRetryable
) {
self.inner.task_channels.write().await.remove(&task_id);
}
Ok(())
}
}
impl Clone for LocalWorkerClient {
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl WorkerClient for LocalWorkerClient {
async fn submit_task(&self, kind: TaskType, task: RawTaskRequest) -> anyhow::Result<TaskId> {
tracing::debug!("submitting task of kind {kind:?}");
let task_id = LocalWorkerClientInner::create_id();
self.inner
.proof_index
.write()
.await
.entry(task.context.proof_id.clone())
.or_insert_with(HashSet::new)
.insert(task_id.clone());
let (tx, rx) = watch::channel(TaskStatus::Pending);
self.inner.db.write().await.insert(task_id.clone(), (tx, rx));
self.inner.input_task_queues[&kind]
.send((task_id.clone(), task))
.await
.map_err(|e| anyhow::anyhow!("failed to send task of kind {:?} to queue: {e}", kind))?;
Ok(task_id)
}
async fn complete_task(
&self,
_proof_id: ProofId,
task_id: TaskId,
_metadata: TaskMetadata,
) -> anyhow::Result<()> {
self.update_task_status(task_id, TaskStatus::Succeeded).await
}
async fn complete_proof(
&self,
proof_id: ProofId,
_task_id: Option<TaskId>,
_status: ProofRequestStatus,
_extra_data: impl Into<String> + Send,
) -> anyhow::Result<()> {
let tasks = self
.inner
.proof_index
.write()
.await
.remove(&proof_id)
.ok_or_else(|| anyhow::anyhow!("proof does not exist for id {proof_id}"))?;
for task_id in tasks {
self.inner.db.write().await.remove(&task_id);
}
Ok(())
}
async fn subscriber(&self, _proof_id: ProofId) -> anyhow::Result<SubscriberBuilder<Self>> {
let (subscriber_input_tx, mut subscriber_input_rx) = mpsc::unbounded_channel();
let (subscriber_output_tx, subscriber_output_rx) = mpsc::unbounded_channel();
tokio::task::spawn({
let db = self.inner.db.clone();
let output_tx = subscriber_output_tx.clone();
async move {
while let Some(id) = subscriber_input_rx.recv().await {
let db = db.clone();
let output_tx = output_tx.clone();
tokio::task::spawn(async move {
let (_, mut rx) =
db.read().await.get(&id).cloned().expect("task does not exist");
rx.mark_changed();
while let Ok(()) = rx.changed().await {
let value = *rx.borrow();
if matches!(
value,
TaskStatus::FailedFatal
| TaskStatus::FailedRetryable
| TaskStatus::Succeeded
) {
output_tx.send((id, value)).ok();
return;
}
}
});
}
}
});
Ok(SubscriberBuilder::new(self.clone(), subscriber_input_tx, subscriber_output_rx))
}
async fn subscribe_task_messages(
&self,
task_id: &TaskId,
) -> anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>> {
let mut channels = self.inner.task_channels.write().await;
if let Some(state) = channels.get_mut(task_id) {
let rx = state
.rx
.take()
.ok_or_else(|| anyhow::anyhow!("task channel already subscribed for {task_id}"))?;
return Ok(rx);
}
let (tx, rx) = mpsc::unbounded_channel();
channels.insert(task_id.clone(), MessageChannelState { tx, rx: None });
Ok(rx)
}
async fn send_task_message(&self, task_id: &TaskId, payload: Vec<u8>) -> anyhow::Result<()> {
let mut channels = self.inner.task_channels.write().await;
if let Some(state) = channels.get_mut(task_id) {
state.tx.send(payload).map_err(|_| anyhow::anyhow!("task channel receiver dropped"))?;
} else {
let (tx, rx) = mpsc::unbounded_channel();
tx.send(payload).expect("just-created channel cannot be closed");
channels.insert(task_id.clone(), MessageChannelState { tx, rx: Some(rx) });
}
Ok(())
}
}
#[cfg(test)]
pub mod test_utils {
use std::{ops::Range, time::Duration};
use rand::Rng;
use super::*;
pub fn mock_worker_client(
mut random_interval: HashMap<TaskType, Range<Duration>>,
) -> LocalWorkerClient {
let (worker_client, mut channels) = LocalWorkerClient::init();
for task_type in [
TaskType::Controller,
TaskType::SetupVkey,
TaskType::ProveShard,
TaskType::MarkerDeferredRecord,
TaskType::RecursionReduce,
TaskType::RecursionDeferred,
TaskType::ShrinkWrap,
TaskType::PlonkWrap,
TaskType::Groth16Wrap,
TaskType::ExecuteOnly,
TaskType::CoreExecute,
] {
let mut rx = channels.task_receivers.remove(&task_type).unwrap();
let interval = random_interval.remove(&task_type).unwrap();
let worker_client = worker_client.clone();
tokio::task::spawn(async move {
while let Some((task_id, request)) = rx.recv().await {
let client = worker_client.clone();
let interval = interval.clone();
tokio::spawn(async move {
let duration = {
let mut rng = rand::thread_rng();
rng.gen_range(interval)
};
tokio::time::sleep(duration).await;
client
.complete_task(
request.context.proof_id,
task_id,
TaskMetadata { gpu_ms: None },
)
.await
.unwrap();
});
}
});
}
worker_client
}
}