redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;

use crate::handlers::util::{arg_as_bytes, arg_as_i64, wrong_arity, wrong_type};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_STRING, delete_key_all, map_sql_err};

enum StringRow {
    Missing,
    WrongType,
    Value {
        data: Vec<u8>,
        expires_at: Option<i64>,
    },
}

pub async fn getbit(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() != 2 {
        return Err(wrong_arity("GETBIT"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let offset = arg_as_i64(&cmd.args[1])?;
    if offset < 0 {
        return Err(RespError::invalid_data(
            "ERR bit offset is not an integer or out of range",
        ));
    }
    let now = now_ms() as i64;
    match fetch_string(auth.pool.as_ref(), key, now).await? {
        StringRow::Missing => Ok(Value::Integer(0)),
        StringRow::WrongType => Ok(wrong_type()),
        StringRow::Value { data, .. } => {
            let byte_index = (offset / 8) as usize;
            if byte_index >= data.len() {
                return Ok(Value::Integer(0));
            }
            let bit_index = 7 - (offset % 8) as u8;
            let bit = (data[byte_index] >> bit_index) & 1;
            Ok(Value::Integer(bit as i64))
        }
    }
}

pub async fn setbit(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() != 3 {
        return Err(wrong_arity("SETBIT"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let offset = arg_as_i64(&cmd.args[1])?;
    let bit_value = arg_as_i64(&cmd.args[2])?;
    if offset < 0 || (bit_value != 0 && bit_value != 1) {
        return Err(RespError::invalid_data(
            "ERR bit is not an integer or out of range",
        ));
    }
    let now = now_ms() as i64;
    let mut data = Vec::new();
    let mut expires_at = None;
    match fetch_string(auth.pool.as_ref(), key, now).await? {
        StringRow::Missing => {}
        StringRow::WrongType => return Ok(wrong_type()),
        StringRow::Value {
            data: existing,
            expires_at: exp,
        } => {
            data = existing;
            expires_at = exp;
        }
    }

    let byte_index = (offset / 8) as usize;
    if data.len() <= byte_index {
        data.resize(byte_index + 1, 0);
    }
    let bit_index = 7 - (offset % 8) as u8;
    let prev = (data[byte_index] >> bit_index) & 1;
    if bit_value == 0 {
        data[byte_index] &= !(1 << bit_index);
    } else {
        data[byte_index] |= 1 << bit_index;
    }

    sqlx::query(
        "INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, ?) \
         ON DUPLICATE KEY UPDATE r_value = VALUES(r_value), r_len = VALUES(r_len), expires_at_ms = VALUES(expires_at_ms)",
    )
    .bind(key.as_ref())
    .bind(TYPE_STRING)
    .bind(data.as_slice())
    .bind(data.len() as i64)
    .bind(expires_at)
    .execute(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;

    Ok(Value::Integer(prev as i64))
}

pub async fn bitcount(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() != 1 && cmd.args.len() != 3 {
        return Err(wrong_arity("BITCOUNT"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;
    let data = match fetch_string(auth.pool.as_ref(), key, now).await? {
        StringRow::Missing => return Ok(Value::Integer(0)),
        StringRow::WrongType => return Ok(wrong_type()),
        StringRow::Value { data, .. } => data,
    };

    let (start, end) = if cmd.args.len() == 3 {
        (arg_as_i64(&cmd.args[1])?, arg_as_i64(&cmd.args[2])?)
    } else {
        (0, data.len() as i64 - 1)
    };

    let (_, slice) = slice_range(&data, start, end);
    let mut count = 0u64;
    for b in slice {
        count += b.count_ones() as u64;
    }
    Ok(Value::Integer(count as i64))
}

pub async fn bitpos(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() < 2 || cmd.args.len() > 4 {
        return Err(wrong_arity("BITPOS"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let bit = arg_as_i64(&cmd.args[1])?;
    if bit != 0 && bit != 1 {
        return Err(RespError::invalid_data(
            "ERR bit is not an integer or out of range",
        ));
    }
    let now = now_ms() as i64;
    let data = match fetch_string(auth.pool.as_ref(), key, now).await? {
        StringRow::Missing => {
            return Ok(Value::Integer(if bit == 0 { 0 } else { -1 }));
        }
        StringRow::WrongType => return Ok(wrong_type()),
        StringRow::Value { data, .. } => data,
    };

    let (start, end) = if cmd.args.len() >= 3 {
        let start = arg_as_i64(&cmd.args[2])?;
        let end = if cmd.args.len() == 4 {
            arg_as_i64(&cmd.args[3])?
        } else {
            data.len() as i64 - 1
        };
        (start, end)
    } else {
        (0, data.len() as i64 - 1)
    };
    let (start_idx, slice) = slice_range(&data, start, end);
    if slice.is_empty() {
        return Ok(Value::Integer(if bit == 0 { 0 } else { -1 }));
    }

    let mut pos = (start_idx as i64) * 8;
    for byte in slice {
        for i in 0..8 {
            let bit_val = (byte >> (7 - i)) & 1;
            if bit_val as i64 == bit {
                return Ok(Value::Integer(pos));
            }
            pos += 1;
        }
    }

    if bit == 0 {
        Ok(Value::Integer(pos))
    } else {
        Ok(Value::Integer(-1))
    }
}

pub async fn bitop(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() < 3 {
        return Err(wrong_arity("BITOP"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let op = arg_as_bytes(&cmd.args[0])?;
    let dest = arg_as_bytes(&cmd.args[1])?;
    let keys = cmd
        .args
        .iter()
        .skip(2)
        .map(arg_as_bytes)
        .collect::<Result<Vec<_>, _>>()?;
    let mut op_upper = op.to_vec();
    for b in &mut op_upper {
        b.make_ascii_uppercase();
    }

    if op_upper.as_slice() == b"NOT" && keys.len() != 1 {
        return Err(wrong_arity("BITOP"));
    }

    let now = now_ms() as i64;
    let mut values = Vec::new();
    let mut max_len = 0usize;
    for key in &keys {
        match fetch_string(auth.pool.as_ref(), key, now).await? {
            StringRow::Missing => values.push(Vec::new()),
            StringRow::WrongType => return Ok(wrong_type()),
            StringRow::Value { data, .. } => {
                max_len = max_len.max(data.len());
                values.push(data);
            }
        }
    }

    let result = match op_upper.as_slice() {
        b"NOT" => {
            let src = &values[0];
            let mut out = Vec::with_capacity(src.len());
            for &b in src {
                out.push(!b);
            }
            out
        }
        b"AND" | b"OR" | b"XOR" => {
            let mut out = vec![0u8; max_len];
            for i in 0..max_len {
                let mut acc = if op_upper.as_slice() == b"AND" {
                    0xFF
                } else {
                    0x00
                };
                for value in &values {
                    let b = if i < value.len() { value[i] } else { 0 };
                    acc = match op_upper.as_slice() {
                        b"AND" => acc & b,
                        b"OR" => acc | b,
                        _ => acc ^ b,
                    };
                }
                out[i] = acc;
            }
            out
        }
        _ => return Err(RespError::invalid_data("ERR syntax error")),
    };

    delete_key_all(auth.pool.as_ref(), dest).await?;
    sqlx::query(
        "INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, NULL)",
    )
    .bind(dest.as_ref())
    .bind(TYPE_STRING)
    .bind(result.as_slice())
    .bind(result.len() as i64)
    .execute(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;

    Ok(Value::Integer(result.len() as i64))
}

fn slice_range(data: &[u8], start: i64, end: i64) -> (usize, &[u8]) {
    if data.is_empty() {
        return (0, &[]);
    }
    let len = data.len() as i64;
    let mut start = if start < 0 { len + start } else { start };
    let mut end = if end < 0 { len + end } else { end };
    if start < 0 {
        start = 0;
    }
    if end >= len {
        end = len - 1;
    }
    if start > end || start >= len {
        return (0, &[]);
    }
    (start as usize, &data[start as usize..=end as usize])
}

async fn fetch_string(
    pool: &sqlx::MySqlPool,
    key: &Bytes,
    now: i64,
) -> Result<StringRow, RespError> {
    let row = sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
        .bind(key.as_ref())
        .fetch_optional(pool)
        .await
        .map_err(map_sql_err)?;
    let Some(row) = row else {
        return Ok(StringRow::Missing);
    };
    let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
    if r_type != TYPE_STRING {
        return Ok(StringRow::WrongType);
    }
    let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
    if let Some(exp) = expires_at
        && exp <= now
    {
        delete_key_all(pool, key).await?;
        return Ok(StringRow::Missing);
    }
    let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
    Ok(StringRow::Value {
        data: value.unwrap_or_default(),
        expires_at,
    })
}