use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::client::OrgApi;
use super::state::OrgState;
use crate::Engine;
const POLL_INTERVAL: Duration = Duration::from_secs(30);
pub struct PolicyPullHandle {
pub current: watch::Receiver<Arc<Engine>>,
pub killswitch: watch::Receiver<bool>,
pub version: Arc<tokio::sync::Mutex<u64>>,
pub _task: tokio::task::JoinHandle<()>,
}
pub fn start_policy_pull(
api: Arc<OrgApi>,
state: OrgState,
initial_engine: Arc<Engine>,
initial_version: u64,
) -> PolicyPullHandle {
let (tx, rx) = watch::channel(initial_engine);
let (ks_tx, ks_rx) = watch::channel(false);
let version = Arc::new(tokio::sync::Mutex::new(initial_version));
let version_for_task = version.clone();
let task = tokio::spawn(async move {
let mut ticker = tokio::time::interval(POLL_INTERVAL);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
ticker.tick().await;
loop {
ticker.tick().await;
let probe = match api.get_shieldset_version(&state.policy_group).await {
Ok(v) => v,
Err(e) => {
log::warn!("[shield] policy version probe failed: {}", e);
continue;
}
};
let _ = ks_tx.send(probe.killswitch.on);
if probe.killswitch.on {
log::warn!(
"[shield] killswitch ON (reason={:?}) -- block-all in effect",
probe.killswitch.reason
);
}
let cur = version_for_task.lock().await;
if probe.version <= *cur {
continue;
}
drop(cur);
let pulled = match api.get_shieldset(&state.policy_group).await {
Ok(p) => p,
Err(e) => {
log::warn!("[shield] shieldset fetch failed: {}", e);
continue;
}
};
let (yaml, new_version) = pulled;
let new_engine = match Engine::from_yaml(&yaml) {
Ok(e) => e,
Err(e) => {
log::error!(
"[shield] pulled shieldset is invalid (version={}): {}. Keeping previous policy.",
new_version, e
);
continue;
}
};
let mut cur = version_for_task.lock().await;
*cur = new_version;
drop(cur);
log::warn!(
"[shield] hot-reloaded policy: group={} version={} rules={}",
state.policy_group,
new_version,
new_engine.rules.len()
);
let _ = tx.send(Arc::new(new_engine));
}
});
PolicyPullHandle {
current: rx,
killswitch: ks_rx,
version,
_task: task,
}
}