1use std::sync::Arc;
12
13use anyhow::{Result, anyhow};
14use serde::{Deserialize, Serialize};
15
16use brainwires_core::{Provider, Task, TaskPriority};
17
18use crate::context::AgentContext;
19use crate::system_prompts::planner_agent_prompt;
20use crate::task_agent::{TaskAgent, TaskAgentConfig, TaskAgentResult};
21
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "lowercase")]
27pub enum DynamicTaskPriority {
28 Urgent,
30 High,
32 Normal,
34 Low,
36}
37
38impl From<DynamicTaskPriority> for TaskPriority {
39 fn from(p: DynamicTaskPriority) -> Self {
40 match p {
41 DynamicTaskPriority::Urgent => TaskPriority::Urgent,
42 DynamicTaskPriority::High => TaskPriority::High,
43 DynamicTaskPriority::Normal => TaskPriority::Normal,
44 DynamicTaskPriority::Low => TaskPriority::Low,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct DynamicTaskSpec {
52 pub id: String,
54 pub description: String,
56 #[serde(default)]
58 pub files_involved: Vec<String>,
59 #[serde(default)]
61 pub depends_on: Vec<String>,
62 #[serde(default = "default_priority")]
64 pub priority: DynamicTaskPriority,
65 #[serde(default)]
67 pub estimated_iterations: Option<u32>,
68 #[serde(skip)]
70 pub agent_config_override: Option<TaskAgentConfig>,
71}
72
73fn default_priority() -> DynamicTaskPriority {
74 DynamicTaskPriority::Normal
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct SubPlannerRequest {
80 pub focus_area: String,
82 pub context: String,
84 #[serde(default = "default_max_depth")]
86 pub max_depth: u32,
87}
88
89fn default_max_depth() -> u32 {
90 1
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct PlannerOutput {
96 pub tasks: Vec<DynamicTaskSpec>,
98 #[serde(default)]
100 pub sub_planners: Vec<SubPlannerRequest>,
101 #[serde(default)]
103 pub rationale: String,
104}
105
106#[derive(Debug, Clone)]
108pub struct PlannerAgentConfig {
109 pub max_iterations: u32,
111 pub max_tasks: usize,
113 pub max_sub_planners: usize,
115 pub planning_depth: u32,
117 pub temperature: f32,
119 pub max_tokens: u32,
121}
122
123impl Default for PlannerAgentConfig {
124 fn default() -> Self {
125 Self {
126 max_iterations: 20,
127 max_tasks: 15,
128 max_sub_planners: 3,
129 planning_depth: 2,
130 temperature: 0.7,
131 max_tokens: 4096,
132 }
133 }
134}
135
136pub struct PlannerAgent {
143 agent: Arc<TaskAgent>,
144 config: PlannerAgentConfig,
145}
146
147impl PlannerAgent {
148 pub fn new(
158 id: String,
159 goal: &str,
160 hints: &[String],
161 provider: Arc<dyn Provider>,
162 context: Arc<AgentContext>,
163 config: PlannerAgentConfig,
164 ) -> Self {
165 let system_prompt = planner_agent_prompt(&id, &context.working_directory, goal, hints);
166
167 let agent_config = TaskAgentConfig {
168 max_iterations: config.max_iterations,
169 system_prompt: Some(system_prompt),
170 temperature: config.temperature,
171 max_tokens: config.max_tokens,
172 validation_config: None, ..Default::default()
174 };
175
176 let task = Task::new(
177 format!("planner-{}", uuid::Uuid::new_v4()),
178 format!("Plan tasks for: {}", goal),
179 );
180
181 let agent = Arc::new(TaskAgent::new(id, task, provider, context, agent_config));
182
183 Self { agent, config }
184 }
185
186 pub async fn execute(&self) -> Result<(PlannerOutput, TaskAgentResult)> {
188 let result = self.agent.execute().await?;
189
190 if !result.success {
191 return Err(anyhow!("Planner agent failed: {}", result.summary));
192 }
193
194 let output = Self::parse_output(&result.summary, &self.config)?;
195 Ok((output, result))
196 }
197
198 pub fn parse_output(text: &str, config: &PlannerAgentConfig) -> Result<PlannerOutput> {
202 let json_str = extract_json_block(text)
203 .ok_or_else(|| anyhow!("No JSON block found in planner output"))?;
204
205 let mut output: PlannerOutput = serde_json::from_str(&json_str)
206 .map_err(|e| anyhow!("Failed to parse planner JSON: {}", e))?;
207
208 output.tasks.truncate(config.max_tasks);
210 output.sub_planners.truncate(config.max_sub_planners);
211
212 for task in &mut output.tasks {
214 if task.id.is_empty() {
215 task.id = uuid::Uuid::new_v4().to_string();
216 }
217 }
218
219 validate_task_graph(&output.tasks)?;
221
222 Ok(output)
223 }
224
225 pub fn agent(&self) -> &Arc<TaskAgent> {
227 &self.agent
228 }
229}
230
231fn extract_json_block(text: &str) -> Option<String> {
235 if let Some(start) = text.find("```json") {
237 let content_start = start + "```json".len();
238 if let Some(end) = text[content_start..].find("```") {
239 return Some(text[content_start..content_start + end].trim().to_string());
240 }
241 }
242
243 if let Some(start) = text.find("```") {
245 let content_start = start + "```".len();
246 let line_end = text[content_start..]
248 .find('\n')
249 .unwrap_or(text[content_start..].len());
250 let actual_start = content_start + line_end + 1;
251 if actual_start < text.len()
252 && let Some(end) = text[actual_start..].find("```")
253 {
254 let candidate = text[actual_start..actual_start + end].trim();
255 if candidate.starts_with('{') {
256 return Some(candidate.to_string());
257 }
258 }
259 }
260
261 if let Some(start) = text.find('{') {
263 let mut depth = 0;
265 let mut end = start;
266 for (i, ch) in text[start..].char_indices() {
267 match ch {
268 '{' => depth += 1,
269 '}' => {
270 depth -= 1;
271 if depth == 0 {
272 end = start + i + 1;
273 break;
274 }
275 }
276 _ => {}
277 }
278 }
279 if depth == 0 && end > start {
280 return Some(text[start..end].to_string());
281 }
282 }
283
284 None
285}
286
287fn validate_task_graph(tasks: &[DynamicTaskSpec]) -> Result<()> {
289 use std::collections::{HashMap, HashSet, VecDeque};
290
291 let id_set: HashSet<&str> = tasks.iter().map(|t| t.id.as_str()).collect();
292
293 let mut in_degree: HashMap<&str, usize> = tasks.iter().map(|t| (t.id.as_str(), 0)).collect();
295 for task in tasks {
297 let count = task
298 .depends_on
299 .iter()
300 .filter(|d| id_set.contains(d.as_str()))
301 .count();
302 in_degree.insert(task.id.as_str(), count);
303 }
304
305 let mut queue: VecDeque<&str> = in_degree
306 .iter()
307 .filter(|(_, deg)| **deg == 0)
308 .map(|(&id, _)| id)
309 .collect();
310
311 let mut visited = 0usize;
312 while let Some(node) = queue.pop_front() {
313 visited += 1;
314 for task in tasks {
316 if task.depends_on.iter().any(|d| d == node) && id_set.contains(task.id.as_str()) {
317 let deg = in_degree.get_mut(task.id.as_str()).unwrap();
318 *deg -= 1;
319 if *deg == 0 {
320 queue.push_back(task.id.as_str());
321 }
322 }
323 }
324 }
325
326 if visited < tasks.len() {
327 return Err(anyhow!(
328 "Circular dependency detected in planner task graph"
329 ));
330 }
331
332 Ok(())
333}
334
335#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_extract_json_block_fenced() {
343 let text = r#"Here is the plan:
344
345```json
346{"tasks": [], "rationale": "nothing to do"}
347```
348
349Done."#;
350 let json = extract_json_block(text).unwrap();
351 assert!(json.contains("tasks"));
352 }
353
354 #[test]
355 fn test_extract_json_block_raw() {
356 let text = r#"I think the plan is {"tasks": [], "rationale": "test"} and that's it."#;
357 let json = extract_json_block(text).unwrap();
358 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
359 assert!(parsed["tasks"].is_array());
360 }
361
362 #[test]
363 fn test_parse_planner_output() {
364 let text = r#"```json
365{
366 "tasks": [
367 {
368 "id": "task-1",
369 "description": "Add error handling to parser",
370 "files_involved": ["src/parser.rs"],
371 "depends_on": [],
372 "priority": "high",
373 "estimated_iterations": 10
374 },
375 {
376 "id": "task-2",
377 "description": "Add tests for parser",
378 "files_involved": ["tests/parser_test.rs"],
379 "depends_on": ["task-1"],
380 "priority": "normal",
381 "estimated_iterations": 5
382 }
383 ],
384 "sub_planners": [],
385 "rationale": "Parser needs error handling before tests can be written"
386}
387```"#;
388
389 let config = PlannerAgentConfig::default();
390 let output = PlannerAgent::parse_output(text, &config).unwrap();
391 assert_eq!(output.tasks.len(), 2);
392 assert_eq!(output.tasks[0].id, "task-1");
393 assert_eq!(output.tasks[1].depends_on, vec!["task-1"]);
394 assert_eq!(
395 output.rationale,
396 "Parser needs error handling before tests can be written"
397 );
398 }
399
400 #[test]
401 fn test_validate_task_graph_no_cycle() {
402 let tasks = vec![
403 DynamicTaskSpec {
404 id: "a".into(),
405 description: "A".into(),
406 files_involved: vec![],
407 depends_on: vec![],
408 priority: DynamicTaskPriority::Normal,
409 estimated_iterations: None,
410 agent_config_override: None,
411 },
412 DynamicTaskSpec {
413 id: "b".into(),
414 description: "B".into(),
415 files_involved: vec![],
416 depends_on: vec!["a".into()],
417 priority: DynamicTaskPriority::Normal,
418 estimated_iterations: None,
419 agent_config_override: None,
420 },
421 ];
422 assert!(validate_task_graph(&tasks).is_ok());
423 }
424
425 #[test]
426 fn test_validate_task_graph_cycle() {
427 let tasks = vec![
428 DynamicTaskSpec {
429 id: "a".into(),
430 description: "A".into(),
431 files_involved: vec![],
432 depends_on: vec!["b".into()],
433 priority: DynamicTaskPriority::Normal,
434 estimated_iterations: None,
435 agent_config_override: None,
436 },
437 DynamicTaskSpec {
438 id: "b".into(),
439 description: "B".into(),
440 files_involved: vec![],
441 depends_on: vec!["a".into()],
442 priority: DynamicTaskPriority::Normal,
443 estimated_iterations: None,
444 agent_config_override: None,
445 },
446 ];
447 assert!(validate_task_graph(&tasks).is_err());
448 }
449
450 #[test]
451 fn test_truncate_limits() {
452 let text = r#"```json
453{
454 "tasks": [
455 {"id": "1", "description": "t1"},
456 {"id": "2", "description": "t2"},
457 {"id": "3", "description": "t3"}
458 ],
459 "sub_planners": [
460 {"focus_area": "a", "context": "c", "max_depth": 1},
461 {"focus_area": "b", "context": "c", "max_depth": 1}
462 ],
463 "rationale": "test"
464}
465```"#;
466
467 let config = PlannerAgentConfig {
468 max_tasks: 2,
469 max_sub_planners: 1,
470 ..Default::default()
471 };
472 let output = PlannerAgent::parse_output(text, &config).unwrap();
473 assert_eq!(output.tasks.len(), 2);
474 assert_eq!(output.sub_planners.len(), 1);
475 }
476
477 #[test]
478 fn test_dynamic_task_priority_conversion() {
479 assert_eq!(
480 TaskPriority::from(DynamicTaskPriority::Urgent),
481 TaskPriority::Urgent
482 );
483 assert_eq!(
484 TaskPriority::from(DynamicTaskPriority::Normal),
485 TaskPriority::Normal
486 );
487 }
488}