redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;

use crate::handlers::util::{arg_as_bytes, arg_as_i64, wrong_arity, wrong_type};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_STRING, delete_key_all, map_sql_err};

enum StringRow {
    Missing,
    WrongType,
    Value {
        data: Vec<u8>,
        expires_at: Option<i64>,
    },
}

enum Overflow {
    Wrap,
    Sat,
    Fail,
}

struct TypeSpec {
    signed: bool,
    bits: u32,
}

pub async fn bitfield(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() < 2 {
        return Err(wrong_arity("BITFIELD"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;

    let mut data = Vec::new();
    let mut expires_at = None;
    match fetch_string(auth.pool.as_ref(), key, now).await? {
        StringRow::Missing => {}
        StringRow::WrongType => return Ok(wrong_type()),
        StringRow::Value {
            data: existing,
            expires_at: exp,
        } => {
            data = existing;
            expires_at = exp;
        }
    }

    let mut overflow = Overflow::Wrap;
    let mut results = Vec::new();
    let mut modified = false;

    let mut i = 1;
    while i < cmd.args.len() {
        let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
        for b in &mut token {
            b.make_ascii_uppercase();
        }
        match token.as_slice() {
            b"OVERFLOW" => {
                let next = cmd
                    .args
                    .get(i + 1)
                    .ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
                let mut mode = arg_as_bytes(next)?.to_vec();
                for b in &mut mode {
                    b.make_ascii_uppercase();
                }
                overflow = match mode.as_slice() {
                    b"WRAP" => Overflow::Wrap,
                    b"SAT" => Overflow::Sat,
                    b"FAIL" => Overflow::Fail,
                    _ => return Err(RespError::invalid_data("ERR syntax error")),
                };
                i += 2;
            }
            b"GET" => {
                let spec = parse_type(arg_as_bytes(&cmd.args[i + 1])?)?;
                let offset = parse_offset(arg_as_bytes(&cmd.args[i + 2])?, spec.bits)?;
                let value = read_bits(&data, offset, spec.bits);
                let signed = if spec.signed {
                    to_signed(value, spec.bits)
                } else {
                    value as i64
                };
                results.push(Value::Integer(signed));
                i += 3;
            }
            b"SET" => {
                let spec = parse_type(arg_as_bytes(&cmd.args[i + 1])?)?;
                let offset = parse_offset(arg_as_bytes(&cmd.args[i + 2])?, spec.bits)?;
                let new_value = arg_as_i64(&cmd.args[i + 3])?;
                let current = read_bits(&data, offset, spec.bits);
                let current_signed = if spec.signed {
                    to_signed(current, spec.bits)
                } else {
                    current as i64
                };
                let new_unsigned = to_unsigned(new_value, spec.bits);
                write_bits(&mut data, offset, spec.bits, new_unsigned);
                results.push(Value::Integer(current_signed));
                modified = true;
                i += 4;
            }
            b"INCRBY" => {
                let spec = parse_type(arg_as_bytes(&cmd.args[i + 1])?)?;
                let offset = parse_offset(arg_as_bytes(&cmd.args[i + 2])?, spec.bits)?;
                let increment = arg_as_i64(&cmd.args[i + 3])?;
                let current = read_bits(&data, offset, spec.bits);
                let current_signed = if spec.signed {
                    to_signed(current, spec.bits)
                } else {
                    current as i64
                };
                let (min, max) = type_range(&spec);
                let mut next = current_signed.saturating_add(increment);
                let mut write = true;
                if next < min || next > max {
                    match overflow {
                        Overflow::Wrap => {
                            next = wrap_value(next, min, max);
                        }
                        Overflow::Sat => {
                            next = next.clamp(min, max);
                        }
                        Overflow::Fail => {
                            write = false;
                        }
                    }
                }
                if write {
                    let next_unsigned = to_unsigned(next, spec.bits);
                    write_bits(&mut data, offset, spec.bits, next_unsigned);
                    results.push(Value::Integer(next));
                    modified = true;
                } else {
                    results.push(Value::Null);
                }
                i += 4;
            }
            _ => return Err(RespError::invalid_data("ERR syntax error")),
        }
    }

    if modified {
        sqlx::query(
            "INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, ?) \
             ON DUPLICATE KEY UPDATE r_value = VALUES(r_value), r_len = VALUES(r_len), expires_at_ms = VALUES(expires_at_ms)",
        )
        .bind(key.as_ref())
        .bind(TYPE_STRING)
        .bind(data.as_slice())
        .bind(data.len() as i64)
        .bind(expires_at)
        .execute(auth.pool.as_ref())
        .await
        .map_err(map_sql_err)?;
    }

    Ok(Value::Array(results))
}

fn parse_type(raw: &Bytes) -> Result<TypeSpec, RespError> {
    let raw = std::str::from_utf8(raw).map_err(|_| RespError::invalid_data("ERR invalid type"))?;
    if raw.len() < 2 {
        return Err(RespError::invalid_data("ERR invalid type"));
    }
    let (signed, bits_str) = match raw.as_bytes()[0] {
        b'i' | b'I' => (true, &raw[1..]),
        b'u' | b'U' => (false, &raw[1..]),
        _ => return Err(RespError::invalid_data("ERR invalid type")),
    };
    let bits: u32 = bits_str
        .parse()
        .map_err(|_| RespError::invalid_data("ERR invalid type"))?;
    if bits == 0 || bits > 64 {
        return Err(RespError::invalid_data("ERR invalid type"));
    }
    Ok(TypeSpec { signed, bits })
}

fn parse_offset(raw: &Bytes, width: u32) -> Result<u64, RespError> {
    let s = std::str::from_utf8(raw).map_err(|_| RespError::invalid_data("ERR invalid offset"))?;
    if let Some(rest) = s.strip_prefix('#') {
        let idx: u64 = rest
            .parse()
            .map_err(|_| RespError::invalid_data("ERR invalid offset"))?;
        return Ok(idx * width as u64);
    }
    let value: i64 = s
        .parse()
        .map_err(|_| RespError::invalid_data("ERR invalid offset"))?;
    if value < 0 {
        return Err(RespError::invalid_data("ERR invalid offset"));
    }
    Ok(value as u64)
}

fn type_range(spec: &TypeSpec) -> (i64, i64) {
    if spec.signed {
        let max = (1i64 << (spec.bits - 1)) - 1;
        let min = -(1i64 << (spec.bits - 1));
        (min, max)
    } else {
        (0, (1i64 << spec.bits) - 1)
    }
}

fn wrap_value(value: i64, min: i64, max: i64) -> i64 {
    let range = max - min + 1;
    let mut v = (value - min) % range;
    if v < 0 {
        v += range;
    }
    v + min
}

fn to_signed(value: u64, bits: u32) -> i64 {
    if bits == 64 {
        return value as i64;
    }
    let sign_bit = 1u64 << (bits - 1);
    if value & sign_bit == 0 {
        value as i64
    } else {
        let mask = (1u64 << bits) - 1;
        -(((value ^ mask) + 1) as i64)
    }
}

fn to_unsigned(value: i64, bits: u32) -> u64 {
    if bits == 64 {
        return value as u64;
    }
    let mask = (1u64 << bits) - 1;
    if value >= 0 {
        (value as u64) & mask
    } else {
        let twos = ((-value as u64) ^ mask) + 1;
        twos & mask
    }
}

fn read_bits(data: &[u8], offset: u64, bits: u32) -> u64 {
    let mut value = 0u64;
    for i in 0..bits {
        let bit = get_bit(data, offset + i as u64);
        value = (value << 1) | bit as u64;
    }
    value
}

fn write_bits(data: &mut Vec<u8>, offset: u64, bits: u32, value: u64) {
    let end_bit = offset + bits as u64;
    let required_len = end_bit.div_ceil(8) as usize;
    if data.len() < required_len {
        data.resize(required_len, 0);
    }
    for i in 0..bits {
        let bit = (value >> (bits - 1 - i)) & 1;
        set_bit(data, offset + i as u64, bit as u8);
    }
}

fn get_bit(data: &[u8], offset: u64) -> u8 {
    let byte_index = (offset / 8) as usize;
    if byte_index >= data.len() {
        return 0;
    }
    let bit_index = 7 - (offset % 8) as u8;
    (data[byte_index] >> bit_index) & 1
}

fn set_bit(data: &mut [u8], offset: u64, bit: u8) {
    let byte_index = (offset / 8) as usize;
    let bit_index = 7 - (offset % 8) as u8;
    if bit == 0 {
        data[byte_index] &= !(1 << bit_index);
    } else {
        data[byte_index] |= 1 << bit_index;
    }
}

async fn fetch_string(
    pool: &sqlx::MySqlPool,
    key: &Bytes,
    now: i64,
) -> Result<StringRow, RespError> {
    let row = sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
        .bind(key.as_ref())
        .fetch_optional(pool)
        .await
        .map_err(map_sql_err)?;
    let Some(row) = row else {
        return Ok(StringRow::Missing);
    };
    let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
    if r_type != TYPE_STRING {
        return Ok(StringRow::WrongType);
    }
    let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
    if let Some(exp) = expires_at
        && exp <= now
    {
        delete_key_all(pool, key).await?;
        return Ok(StringRow::Missing);
    }
    let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
    Ok(StringRow::Value {
        data: value.unwrap_or_default(),
        expires_at,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_type_signed() {
        let spec = parse_type(&Bytes::from_static(b"i8")).unwrap();
        assert!(spec.signed);
        assert_eq!(spec.bits, 8);
    }

    #[test]
    fn bitfield_read_write() {
        let mut data = vec![0u8];
        write_bits(&mut data, 0, 4, 0b1010);
        let value = read_bits(&data, 0, 4);
        assert_eq!(value, 0b1010);
    }
}