use std::sync::Arc;
use futures::stream::BoxStream;
use serde::Deserialize;
use wesichain_core::{
LlmRequest, LlmResponse, Message, MessageContent, Runnable, StreamEvent, WesichainError,
};
pub trait WorkerRunner: Send + Sync {
fn run(&self, task: String) -> BoxStream<'static, Result<StreamEvent, WesichainError>>;
}
pub struct WorkerSpec {
pub name: String,
pub description: String,
pub runner: Arc<dyn WorkerRunner>,
}
pub struct SupervisorBuilder<L> {
llm: L,
workers: Vec<WorkerSpec>,
max_rounds: usize,
model: String,
}
impl<L> SupervisorBuilder<L>
where
L: Runnable<LlmRequest, LlmResponse> + Clone + Send + Sync + 'static,
{
pub fn new(llm: L) -> Self {
Self { llm, workers: Vec::new(), max_rounds: 10, model: String::new() }
}
pub fn add_worker(mut self, worker: WorkerSpec) -> Self {
self.workers.push(worker);
self
}
pub fn max_rounds(mut self, max: usize) -> Self {
self.max_rounds = max;
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn build(self) -> Supervisor<L> {
Supervisor {
llm: self.llm,
workers: Arc::new(self.workers),
max_rounds: self.max_rounds,
model: self.model,
}
}
}
pub struct Supervisor<L> {
llm: L,
workers: Arc<Vec<WorkerSpec>>,
max_rounds: usize,
model: String,
}
impl<L> Supervisor<L>
where
L: Runnable<LlmRequest, LlmResponse> + Clone + Send + Sync + 'static,
{
pub async fn run(&self, task: String) -> Result<String, WesichainError> {
let worker_list = self
.workers
.iter()
.map(|w| format!("- {}: {}", w.name, w.description))
.collect::<Vec<_>>()
.join("\n");
let system_prompt = format!(
"You are a supervisor agent. You have access to the following workers:\n\
{worker_list}\n\n\
To delegate a task respond ONLY with JSON: \
{{\"action\":\"delegate\",\"worker\":\"<name>\",\"task\":\"<task>\"}}\n\
When you have a final answer respond ONLY with JSON: \
{{\"action\":\"finish\",\"answer\":\"<answer>\"}}"
);
let mut messages = vec![
Message { role: wesichain_core::Role::System, content: MessageContent::Text(system_prompt), tool_call_id: None, tool_calls: vec![] },
Message { role: wesichain_core::Role::User, content: MessageContent::Text(task.clone()), tool_call_id: None, tool_calls: vec![] },
];
for _ in 0..self.max_rounds {
let req = LlmRequest {
model: self.model.clone(),
messages: messages.clone(),
tools: vec![],
temperature: Some(0.0),
max_tokens: None,
stop_sequences: vec![],
};
let resp = self.llm.invoke(req).await?;
let text = resp.content.trim().to_string();
match serde_json::from_str::<SupervisorDecision>(&text) {
Ok(SupervisorDecision::Delegate { worker, task: sub_task }) => {
let runner = self
.workers
.iter()
.find(|w| w.name == worker)
.map(|w| w.runner.clone());
let worker_output = if let Some(runner) = runner {
collect_stream_to_text(runner.run(sub_task)).await
} else {
format!("[error: unknown worker '{worker}']")
};
messages.push(Message {
role: wesichain_core::Role::Assistant,
content: MessageContent::Text(text),
tool_call_id: None,
tool_calls: vec![],
});
messages.push(Message {
role: wesichain_core::Role::User,
content: MessageContent::Text(format!(
"Worker '{worker}' returned:\n{worker_output}\n\nContinue."
)),
tool_call_id: None,
tool_calls: vec![],
});
}
Ok(SupervisorDecision::Finish { answer }) => {
return Ok(answer);
}
Err(_) => {
return Ok(text);
}
}
}
Err(WesichainError::Custom(format!(
"Supervisor exceeded max_rounds ({}) without finishing",
self.max_rounds
)))
}
}
#[derive(Debug, Deserialize)]
#[serde(tag = "action", rename_all = "lowercase")]
enum SupervisorDecision {
Delegate { worker: String, task: String },
Finish { answer: String },
}
async fn collect_stream_to_text(
mut stream: BoxStream<'static, Result<StreamEvent, WesichainError>>,
) -> String {
use futures::StreamExt;
let mut buf = String::new();
while let Some(item) = stream.next().await {
match item {
Ok(StreamEvent::ContentChunk(s)) | Ok(StreamEvent::FinalAnswer(s)) => buf.push_str(&s),
_ => {}
}
}
buf
}