use std::collections::{HashMap, HashSet};
use serde::Deserialize;
use zeph_llm::provider::{LlmProvider, Message, Role};
use super::dag;
use super::error::OrchestrationError;
use super::graph::{FailureStrategy, TaskGraph, TaskId, TaskNode};
use crate::config::OrchestrationConfig;
use crate::subagent::def::{SubAgentDef, ToolPolicy};
#[allow(async_fn_in_trait)]
pub trait Planner: Send + Sync {
async fn plan(
&self,
goal: &str,
available_agents: &[SubAgentDef],
) -> Result<TaskGraph, OrchestrationError>;
}
pub struct LlmPlanner<P: LlmProvider> {
provider: P,
max_tasks: u32,
}
impl<P: LlmProvider> LlmPlanner<P> {
#[must_use]
pub fn new(provider: P, config: &OrchestrationConfig) -> Self {
Self {
provider,
max_tasks: config.max_tasks,
}
}
}
#[derive(Debug, Clone, Deserialize, schemars::JsonSchema)]
pub(crate) struct PlannerResponse {
pub tasks: Vec<PlannedTask>,
}
#[derive(Debug, Clone, Deserialize, schemars::JsonSchema)]
pub(crate) struct PlannedTask {
pub task_id: String,
pub title: String,
pub description: String,
#[serde(default)]
pub agent_hint: Option<String>,
#[serde(default)]
pub depends_on: Vec<String>,
#[serde(default)]
pub failure_strategy: Option<String>,
}
impl<P: LlmProvider + Send + Sync> Planner for LlmPlanner<P> {
async fn plan(
&self,
goal: &str,
available_agents: &[SubAgentDef],
) -> Result<TaskGraph, OrchestrationError> {
if goal.trim().is_empty() {
return Err(OrchestrationError::PlanningFailed(
"goal cannot be empty".into(),
));
}
let messages = build_prompt(goal, available_agents, self.max_tasks);
let response: PlannerResponse = self
.provider
.chat_typed(&messages)
.await
.map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
let graph = convert_response(response, goal, available_agents, self.max_tasks)?;
dag::validate(&graph.tasks, self.max_tasks as usize)?;
Ok(graph)
}
}
fn build_prompt(goal: &str, agents: &[SubAgentDef], max_tasks: u32) -> Vec<Message> {
let agent_catalog = agents
.iter()
.map(|a| {
let tools = match &a.tools {
ToolPolicy::AllowList(list) => list.join(", "),
ToolPolicy::DenyList(excluded) => {
format!("all except: [{}]", excluded.join(", "))
}
ToolPolicy::InheritAll => "all".to_string(),
};
format!(
"- name: \"{}\", description: \"{}\", tools: [{}]",
a.name, a.description, tools
)
})
.collect::<Vec<_>>()
.join("\n");
let system = format!(
"You are a task planner. Decompose the user's goal into \
independent sub-tasks that can be executed by the available agents.\n\n\
Available agents:\n{agent_catalog}\n\n\
Rules:\n\
- Each task must have a unique task_id (short, descriptive, kebab-case: [a-z0-9-]).\n\
- Each task must have a clear, actionable title and description.\n\
- The description should be a complete prompt for the assigned agent.\n\
- Specify dependencies using task_id strings in depends_on.\n\
- Maximize parallelism: only add a dependency when the output is truly needed.\n\
- Do not create more than {max_tasks} tasks.\n\
- Assign agent_hint when a specific agent is clearly appropriate.\n\
- failure_strategy is optional: \"abort\", \"retry\", \"skip\", \"ask\", or omit for default.\n\n\
Example (2-task plan):\n\
{{\"tasks\": [\
{{\"task_id\": \"fetch-data\", \"title\": \"Fetch raw data\", \
\"description\": \"Download the dataset from source.\", \
\"depends_on\": []}},\
{{\"task_id\": \"process-data\", \"title\": \"Process dataset\", \
\"description\": \"Transform and clean the downloaded data.\", \
\"depends_on\": [\"fetch-data\"]}}\
]}}"
);
let user = format!("Decompose this goal into tasks:\n\n{goal}");
vec![
Message::from_legacy(Role::System, system),
Message::from_legacy(Role::User, user),
]
}
fn is_valid_task_id(id: &str) -> bool {
if id.is_empty() || id.len() > 64 {
return false;
}
let bytes = id.as_bytes();
let first_ok = bytes[0].is_ascii_lowercase() || bytes[0].is_ascii_digit();
let last_ok =
bytes[bytes.len() - 1].is_ascii_lowercase() || bytes[bytes.len() - 1].is_ascii_digit();
if !first_ok || !last_ok {
return false;
}
bytes
.iter()
.all(|b| b.is_ascii_lowercase() || b.is_ascii_digit() || *b == b'-')
}
fn convert_response(
response: PlannerResponse,
goal: &str,
available_agents: &[SubAgentDef],
max_tasks: u32,
) -> Result<TaskGraph, OrchestrationError> {
let planned = response.tasks;
if planned.is_empty() {
return Err(OrchestrationError::PlanningFailed(
"planner returned zero tasks".into(),
));
}
if planned.len() > max_tasks as usize {
return Err(OrchestrationError::PlanningFailed(format!(
"planner returned {} tasks, exceeding limit of {max_tasks}",
planned.len()
)));
}
for pt in &planned {
if !is_valid_task_id(&pt.task_id) {
return Err(OrchestrationError::PlanningFailed(format!(
"invalid task_id '{}': must match ^[a-z0-9]([a-z0-9-]*[a-z0-9])?$",
pt.task_id
)));
}
}
let id_map: HashMap<&str, u32> = planned
.iter()
.enumerate()
.map(|(i, t)| {
u32::try_from(i)
.map(|idx| (t.task_id.as_str(), idx))
.map_err(|_| {
OrchestrationError::PlanningFailed(format!("task index {i} overflows u32"))
})
})
.collect::<Result<_, _>>()?;
if id_map.len() != planned.len() {
return Err(OrchestrationError::PlanningFailed(
"duplicate task_id in planner output".into(),
));
}
let agent_names: HashSet<&str> = available_agents.iter().map(|a| a.name.as_str()).collect();
let mut graph = TaskGraph::new(goal);
for (i, pt) in planned.iter().enumerate() {
let idx = u32::try_from(i).map_err(|_| {
OrchestrationError::PlanningFailed(format!("task index {i} overflows u32"))
})?;
let mut node = TaskNode::new(idx, &pt.title, &pt.description);
for dep_str in &pt.depends_on {
match id_map.get(dep_str.as_str()) {
Some(&dep_idx) => node.depends_on.push(TaskId(dep_idx)),
None => {
return Err(OrchestrationError::PlanningFailed(format!(
"task '{}' depends on unknown task_id '{dep_str}'",
pt.task_id
)));
}
}
}
if let Some(hint) = &pt.agent_hint {
if agent_names.contains(hint.as_str()) {
node.agent_hint = Some(hint.clone());
} else {
tracing::warn!(
task_id = %pt.task_id,
agent_hint = %hint,
"unknown agent_hint in planner output, ignoring"
);
}
}
if let Some(fs_str) = &pt.failure_strategy {
match fs_str.parse::<FailureStrategy>() {
Ok(fs) => node.failure_strategy = Some(fs),
Err(_) => {
tracing::warn!(
task_id = %pt.task_id,
strategy = %fs_str,
"invalid failure_strategy in planner output, using default"
);
}
}
}
graph.tasks.push(node);
}
Ok(graph)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::subagent::def::{SkillFilter, SubAgentDef, SubAgentPermissions, ToolPolicy};
use crate::subagent::hooks::SubagentHooks;
fn make_agent(name: &str, tools: ToolPolicy) -> SubAgentDef {
SubAgentDef {
name: name.to_string(),
description: format!("{name} agent"),
model: None,
tools,
disallowed_tools: Vec::new(),
permissions: SubAgentPermissions::default(),
skills: SkillFilter::default(),
system_prompt: String::new(),
hooks: SubagentHooks::default(),
memory: None,
source: None,
file_path: None,
}
}
fn make_planned(
task_id: &str,
title: &str,
deps: Vec<&str>,
agent_hint: Option<&str>,
) -> PlannedTask {
PlannedTask {
task_id: task_id.to_string(),
title: title.to_string(),
description: format!("do {title}"),
agent_hint: agent_hint.map(|s| s.to_string()),
depends_on: deps.iter().map(|s| s.to_string()).collect(),
failure_strategy: None,
}
}
fn agents() -> Vec<SubAgentDef> {
vec![
make_agent("agent-a", ToolPolicy::InheritAll),
make_agent("agent-b", ToolPolicy::AllowList(vec!["shell".to_string()])),
]
}
#[test]
fn test_convert_valid_linear_chain() {
let response = PlannerResponse {
tasks: vec![
make_planned("task-a", "Task A", vec![], None),
make_planned("task-b", "Task B", vec!["task-a"], None),
make_planned("task-c", "Task C", vec!["task-b"], None),
],
};
let graph = convert_response(response, "linear goal", &agents(), 20).unwrap();
assert_eq!(graph.tasks.len(), 3);
assert_eq!(graph.tasks[0].id, TaskId(0));
assert_eq!(graph.tasks[1].depends_on, vec![TaskId(0)]);
assert_eq!(graph.tasks[2].depends_on, vec![TaskId(1)]);
}
#[test]
fn test_convert_valid_diamond() {
let response = PlannerResponse {
tasks: vec![
make_planned("a", "A", vec![], None),
make_planned("b", "B", vec!["a"], None),
make_planned("c", "C", vec!["a"], None),
make_planned("d", "D", vec!["b", "c"], None),
],
};
let graph = convert_response(response, "diamond", &agents(), 20).unwrap();
assert_eq!(graph.tasks[3].depends_on, vec![TaskId(1), TaskId(2)]);
}
#[test]
fn test_convert_parallel_tasks() {
let response = PlannerResponse {
tasks: vec![
make_planned("t1", "T1", vec![], None),
make_planned("t2", "T2", vec![], None),
make_planned("t3", "T3", vec![], None),
],
};
let graph = convert_response(response, "parallel", &agents(), 20).unwrap();
for node in &graph.tasks {
assert!(node.depends_on.is_empty());
}
}
#[test]
fn test_convert_empty_tasks_rejected() {
let response = PlannerResponse { tasks: vec![] };
let err = convert_response(response, "goal", &agents(), 20).unwrap_err();
assert!(matches!(err, OrchestrationError::PlanningFailed(_)));
}
#[test]
fn test_convert_exceeds_max_tasks() {
let tasks = (0..5)
.map(|i| make_planned(&format!("task-{i}"), &format!("T{i}"), vec![], None))
.collect();
let response = PlannerResponse { tasks };
let err = convert_response(response, "goal", &agents(), 3).unwrap_err();
assert!(matches!(err, OrchestrationError::PlanningFailed(_)));
}
#[test]
fn test_convert_duplicate_task_ids() {
let response = PlannerResponse {
tasks: vec![
make_planned("dup", "First", vec![], None),
make_planned("dup", "Second", vec![], None),
],
};
let err = convert_response(response, "goal", &agents(), 20).unwrap_err();
assert!(matches!(err, OrchestrationError::PlanningFailed(_)));
}
#[test]
fn test_convert_unknown_dependency() {
let response = PlannerResponse {
tasks: vec![make_planned("task-a", "A", vec!["nonexistent"], None)],
};
let err = convert_response(response, "goal", &agents(), 20).unwrap_err();
assert!(matches!(err, OrchestrationError::PlanningFailed(_)));
}
#[test]
fn test_convert_unknown_agent_hint_warns() {
let response = PlannerResponse {
tasks: vec![make_planned("task-a", "A", vec![], Some("unknown-agent"))],
};
let graph = convert_response(response, "goal", &agents(), 20).unwrap();
assert!(graph.tasks[0].agent_hint.is_none());
}
#[test]
fn test_convert_known_agent_hint_stored() {
let response = PlannerResponse {
tasks: vec![make_planned("task-a", "A", vec![], Some("agent-a"))],
};
let graph = convert_response(response, "goal", &agents(), 20).unwrap();
assert_eq!(graph.tasks[0].agent_hint.as_deref(), Some("agent-a"));
}
#[test]
fn test_convert_invalid_task_id_format() {
let cases = vec![
"", " ", "Task A", "-task", "task-", "TASK", "task_one", "задача", ];
for bad_id in cases {
let response = PlannerResponse {
tasks: vec![PlannedTask {
task_id: bad_id.to_string(),
title: "T".to_string(),
description: "d".to_string(),
agent_hint: None,
depends_on: vec![],
failure_strategy: None,
}],
};
let err = convert_response(response, "goal", &agents(), 20).unwrap_err();
assert!(
matches!(err, OrchestrationError::PlanningFailed(_)),
"expected PlanningFailed for task_id '{bad_id}'"
);
}
}
#[test]
fn test_convert_valid_task_id_formats() {
let cases = vec!["a", "a1", "task-a", "fetch-data-v2", "0"];
for id in cases {
assert!(is_valid_task_id(id), "expected valid: '{id}'");
}
}
#[test]
fn test_convert_invalid_failure_strategy_uses_none() {
let response = PlannerResponse {
tasks: vec![PlannedTask {
task_id: "task-a".to_string(),
title: "A".to_string(),
description: "d".to_string(),
agent_hint: None,
depends_on: vec![],
failure_strategy: Some("explode".to_string()),
}],
};
let graph = convert_response(response, "goal", &agents(), 20).unwrap();
assert!(graph.tasks[0].failure_strategy.is_none());
}
#[test]
fn test_convert_goal_is_set() {
let response = PlannerResponse {
tasks: vec![make_planned("t1", "T1", vec![], None)],
};
let graph = convert_response(response, "my goal", &agents(), 20).unwrap();
assert_eq!(graph.goal, "my goal");
}
#[test]
fn test_build_prompt_includes_agent_catalog() {
let msgs = build_prompt("do something", &agents(), 20);
let text = &msgs[0].content;
assert!(text.contains("agent-a"));
assert!(text.contains("agent-b"));
assert!(text.contains("shell"));
}
#[test]
fn test_build_prompt_includes_max_tasks() {
let msgs = build_prompt("goal", &agents(), 42);
let text = &msgs[0].content;
assert!(text.contains("42"));
}
#[test]
fn test_build_prompt_deny_list_renders_as_except() {
let a = make_agent(
"restricted",
ToolPolicy::DenyList(vec!["shell".to_string(), "web".to_string()]),
);
let msgs = build_prompt("goal", &[a], 20);
let text = &msgs[0].content;
assert!(text.contains("all except:"));
assert!(text.contains("shell"));
assert!(text.contains("web"));
}
#[test]
fn test_build_prompt_has_two_messages() {
let msgs = build_prompt("goal", &agents(), 20);
assert_eq!(msgs.len(), 2);
}
#[test]
fn test_build_prompt_includes_example_json() {
let msgs = build_prompt("goal", &agents(), 20);
let text = &msgs[0].content;
assert!(
text.contains("fetch-data"),
"example should include fetch-data task_id"
);
assert!(
text.contains("depends_on"),
"example should show depends_on field"
);
}
mod integration {
use super::*;
use zeph_llm::mock::MockProvider;
fn valid_json_response() -> String {
r#"{"tasks": [
{"task_id": "step-one", "title": "Step one", "description": "Do step one", "depends_on": []},
{"task_id": "step-two", "title": "Step two", "description": "Do step two", "depends_on": ["step-one"]}
]}"#
.to_string()
}
fn cyclic_json_response() -> String {
r#"{"tasks": [
{"task_id": "a", "title": "A", "description": "A desc", "depends_on": ["b"]},
{"task_id": "b", "title": "B", "description": "B desc", "depends_on": ["a"]}
]}"#
.to_string()
}
fn single_task_json() -> String {
r#"{"tasks": [
{"task_id": "only-task", "title": "The task", "description": "Do it", "depends_on": []}
]}"#
.to_string()
}
fn make_config() -> OrchestrationConfig {
OrchestrationConfig::default()
}
#[tokio::test]
async fn test_plan_valid_response() {
let provider = MockProvider::with_responses(vec![valid_json_response()]);
let planner = LlmPlanner::new(provider, &make_config());
let graph = planner.plan("build and deploy", &agents()).await.unwrap();
assert_eq!(graph.tasks.len(), 2);
assert_eq!(graph.goal, "build and deploy");
}
#[tokio::test]
async fn test_plan_cycle_rejected() {
let provider = MockProvider::with_responses(vec![cyclic_json_response()]);
let planner = LlmPlanner::new(provider, &make_config());
let err = planner.plan("cyclic", &agents()).await.unwrap_err();
assert!(matches!(err, OrchestrationError::CycleDetected));
}
#[tokio::test]
async fn test_plan_empty_goal_rejected() {
let provider = MockProvider::default();
let planner = LlmPlanner::new(provider, &make_config());
let err = planner.plan(" ", &agents()).await.unwrap_err();
assert!(matches!(err, OrchestrationError::PlanningFailed(_)));
}
#[tokio::test]
async fn test_plan_llm_error_maps_to_planning_failed() {
let provider = MockProvider::failing();
let planner = LlmPlanner::new(provider, &make_config());
let err = planner.plan("valid goal", &agents()).await.unwrap_err();
assert!(matches!(err, OrchestrationError::PlanningFailed(_)));
}
#[tokio::test]
async fn test_plan_invalid_failure_strategy_warns() {
let json = r#"{"tasks": [
{"task_id": "t1", "title": "T1", "description": "d", "depends_on": [],
"failure_strategy": "explode"}
]}"#
.to_string();
let provider = MockProvider::with_responses(vec![json]);
let planner = LlmPlanner::new(provider, &make_config());
let graph = planner.plan("goal", &agents()).await.unwrap();
assert!(graph.tasks[0].failure_strategy.is_none());
}
#[tokio::test]
async fn test_plan_single_task_goal() {
let provider = MockProvider::with_responses(vec![single_task_json()]);
let planner = LlmPlanner::new(provider, &make_config());
let graph = planner.plan("simple task", &agents()).await.unwrap();
assert_eq!(graph.tasks.len(), 1);
assert!(graph.tasks[0].depends_on.is_empty());
}
}
}