use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use aa_proto::assembly::common::v1::AgentId;
use aa_proto::assembly::policy::v1::policy_service_client::PolicyServiceClient;
use aa_proto::assembly::policy::v1::{OpControlSignal, OpControlSubscribeRequest};
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
const MAX_BACKOFF: Duration = Duration::from_secs(32);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OpState {
Paused,
Terminated,
}
#[derive(Clone, Default)]
pub struct OpControlStore {
states: Arc<DashMap<String, OpState>>,
changed: Arc<Notify>,
}
impl OpControlStore {
pub fn new() -> Self {
Self::default()
}
pub fn apply(&self, op_id: &str, signal: OpControlSignal) -> Option<OpState> {
let result = match signal {
OpControlSignal::Terminate => {
self.states.insert(op_id.to_owned(), OpState::Terminated);
Some(OpState::Terminated)
}
OpControlSignal::Pause => {
if matches!(self.states.get(op_id).as_deref(), Some(OpState::Terminated)) {
Some(OpState::Terminated)
} else {
self.states.insert(op_id.to_owned(), OpState::Paused);
Some(OpState::Paused)
}
}
OpControlSignal::Resume => {
if matches!(self.states.get(op_id).as_deref(), Some(OpState::Terminated)) {
Some(OpState::Terminated)
} else {
self.states.remove(op_id);
None
}
}
OpControlSignal::Unspecified => self.states.get(op_id).as_deref().copied(),
};
self.changed.notify_waiters();
result
}
pub fn state(&self, op_id: &str) -> Option<OpState> {
self.states.get(op_id).as_deref().copied()
}
pub fn changed(&self) -> tokio::sync::futures::Notified<'_> {
self.changed.notified()
}
}
fn next_backoff(current: Duration) -> Duration {
(current * 2).min(MAX_BACKOFF)
}
pub struct OpControlClient;
impl OpControlClient {
pub fn start(gateway_url: String, agent_id: AgentId, store: OpControlStore) -> JoinHandle<()> {
tokio::spawn(async move { run(gateway_url, agent_id, store).await })
}
}
async fn run(gateway_url: String, agent_id: AgentId, store: OpControlStore) {
let mut backoff = INITIAL_BACKOFF;
loop {
match subscribe_once(&gateway_url, &agent_id, &store).await {
Ok(()) => backoff = INITIAL_BACKOFF,
Err(err) => {
metrics::counter!("aa_op_control_reconnects_total").increment(1);
tracing::warn!(
error = %err,
backoff_secs = backoff.as_secs(),
"op-control stream dropped; reconnecting after backoff"
);
tokio::time::sleep(backoff).await;
backoff = next_backoff(backoff);
}
}
}
}
async fn subscribe_once(
gateway_url: &str,
agent_id: &AgentId,
store: &OpControlStore,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut client = PolicyServiceClient::connect(gateway_url.to_owned()).await?;
let request = OpControlSubscribeRequest {
agent_id: Some(agent_id.clone()),
};
let response = client.op_control_stream(request).await?;
let mut inbound = response.into_inner();
while let Some(message) = inbound.message().await? {
let signal = message.signal();
tracing::debug!(op_id = %message.op_id, ?signal, "op-control signal received");
store.apply(&message.op_id, signal);
metrics::counter!("aa_op_control_signals_total").increment(1);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_doubles_then_caps_at_32s() {
let schedule: Vec<u64> = std::iter::successors(Some(INITIAL_BACKOFF), |&d| Some(next_backoff(d)))
.take(7)
.map(|d| d.as_secs())
.collect();
assert_eq!(schedule, vec![1, 2, 4, 8, 16, 32, 32]);
}
#[test]
fn terminate_records_terminated_state() {
let store = OpControlStore::new();
assert_eq!(
store.apply("t:s", OpControlSignal::Terminate),
Some(OpState::Terminated)
);
assert_eq!(store.state("t:s"), Some(OpState::Terminated));
}
#[test]
fn pause_then_resume_clears_state() {
let store = OpControlStore::new();
assert_eq!(store.apply("t:s", OpControlSignal::Pause), Some(OpState::Paused));
assert_eq!(store.apply("t:s", OpControlSignal::Resume), None);
assert_eq!(store.state("t:s"), None);
}
#[test]
fn terminate_is_sticky_against_later_pause_and_resume() {
let store = OpControlStore::new();
store.apply("t:s", OpControlSignal::Terminate);
assert_eq!(store.apply("t:s", OpControlSignal::Pause), Some(OpState::Terminated));
assert_eq!(store.apply("t:s", OpControlSignal::Resume), Some(OpState::Terminated));
assert_eq!(store.state("t:s"), Some(OpState::Terminated));
}
#[test]
fn unspecified_signal_is_ignored() {
let store = OpControlStore::new();
assert_eq!(store.apply("t:s", OpControlSignal::Unspecified), None);
assert_eq!(store.state("t:s"), None);
}
#[test]
fn distinct_ops_are_independent() {
let store = OpControlStore::new();
store.apply("a:1", OpControlSignal::Terminate);
store.apply("b:2", OpControlSignal::Pause);
assert_eq!(store.state("a:1"), Some(OpState::Terminated));
assert_eq!(store.state("b:2"), Some(OpState::Paused));
assert_eq!(store.state("c:3"), None);
}
}