Skip to main content

rain_engine_cognition/
lib.rs

1//! Optional planning and task-graph orchestration for RainEngine.
2//!
3//! This crate composes over `rain-engine-core`; the kernel remains useful
4//! without it.
5
6mod rag;
7mod research_planner;
8
9pub use rag::*;
10pub use research_planner::*;
11
12use async_trait::async_trait;
13use rain_engine_core::{
14    AgentStateSnapshot, AgentTrigger, GoalId, GoalRecord, GoalStatus, KernelEvent, Planner,
15    PlannerOutput, ResumeToken, TaskId, TaskRecord, TaskStatus, WakeId, WakeRequestRecord,
16};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct ExecutorPolicy {
20    pub max_active_tasks: usize,
21}
22
23impl Default for ExecutorPolicy {
24    fn default() -> Self {
25        Self {
26            max_active_tasks: 4,
27        }
28    }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ReviewPolicy {
33    pub require_review_for_delegation: bool,
34    pub approval_scope: String,
35}
36
37impl Default for ReviewPolicy {
38    fn default() -> Self {
39        Self {
40            require_review_for_delegation: true,
41            approval_scope: "scope:human_approval".to_string(),
42        }
43    }
44}
45
46impl ReviewPolicy {
47    pub fn requires_human_review(&self, required_scopes: &[String]) -> bool {
48        required_scopes
49            .iter()
50            .any(|scope| scope == &self.approval_scope)
51    }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct WakePolicy {
56    pub schedule_follow_up: bool,
57    pub follow_up_ms: u64,
58}
59
60impl Default for WakePolicy {
61    fn default() -> Self {
62        Self {
63            schedule_follow_up: true,
64            follow_up_ms: 30 * 60 * 1000,
65        }
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub struct ReflectionPolicy {
71    pub completed_tasks_before_review: usize,
72    pub failed_tasks_before_replan: usize,
73}
74
75impl Default for ReflectionPolicy {
76    fn default() -> Self {
77        Self {
78            completed_tasks_before_review: 3,
79            failed_tasks_before_replan: 1,
80        }
81    }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub struct TaskRoute {
86    pub task_id: TaskId,
87    pub lane: String,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct AgentKernelProfile {
92    pub planning_cadence: String,
93    pub max_active_tasks: usize,
94    pub reflection_threshold: usize,
95    pub wake_policy: String,
96    pub human_approval_policy: String,
97}
98
99impl Default for AgentKernelProfile {
100    fn default() -> Self {
101        Self {
102            planning_cadence: "event".to_string(),
103            max_active_tasks: 4,
104            reflection_threshold: 2,
105            wake_policy: "external".to_string(),
106            human_approval_policy: "scoped".to_string(),
107        }
108    }
109}
110
111#[async_trait]
112pub trait TaskRouter: Send + Sync {
113    async fn route(&self, task: &TaskRecord) -> TaskRoute;
114}
115
116pub trait ReviewPolicyDecider: Send + Sync {
117    fn review_policy(&self) -> ReviewPolicy;
118}
119
120#[derive(Debug, Clone, Default)]
121pub struct MinimalTaskGraphPlanner;
122
123#[async_trait]
124impl Planner for MinimalTaskGraphPlanner {
125    async fn plan(&self, state: &AgentStateSnapshot, trigger: &AgentTrigger) -> PlannerOutput {
126        let has_active_tasks = state.tasks.iter().any(|task| {
127            matches!(
128                task.status,
129                TaskStatus::Pending
130                    | TaskStatus::Ready
131                    | TaskStatus::Running
132                    | TaskStatus::Blocked
133                    | TaskStatus::WaitingHuman
134            )
135        });
136
137        match trigger {
138            AgentTrigger::HumanInput { content, .. } | AgentTrigger::Message { content, .. }
139                if !has_active_tasks =>
140            {
141                let goal = GoalRecord {
142                    goal_id: GoalId(format!("goal-{}", state.goals.len() + 1)),
143                    created_at: std::time::SystemTime::now(),
144                    title: content.clone(),
145                    detail: Some("created from user input".to_string()),
146                    status: GoalStatus::Active,
147                    parent_goal_id: None,
148                };
149                let task = TaskRecord {
150                    task_id: TaskId(format!("task-{}", state.tasks.len() + 1)),
151                    goal_id: Some(goal.goal_id.clone()),
152                    parent_task_id: None,
153                    created_at: std::time::SystemTime::now(),
154                    title: format!("Investigate: {}", goal.title),
155                    detail: Some("planned from new observation".to_string()),
156                    status: TaskStatus::Ready,
157                    assignee: None,
158                    blocked_by: Vec::new(),
159                };
160                let wake = follow_up_wake(Some(task.task_id.clone()));
161                PlannerOutput {
162                    events: vec![
163                        KernelEvent::GoalCreated(goal),
164                        KernelEvent::TaskPlanned(task),
165                        KernelEvent::WakeScheduled(wake),
166                    ],
167                    proposed_plan: None,
168                }
169            }
170            AgentTrigger::SystemObservation { source, .. }
171            | AgentTrigger::ExternalEvent { source, .. }
172                if !has_active_tasks =>
173            {
174                let goal = GoalRecord {
175                    goal_id: GoalId(format!("goal-{}", state.goals.len() + 1)),
176                    created_at: std::time::SystemTime::now(),
177                    title: format!("Respond to {source}"),
178                    detail: Some("created from external observation".to_string()),
179                    status: GoalStatus::Active,
180                    parent_goal_id: None,
181                };
182                let task = TaskRecord {
183                    task_id: TaskId(format!("task-{}", state.tasks.len() + 1)),
184                    goal_id: Some(goal.goal_id.clone()),
185                    parent_task_id: None,
186                    created_at: std::time::SystemTime::now(),
187                    title: format!("Triage {source}"),
188                    detail: Some("planned from system observation".to_string()),
189                    status: TaskStatus::Ready,
190                    assignee: None,
191                    blocked_by: Vec::new(),
192                };
193                let wake = follow_up_wake(Some(task.task_id.clone()));
194                PlannerOutput {
195                    events: vec![
196                        KernelEvent::GoalCreated(goal),
197                        KernelEvent::TaskPlanned(task),
198                        KernelEvent::WakeScheduled(wake),
199                    ],
200                    proposed_plan: None,
201                }
202            }
203            AgentTrigger::ScheduledWake { .. } => {
204                if let Some(task) = state.tasks.iter().find(|task| {
205                    matches!(
206                        task.status,
207                        TaskStatus::Ready | TaskStatus::Blocked | TaskStatus::WaitingHuman
208                    )
209                }) {
210                    PlannerOutput {
211                        events: vec![KernelEvent::TaskClaimed {
212                            task_id: task.task_id.clone(),
213                            claimed_at: std::time::SystemTime::now(),
214                            assignee: Some("scheduler".to_string()),
215                        }],
216                        proposed_plan: None,
217                    }
218                } else {
219                    PlannerOutput {
220                        events: Vec::new(),
221                        proposed_plan: None,
222                    }
223                }
224            }
225            _ => PlannerOutput {
226                events: Vec::new(),
227                proposed_plan: None,
228            },
229        }
230    }
231}
232
233pub fn human_review_event(
234    task_id: Option<TaskId>,
235    prompt: impl Into<String>,
236    resume_token: ResumeToken,
237) -> KernelEvent {
238    KernelEvent::HumanInputRequested {
239        task_id,
240        requested_at: std::time::SystemTime::now(),
241        prompt: prompt.into(),
242        resume_token,
243    }
244}
245
246fn follow_up_wake(task_id: Option<TaskId>) -> WakeRequestRecord {
247    let requested_at = std::time::SystemTime::now();
248    WakeRequestRecord {
249        wake_id: WakeId(format!(
250            "wake-{}",
251            requested_at
252                .duration_since(std::time::UNIX_EPOCH)
253                .unwrap_or_default()
254                .as_millis()
255        )),
256        requested_at,
257        due_at: requested_at + std::time::Duration::from_millis(WakePolicy::default().follow_up_ms),
258        reason: "follow up on active task".to_string(),
259        task_id,
260    }
261}
262
263#[derive(Debug, Clone, Default)]
264pub struct RoundRobinTaskRouter;
265
266#[async_trait]
267impl TaskRouter for RoundRobinTaskRouter {
268    async fn route(&self, task: &TaskRecord) -> TaskRoute {
269        TaskRoute {
270            task_id: task.task_id.clone(),
271            lane: if task.goal_id.is_some() {
272                "goal-backed".to_string()
273            } else {
274                "default".to_string()
275            },
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use rain_engine_core::{AgentId, AgentStateSnapshot};
284
285    #[tokio::test]
286    async fn planner_creates_goal_and_task_for_new_input() {
287        let planner = MinimalTaskGraphPlanner;
288        let state = AgentStateSnapshot {
289            agent_id: AgentId("agent-1".to_string()),
290            profile: None,
291            goals: Vec::new(),
292            tasks: Vec::new(),
293            observations: Vec::new(),
294            artifacts: Vec::new(),
295            resources: Vec::new(),
296            relationships: Vec::new(),
297            pending_wake: None,
298        };
299        let output = planner
300            .plan(
301                &state,
302                &AgentTrigger::HumanInput {
303                    actor_id: "user".to_string(),
304                    content: "Investigate outage".to_string(),
305                    attachments: Vec::new(),
306                },
307            )
308            .await;
309
310        assert_eq!(output.events.len(), 3);
311        assert!(matches!(output.events[0], KernelEvent::GoalCreated(_)));
312        assert!(matches!(output.events[1], KernelEvent::TaskPlanned(_)));
313        assert!(matches!(output.events[2], KernelEvent::WakeScheduled(_)));
314    }
315}