use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Notify, RwLock};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum WellbeingState {
#[default]
Healthy,
Cautious,
Concerned,
Distressed,
}
impl std::fmt::Display for WellbeingState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Healthy => write!(f, "healthy"),
Self::Cautious => write!(f, "cautious"),
Self::Concerned => write!(f, "concerned"),
Self::Distressed => write!(f, "distressed"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InterventionConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub pause_on_distressed: bool,
#[serde(default)]
pub pause_on_concerned: bool,
#[serde(default = "default_max_pause")]
pub max_pause_secs: u64,
#[serde(default = "default_wait_timeout")]
pub wait_timeout_secs: u64,
}
fn default_true() -> bool {
true
}
fn default_max_pause() -> u64 {
300
} fn default_wait_timeout() -> u64 {
60
}
impl Default for InterventionConfig {
fn default() -> Self {
Self {
enabled: true,
pause_on_distressed: false, pause_on_concerned: false,
max_pause_secs: 300,
wait_timeout_secs: 60,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct InterventionMetrics {
pub interventions_total: u64,
pub pauses_total: u64,
pub resumes_total: u64,
pub requests_blocked: u64,
pub requests_waited: u64,
pub requests_timed_out: u64,
pub current_pause_duration_secs: u64,
}
pub struct InterventionController {
config: InterventionConfig,
state: RwLock<WellbeingState>,
paused: AtomicBool,
pause_start: RwLock<Option<Instant>>,
pause_reason: RwLock<Option<String>>,
unpause_notify: Notify,
interventions_total: AtomicU64,
pauses_total: AtomicU64,
resumes_total: AtomicU64,
requests_blocked: AtomicU64,
requests_waited: AtomicU64,
requests_timed_out: AtomicU64,
}
impl InterventionController {
#[must_use]
pub fn new(config: InterventionConfig) -> Self {
info!(
"Wellbeing intervention controller initialized (enabled={}, pause_on_distressed={})",
config.enabled, config.pause_on_distressed
);
Self {
config,
state: RwLock::new(WellbeingState::Healthy),
paused: AtomicBool::new(false),
pause_start: RwLock::new(None),
pause_reason: RwLock::new(None),
unpause_notify: Notify::new(),
interventions_total: AtomicU64::new(0),
pauses_total: AtomicU64::new(0),
resumes_total: AtomicU64::new(0),
requests_blocked: AtomicU64::new(0),
requests_waited: AtomicU64::new(0),
requests_timed_out: AtomicU64::new(0),
}
}
pub async fn update_state(&self, state: WellbeingState, reason: Option<String>) {
if !self.config.enabled {
return;
}
let old_state = {
let mut s = self.state.write().await;
let old = *s;
*s = state;
old
};
if old_state != state {
info!("Wellbeing state changed: {} -> {}", old_state, state);
self.interventions_total.fetch_add(1, Ordering::Relaxed);
}
let should_pause = match state {
WellbeingState::Distressed => self.config.pause_on_distressed,
WellbeingState::Concerned => self.config.pause_on_concerned,
_ => false,
};
if should_pause && !self.is_paused() {
let reason = reason.unwrap_or_else(|| format!("Agent entered {} state", state));
self.pause(&reason).await;
} else if !should_pause && self.is_paused() {
self.resume().await;
}
}
pub async fn current_state(&self) -> WellbeingState {
*self.state.read().await
}
#[must_use]
pub fn is_paused(&self) -> bool {
self.paused.load(Ordering::SeqCst)
}
pub async fn pause(&self, reason: &str) {
if self.paused.swap(true, Ordering::SeqCst) {
return;
}
*self.pause_start.write().await = Some(Instant::now());
*self.pause_reason.write().await = Some(reason.to_string());
self.pauses_total.fetch_add(1, Ordering::Relaxed);
warn!("Inference paused due to wellbeing intervention: {}", reason);
}
pub async fn resume(&self) {
if !self.paused.swap(false, Ordering::SeqCst) {
return;
}
let duration = {
let start = self.pause_start.read().await;
start.map(|s| s.elapsed())
};
*self.pause_start.write().await = None;
*self.pause_reason.write().await = None;
self.resumes_total.fetch_add(1, Ordering::Relaxed);
self.unpause_notify.notify_waiters();
if let Some(d) = duration {
info!("Inference resumed after {:.1}s pause", d.as_secs_f32());
} else {
info!("Inference resumed");
}
}
pub async fn gate_request(&self) -> Result<(), InterventionError> {
if !self.config.enabled {
return Ok(());
}
if !self.is_paused() {
return Ok(());
}
self.check_auto_resume().await;
if !self.is_paused() {
return Ok(());
}
self.requests_blocked.fetch_add(1, Ordering::Relaxed);
debug!("Request blocked - waiting for wellbeing intervention to clear");
let timeout = Duration::from_secs(self.config.wait_timeout_secs);
tokio::select! {
_ = self.unpause_notify.notified() => {
self.requests_waited.fetch_add(1, Ordering::Relaxed);
debug!("Request proceeding after wait");
Ok(())
}
_ = tokio::time::sleep(timeout) => {
self.requests_timed_out.fetch_add(1, Ordering::Relaxed);
let reason = self.pause_reason.read().await.clone()
.unwrap_or_else(|| "Unknown".to_string());
Err(InterventionError::Paused { reason, timeout_secs: self.config.wait_timeout_secs })
}
}
}
async fn check_auto_resume(&self) {
let should_resume = {
let start = self.pause_start.read().await;
start.map_or(false, |s| {
s.elapsed().as_secs() >= self.config.max_pause_secs
})
};
if should_resume {
warn!(
"Auto-resuming inference after max pause duration ({}s)",
self.config.max_pause_secs
);
self.resume().await;
}
}
pub async fn metrics(&self) -> InterventionMetrics {
let current_pause_duration_secs = {
let start = self.pause_start.read().await;
start.map_or(0, |s| s.elapsed().as_secs())
};
InterventionMetrics {
interventions_total: self.interventions_total.load(Ordering::Relaxed),
pauses_total: self.pauses_total.load(Ordering::Relaxed),
resumes_total: self.resumes_total.load(Ordering::Relaxed),
requests_blocked: self.requests_blocked.load(Ordering::Relaxed),
requests_waited: self.requests_waited.load(Ordering::Relaxed),
requests_timed_out: self.requests_timed_out.load(Ordering::Relaxed),
current_pause_duration_secs,
}
}
pub async fn pause_reason(&self) -> Option<String> {
self.pause_reason.read().await.clone()
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum InterventionError {
#[error(
"inference paused due to wellbeing intervention: {reason} (timeout after {timeout_secs}s)"
)]
Paused {
reason: String,
timeout_secs: u64,
},
}
pub type SharedInterventionController = Arc<InterventionController>;
#[must_use]
pub fn create_intervention_controller(config: InterventionConfig) -> SharedInterventionController {
Arc::new(InterventionController::new(config))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_default_not_paused() {
let controller = InterventionController::new(InterventionConfig::default());
assert!(!controller.is_paused());
assert!(controller.gate_request().await.is_ok());
}
#[tokio::test]
async fn test_manual_pause_resume() {
let controller = InterventionController::new(InterventionConfig::default());
controller.pause("test pause").await;
assert!(controller.is_paused());
controller.resume().await;
assert!(!controller.is_paused());
}
#[tokio::test]
async fn test_auto_pause_on_distressed() {
let config = InterventionConfig {
enabled: true,
pause_on_distressed: true,
..Default::default()
};
let controller = InterventionController::new(config);
controller
.update_state(WellbeingState::Distressed, None)
.await;
assert!(controller.is_paused());
controller.update_state(WellbeingState::Healthy, None).await;
assert!(!controller.is_paused());
}
#[tokio::test]
async fn test_metrics() {
let config = InterventionConfig {
enabled: true,
pause_on_distressed: true,
..Default::default()
};
let controller = InterventionController::new(config);
controller
.update_state(WellbeingState::Distressed, None)
.await;
controller.update_state(WellbeingState::Healthy, None).await;
let metrics = controller.metrics().await;
assert_eq!(metrics.pauses_total, 1);
assert_eq!(metrics.resumes_total, 1);
assert!(metrics.interventions_total >= 2);
}
}