use std::{
num::NonZeroU64,
time::{SystemTime, UNIX_EPOCH},
};
use super::rand;
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, Default)]
pub enum Extension {
#[default]
None,
TTL(u64),
Range(u64),
Session(u64),
}
impl Extension {
const EXTENSION_TTL: &'static str = "-ttl-";
const EXTENSION_SESSION: &'static str = "-session-";
const EXTENSION_RANGE_SESSION: &'static str = "-range-";
#[inline]
pub async fn try_from<O>(prefix: &str, full: O) -> crate::Result<Extension>
where
O: Into<String>,
{
let full = full.into();
let prefix = prefix.to_owned();
tokio::task::spawn_blocking(move || parser(prefix, full))
.await
.map_err(Into::into)
}
}
#[inline]
fn parser(prefix: String, full: String) -> Extension {
if let Some(extracted_tag) = full.strip_prefix(&prefix) {
if let Some(extension) = parse_extension(
false,
&full,
Extension::EXTENSION_SESSION,
parse_session_extension,
) {
return extension;
}
if let Some(extension) = parse_extension(
true,
extracted_tag,
Extension::EXTENSION_TTL,
parse_ttl_extension,
) {
return extension;
}
if let Some(extension) = parse_extension(
true,
extracted_tag,
Extension::EXTENSION_RANGE_SESSION,
parse_range_extension,
) {
return extension;
}
}
Extension::None
}
#[tracing::instrument(level = "trace", skip(handler))]
#[inline]
fn parse_extension(
trim: bool,
s: &str,
prefix: &str,
handler: fn(&str) -> Extension,
) -> Option<Extension> {
if !s.contains(prefix) {
return None;
}
let s = if trim {
s.trim_start_matches(prefix)
} else {
s
};
let extension = handler(s);
tracing::trace!("Extension: {:?}", extension);
Some(extension)
}
#[inline(always)]
fn parse_range_extension(s: &str) -> Extension {
let hash = fxhash::hash64(s.as_bytes());
Extension::Range(hash)
}
#[inline(always)]
fn parse_session_extension(s: &str) -> Extension {
let hash = fxhash::hash64(s.as_bytes());
Extension::Session(hash)
}
#[inline]
fn parse_ttl_extension(s: &str) -> Extension {
if let Ok(Some(ttl)) = s.parse::<u64>().map(NonZeroU64::new) {
let start = SystemTime::now();
let timestamp = start
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(rand::random_u64());
let time = timestamp - (timestamp % ttl.get());
let hash = fxhash::hash64(&time.to_be_bytes());
return Extension::TTL(hash);
}
Extension::None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_ttl_extension_zero() {
assert!(matches!(parse_ttl_extension("0"), Extension::None));
}
#[test]
fn test_parse_ttl_extension_nonzero() {
let ext = parse_ttl_extension("60");
match ext {
Extension::TTL(_) => {}
_ => panic!("Expected Extension::TTL"),
}
}
#[test]
fn test_parse_ttl_extension_invalid() {
assert!(matches!(parse_ttl_extension("abc"), Extension::None));
}
}