use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::task::JoinHandle;
use aa_proto::assembly::gateway::v1::invalidation_event::Payload;
use aa_proto::assembly::gateway::v1::invalidation_service_client::InvalidationServiceClient;
use aa_proto::assembly::gateway::v1::{subscribe_request::Kind, Decision, SubscribeInitial, SubscribeRequest};
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
const MAX_BACKOFF: Duration = Duration::from_secs(32);
pub trait InvalidationSink: Send + Sync {
fn on_policy_invalidated(&self, agent_id: &str);
fn on_approval_resolved(&self, request_id: &str, decision: Decision) {
let _ = (request_id, decision);
}
}
fn next_backoff(current: Duration) -> Duration {
(current * 2).min(MAX_BACKOFF)
}
pub struct InvalidationClient;
impl InvalidationClient {
pub fn start(gateway_url: String, assembly_id: String, sinks: Vec<Arc<dyn InvalidationSink>>) -> JoinHandle<()> {
tokio::spawn(async move { run(gateway_url, assembly_id, sinks).await })
}
}
async fn run(gateway_url: String, assembly_id: String, sinks: Vec<Arc<dyn InvalidationSink>>) {
let mut backoff = INITIAL_BACKOFF;
let mut last_seq_seen: u64 = 0;
loop {
match subscribe_once(&gateway_url, &assembly_id, &mut last_seq_seen, &sinks).await {
Ok(()) => backoff = INITIAL_BACKOFF,
Err(err) => {
metrics::counter!("aa_invalidation_reconnects_total").increment(1);
tracing::warn!(
error = %err,
backoff_secs = backoff.as_secs(),
last_seq_seen,
"invalidation stream dropped; reconnecting after backoff"
);
tokio::time::sleep(backoff).await;
backoff = next_backoff(backoff);
}
}
}
}
async fn subscribe_once(
gateway_url: &str,
assembly_id: &str,
last_seq_seen: &mut u64,
sinks: &[Arc<dyn InvalidationSink>],
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut client = InvalidationServiceClient::connect(gateway_url.to_owned()).await?;
let initial = SubscribeRequest {
assembly_id: assembly_id.to_owned(),
kind: Some(Kind::Initial(SubscribeInitial {
last_seq_seen: *last_seq_seen,
})),
};
let response = client.subscribe(tokio_stream::once(initial)).await?;
let mut inbound = response.into_inner();
while let Some(event) = inbound.message().await? {
let applied_at = Instant::now();
match &event.payload {
Some(Payload::PolicyInvalidated(policy)) => {
for sink in sinks {
sink.on_policy_invalidated(&policy.agent_id);
}
}
Some(Payload::ApprovalResolved(approval)) => {
let decision = approval.decision();
for sink in sinks {
sink.on_approval_resolved(&approval.request_id, decision);
}
}
None => {}
}
if event.seq > *last_seq_seen {
*last_seq_seen = event.seq;
}
metrics::histogram!("aa_invalidation_latency_seconds").record(applied_at.elapsed().as_secs_f64());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_doubles_then_caps_at_32s() {
let schedule: Vec<u64> = std::iter::successors(Some(INITIAL_BACKOFF), |&d| Some(next_backoff(d)))
.take(7)
.map(|d| d.as_secs())
.collect();
assert_eq!(schedule, vec![1, 2, 4, 8, 16, 32, 32]);
}
}