use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use crate::handlers::util::{arg_as_bytes, wrong_arity, wrong_type};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_STRING, delete_key_all, map_sql_err};
const HLL_REGISTERS: usize = 16384;
const HLL_P: u32 = 14;
pub async fn pfadd(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("PFADD"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let mut registers = match fetch_hll(auth.pool.as_ref(), key, now).await? {
HllRow::Missing => vec![0u8; HLL_REGISTERS],
HllRow::WrongType => return Ok(wrong_type()),
HllRow::Value { data, .. } => data,
};
let mut changed = false;
for arg in cmd.args.iter().skip(1) {
let value = arg_as_bytes(arg)?;
let hash = murmur_hash64(value.as_ref());
let index = (hash & ((1u64 << HLL_P) - 1)) as usize;
let remaining = hash >> HLL_P;
let rank = rank(remaining, 64 - HLL_P);
if registers[index] < rank {
registers[index] = rank;
changed = true;
}
}
if changed {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, NULL) \
ON DUPLICATE KEY UPDATE r_value = VALUES(r_value), r_len = VALUES(r_len)",
)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(registers.as_slice())
.bind(registers.len() as i64)
.execute(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
}
Ok(Value::Integer(if changed { 1 } else { 0 }))
}
pub async fn pfcount(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity("PFCOUNT"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let now = now_ms() as i64;
let mut merged = vec![0u8; HLL_REGISTERS];
for arg in &cmd.args {
let key = arg_as_bytes(arg)?;
match fetch_hll(auth.pool.as_ref(), key, now).await? {
HllRow::Missing => {}
HllRow::WrongType => return Ok(wrong_type()),
HllRow::Value { data, .. } => {
for i in 0..HLL_REGISTERS {
if merged[i] < data[i] {
merged[i] = data[i];
}
}
}
}
}
let estimate = hll_count(&merged);
Ok(Value::Integer(estimate as i64))
}
pub async fn pfmerge(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("PFMERGE"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let dest = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let mut merged = vec![0u8; HLL_REGISTERS];
for arg in cmd.args.iter().skip(1) {
let key = arg_as_bytes(arg)?;
match fetch_hll(auth.pool.as_ref(), key, now).await? {
HllRow::Missing => {}
HllRow::WrongType => return Ok(wrong_type()),
HllRow::Value { data, .. } => {
for i in 0..HLL_REGISTERS {
if merged[i] < data[i] {
merged[i] = data[i];
}
}
}
}
}
delete_key_all(auth.pool.as_ref(), dest).await?;
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, NULL)",
)
.bind(dest.as_ref())
.bind(TYPE_STRING)
.bind(merged.as_slice())
.bind(merged.len() as i64)
.execute(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
Ok(Value::Simple(Bytes::from_static(b"OK")))
}
enum HllRow {
Missing,
WrongType,
Value { data: Vec<u8> },
}
async fn fetch_hll(pool: &sqlx::MySqlPool, key: &Bytes, now: i64) -> Result<HllRow, RespError> {
let row = sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
.bind(key.as_ref())
.fetch_optional(pool)
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(HllRow::Missing);
};
let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
if r_type != TYPE_STRING {
return Ok(HllRow::WrongType);
}
let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
if let Some(exp) = expires_at
&& exp <= now
{
delete_key_all(pool, key).await?;
return Ok(HllRow::Missing);
}
let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
let mut data = value.unwrap_or_default();
if data.len() != HLL_REGISTERS {
data.resize(HLL_REGISTERS, 0);
}
Ok(HllRow::Value { data })
}
fn murmur_hash64(data: &[u8]) -> u64 {
const M: u64 = 0xc6a4a7935bd1e995;
const R: u32 = 47;
let len = data.len() as u64;
let mut h = len.wrapping_mul(M);
let mut i = 0usize;
while i + 8 <= data.len() {
let mut k = u64::from_le_bytes(data[i..i + 8].try_into().unwrap());
k = k.wrapping_mul(M);
k ^= k >> R;
k = k.wrapping_mul(M);
h ^= k;
h = h.wrapping_mul(M);
i += 8;
}
let remaining = &data[i..];
let mut t = 0u64;
for (shift, &b) in remaining.iter().enumerate() {
t |= (b as u64) << (shift * 8);
}
if !remaining.is_empty() {
h ^= t;
h = h.wrapping_mul(M);
}
h ^= h >> R;
h = h.wrapping_mul(M);
h ^= h >> R;
h
}
fn rank(value: u64, bits: u32) -> u8 {
if value == 0 {
return (bits + 1) as u8;
}
let leading = value.leading_zeros();
let rank = leading + 1;
let max_rank = bits + 1;
if rank > max_rank {
max_rank as u8
} else {
rank as u8
}
}
fn hll_count(registers: &[u8]) -> f64 {
let m = HLL_REGISTERS as f64;
let alpha = 0.7213 / (1.0 + 1.079 / m);
let mut sum = 0.0;
let mut zeros = 0usize;
for &r in registers {
sum += 2f64.powi(-(r as i32));
if r == 0 {
zeros += 1;
}
}
let estimate = alpha * m * m / sum;
if estimate <= 2.5 * m && zeros > 0 {
m * (m / zeros as f64).ln()
} else {
estimate
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hll_empty_count_zero() {
let regs = vec![0u8; HLL_REGISTERS];
let estimate = hll_count(®s);
assert!(estimate < 1.0);
}
#[test]
fn rank_zero_value() {
let r = rank(0, 50);
assert_eq!(r, 51);
}
}