redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;

use crate::handlers::util::{arg_as_bytes, wrong_arity, wrong_type};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_STRING, delete_key_all, map_sql_err};

const HLL_REGISTERS: usize = 16384;
const HLL_P: u32 = 14;

pub async fn pfadd(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() < 2 {
        return Err(wrong_arity("PFADD"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let key = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;

    let mut registers = match fetch_hll(auth.pool.as_ref(), key, now).await? {
        HllRow::Missing => vec![0u8; HLL_REGISTERS],
        HllRow::WrongType => return Ok(wrong_type()),
        HllRow::Value { data, .. } => data,
    };

    let mut changed = false;
    for arg in cmd.args.iter().skip(1) {
        let value = arg_as_bytes(arg)?;
        let hash = murmur_hash64(value.as_ref());
        let index = (hash & ((1u64 << HLL_P) - 1)) as usize;
        let remaining = hash >> HLL_P;
        let rank = rank(remaining, 64 - HLL_P);
        if registers[index] < rank {
            registers[index] = rank;
            changed = true;
        }
    }

    if changed {
        sqlx::query(
            "INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, NULL) \
             ON DUPLICATE KEY UPDATE r_value = VALUES(r_value), r_len = VALUES(r_len)",
        )
        .bind(key.as_ref())
        .bind(TYPE_STRING)
        .bind(registers.as_slice())
        .bind(registers.len() as i64)
        .execute(auth.pool.as_ref())
        .await
        .map_err(map_sql_err)?;
    }

    Ok(Value::Integer(if changed { 1 } else { 0 }))
}

pub async fn pfcount(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.is_empty() {
        return Err(wrong_arity("PFCOUNT"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let now = now_ms() as i64;

    let mut merged = vec![0u8; HLL_REGISTERS];
    for arg in &cmd.args {
        let key = arg_as_bytes(arg)?;
        match fetch_hll(auth.pool.as_ref(), key, now).await? {
            HllRow::Missing => {}
            HllRow::WrongType => return Ok(wrong_type()),
            HllRow::Value { data, .. } => {
                for i in 0..HLL_REGISTERS {
                    if merged[i] < data[i] {
                        merged[i] = data[i];
                    }
                }
            }
        }
    }

    let estimate = hll_count(&merged);
    Ok(Value::Integer(estimate as i64))
}

pub async fn pfmerge(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() < 2 {
        return Err(wrong_arity("PFMERGE"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let dest = arg_as_bytes(&cmd.args[0])?;
    let now = now_ms() as i64;

    let mut merged = vec![0u8; HLL_REGISTERS];
    for arg in cmd.args.iter().skip(1) {
        let key = arg_as_bytes(arg)?;
        match fetch_hll(auth.pool.as_ref(), key, now).await? {
            HllRow::Missing => {}
            HllRow::WrongType => return Ok(wrong_type()),
            HllRow::Value { data, .. } => {
                for i in 0..HLL_REGISTERS {
                    if merged[i] < data[i] {
                        merged[i] = data[i];
                    }
                }
            }
        }
    }

    delete_key_all(auth.pool.as_ref(), dest).await?;
    sqlx::query(
        "INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, NULL)",
    )
    .bind(dest.as_ref())
    .bind(TYPE_STRING)
    .bind(merged.as_slice())
    .bind(merged.len() as i64)
    .execute(auth.pool.as_ref())
    .await
    .map_err(map_sql_err)?;

    Ok(Value::Simple(Bytes::from_static(b"OK")))
}

enum HllRow {
    Missing,
    WrongType,
    Value { data: Vec<u8> },
}

async fn fetch_hll(pool: &sqlx::MySqlPool, key: &Bytes, now: i64) -> Result<HllRow, RespError> {
    let row = sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
        .bind(key.as_ref())
        .fetch_optional(pool)
        .await
        .map_err(map_sql_err)?;
    let Some(row) = row else {
        return Ok(HllRow::Missing);
    };
    let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
    if r_type != TYPE_STRING {
        return Ok(HllRow::WrongType);
    }
    let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
    if let Some(exp) = expires_at
        && exp <= now
    {
        delete_key_all(pool, key).await?;
        return Ok(HllRow::Missing);
    }
    let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
    let mut data = value.unwrap_or_default();
    if data.len() != HLL_REGISTERS {
        data.resize(HLL_REGISTERS, 0);
    }
    Ok(HllRow::Value { data })
}

fn murmur_hash64(data: &[u8]) -> u64 {
    const M: u64 = 0xc6a4a7935bd1e995;
    const R: u32 = 47;
    let len = data.len() as u64;
    let mut h = len.wrapping_mul(M);
    let mut i = 0usize;
    while i + 8 <= data.len() {
        let mut k = u64::from_le_bytes(data[i..i + 8].try_into().unwrap());
        k = k.wrapping_mul(M);
        k ^= k >> R;
        k = k.wrapping_mul(M);
        h ^= k;
        h = h.wrapping_mul(M);
        i += 8;
    }
    let remaining = &data[i..];
    let mut t = 0u64;
    for (shift, &b) in remaining.iter().enumerate() {
        t |= (b as u64) << (shift * 8);
    }
    if !remaining.is_empty() {
        h ^= t;
        h = h.wrapping_mul(M);
    }
    h ^= h >> R;
    h = h.wrapping_mul(M);
    h ^= h >> R;
    h
}

fn rank(value: u64, bits: u32) -> u8 {
    if value == 0 {
        return (bits + 1) as u8;
    }
    let leading = value.leading_zeros();
    let rank = leading + 1;
    let max_rank = bits + 1;
    if rank > max_rank {
        max_rank as u8
    } else {
        rank as u8
    }
}

fn hll_count(registers: &[u8]) -> f64 {
    let m = HLL_REGISTERS as f64;
    let alpha = 0.7213 / (1.0 + 1.079 / m);
    let mut sum = 0.0;
    let mut zeros = 0usize;
    for &r in registers {
        sum += 2f64.powi(-(r as i32));
        if r == 0 {
            zeros += 1;
        }
    }
    let estimate = alpha * m * m / sum;
    if estimate <= 2.5 * m && zeros > 0 {
        m * (m / zeros as f64).ln()
    } else {
        estimate
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn hll_empty_count_zero() {
        let regs = vec![0u8; HLL_REGISTERS];
        let estimate = hll_count(&regs);
        assert!(estimate < 1.0);
    }

    #[test]
    fn rank_zero_value() {
        let r = rank(0, 50);
        assert_eq!(r, 51);
    }
}