aa-runtime 0.0.1-rc.3

Tokio async runtime wrapper and lifecycle management for Agent Assembly
//! Runtime-side consumer of the gateway's op-control kill switch
//! (`PolicyService.OpControlStream`, AAASM-3491).
//!
//! # Why this exists
//!
//! The gateway can already *publish* pause/resume/terminate signals for an
//! in-flight op (`aa-gateway/src/ops`), but before this module **no client on
//! the agent's execution path subscribed to them** — an operator terminate
//! flipped the gateway registry to `Terminated` and broadcast into a channel
//! with no listener, so the running agent kept executing (a silent no-op /
//! allow-through of the documented kill switch; QA `qa3464-ops-registry-control`).
//!
//! [`OpControlClient`] is the missing consumer. It opens the
//! `OpControlStream` for this agent's composite id, and records each signal in
//! a shared [`OpControlStore`] keyed by `op_id` (`"{trace_id}:{span_id}"`, the
//! same form the gateway and dashboard use). The runtime's per-tool policy
//! check (`pipeline::handle_policy_query`) consults the store before allowing
//! an action, so a terminate **fast-fails** the in-flight action and a pause
//! **blocks** it until resume.
//!
//! # Fail-closed
//!
//! An op the operator has terminated is denied; a paused op is held. The store
//! is the authoritative runtime-side record — once a terminate is observed it
//! is sticky (`Terminated` is never cleared by a later pause/resume), so the
//! kill switch cannot be undone by a racing signal.

use std::sync::Arc;
use std::time::Duration;

use dashmap::DashMap;
use tokio::sync::Notify;
use tokio::task::JoinHandle;

use aa_proto::assembly::common::v1::AgentId;
use aa_proto::assembly::policy::v1::policy_service_client::PolicyServiceClient;
use aa_proto::assembly::policy::v1::{OpControlSignal, OpControlSubscribeRequest};

/// First reconnect delay; doubles on each consecutive failure.
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
/// Upper bound on the reconnect delay (1s → 2 → 4 → … → 32s cap).
const MAX_BACKOFF: Duration = Duration::from_secs(32);

/// Reserved `op_id` addressing **every** op for the subscribed agent — a global
/// kill switch.
///
/// A signal the gateway publishes under this id halts the agent regardless of
/// which op (if any) a request claims, so it cannot be evaded by an absent or
/// altered agent-supplied `trace_id` (AAASM-3873). It never collides with a
/// per-op id, which always has the colon-bearing `"{trace_id}:{span_id}"` form.
pub const GLOBAL_HALT_OP_ID: &str = "*";

/// Reserved `op_id` addressing the whole agent identified by `agent_id`.
///
/// Op-control matching binds to this **server-side** identity — the agent id the
/// runtime itself knows — rather than the attacker-controlled `trace_id` /
/// `span_id` on a `CheckActionRequest`. An agent-level terminate/pause published
/// under this id applies to every request from the agent, including ones that
/// omit or forge a `trace_id` (AAASM-3873). The `agent:` prefix keeps it disjoint
/// from per-op ids of the form `"{trace_id}:{span_id}"`.
pub fn agent_halt_op_id(agent_id: &str) -> String {
    format!("agent:{agent_id}")
}

/// The runtime-observed lifecycle state of a single op.
///
/// Derived from the most recent [`OpControlSignal`] the gateway pushed for the
/// op's `op_id`. Absence from the [`OpControlStore`] means "no control signal
/// seen" — the op runs normally.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OpState {
    /// Operator paused the op: the next per-tool check must block until a
    /// `Resume` (or `Terminate`) arrives.
    Paused,
    /// Operator (or a policy deny) terminated the op: every further per-tool
    /// check must fast-fail. Terminal and sticky — never downgraded.
    Terminated,
}

/// Shared, lock-light record of op-control state keyed by `op_id`.
///
/// Written by the [`OpControlClient`] background subscriber; read by the
/// runtime pipeline on every per-tool policy check. Cheap to clone (`Arc`).
#[derive(Clone, Default)]
pub struct OpControlStore {
    /// `op_id` → latest non-`Resume` state. A `Resume` removes the entry so a
    /// resumed op reads as "runnable" again.
    states: Arc<DashMap<String, OpState>>,
    /// Woken whenever a signal is applied, so a check parked on a paused op
    /// re-evaluates promptly instead of polling.
    changed: Arc<Notify>,
}

impl OpControlStore {
    /// Construct an empty store.
    pub fn new() -> Self {
        Self::default()
    }

    /// Apply one `OpControlSignal` for `op_id`, returning the resulting state
    /// (`None` once the op is runnable again, i.e. after a `Resume`).
    ///
    /// `Terminate` is sticky: once recorded it overrides a later `Pause` or
    /// `Resume` so the kill switch cannot be lifted by a racing signal.
    /// `Unspecified` is a malformed message and is ignored.
    pub fn apply(&self, op_id: &str, signal: OpControlSignal) -> Option<OpState> {
        let result = match signal {
            OpControlSignal::Terminate => {
                self.states.insert(op_id.to_owned(), OpState::Terminated);
                Some(OpState::Terminated)
            }
            OpControlSignal::Pause => {
                // Never undo a terminate.
                if matches!(self.states.get(op_id).as_deref(), Some(OpState::Terminated)) {
                    Some(OpState::Terminated)
                } else {
                    self.states.insert(op_id.to_owned(), OpState::Paused);
                    Some(OpState::Paused)
                }
            }
            OpControlSignal::Resume => {
                // A terminated op stays terminated; otherwise resume clears it.
                if matches!(self.states.get(op_id).as_deref(), Some(OpState::Terminated)) {
                    Some(OpState::Terminated)
                } else {
                    self.states.remove(op_id);
                    None
                }
            }
            OpControlSignal::Unspecified => self.states.get(op_id).as_deref().copied(),
        };
        self.changed.notify_waiters();
        result
    }

    /// Current state of `op_id`, or `None` if no control signal is in effect.
    pub fn state(&self, op_id: &str) -> Option<OpState> {
        self.states.get(op_id).as_deref().copied()
    }

    /// A future that resolves on the next signal application.
    ///
    /// Returned (not awaited) so a caller parked on a paused op can register
    /// interest **before** re-reading [`state`](Self::state) — closing the race
    /// where a resume/terminate lands between the state read and the await.
    /// `notify_waiters` only wakes already-registered waiters, so registering
    /// first is required for correctness.
    pub fn changed(&self) -> tokio::sync::futures::Notified<'_> {
        self.changed.notified()
    }
}

/// Next reconnect delay: double the current one, capped at [`MAX_BACKOFF`].
fn next_backoff(current: Duration) -> Duration {
    (current * 2).min(MAX_BACKOFF)
}

/// Background subscriber that keeps an [`OpControlStore`] fresh from the
/// gateway's `PolicyService.OpControlStream`.
///
/// Mirrors [`crate::invalidation_client::InvalidationClient`]: it opens the
/// stream keyed by this agent's composite id, applies each pushed signal to the
/// store, and reconnects forever with exponential backoff. The gateway filters
/// the broadcast so only this agent's ops arrive.
pub struct OpControlClient;

impl OpControlClient {
    /// Spawn the subscribe loop on the Tokio runtime and return its handle.
    ///
    /// `gateway_url` is the same endpoint the policy-check path forwards to;
    /// `agent_id` must match the `agent_id` on this agent's `CheckActionRequest`s
    /// so the gateway routes the right signals. Abort the returned
    /// [`JoinHandle`] to stop the subscriber.
    pub fn start(gateway_url: String, agent_id: AgentId, store: OpControlStore) -> JoinHandle<()> {
        tokio::spawn(async move { run(gateway_url, agent_id, store).await })
    }
}

/// Reconnect loop: subscribe, apply signals, and on disconnect back off
/// exponentially before resubscribing.
async fn run(gateway_url: String, agent_id: AgentId, store: OpControlStore) {
    let mut backoff = INITIAL_BACKOFF;
    loop {
        match subscribe_once(&gateway_url, &agent_id, &store).await {
            // The gateway closed the stream cleanly — reconnect promptly.
            Ok(()) => backoff = INITIAL_BACKOFF,
            Err(err) => {
                metrics::counter!("aa_op_control_reconnects_total").increment(1);
                tracing::warn!(
                    error = %err,
                    backoff_secs = backoff.as_secs(),
                    "op-control stream dropped; reconnecting after backoff"
                );
                tokio::time::sleep(backoff).await;
                backoff = next_backoff(backoff);
            }
        }
    }
}

/// Open one `OpControlStream` and apply messages to the store until it ends or
/// errors.
async fn subscribe_once(
    gateway_url: &str,
    agent_id: &AgentId,
    store: &OpControlStore,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    let mut client = PolicyServiceClient::connect(gateway_url.to_owned()).await?;
    let request = OpControlSubscribeRequest {
        agent_id: Some(agent_id.clone()),
    };
    let response = client.op_control_stream(request).await?;
    let mut inbound = response.into_inner();

    while let Some(message) = inbound.message().await? {
        let signal = message.signal();
        tracing::debug!(op_id = %message.op_id, ?signal, "op-control signal received");
        store.apply(&message.op_id, signal);
        metrics::counter!("aa_op_control_signals_total").increment(1);
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn backoff_doubles_then_caps_at_32s() {
        let schedule: Vec<u64> = std::iter::successors(Some(INITIAL_BACKOFF), |&d| Some(next_backoff(d)))
            .take(7)
            .map(|d| d.as_secs())
            .collect();
        assert_eq!(schedule, vec![1, 2, 4, 8, 16, 32, 32]);
    }

    #[test]
    fn terminate_records_terminated_state() {
        let store = OpControlStore::new();
        assert_eq!(
            store.apply("t:s", OpControlSignal::Terminate),
            Some(OpState::Terminated)
        );
        assert_eq!(store.state("t:s"), Some(OpState::Terminated));
    }

    #[test]
    fn pause_then_resume_clears_state() {
        let store = OpControlStore::new();
        assert_eq!(store.apply("t:s", OpControlSignal::Pause), Some(OpState::Paused));
        assert_eq!(store.apply("t:s", OpControlSignal::Resume), None);
        assert_eq!(store.state("t:s"), None);
    }

    #[test]
    fn terminate_is_sticky_against_later_pause_and_resume() {
        let store = OpControlStore::new();
        store.apply("t:s", OpControlSignal::Terminate);
        // A racing pause or resume must not lift the kill switch.
        assert_eq!(store.apply("t:s", OpControlSignal::Pause), Some(OpState::Terminated));
        assert_eq!(store.apply("t:s", OpControlSignal::Resume), Some(OpState::Terminated));
        assert_eq!(store.state("t:s"), Some(OpState::Terminated));
    }

    #[test]
    fn unspecified_signal_is_ignored() {
        let store = OpControlStore::new();
        assert_eq!(store.apply("t:s", OpControlSignal::Unspecified), None);
        assert_eq!(store.state("t:s"), None);
    }

    #[test]
    fn server_side_halt_keys_are_disjoint_from_per_op_ids() {
        // Per-op ids always carry a colon ("{trace_id}:{span_id}"); the
        // server-side halt keys must not collide with that namespace.
        assert!(!GLOBAL_HALT_OP_ID.contains(':'));
        assert_eq!(agent_halt_op_id("svc-agent"), "agent:svc-agent");

        // An agent-level halt is addressable independently of any op halt.
        let store = OpControlStore::new();
        store.apply(&agent_halt_op_id("svc-agent"), OpControlSignal::Terminate);
        assert_eq!(store.state(&agent_halt_op_id("svc-agent")), Some(OpState::Terminated));
        assert_eq!(store.state("trace:span"), None);
    }

    #[test]
    fn distinct_ops_are_independent() {
        let store = OpControlStore::new();
        store.apply("a:1", OpControlSignal::Terminate);
        store.apply("b:2", OpControlSignal::Pause);
        assert_eq!(store.state("a:1"), Some(OpState::Terminated));
        assert_eq!(store.state("b:2"), Some(OpState::Paused));
        assert_eq!(store.state("c:3"), None);
    }
}