Skip to main content

aperion_shield/orgmode/
policy_pull.rs

1//! Policy fan-out -- pull side.
2//!
3//! Polls `/api/enterprise/shield/shieldset/<group>/version` every 30 s.
4//! When the server reports a newer version, downloads the YAML and
5//! publishes a fresh [`crate::Engine`] on a `tokio::sync::watch` channel
6//! that the MCP middleman snapshots on every tool call.
7//!
8//! The polling cadence is intentionally generous -- this is M2 first
9//! cut; we can swap in SSE / WebSocket fan-out in M2.5 without changing
10//! the consumer API (still a `watch::Receiver<Arc<Engine>>`).
11
12use std::sync::Arc;
13use std::time::Duration;
14
15use tokio::sync::watch;
16
17use super::client::OrgApi;
18use super::state::OrgState;
19use crate::Engine;
20
21/// How often we probe for a new policy version.
22const POLL_INTERVAL: Duration = Duration::from_secs(30);
23
24/// Handle returned to `main()`. Holds the receiver side of the watch
25/// channel + the running task. Dropping it cancels the task.
26pub struct PolicyPullHandle {
27    pub current: watch::Receiver<Arc<Engine>>,
28    pub killswitch: watch::Receiver<bool>,
29    /// Latest version we've seen from the server, exposed for the
30    /// status / metrics path.
31    pub version: Arc<tokio::sync::Mutex<u64>>,
32    pub _task: tokio::task::JoinHandle<()>,
33}
34
35/// Spawn the policy-pull loop. `initial_engine` is the engine the
36/// process started with -- usually the result of
37/// `orgmode::load_initial_engine` -- and is published as the first
38/// value on the watch channel.
39pub fn start_policy_pull(
40    api: Arc<OrgApi>,
41    state: OrgState,
42    initial_engine: Arc<Engine>,
43    initial_version: u64,
44) -> PolicyPullHandle {
45    let (tx, rx) = watch::channel(initial_engine);
46    let (ks_tx, ks_rx) = watch::channel(false);
47    let version = Arc::new(tokio::sync::Mutex::new(initial_version));
48    let version_for_task = version.clone();
49
50    let task = tokio::spawn(async move {
51        let mut ticker = tokio::time::interval(POLL_INTERVAL);
52        ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
53        // First tick fires immediately -- skip; we already published
54        // the initial engine.
55        ticker.tick().await;
56        loop {
57            ticker.tick().await;
58            let probe = match api.get_shieldset_version(&state.policy_group).await {
59                Ok(v) => v,
60                Err(e) => {
61                    log::warn!("[shield] policy version probe failed: {}", e);
62                    continue;
63                }
64            };
65
66            // Killswitch publishes independently so we can react fast
67            // without re-pulling policy.
68            let _ = ks_tx.send(probe.killswitch.on);
69            if probe.killswitch.on {
70                log::warn!(
71                    "[shield] killswitch ON (reason={:?}) -- block-all in effect",
72                    probe.killswitch.reason
73                );
74            }
75
76            let cur = version_for_task.lock().await;
77            if probe.version <= *cur {
78                continue;
79            }
80            drop(cur); // free the lock for the duration of the fetch
81
82            let pulled = match api.get_shieldset(&state.policy_group).await {
83                Ok(p) => p,
84                Err(e) => {
85                    log::warn!("[shield] shieldset fetch failed: {}", e);
86                    continue;
87                }
88            };
89            let (yaml, new_version) = pulled;
90            let new_engine = match Engine::from_yaml(&yaml) {
91                Ok(e) => e,
92                Err(e) => {
93                    log::error!(
94                        "[shield] pulled shieldset is invalid (version={}): {}. Keeping previous policy.",
95                        new_version, e
96                    );
97                    continue;
98                }
99            };
100            let mut cur = version_for_task.lock().await;
101            *cur = new_version;
102            drop(cur);
103
104            log::warn!(
105                "[shield] hot-reloaded policy: group={} version={} rules={}",
106                state.policy_group,
107                new_version,
108                new_engine.rules.len()
109            );
110            // `send` returns Err only if every receiver has been
111            // dropped; in that case the process is shutting down.
112            let _ = tx.send(Arc::new(new_engine));
113        }
114    });
115
116    PolicyPullHandle {
117        current: rx,
118        killswitch: ks_rx,
119        version,
120        _task: task,
121    }
122}