use std::sync::Arc;
use entelix_core::ir::Message;
use entelix_core::{Error, ExecutionContext, Result};
use entelix_graph::StateGraph;
use crate::agent::Agent;
use entelix_runnable::{Runnable, RunnableLambda};
use crate::state::SupervisorState;
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum SupervisorDecision {
Agent(String),
Finish,
Handoff {
agent: String,
payload: serde_json::Value,
},
}
impl SupervisorDecision {
pub fn agent(name: impl Into<String>) -> Self {
Self::Agent(name.into())
}
pub fn handoff(agent: impl Into<String>, payload: serde_json::Value) -> Self {
Self::Handoff {
agent: agent.into(),
payload,
}
}
#[must_use]
pub fn agent_name(&self) -> Option<&str> {
match self {
Self::Agent(name) => Some(name),
Self::Handoff { agent, .. } => Some(agent),
Self::Finish => None,
}
}
}
pub struct AgentEntry {
pub name: String,
pub agent: Arc<dyn Runnable<Vec<Message>, Message>>,
}
impl AgentEntry {
pub fn new<R>(name: impl Into<String>, agent: R) -> Self
where
R: Runnable<Vec<Message>, Message> + 'static,
{
Self {
name: name.into(),
agent: Arc::new(agent),
}
}
}
pub fn build_supervisor_graph<R>(
router: R,
agents: Vec<AgentEntry>,
) -> Result<entelix_graph::CompiledGraph<SupervisorState>>
where
R: Runnable<Vec<Message>, SupervisorDecision> + 'static,
{
if agents.is_empty() {
return Err(Error::config(
"build_supervisor_graph: at least one agent required",
));
}
let router = Arc::new(router);
let supervisor_node =
RunnableLambda::new(move |mut state: SupervisorState, ctx: ExecutionContext| {
let router = router.clone();
async move {
let decision = router.invoke(state.messages.clone(), &ctx).await?;
if let Some(name) = decision.agent_name()
&& let Some(handle) = ctx.audit_sink()
{
handle
.as_sink()
.record_agent_handoff(state.last_speaker.as_deref(), name);
}
state.next_speaker = Some(decision);
Ok::<_, _>(state)
}
});
let mut graph = StateGraph::<SupervisorState>::new()
.add_node("supervisor", supervisor_node)
.set_entry_point("supervisor");
let finish_node =
RunnableLambda::new(|state: SupervisorState, _ctx| async move { Ok::<_, _>(state) });
graph = graph
.add_node("finish", finish_node)
.add_finish_point("finish");
let known_names: std::collections::HashSet<String> =
agents.iter().map(|e| e.name.clone()).collect();
let mut conditional_mapping: Vec<(String, String)> =
vec![(FINISH_KEY.to_owned(), "finish".to_owned())];
for entry in agents {
let AgentEntry { name, agent } = entry;
let label = name.clone();
let node_name = name.clone();
let agent_node =
RunnableLambda::new(move |mut state: SupervisorState, ctx: ExecutionContext| {
let agent = agent.clone();
let label = label.clone();
async move {
if let Some(SupervisorDecision::Handoff { payload, .. }) =
state.next_speaker.take()
{
let rendered = serde_json::to_string_pretty(&payload)
.unwrap_or_else(|_| payload.to_string());
state
.messages
.push(Message::system(format!("Handoff payload:\n{rendered}")));
}
let reply = agent.invoke(state.messages.clone(), &ctx).await?;
state.messages.push(reply);
state.last_speaker = Some(label.clone());
state.next_speaker = None;
Ok::<_, _>(state)
}
});
graph = graph
.add_node(node_name.clone(), agent_node)
.add_edge(node_name.clone(), "supervisor");
conditional_mapping.push((name.clone(), name));
}
graph = graph.add_conditional_edges(
"supervisor",
move |state: &SupervisorState| match state.next_speaker.as_ref().and_then(SupervisorDecision::agent_name) {
Some(name) if known_names.contains(name) => name.to_owned(),
Some(name) => {
tracing::warn!(
target: "entelix_agents::supervisor",
unknown_agent = %name,
"supervisor router emitted decision routing to '{name}' but no AgentEntry by that name; finishing"
);
FINISH_KEY.to_owned()
}
None => FINISH_KEY.to_owned(),
},
conditional_mapping,
);
graph.compile()
}
const FINISH_KEY: &str = "__finish__";
pub fn create_supervisor_agent<R>(
router: R,
agents: Vec<AgentEntry>,
) -> Result<Agent<SupervisorState>>
where
R: Runnable<Vec<Message>, SupervisorDecision> + 'static,
{
let compiled = build_supervisor_graph(router, agents)?;
Agent::<SupervisorState>::builder()
.with_name("supervisor")
.with_runnable(compiled)
.build()
}
pub fn team_from_supervisor(team: Agent<SupervisorState>) -> impl Runnable<Vec<Message>, Message> {
let team = Arc::new(team);
RunnableLambda::new(move |messages: Vec<Message>, ctx: ExecutionContext| {
let team = team.clone();
async move {
let state = SupervisorState {
messages,
last_speaker: None,
next_speaker: None,
};
let final_state = team.execute(state, &ctx).await?.into_state();
final_state.messages.last().cloned().ok_or_else(|| {
Error::invalid_request(
"team_from_supervisor: team finished with empty conversation",
)
})
}
})
}