Skip to main content

loong_kernel/
task_supervisor.rs

1use crate::{
2    contracts::{CapabilityToken, TaskIntent},
3    kernel::{KernelDispatch, LoongKernel},
4    policy::PolicyEngine,
5};
6use loong_contracts::{Fault, TaskState};
7
8/// Opt-in wrapper around `execute_task` that enforces FSM transitions.
9pub struct TaskSupervisor {
10    state: TaskState,
11}
12
13impl TaskSupervisor {
14    pub fn new(intent: TaskIntent) -> Self {
15        Self {
16            state: TaskState::Runnable(intent),
17        }
18    }
19
20    pub fn state(&self) -> &TaskState {
21        &self.state
22    }
23
24    pub fn is_runnable(&self) -> bool {
25        matches!(self.state, TaskState::Runnable(_))
26    }
27
28    /// Clone the current state before attempting a guarded transition so
29    /// rejected transitions leave the supervisor unchanged.
30    fn take_state(&self) -> TaskState {
31        self.state.clone()
32    }
33
34    /// Execute the task through the kernel, tracking state transitions.
35    pub async fn execute<P: PolicyEngine>(
36        &mut self,
37        kernel: &LoongKernel<P>,
38        pack_id: &str,
39        token: &CapabilityToken,
40    ) -> Result<KernelDispatch, Fault> {
41        // Clone the intent before transitioning, since we need it for the
42        // kernel call and transition_to_in_send consumes it.
43        let intent = match &self.state {
44            TaskState::Runnable(intent) => intent.clone(),
45            TaskState::InSend { .. }
46            | TaskState::InReply { .. }
47            | TaskState::Completed(_)
48            | TaskState::Faulted(_) => {
49                return Err(Fault::ProtocolViolation {
50                    detail: "task is not in Runnable state".to_owned(),
51                });
52            }
53        };
54
55        // Runnable -> InSend (guarded transition)
56        let taken = self.take_state();
57        self.state = taken
58            .transition_to_in_send()
59            .map_err(|detail| Fault::ProtocolViolation { detail })?;
60
61        // InSend -> InReply (guarded transition)
62        let taken = self.take_state();
63        self.state = taken
64            .transition_to_in_reply()
65            .map_err(|detail| Fault::ProtocolViolation { detail })?;
66
67        // Execute through kernel
68        match kernel.execute_task(pack_id, token, intent).await {
69            Ok(dispatch) => {
70                // InReply -> Completed (guarded transition)
71                let taken = self.take_state();
72                self.state = taken
73                    .transition_to_completed(dispatch.outcome.clone())
74                    .map_err(|detail| Fault::ProtocolViolation { detail })?;
75                Ok(dispatch)
76            }
77            Err(kernel_err) => {
78                let fault = Fault::from_kernel_error(kernel_err);
79                // Any non-terminal -> Faulted
80                let taken = self.take_state();
81                self.state = taken.transition_to_faulted(fault.clone());
82                Err(fault)
83            }
84        }
85    }
86
87    /// Force state -- for testing only.
88    #[cfg(test)]
89    pub fn force_state(&mut self, state: TaskState) {
90        self.state = state;
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use std::collections::BTreeSet;
97
98    use serde_json::json;
99
100    use super::TaskSupervisor;
101    use crate::contracts::{Capability, TaskIntent};
102    use loong_contracts::{Fault, TaskState};
103
104    fn sample_intent() -> TaskIntent {
105        TaskIntent {
106            task_id: "supervised-guard".to_owned(),
107            objective: "exercise guarded transition".to_owned(),
108            required_capabilities: BTreeSet::from([Capability::InvokeTool]),
109            payload: json!({}),
110        }
111    }
112
113    #[test]
114    fn take_state_does_not_poison_supervisor_when_transition_is_rejected() {
115        let supervisor = TaskSupervisor::new(sample_intent());
116
117        let error = supervisor
118            .take_state()
119            .transition_to_in_reply()
120            .expect_err("Runnable cannot transition directly to InReply");
121
122        assert!(error.contains("cannot move to InReply"));
123        assert!(matches!(
124            supervisor.state(),
125            TaskState::Runnable(intent) if intent.task_id == "supervised-guard"
126        ));
127        assert!(!matches!(
128            supervisor.state(),
129            TaskState::Faulted(Fault::ProtocolViolation { .. })
130        ));
131    }
132}