use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use crate::handlers::util::{arg_as_bytes, arg_as_i64, wrong_arity, wrong_type};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_STRING, delete_key_all, map_sql_err};
enum StringRow {
Missing,
WrongType,
Value {
data: Vec<u8>,
expires_at: Option<i64>,
},
}
enum Overflow {
Wrap,
Sat,
Fail,
}
struct TypeSpec {
signed: bool,
bits: u32,
}
pub async fn bitfield(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("BITFIELD"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let mut data = Vec::new();
let mut expires_at = None;
match fetch_string(auth.pool.as_ref(), key, now).await? {
StringRow::Missing => {}
StringRow::WrongType => return Ok(wrong_type()),
StringRow::Value {
data: existing,
expires_at: exp,
} => {
data = existing;
expires_at = exp;
}
}
let mut overflow = Overflow::Wrap;
let mut results = Vec::new();
let mut modified = false;
let mut i = 1;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"OVERFLOW" => {
let next = cmd
.args
.get(i + 1)
.ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
let mut mode = arg_as_bytes(next)?.to_vec();
for b in &mut mode {
b.make_ascii_uppercase();
}
overflow = match mode.as_slice() {
b"WRAP" => Overflow::Wrap,
b"SAT" => Overflow::Sat,
b"FAIL" => Overflow::Fail,
_ => return Err(RespError::invalid_data("ERR syntax error")),
};
i += 2;
}
b"GET" => {
let spec = parse_type(arg_as_bytes(&cmd.args[i + 1])?)?;
let offset = parse_offset(arg_as_bytes(&cmd.args[i + 2])?, spec.bits)?;
let value = read_bits(&data, offset, spec.bits);
let signed = if spec.signed {
to_signed(value, spec.bits)
} else {
value as i64
};
results.push(Value::Integer(signed));
i += 3;
}
b"SET" => {
let spec = parse_type(arg_as_bytes(&cmd.args[i + 1])?)?;
let offset = parse_offset(arg_as_bytes(&cmd.args[i + 2])?, spec.bits)?;
let new_value = arg_as_i64(&cmd.args[i + 3])?;
let current = read_bits(&data, offset, spec.bits);
let current_signed = if spec.signed {
to_signed(current, spec.bits)
} else {
current as i64
};
let new_unsigned = to_unsigned(new_value, spec.bits);
write_bits(&mut data, offset, spec.bits, new_unsigned);
results.push(Value::Integer(current_signed));
modified = true;
i += 4;
}
b"INCRBY" => {
let spec = parse_type(arg_as_bytes(&cmd.args[i + 1])?)?;
let offset = parse_offset(arg_as_bytes(&cmd.args[i + 2])?, spec.bits)?;
let increment = arg_as_i64(&cmd.args[i + 3])?;
let current = read_bits(&data, offset, spec.bits);
let current_signed = if spec.signed {
to_signed(current, spec.bits)
} else {
current as i64
};
let (min, max) = type_range(&spec);
let mut next = current_signed.saturating_add(increment);
let mut write = true;
if next < min || next > max {
match overflow {
Overflow::Wrap => {
next = wrap_value(next, min, max);
}
Overflow::Sat => {
next = next.clamp(min, max);
}
Overflow::Fail => {
write = false;
}
}
}
if write {
let next_unsigned = to_unsigned(next, spec.bits);
write_bits(&mut data, offset, spec.bits, next_unsigned);
results.push(Value::Integer(next));
modified = true;
} else {
results.push(Value::Null);
}
i += 4;
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
}
}
if modified {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, ?) \
ON DUPLICATE KEY UPDATE r_value = VALUES(r_value), r_len = VALUES(r_len), expires_at_ms = VALUES(expires_at_ms)",
)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(data.as_slice())
.bind(data.len() as i64)
.bind(expires_at)
.execute(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
}
Ok(Value::Array(results))
}
fn parse_type(raw: &Bytes) -> Result<TypeSpec, RespError> {
let raw = std::str::from_utf8(raw).map_err(|_| RespError::invalid_data("ERR invalid type"))?;
if raw.len() < 2 {
return Err(RespError::invalid_data("ERR invalid type"));
}
let (signed, bits_str) = match raw.as_bytes()[0] {
b'i' | b'I' => (true, &raw[1..]),
b'u' | b'U' => (false, &raw[1..]),
_ => return Err(RespError::invalid_data("ERR invalid type")),
};
let bits: u32 = bits_str
.parse()
.map_err(|_| RespError::invalid_data("ERR invalid type"))?;
if bits == 0 || bits > 64 {
return Err(RespError::invalid_data("ERR invalid type"));
}
Ok(TypeSpec { signed, bits })
}
fn parse_offset(raw: &Bytes, width: u32) -> Result<u64, RespError> {
let s = std::str::from_utf8(raw).map_err(|_| RespError::invalid_data("ERR invalid offset"))?;
if let Some(rest) = s.strip_prefix('#') {
let idx: u64 = rest
.parse()
.map_err(|_| RespError::invalid_data("ERR invalid offset"))?;
return Ok(idx * width as u64);
}
let value: i64 = s
.parse()
.map_err(|_| RespError::invalid_data("ERR invalid offset"))?;
if value < 0 {
return Err(RespError::invalid_data("ERR invalid offset"));
}
Ok(value as u64)
}
fn type_range(spec: &TypeSpec) -> (i64, i64) {
if spec.signed {
let max = (1i64 << (spec.bits - 1)) - 1;
let min = -(1i64 << (spec.bits - 1));
(min, max)
} else {
(0, (1i64 << spec.bits) - 1)
}
}
fn wrap_value(value: i64, min: i64, max: i64) -> i64 {
let range = max - min + 1;
let mut v = (value - min) % range;
if v < 0 {
v += range;
}
v + min
}
fn to_signed(value: u64, bits: u32) -> i64 {
if bits == 64 {
return value as i64;
}
let sign_bit = 1u64 << (bits - 1);
if value & sign_bit == 0 {
value as i64
} else {
let mask = (1u64 << bits) - 1;
-(((value ^ mask) + 1) as i64)
}
}
fn to_unsigned(value: i64, bits: u32) -> u64 {
if bits == 64 {
return value as u64;
}
let mask = (1u64 << bits) - 1;
if value >= 0 {
(value as u64) & mask
} else {
let twos = ((-value as u64) ^ mask) + 1;
twos & mask
}
}
fn read_bits(data: &[u8], offset: u64, bits: u32) -> u64 {
let mut value = 0u64;
for i in 0..bits {
let bit = get_bit(data, offset + i as u64);
value = (value << 1) | bit as u64;
}
value
}
fn write_bits(data: &mut Vec<u8>, offset: u64, bits: u32, value: u64) {
let end_bit = offset + bits as u64;
let required_len = end_bit.div_ceil(8) as usize;
if data.len() < required_len {
data.resize(required_len, 0);
}
for i in 0..bits {
let bit = (value >> (bits - 1 - i)) & 1;
set_bit(data, offset + i as u64, bit as u8);
}
}
fn get_bit(data: &[u8], offset: u64) -> u8 {
let byte_index = (offset / 8) as usize;
if byte_index >= data.len() {
return 0;
}
let bit_index = 7 - (offset % 8) as u8;
(data[byte_index] >> bit_index) & 1
}
fn set_bit(data: &mut [u8], offset: u64, bit: u8) {
let byte_index = (offset / 8) as usize;
let bit_index = 7 - (offset % 8) as u8;
if bit == 0 {
data[byte_index] &= !(1 << bit_index);
} else {
data[byte_index] |= 1 << bit_index;
}
}
async fn fetch_string(
pool: &sqlx::MySqlPool,
key: &Bytes,
now: i64,
) -> Result<StringRow, RespError> {
let row = sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
.bind(key.as_ref())
.fetch_optional(pool)
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(StringRow::Missing);
};
let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
if r_type != TYPE_STRING {
return Ok(StringRow::WrongType);
}
let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
if let Some(exp) = expires_at
&& exp <= now
{
delete_key_all(pool, key).await?;
return Ok(StringRow::Missing);
}
let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
Ok(StringRow::Value {
data: value.unwrap_or_default(),
expires_at,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_type_signed() {
let spec = parse_type(&Bytes::from_static(b"i8")).unwrap();
assert!(spec.signed);
assert_eq!(spec.bits, 8);
}
#[test]
fn bitfield_read_write() {
let mut data = vec![0u8];
write_bits(&mut data, 0, 4, 0b1010);
let value = read_bits(&data, 0, 4);
assert_eq!(value, 0b1010);
}
}