use crate::core::ctx::Ctx;
use crate::core::engine::Engine;
use crate::types::{CapToken, TaskId};
use crate::worker::Worker;
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SpawnError {
#[error("worker not registered: {0}")]
NotRegistered(String),
#[error("spawn rejected by middleware: {0}")]
RejectedByMiddleware(String),
#[error("internal: {0}")]
Internal(String),
}
#[derive(Debug, Error)]
pub enum WorkerError {
#[error("worker fn returned error: {0}")]
Failed(String),
#[error("cancelled")]
Cancelled,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct WorkerResult {
pub value: Value,
pub ok: bool,
}
#[async_trait]
pub trait SpawnerAdapter: Send + Sync {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError>;
}
#[derive(Clone)]
pub struct WorkerInvocation {
pub token: CapToken,
pub task_id: TaskId,
pub attempt: u32,
pub agent: String,
pub prompt: String,
pub sink: Option<std::sync::Arc<dyn crate::worker::output::OutputSink>>,
pub cancel_token: Option<tokio_util::sync::CancellationToken>,
}
impl std::fmt::Debug for WorkerInvocation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkerInvocation")
.field("token", &self.token)
.field("task_id", &self.task_id)
.field("attempt", &self.attempt)
.field("agent", &self.agent)
.field("prompt", &self.prompt)
.field("sink", &self.sink.as_ref().map(|_| "<OutputSink>"))
.field(
"cancel_token",
&self.cancel_token.as_ref().map(|_| "<CancellationToken>"),
)
.finish()
}
}
pub type WorkerFn = Arc<
dyn Fn(
WorkerInvocation,
) -> Pin<Box<dyn Future<Output = Result<WorkerResult, WorkerError>> + Send>>
+ Send
+ Sync,
>;
pub struct InProcSpawner<W = crate::worker::MiddlewareWorker> {
pub registry: HashMap<String, WorkerFn>,
_phantom: std::marker::PhantomData<W>,
}
impl InProcSpawner {
pub fn new() -> Self {
Self {
registry: HashMap::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn register<F, Fut>(&mut self, agent: impl Into<String>, f: F) -> &mut Self
where
F: Fn(WorkerInvocation) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<WorkerResult, WorkerError>> + Send + 'static,
{
let f = Arc::new(f);
let wrapped: WorkerFn = Arc::new(move |inv| {
let f = f.clone();
Box::pin(f(inv))
});
self.registry.insert(agent.into(), wrapped);
self
}
}
impl<W> InProcSpawner<W>
where
W: Worker + From<crate::worker::WorkerJoinHandler> + Send + Sync + 'static,
{
pub fn typed() -> Self {
Self {
registry: HashMap::new(),
_phantom: std::marker::PhantomData,
}
}
}
impl Default for InProcSpawner {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<W: Worker + From<crate::worker::WorkerJoinHandler> + Send + Sync + 'static> SpawnerAdapter
for InProcSpawner<W>
{
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
let f = self
.registry
.get(&ctx.agent)
.cloned()
.ok_or_else(|| SpawnError::NotRegistered(ctx.agent.clone()))?;
let prompt = engine
.fetch_prompt(&token, &task_id)
.await
.map_err(|e| SpawnError::Internal(format!("fetch_prompt: {e}")))?;
let (tx, rx) = tokio::sync::oneshot::channel();
let cancel = tokio_util::sync::CancellationToken::new();
let cancel_inner = cancel.clone();
let worker_id = crate::types::WorkerId::new();
let engine_for_emit = engine.clone();
let token_for_emit = token.clone();
let task_id_for_emit = task_id.clone();
let sink = std::sync::Arc::new(crate::worker::output::EngineSink::new(
engine.clone(),
token.clone(),
task_id.clone(),
attempt,
)) as std::sync::Arc<dyn crate::worker::output::OutputSink>;
let inv = WorkerInvocation {
token,
task_id,
attempt,
agent: ctx.agent.clone(),
prompt,
sink: Some(sink),
cancel_token: Some(cancel_inner.clone()),
};
tokio::spawn(async move {
let result = tokio::select! {
r = f(inv) => r,
_ = cancel_inner.cancelled() => Err(WorkerError::Cancelled),
};
if let Ok(wr) = &result {
let ev = crate::worker::output::OutputEvent::Final {
content: crate::worker::output::ContentRef::Inline {
value: wr.value.clone(),
},
ok: wr.ok,
};
let _ = engine_for_emit
.submit_output(&token_for_emit, &task_id_for_emit, attempt, ev)
.await;
}
let signal: Result<(), WorkerError> = result.map(|_| ());
let _ = tx.send(signal);
});
let handler = crate::worker::WorkerJoinHandler {
worker_id,
cancel,
completion: rx,
};
Ok(Box::new(W::from(handler)))
}
}