use std::net::IpAddr;
use std::str::FromStr;
use std::time::Duration;
use crate::abi::{self, FastlyStatus};
use crate::http::body::handle::BodyHandle;
use crate::http::request::handle::RequestHandle;
use crate::http::request::Request;
use crate::Response;
use lazy_static::lazy_static;
use serde::Deserialize;
lazy_static! {
static ref NULL_BODY_HANDLE: BodyHandle = BodyHandle::new();
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum InspectError {
#[error("Failed to deserialize security response: {0}")]
DeserializeError(serde_json::Error),
#[error("Response from Security was not valid UTF-8")]
InvalidBytes(std::str::Utf8Error, Vec<u8>),
#[error("Invalid Argument in Configuration")]
InvalidConfig,
#[error("Inspection request completed without any verdict")]
NoVerdict,
#[error("Failed to send request to security module: {0:?}")]
RequestError(FastlyStatus),
#[error("Buffer ({0} bytes) was not large enough to fit response")]
BufferSizeError(usize),
}
#[derive(Debug)]
pub enum InspectVerdict {
Allow,
Block,
Unauthorized,
Other(String),
}
impl FromStr for InspectVerdict {
type Err = std::convert::Infallible;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let verdict = match input {
"allow" => InspectVerdict::Allow,
"block" => InspectVerdict::Block,
"unauthorized" => InspectVerdict::Unauthorized,
_ => InspectVerdict::Other(input.to_string()),
};
Ok(verdict)
}
}
impl<'de> Deserialize<'de> for InspectVerdict {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[derive(Deserialize, Debug)]
pub struct InspectResponse {
#[serde(alias = "waf_response")]
response: i16,
redirect_url: String,
tags: Vec<String>,
verdict: InspectVerdict,
decision_ms: u64,
}
impl InspectResponse {
#[deprecated(note = "Please use InspectResponse::status instead")]
#[doc(hidden)]
pub fn waf_response(&self) -> i16 {
self.response
}
pub fn status(&self) -> i16 {
self.response
}
pub fn redirect_url(&self) -> Option<&str> {
if matches!(self.verdict, InspectVerdict::Block) && !self.redirect_url.is_empty() {
Some(self.redirect_url.as_str())
} else {
None
}
}
pub fn tags(&self) -> Vec<&str> {
self.tags.iter().map(String::as_str).collect()
}
pub fn verdict(&self) -> &InspectVerdict {
&self.verdict
}
pub fn decision_ms(&self) -> Duration {
Duration::from_millis(self.decision_ms)
}
pub fn is_redirect(&self) -> bool {
(300..=309).contains(&self.response) && self.redirect_url().is_some()
}
pub fn into_redirect(self) -> Option<Response> {
let response = self.response as u16;
self.redirect_url()
.map(|url| Response::from_status(response).with_header(http::header::LOCATION, url))
}
}
pub struct InspectConfig<'a> {
corp: Option<String>,
workspace: Option<String>,
req_handle: &'a RequestHandle,
body_handle: &'a BodyHandle,
buffer_size: usize,
override_client_ip: Option<IpAddr>,
}
impl<'a> InspectConfig<'a> {
#[deprecated(
since = "0.11.7",
note = "Use InspectConfig::from_handles or InspectConfig::from_request instead"
)]
pub fn new(req_handle: &'a RequestHandle, body_handle: &'a BodyHandle) -> Self {
Self::from_handles(req_handle, body_handle)
}
pub fn from_handles(req_handle: &'a RequestHandle, body_handle: &'a BodyHandle) -> Self {
Self {
corp: None,
workspace: None,
req_handle,
body_handle,
buffer_size: 16 * 1024, override_client_ip: None,
}
}
pub fn from_request(req: &'a Request) -> Self {
let (rh, bh) = req.get_handles();
let bh = bh.unwrap_or(&NULL_BODY_HANDLE);
Self::from_handles(rh, bh)
}
pub fn client_ip(mut self, ip: IpAddr) -> Self {
self.override_client_ip = Some(ip);
self
}
pub fn corp(mut self, name: impl ToString) -> Self {
self.corp = name.to_string().into();
self
}
pub fn workspace(mut self, name: impl ToString) -> Self {
self.workspace = name.to_string().into();
self
}
pub fn buffer_size(self, buffer_size: usize) -> Self {
Self {
buffer_size,
..self
}
}
}
pub fn inspect(config: InspectConfig) -> Result<InspectResponse, InspectError> {
use abi::fastly_http_req::{InspectInfo, InspectInfoMask};
let mut add_info_mask = InspectInfoMask::empty();
let mut add_info = InspectInfo::default();
if let Some(corp) = config.corp.as_deref() {
add_info.corp = corp.as_ptr();
add_info.corp_len = corp.len() as u32;
add_info_mask.insert(InspectInfoMask::CORP);
}
if let Some(workspace) = config.workspace.as_deref() {
add_info.workspace = workspace.as_ptr();
add_info.workspace_len = workspace.len() as u32;
add_info_mask.insert(InspectInfoMask::WORKSPACE);
}
let ipv4_bytes;
let ipv6_bytes;
if let Some(ip) = config.override_client_ip {
let bytes = match ip {
IpAddr::V4(ip) => {
ipv4_bytes = ip.octets();
&ipv4_bytes[..]
}
IpAddr::V6(ip) => {
ipv6_bytes = ip.octets();
&ipv6_bytes[..]
}
};
add_info.override_client_ip_ptr = bytes.as_ptr();
add_info.override_client_ip_len = bytes.len() as u32;
add_info_mask.insert(InspectInfoMask::OVERRIDE_CLIENT_IP);
}
let mut buf = vec![0u8; config.buffer_size];
let mut nwritten = 0;
let status = unsafe {
abi::fastly_http_req::inspect(
config.req_handle.as_u32(),
config.body_handle.as_u32(),
add_info_mask,
&add_info,
buf.as_mut_ptr(),
buf.capacity(),
&mut nwritten,
)
};
match status {
FastlyStatus::OK => {
if nwritten == 0 {
return Err(InspectError::NoVerdict);
}
unsafe {
buf.set_len(nwritten);
}
}
FastlyStatus::INVAL => return Err(InspectError::InvalidConfig),
FastlyStatus::BUFLEN => return Err(InspectError::BufferSizeError(buf.capacity())),
status => return Err(InspectError::RequestError(status)),
}
let s = match std::str::from_utf8(&buf) {
Ok(s) => s,
Err(e) => return Err(InspectError::InvalidBytes(e, buf)),
};
serde_json::from_str(s).map_err(InspectError::DeserializeError)
}