use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::broadcast;
use aa_runtime::approval::ApprovalRequestId;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PersistedEscalation {
pub request_id: ApprovalRequestId,
pub team_id: String,
pub escalation_approvers: Vec<String>,
pub escalate_at: u64,
}
#[derive(Debug, Default, serde::Serialize, serde::Deserialize)]
struct PersistedEscalations {
pending: Vec<PersistedEscalation>,
}
#[derive(Debug, Clone)]
pub struct EscalationEvent {
pub request_id: ApprovalRequestId,
pub team_id: String,
pub escalation_approvers: Vec<String>,
}
pub struct EscalationScheduler {
path: PathBuf,
state: Arc<Mutex<HashMap<ApprovalRequestId, PersistedEscalation>>>,
event_tx: broadcast::Sender<EscalationEvent>,
poll_interval: Duration,
}
impl EscalationScheduler {
pub fn new(
path: impl Into<PathBuf>,
event_tx: broadcast::Sender<EscalationEvent>,
poll_interval: Duration,
) -> Result<Self, EscalationError> {
let path = path.into();
let initial = load_escalations(&path)?;
let state = Arc::new(Mutex::new(
initial
.into_iter()
.map(|e| (e.request_id, e))
.collect::<HashMap<_, _>>(),
));
Ok(Self {
path,
state,
event_tx,
poll_interval,
})
}
pub fn subscribe(&self) -> broadcast::Receiver<EscalationEvent> {
self.event_tx.subscribe()
}
pub fn register(
&self,
request_id: ApprovalRequestId,
team_id: String,
escalation_approvers: Vec<String>,
timeout_secs: u64,
) -> Result<(), EscalationError> {
let now = current_epoch_secs();
let entry = PersistedEscalation {
request_id,
team_id,
escalation_approvers,
escalate_at: now + timeout_secs,
};
{
let mut state = self.state.lock().unwrap();
state.insert(request_id, entry);
}
self.persist()
}
pub fn cancel(&self, request_id: ApprovalRequestId) -> Result<bool, EscalationError> {
let removed = {
let mut state = self.state.lock().unwrap();
state.remove(&request_id).is_some()
};
if removed {
self.persist()?;
}
Ok(removed)
}
pub fn tick(&self) {
let now = current_epoch_secs();
let overdue: Vec<PersistedEscalation> = {
let mut state = self.state.lock().unwrap();
let overdue: Vec<_> = state.values().filter(|e| e.escalate_at <= now).cloned().collect();
for e in &overdue {
state.remove(&e.request_id);
}
overdue
};
if !overdue.is_empty() {
let _ = self.persist();
}
for entry in overdue {
tracing::info!(
request_id = %entry.request_id,
team_id = %entry.team_id,
"approval escalation fired"
);
let _ = self.event_tx.send(EscalationEvent {
request_id: entry.request_id,
team_id: entry.team_id,
escalation_approvers: entry.escalation_approvers,
});
}
}
pub async fn run(self: Arc<Self>) {
let mut interval = tokio::time::interval(self.poll_interval);
loop {
interval.tick().await;
self.tick();
}
}
fn persist(&self) -> Result<(), EscalationError> {
let state = self.state.lock().unwrap();
let persisted = PersistedEscalations {
pending: state.values().cloned().collect(),
};
drop(state);
save_escalations(&self.path, &persisted)
}
}
fn load_escalations(path: &Path) -> Result<Vec<PersistedEscalation>, EscalationError> {
match std::fs::read_to_string(path) {
Ok(json) => {
let p: PersistedEscalations = serde_json::from_str(&json).map_err(EscalationError::Json)?;
Ok(p.pending)
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(vec![]),
Err(e) => Err(EscalationError::Io(e)),
}
}
fn save_escalations(path: &Path, state: &PersistedEscalations) -> Result<(), EscalationError> {
super::persistence::write_json_atomic(path, state, EscalationError::Io, EscalationError::Json)
}
fn current_epoch_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Debug, thiserror::Error)]
pub enum EscalationError {
#[error("escalation I/O error: {0}")]
Io(std::io::Error),
#[error("escalation JSON error: {0}")]
Json(serde_json::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
fn temp_path() -> PathBuf {
let mut p = std::env::temp_dir();
p.push(format!("escalation_test_{}.json", Uuid::new_v4()));
p
}
fn make_scheduler() -> (Arc<EscalationScheduler>, broadcast::Receiver<EscalationEvent>) {
let (tx, rx) = broadcast::channel(16);
let s = Arc::new(EscalationScheduler::new(temp_path(), tx, Duration::from_millis(50)).unwrap());
(s, rx)
}
#[test]
fn register_then_cancel_returns_true() {
let (s, _rx) = make_scheduler();
let id = Uuid::new_v4();
s.register(id, "team-a".to_string(), vec!["mgr".to_string()], 300)
.unwrap();
assert!(s.cancel(id).unwrap());
assert!(!s.cancel(id).unwrap());
}
#[test]
fn cancel_nonexistent_returns_false() {
let (s, _rx) = make_scheduler();
assert!(!s.cancel(Uuid::new_v4()).unwrap());
}
#[test]
fn register_persists_to_disk() {
let path = temp_path();
let (tx, _rx) = broadcast::channel(4);
let s = Arc::new(EscalationScheduler::new(&path, tx, Duration::from_millis(50)).unwrap());
let id = Uuid::new_v4();
s.register(id, "team-b".to_string(), vec![], 600).unwrap();
let loaded = load_escalations(&path).unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].request_id, id);
let _ = std::fs::remove_file(&path);
}
#[tokio::test]
async fn overdue_entry_fires_event() {
let path = temp_path();
let (tx, mut rx) = broadcast::channel(4);
let s = Arc::new(EscalationScheduler::new(&path, tx, Duration::from_millis(50)).unwrap());
let id = Uuid::new_v4();
s.register(id, "team-c".to_string(), vec!["mgr".to_string()], 0)
.unwrap();
s.tick();
let event = rx.try_recv().unwrap();
assert_eq!(event.request_id, id);
assert_eq!(event.team_id, "team-c");
let _ = std::fs::remove_file(&path);
}
#[tokio::test]
async fn not_yet_overdue_does_not_fire() {
let (s, mut rx) = make_scheduler();
let id = Uuid::new_v4();
s.register(id, "team-d".to_string(), vec![], 3600).unwrap();
s.tick();
assert!(rx.try_recv().is_err());
}
#[test]
fn reload_restores_registered_entry() {
let path = temp_path();
let (tx, _rx) = broadcast::channel(4);
let s = Arc::new(EscalationScheduler::new(&path, tx, Duration::from_millis(50)).unwrap());
let id = Uuid::new_v4();
s.register(id, "team-e".to_string(), vec![], 120).unwrap();
drop(s);
let (tx2, _rx2) = broadcast::channel(4);
let s2 = Arc::new(EscalationScheduler::new(&path, tx2, Duration::from_millis(50)).unwrap());
assert!(s2.cancel(id).unwrap());
let _ = std::fs::remove_file(&path);
}
#[test]
fn escalation_error_display_io() {
let e = EscalationError::Io(std::io::Error::other("disk full"));
assert!(e.to_string().contains("escalation I/O error"));
}
#[test]
fn escalation_error_display_json() {
let raw: Result<PersistedEscalations, _> = serde_json::from_str("not json");
let e = EscalationError::Json(raw.unwrap_err());
assert!(e.to_string().contains("escalation JSON error"));
}
}