use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use osproxy_core::{Clock, ClusterId, Instant, SystemClock};
pub const DEFAULT_CURSOR_TTL: Duration = Duration::from_secs(300);
pub const DEFAULT_CAPACITY: usize = 100_000;
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub enum Affinity {
#[default]
Off,
Pin,
}
#[derive(Clone, Debug)]
struct Pinned {
cluster: ClusterId,
pinned_at: Instant,
}
pub struct CursorAffinity {
clock: Arc<dyn Clock>,
ttl: Duration,
capacity: usize,
entries: Mutex<HashMap<String, Pinned>>,
}
impl std::fmt::Debug for CursorAffinity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CursorAffinity")
.field("ttl", &self.ttl)
.field("capacity", &self.capacity)
.field("live", &self.len())
.finish_non_exhaustive()
}
}
impl CursorAffinity {
#[must_use]
pub fn new(ttl: Duration, capacity: usize) -> Self {
Self {
clock: Arc::new(SystemClock),
ttl,
capacity: capacity.max(1),
entries: Mutex::new(HashMap::new()),
}
}
#[must_use]
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = clock;
self
}
pub fn pin(&self, cursor_id: impl Into<String>, cluster: ClusterId) {
let now = self.clock.now();
let mut entries = self.lock();
entries.retain(|_, p| !self.is_expired(p, now));
if entries.len() >= self.capacity {
if let Some(oldest) = entries
.iter()
.min_by_key(|(_, p)| p.pinned_at)
.map(|(k, _)| k.clone())
{
entries.remove(&oldest);
}
}
entries.insert(
cursor_id.into(),
Pinned {
cluster,
pinned_at: now,
},
);
}
#[must_use]
pub fn resolve(&self, cursor_id: &str) -> Option<ClusterId> {
let now = self.clock.now();
let entries = self.lock();
entries
.get(cursor_id)
.filter(|p| !self.is_expired(p, now))
.map(|p| p.cluster.clone())
}
pub fn release(&self, cursor_id: &str) {
self.lock().remove(cursor_id);
}
#[must_use]
pub fn len(&self) -> usize {
self.lock().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.lock().is_empty()
}
fn is_expired(&self, p: &Pinned, now: Instant) -> bool {
now.saturating_duration_since(p.pinned_at) >= self.ttl
}
fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<String, Pinned>> {
self.entries
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}