redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use std::sync::Arc;

use crate::handlers::util::{arg_as_bytes, arg_as_i64, invalid_integer, wrong_arity};
use crate::state::{AppState, Session, SessionHandle, now_ms};
use crate::storage::{delete_key_all, map_sql_err};

pub async fn expire(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    expire_relative(cmd, state, session, 1000, "EXPIRE").await
}

pub async fn pexpire(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    expire_relative(cmd, state, session, 1, "PEXPIRE").await
}

pub async fn expireat(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    expire_absolute(cmd, state, session, 1000, "EXPIREAT").await
}

pub async fn pexpireat(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    expire_absolute(cmd, state, session, 1, "PEXPIREAT").await
}

pub async fn persist(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() != 1 {
        return Err(wrong_arity("PERSIST"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;
    let res = sqlx::query(
        "UPDATE redis_kv SET expires_at_ms = NULL \
         WHERE r_key = ? AND expires_at_ms IS NOT NULL AND expires_at_ms > ?",
    )
    .bind(key.as_ref())
    .bind(now)
    .execute(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;
    Ok(Value::Integer(res.rows_affected() as i64))
}

pub async fn ttl(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    ttl_inner(cmd, state, session, 1000, "TTL").await
}

pub async fn pttl(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    ttl_inner(cmd, state, session, 1, "PTTL").await
}

async fn expire_relative(
    cmd: resp_async::Command,
    state: Arc<AppState>,
    session: Arc<Session>,
    multiplier: i64,
    name: &str,
) -> Result<Value, RespError> {
    if cmd.args.len() != 2 {
        return Err(wrong_arity(name));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let ttl = arg_as_i64(&cmd.args[1])?;
    if ttl <= 0 {
        return delete_key_if_present(auth.pool.as_ref(), key, now_ms() as i64).await;
    }
    let now = now_ms() as i64;
    let expires_at = ttl
        .checked_mul(multiplier)
        .and_then(|delta| now.checked_add(delta))
        .ok_or_else(invalid_integer)?;
    expire_at_ms(auth.pool.as_ref(), key, expires_at, now).await
}

async fn expire_absolute(
    cmd: resp_async::Command,
    state: Arc<AppState>,
    session: Arc<Session>,
    multiplier: i64,
    name: &str,
) -> Result<Value, RespError> {
    if cmd.args.len() != 2 {
        return Err(wrong_arity(name));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let ts = arg_as_i64(&cmd.args[1])?;
    let expires_at = ts.checked_mul(multiplier).ok_or_else(invalid_integer)?;
    let now = now_ms() as i64;
    if expires_at <= now {
        return delete_key_if_present(auth.pool.as_ref(), key, now).await;
    }
    expire_at_ms(auth.pool.as_ref(), key, expires_at, now).await
}

async fn expire_at_ms(
    pool: &sqlx::MySqlPool,
    key: &bytes::Bytes,
    expires_at: i64,
    now: i64,
) -> Result<Value, RespError> {
    let res = sqlx::query(
        "UPDATE redis_kv SET expires_at_ms = ? \
         WHERE r_key = ? AND (expires_at_ms IS NULL OR expires_at_ms > ?)",
    )
    .bind(expires_at)
    .bind(key.as_ref())
    .bind(now)
    .execute(pool)
    .await
    .map_err(map_sql_err)?;
    Ok(Value::Integer(res.rows_affected() as i64))
}

async fn delete_key_if_present(
    pool: &sqlx::MySqlPool,
    key: &bytes::Bytes,
    now: i64,
) -> Result<Value, RespError> {
    let row = sqlx::query(
        "SELECT r_key FROM redis_kv WHERE r_key = ? AND (expires_at_ms IS NULL OR expires_at_ms > ?)",
    )
    .bind(key.as_ref())
    .bind(now)
    .fetch_optional(pool)
    .await
    .map_err(map_sql_err)?;
    if row.is_none() {
        return Ok(Value::Integer(0));
    }
    delete_key_all(pool, key).await?;
    Ok(Value::Integer(1))
}

async fn ttl_inner(
    cmd: resp_async::Command,
    state: Arc<AppState>,
    session: Arc<Session>,
    divisor: i64,
    name: &str,
) -> Result<Value, RespError> {
    if cmd.args.len() != 1 {
        return Err(wrong_arity(name));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;
    let row = sqlx::query("SELECT expires_at_ms FROM redis_kv WHERE r_key = ?")
        .bind(key.as_ref())
        .fetch_optional(auth.pool.as_ref())
        .await
        .map_err(map_sql_err)?;

    let Some(row) = row else {
        return Ok(Value::Integer(-2));
    };
    let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
    let Some(expires_at) = expires_at else {
        return Ok(Value::Integer(-1));
    };
    if expires_at <= now {
        delete_key_all(auth.pool.as_ref(), key).await?;
        return Ok(Value::Integer(-2));
    }
    let diff = expires_at - now;
    Ok(Value::Integer(diff / divisor))
}

#[cfg(test)]
mod tests {
    #[test]
    fn ttl_division_floor() {
        let now = 1000i64;
        let expires_at = 2500i64;
        let diff = expires_at - now;
        assert_eq!(diff / 1000, 1);
    }
}