use crate::http_token_utils::{http_token, HttpToken};
use crate::request_analyzer::{HttpRequestData, RequestAnalysisResult};
use crate::ultralight_rate_limiter::UltraLightRateLimiter;
use crate::{
ClassificationReason, ExtClassificationMetricsSettings, ExtLoggingSettings,
ExtRequestAnalysisMetricsUnit, ExtString, ExtTierMetricsSettings, RequestSafetyTier,
MESSAGE_MAX_SIZE,
};
use core::hash::Hash;
use lazy_static::lazy_static;
use smallvec::SmallVec;
use std::collections::HashMap;
use std::sync::atomic::AtomicPtr;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Instant;
const LOG_MESSAGE_RATE_LIMIT: f64 = 100. / 1.;
const METRICS_VECTOR_SIZE: usize = 128;
pub const MONITORED_METHODS: &[&str] = &[
"GET", "PUT", "POST", "PATCH", "DELETE", "HEAD", "OPTIONS", "OTHER",
];
static LOGGER_SETTINGS: AtomicSettings<LoggingSettings> = AtomicSettings::empty();
static TIER_METRICS_SETTINGS: AtomicSettings<TierMetricsSettings> = AtomicSettings::empty();
static CLASSIFICATION_METRICS_SETTINGS: AtomicSettings<ClassificationMetricsSettings> =
AtomicSettings::empty();
pub struct AtomicSettings<T> {
settings: AtomicPtr<T>,
}
impl<T> AtomicSettings<T> {
pub const fn empty() -> Self {
Self {
settings: AtomicPtr::new(std::ptr::null_mut()),
}
}
#[inline(always)]
pub fn get(&self) -> Option<&'static T> {
let ptr = self.settings.load(Ordering::Relaxed);
if ptr.is_null() {
None
} else {
unsafe { Some(&(*ptr)) }
}
}
#[cfg(test)]
pub fn store(&self, new_value: T) -> Result<(), &str> {
self.settings
.store(Box::into_raw(Box::new(new_value)), Ordering::Relaxed);
Ok(())
}
#[cfg(not(test))]
pub fn store(&self, new_value: T) -> Result<(), &str> {
if self
.settings
.compare_and_swap(
std::ptr::null_mut(),
Box::into_raw(Box::new(new_value)),
Ordering::Relaxed,
)
.is_null()
{
Ok(())
} else {
Err("Can be set only once.")
}
}
}
lazy_static! {
pub static ref TIER_STATS: RequestAnalysisStatisticsStore<RequestSafetyTier> =
RequestAnalysisStatisticsStore::<RequestSafetyTier>::new();
}
lazy_static! {
pub static ref CLASSIFICATION_STATS: RequestAnalysisStatisticsStore<ClassificationReason> =
RequestAnalysisStatisticsStore::<ClassificationReason>::new();
}
lazy_static! {
pub static ref START_TIME: Instant = Instant::now();
}
pub struct LoggingSettings {
rate_limiter: UltraLightRateLimiter,
callback: extern "C" fn(RequestSafetyTier, u32, *const u8),
}
pub struct TierMetricsSettings {
rate_limiter: UltraLightRateLimiter,
callback: extern "C" fn(u32, *const ExtRequestAnalysisMetricsUnit<RequestSafetyTier>),
}
pub struct ClassificationMetricsSettings {
rate_limiter: UltraLightRateLimiter,
callback: extern "C" fn(u32, *const ExtRequestAnalysisMetricsUnit<ClassificationReason>),
}
pub struct RequestAnalysisCounterPerMethod<T>
where
T: Hash + Eq + strum::IntoEnumIterator,
{
method: HttpToken<'static>,
counters_map: HashMap<T, AtomicU32>,
}
pub struct RequestAnalysisMetricsUnit<T>
where
T: Hash + Eq,
{
pub method: HttpToken<'static>,
pub counter_type: T,
pub count: u32,
}
pub struct RequestAnalysisStatisticsStore<T>
where
T: Hash + Eq + strum::IntoEnumIterator,
{
stats: Vec<RequestAnalysisCounterPerMethod<T>>,
}
pub trait CounterType {
fn get_counter_type(result: &RequestAnalysisResult) -> Self;
}
impl CounterType for RequestSafetyTier {
fn get_counter_type(result: &RequestAnalysisResult) -> Self {
result.tier
}
}
impl CounterType for ClassificationReason {
fn get_counter_type(result: &RequestAnalysisResult) -> Self {
result.reason
}
}
pub trait MetricsCallback<T>
where
T: Hash + Eq,
{
fn is_enabled() -> bool;
fn get_metrics_callback() -> Option<extern "C" fn(u32, *const ExtRequestAnalysisMetricsUnit<T>)>;
}
impl MetricsCallback<RequestSafetyTier> for RequestSafetyTier {
#[inline(always)]
fn is_enabled() -> bool {
TierMetricsSettings::get().is_some()
}
#[inline(always)]
fn get_metrics_callback(
) -> Option<extern "C" fn(u32, *const ExtRequestAnalysisMetricsUnit<RequestSafetyTier>)> {
if let Some(settings) = TierMetricsSettings::get() {
settings.rate_limiter.try_acquire_value(settings.callback)
} else {
None
}
}
}
impl MetricsCallback<ClassificationReason> for ClassificationReason {
#[inline(always)]
fn is_enabled() -> bool {
ClassificationMetricsSettings::get().is_some()
}
#[inline(always)]
fn get_metrics_callback(
) -> Option<extern "C" fn(u32, *const ExtRequestAnalysisMetricsUnit<ClassificationReason>)>
{
if let Some(settings) = ClassificationMetricsSettings::get() {
settings.rate_limiter.try_acquire_value(settings.callback)
} else {
None
}
}
}
impl LoggingSettings {
pub fn set(settings: &ExtLoggingSettings) -> Result<(), &str> {
let callback = settings
.callback
.expect("Callback to send logs is required");
let settings = LoggingSettings {
rate_limiter: UltraLightRateLimiter::new(
LOG_MESSAGE_RATE_LIMIT,
get_current_time_millis,
),
callback,
};
LOGGER_SETTINGS.store(settings)
}
#[inline(always)]
pub fn log_message(tier: RequestSafetyTier, msg: &str) {
if let Some(s) = LOGGER_SETTINGS.get() {
s.try_log_message(tier, msg);
}
}
#[inline(always)]
fn try_log_message(&self, tier: RequestSafetyTier, msg: &str) {
if self.rate_limiter.try_acquire() {
(self.callback)(tier, msg.len().min(MESSAGE_MAX_SIZE) as u32, msg.as_ptr());
}
}
}
impl TierMetricsSettings {
pub fn set(settings: &ExtTierMetricsSettings) -> Result<(), &str> {
let period_seconds = settings.period_seconds;
let callback = settings
.callback
.expect("Callback to send tier metrics is required");
let settings = TierMetricsSettings {
rate_limiter: UltraLightRateLimiter::new(
1. / period_seconds as f64,
get_current_time_millis,
),
callback,
};
TIER_METRICS_SETTINGS.store(settings)
}
#[inline(always)]
fn get() -> Option<&'static Self> {
TIER_METRICS_SETTINGS.get()
}
}
impl ClassificationMetricsSettings {
pub fn set(settings: &ExtClassificationMetricsSettings) -> Result<(), &str> {
let period_seconds = settings.period_seconds;
let callback = settings
.callback
.expect("Callback to send classification metrics is required");
let settings = ClassificationMetricsSettings {
rate_limiter: UltraLightRateLimiter::new(
1. / period_seconds as f64,
get_current_time_millis,
),
callback,
};
CLASSIFICATION_METRICS_SETTINGS.store(settings)
}
#[inline(always)]
fn get() -> Option<&'static Self> {
CLASSIFICATION_METRICS_SETTINGS.get()
}
}
impl<T: strum::IntoEnumIterator> RequestAnalysisCounterPerMethod<T>
where
T: Hash + Eq,
{
fn new(method: HttpToken<'static>) -> Self {
let counters_map = T::iter().map(|tier| (tier, Default::default())).collect();
Self {
method,
counters_map,
}
}
}
impl<T> RequestAnalysisStatisticsStore<T>
where
T: Hash + Eq + CounterType + MetricsCallback<T> + strum::IntoEnumIterator + std::fmt::Debug,
{
fn new() -> Self {
let stats = MONITORED_METHODS
.iter()
.map(|method| RequestAnalysisCounterPerMethod::<T>::new(http_token(method)))
.collect();
Self { stats }
}
pub fn update_counters(&self, request: &HttpRequestData, result: &RequestAnalysisResult) {
if !<T as self::MetricsCallback<T>>::is_enabled() {
return;
}
let counters = self
.stats
.iter()
.find(|m| m.method == request.method)
.unwrap_or(&self.stats[self.stats.len() - 1]);
let counter_type = &counters.counters_map[&self::CounterType::get_counter_type(result)];
counter_type.fetch_add(1, Ordering::Relaxed);
self.emit_granular_metrics_event();
}
pub fn get_and_reset_metrics_snapshot(
&self,
) -> SmallVec<[RequestAnalysisMetricsUnit<T>; METRICS_VECTOR_SIZE]> {
let mut metrics_collection = SmallVec::new();
self.stats.iter().for_each(|counter_aggregate| {
for counter_type in T::iter() {
let count =
counter_aggregate.counters_map[&counter_type].fetch_and(0, Ordering::Relaxed);
if count > 0 {
metrics_collection.push(RequestAnalysisMetricsUnit {
method: counter_aggregate.method,
counter_type,
count,
})
}
}
});
metrics_collection
}
fn emit_granular_metrics_event(&self) {
if let Some(metrics_callback) = <T as self::MetricsCallback<T>>::get_metrics_callback() {
let snapshot: SmallVec<[ExtRequestAnalysisMetricsUnit<T>; METRICS_VECTOR_SIZE]> = self
.get_and_reset_metrics_snapshot()
.into_iter()
.map(|item| {
ExtRequestAnalysisMetricsUnit::<T>::new(
ExtString::from_http_token(item.method),
item.counter_type,
item.count,
)
})
.collect();
if !snapshot.is_empty() {
(metrics_callback)(snapshot.len() as u32, snapshot.as_ptr());
}
}
}
}
impl<T> ExtRequestAnalysisMetricsUnit<T>
where
T: Hash + Eq,
{
fn new(method: ExtString, counter_type: T, count: u32) -> Self {
Self {
method,
counter_type,
count,
}
}
}
static LAST_TIME: AtomicU64 = AtomicU64::new(0);
#[cfg(not(test))]
static SKIPS: AtomicU64 = AtomicU64::new(0);
#[cfg(not(test))]
fn get_current_time_millis() -> u64 {
let start_time = *START_TIME;
let x = SKIPS.fetch_add(1, Ordering::Relaxed);
if x > LOG_MESSAGE_RATE_LIMIT as u64 >> 4 {
SKIPS.fetch_and(0, Ordering::Relaxed);
}
if x == 1 {
let current = Instant::now()
.saturating_duration_since(start_time)
.as_millis() as u64;
LAST_TIME.store(current, Ordering::Relaxed);
current
} else {
LAST_TIME.load(Ordering::Relaxed)
}
}
#[cfg(test)]
fn get_current_time_millis() -> u64 {
let current = LAST_TIME.load(Ordering::Relaxed);
LAST_TIME.store(current + 10_000, Ordering::Relaxed);
current
}