use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use dynamo_kv_router::protocols::WorkerWithDpRank;
const REAPER_INTERVAL: Duration = Duration::from_secs(30);
type ExpiryHandler = Arc<dyn Fn(String, u64) + Send + Sync>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AffinityKind {
RouterOnly,
EngineBacked,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AffinityBinding {
pub worker: WorkerWithDpRank,
pub kind: AffinityKind,
}
pub trait AffinityStore: Send + Sync {
fn get(&self, session_id: &str) -> Option<WorkerWithDpRank>;
fn peek(&self, session_id: &str) -> Option<WorkerWithDpRank>;
fn put(&self, session_id: &str, worker: WorkerWithDpRank, ttl: Duration, kind: AffinityKind);
fn remove(&self, session_id: &str) -> Option<AffinityBinding>;
}
struct AffinityEntry {
worker: WorkerWithDpRank,
ttl: Duration,
expires_at: Instant,
kind: AffinityKind,
}
impl AffinityEntry {
fn binding(&self) -> AffinityBinding {
AffinityBinding {
worker: self.worker,
kind: self.kind,
}
}
}
#[derive(Clone)]
pub struct InMemoryAffinityStore {
map: Arc<DashMap<String, AffinityEntry>>,
on_expire: Option<ExpiryHandler>,
}
impl Default for InMemoryAffinityStore {
fn default() -> Self {
Self::new()
}
}
impl InMemoryAffinityStore {
pub fn new() -> Self {
Self::new_with_on_expire(None)
}
pub fn new_with_on_expire(on_expire: Option<ExpiryHandler>) -> Self {
let map = Arc::new(DashMap::new());
let store = InMemoryAffinityStore { map, on_expire };
let reaper_store = store.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(REAPER_INTERVAL);
loop {
interval.tick().await;
reaper_store.reap_expired(Instant::now());
}
});
store
}
fn reap_expired(&self, now: Instant) {
let on_expire = self.on_expire.clone();
self.map.retain(|session_id, entry: &mut AffinityEntry| {
let alive = entry.expires_at > now;
if !alive {
tracing::debug!(%session_id, "Session affinity expired, removing");
if entry.kind == AffinityKind::EngineBacked
&& let Some(handler) = &on_expire
{
handler(session_id.clone(), entry.worker.worker_id);
}
}
alive
});
}
fn lookup(&self, session_id: &str, refresh: bool) -> Option<WorkerWithDpRank> {
let now = Instant::now();
let mut entry = self.map.get_mut(session_id)?;
if entry.expires_at <= now {
let binding = entry.binding();
let expires_at = entry.expires_at;
drop(entry);
self.remove_expired_if_current(session_id, binding, expires_at);
return None;
}
let worker = entry.worker;
if refresh {
entry.expires_at = now + entry.ttl;
}
tracing::info!(
%session_id,
worker_id = worker.worker_id,
dp_rank = worker.dp_rank,
refreshed = refresh,
"Sticky session hit"
);
Some(worker)
}
fn remove_expired_if_current(
&self,
session_id: &str,
binding: AffinityBinding,
expires_at: Instant,
) {
let removed = self.map.remove_if(session_id, |_, entry| {
entry.worker == binding.worker
&& entry.expires_at == expires_at
&& entry.expires_at <= Instant::now()
});
if removed.is_none() {
return;
}
tracing::debug!(%session_id, "Session affinity expired during resolve");
if binding.kind == AffinityKind::EngineBacked
&& let Some(handler) = &self.on_expire
{
handler(session_id.to_owned(), binding.worker.worker_id);
}
}
}
impl AffinityStore for InMemoryAffinityStore {
fn get(&self, session_id: &str) -> Option<WorkerWithDpRank> {
self.lookup(session_id, true)
}
fn peek(&self, session_id: &str) -> Option<WorkerWithDpRank> {
self.lookup(session_id, false)
}
fn put(&self, session_id: &str, worker: WorkerWithDpRank, ttl: Duration, kind: AffinityKind) {
self.map.insert(
session_id.to_owned(),
AffinityEntry {
worker,
ttl,
expires_at: Instant::now() + ttl,
kind,
},
);
}
fn remove(&self, session_id: &str) -> Option<AffinityBinding> {
self.map
.remove(session_id)
.map(|(_, entry)| entry.binding())
}
}
pub struct StickySessionRouter {
store: Box<dyn AffinityStore>,
}
impl StickySessionRouter {
pub fn new(store: impl AffinityStore + 'static) -> Self {
tracing::debug!("StickySessionRouter initialized");
StickySessionRouter {
store: Box::new(store),
}
}
pub fn resolve_session(&self, session_id: &str) -> Option<WorkerWithDpRank> {
self.store.get(session_id)
}
pub fn peek_session(&self, session_id: &str) -> Option<WorkerWithDpRank> {
self.store.peek(session_id)
}
pub fn bind_router_only(&self, session_id: &str, worker: WorkerWithDpRank, ttl: Duration) {
self.bind_with_kind(session_id, worker, ttl, AffinityKind::RouterOnly);
}
pub fn bind_engine_session(&self, session_id: &str, worker: WorkerWithDpRank, ttl: Duration) {
self.bind_with_kind(session_id, worker, ttl, AffinityKind::EngineBacked);
}
fn bind_with_kind(
&self,
session_id: &str,
worker: WorkerWithDpRank,
ttl: Duration,
kind: AffinityKind,
) {
tracing::info!(
%session_id,
worker_id = worker.worker_id,
dp_rank = worker.dp_rank,
ttl_secs = ttl.as_secs(),
kind = ?kind,
"Binding session affinity"
);
self.store.put(session_id, worker, ttl, kind);
}
pub fn unbind(&self, session_id: &str) -> Option<AffinityBinding> {
tracing::info!(%session_id, "Removing session affinity");
self.store.remove(session_id)
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
fn worker(worker_id: u64, dp_rank: u32) -> WorkerWithDpRank {
WorkerWithDpRank::new(worker_id, dp_rank)
}
#[test]
fn resolve_returns_none_for_unknown_session() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
let router = StickySessionRouter::new(store);
assert!(router.resolve_session("unknown-session").is_none());
}
#[test]
fn bind_then_resolve_returns_worker_and_dp_rank() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
let router = StickySessionRouter::new(store);
router.bind_engine_session("sess-1", worker(42, 3), Duration::from_secs(300));
assert_eq!(router.resolve_session("sess-1"), Some(worker(42, 3)));
}
#[test]
fn peek_returns_worker_without_refreshing_ttl() {
let map = Arc::new(DashMap::new());
let ttl = Duration::from_secs(60);
let expires_at = Instant::now() + Duration::from_secs(5);
map.insert(
"sess-peek".to_owned(),
AffinityEntry {
worker: worker(7, 2),
ttl,
expires_at,
kind: AffinityKind::EngineBacked,
},
);
let store = InMemoryAffinityStore {
map: map.clone(),
on_expire: None,
};
let router = StickySessionRouter::new(store);
assert_eq!(router.peek_session("sess-peek"), Some(worker(7, 2)));
let entry = map.get("sess-peek").unwrap();
assert_eq!(entry.expires_at, expires_at);
}
#[test]
fn bind_overwrites_worker_rank_and_ttl() {
let map = Arc::new(DashMap::new());
let store = InMemoryAffinityStore {
map: map.clone(),
on_expire: None,
};
let router = StickySessionRouter::new(store);
router.bind_engine_session("sess-1", worker(1, 0), Duration::from_secs(10));
router.bind_router_only("sess-1", worker(2, 3), Duration::from_secs(90));
assert_eq!(router.peek_session("sess-1"), Some(worker(2, 3)));
let entry = map.get("sess-1").unwrap();
assert_eq!(entry.worker, worker(2, 3));
assert_eq!(entry.ttl, Duration::from_secs(90));
assert_eq!(entry.kind, AffinityKind::RouterOnly);
assert!(entry.expires_at > Instant::now() + Duration::from_secs(80));
}
#[test]
fn unbind_removes_affinity() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
let router = StickySessionRouter::new(store);
router.bind_engine_session("sess-1", worker(42, 1), Duration::from_secs(300));
assert_eq!(
router.unbind("sess-1"),
Some(AffinityBinding {
worker: worker(42, 1),
kind: AffinityKind::EngineBacked,
})
);
assert!(router.resolve_session("sess-1").is_none());
}
#[test]
fn expired_entry_returns_none() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
store.map.insert(
"sess-expired".to_owned(),
AffinityEntry {
worker: worker(99, 0),
ttl: Duration::from_secs(0),
expires_at: Instant::now() - Duration::from_secs(1),
kind: AffinityKind::EngineBacked,
},
);
let router = StickySessionRouter::new(store);
assert!(router.resolve_session("sess-expired").is_none());
assert!(router.store.peek("sess-expired").is_none());
}
#[test]
fn resolve_refreshes_ttl() {
let map = Arc::new(DashMap::new());
let ttl = Duration::from_secs(60);
map.insert(
"sess-refresh".to_owned(),
AffinityEntry {
worker: worker(7, 2),
ttl,
expires_at: Instant::now() + Duration::from_secs(5),
kind: AffinityKind::EngineBacked,
},
);
let store = InMemoryAffinityStore {
map: map.clone(),
on_expire: None,
};
let router = StickySessionRouter::new(store);
assert_eq!(router.resolve_session("sess-refresh"), Some(worker(7, 2)));
let entry = map.get("sess-refresh").unwrap();
let remaining = entry.expires_at.duration_since(Instant::now());
assert!(
remaining > Duration::from_secs(50),
"TTL should have been refreshed, but remaining={remaining:?}"
);
}
#[test]
fn expired_entry_triggers_close_callback_on_resolve() {
let expired_sessions = Arc::new(Mutex::new(Vec::new()));
let on_expire = {
let expired_sessions = expired_sessions.clone();
Arc::new(move |session_id: String, worker_id: u64| {
expired_sessions
.lock()
.unwrap()
.push((session_id, worker_id));
})
};
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: Some(on_expire),
};
store.map.insert(
"sess-expired".to_owned(),
AffinityEntry {
worker: worker(99, 0),
ttl: Duration::from_secs(0),
expires_at: Instant::now() - Duration::from_secs(1),
kind: AffinityKind::EngineBacked,
},
);
let router = StickySessionRouter::new(store);
assert!(router.resolve_session("sess-expired").is_none());
assert_eq!(
expired_sessions.lock().unwrap().as_slice(),
&[("sess-expired".to_string(), 99)]
);
}
#[test]
fn expired_router_only_entry_drops_without_close_callback_on_resolve() {
let expired_sessions = Arc::new(Mutex::new(Vec::new()));
let on_expire = {
let expired_sessions = expired_sessions.clone();
Arc::new(move |session_id: String, worker_id: u64| {
expired_sessions
.lock()
.unwrap()
.push((session_id, worker_id));
})
};
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: Some(on_expire),
};
store.map.insert(
"sess-router-only".to_owned(),
AffinityEntry {
worker: worker(11, 0),
ttl: Duration::from_secs(0),
expires_at: Instant::now() - Duration::from_secs(1),
kind: AffinityKind::RouterOnly,
},
);
let router = StickySessionRouter::new(store);
assert!(router.resolve_session("sess-router-only").is_none());
assert!(expired_sessions.lock().unwrap().is_empty());
}
#[test]
fn expired_lookup_does_not_remove_newer_binding() {
let expired_sessions = Arc::new(Mutex::new(Vec::new()));
let on_expire = {
let expired_sessions = expired_sessions.clone();
Arc::new(move |session_id: String, worker_id: u64| {
expired_sessions
.lock()
.unwrap()
.push((session_id, worker_id));
})
};
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: Some(on_expire),
};
store.map.insert(
"sess-race".to_owned(),
AffinityEntry {
worker: worker(1, 0),
ttl: Duration::from_secs(1),
expires_at: Instant::now() - Duration::from_secs(1),
kind: AffinityKind::EngineBacked,
},
);
let stale = store.map.get("sess-race").unwrap();
let stale_binding = stale.binding();
let stale_expires_at = stale.expires_at;
drop(stale);
store.put(
"sess-race",
worker(2, 1),
Duration::from_secs(300),
AffinityKind::EngineBacked,
);
store.remove_expired_if_current("sess-race", stale_binding, stale_expires_at);
assert_eq!(store.peek("sess-race"), Some(worker(2, 1)));
assert!(expired_sessions.lock().unwrap().is_empty());
}
#[test]
fn reaper_triggers_close_callback_for_expired_entry() {
let expired_sessions = Arc::new(Mutex::new(Vec::new()));
let on_expire = {
let expired_sessions = expired_sessions.clone();
Arc::new(move |session_id: String, worker_id: u64| {
expired_sessions
.lock()
.unwrap()
.push((session_id, worker_id));
})
};
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: Some(on_expire),
};
store.map.insert(
"sess-reaped".to_owned(),
AffinityEntry {
worker: worker(17, 0),
ttl: Duration::from_secs(30),
expires_at: Instant::now() - Duration::from_secs(1),
kind: AffinityKind::EngineBacked,
},
);
store.reap_expired(Instant::now());
assert!(store.map.get("sess-reaped").is_none());
assert_eq!(
expired_sessions.lock().unwrap().as_slice(),
&[("sess-reaped".to_string(), 17)]
);
}
#[test]
fn reaper_drops_router_only_entry_without_close_callback() {
let expired_sessions = Arc::new(Mutex::new(Vec::new()));
let on_expire = {
let expired_sessions = expired_sessions.clone();
Arc::new(move |session_id: String, worker_id: u64| {
expired_sessions
.lock()
.unwrap()
.push((session_id, worker_id));
})
};
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: Some(on_expire),
};
store.map.insert(
"sess-router-only-reaped".to_owned(),
AffinityEntry {
worker: worker(18, 0),
ttl: Duration::from_secs(30),
expires_at: Instant::now() - Duration::from_secs(1),
kind: AffinityKind::RouterOnly,
},
);
store.reap_expired(Instant::now());
assert!(store.map.get("sess-router-only-reaped").is_none());
assert!(expired_sessions.lock().unwrap().is_empty());
}
}