use std::time::Duration;
use async_trait::async_trait;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine as _;
use chio_core_types::GuardEvidence;
use chio_kernel::Verdict;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
use crate::external::bedrock::{classify_reqwest_error, classify_status_error};
use crate::external::{ExternalGuard, ExternalGuardError, GuardCallContext};
pub const GUARD_NAME: &str = "virustotal";
pub const DEFAULT_BASE_URL: &str = "https://www.virustotal.com/api/v3";
pub const DEFAULT_MIN_DETECTIONS: u64 = 5;
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub struct VirusTotalConfig {
pub api_key: Zeroizing<String>,
pub base_url: Option<String>,
pub min_detections: u64,
pub timeout: Duration,
}
impl std::fmt::Debug for VirusTotalConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VirusTotalConfig")
.field("api_key", &"***redacted***")
.field("base_url", &self.base_url)
.field("min_detections", &self.min_detections)
.field("timeout", &self.timeout)
.finish()
}
}
impl VirusTotalConfig {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: Zeroizing::new(api_key.into()),
base_url: None,
min_detections: DEFAULT_MIN_DETECTIONS,
timeout: DEFAULT_TIMEOUT,
}
}
pub fn with_base_url(mut self, base: impl Into<String>) -> Self {
self.base_url = Some(base.into());
self
}
pub fn with_min_detections(mut self, threshold: u64) -> Self {
self.min_detections = threshold.max(1);
self
}
fn resolved_base_url(&self) -> String {
self.base_url
.clone()
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.trim_end_matches('/')
.to_string()
}
}
#[derive(Debug, Clone, Deserialize)]
struct VirusTotalArgs {
#[serde(default)]
hash: Option<String>,
#[serde(default)]
url: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct VirusTotalResponse {
#[serde(default)]
data: Option<VirusTotalData>,
}
#[derive(Debug, Clone, Deserialize)]
struct VirusTotalData {
#[serde(default)]
attributes: Option<VirusTotalAttributes>,
}
#[derive(Debug, Clone, Deserialize)]
struct VirusTotalAttributes {
#[serde(default, rename = "last_analysis_stats")]
last_analysis_stats: Option<VirusTotalStats>,
}
#[derive(Debug, Clone, Deserialize)]
struct VirusTotalStats {
#[serde(default)]
malicious: u64,
#[serde(default)]
suspicious: u64,
}
#[derive(Debug, Clone, Serialize)]
pub struct VirusTotalEvidence {
pub target: String,
pub malicious: u64,
pub suspicious: u64,
pub min_detections: u64,
}
pub struct VirusTotalGuard {
cfg: VirusTotalConfig,
base_url: String,
http: Client,
}
impl VirusTotalGuard {
pub fn new(cfg: VirusTotalConfig) -> Result<Self, ExternalGuardError> {
let http = Client::builder()
.timeout(cfg.timeout)
.build()
.map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
let base_url = cfg.resolved_base_url();
Ok(Self {
cfg,
base_url,
http,
})
}
pub fn with_client(cfg: VirusTotalConfig, http: Client) -> Self {
let base_url = cfg.resolved_base_url();
Self {
cfg,
base_url,
http,
}
}
pub fn evidence_from_decision(
&self,
verdict: Verdict,
details: Option<&VirusTotalEvidence>,
) -> GuardEvidence {
GuardEvidence {
guard_name: self.name().to_string(),
verdict: matches!(verdict, Verdict::Allow),
details: details.and_then(|d| serde_json::to_string(d).ok()),
}
}
}
fn normalize_sha256_hex(input: &str) -> Option<String> {
let trimmed = input.trim();
let without_prefix = if let Some(rest) = trimmed.strip_prefix("sha256:") {
rest
} else if let Some(rest) = trimmed.strip_prefix("0x") {
rest
} else {
trimmed
};
let hex = without_prefix.trim();
if hex.len() != 64 || !hex.bytes().all(|b| b.is_ascii_hexdigit()) {
return None;
}
Some(hex.to_ascii_lowercase())
}
#[async_trait]
impl ExternalGuard for VirusTotalGuard {
fn name(&self) -> &str {
GUARD_NAME
}
fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
let args: VirusTotalArgs = serde_json::from_str(&ctx.arguments_json).ok()?;
if let Some(h) = args.hash.as_deref().and_then(normalize_sha256_hex) {
return Some(format!("vt:file:{h}"));
}
if let Some(u) = args.url.as_deref() {
let mut hasher = Sha256::new();
hasher.update(u.as_bytes());
let digest = hasher.finalize();
let mut hex = String::with_capacity(digest.len() * 2);
for b in digest {
hex.push_str(&format!("{b:02x}"));
}
return Some(format!("vt:url:{hex}"));
}
None
}
async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
let args: VirusTotalArgs = serde_json::from_str(&ctx.arguments_json).map_err(|e| {
ExternalGuardError::Permanent(format!("invalid virustotal arguments: {e}"))
})?;
let endpoint = if let Some(raw_hash) = args.hash.as_deref() {
let Some(hash) = normalize_sha256_hex(raw_hash) else {
return Err(ExternalGuardError::Permanent(
"virustotal: hash is not a sha256 hex string".to_string(),
));
};
format!("{}/files/{hash}", self.base_url)
} else if let Some(target_url) = args.url.as_deref() {
let id = URL_SAFE_NO_PAD.encode(target_url.as_bytes());
format!("{}/urls/{id}", self.base_url)
} else {
return Err(ExternalGuardError::Permanent(
"virustotal: arguments must include `hash` or `url`".to_string(),
));
};
super::super::endpoint_security::validate_external_guard_url(
"virustotal base_url",
&endpoint,
)?;
let mut headers = HeaderMap::new();
headers.insert(
"x-apikey",
HeaderValue::from_str(self.cfg.api_key.as_str())
.map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
);
let resp = self
.http
.get(&endpoint)
.headers(headers)
.send()
.await
.map_err(classify_reqwest_error)?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
if status.as_u16() == 404 {
tracing::info!(guard = GUARD_NAME, "virustotal: target not found");
return Ok(Verdict::Allow);
}
if !status.is_success() {
return Err(classify_status_error("virustotal", status, &text));
}
let parsed: VirusTotalResponse = serde_json::from_str(&text)
.map_err(|e| ExternalGuardError::Transient(format!("parse vt response: {e}")))?;
let (malicious, suspicious) = parsed
.data
.and_then(|d| d.attributes)
.and_then(|a| a.last_analysis_stats)
.map(|s| (s.malicious, s.suspicious))
.unwrap_or((0, 0));
let detections = malicious.saturating_add(suspicious);
tracing::info!(
guard = GUARD_NAME,
malicious,
suspicious,
min_detections = self.cfg.min_detections,
"virustotal response"
);
if detections >= self.cfg.min_detections {
return Ok(Verdict::Deny);
}
if malicious > 0 {
tracing::warn!(
guard = GUARD_NAME,
malicious,
suspicious,
"virustotal: malicious detections below threshold"
);
}
Ok(Verdict::Allow)
}
}