Skip to main content

rain_engine_cognition/
research_planner.rs

1use async_trait::async_trait;
2use rain_engine_core::{
3    AgentAction, AgentStateSnapshot, AgentTrigger, GoalId, GoalRecord, GoalStatus, KernelEvent,
4    LlmProvider, Planner, PlannerOutput, ProviderContentPart, ProviderMessage, ProviderRequest,
5    ProviderRole, TaskId, TaskRecord, TaskStatus,
6};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::SystemTime;
11
12pub struct ResearchPlanner {
13    llm: Arc<dyn LlmProvider>,
14}
15
16impl ResearchPlanner {
17    pub fn new(llm: Arc<dyn LlmProvider>) -> Self {
18        Self { llm }
19    }
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23struct ResearchPlan {
24    goal_title: String,
25    goal_detail: String,
26    tasks: Vec<PlannedTask>,
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30struct PlannedTask {
31    id: String,
32    title: String,
33    detail: String,
34    depends_on: Vec<String>,
35}
36
37#[async_trait]
38impl Planner for ResearchPlanner {
39    async fn plan(&self, state: &AgentStateSnapshot, trigger: &AgentTrigger) -> PlannerOutput {
40        // Only plan if we have no active tasks and this is a new message/input
41        let has_active_tasks = state.tasks.iter().any(|task| {
42            matches!(
43                task.status,
44                TaskStatus::Pending
45                    | TaskStatus::Ready
46                    | TaskStatus::Running
47                    | TaskStatus::Blocked
48                    | TaskStatus::WaitingHuman
49            )
50        });
51
52        if has_active_tasks {
53            return PlannerOutput::default();
54        }
55
56        let content = match trigger {
57            AgentTrigger::HumanInput { content, .. } | AgentTrigger::Message { content, .. } => {
58                content
59            }
60            _ => return PlannerOutput::default(),
61        };
62
63        let system_prompt = r#"You are a Research Architect. Your job is to decompose a high-level research goal into a directed acyclic graph (DAG) of specific tasks.
64Output your plan as a single JSON object with this structure:
65{
66  "goal_title": "Short title",
67  "goal_detail": "Detailed explanation",
68  "tasks": [
69    {
70      "id": "task-1",
71      "title": "Task title",
72      "detail": "What to do",
73      "depends_on": []
74    },
75    ...
76  ]
77}
78Each task should be atomic. Use "depends_on" to specify which task IDs must be completed before a task can start."#;
79
80        let request = ProviderRequest {
81            trigger: trigger.clone(),
82            context: rain_engine_core::AgentContextSnapshot {
83                session_id: state.agent_id.0.clone(),
84                granted_scopes: vec!["scope:research".to_string()],
85                trigger_id: "planning-trigger".to_string(),
86                idempotency_key: None,
87                current_step: 0,
88                max_steps: 1,
89                history: Vec::new(),
90                prior_tool_results: Vec::new(),
91                session_cost_usd: 0.0,
92                state: state.clone(),
93                policy: Default::default(),
94                active_execution_plan: None,
95            },
96            available_skills: Vec::new(),
97            config: Default::default(),
98            policy: Default::default(),
99            contents: vec![
100                ProviderMessage {
101                    role: ProviderRole::System,
102                    parts: vec![ProviderContentPart::Text(system_prompt.to_string())],
103                },
104                ProviderMessage {
105                    role: ProviderRole::User,
106                    parts: vec![ProviderContentPart::Text(format!("Goal: {}", content))],
107                },
108            ],
109        };
110
111        let decision = match self.llm.generate_action(request).await {
112            Ok(d) => d,
113            Err(_) => return PlannerOutput::default(),
114        };
115
116        let response_text = match decision.action {
117            AgentAction::Respond { content } => content,
118            _ => return PlannerOutput::default(),
119        };
120
121        // Attempt to parse the JSON plan
122        let plan: ResearchPlan = match serde_json::from_str(&extract_json(&response_text)) {
123            Ok(p) => p,
124            Err(_) => return PlannerOutput::default(),
125        };
126
127        let mut events = Vec::new();
128        let goal_id = GoalId(format!("goal-{}", state.goals.len() + 1));
129
130        events.push(KernelEvent::GoalCreated(GoalRecord {
131            goal_id: goal_id.clone(),
132            created_at: SystemTime::now(),
133            title: plan.goal_title,
134            detail: Some(plan.goal_detail),
135            status: GoalStatus::Active,
136            parent_goal_id: None,
137        }));
138
139        let mut id_map = HashMap::new();
140        for (i, planned_task) in plan.tasks.into_iter().enumerate() {
141            let task_id = TaskId(format!("task-{}", state.tasks.len() + i + 1));
142            id_map.insert(planned_task.id, task_id.clone());
143
144            let mut blocked_by = Vec::new();
145            for dep in planned_task.depends_on {
146                if let Some(dep_id) = id_map.get(&dep) {
147                    blocked_by.push(dep_id.clone());
148                }
149            }
150
151            events.push(KernelEvent::TaskPlanned(TaskRecord {
152                task_id,
153                goal_id: Some(goal_id.clone()),
154                parent_task_id: None,
155                created_at: SystemTime::now(),
156                title: planned_task.title,
157                detail: Some(planned_task.detail),
158                status: if blocked_by.is_empty() {
159                    TaskStatus::Ready
160                } else {
161                    TaskStatus::Blocked
162                },
163                assignee: None,
164                blocked_by,
165            }));
166        }
167
168        PlannerOutput {
169            events,
170            proposed_plan: None,
171        }
172    }
173}
174
175fn extract_json(text: &str) -> String {
176    let text = text.trim();
177    if let Some(start) = text.find('{')
178        && let Some(end) = text.rfind('}')
179    {
180        return text[start..=end].to_string();
181    }
182    text.to_string()
183}