use crate::invocation::{Invocation, InvocationError, InvocationResponse, InvocationStatus};
use crate::simulator::SimulatorPhase;
use chrono::{DateTime, Utc};
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::{Mutex, Notify};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecordResult {
Recorded,
AlreadyCompleted,
NotFound,
}
#[derive(Debug, Clone)]
pub struct InvocationState {
pub invocation: Invocation,
pub status: InvocationStatus,
pub started_at: Option<DateTime<Utc>>,
pub response: Option<InvocationResponse>,
pub error: Option<InvocationError>,
}
#[derive(Debug)]
pub(crate) struct RuntimeState {
pending_invocations: Mutex<VecDeque<Invocation>>,
invocation_states: Mutex<HashMap<String, InvocationState>>,
invocation_available: Notify,
state_changed: Notify,
phase: Mutex<SimulatorPhase>,
phase_changed: Notify,
init_error: Mutex<Option<String>>,
init_started_at: DateTime<Utc>,
init_telemetry_emitted: AtomicBool,
}
impl RuntimeState {
pub fn new() -> Self {
Self {
pending_invocations: Mutex::new(VecDeque::new()),
invocation_states: Mutex::new(HashMap::new()),
invocation_available: Notify::new(),
state_changed: Notify::new(),
phase: Mutex::new(SimulatorPhase::Initializing),
phase_changed: Notify::new(),
init_error: Mutex::new(None),
init_started_at: Utc::now(),
init_telemetry_emitted: AtomicBool::new(false),
}
}
pub fn init_started_at(&self) -> DateTime<Utc> {
self.init_started_at
}
pub fn mark_init_telemetry_emitted(&self) -> bool {
self.init_telemetry_emitted.swap(true, Ordering::SeqCst)
}
pub(crate) async fn enqueue_invocation(&self, invocation: Invocation) {
let request_id = invocation.request_id.clone();
let state = InvocationState {
invocation: invocation.clone(),
status: InvocationStatus::Pending,
started_at: None,
response: None,
error: None,
};
self.invocation_states
.lock()
.await
.insert(request_id, state);
self.pending_invocations.lock().await.push_back(invocation);
self.invocation_available.notify_one();
}
pub async fn next_invocation(&self) -> Invocation {
loop {
{
let mut queue = self.pending_invocations.lock().await;
if let Some(invocation) = queue.pop_front() {
if let Some(state) = self
.invocation_states
.lock()
.await
.get_mut(&invocation.request_id)
{
state.status = InvocationStatus::InProgress;
state.started_at = Some(Utc::now());
}
return invocation;
}
}
self.invocation_available.notified().await;
}
}
pub async fn record_response(&self, response: InvocationResponse) -> RecordResult {
let mut states = self.invocation_states.lock().await;
let Some(state) = states.get_mut(&response.request_id) else {
return RecordResult::NotFound;
};
if state.status != InvocationStatus::InProgress {
return RecordResult::AlreadyCompleted;
}
state.status = InvocationStatus::Success;
state.response = Some(response);
drop(states);
self.state_changed.notify_waiters();
RecordResult::Recorded
}
pub async fn record_error(&self, error: InvocationError) -> RecordResult {
let mut states = self.invocation_states.lock().await;
let Some(state) = states.get_mut(&error.request_id) else {
return RecordResult::NotFound;
};
if state.status != InvocationStatus::InProgress {
return RecordResult::AlreadyCompleted;
}
state.status = InvocationStatus::Error;
state.error = Some(error);
drop(states);
self.state_changed.notify_waiters();
RecordResult::Recorded
}
pub async fn mark_initialized(&self) {
*self.phase.lock().await = SimulatorPhase::Ready;
self.phase_changed.notify_waiters();
}
pub async fn mark_shutting_down(&self) {
*self.phase.lock().await = SimulatorPhase::ShuttingDown;
self.phase_changed.notify_waiters();
}
pub async fn is_initialized(&self) -> bool {
matches!(
*self.phase.lock().await,
SimulatorPhase::Ready | SimulatorPhase::ShuttingDown
)
}
pub async fn get_phase(&self) -> SimulatorPhase {
*self.phase.lock().await
}
pub(crate) async fn wait_for_phase(&self, target_phase: SimulatorPhase) {
loop {
if *self.phase.lock().await == target_phase {
return;
}
self.phase_changed.notified().await;
}
}
pub async fn record_init_error(&self, error: String) {
*self.init_error.lock().await = Some(error);
}
pub async fn get_init_error(&self) -> Option<String> {
self.init_error.lock().await.clone()
}
pub(crate) async fn wait_for_state_change(&self) {
self.state_changed.notified().await;
}
pub async fn get_invocation_state(&self, request_id: &str) -> Option<InvocationState> {
self.invocation_states.lock().await.get(request_id).cloned()
}
pub async fn get_all_states(&self) -> Vec<InvocationState> {
self.invocation_states
.lock()
.await
.values()
.cloned()
.collect()
}
}
impl Default for RuntimeState {
fn default() -> Self {
Self::new()
}
}