use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{FromRequest, RequestContext, ServerHooks};
use sqlx::MySqlPool;
use tokio::sync::Mutex as AsyncMutex;
use crate::config::ProxyConfig;
#[derive(Clone)]
pub struct AppState {
pub config: Arc<ProxyConfig>,
pub pools: BackendPoolRegistry,
pub sessions: SessionRegistry,
}
impl AppState {
pub fn new(config: ProxyConfig) -> Self {
Self {
config: Arc::new(config),
pools: BackendPoolRegistry::new(),
sessions: SessionRegistry::new(),
}
}
pub fn hooks(&self) -> ProxyHooks {
ProxyHooks {
sessions: self.sessions.clone(),
}
}
}
#[derive(Clone)]
pub struct BackendPoolRegistry {
inner: Arc<Mutex<HashMap<String, PoolEntry>>>,
}
struct PoolEntry {
pool: Arc<MySqlPool>,
last_used_ms: AtomicU64,
}
impl Default for BackendPoolRegistry {
fn default() -> Self {
Self::new()
}
}
impl BackendPoolRegistry {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn get(&self, user: &str) -> Option<Arc<MySqlPool>> {
let guard = self.inner.lock().unwrap();
guard.get(user).map(|entry| {
entry.last_used_ms.store(now_ms(), Ordering::Release);
Arc::clone(&entry.pool)
})
}
pub fn insert_if_absent(&self, user: String, pool: Arc<MySqlPool>) -> Arc<MySqlPool> {
let mut guard = self.inner.lock().unwrap();
match guard.get(&user) {
Some(entry) => {
entry.last_used_ms.store(now_ms(), Ordering::Release);
Arc::clone(&entry.pool)
}
None => {
let entry = PoolEntry {
pool: Arc::clone(&pool),
last_used_ms: AtomicU64::new(now_ms()),
};
guard.insert(user, entry);
pool
}
}
}
pub fn touch(&self, user: &str) {
if let Some(entry) = self.inner.lock().unwrap().get(user) {
entry.last_used_ms.store(now_ms(), Ordering::Release);
}
}
pub fn snapshot_pools(&self) -> Vec<Arc<MySqlPool>> {
let guard = self.inner.lock().unwrap();
guard
.values()
.map(|entry| Arc::clone(&entry.pool))
.collect()
}
pub fn prune_idle(&self, idle_ttl_ms: u64) -> usize {
if idle_ttl_ms == 0 {
let mut guard = self.inner.lock().unwrap();
let removed = guard.len();
guard.clear();
return removed;
}
let cutoff = now_ms().saturating_sub(idle_ttl_ms);
let mut guard = self.inner.lock().unwrap();
let before = guard.len();
guard.retain(|_, entry| entry.last_used_ms.load(Ordering::Acquire) >= cutoff);
before - guard.len()
}
}
#[derive(Clone)]
pub struct SessionRegistry {
inner: Arc<Mutex<HashMap<u64, Arc<Session>>>>,
}
impl Default for SessionRegistry {
fn default() -> Self {
Self::new()
}
}
impl SessionRegistry {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn insert(&self, client_id: u64) -> Arc<Session> {
let mut guard = self.inner.lock().unwrap();
let session = Arc::new(Session::new());
guard.insert(client_id, Arc::clone(&session));
session
}
pub fn remove(&self, client_id: u64) {
self.inner.lock().unwrap().remove(&client_id);
}
pub fn get(&self, client_id: u64) -> Option<Arc<Session>> {
self.inner.lock().unwrap().get(&client_id).map(Arc::clone)
}
}
#[derive(Clone)]
pub struct ProxyHooks {
sessions: SessionRegistry,
}
impl ServerHooks for ProxyHooks {
fn on_connection_open(&self, info: resp_async::ConnectionInfo) {
let _ = self.sessions.insert(info.id);
}
fn on_connection_close(&self, info: resp_async::ConnectionInfo) {
self.sessions.remove(info.id);
}
}
#[derive(Clone)]
pub struct SessionHandle(pub Arc<Session>);
impl FromRequest<AppState> for SessionHandle {
type Rejection = RespError;
async fn from_request(
ctx: &mut RequestContext,
state: &Arc<AppState>,
) -> Result<Self, Self::Rejection> {
state
.sessions
.get(ctx.client_id)
.map(SessionHandle)
.ok_or_else(RespError::internal)
}
}
pub struct Session {
auth: AsyncMutex<Option<AuthContext>>,
pubsub: AsyncMutex<PubSubState>,
client_name: AsyncMutex<Option<Bytes>>,
poller_active: AtomicBool,
}
impl Default for Session {
fn default() -> Self {
Self::new()
}
}
impl Session {
pub fn new() -> Self {
Self {
auth: AsyncMutex::new(None),
pubsub: AsyncMutex::new(PubSubState::default()),
client_name: AsyncMutex::new(None),
poller_active: AtomicBool::new(false),
}
}
pub async fn set_auth(&self, auth: AuthContext) {
*self.auth.lock().await = Some(auth);
}
pub async fn auth(&self) -> Option<AuthContext> {
self.auth.lock().await.clone()
}
pub async fn pubsub_state(&self) -> tokio::sync::MutexGuard<'_, PubSubState> {
self.pubsub.lock().await
}
pub async fn set_client_name(&self, name: Option<Bytes>) {
*self.client_name.lock().await = name;
}
pub async fn client_name(&self) -> Option<Bytes> {
self.client_name.lock().await.clone()
}
pub fn try_activate_poller(&self) -> bool {
self.poller_active
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
pub fn deactivate_poller(&self) {
self.poller_active.store(false, Ordering::Release);
}
}
#[derive(Clone)]
pub struct AuthContext {
pub user: String,
#[allow(dead_code)]
pub tenant_id: Bytes,
pub pool: Arc<MySqlPool>,
}
#[derive(Default)]
pub struct PubSubState {
pub subscriber_id: Option<u64>,
pub channels: HashSet<Bytes>,
}
pub fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn session_registry_roundtrip() {
let registry = SessionRegistry::new();
let session = registry.insert(42);
assert!(registry.get(42).is_some());
registry.remove(42);
assert!(registry.get(42).is_none());
drop(session);
}
#[tokio::test]
async fn pool_registry_prune() {
let registry = BackendPoolRegistry::new();
let pool = Arc::new(MySqlPool::connect_lazy("mysql://root@localhost/test").unwrap());
registry.insert_if_absent("user".to_string(), pool);
let removed = registry.prune_idle(0);
assert_eq!(removed, 1);
}
}