use crate::{
attack_signal::{BoundaryViolation, ViolationKind},
error::BoundaryRejection,
};
use serde::{Deserialize, Deserializer};
use std::fmt;
fn emit_violation(kind: ViolationKind, code: &'static str) {
BoundaryViolation::new(kind, code).emit();
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SafePath(String);
impl SafePath {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<&str> for SafePath {
type Error = BoundaryRejection;
fn try_from(s: &str) -> Result<Self, Self::Error> {
if s.contains('\0') {
emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
return Err(BoundaryRejection::PathTraversal);
}
if s.starts_with('/') || s.starts_with('\\') {
emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
return Err(BoundaryRejection::PathTraversal);
}
if s.contains("../")
|| s.contains("..\\")
|| s == ".."
|| s.ends_with("/..")
|| s.ends_with("\\..")
{
emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
return Err(BoundaryRejection::PathTraversal);
}
let lower = s.to_lowercase();
if lower.contains("%2e%2e")
|| lower.contains("%2f")
|| lower.contains("%5c")
|| lower.contains("..%2f")
|| lower.contains("..%5c")
{
emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
return Err(BoundaryRejection::PathTraversal);
}
Ok(Self(s.to_owned()))
}
}
impl<'de> Deserialize<'de> for SafePath {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
SafePath::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl fmt::Display for SafePath {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SafeFilename(String);
impl SafeFilename {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<&str> for SafeFilename {
type Error = BoundaryRejection;
fn try_from(s: &str) -> Result<Self, Self::Error> {
let reject = || {
emit_violation(ViolationKind::SyntaxViolation, "invalid_filename");
BoundaryRejection::InjectionAttempt {
code: "invalid_filename",
}
};
if s.is_empty() {
return Err(reject());
}
if s.contains('\0') {
return Err(reject());
}
if s.contains('/') || s.contains('\\') {
return Err(reject());
}
if s == ".." || s.starts_with("../") || s.starts_with("..\\") {
return Err(reject());
}
if s.chars()
.any(|c| matches!(c, ';' | '|' | '&' | '`' | '$' | '>' | '<'))
{
return Err(reject());
}
Ok(Self(s.to_owned()))
}
}
impl<'de> Deserialize<'de> for SafeFilename {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
SafeFilename::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl fmt::Display for SafeFilename {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SafeCommandArg(String);
impl SafeCommandArg {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<&str> for SafeCommandArg {
type Error = BoundaryRejection;
fn try_from(s: &str) -> Result<Self, Self::Error> {
let reject = || {
emit_violation(ViolationKind::SyntaxViolation, "command_injection");
BoundaryRejection::InjectionAttempt {
code: "command_injection",
}
};
if s.chars()
.any(|c| matches!(c, ';' | '|' | '&' | '`' | '$' | '>' | '<' | '\n' | '\r'))
{
return Err(reject());
}
Ok(Self(s.to_owned()))
}
}
impl<'de> Deserialize<'de> for SafeCommandArg {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
SafeCommandArg::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl fmt::Display for SafeCommandArg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SafeUrl(String);
impl SafeUrl {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<&str> for SafeUrl {
type Error = BoundaryRejection;
fn try_from(s: &str) -> Result<Self, Self::Error> {
let reject = || {
emit_violation(ViolationKind::SyntaxViolation, "ssrf_attempt");
BoundaryRejection::SsrfAttempt
};
let lower = s.to_lowercase();
let is_http = lower.starts_with("http://");
let is_https = lower.starts_with("https://");
if !is_http && !is_https {
return Err(reject());
}
let prefix_len = if is_https {
"https://".len()
} else {
"http://".len()
};
let rest = &s[prefix_len..];
let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let host_with_port = &rest[..host_end];
let host = if host_with_port.starts_with('[') {
let bracket_end = host_with_port
.find(']')
.map(|i| i + 1)
.unwrap_or(host_with_port.len());
&host_with_port[..bracket_end]
} else {
match host_with_port.rfind(':') {
Some(pos)
if host_with_port[pos + 1..]
.chars()
.all(|c| c.is_ascii_digit()) =>
{
&host_with_port[..pos]
}
_ => host_with_port,
}
};
if is_private_ip(host) {
return Err(reject());
}
Ok(Self(s.to_owned()))
}
}
impl<'de> Deserialize<'de> for SafeUrl {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
SafeUrl::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl fmt::Display for SafeUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
fn is_private_ip(host: &str) -> bool {
let host = host.trim_matches(|c| c == '[' || c == ']');
if let Ok(addr) = host.parse::<std::net::Ipv4Addr>() {
return is_private_ipv4(addr);
}
if let Ok(addr) = host.parse::<std::net::Ipv6Addr>() {
return is_private_ipv6(addr);
}
false
}
fn is_private_ipv4(addr: std::net::Ipv4Addr) -> bool {
let o = addr.octets();
o[0] == 127
|| o[0] == 10
|| (o[0] == 172 && o[1] >= 16 && o[1] <= 31)
|| (o[0] == 192 && o[1] == 168)
|| (o[0] == 169 && o[1] == 254)
|| (o[0] >= 224 && o[0] <= 239)
|| (o[0] == 0 && o[1] == 0 && o[2] == 0 && o[3] == 0)
}
fn is_private_ipv6(addr: std::net::Ipv6Addr) -> bool {
addr.is_loopback()
|| addr.is_unspecified()
|| (addr.segments()[0] & 0xfe00) == 0xfc00
|| (addr.segments()[0] & 0xffc0) == 0xfe80
|| (addr.segments()[0] & 0xff00) == 0xff00
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SafeRedirectUrl(String);
impl SafeRedirectUrl {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<&str> for SafeRedirectUrl {
type Error = BoundaryRejection;
fn try_from(s: &str) -> Result<Self, Self::Error> {
let reject = || {
emit_violation(ViolationKind::SyntaxViolation, "invalid_redirect");
BoundaryRejection::InjectionAttempt {
code: "invalid_redirect",
}
};
if !s.starts_with('/') || s.starts_with("//") {
return Err(reject());
}
if s.contains(':') {
return Err(reject());
}
Ok(Self(s.to_owned()))
}
}
impl<'de> Deserialize<'de> for SafeRedirectUrl {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
SafeRedirectUrl::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl fmt::Display for SafeRedirectUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SqlIdentifier(String);
impl SqlIdentifier {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<&str> for SqlIdentifier {
type Error = BoundaryRejection;
fn try_from(s: &str) -> Result<Self, Self::Error> {
let reject = || {
emit_violation(ViolationKind::SyntaxViolation, "invalid_sql_identifier");
BoundaryRejection::InjectionAttempt {
code: "invalid_sql_identifier",
}
};
if s.is_empty() {
return Err(reject());
}
if s.len() > 128 {
return Err(reject());
}
let mut chars = s.chars();
let first = chars.next().expect("non-empty string has a first char");
if !first.is_ascii_alphabetic() && first != '_' {
return Err(reject());
}
if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
return Err(reject());
}
Ok(Self(s.to_owned()))
}
}
impl<'de> Deserialize<'de> for SqlIdentifier {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
SqlIdentifier::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl fmt::Display for SqlIdentifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct LdapSafeString(String);
impl LdapSafeString {
#[must_use]
pub fn as_inner(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<&str> for LdapSafeString {
type Error = BoundaryRejection;
fn try_from(s: &str) -> Result<Self, Self::Error> {
let mut escaped = String::with_capacity(s.len() * 2);
let mut had_special = false;
for c in s.chars() {
match c {
'\0' => {
escaped.push_str("\\00");
had_special = true;
}
'*' => {
escaped.push_str("\\2a");
had_special = true;
}
'(' => {
escaped.push_str("\\28");
had_special = true;
}
')' => {
escaped.push_str("\\29");
had_special = true;
}
'\\' => {
escaped.push_str("\\5c");
had_special = true;
}
other => escaped.push(other),
}
}
if had_special {
emit_violation(ViolationKind::SyntaxViolation, "ldap_injection_chars");
}
Ok(Self(escaped))
}
}
impl<'de> Deserialize<'de> for LdapSafeString {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
LdapSafeString::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl fmt::Display for LdapSafeString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}