chio_external_guards/external/threat_intel/
safe_browsing.rs1use std::time::Duration;
10
11use async_trait::async_trait;
12use chio_core_types::GuardEvidence;
13use chio_kernel::Verdict;
14use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use zeroize::Zeroizing;
19
20use crate::external::bedrock::{classify_reqwest_error, classify_status_error};
21use crate::external::{ExternalGuard, ExternalGuardError, GuardCallContext};
22
23pub const GUARD_NAME: &str = "safe-browsing";
25
26pub const DEFAULT_BASE_URL: &str = "https://safebrowsing.googleapis.com/v4";
28
29pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
31
32pub const DEFAULT_CLIENT_ID: &str = "chio-guards";
34
35pub const DEFAULT_CLIENT_VERSION: &str = "0.1.0";
37
38#[derive(Clone)]
40pub struct SafeBrowsingConfig {
41 pub api_key: Zeroizing<String>,
43 pub client_id: String,
45 pub client_version: String,
47 pub base_url: Option<String>,
49 pub threat_types: Vec<String>,
51 pub timeout: Duration,
53}
54
55impl std::fmt::Debug for SafeBrowsingConfig {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("SafeBrowsingConfig")
58 .field("api_key", &"***redacted***")
59 .field("client_id", &self.client_id)
60 .field("client_version", &self.client_version)
61 .field("base_url", &self.base_url)
62 .field("threat_types", &self.threat_types)
63 .field("timeout", &self.timeout)
64 .finish()
65 }
66}
67
68impl SafeBrowsingConfig {
69 pub fn new(api_key: impl Into<String>) -> Self {
71 Self {
72 api_key: Zeroizing::new(api_key.into()),
73 client_id: DEFAULT_CLIENT_ID.to_string(),
74 client_version: DEFAULT_CLIENT_VERSION.to_string(),
75 base_url: None,
76 threat_types: vec![
77 "MALWARE".to_string(),
78 "SOCIAL_ENGINEERING".to_string(),
79 "UNWANTED_SOFTWARE".to_string(),
80 "POTENTIALLY_HARMFUL_APPLICATION".to_string(),
81 ],
82 timeout: DEFAULT_TIMEOUT,
83 }
84 }
85
86 pub fn with_base_url(mut self, base: impl Into<String>) -> Self {
88 self.base_url = Some(base.into());
89 self
90 }
91
92 fn resolved_base_url(&self) -> String {
93 self.base_url
94 .clone()
95 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
96 .trim_end_matches('/')
97 .to_string()
98 }
99}
100
101#[derive(Debug, Clone, Deserialize)]
102struct SafeBrowsingArgs {
103 url: String,
104}
105
106#[derive(Debug, Serialize)]
107struct FindRequest<'a> {
108 client: ClientInfo<'a>,
109 #[serde(rename = "threatInfo")]
110 threat_info: ThreatInfo<'a>,
111}
112
113#[derive(Debug, Serialize)]
114struct ClientInfo<'a> {
115 #[serde(rename = "clientId")]
116 client_id: &'a str,
117 #[serde(rename = "clientVersion")]
118 client_version: &'a str,
119}
120
121#[derive(Debug, Serialize)]
122struct ThreatInfo<'a> {
123 #[serde(rename = "threatTypes")]
124 threat_types: &'a [String],
125 #[serde(rename = "platformTypes")]
126 platform_types: Vec<&'a str>,
127 #[serde(rename = "threatEntryTypes")]
128 threat_entry_types: Vec<&'a str>,
129 #[serde(rename = "threatEntries")]
130 threat_entries: Vec<ThreatEntry<'a>>,
131}
132
133#[derive(Debug, Serialize)]
134struct ThreatEntry<'a> {
135 url: &'a str,
136}
137
138#[derive(Debug, Clone, Deserialize)]
139struct FindResponse {
140 #[serde(default)]
141 matches: Vec<MatchEntry>,
142}
143
144#[derive(Debug, Clone, Deserialize, Serialize)]
145struct MatchEntry {
146 #[serde(default, rename = "threatType")]
147 threat_type: Option<String>,
148 #[serde(default, rename = "platformType")]
149 platform_type: Option<String>,
150}
151
152#[derive(Debug, Clone, Serialize)]
154pub struct SafeBrowsingEvidence {
155 pub url: String,
156 pub matches: Vec<String>,
157}
158
159pub struct SafeBrowsingGuard {
161 cfg: SafeBrowsingConfig,
162 base_url: String,
163 http: Client,
164}
165
166impl SafeBrowsingGuard {
167 pub fn new(cfg: SafeBrowsingConfig) -> Result<Self, ExternalGuardError> {
169 let http = Client::builder()
170 .timeout(cfg.timeout)
171 .build()
172 .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
173 let base_url = cfg.resolved_base_url();
174 Ok(Self {
175 cfg,
176 base_url,
177 http,
178 })
179 }
180
181 pub fn with_client(cfg: SafeBrowsingConfig, http: Client) -> Self {
183 let base_url = cfg.resolved_base_url();
184 Self {
185 cfg,
186 base_url,
187 http,
188 }
189 }
190
191 pub fn evidence_from_decision(
193 &self,
194 verdict: Verdict,
195 details: Option<&SafeBrowsingEvidence>,
196 ) -> GuardEvidence {
197 GuardEvidence {
198 guard_name: self.name().to_string(),
199 verdict: matches!(verdict, Verdict::Allow),
200 details: details.and_then(|d| serde_json::to_string(d).ok()),
201 }
202 }
203}
204
205#[async_trait]
206impl ExternalGuard for SafeBrowsingGuard {
207 fn name(&self) -> &str {
208 GUARD_NAME
209 }
210
211 fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
212 let args: SafeBrowsingArgs = serde_json::from_str(&ctx.arguments_json).ok()?;
213 let mut hasher = Sha256::new();
214 hasher.update(args.url.as_bytes());
215 let digest = hasher.finalize();
216 let mut hex = String::with_capacity(digest.len() * 2);
217 for b in digest {
218 hex.push_str(&format!("{b:02x}"));
219 }
220 Some(format!("sb:{hex}"))
221 }
222
223 async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
224 super::super::endpoint_security::validate_external_guard_url(
225 "safe-browsing base_url",
226 &self.base_url,
227 )?;
228 let args: SafeBrowsingArgs = serde_json::from_str(&ctx.arguments_json).map_err(|e| {
229 ExternalGuardError::Permanent(format!("invalid safe-browsing arguments: {e}"))
230 })?;
231
232 let endpoint = format!(
233 "{}/threatMatches:find?key={}",
234 self.base_url,
235 self.cfg.api_key.as_str()
236 );
237
238 let body = FindRequest {
239 client: ClientInfo {
240 client_id: &self.cfg.client_id,
241 client_version: &self.cfg.client_version,
242 },
243 threat_info: ThreatInfo {
244 threat_types: &self.cfg.threat_types,
245 platform_types: vec!["ANY_PLATFORM"],
246 threat_entry_types: vec!["URL"],
247 threat_entries: vec![ThreatEntry { url: &args.url }],
248 },
249 };
250
251 let mut headers = HeaderMap::new();
252 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
253
254 let resp = self
255 .http
256 .post(&endpoint)
257 .headers(headers)
258 .json(&body)
259 .send()
260 .await
261 .map_err(classify_reqwest_error)?;
262
263 let status = resp.status();
264 let text = resp
265 .text()
266 .await
267 .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
268
269 if !status.is_success() {
270 return Err(classify_status_error("safe-browsing", status, &text));
271 }
272
273 let parsed: FindResponse = serde_json::from_str(&text).map_err(|e| {
274 ExternalGuardError::Transient(format!("parse safe browsing response: {e}"))
275 })?;
276
277 let matched = !parsed.matches.is_empty();
278 tracing::info!(
279 guard = GUARD_NAME,
280 match_count = parsed.matches.len(),
281 "safe browsing response"
282 );
283
284 Ok(if matched {
285 Verdict::Deny
286 } else {
287 Verdict::Allow
288 })
289 }
290}