pub mod input_inject;
pub mod lua_layer;
pub mod project_name_alias;
pub mod resolver;
pub mod sink;
use crate::core::ctx::{Ctx, OperatorKind};
use crate::core::engine::Engine;
use crate::core::state::Event;
use crate::types::{CapToken, TaskId};
use crate::worker::adapter::{SpawnError, SpawnerAdapter};
use crate::worker::output::{ContentRef, OutputEvent};
use crate::worker::{wrap_join, MiddlewareWorker, Worker, WorkerJoinHandler};
use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::broadcast;
async fn pull_final_value_ok(
engine: &Engine,
task_id: &TaskId,
attempt: u32,
) -> Option<(Value, bool)> {
let tail = engine.output_tail(task_id, attempt).await;
tail.iter().rev().find_map(|ev| match ev {
OutputEvent::Final {
content: ContentRef::Inline { value },
ok,
} => Some((value.clone(), *ok)),
OutputEvent::Final {
content: ContentRef::FileRef { path, .. },
ok,
} => Some((serde_json::json!({"file_ref": path.to_string_lossy()}), *ok)),
_ => None,
})
}
pub trait SpawnerLayer: Send + Sync + 'static {
fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter>;
}
pub struct SpawnerStack {
inner: Arc<dyn SpawnerAdapter>,
}
impl SpawnerStack {
pub fn new(base: Arc<dyn SpawnerAdapter>) -> Self {
Self { inner: base }
}
pub fn layer<L: SpawnerLayer>(mut self, layer: L) -> Self {
self.inner = layer.wrap(self.inner);
self
}
pub fn layer_dyn(mut self, layer: Arc<dyn SpawnerLayer>) -> Self {
self.inner = layer.wrap(self.inner);
self
}
pub fn build(self) -> Arc<dyn SpawnerAdapter> {
self.inner
}
}
pub type LayerFactory =
Arc<dyn Fn(&crate::core::engine::Engine) -> Arc<dyn SpawnerLayer> + Send + Sync + 'static>;
#[derive(Default, Clone)]
pub struct LayerRegistry {
base: Vec<LayerFactory>,
hints: std::collections::HashMap<String, LayerFactory>,
}
impl LayerRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_base<F>(mut self, factory: F) -> Self
where
F: Fn(&crate::core::engine::Engine) -> Arc<dyn SpawnerLayer> + Send + Sync + 'static,
{
self.base.push(Arc::new(factory));
self
}
pub fn with_hint<F>(mut self, key: impl Into<String>, factory: F) -> Self
where
F: Fn(&crate::core::engine::Engine) -> Arc<dyn SpawnerLayer> + Send + Sync + 'static,
{
self.hints.insert(key.into(), Arc::new(factory));
self
}
pub fn base_factories(&self) -> &[LayerFactory] {
&self.base
}
pub fn lookup_hint(&self, key: &str) -> Option<&LayerFactory> {
self.hints.get(key)
}
}
pub struct AuditMiddleware {
pub event_tx: broadcast::Sender<Event>,
}
impl AuditMiddleware {
pub fn new(event_tx: broadcast::Sender<Event>) -> Self {
Self { event_tx }
}
}
impl SpawnerLayer for AuditMiddleware {
fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
Arc::new(AuditWrapped {
inner,
event_tx: self.event_tx.clone(),
})
}
}
struct AuditWrapped {
inner: Arc<dyn SpawnerAdapter>,
event_tx: broadcast::Sender<Event>,
}
#[async_trait]
impl SpawnerAdapter for AuditWrapped {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
let _ = self.event_tx.send(Event::TaskAttemptStarted {
task_id: task_id.clone(),
attempt,
});
self.inner.spawn(engine, ctx, task_id, attempt, token).await
}
}
pub struct MainAIMiddleware;
impl MainAIMiddleware {
pub fn new() -> Self {
Self
}
}
impl Default for MainAIMiddleware {
fn default() -> Self {
Self::new()
}
}
impl SpawnerLayer for MainAIMiddleware {
fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
Arc::new(MainAIWrapped { inner })
}
}
struct MainAIWrapped {
inner: Arc<dyn SpawnerAdapter>,
}
#[async_trait]
impl SpawnerAdapter for MainAIWrapped {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
let mainai = matches!(
ctx.operator.kind,
OperatorKind::MainAi | OperatorKind::Composite
);
if mainai {
if let Some(hook) = &ctx.operator.spawn_hook {
hook.before(ctx)
.await
.map_err(SpawnError::RejectedByMiddleware)?;
}
}
let handle = self
.inner
.spawn(engine, ctx, task_id.clone(), attempt, token)
.await?;
if !mainai {
return Ok(handle);
}
let Some(hook) = ctx.operator.spawn_hook.clone() else {
return Ok(handle);
};
let ctx_clone = ctx.clone();
let engine_clone = engine.clone();
let task_id_clone = task_id.clone();
Ok(wrap_join(handle, move |signal| {
let hook = hook.clone();
let ctx_clone = ctx_clone.clone();
let engine_clone = engine_clone.clone();
let task_id_clone = task_id_clone.clone();
async move {
let v = match &signal {
Ok(()) => pull_final_value_ok(&engine_clone, &task_id_clone, attempt)
.await
.map(|(v, _)| v)
.unwrap_or(Value::Null),
Err(e) => Value::String(e.to_string()),
};
let _ = hook.after(&ctx_clone, &v).await;
signal
}
}))
}
}
pub struct SeniorEscalationMiddleware;
impl SeniorEscalationMiddleware {
pub fn new() -> Self {
Self
}
}
impl Default for SeniorEscalationMiddleware {
fn default() -> Self {
Self::new()
}
}
impl SpawnerLayer for SeniorEscalationMiddleware {
fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
Arc::new(SeniorWrapped { inner })
}
}
struct SeniorWrapped {
inner: Arc<dyn SpawnerAdapter>,
}
#[async_trait]
impl SpawnerAdapter for SeniorWrapped {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
let bridge = ctx.operator.senior_bridge.clone();
let task_id_for_hook = task_id.clone();
let engine_clone = engine.clone();
let token_clone = token.clone();
let handle = self
.inner
.spawn(engine, ctx, task_id, attempt, token)
.await?;
let Some(bridge) = bridge else {
return Ok(handle);
};
Ok(wrap_join(handle, move |signal| {
let bridge = bridge.clone();
let task_id = task_id_for_hook.clone();
let engine = engine_clone.clone();
let token = token_clone.clone();
async move {
signal?;
let last = pull_final_value_ok(&engine, &task_id, attempt).await;
if let Some((value, false)) = last {
let question = serde_json::json!({
"reason": "worker reported ok=false",
"value": value.clone(),
});
if let Ok(answer) = bridge.ask(&task_id, question).await {
let override_val = serde_json::json!({
"original": value,
"senior_answer": answer,
});
let _ = engine
.submit_output(
&token,
&task_id,
attempt,
OutputEvent::Final {
content: ContentRef::Inline {
value: override_val,
},
ok: true,
},
)
.await;
}
}
Ok(())
}
}))
}
}
pub struct OperatorDelegateMiddleware;
impl OperatorDelegateMiddleware {
pub fn new() -> Self {
Self
}
}
impl Default for OperatorDelegateMiddleware {
fn default() -> Self {
Self::new()
}
}
impl SpawnerLayer for OperatorDelegateMiddleware {
fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
Arc::new(OperatorDelegateWrapped { inner })
}
}
struct OperatorDelegateWrapped {
inner: Arc<dyn SpawnerAdapter>,
}
#[async_trait]
impl SpawnerAdapter for OperatorDelegateWrapped {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
let Some(operator) = ctx.operator.operator.clone() else {
return self.inner.spawn(engine, ctx, task_id, attempt, token).await;
};
let prompt = engine
.fetch_prompt(&token, &task_id)
.await
.map_err(|e| SpawnError::Internal(format!("fetch_prompt: {e}")))?;
let engine_clone = engine.clone();
let token_clone = token.clone();
let token_for_op = token.clone();
let task_id_clone = task_id.clone();
let ctx_clone = ctx.clone();
let (tx, rx) = tokio::sync::oneshot::channel();
let cancel = tokio_util::sync::CancellationToken::new();
let cancel_inner = cancel.clone();
let worker_id = crate::types::WorkerId::new();
tokio::spawn(async move {
let result: Result<
crate::worker::adapter::WorkerResult,
crate::worker::adapter::WorkerError,
> = tokio::select! {
r = operator.execute(&ctx_clone, None, prompt, None, token_for_op) => r,
_ = cancel_inner.cancelled() => Err(crate::worker::adapter::WorkerError::Cancelled),
};
if let Ok(wr) = &result {
let tail = engine_clone.output_tail(&task_id_clone, attempt).await;
let has_final = tail
.iter()
.any(|ev| matches!(ev, crate::worker::output::OutputEvent::Final { .. }));
if !has_final {
let ev = crate::worker::output::OutputEvent::Final {
content: crate::worker::output::ContentRef::Inline {
value: wr.value.clone(),
},
ok: wr.ok,
};
let _ = engine_clone
.submit_output(&token_clone, &task_id_clone, attempt, ev)
.await;
}
}
let signal: Result<(), crate::worker::adapter::WorkerError> = result.map(|_| ());
let _ = tx.send(signal);
});
Ok(Box::new(MiddlewareWorker {
handler: WorkerJoinHandler {
worker_id,
cancel,
completion: rx,
},
}))
}
}
pub struct LongHoldMiddleware {
pub default_hold: Duration,
pub event_tx: broadcast::Sender<Event>,
}
impl LongHoldMiddleware {
pub fn new(default_hold: Duration, event_tx: broadcast::Sender<Event>) -> Self {
Self {
default_hold,
event_tx,
}
}
}
impl SpawnerLayer for LongHoldMiddleware {
fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
Arc::new(LongHoldWrapped {
inner,
default_hold: self.default_hold,
event_tx: self.event_tx.clone(),
})
}
}
struct LongHoldWrapped {
inner: Arc<dyn SpawnerAdapter>,
default_hold: Duration,
event_tx: broadcast::Sender<Event>,
}
#[async_trait]
impl SpawnerAdapter for LongHoldWrapped {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
let handle = self
.inner
.spawn(engine, ctx, task_id.clone(), attempt, token)
.await?;
let started = Instant::now();
let default_hold = self.default_hold;
let event_tx = self.event_tx.clone();
let task_id_inner = task_id.clone();
Ok(wrap_join(handle, move |signal| {
let elapsed = started.elapsed();
let default_hold = default_hold;
let event_tx = event_tx.clone();
let task_id_inner = task_id_inner.clone();
async move {
if elapsed > default_hold {
let _ = event_tx.send(Event::TaskAttemptCompleted {
task_id: task_id_inner,
attempt,
result: serde_json::json!({
"long_hold_warn": true,
"elapsed_ms": elapsed.as_millis() as u64,
"default_hold_ms": default_hold.as_millis() as u64,
}),
});
}
signal
}
}))
}
}