use security_core::classification::DataClassification;
use security_core::severity::SecuritySeverity;
use security_events::event::{EventOutcome, SecurityEvent};
use security_events::kind::EventKind;
use std::fmt;
const DANGEROUS_SCHEMES: &[&str] = &["javascript", "data", "blob", "vbscript"];
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PlatformRejection {
InvalidScheme,
DangerousScheme,
PathTraversal,
UntrustedHost,
FileAccessBlocked,
MalformedUrl,
}
impl fmt::Display for PlatformRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidScheme => write!(f, "URL scheme not in allowed list"),
Self::DangerousScheme => write!(f, "dangerous URL scheme blocked"),
Self::PathTraversal => write!(f, "path traversal detected in URL"),
Self::UntrustedHost => write!(f, "URL host not in trusted list"),
Self::FileAccessBlocked => write!(f, "file:// URL blocked in WebView"),
Self::MalformedUrl => write!(f, "malformed URL"),
}
}
}
impl std::error::Error for PlatformRejection {}
fn extract_scheme(url: &str) -> Option<&str> {
let colon = url.find(':')?;
let scheme = &url[..colon];
if scheme.is_empty() {
return None;
}
if scheme
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '-' || c == '.')
{
Some(scheme)
} else {
None
}
}
fn extract_host(url: &str) -> Option<&str> {
let after_scheme = url.find("://").map(|i| i + 3)?;
let rest = &url[after_scheme..];
if rest.is_empty() {
return None;
}
let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let host_with_port = &rest[..host_end];
let host = match host_with_port.rfind(':') {
Some(pos)
if host_with_port[pos + 1..]
.chars()
.all(|c| c.is_ascii_digit()) =>
{
&host_with_port[..pos]
}
_ => host_with_port,
};
if host.is_empty() {
None
} else {
Some(host)
}
}
fn has_path_traversal(url: &str) -> bool {
let path = if let Some(idx) = url.find("://") {
let after = &url[idx + 3..];
after.find('/').map(|i| &after[i..]).unwrap_or("")
} else if let Some(idx) = url.find(':') {
&url[idx + 1..]
} else {
url
};
path.contains("../")
|| path.contains("..\\")
|| path == ".."
|| path.ends_with("/..")
|| path.ends_with("\\..")
|| {
let lower = path.to_lowercase();
lower.contains("%2e%2e") || lower.contains("..%2f") || lower.contains("..%5c")
}
}
fn is_dangerous_scheme(scheme: &str) -> bool {
let lower = scheme.to_lowercase();
DANGEROUS_SCHEMES.iter().any(|&s| lower == s)
}
fn make_violation_event() -> SecurityEvent {
SecurityEvent::new(
EventKind::PlatformSafetyViolation,
SecuritySeverity::High,
EventOutcome::Blocked,
)
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SafeDeepLink(String);
impl SafeDeepLink {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl fmt::Display for SafeDeepLink {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Debug)]
pub struct DeepLinkValidator {
allowed_schemes: Vec<String>,
allowed_hosts: Option<Vec<String>>,
}
impl DeepLinkValidator {
#[must_use]
pub fn new(allowed_schemes: &[&str]) -> Self {
Self {
allowed_schemes: allowed_schemes.iter().map(|s| s.to_lowercase()).collect(),
allowed_hosts: None,
}
}
#[must_use]
pub fn with_allowed_hosts(mut self, hosts: &[&str]) -> Self {
self.allowed_hosts = Some(hosts.iter().map(|h| h.to_lowercase()).collect());
self
}
pub fn validate(&self, url: &str) -> Result<SafeDeepLink, PlatformRejection> {
let scheme = extract_scheme(url).ok_or(PlatformRejection::MalformedUrl)?;
if is_dangerous_scheme(scheme) {
return Err(PlatformRejection::DangerousScheme);
}
if !self.allowed_schemes.contains(&scheme.to_lowercase()) {
return Err(PlatformRejection::InvalidScheme);
}
if has_path_traversal(url) {
return Err(PlatformRejection::PathTraversal);
}
if let Some(ref allowed_hosts) = self.allowed_hosts {
let host = extract_host(url).ok_or(PlatformRejection::UntrustedHost)?;
if !allowed_hosts.contains(&host.to_lowercase()) {
return Err(PlatformRejection::UntrustedHost);
}
}
Ok(SafeDeepLink(url.to_owned()))
}
pub fn validate_with_events(
&self,
url: &str,
) -> (Result<SafeDeepLink, PlatformRejection>, Vec<SecurityEvent>) {
let result = self.validate(url);
let events = if result.is_err() {
vec![make_violation_event()]
} else {
vec![]
};
(result, events)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ClipboardPolicy {
local_only: bool,
expiration_secs: Option<u64>,
}
impl ClipboardPolicy {
#[must_use]
pub fn for_classification(class: DataClassification) -> Self {
match class {
DataClassification::Public | DataClassification::Internal => Self {
local_only: false,
expiration_secs: None,
},
DataClassification::Confidential
| DataClassification::PII
| DataClassification::Regulated => Self {
local_only: true,
expiration_secs: None,
},
DataClassification::Secret | DataClassification::Credentials => Self {
local_only: true,
expiration_secs: Some(60),
},
_ => Self {
local_only: true,
expiration_secs: Some(60),
},
}
}
#[must_use]
pub fn restrict_to_local_device(&self) -> bool {
self.local_only
}
#[must_use]
pub fn expiration_seconds(&self) -> Option<u64> {
self.expiration_secs
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SafeWebViewUrl(String);
impl SafeWebViewUrl {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl fmt::Display for SafeWebViewUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Debug)]
pub struct WebViewUrlValidator {
allowed_domains: Option<Vec<String>>,
}
impl WebViewUrlValidator {
#[must_use]
pub fn new() -> Self {
Self {
allowed_domains: None,
}
}
#[must_use]
pub fn with_allowed_domains(mut self, domains: &[&str]) -> Self {
self.allowed_domains = Some(domains.iter().map(|d| d.to_lowercase()).collect());
self
}
pub fn validate(&self, url: &str) -> Result<SafeWebViewUrl, PlatformRejection> {
let scheme = extract_scheme(url).ok_or(PlatformRejection::MalformedUrl)?;
let lower_scheme = scheme.to_lowercase();
if lower_scheme == "file" {
return Err(PlatformRejection::FileAccessBlocked);
}
if is_dangerous_scheme(scheme) {
return Err(PlatformRejection::DangerousScheme);
}
if lower_scheme != "http" && lower_scheme != "https" {
return Err(PlatformRejection::InvalidScheme);
}
if let Some(ref allowed_domains) = self.allowed_domains {
let host = extract_host(url).ok_or(PlatformRejection::UntrustedHost)?;
if !allowed_domains.contains(&host.to_lowercase()) {
return Err(PlatformRejection::UntrustedHost);
}
}
Ok(SafeWebViewUrl(url.to_owned()))
}
pub fn validate_with_events(
&self,
url: &str,
) -> (
Result<SafeWebViewUrl, PlatformRejection>,
Vec<SecurityEvent>,
) {
let result = self.validate(url);
let events = if result.is_err() {
vec![make_violation_event()]
} else {
vec![]
};
(result, events)
}
}
impl Default for WebViewUrlValidator {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ScreenshotPolicy {
prevent: bool,
}
impl ScreenshotPolicy {
#[must_use]
pub fn prevent() -> Self {
Self { prevent: true }
}
#[must_use]
pub fn allow() -> Self {
Self { prevent: false }
}
#[must_use]
pub fn for_classification(class: DataClassification) -> Self {
match class {
DataClassification::Public | DataClassification::Internal => Self::allow(),
_ => Self::prevent(),
}
}
#[must_use]
pub fn should_prevent_screenshot(&self) -> bool {
self.prevent
}
}