chio_external_guards/external/threat_intel/
virustotal.rs1use std::time::Duration;
17
18use async_trait::async_trait;
19use base64::engine::general_purpose::URL_SAFE_NO_PAD;
20use base64::Engine as _;
21use chio_core_types::GuardEvidence;
22use chio_kernel::Verdict;
23use reqwest::header::{HeaderMap, HeaderValue};
24use reqwest::Client;
25use serde::{Deserialize, Serialize};
26use sha2::{Digest, Sha256};
27use zeroize::Zeroizing;
28
29use crate::external::bedrock::{classify_reqwest_error, classify_status_error};
30use crate::external::{ExternalGuard, ExternalGuardError, GuardCallContext};
31
32pub const GUARD_NAME: &str = "virustotal";
34
35pub const DEFAULT_BASE_URL: &str = "https://www.virustotal.com/api/v3";
37
38pub const DEFAULT_MIN_DETECTIONS: u64 = 5;
40
41pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
43
44#[derive(Clone)]
46pub struct VirusTotalConfig {
47 pub api_key: Zeroizing<String>,
49 pub base_url: Option<String>,
51 pub min_detections: u64,
54 pub timeout: Duration,
56}
57
58impl std::fmt::Debug for VirusTotalConfig {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("VirusTotalConfig")
61 .field("api_key", &"***redacted***")
62 .field("base_url", &self.base_url)
63 .field("min_detections", &self.min_detections)
64 .field("timeout", &self.timeout)
65 .finish()
66 }
67}
68
69impl VirusTotalConfig {
70 pub fn new(api_key: impl Into<String>) -> Self {
72 Self {
73 api_key: Zeroizing::new(api_key.into()),
74 base_url: None,
75 min_detections: DEFAULT_MIN_DETECTIONS,
76 timeout: DEFAULT_TIMEOUT,
77 }
78 }
79
80 pub fn with_base_url(mut self, base: impl Into<String>) -> Self {
82 self.base_url = Some(base.into());
83 self
84 }
85
86 pub fn with_min_detections(mut self, threshold: u64) -> Self {
88 self.min_detections = threshold.max(1);
89 self
90 }
91
92 fn resolved_base_url(&self) -> String {
93 self.base_url
94 .clone()
95 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
96 .trim_end_matches('/')
97 .to_string()
98 }
99}
100
101#[derive(Debug, Clone, Deserialize)]
104struct VirusTotalArgs {
105 #[serde(default)]
106 hash: Option<String>,
107 #[serde(default)]
108 url: Option<String>,
109}
110
111#[derive(Debug, Clone, Deserialize)]
112struct VirusTotalResponse {
113 #[serde(default)]
114 data: Option<VirusTotalData>,
115}
116
117#[derive(Debug, Clone, Deserialize)]
118struct VirusTotalData {
119 #[serde(default)]
120 attributes: Option<VirusTotalAttributes>,
121}
122
123#[derive(Debug, Clone, Deserialize)]
124struct VirusTotalAttributes {
125 #[serde(default, rename = "last_analysis_stats")]
126 last_analysis_stats: Option<VirusTotalStats>,
127}
128
129#[derive(Debug, Clone, Deserialize)]
130struct VirusTotalStats {
131 #[serde(default)]
132 malicious: u64,
133 #[serde(default)]
134 suspicious: u64,
135}
136
137#[derive(Debug, Clone, Serialize)]
139pub struct VirusTotalEvidence {
140 pub target: String,
141 pub malicious: u64,
142 pub suspicious: u64,
143 pub min_detections: u64,
144}
145
146pub struct VirusTotalGuard {
148 cfg: VirusTotalConfig,
149 base_url: String,
150 http: Client,
151}
152
153impl VirusTotalGuard {
154 pub fn new(cfg: VirusTotalConfig) -> Result<Self, ExternalGuardError> {
156 let http = Client::builder()
157 .timeout(cfg.timeout)
158 .build()
159 .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
160 let base_url = cfg.resolved_base_url();
161 Ok(Self {
162 cfg,
163 base_url,
164 http,
165 })
166 }
167
168 pub fn with_client(cfg: VirusTotalConfig, http: Client) -> Self {
170 let base_url = cfg.resolved_base_url();
171 Self {
172 cfg,
173 base_url,
174 http,
175 }
176 }
177
178 pub fn evidence_from_decision(
180 &self,
181 verdict: Verdict,
182 details: Option<&VirusTotalEvidence>,
183 ) -> GuardEvidence {
184 GuardEvidence {
185 guard_name: self.name().to_string(),
186 verdict: matches!(verdict, Verdict::Allow),
187 details: details.and_then(|d| serde_json::to_string(d).ok()),
188 }
189 }
190}
191
192fn normalize_sha256_hex(input: &str) -> Option<String> {
193 let trimmed = input.trim();
194 let without_prefix = if let Some(rest) = trimmed.strip_prefix("sha256:") {
195 rest
196 } else if let Some(rest) = trimmed.strip_prefix("0x") {
197 rest
198 } else {
199 trimmed
200 };
201 let hex = without_prefix.trim();
202 if hex.len() != 64 || !hex.bytes().all(|b| b.is_ascii_hexdigit()) {
203 return None;
204 }
205 Some(hex.to_ascii_lowercase())
206}
207
208#[async_trait]
209impl ExternalGuard for VirusTotalGuard {
210 fn name(&self) -> &str {
211 GUARD_NAME
212 }
213
214 fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
215 let args: VirusTotalArgs = serde_json::from_str(&ctx.arguments_json).ok()?;
216 if let Some(h) = args.hash.as_deref().and_then(normalize_sha256_hex) {
217 return Some(format!("vt:file:{h}"));
218 }
219 if let Some(u) = args.url.as_deref() {
220 let mut hasher = Sha256::new();
221 hasher.update(u.as_bytes());
222 let digest = hasher.finalize();
223 let mut hex = String::with_capacity(digest.len() * 2);
224 for b in digest {
225 hex.push_str(&format!("{b:02x}"));
226 }
227 return Some(format!("vt:url:{hex}"));
228 }
229 None
230 }
231
232 async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
233 let args: VirusTotalArgs = serde_json::from_str(&ctx.arguments_json).map_err(|e| {
234 ExternalGuardError::Permanent(format!("invalid virustotal arguments: {e}"))
235 })?;
236
237 let endpoint = if let Some(raw_hash) = args.hash.as_deref() {
238 let Some(hash) = normalize_sha256_hex(raw_hash) else {
239 return Err(ExternalGuardError::Permanent(
240 "virustotal: hash is not a sha256 hex string".to_string(),
241 ));
242 };
243 format!("{}/files/{hash}", self.base_url)
244 } else if let Some(target_url) = args.url.as_deref() {
245 let id = URL_SAFE_NO_PAD.encode(target_url.as_bytes());
246 format!("{}/urls/{id}", self.base_url)
247 } else {
248 return Err(ExternalGuardError::Permanent(
249 "virustotal: arguments must include `hash` or `url`".to_string(),
250 ));
251 };
252 super::super::endpoint_security::validate_external_guard_url(
253 "virustotal base_url",
254 &endpoint,
255 )?;
256
257 let mut headers = HeaderMap::new();
258 headers.insert(
259 "x-apikey",
260 HeaderValue::from_str(self.cfg.api_key.as_str())
261 .map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
262 );
263
264 let resp = self
265 .http
266 .get(&endpoint)
267 .headers(headers)
268 .send()
269 .await
270 .map_err(classify_reqwest_error)?;
271
272 let status = resp.status();
273 let text = resp
274 .text()
275 .await
276 .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
277
278 if status.as_u16() == 404 {
282 tracing::info!(guard = GUARD_NAME, "virustotal: target not found");
283 return Ok(Verdict::Allow);
284 }
285
286 if !status.is_success() {
287 return Err(classify_status_error("virustotal", status, &text));
288 }
289
290 let parsed: VirusTotalResponse = serde_json::from_str(&text)
291 .map_err(|e| ExternalGuardError::Transient(format!("parse vt response: {e}")))?;
292
293 let (malicious, suspicious) = parsed
294 .data
295 .and_then(|d| d.attributes)
296 .and_then(|a| a.last_analysis_stats)
297 .map(|s| (s.malicious, s.suspicious))
298 .unwrap_or((0, 0));
299
300 let detections = malicious.saturating_add(suspicious);
301 tracing::info!(
302 guard = GUARD_NAME,
303 malicious,
304 suspicious,
305 min_detections = self.cfg.min_detections,
306 "virustotal response"
307 );
308
309 if detections >= self.cfg.min_detections {
310 return Ok(Verdict::Deny);
311 }
312 if malicious > 0 {
313 tracing::warn!(
314 guard = GUARD_NAME,
315 malicious,
316 suspicious,
317 "virustotal: malicious detections below threshold"
318 );
319 }
320
321 Ok(Verdict::Allow)
322 }
323}