1use async_trait::async_trait;
2use rain_engine_core::{
3 AgentAction, AgentStateSnapshot, AgentTrigger, GoalId, GoalRecord, GoalStatus, KernelEvent,
4 LlmProvider, Planner, PlannerOutput, ProviderContentPart, ProviderMessage, ProviderRequest,
5 ProviderRole, TaskId, TaskRecord, TaskStatus,
6};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::SystemTime;
11
12pub struct ResearchPlanner {
13 llm: Arc<dyn LlmProvider>,
14}
15
16impl ResearchPlanner {
17 pub fn new(llm: Arc<dyn LlmProvider>) -> Self {
18 Self { llm }
19 }
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23struct ResearchPlan {
24 goal_title: String,
25 goal_detail: String,
26 tasks: Vec<PlannedTask>,
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30struct PlannedTask {
31 id: String,
32 title: String,
33 detail: String,
34 depends_on: Vec<String>,
35}
36
37#[async_trait]
38impl Planner for ResearchPlanner {
39 async fn plan(&self, state: &AgentStateSnapshot, trigger: &AgentTrigger) -> PlannerOutput {
40 let has_active_tasks = state.tasks.iter().any(|task| {
42 matches!(
43 task.status,
44 TaskStatus::Pending
45 | TaskStatus::Ready
46 | TaskStatus::Running
47 | TaskStatus::Blocked
48 | TaskStatus::WaitingHuman
49 )
50 });
51
52 if has_active_tasks {
53 return PlannerOutput::default();
54 }
55
56 let content = match trigger {
57 AgentTrigger::HumanInput { content, .. } | AgentTrigger::Message { content, .. } => {
58 content
59 }
60 _ => return PlannerOutput::default(),
61 };
62
63 let system_prompt = r#"You are a Research Architect. Your job is to decompose a high-level research goal into a directed acyclic graph (DAG) of specific tasks.
64Output your plan as a single JSON object with this structure:
65{
66 "goal_title": "Short title",
67 "goal_detail": "Detailed explanation",
68 "tasks": [
69 {
70 "id": "task-1",
71 "title": "Task title",
72 "detail": "What to do",
73 "depends_on": []
74 },
75 ...
76 ]
77}
78Each task should be atomic. Use "depends_on" to specify which task IDs must be completed before a task can start."#;
79
80 let request = ProviderRequest {
81 trigger: trigger.clone(),
82 context: rain_engine_core::AgentContextSnapshot {
83 session_id: state.agent_id.0.clone(),
84 granted_scopes: vec!["scope:research".to_string()],
85 trigger_id: "planning-trigger".to_string(),
86 idempotency_key: None,
87 current_step: 0,
88 max_steps: 1,
89 history: Vec::new(),
90 prior_tool_results: Vec::new(),
91 session_cost_usd: 0.0,
92 state: state.clone(),
93 policy: Default::default(),
94 active_execution_plan: None,
95 },
96 available_skills: Vec::new(),
97 config: Default::default(),
98 policy: Default::default(),
99 contents: vec![
100 ProviderMessage {
101 role: ProviderRole::System,
102 parts: vec![ProviderContentPart::Text(system_prompt.to_string())],
103 },
104 ProviderMessage {
105 role: ProviderRole::User,
106 parts: vec![ProviderContentPart::Text(format!("Goal: {}", content))],
107 },
108 ],
109 };
110
111 let decision = match self.llm.generate_action(request).await {
112 Ok(d) => d,
113 Err(_) => return PlannerOutput::default(),
114 };
115
116 let response_text = match decision.action {
117 AgentAction::Respond { content } => content,
118 _ => return PlannerOutput::default(),
119 };
120
121 let plan: ResearchPlan = match serde_json::from_str(&extract_json(&response_text)) {
123 Ok(p) => p,
124 Err(_) => return PlannerOutput::default(),
125 };
126
127 let mut events = Vec::new();
128 let goal_id = GoalId(format!("goal-{}", state.goals.len() + 1));
129
130 events.push(KernelEvent::GoalCreated(GoalRecord {
131 goal_id: goal_id.clone(),
132 created_at: SystemTime::now(),
133 title: plan.goal_title,
134 detail: Some(plan.goal_detail),
135 status: GoalStatus::Active,
136 parent_goal_id: None,
137 }));
138
139 let mut id_map = HashMap::new();
140 for (i, planned_task) in plan.tasks.into_iter().enumerate() {
141 let task_id = TaskId(format!("task-{}", state.tasks.len() + i + 1));
142 id_map.insert(planned_task.id, task_id.clone());
143
144 let mut blocked_by = Vec::new();
145 for dep in planned_task.depends_on {
146 if let Some(dep_id) = id_map.get(&dep) {
147 blocked_by.push(dep_id.clone());
148 }
149 }
150
151 events.push(KernelEvent::TaskPlanned(TaskRecord {
152 task_id,
153 goal_id: Some(goal_id.clone()),
154 parent_task_id: None,
155 created_at: SystemTime::now(),
156 title: planned_task.title,
157 detail: Some(planned_task.detail),
158 status: if blocked_by.is_empty() {
159 TaskStatus::Ready
160 } else {
161 TaskStatus::Blocked
162 },
163 assignee: None,
164 blocked_by,
165 }));
166 }
167
168 PlannerOutput {
169 events,
170 proposed_plan: None,
171 }
172 }
173}
174
175fn extract_json(text: &str) -> String {
176 let text = text.trim();
177 if let Some(start) = text.find('{')
178 && let Some(end) = text.rfind('}')
179 {
180 return text[start..=end].to_string();
181 }
182 text.to_string()
183}