redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use std::sync::Arc;

use crate::handlers::strings;
use crate::handlers::util::{
    arg_as_bytes, arg_as_i64, glob_has_wildcards, glob_match, ok, random_index, wrong_arity,
};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{
    TYPE_HASH, TYPE_LIST, TYPE_SET, TYPE_ZSET, delete_key_in_tx, load_meta, map_sql_err, type_name,
};

pub async fn type_cmd(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() != 1 {
        return Err(wrong_arity("TYPE"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;
    let name = match load_meta(auth.pool.as_ref(), key, now).await? {
        Some(meta) => type_name(meta.r_type),
        None => b"none",
    };
    Ok(Value::Simple(Bytes::from_static(name)))
}

pub async fn keys(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() != 1 {
        return Err(wrong_arity("KEYS"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let pattern = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;

    if !glob_has_wildcards(pattern.as_ref()) {
        let exists = load_meta(auth.pool.as_ref(), pattern, now).await?.is_some();
        if exists {
            return Ok(Value::Array(vec![Value::Bulk(pattern.clone())]));
        }
        return Ok(Value::Array(Vec::new()));
    }

    let rows = sqlx::query(
        "SELECT r_key FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?)",
    )
    .bind(now)
    .fetch_all(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;

    let mut out = Vec::new();
    for row in rows {
        let key: Vec<u8> = row.try_get("r_key").map_err(map_sql_err)?;
        if glob_match(pattern.as_ref(), &key) {
            out.push(Value::Bulk(Bytes::from(key)));
        }
    }

    Ok(Value::Array(out))
}

pub async fn scan(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.is_empty() {
        return Err(wrong_arity("SCAN"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);

    let mut cursor = arg_as_i64(&cmd.args[0])?;
    if cursor < 0 {
        return Err(RespError::invalid_data("ERR invalid cursor"));
    }
    let mut count = 10i64;
    let mut pattern: Option<Bytes> = None;

    let mut i = 1;
    while i < cmd.args.len() {
        let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
        for b in &mut token {
            b.make_ascii_uppercase();
        }
        match token.as_slice() {
            b"MATCH" => {
                let next = cmd
                    .args
                    .get(i + 1)
                    .ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
                pattern = Some(arg_as_bytes(next)?.clone());
                i += 2;
            }
            b"COUNT" => {
                let next = cmd
                    .args
                    .get(i + 1)
                    .ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
                count = arg_as_i64(next)?;
                if count <= 0 {
                    return Err(RespError::invalid_data("ERR invalid COUNT"));
                }
                i += 2;
            }
            _ => return Err(RespError::invalid_data("ERR syntax error")),
        }
    }

    let now = now_ms() as i64;
    let rows = sqlx::query(
        "SELECT r_key FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?) \
         ORDER BY r_key LIMIT ?, ?",
    )
    .bind(now)
    .bind(cursor)
    .bind(count)
    .fetch_all(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;

    let mut out = Vec::new();
    for row in &rows {
        let key: Vec<u8> = row.try_get("r_key").map_err(map_sql_err)?;
        if let Some(pattern) = &pattern
            && !glob_match(pattern.as_ref(), &key)
        {
            continue;
        }
        out.push(Value::Bulk(Bytes::from(key)));
    }

    if rows.len() < count as usize {
        cursor = 0;
    } else {
        cursor = cursor.saturating_add(count);
    }

    Ok(Value::Array(vec![
        Value::Bulk(Bytes::from(cursor.to_string())),
        Value::Array(out),
    ]))
}

pub async fn randomkey(
    _cmd: Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let now = now_ms() as i64;
    let row = sqlx::query(
        "SELECT COUNT(*) AS count FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?)",
    )
    .bind(now)
    .fetch_one(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;
    let count: i64 = row.try_get("count").map_err(map_sql_err)?;
    if count <= 0 {
        return Ok(Value::Null);
    }
    let offset = random_index(count);
    let row = sqlx::query(
        "SELECT r_key FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?) \
         ORDER BY r_key LIMIT 1 OFFSET ?",
    )
    .bind(now)
    .bind(offset)
    .fetch_optional(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;

    let Some(row) = row else {
        return Ok(Value::Null);
    };
    let key: Vec<u8> = row.try_get("r_key").map_err(map_sql_err)?;
    Ok(Value::Bulk(Bytes::from(key)))
}

pub async fn rename(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    rename_inner(cmd, state, session, true).await
}

pub async fn renamenx(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    rename_inner(cmd, state, session, false).await
}

async fn rename_inner(
    cmd: resp_async::Command,
    state: Arc<AppState>,
    session: Arc<crate::state::Session>,
    allow_overwrite: bool,
) -> Result<Value, RespError> {
    if cmd.args.len() != 2 {
        return Err(wrong_arity(if allow_overwrite {
            "RENAME"
        } else {
            "RENAMENX"
        }));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let src = arg_as_bytes(&cmd.args[0])?;
    let dest = arg_as_bytes(&cmd.args[1])?;
    if src == dest {
        return Ok(if allow_overwrite {
            ok()
        } else {
            Value::Integer(0)
        });
    }
    let now = now_ms() as i64;
    let src_meta = load_meta(auth.pool.as_ref(), src, now).await?;
    let Some(src_meta) = src_meta else {
        return Err(RespError::invalid_data("ERR no such key"));
    };
    let dest_meta = load_meta(auth.pool.as_ref(), dest, now).await?;
    if dest_meta.is_some() && !allow_overwrite {
        return Ok(Value::Integer(0));
    }

    let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
    if dest_meta.is_some() {
        delete_key_in_tx(&mut tx, dest).await?;
    }

    sqlx::query("UPDATE redis_kv SET r_key = ? WHERE r_key = ?")
        .bind(dest.as_ref())
        .bind(src.as_ref())
        .execute(&mut *tx)
        .await
        .map_err(map_sql_err)?;

    match src_meta.r_type {
        TYPE_HASH => update_key_in_tx(&mut tx, "redis_hash", src, dest).await?,
        TYPE_SET => update_key_in_tx(&mut tx, "redis_set", src, dest).await?,
        TYPE_ZSET => update_key_in_tx(&mut tx, "redis_zset", src, dest).await?,
        TYPE_LIST => {
            update_key_in_tx(&mut tx, "redis_list_meta", src, dest).await?;
            update_key_in_tx(&mut tx, "redis_list", src, dest).await?;
        }
        _ => {}
    }

    tx.commit().await.map_err(map_sql_err)?;

    Ok(if allow_overwrite {
        ok()
    } else {
        Value::Integer(1)
    })
}

pub async fn dbsize(
    _cmd: Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let now = now_ms() as i64;
    let row = sqlx::query(
        "SELECT COUNT(*) AS count FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?)",
    )
    .bind(now)
    .fetch_one(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;
    let count: i64 = row.try_get("count").map_err(map_sql_err)?;
    Ok(Value::Integer(count))
}

pub async fn flushdb(
    _cmd: Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    flush_all(state, session).await
}

pub async fn flushall(
    _cmd: Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    flush_all(state, session).await
}

async fn flush_all(
    state: Arc<AppState>,
    session: Arc<crate::state::Session>,
) -> Result<Value, RespError> {
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
    sqlx::query("DELETE FROM redis_hash")
        .execute(&mut *tx)
        .await
        .map_err(map_sql_err)?;
    sqlx::query("DELETE FROM redis_set")
        .execute(&mut *tx)
        .await
        .map_err(map_sql_err)?;
    sqlx::query("DELETE FROM redis_zset")
        .execute(&mut *tx)
        .await
        .map_err(map_sql_err)?;
    sqlx::query("DELETE FROM redis_list")
        .execute(&mut *tx)
        .await
        .map_err(map_sql_err)?;
    sqlx::query("DELETE FROM redis_list_meta")
        .execute(&mut *tx)
        .await
        .map_err(map_sql_err)?;
    sqlx::query("DELETE FROM redis_kv")
        .execute(&mut *tx)
        .await
        .map_err(map_sql_err)?;
    tx.commit().await.map_err(map_sql_err)?;
    Ok(ok())
}

pub async fn unlink(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    strings::del(Cmd(cmd), State(state), SessionHandle(session)).await
}

async fn update_key_in_tx(
    tx: &mut sqlx::Transaction<'_, sqlx::MySql>,
    table: &str,
    src: &Bytes,
    dest: &Bytes,
) -> Result<(), RespError> {
    let mut sql = String::from("UPDATE ");
    sql.push_str(table);
    sql.push_str(" SET r_key = ? WHERE r_key = ?");
    sqlx::query(&sql)
        .bind(dest.as_ref())
        .bind(src.as_ref())
        .execute(&mut **tx)
        .await
        .map_err(map_sql_err)?;
    Ok(())
}