use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::antibot::SessionState;
use crate::config::RenderSessionScope;
pub const DEFAULT_SESSION_TTL_SECS: u64 = 3600;
#[derive(Debug, Clone)]
pub struct SessionEntry {
pub id: String,
pub scope: RenderSessionScope,
pub scope_key: String,
pub bundle_id: Option<u64>,
pub state: SessionState,
pub created_at: Instant,
pub last_used: Instant,
pub created_unix: i64,
pub last_used_unix: i64,
pub ttl_override: Option<Duration>,
pub urls_visited: u32,
pub challenges_seen: u32,
pub proxy_history: Vec<Url>,
}
impl SessionEntry {
pub fn effective_ttl(&self, default: Duration) -> Duration {
self.ttl_override.unwrap_or(default)
}
pub fn as_snapshot(&self) -> SessionSnapshot {
SessionSnapshot {
id: self.id.clone(),
scope: self.scope,
scope_key: self.scope_key.clone(),
bundle_id: self.bundle_id,
state: self.state,
created_unix: self.created_unix,
last_used_unix: self.last_used_unix,
urls_visited: self.urls_visited,
challenges_seen: self.challenges_seen,
proxy_history: self.proxy_history.iter().map(|u| u.to_string()).collect(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSnapshot {
pub id: String,
pub scope: RenderSessionScope,
pub scope_key: String,
pub bundle_id: Option<u64>,
pub state: SessionState,
pub created_unix: i64,
pub last_used_unix: i64,
pub urls_visited: u32,
pub challenges_seen: u32,
pub proxy_history: Vec<String>,
}
pub struct SessionRegistry {
entries: DashMap<String, SessionEntry>,
default_ttl: Duration,
}
impl SessionRegistry {
pub fn new(ttl_secs: u64) -> Self {
let ttl = ttl_secs.max(1);
Self {
entries: DashMap::new(),
default_ttl: Duration::from_secs(ttl),
}
}
pub fn default_ttl(&self) -> Duration {
self.default_ttl
}
pub fn scope_key_for(scope: RenderSessionScope, url: &Url) -> String {
let host = url.host_str().unwrap_or_default();
match scope {
RenderSessionScope::RegistrableDomain => {
crate::discovery::subdomains::registrable_domain(host)
.unwrap_or_else(|| host.to_string())
}
RenderSessionScope::Host => {
if let Some(port) = url.port() {
format!("{host}:{port}")
} else {
host.to_string()
}
}
RenderSessionScope::Origin => url.origin().ascii_serialization(),
RenderSessionScope::Url => url.as_str().to_string(),
}
}
pub fn get_or_create(&self, id: &str, scope: RenderSessionScope, url: &Url) -> SessionEntry {
let now = Instant::now();
let now_unix = now_unix();
let scope_key = Self::scope_key_for(scope, url);
let entry = self
.entries
.entry(id.to_string())
.and_modify(|e| {
e.last_used = now;
e.last_used_unix = now_unix;
e.urls_visited = e.urls_visited.saturating_add(1);
})
.or_insert_with(|| SessionEntry {
id: id.to_string(),
scope,
scope_key: scope_key.clone(),
bundle_id: None,
state: SessionState::Clean,
created_at: now,
last_used: now,
created_unix: now_unix,
last_used_unix: now_unix,
ttl_override: None,
urls_visited: 1,
challenges_seen: 0,
proxy_history: Vec::new(),
})
.clone();
entry
}
pub fn touch(&self, id: &str) {
if let Some(mut e) = self.entries.get_mut(id) {
let now = Instant::now();
e.last_used = now;
e.last_used_unix = now_unix();
}
}
pub fn mark(&self, id: &str, state: SessionState) -> Option<(SessionState, SessionState)> {
let mut e = self.entries.get_mut(id)?;
let from = e.state;
if from == state {
return None;
}
e.state = state;
Some((from, state))
}
pub fn bump_challenge(&self, id: &str) {
if let Some(mut e) = self.entries.get_mut(id) {
e.challenges_seen = e.challenges_seen.saturating_add(1);
}
}
pub fn set_bundle(&self, id: &str, bundle_id: u64) {
if let Some(mut e) = self.entries.get_mut(id) {
e.bundle_id = Some(bundle_id);
}
}
pub fn record_proxy(&self, id: &str, proxy: &Url) {
if let Some(mut e) = self.entries.get_mut(id) {
if !e.proxy_history.iter().any(|p| p == proxy) {
e.proxy_history.push(proxy.clone());
}
}
}
pub fn set_ttl_override(&self, id: &str, ttl: Option<Duration>) {
if let Some(mut e) = self.entries.get_mut(id) {
e.ttl_override = ttl;
}
}
pub fn expired(&self) -> Vec<String> {
let now = Instant::now();
self.entries
.iter()
.filter_map(|r| {
let ttl = r.effective_ttl(self.default_ttl);
if now.duration_since(r.last_used) >= ttl {
Some(r.id.clone())
} else {
None
}
})
.collect()
}
pub fn evict(&self, id: &str) -> Option<SessionEntry> {
self.entries.remove(id).map(|(_, v)| v)
}
pub fn contains(&self, id: &str) -> bool {
self.entries.contains_key(id)
}
pub fn get(&self, id: &str) -> Option<SessionEntry> {
self.entries.get(id).map(|e| e.clone())
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn list(&self, filter: Option<SessionState>) -> Vec<SessionSnapshot> {
self.entries
.iter()
.filter(|e| filter.map(|f| f == e.state).unwrap_or(true))
.map(|e| e.as_snapshot())
.collect()
}
}
fn now_unix() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EvictionReason {
Ttl,
Blocked,
Manual,
RunEnded,
}
impl EvictionReason {
pub fn as_str(&self) -> &'static str {
match self {
Self::Ttl => "ttl",
Self::Blocked => "blocked",
Self::Manual => "manual",
Self::RunEnded => "run_ended",
}
}
}
#[async_trait::async_trait]
pub trait SessionDropTarget: Send + Sync {
async fn drop_session(&self, id: &str);
}
pub fn spawn_cleanup_task(
registry: Arc<SessionRegistry>,
drop_target: Arc<dyn SessionDropTarget>,
archive: Option<Arc<dyn SessionArchive>>,
tick: Duration,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
loop {
tokio::time::sleep(tick).await;
let expired = registry.expired();
for id in expired {
drop_target.drop_session(&id).await;
if let Some(entry) = registry.evict(&id) {
if let Some(sink) = archive.as_ref() {
let _ = sink.archive_session(&entry, EvictionReason::Ttl).await;
}
}
}
}
})
}
#[async_trait::async_trait]
pub trait SessionArchive: Send + Sync {
async fn archive_session(
&self,
entry: &SessionEntry,
reason: EvictionReason,
) -> crate::Result<()>;
}
pub struct StorageArchive(pub std::sync::Arc<dyn crate::storage::StateStorage>);
#[async_trait::async_trait]
impl SessionArchive for StorageArchive {
async fn archive_session(
&self,
entry: &SessionEntry,
reason: EvictionReason,
) -> crate::Result<()> {
self.0.archive_session(entry, reason).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use url::Url;
fn url(s: &str) -> Url {
Url::parse(s).unwrap()
}
#[test]
fn get_or_create_is_idempotent() {
let reg = SessionRegistry::new(60);
let u = url("https://example.com/a");
let a = reg.get_or_create("s1", RenderSessionScope::RegistrableDomain, &u);
let b = reg.get_or_create("s1", RenderSessionScope::RegistrableDomain, &u);
assert_eq!(a.id, b.id);
assert_eq!(b.urls_visited, 2);
}
#[test]
fn mark_transitions_state() {
let reg = SessionRegistry::new(60);
let _ = reg.get_or_create(
"s2",
RenderSessionScope::RegistrableDomain,
&url("https://x.test/"),
);
let change = reg.mark("s2", SessionState::Contaminated);
assert_eq!(
change,
Some((SessionState::Clean, SessionState::Contaminated))
);
assert_eq!(reg.mark("s2", SessionState::Contaminated), None);
}
#[test]
fn expired_detects_ttl() {
let reg = SessionRegistry::new(1);
let _ = reg.get_or_create("s3", RenderSessionScope::Url, &url("https://y.test/"));
reg.set_ttl_override("s3", Some(Duration::from_millis(0)));
std::thread::sleep(std::time::Duration::from_millis(5));
let expired = reg.expired();
assert!(expired.iter().any(|id| id == "s3"));
}
#[test]
fn scope_key_for_picks_right_granularity() {
let u = url("https://www.example.com:8443/path?q=1");
let dom = SessionRegistry::scope_key_for(RenderSessionScope::RegistrableDomain, &u);
assert!(dom.ends_with("example.com"));
let host = SessionRegistry::scope_key_for(RenderSessionScope::Host, &u);
assert_eq!(host, "www.example.com:8443");
let origin = SessionRegistry::scope_key_for(RenderSessionScope::Origin, &u);
assert!(origin.starts_with("https://www.example.com"));
let full = SessionRegistry::scope_key_for(RenderSessionScope::Url, &u);
assert_eq!(full, u.as_str());
}
#[test]
fn list_filters_by_state() {
let reg = SessionRegistry::new(60);
let _ = reg.get_or_create("a", RenderSessionScope::Url, &url("https://a.test/"));
let _ = reg.get_or_create("b", RenderSessionScope::Url, &url("https://b.test/"));
reg.mark("b", SessionState::Blocked);
let blocked = reg.list(Some(SessionState::Blocked));
assert_eq!(blocked.len(), 1);
assert_eq!(blocked[0].id, "b");
let all = reg.list(None);
assert_eq!(all.len(), 2);
}
#[test]
fn drop_removes_entry() {
let reg = SessionRegistry::new(60);
let _ = reg.get_or_create("k", RenderSessionScope::Url, &url("https://k.test/"));
assert!(reg.contains("k"));
let removed = reg.evict("k");
assert!(removed.is_some());
assert!(!reg.contains("k"));
}
}