use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::time::SystemTime;
use parking_lot::Mutex;
use super::protocol::ShimProtocol;
pub type ShimConnId = u64;
#[derive(Debug, Clone)]
pub struct ShimConnEntry {
pub protocol: ShimProtocol,
pub pid: u32,
pub connected_at: SystemTime,
}
#[derive(Debug, Default)]
pub struct ShimRegistry {
inner: Mutex<HashMap<ShimConnId, ShimConnEntry>>,
next_id: AtomicU64,
}
impl ShimRegistry {
#[must_use]
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn register(self: &Arc<Self>, protocol: ShimProtocol, pid: u32) -> ShimHandle {
let id = self.next_id.fetch_add(1, Ordering::Relaxed) + 1;
let entry = ShimConnEntry {
protocol,
pid,
connected_at: SystemTime::now(),
};
self.inner.lock().insert(id, entry);
ShimHandle {
registry: Arc::downgrade(self),
id,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.lock().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.lock().is_empty()
}
pub fn try_register_bounded(
self: &Arc<Self>,
protocol: ShimProtocol,
pid: u32,
cap: usize,
) -> Result<ShimHandle, RejectReason> {
let mut guard = self.inner.lock();
let current = guard.len();
if current >= cap {
return Err(RejectReason::CapExceeded { current, cap });
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed) + 1;
let entry = ShimConnEntry {
protocol,
pid,
connected_at: SystemTime::now(),
};
guard.insert(id, entry);
drop(guard);
Ok(ShimHandle {
registry: Arc::downgrade(self),
id,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RejectReason {
CapExceeded {
current: usize,
cap: usize,
},
}
impl std::fmt::Display for RejectReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CapExceeded { current, cap } => {
write!(f, "shim registry full ({current} / {cap})")
}
}
}
}
impl std::error::Error for RejectReason {}
#[derive(Debug)]
pub struct ShimHandle {
registry: Weak<ShimRegistry>,
id: ShimConnId,
}
impl Drop for ShimHandle {
fn drop(&mut self) {
if let Some(reg) = self.registry.upgrade() {
reg.inner.lock().remove(&self.id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_and_drop_cleans_up() {
let reg = ShimRegistry::new();
assert_eq!(reg.len(), 0);
let handle = reg.register(ShimProtocol::Lsp, 42);
assert_eq!(reg.len(), 1);
drop(handle);
assert_eq!(reg.len(), 0);
}
#[test]
fn multiple_registrations_unique_ids() {
let reg = ShimRegistry::new();
let h1 = reg.register(ShimProtocol::Lsp, 1);
let h2 = reg.register(ShimProtocol::Mcp, 2);
assert_ne!(h1.id, h2.id);
assert_eq!(reg.len(), 2);
}
#[test]
fn try_register_bounded_rejects_when_at_cap() {
let reg = ShimRegistry::new();
let h1 = reg.try_register_bounded(ShimProtocol::Lsp, 1, 2).unwrap();
let h2 = reg.try_register_bounded(ShimProtocol::Mcp, 2, 2).unwrap();
assert_eq!(reg.len(), 2);
let err = reg
.try_register_bounded(ShimProtocol::Lsp, 3, 2)
.expect_err("must reject at cap");
match err {
RejectReason::CapExceeded { current, cap } => {
assert_eq!(current, 2);
assert_eq!(cap, 2);
}
}
assert_eq!(reg.len(), 2);
drop(h1);
let h3 = reg.try_register_bounded(ShimProtocol::Lsp, 3, 2).unwrap();
assert_eq!(reg.len(), 2);
drop(h2);
drop(h3);
assert_eq!(reg.len(), 0);
}
#[test]
fn reject_reason_display_matches_wire_form() {
let r = RejectReason::CapExceeded {
current: 256,
cap: 256,
};
assert_eq!(r.to_string(), "shim registry full (256 / 256)");
}
#[test]
fn try_register_bounded_cap_race_256_concurrent_admits_exactly_cap() {
use std::sync::Barrier;
use std::thread;
let reg = ShimRegistry::new();
let cap: usize = 256;
let n_threads: usize = 300;
let barrier = Arc::new(Barrier::new(n_threads));
let handles: Vec<_> = (0..n_threads)
.map(|i| {
let reg = Arc::clone(®);
let barrier = Arc::clone(&barrier);
thread::spawn(move || -> Option<ShimHandle> {
barrier.wait();
reg.try_register_bounded(ShimProtocol::Lsp, i as u32, cap)
.ok()
})
})
.collect();
let mut ok_handles = Vec::new();
let mut rejected = 0usize;
for h in handles {
match h.join().unwrap() {
Some(handle) => ok_handles.push(handle),
None => rejected += 1,
}
}
assert_eq!(ok_handles.len(), cap, "exactly cap admissions must succeed");
assert_eq!(rejected, n_threads - cap, "remaining must be rejected");
assert_eq!(reg.len(), cap);
drop(ok_handles);
assert_eq!(reg.len(), 0);
}
}