Skip to main content

chio_external_guards/external/threat_intel/
virustotal.rs

1//! VirusTotal v3 adapter (phase 13.3).
2//!
3//! Adapted from
4//! `../clawdstrike/crates/libs/clawdstrike/src/async_guards/threat_intel/virustotal.rs`
5//! with the following deviations:
6//!
7//! * The Chio `ExternalGuard` surface is synchronous-in-decision - we
8//!   return [`Verdict`] directly rather than the ClawdStrike `warn/block`
9//!   tri-state. Below-threshold detections surface as `Allow`; the
10//!   existing adapter `tracing::warn!`s the suspicious-but-not-blocked
11//!   signal.
12//! * Arguments are passed via a small JSON envelope in
13//!   [`GuardCallContext::arguments_json`] (`{"hash": ...}` or
14//!   `{"url": ...}`) instead of ClawdStrike's `GuardAction` enum.
15
16use std::time::Duration;
17
18use async_trait::async_trait;
19use base64::engine::general_purpose::URL_SAFE_NO_PAD;
20use base64::Engine as _;
21use chio_core_types::GuardEvidence;
22use chio_kernel::Verdict;
23use reqwest::header::{HeaderMap, HeaderValue};
24use reqwest::Client;
25use serde::{Deserialize, Serialize};
26use sha2::{Digest, Sha256};
27use zeroize::Zeroizing;
28
29use crate::external::bedrock::{classify_reqwest_error, classify_status_error};
30use crate::external::{ExternalGuard, ExternalGuardError, GuardCallContext};
31
32/// Guard name reported by [`VirusTotalGuard::name`].
33pub const GUARD_NAME: &str = "virustotal";
34
35/// Default base URL.
36pub const DEFAULT_BASE_URL: &str = "https://www.virustotal.com/api/v3";
37
38/// Default detection threshold.
39pub const DEFAULT_MIN_DETECTIONS: u64 = 5;
40
41/// Default request timeout.
42pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
43
44/// Configuration for [`VirusTotalGuard`].
45#[derive(Clone)]
46pub struct VirusTotalConfig {
47    /// `x-apikey` header.
48    pub api_key: Zeroizing<String>,
49    /// Override the base URL (test hook).
50    pub base_url: Option<String>,
51    /// Detection threshold. Calls are denied when
52    /// `malicious + suspicious >= min_detections`.
53    pub min_detections: u64,
54    /// Per-request HTTP timeout.
55    pub timeout: Duration,
56}
57
58impl std::fmt::Debug for VirusTotalConfig {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("VirusTotalConfig")
61            .field("api_key", &"***redacted***")
62            .field("base_url", &self.base_url)
63            .field("min_detections", &self.min_detections)
64            .field("timeout", &self.timeout)
65            .finish()
66    }
67}
68
69impl VirusTotalConfig {
70    /// Construct a config with defaults.
71    pub fn new(api_key: impl Into<String>) -> Self {
72        Self {
73            api_key: Zeroizing::new(api_key.into()),
74            base_url: None,
75            min_detections: DEFAULT_MIN_DETECTIONS,
76            timeout: DEFAULT_TIMEOUT,
77        }
78    }
79
80    /// Override the base URL (used in tests).
81    pub fn with_base_url(mut self, base: impl Into<String>) -> Self {
82        self.base_url = Some(base.into());
83        self
84    }
85
86    /// Override the detection threshold.
87    pub fn with_min_detections(mut self, threshold: u64) -> Self {
88        self.min_detections = threshold.max(1);
89        self
90    }
91
92    fn resolved_base_url(&self) -> String {
93        self.base_url
94            .clone()
95            .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
96            .trim_end_matches('/')
97            .to_string()
98    }
99}
100
101/// Shape of the JSON we accept in
102/// [`GuardCallContext::arguments_json`].
103#[derive(Debug, Clone, Deserialize)]
104struct VirusTotalArgs {
105    #[serde(default)]
106    hash: Option<String>,
107    #[serde(default)]
108    url: Option<String>,
109}
110
111#[derive(Debug, Clone, Deserialize)]
112struct VirusTotalResponse {
113    #[serde(default)]
114    data: Option<VirusTotalData>,
115}
116
117#[derive(Debug, Clone, Deserialize)]
118struct VirusTotalData {
119    #[serde(default)]
120    attributes: Option<VirusTotalAttributes>,
121}
122
123#[derive(Debug, Clone, Deserialize)]
124struct VirusTotalAttributes {
125    #[serde(default, rename = "last_analysis_stats")]
126    last_analysis_stats: Option<VirusTotalStats>,
127}
128
129#[derive(Debug, Clone, Deserialize)]
130struct VirusTotalStats {
131    #[serde(default)]
132    malicious: u64,
133    #[serde(default)]
134    suspicious: u64,
135}
136
137/// Structured receipt evidence.
138#[derive(Debug, Clone, Serialize)]
139pub struct VirusTotalEvidence {
140    pub target: String,
141    pub malicious: u64,
142    pub suspicious: u64,
143    pub min_detections: u64,
144}
145
146/// Guard that queries VirusTotal for file-hash or URL reputation.
147pub struct VirusTotalGuard {
148    cfg: VirusTotalConfig,
149    base_url: String,
150    http: Client,
151}
152
153impl VirusTotalGuard {
154    /// Build a guard with an internally-owned [`reqwest::Client`].
155    pub fn new(cfg: VirusTotalConfig) -> Result<Self, ExternalGuardError> {
156        let http = Client::builder()
157            .timeout(cfg.timeout)
158            .build()
159            .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
160        let base_url = cfg.resolved_base_url();
161        Ok(Self {
162            cfg,
163            base_url,
164            http,
165        })
166    }
167
168    /// Build with a caller-supplied client.
169    pub fn with_client(cfg: VirusTotalConfig, http: Client) -> Self {
170        let base_url = cfg.resolved_base_url();
171        Self {
172            cfg,
173            base_url,
174            http,
175        }
176    }
177
178    /// Build a [`GuardEvidence`] record for a prior decision.
179    pub fn evidence_from_decision(
180        &self,
181        verdict: Verdict,
182        details: Option<&VirusTotalEvidence>,
183    ) -> GuardEvidence {
184        GuardEvidence {
185            guard_name: self.name().to_string(),
186            verdict: matches!(verdict, Verdict::Allow),
187            details: details.and_then(|d| serde_json::to_string(d).ok()),
188        }
189    }
190}
191
192fn normalize_sha256_hex(input: &str) -> Option<String> {
193    let trimmed = input.trim();
194    let without_prefix = if let Some(rest) = trimmed.strip_prefix("sha256:") {
195        rest
196    } else if let Some(rest) = trimmed.strip_prefix("0x") {
197        rest
198    } else {
199        trimmed
200    };
201    let hex = without_prefix.trim();
202    if hex.len() != 64 || !hex.bytes().all(|b| b.is_ascii_hexdigit()) {
203        return None;
204    }
205    Some(hex.to_ascii_lowercase())
206}
207
208#[async_trait]
209impl ExternalGuard for VirusTotalGuard {
210    fn name(&self) -> &str {
211        GUARD_NAME
212    }
213
214    fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
215        let args: VirusTotalArgs = serde_json::from_str(&ctx.arguments_json).ok()?;
216        if let Some(h) = args.hash.as_deref().and_then(normalize_sha256_hex) {
217            return Some(format!("vt:file:{h}"));
218        }
219        if let Some(u) = args.url.as_deref() {
220            let mut hasher = Sha256::new();
221            hasher.update(u.as_bytes());
222            let digest = hasher.finalize();
223            let mut hex = String::with_capacity(digest.len() * 2);
224            for b in digest {
225                hex.push_str(&format!("{b:02x}"));
226            }
227            return Some(format!("vt:url:{hex}"));
228        }
229        None
230    }
231
232    async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
233        let args: VirusTotalArgs = serde_json::from_str(&ctx.arguments_json).map_err(|e| {
234            ExternalGuardError::Permanent(format!("invalid virustotal arguments: {e}"))
235        })?;
236
237        let endpoint = if let Some(raw_hash) = args.hash.as_deref() {
238            let Some(hash) = normalize_sha256_hex(raw_hash) else {
239                return Err(ExternalGuardError::Permanent(
240                    "virustotal: hash is not a sha256 hex string".to_string(),
241                ));
242            };
243            format!("{}/files/{hash}", self.base_url)
244        } else if let Some(target_url) = args.url.as_deref() {
245            let id = URL_SAFE_NO_PAD.encode(target_url.as_bytes());
246            format!("{}/urls/{id}", self.base_url)
247        } else {
248            return Err(ExternalGuardError::Permanent(
249                "virustotal: arguments must include `hash` or `url`".to_string(),
250            ));
251        };
252        super::super::endpoint_security::validate_external_guard_url(
253            "virustotal base_url",
254            &endpoint,
255        )?;
256
257        let mut headers = HeaderMap::new();
258        headers.insert(
259            "x-apikey",
260            HeaderValue::from_str(self.cfg.api_key.as_str())
261                .map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
262        );
263
264        let resp = self
265            .http
266            .get(&endpoint)
267            .headers(headers)
268            .send()
269            .await
270            .map_err(classify_reqwest_error)?;
271
272        let status = resp.status();
273        let text = resp
274            .text()
275            .await
276            .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
277
278        // 404 -> not found in VT database. We allow-by-default so that a
279        // previously-unseen hash/URL doesn't block benign traffic. Upstream
280        // callers can layer additional controls.
281        if status.as_u16() == 404 {
282            tracing::info!(guard = GUARD_NAME, "virustotal: target not found");
283            return Ok(Verdict::Allow);
284        }
285
286        if !status.is_success() {
287            return Err(classify_status_error("virustotal", status, &text));
288        }
289
290        let parsed: VirusTotalResponse = serde_json::from_str(&text)
291            .map_err(|e| ExternalGuardError::Transient(format!("parse vt response: {e}")))?;
292
293        let (malicious, suspicious) = parsed
294            .data
295            .and_then(|d| d.attributes)
296            .and_then(|a| a.last_analysis_stats)
297            .map(|s| (s.malicious, s.suspicious))
298            .unwrap_or((0, 0));
299
300        let detections = malicious.saturating_add(suspicious);
301        tracing::info!(
302            guard = GUARD_NAME,
303            malicious,
304            suspicious,
305            min_detections = self.cfg.min_detections,
306            "virustotal response"
307        );
308
309        if detections >= self.cfg.min_detections {
310            return Ok(Verdict::Deny);
311        }
312        if malicious > 0 {
313            tracing::warn!(
314                guard = GUARD_NAME,
315                malicious,
316                suspicious,
317                "virustotal: malicious detections below threshold"
318            );
319        }
320
321        Ok(Verdict::Allow)
322    }
323}