Skip to main content

chio_external_guards/external/threat_intel/
safe_browsing.rs

1//! Google Safe Browsing v4 adapter (phase 13.3).
2//!
3//! Adapted from
4//! `../clawdstrike/crates/libs/clawdstrike/src/async_guards/threat_intel/safe_browsing.rs`.
5//! Accepts `{"url": "<absolute-url>"}` in
6//! [`GuardCallContext::arguments_json`] and denies when Safe Browsing
7//! returns at least one match.
8
9use std::time::Duration;
10
11use async_trait::async_trait;
12use chio_core_types::GuardEvidence;
13use chio_kernel::Verdict;
14use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use zeroize::Zeroizing;
19
20use crate::external::bedrock::{classify_reqwest_error, classify_status_error};
21use crate::external::{ExternalGuard, ExternalGuardError, GuardCallContext};
22
23/// Guard name reported by [`SafeBrowsingGuard::name`].
24pub const GUARD_NAME: &str = "safe-browsing";
25
26/// Default base URL.
27pub const DEFAULT_BASE_URL: &str = "https://safebrowsing.googleapis.com/v4";
28
29/// Default request timeout.
30pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
31
32/// Default client id reported to the Safe Browsing API.
33pub const DEFAULT_CLIENT_ID: &str = "chio-guards";
34
35/// Default client version.
36pub const DEFAULT_CLIENT_VERSION: &str = "0.1.0";
37
38/// Configuration for [`SafeBrowsingGuard`].
39#[derive(Clone)]
40pub struct SafeBrowsingConfig {
41    /// API key (query parameter `key`).
42    pub api_key: Zeroizing<String>,
43    /// Client identifier submitted in the request body.
44    pub client_id: String,
45    /// Client version submitted in the request body.
46    pub client_version: String,
47    /// Override the base URL (test hook).
48    pub base_url: Option<String>,
49    /// Threat types to query.
50    pub threat_types: Vec<String>,
51    /// Per-request HTTP timeout.
52    pub timeout: Duration,
53}
54
55impl std::fmt::Debug for SafeBrowsingConfig {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("SafeBrowsingConfig")
58            .field("api_key", &"***redacted***")
59            .field("client_id", &self.client_id)
60            .field("client_version", &self.client_version)
61            .field("base_url", &self.base_url)
62            .field("threat_types", &self.threat_types)
63            .field("timeout", &self.timeout)
64            .finish()
65    }
66}
67
68impl SafeBrowsingConfig {
69    /// Construct a config with defaults.
70    pub fn new(api_key: impl Into<String>) -> Self {
71        Self {
72            api_key: Zeroizing::new(api_key.into()),
73            client_id: DEFAULT_CLIENT_ID.to_string(),
74            client_version: DEFAULT_CLIENT_VERSION.to_string(),
75            base_url: None,
76            threat_types: vec![
77                "MALWARE".to_string(),
78                "SOCIAL_ENGINEERING".to_string(),
79                "UNWANTED_SOFTWARE".to_string(),
80                "POTENTIALLY_HARMFUL_APPLICATION".to_string(),
81            ],
82            timeout: DEFAULT_TIMEOUT,
83        }
84    }
85
86    /// Override the base URL (used in tests).
87    pub fn with_base_url(mut self, base: impl Into<String>) -> Self {
88        self.base_url = Some(base.into());
89        self
90    }
91
92    fn resolved_base_url(&self) -> String {
93        self.base_url
94            .clone()
95            .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
96            .trim_end_matches('/')
97            .to_string()
98    }
99}
100
101#[derive(Debug, Clone, Deserialize)]
102struct SafeBrowsingArgs {
103    url: String,
104}
105
106#[derive(Debug, Serialize)]
107struct FindRequest<'a> {
108    client: ClientInfo<'a>,
109    #[serde(rename = "threatInfo")]
110    threat_info: ThreatInfo<'a>,
111}
112
113#[derive(Debug, Serialize)]
114struct ClientInfo<'a> {
115    #[serde(rename = "clientId")]
116    client_id: &'a str,
117    #[serde(rename = "clientVersion")]
118    client_version: &'a str,
119}
120
121#[derive(Debug, Serialize)]
122struct ThreatInfo<'a> {
123    #[serde(rename = "threatTypes")]
124    threat_types: &'a [String],
125    #[serde(rename = "platformTypes")]
126    platform_types: Vec<&'a str>,
127    #[serde(rename = "threatEntryTypes")]
128    threat_entry_types: Vec<&'a str>,
129    #[serde(rename = "threatEntries")]
130    threat_entries: Vec<ThreatEntry<'a>>,
131}
132
133#[derive(Debug, Serialize)]
134struct ThreatEntry<'a> {
135    url: &'a str,
136}
137
138#[derive(Debug, Clone, Deserialize)]
139struct FindResponse {
140    #[serde(default)]
141    matches: Vec<MatchEntry>,
142}
143
144#[derive(Debug, Clone, Deserialize, Serialize)]
145struct MatchEntry {
146    #[serde(default, rename = "threatType")]
147    threat_type: Option<String>,
148    #[serde(default, rename = "platformType")]
149    platform_type: Option<String>,
150}
151
152/// Structured receipt evidence.
153#[derive(Debug, Clone, Serialize)]
154pub struct SafeBrowsingEvidence {
155    pub url: String,
156    pub matches: Vec<String>,
157}
158
159/// Guard wrapping Safe Browsing `threatMatches:find`.
160pub struct SafeBrowsingGuard {
161    cfg: SafeBrowsingConfig,
162    base_url: String,
163    http: Client,
164}
165
166impl SafeBrowsingGuard {
167    /// Build a guard with an internally-owned [`reqwest::Client`].
168    pub fn new(cfg: SafeBrowsingConfig) -> Result<Self, ExternalGuardError> {
169        let http = Client::builder()
170            .timeout(cfg.timeout)
171            .build()
172            .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
173        let base_url = cfg.resolved_base_url();
174        Ok(Self {
175            cfg,
176            base_url,
177            http,
178        })
179    }
180
181    /// Build with a caller-supplied client.
182    pub fn with_client(cfg: SafeBrowsingConfig, http: Client) -> Self {
183        let base_url = cfg.resolved_base_url();
184        Self {
185            cfg,
186            base_url,
187            http,
188        }
189    }
190
191    /// Build a [`GuardEvidence`] record for a prior decision.
192    pub fn evidence_from_decision(
193        &self,
194        verdict: Verdict,
195        details: Option<&SafeBrowsingEvidence>,
196    ) -> GuardEvidence {
197        GuardEvidence {
198            guard_name: self.name().to_string(),
199            verdict: matches!(verdict, Verdict::Allow),
200            details: details.and_then(|d| serde_json::to_string(d).ok()),
201        }
202    }
203}
204
205#[async_trait]
206impl ExternalGuard for SafeBrowsingGuard {
207    fn name(&self) -> &str {
208        GUARD_NAME
209    }
210
211    fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
212        let args: SafeBrowsingArgs = serde_json::from_str(&ctx.arguments_json).ok()?;
213        let mut hasher = Sha256::new();
214        hasher.update(args.url.as_bytes());
215        let digest = hasher.finalize();
216        let mut hex = String::with_capacity(digest.len() * 2);
217        for b in digest {
218            hex.push_str(&format!("{b:02x}"));
219        }
220        Some(format!("sb:{hex}"))
221    }
222
223    async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
224        super::super::endpoint_security::validate_external_guard_url(
225            "safe-browsing base_url",
226            &self.base_url,
227        )?;
228        let args: SafeBrowsingArgs = serde_json::from_str(&ctx.arguments_json).map_err(|e| {
229            ExternalGuardError::Permanent(format!("invalid safe-browsing arguments: {e}"))
230        })?;
231
232        let endpoint = format!(
233            "{}/threatMatches:find?key={}",
234            self.base_url,
235            self.cfg.api_key.as_str()
236        );
237
238        let body = FindRequest {
239            client: ClientInfo {
240                client_id: &self.cfg.client_id,
241                client_version: &self.cfg.client_version,
242            },
243            threat_info: ThreatInfo {
244                threat_types: &self.cfg.threat_types,
245                platform_types: vec!["ANY_PLATFORM"],
246                threat_entry_types: vec!["URL"],
247                threat_entries: vec![ThreatEntry { url: &args.url }],
248            },
249        };
250
251        let mut headers = HeaderMap::new();
252        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
253
254        let resp = self
255            .http
256            .post(&endpoint)
257            .headers(headers)
258            .json(&body)
259            .send()
260            .await
261            .map_err(classify_reqwest_error)?;
262
263        let status = resp.status();
264        let text = resp
265            .text()
266            .await
267            .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
268
269        if !status.is_success() {
270            return Err(classify_status_error("safe-browsing", status, &text));
271        }
272
273        let parsed: FindResponse = serde_json::from_str(&text).map_err(|e| {
274            ExternalGuardError::Transient(format!("parse safe browsing response: {e}"))
275        })?;
276
277        let matched = !parsed.matches.is_empty();
278        tracing::info!(
279            guard = GUARD_NAME,
280            match_count = parsed.matches.len(),
281            "safe browsing response"
282        );
283
284        Ok(if matched {
285            Verdict::Deny
286        } else {
287            Verdict::Allow
288        })
289    }
290}