redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};

use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{FromRequest, RequestContext, ServerHooks};
use sqlx::MySqlPool;
use tokio::sync::Mutex as AsyncMutex;

use crate::config::ProxyConfig;

/// Shared application state stored in the router.
#[derive(Clone)]
pub struct AppState {
    pub config: Arc<ProxyConfig>,
    pub pools: BackendPoolRegistry,
    pub sessions: SessionRegistry,
}

impl AppState {
    pub fn new(config: ProxyConfig) -> Self {
        Self {
            config: Arc::new(config),
            pools: BackendPoolRegistry::new(),
            sessions: SessionRegistry::new(),
        }
    }

    pub fn hooks(&self) -> ProxyHooks {
        ProxyHooks {
            sessions: self.sessions.clone(),
        }
    }
}

/// Per-user MySQL pool registry with idle eviction metadata.
#[derive(Clone)]
pub struct BackendPoolRegistry {
    inner: Arc<Mutex<HashMap<String, PoolEntry>>>,
}

struct PoolEntry {
    pool: Arc<MySqlPool>,
    last_used_ms: AtomicU64,
}

impl Default for BackendPoolRegistry {
    fn default() -> Self {
        Self::new()
    }
}

impl BackendPoolRegistry {
    pub fn new() -> Self {
        Self {
            inner: Arc::new(Mutex::new(HashMap::new())),
        }
    }

    pub fn get(&self, user: &str) -> Option<Arc<MySqlPool>> {
        let guard = self.inner.lock().unwrap();
        guard.get(user).map(|entry| {
            entry.last_used_ms.store(now_ms(), Ordering::Release);
            Arc::clone(&entry.pool)
        })
    }

    pub fn insert_if_absent(&self, user: String, pool: Arc<MySqlPool>) -> Arc<MySqlPool> {
        let mut guard = self.inner.lock().unwrap();
        match guard.get(&user) {
            Some(entry) => {
                entry.last_used_ms.store(now_ms(), Ordering::Release);
                Arc::clone(&entry.pool)
            }
            None => {
                let entry = PoolEntry {
                    pool: Arc::clone(&pool),
                    last_used_ms: AtomicU64::new(now_ms()),
                };
                guard.insert(user, entry);
                pool
            }
        }
    }

    pub fn touch(&self, user: &str) {
        if let Some(entry) = self.inner.lock().unwrap().get(user) {
            entry.last_used_ms.store(now_ms(), Ordering::Release);
        }
    }

    pub fn snapshot_pools(&self) -> Vec<Arc<MySqlPool>> {
        let guard = self.inner.lock().unwrap();
        guard
            .values()
            .map(|entry| Arc::clone(&entry.pool))
            .collect()
    }

    pub fn prune_idle(&self, idle_ttl_ms: u64) -> usize {
        if idle_ttl_ms == 0 {
            let mut guard = self.inner.lock().unwrap();
            let removed = guard.len();
            guard.clear();
            return removed;
        }
        let cutoff = now_ms().saturating_sub(idle_ttl_ms);
        let mut guard = self.inner.lock().unwrap();
        let before = guard.len();
        guard.retain(|_, entry| entry.last_used_ms.load(Ordering::Acquire) >= cutoff);
        before - guard.len()
    }
}

/// Per-connection session registry keyed by client id.
#[derive(Clone)]
pub struct SessionRegistry {
    inner: Arc<Mutex<HashMap<u64, Arc<Session>>>>,
}

impl Default for SessionRegistry {
    fn default() -> Self {
        Self::new()
    }
}

impl SessionRegistry {
    pub fn new() -> Self {
        Self {
            inner: Arc::new(Mutex::new(HashMap::new())),
        }
    }

    pub fn insert(&self, client_id: u64) -> Arc<Session> {
        let mut guard = self.inner.lock().unwrap();
        let session = Arc::new(Session::new());
        guard.insert(client_id, Arc::clone(&session));
        session
    }

    pub fn remove(&self, client_id: u64) {
        self.inner.lock().unwrap().remove(&client_id);
    }

    pub fn get(&self, client_id: u64) -> Option<Arc<Session>> {
        self.inner.lock().unwrap().get(&client_id).map(Arc::clone)
    }
}

/// Server hooks to manage session lifecycle.
#[derive(Clone)]
pub struct ProxyHooks {
    sessions: SessionRegistry,
}

impl ServerHooks for ProxyHooks {
    fn on_connection_open(&self, info: resp_async::ConnectionInfo) {
        let _ = self.sessions.insert(info.id);
    }

    fn on_connection_close(&self, info: resp_async::ConnectionInfo) {
        self.sessions.remove(info.id);
    }
}

/// Extracted session handle for handlers.
#[derive(Clone)]
pub struct SessionHandle(pub Arc<Session>);

impl FromRequest<AppState> for SessionHandle {
    type Rejection = RespError;

    async fn from_request(
        ctx: &mut RequestContext,
        state: &Arc<AppState>,
    ) -> Result<Self, Self::Rejection> {
        state
            .sessions
            .get(ctx.client_id)
            .map(SessionHandle)
            .ok_or_else(RespError::internal)
    }
}

/// Per-connection session state.
pub struct Session {
    auth: AsyncMutex<Option<AuthContext>>,
    pubsub: AsyncMutex<PubSubState>,
    client_name: AsyncMutex<Option<Bytes>>,
    poller_active: AtomicBool,
}

impl Default for Session {
    fn default() -> Self {
        Self::new()
    }
}

impl Session {
    pub fn new() -> Self {
        Self {
            auth: AsyncMutex::new(None),
            pubsub: AsyncMutex::new(PubSubState::default()),
            client_name: AsyncMutex::new(None),
            poller_active: AtomicBool::new(false),
        }
    }

    pub async fn set_auth(&self, auth: AuthContext) {
        *self.auth.lock().await = Some(auth);
    }

    pub async fn auth(&self) -> Option<AuthContext> {
        self.auth.lock().await.clone()
    }

    pub async fn pubsub_state(&self) -> tokio::sync::MutexGuard<'_, PubSubState> {
        self.pubsub.lock().await
    }

    pub async fn set_client_name(&self, name: Option<Bytes>) {
        *self.client_name.lock().await = name;
    }

    pub async fn client_name(&self) -> Option<Bytes> {
        self.client_name.lock().await.clone()
    }

    pub fn try_activate_poller(&self) -> bool {
        self.poller_active
            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
            .is_ok()
    }

    pub fn deactivate_poller(&self) {
        self.poller_active.store(false, Ordering::Release);
    }
}

/// Authenticated user context stored in session.
#[derive(Clone)]
pub struct AuthContext {
    pub user: String,
    #[allow(dead_code)]
    pub tenant_id: Bytes,
    pub pool: Arc<MySqlPool>,
}

/// Pub/Sub session state (subscriber id + channels).
#[derive(Default)]
pub struct PubSubState {
    pub subscriber_id: Option<u64>,
    pub channels: HashSet<Bytes>,
}

/// Current unix time in milliseconds.
pub fn now_ms() -> u64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_millis() as u64
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn session_registry_roundtrip() {
        let registry = SessionRegistry::new();
        let session = registry.insert(42);
        assert!(registry.get(42).is_some());
        registry.remove(42);
        assert!(registry.get(42).is_none());
        drop(session);
    }

    #[tokio::test]
    async fn pool_registry_prune() {
        let registry = BackendPoolRegistry::new();
        let pool = Arc::new(MySqlPool::connect_lazy("mysql://root@localhost/test").unwrap());
        registry.insert_if_absent("user".to_string(), pool);
        let removed = registry.prune_idle(0);
        assert_eq!(removed, 1);
    }
}