loong_kernel/
task_supervisor.rs1use crate::{
2 contracts::{CapabilityToken, TaskIntent},
3 kernel::{KernelDispatch, LoongKernel},
4 policy::PolicyEngine,
5};
6use loong_contracts::{Fault, TaskState};
7
8pub struct TaskSupervisor {
10 state: TaskState,
11}
12
13impl TaskSupervisor {
14 pub fn new(intent: TaskIntent) -> Self {
15 Self {
16 state: TaskState::Runnable(intent),
17 }
18 }
19
20 pub fn state(&self) -> &TaskState {
21 &self.state
22 }
23
24 pub fn is_runnable(&self) -> bool {
25 matches!(self.state, TaskState::Runnable(_))
26 }
27
28 fn take_state(&self) -> TaskState {
31 self.state.clone()
32 }
33
34 pub async fn execute<P: PolicyEngine>(
36 &mut self,
37 kernel: &LoongKernel<P>,
38 pack_id: &str,
39 token: &CapabilityToken,
40 ) -> Result<KernelDispatch, Fault> {
41 let intent = match &self.state {
44 TaskState::Runnable(intent) => intent.clone(),
45 TaskState::InSend { .. }
46 | TaskState::InReply { .. }
47 | TaskState::Completed(_)
48 | TaskState::Faulted(_) => {
49 return Err(Fault::ProtocolViolation {
50 detail: "task is not in Runnable state".to_owned(),
51 });
52 }
53 };
54
55 let taken = self.take_state();
57 self.state = taken
58 .transition_to_in_send()
59 .map_err(|detail| Fault::ProtocolViolation { detail })?;
60
61 let taken = self.take_state();
63 self.state = taken
64 .transition_to_in_reply()
65 .map_err(|detail| Fault::ProtocolViolation { detail })?;
66
67 match kernel.execute_task(pack_id, token, intent).await {
69 Ok(dispatch) => {
70 let taken = self.take_state();
72 self.state = taken
73 .transition_to_completed(dispatch.outcome.clone())
74 .map_err(|detail| Fault::ProtocolViolation { detail })?;
75 Ok(dispatch)
76 }
77 Err(kernel_err) => {
78 let fault = Fault::from_kernel_error(kernel_err);
79 let taken = self.take_state();
81 self.state = taken.transition_to_faulted(fault.clone());
82 Err(fault)
83 }
84 }
85 }
86
87 #[cfg(test)]
89 pub fn force_state(&mut self, state: TaskState) {
90 self.state = state;
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use std::collections::BTreeSet;
97
98 use serde_json::json;
99
100 use super::TaskSupervisor;
101 use crate::contracts::{Capability, TaskIntent};
102 use loong_contracts::{Fault, TaskState};
103
104 fn sample_intent() -> TaskIntent {
105 TaskIntent {
106 task_id: "supervised-guard".to_owned(),
107 objective: "exercise guarded transition".to_owned(),
108 required_capabilities: BTreeSet::from([Capability::InvokeTool]),
109 payload: json!({}),
110 }
111 }
112
113 #[test]
114 fn take_state_does_not_poison_supervisor_when_transition_is_rejected() {
115 let supervisor = TaskSupervisor::new(sample_intent());
116
117 let error = supervisor
118 .take_state()
119 .transition_to_in_reply()
120 .expect_err("Runnable cannot transition directly to InReply");
121
122 assert!(error.contains("cannot move to InReply"));
123 assert!(matches!(
124 supervisor.state(),
125 TaskState::Runnable(intent) if intent.task_id == "supervised-guard"
126 ));
127 assert!(!matches!(
128 supervisor.state(),
129 TaskState::Faulted(Fault::ProtocolViolation { .. })
130 ));
131 }
132}