use std::time::Duration;
use async_trait::async_trait;
use chio_core_types::GuardEvidence;
use chio_kernel::Verdict;
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
use crate::external::bedrock::{classify_reqwest_error, classify_status_error};
use crate::external::{ExternalGuard, ExternalGuardError, GuardCallContext};
pub const GUARD_NAME: &str = "safe-browsing";
pub const DEFAULT_BASE_URL: &str = "https://safebrowsing.googleapis.com/v4";
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_CLIENT_ID: &str = "chio-guards";
pub const DEFAULT_CLIENT_VERSION: &str = "0.1.0";
#[derive(Clone)]
pub struct SafeBrowsingConfig {
pub api_key: Zeroizing<String>,
pub client_id: String,
pub client_version: String,
pub base_url: Option<String>,
pub threat_types: Vec<String>,
pub timeout: Duration,
}
impl std::fmt::Debug for SafeBrowsingConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SafeBrowsingConfig")
.field("api_key", &"***redacted***")
.field("client_id", &self.client_id)
.field("client_version", &self.client_version)
.field("base_url", &self.base_url)
.field("threat_types", &self.threat_types)
.field("timeout", &self.timeout)
.finish()
}
}
impl SafeBrowsingConfig {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: Zeroizing::new(api_key.into()),
client_id: DEFAULT_CLIENT_ID.to_string(),
client_version: DEFAULT_CLIENT_VERSION.to_string(),
base_url: None,
threat_types: vec![
"MALWARE".to_string(),
"SOCIAL_ENGINEERING".to_string(),
"UNWANTED_SOFTWARE".to_string(),
"POTENTIALLY_HARMFUL_APPLICATION".to_string(),
],
timeout: DEFAULT_TIMEOUT,
}
}
pub fn with_base_url(mut self, base: impl Into<String>) -> Self {
self.base_url = Some(base.into());
self
}
fn resolved_base_url(&self) -> String {
self.base_url
.clone()
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.trim_end_matches('/')
.to_string()
}
}
#[derive(Debug, Clone, Deserialize)]
struct SafeBrowsingArgs {
url: String,
}
#[derive(Debug, Serialize)]
struct FindRequest<'a> {
client: ClientInfo<'a>,
#[serde(rename = "threatInfo")]
threat_info: ThreatInfo<'a>,
}
#[derive(Debug, Serialize)]
struct ClientInfo<'a> {
#[serde(rename = "clientId")]
client_id: &'a str,
#[serde(rename = "clientVersion")]
client_version: &'a str,
}
#[derive(Debug, Serialize)]
struct ThreatInfo<'a> {
#[serde(rename = "threatTypes")]
threat_types: &'a [String],
#[serde(rename = "platformTypes")]
platform_types: Vec<&'a str>,
#[serde(rename = "threatEntryTypes")]
threat_entry_types: Vec<&'a str>,
#[serde(rename = "threatEntries")]
threat_entries: Vec<ThreatEntry<'a>>,
}
#[derive(Debug, Serialize)]
struct ThreatEntry<'a> {
url: &'a str,
}
#[derive(Debug, Clone, Deserialize)]
struct FindResponse {
#[serde(default)]
matches: Vec<MatchEntry>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct MatchEntry {
#[serde(default, rename = "threatType")]
threat_type: Option<String>,
#[serde(default, rename = "platformType")]
platform_type: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SafeBrowsingEvidence {
pub url: String,
pub matches: Vec<String>,
}
pub struct SafeBrowsingGuard {
cfg: SafeBrowsingConfig,
base_url: String,
http: Client,
}
impl SafeBrowsingGuard {
pub fn new(cfg: SafeBrowsingConfig) -> Result<Self, ExternalGuardError> {
let http = Client::builder()
.timeout(cfg.timeout)
.build()
.map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
let base_url = cfg.resolved_base_url();
Ok(Self {
cfg,
base_url,
http,
})
}
pub fn with_client(cfg: SafeBrowsingConfig, http: Client) -> Self {
let base_url = cfg.resolved_base_url();
Self {
cfg,
base_url,
http,
}
}
pub fn evidence_from_decision(
&self,
verdict: Verdict,
details: Option<&SafeBrowsingEvidence>,
) -> GuardEvidence {
GuardEvidence {
guard_name: self.name().to_string(),
verdict: matches!(verdict, Verdict::Allow),
details: details.and_then(|d| serde_json::to_string(d).ok()),
}
}
}
#[async_trait]
impl ExternalGuard for SafeBrowsingGuard {
fn name(&self) -> &str {
GUARD_NAME
}
fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
let args: SafeBrowsingArgs = serde_json::from_str(&ctx.arguments_json).ok()?;
let mut hasher = Sha256::new();
hasher.update(args.url.as_bytes());
let digest = hasher.finalize();
let mut hex = String::with_capacity(digest.len() * 2);
for b in digest {
hex.push_str(&format!("{b:02x}"));
}
Some(format!("sb:{hex}"))
}
async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
super::super::endpoint_security::validate_external_guard_url(
"safe-browsing base_url",
&self.base_url,
)?;
let args: SafeBrowsingArgs = serde_json::from_str(&ctx.arguments_json).map_err(|e| {
ExternalGuardError::Permanent(format!("invalid safe-browsing arguments: {e}"))
})?;
let endpoint = format!(
"{}/threatMatches:find?key={}",
self.base_url,
self.cfg.api_key.as_str()
);
let body = FindRequest {
client: ClientInfo {
client_id: &self.cfg.client_id,
client_version: &self.cfg.client_version,
},
threat_info: ThreatInfo {
threat_types: &self.cfg.threat_types,
platform_types: vec!["ANY_PLATFORM"],
threat_entry_types: vec!["URL"],
threat_entries: vec![ThreatEntry { url: &args.url }],
},
};
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let resp = self
.http
.post(&endpoint)
.headers(headers)
.json(&body)
.send()
.await
.map_err(classify_reqwest_error)?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
if !status.is_success() {
return Err(classify_status_error("safe-browsing", status, &text));
}
let parsed: FindResponse = serde_json::from_str(&text).map_err(|e| {
ExternalGuardError::Transient(format!("parse safe browsing response: {e}"))
})?;
let matched = !parsed.matches.is_empty();
tracing::info!(
guard = GUARD_NAME,
match_count = parsed.matches.len(),
"safe browsing response"
);
Ok(if matched {
Verdict::Deny
} else {
Verdict::Allow
})
}
}