use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use crate::preprocessor::PreprocessedRequest;
const REAPER_INTERVAL: Duration = Duration::from_secs(30);
type ExpiryHandler = Arc<dyn Fn(String, u64) + Send + Sync>;
pub trait AffinityStore: Send + Sync {
fn get(&self, session_id: &str) -> Option<u64>;
fn put(&self, session_id: &str, worker_id: u64, ttl: Duration);
fn remove(&self, session_id: &str);
}
struct AffinityEntry {
worker_id: u64,
ttl: Duration,
expires_at: Instant,
}
#[derive(Clone)]
pub struct InMemoryAffinityStore {
map: Arc<DashMap<String, AffinityEntry>>,
on_expire: Option<ExpiryHandler>,
}
impl Default for InMemoryAffinityStore {
fn default() -> Self {
Self::new()
}
}
impl InMemoryAffinityStore {
pub fn new() -> Self {
Self::new_with_on_expire(None)
}
pub fn new_with_on_expire(on_expire: Option<ExpiryHandler>) -> Self {
let map = Arc::new(DashMap::new());
let store = InMemoryAffinityStore { map, on_expire };
let reaper_store = store.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(REAPER_INTERVAL);
loop {
interval.tick().await;
reaper_store.reap_expired(Instant::now());
}
});
store
}
fn reap_expired(&self, now: Instant) {
let on_expire = self.on_expire.clone();
self.map.retain(|session_id, entry: &mut AffinityEntry| {
let alive = entry.expires_at > now;
if !alive {
tracing::debug!(%session_id, "Session affinity expired, removing");
if let Some(handler) = &on_expire {
handler(session_id.clone(), entry.worker_id);
}
}
alive
});
}
}
impl AffinityStore for InMemoryAffinityStore {
fn get(&self, session_id: &str) -> Option<u64> {
let mut entry = self.map.get_mut(session_id)?;
if entry.expires_at <= Instant::now() {
let worker_id = entry.worker_id;
drop(entry);
self.map.remove(session_id);
tracing::debug!(%session_id, "Session affinity expired during resolve");
if let Some(handler) = &self.on_expire {
handler(session_id.to_owned(), worker_id);
}
return None;
}
entry.expires_at = Instant::now() + entry.ttl;
let worker_id = entry.worker_id;
tracing::info!(%session_id, worker_id, "Sticky session hit");
Some(worker_id)
}
fn put(&self, session_id: &str, worker_id: u64, ttl: Duration) {
self.map.insert(
session_id.to_owned(),
AffinityEntry {
worker_id,
ttl,
expires_at: Instant::now() + ttl,
},
);
}
fn remove(&self, session_id: &str) {
self.map.remove(session_id);
}
}
pub struct StickySessionRouter {
store: Box<dyn AffinityStore>,
}
impl StickySessionRouter {
pub fn new(store: impl AffinityStore + 'static) -> Self {
tracing::debug!("StickySessionRouter initialized");
StickySessionRouter {
store: Box::new(store),
}
}
pub fn resolve(&self, request: &PreprocessedRequest) -> Option<u64> {
let routing = request.routing.as_ref()?;
let session_id = routing
.session_control
.as_ref()
.map(|sc| sc.session_id.as_str())?;
self.store.get(session_id)
}
pub fn bind(&self, session_id: &str, worker_id: u64, ttl: Duration) {
tracing::info!(%session_id, worker_id, ttl_secs = ttl.as_secs(), "Binding session affinity");
self.store.put(session_id, worker_id, ttl);
}
pub fn unbind(&self, session_id: &str) {
tracing::info!(%session_id, "Removing session affinity");
self.store.remove(session_id);
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
use crate::protocols::common::preprocessor::{PreprocessedRequest, RoutingHints};
use crate::protocols::openai::nvext::SessionControl;
fn make_request(session_id: Option<&str>) -> PreprocessedRequest {
let routing = session_id.map(|id| RoutingHints {
session_control: Some(SessionControl {
session_id: id.to_owned(),
action: None,
timeout: 300,
}),
..Default::default()
});
PreprocessedRequest::builder()
.model("test".to_string())
.token_ids(vec![1, 2, 3])
.stop_conditions(Default::default())
.sampling_options(Default::default())
.output_options(Default::default())
.routing(routing)
.build()
.unwrap()
}
#[test]
fn resolve_returns_none_for_unknown_session() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
let router = StickySessionRouter::new(store);
let req = make_request(Some("unknown-session"));
assert!(router.resolve(&req).is_none());
}
#[test]
fn resolve_returns_none_when_no_session_id() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
let router = StickySessionRouter::new(store);
let req = make_request(None);
assert!(router.resolve(&req).is_none());
}
#[test]
fn bind_then_resolve_returns_worker() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
let router = StickySessionRouter::new(store);
router.bind("sess-1", 42, Duration::from_secs(300));
let req = make_request(Some("sess-1"));
assert_eq!(router.resolve(&req), Some(42));
}
#[test]
fn unbind_removes_affinity() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
let router = StickySessionRouter::new(store);
router.bind("sess-1", 42, Duration::from_secs(300));
router.unbind("sess-1");
let req = make_request(Some("sess-1"));
assert!(router.resolve(&req).is_none());
}
#[test]
fn expired_entry_returns_none() {
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: None,
};
store.map.insert(
"sess-expired".to_owned(),
AffinityEntry {
worker_id: 99,
ttl: Duration::from_secs(0),
expires_at: Instant::now() - Duration::from_secs(1),
},
);
let router = StickySessionRouter::new(store);
let req = make_request(Some("sess-expired"));
assert!(router.resolve(&req).is_none());
assert!(router.store.get("sess-expired").is_none());
}
#[test]
fn resolve_refreshes_ttl() {
let map = Arc::new(DashMap::new());
let ttl = Duration::from_secs(60);
map.insert(
"sess-refresh".to_owned(),
AffinityEntry {
worker_id: 7,
ttl,
expires_at: Instant::now() + Duration::from_secs(5),
},
);
let store = InMemoryAffinityStore {
map: map.clone(),
on_expire: None,
};
let router = StickySessionRouter::new(store);
let req = make_request(Some("sess-refresh"));
assert_eq!(router.resolve(&req), Some(7));
let entry = map.get("sess-refresh").unwrap();
let remaining = entry.expires_at.duration_since(Instant::now());
assert!(
remaining > Duration::from_secs(50),
"TTL should have been refreshed, but remaining={remaining:?}"
);
}
#[test]
fn expired_entry_triggers_close_callback_on_resolve() {
let expired_sessions = Arc::new(Mutex::new(Vec::new()));
let on_expire = {
let expired_sessions = expired_sessions.clone();
Arc::new(move |session_id: String, worker_id: u64| {
expired_sessions
.lock()
.unwrap()
.push((session_id, worker_id));
})
};
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: Some(on_expire),
};
store.map.insert(
"sess-expired".to_owned(),
AffinityEntry {
worker_id: 99,
ttl: Duration::from_secs(0),
expires_at: Instant::now() - Duration::from_secs(1),
},
);
let router = StickySessionRouter::new(store);
let req = make_request(Some("sess-expired"));
assert!(router.resolve(&req).is_none());
assert_eq!(
expired_sessions.lock().unwrap().as_slice(),
&[("sess-expired".to_string(), 99)]
);
}
#[test]
fn reaper_triggers_close_callback_for_expired_entry() {
let expired_sessions = Arc::new(Mutex::new(Vec::new()));
let on_expire = {
let expired_sessions = expired_sessions.clone();
Arc::new(move |session_id: String, worker_id: u64| {
expired_sessions
.lock()
.unwrap()
.push((session_id, worker_id));
})
};
let store = InMemoryAffinityStore {
map: Arc::new(DashMap::new()),
on_expire: Some(on_expire),
};
store.map.insert(
"sess-reaped".to_owned(),
AffinityEntry {
worker_id: 17,
ttl: Duration::from_secs(30),
expires_at: Instant::now() - Duration::from_secs(1),
},
);
store.reap_expired(Instant::now());
assert!(store.map.get("sess-reaped").is_none());
assert_eq!(
expired_sessions.lock().unwrap().as_slice(),
&[("sess-reaped".to_string(), 17)]
);
}
}