use ainl_contracts::{Mission, MissionId, MissionState};
use thiserror::Error;
use crate::state_machine::{transition, StateMachineError};
pub trait MissionLifecycle {
fn load_mission(&self, mission_id: &MissionId) -> Result<Mission, MissionLifecycleError>;
fn save_mission(&self, mission: &Mission) -> Result<(), MissionLifecycleError>;
fn set_state(
&self,
mission_id: &MissionId,
new_state: MissionState,
) -> Result<MissionState, MissionLifecycleError> {
let mut mission = self.load_mission(mission_id)?;
let next = transition(mission.state, new_state).map_err(MissionLifecycleError::State)?;
mission.state = next;
self.save_mission(&mission)?;
Ok(next)
}
fn initialize(&self, mission_id: &MissionId) -> Result<MissionState, MissionLifecycleError> {
self.set_state(mission_id, MissionState::Initializing)
}
fn start_running(&self, mission_id: &MissionId) -> Result<MissionState, MissionLifecycleError> {
let mission = self.load_mission(mission_id)?;
let target = match mission.state {
MissionState::Initializing | MissionState::Paused | MissionState::OrchestratorTurn => {
MissionState::Running
}
other => other,
};
self.set_state(mission_id, target)
}
fn pause(&self, mission_id: &MissionId) -> Result<MissionState, MissionLifecycleError> {
self.set_state(mission_id, MissionState::Paused)
}
fn resume(&self, mission_id: &MissionId) -> Result<MissionState, MissionLifecycleError> {
self.set_state(mission_id, MissionState::Running)
}
fn enter_orchestrator_turn(
&self,
mission_id: &MissionId,
) -> Result<MissionState, MissionLifecycleError> {
self.set_state(mission_id, MissionState::OrchestratorTurn)
}
fn complete(&self, mission_id: &MissionId) -> Result<MissionState, MissionLifecycleError> {
self.set_state(mission_id, MissionState::Completed)
}
fn cancel(&self, mission_id: &MissionId) -> Result<MissionState, MissionLifecycleError> {
self.set_state(mission_id, MissionState::Cancelled)
}
}
#[derive(Debug, Error)]
pub enum MissionLifecycleError {
#[error("state machine: {0}")]
State(#[from] StateMachineError),
#[error("not found: {0}")]
NotFound(String),
#[error("persist: {0}")]
Persist(String),
}
#[cfg(test)]
mod tests {
use super::*;
use ainl_contracts::MissionCapabilityFlags;
use chrono::Utc;
use std::cell::RefCell;
use std::collections::HashMap;
struct MemLifecycle {
missions: RefCell<HashMap<String, Mission>>,
}
impl MemLifecycle {
fn with_mission(state: MissionState) -> Self {
let m = Mission {
mission_id: MissionId("m1".into()),
objective_md: "test".into(),
state,
milestone_ids: vec![],
mission_root: None,
created_at: Utc::now(),
last_orchestrator_turn_at: None,
capability_flags: MissionCapabilityFlags::default(),
};
let mut map = HashMap::new();
map.insert("m1".into(), m);
Self {
missions: RefCell::new(map),
}
}
}
impl MissionLifecycle for MemLifecycle {
fn load_mission(&self, mission_id: &MissionId) -> Result<Mission, MissionLifecycleError> {
self.missions
.borrow()
.get(mission_id.as_str())
.cloned()
.ok_or_else(|| MissionLifecycleError::NotFound(mission_id.as_str().into()))
}
fn save_mission(&self, mission: &Mission) -> Result<(), MissionLifecycleError> {
self.missions
.borrow_mut()
.insert(mission.mission_id.as_str().into(), mission.clone());
Ok(())
}
}
#[test]
fn initialize_then_run() {
let lc = MemLifecycle::with_mission(MissionState::AwaitingInput);
let id = MissionId("m1".into());
assert_eq!(lc.initialize(&id).unwrap(), MissionState::Initializing);
assert_eq!(lc.start_running(&id).unwrap(), MissionState::Running);
}
#[test]
fn pause_resume() {
let lc = MemLifecycle::with_mission(MissionState::Running);
let id = MissionId("m1".into());
assert_eq!(lc.pause(&id).unwrap(), MissionState::Paused);
assert_eq!(lc.resume(&id).unwrap(), MissionState::Running);
}
}