use std::sync::Arc;
use std::time::Duration;
use sqlx::SqlitePool;
use tokio::sync::broadcast;
use tokio_util::sync::CancellationToken;
use aa_core::AuditEventType;
use aa_runtime::approval::{ApprovalQueue, ApprovalRequestId};
use super::audit_sink::AuditEventSink;
use super::clock::Clock;
use super::escalation::EscalationEvent;
pub struct DbEscalationScheduler {
pool: SqlitePool,
clock: Arc<dyn Clock>,
queue: Arc<ApprovalQueue>,
audit_sink: Arc<dyn AuditEventSink>,
event_tx: broadcast::Sender<EscalationEvent>,
poll_interval: Duration,
}
impl DbEscalationScheduler {
pub async fn new(
pool: SqlitePool,
clock: Arc<dyn Clock>,
queue: Arc<ApprovalQueue>,
audit_sink: Arc<dyn AuditEventSink>,
event_tx: broadcast::Sender<EscalationEvent>,
poll_interval: Duration,
) -> Result<Self, DbEscalationError> {
sqlx::migrate!("./migrations").run(&pool).await?;
Ok(Self {
pool,
clock,
queue,
audit_sink,
event_tx,
poll_interval,
})
}
pub fn subscribe(&self) -> broadcast::Receiver<EscalationEvent> {
self.event_tx.subscribe()
}
pub async fn register(
&self,
request_id: ApprovalRequestId,
team_id: String,
escalation_role: String,
from_role: String,
escalate_at: u64,
) -> Result<(), DbEscalationError> {
let id_str = request_id.to_string();
let escalate_at_i = escalate_at as i64;
sqlx::query!(
r#"
INSERT OR IGNORE INTO pending_escalations
(approval_id, team_id, escalation_role, from_role, escalate_at)
VALUES (?, ?, ?, ?, ?)
"#,
id_str,
team_id,
escalation_role,
from_role,
escalate_at_i,
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn cancel(&self, request_id: ApprovalRequestId) -> Result<bool, DbEscalationError> {
let id_str = request_id.to_string();
let result = sqlx::query!("DELETE FROM pending_escalations WHERE approval_id = ?", id_str,)
.execute(&self.pool)
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn tick(&self) -> Result<(), DbEscalationError> {
let now = self.clock.now_secs().min(i64::MAX as u64) as i64;
let mut conn = self.pool.acquire().await?;
sqlx::query("BEGIN IMMEDIATE").execute(&mut *conn).await?;
let rows = sqlx::query!(
r#"
SELECT approval_id, team_id, escalation_role, from_role
FROM pending_escalations
WHERE escalate_at <= ?
ORDER BY escalate_at
LIMIT 50
"#,
now,
)
.fetch_all(&mut *conn)
.await?;
if rows.is_empty() {
sqlx::query("ROLLBACK").execute(&mut *conn).await?;
return Ok(());
}
for row in &rows {
sqlx::query!("DELETE FROM pending_escalations WHERE approval_id = ?", row.approval_id,)
.execute(&mut *conn)
.await?;
}
sqlx::query("COMMIT").execute(&mut *conn).await?;
drop(conn);
for row in rows {
let approval_id = match row.approval_id.parse::<ApprovalRequestId>() {
Ok(id) => id,
Err(_) => {
tracing::warn!(
approval_id = %row.approval_id,
"pending_escalations row has invalid UUID — skipping"
);
continue;
}
};
let escalation_now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let escalation_status = format!("escalated_to_{}", row.escalation_role);
let history_entry = aa_runtime::approval::RoutingHistoryEntry {
at: escalation_now,
action: "escalated".to_string(),
from_role: Some(row.from_role.clone()),
to_role: row.escalation_role.clone(),
};
let still_pending = self.queue.record_routing(
approval_id,
escalation_status,
Some(row.escalation_role.clone()),
None,
None,
Some(history_entry),
);
if !still_pending {
tracing::debug!(
%approval_id,
"escalation fired but approval already resolved — no event emitted"
);
continue;
}
tracing::info!(
%approval_id,
team_id = %row.team_id,
from_role = %row.from_role,
escalation_role = %row.escalation_role,
"approval escalation fired"
);
self.audit_sink.emit(
AuditEventType::ApprovalEscalated,
serde_json::json!({
"approval_id": approval_id.to_string(),
"from_role": row.from_role,
"to_role": row.escalation_role,
"team_id": row.team_id,
})
.to_string(),
);
let _ = self.event_tx.send(EscalationEvent {
request_id: approval_id,
team_id: row.team_id,
escalation_approvers: vec![row.escalation_role],
});
}
Ok(())
}
pub async fn run(self: Arc<Self>, token: CancellationToken) {
let mut interval = tokio::time::interval(self.poll_interval);
loop {
tokio::select! {
_ = token.cancelled() => {
if let Err(e) = self.tick().await {
tracing::error!(error = %e, "escalation tick failed during shutdown flush");
}
break;
}
_ = interval.tick() => {
if let Err(e) = self.tick().await {
tracing::error!(error = %e, "escalation tick failed");
}
}
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum DbEscalationError {
#[error("db escalation database error: {0}")]
Db(#[from] sqlx::Error),
#[error("db escalation migration error: {0}")]
Migration(#[from] sqlx::migrate::MigrateError),
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
use crate::approval::audit_sink::NoopAuditSink;
use crate::approval::clock::FakeClock;
async fn in_memory_scheduler(
clock: Arc<dyn Clock>,
) -> (DbEscalationScheduler, broadcast::Receiver<EscalationEvent>) {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let queue = ApprovalQueue::new();
let (tx, rx) = broadcast::channel(256);
let scheduler =
DbEscalationScheduler::new(pool, clock, queue, Arc::new(NoopAuditSink), tx, Duration::from_secs(30))
.await
.unwrap();
(scheduler, rx)
}
fn clock(secs: u64) -> Arc<dyn Clock> {
Arc::new(FakeClock::new(secs))
}
#[tokio::test]
async fn register_then_cancel_returns_true() {
let (s, _rx) = in_memory_scheduler(clock(1000)).await;
let id = Uuid::new_v4();
s.register(id, "team-a".into(), "OrgAdmin".into(), "TeamAdmin".into(), 2000)
.await
.unwrap();
assert!(s.cancel(id).await.unwrap());
assert!(!s.cancel(id).await.unwrap());
}
#[tokio::test]
async fn cancel_nonexistent_returns_false() {
let (s, _rx) = in_memory_scheduler(clock(1000)).await;
assert!(!s.cancel(Uuid::new_v4()).await.unwrap());
}
#[tokio::test]
async fn register_idempotent_insert_or_ignore() {
let (s, _rx) = in_memory_scheduler(clock(1000)).await;
let id = Uuid::new_v4();
s.register(id, "team-a".into(), "OrgAdmin".into(), "TeamAdmin".into(), 2000)
.await
.unwrap();
s.register(id, "team-a".into(), "OrgAdmin".into(), "TeamAdmin".into(), 9999)
.await
.unwrap();
assert!(s.cancel(id).await.unwrap());
assert!(!s.cancel(id).await.unwrap());
}
#[tokio::test]
async fn tick_fires_overdue_entry_and_emits_event() {
let fake = Arc::new(FakeClock::new(1000));
let (s, mut rx) = in_memory_scheduler(fake.clone() as Arc<dyn Clock>).await;
let req = aa_runtime::approval::ApprovalRequest {
request_id: Uuid::new_v4(),
agent_id: "agent-1".into(),
action: "test".into(),
condition_triggered: "test-policy".into(),
submitted_at: 0,
timeout_secs: 3600,
fallback: aa_core::PolicyResult::Deny {
reason: "timeout".into(),
},
team_id: Some("team-a".into()),
timeout_override_secs: None,
escalation_role_override: None,
};
let id = req.request_id;
let _fut = s.queue.submit(req);
s.register(id, "team-a".into(), "OrgAdmin".into(), "TeamAdmin".into(), 999)
.await
.unwrap();
s.tick().await.unwrap();
let event = rx.try_recv().unwrap();
assert_eq!(event.request_id, id);
assert_eq!(event.team_id, "team-a");
assert_eq!(event.escalation_approvers, vec!["OrgAdmin"]);
}
#[tokio::test]
async fn tick_does_not_fire_future_entry() {
let fake = Arc::new(FakeClock::new(1000));
let (s, mut rx) = in_memory_scheduler(fake.clone() as Arc<dyn Clock>).await;
let id = Uuid::new_v4();
s.register(id, "team-a".into(), "OrgAdmin".into(), "TeamAdmin".into(), 9999)
.await
.unwrap();
s.tick().await.unwrap();
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn tick_skips_already_resolved_approval() {
let fake = Arc::new(FakeClock::new(1000));
let (s, mut rx) = in_memory_scheduler(fake.clone() as Arc<dyn Clock>).await;
let id = Uuid::new_v4();
s.register(id, "team-b".into(), "OrgAdmin".into(), "TeamAdmin".into(), 0)
.await
.unwrap();
s.tick().await.unwrap();
assert!(rx.try_recv().is_err());
assert!(!s.cancel(id).await.unwrap());
}
#[tokio::test]
async fn run_stops_on_cancellation_and_flushes_due_rows() {
let fake = Arc::new(FakeClock::new(1000));
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let queue = ApprovalQueue::new();
let req = aa_runtime::approval::ApprovalRequest {
request_id: Uuid::new_v4(),
agent_id: "agent-2".into(),
action: "test".into(),
condition_triggered: "test-policy".into(),
submitted_at: 0,
timeout_secs: 3600,
fallback: aa_core::PolicyResult::Deny {
reason: "timeout".into(),
},
team_id: Some("team-b".into()),
timeout_override_secs: None,
escalation_role_override: None,
};
let id = req.request_id;
let _fut = queue.submit(req);
let (tx, mut rx) = broadcast::channel(16);
let scheduler = Arc::new(
DbEscalationScheduler::new(
pool,
fake.clone() as Arc<dyn Clock>,
queue,
Arc::new(NoopAuditSink),
tx,
Duration::from_secs(3600),
)
.await
.unwrap(),
);
scheduler
.register(id, "team-b".into(), "OrgAdmin".into(), "TeamAdmin".into(), 0)
.await
.unwrap();
let token = CancellationToken::new();
let token_clone = token.clone();
let sched_clone = Arc::clone(&scheduler);
let handle = tokio::spawn(async move { sched_clone.run(token_clone).await });
token.cancel();
handle.await.unwrap();
let event = rx.try_recv().unwrap();
assert_eq!(event.request_id, id);
}
#[tokio::test]
async fn concurrent_instances_fire_each_row_exactly_once() {
use sqlx::sqlite::SqliteConnectOptions;
use tempfile::NamedTempFile;
let tmp = NamedTempFile::new().unwrap();
let opts = SqliteConnectOptions::new()
.filename(tmp.path())
.create_if_missing(true)
.busy_timeout(Duration::from_secs(5));
let bootstrap = SqlitePool::connect_with(opts.clone()).await.unwrap();
sqlx::migrate!("./migrations").run(&bootstrap).await.unwrap();
for _ in 0..100_usize {
let id = Uuid::new_v4().to_string();
let at = 0_i64;
sqlx::query!(
"INSERT OR IGNORE INTO pending_escalations (approval_id, team_id, escalation_role, from_role, escalate_at) VALUES (?, 'team-x', 'OrgAdmin', 'TeamAdmin', ?)",
id,
at,
)
.execute(&bootstrap)
.await
.unwrap();
}
bootstrap.close().await;
let shared_queue = ApprovalQueue::new();
let mut schedulers = vec![];
for _ in 0..3_usize {
let pool = SqlitePool::connect_with(opts.clone()).await.unwrap();
let (tx, _rx) = broadcast::channel::<EscalationEvent>(256);
let clock: Arc<dyn Clock> = Arc::new(FakeClock::new(u64::MAX));
let s = Arc::new(
DbEscalationScheduler::new(
pool,
clock,
Arc::clone(&shared_queue),
Arc::new(NoopAuditSink),
tx,
Duration::from_secs(30),
)
.await
.unwrap(),
);
schedulers.push(s);
}
let mut tick_handles = vec![];
for s in &schedulers {
let s = Arc::clone(s);
tick_handles.push(tokio::spawn(async move { s.tick().await.unwrap() }));
}
for h in tick_handles {
h.await.unwrap();
}
let check_pool = SqlitePool::connect_with(opts.clone()).await.unwrap();
let remaining: i64 = sqlx::query_scalar!("SELECT COUNT(*) FROM pending_escalations")
.fetch_one(&check_pool)
.await
.unwrap();
check_pool.close().await;
assert_eq!(remaining, 0, "all 100 rows must be deleted exactly once");
drop(tmp);
}
}