use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use crate::handlers::util::{arg_as_bytes, arg_as_i64, wrong_arity, wrong_type};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_STRING, delete_key_all, map_sql_err};
enum StringRow {
Missing,
WrongType,
Value {
data: Vec<u8>,
expires_at: Option<i64>,
},
}
pub async fn getbit(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity("GETBIT"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let offset = arg_as_i64(&cmd.args[1])?;
if offset < 0 {
return Err(RespError::invalid_data(
"ERR bit offset is not an integer or out of range",
));
}
let now = now_ms() as i64;
match fetch_string(auth.pool.as_ref(), key, now).await? {
StringRow::Missing => Ok(Value::Integer(0)),
StringRow::WrongType => Ok(wrong_type()),
StringRow::Value { data, .. } => {
let byte_index = (offset / 8) as usize;
if byte_index >= data.len() {
return Ok(Value::Integer(0));
}
let bit_index = 7 - (offset % 8) as u8;
let bit = (data[byte_index] >> bit_index) & 1;
Ok(Value::Integer(bit as i64))
}
}
}
pub async fn setbit(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 3 {
return Err(wrong_arity("SETBIT"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let offset = arg_as_i64(&cmd.args[1])?;
let bit_value = arg_as_i64(&cmd.args[2])?;
if offset < 0 || (bit_value != 0 && bit_value != 1) {
return Err(RespError::invalid_data(
"ERR bit is not an integer or out of range",
));
}
let now = now_ms() as i64;
let mut data = Vec::new();
let mut expires_at = None;
match fetch_string(auth.pool.as_ref(), key, now).await? {
StringRow::Missing => {}
StringRow::WrongType => return Ok(wrong_type()),
StringRow::Value {
data: existing,
expires_at: exp,
} => {
data = existing;
expires_at = exp;
}
}
let byte_index = (offset / 8) as usize;
if data.len() <= byte_index {
data.resize(byte_index + 1, 0);
}
let bit_index = 7 - (offset % 8) as u8;
let prev = (data[byte_index] >> bit_index) & 1;
if bit_value == 0 {
data[byte_index] &= !(1 << bit_index);
} else {
data[byte_index] |= 1 << bit_index;
}
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, ?) \
ON DUPLICATE KEY UPDATE r_value = VALUES(r_value), r_len = VALUES(r_len), expires_at_ms = VALUES(expires_at_ms)",
)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(data.as_slice())
.bind(data.len() as i64)
.bind(expires_at)
.execute(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
Ok(Value::Integer(prev as i64))
}
pub async fn bitcount(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 && cmd.args.len() != 3 {
return Err(wrong_arity("BITCOUNT"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let data = match fetch_string(auth.pool.as_ref(), key, now).await? {
StringRow::Missing => return Ok(Value::Integer(0)),
StringRow::WrongType => return Ok(wrong_type()),
StringRow::Value { data, .. } => data,
};
let (start, end) = if cmd.args.len() == 3 {
(arg_as_i64(&cmd.args[1])?, arg_as_i64(&cmd.args[2])?)
} else {
(0, data.len() as i64 - 1)
};
let (_, slice) = slice_range(&data, start, end);
let mut count = 0u64;
for b in slice {
count += b.count_ones() as u64;
}
Ok(Value::Integer(count as i64))
}
pub async fn bitpos(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 || cmd.args.len() > 4 {
return Err(wrong_arity("BITPOS"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let bit = arg_as_i64(&cmd.args[1])?;
if bit != 0 && bit != 1 {
return Err(RespError::invalid_data(
"ERR bit is not an integer or out of range",
));
}
let now = now_ms() as i64;
let data = match fetch_string(auth.pool.as_ref(), key, now).await? {
StringRow::Missing => {
return Ok(Value::Integer(if bit == 0 { 0 } else { -1 }));
}
StringRow::WrongType => return Ok(wrong_type()),
StringRow::Value { data, .. } => data,
};
let (start, end) = if cmd.args.len() >= 3 {
let start = arg_as_i64(&cmd.args[2])?;
let end = if cmd.args.len() == 4 {
arg_as_i64(&cmd.args[3])?
} else {
data.len() as i64 - 1
};
(start, end)
} else {
(0, data.len() as i64 - 1)
};
let (start_idx, slice) = slice_range(&data, start, end);
if slice.is_empty() {
return Ok(Value::Integer(if bit == 0 { 0 } else { -1 }));
}
let mut pos = (start_idx as i64) * 8;
for byte in slice {
for i in 0..8 {
let bit_val = (byte >> (7 - i)) & 1;
if bit_val as i64 == bit {
return Ok(Value::Integer(pos));
}
pos += 1;
}
}
if bit == 0 {
Ok(Value::Integer(pos))
} else {
Ok(Value::Integer(-1))
}
}
pub async fn bitop(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 3 {
return Err(wrong_arity("BITOP"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let op = arg_as_bytes(&cmd.args[0])?;
let dest = arg_as_bytes(&cmd.args[1])?;
let keys = cmd
.args
.iter()
.skip(2)
.map(arg_as_bytes)
.collect::<Result<Vec<_>, _>>()?;
let mut op_upper = op.to_vec();
for b in &mut op_upper {
b.make_ascii_uppercase();
}
if op_upper.as_slice() == b"NOT" && keys.len() != 1 {
return Err(wrong_arity("BITOP"));
}
let now = now_ms() as i64;
let mut values = Vec::new();
let mut max_len = 0usize;
for key in &keys {
match fetch_string(auth.pool.as_ref(), key, now).await? {
StringRow::Missing => values.push(Vec::new()),
StringRow::WrongType => return Ok(wrong_type()),
StringRow::Value { data, .. } => {
max_len = max_len.max(data.len());
values.push(data);
}
}
}
let result = match op_upper.as_slice() {
b"NOT" => {
let src = &values[0];
let mut out = Vec::with_capacity(src.len());
for &b in src {
out.push(!b);
}
out
}
b"AND" | b"OR" | b"XOR" => {
let mut out = vec![0u8; max_len];
for i in 0..max_len {
let mut acc = if op_upper.as_slice() == b"AND" {
0xFF
} else {
0x00
};
for value in &values {
let b = if i < value.len() { value[i] } else { 0 };
acc = match op_upper.as_slice() {
b"AND" => acc & b,
b"OR" => acc | b,
_ => acc ^ b,
};
}
out[i] = acc;
}
out
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
};
delete_key_all(auth.pool.as_ref(), dest).await?;
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) VALUES (?, ?, ?, ?, NULL)",
)
.bind(dest.as_ref())
.bind(TYPE_STRING)
.bind(result.as_slice())
.bind(result.len() as i64)
.execute(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
Ok(Value::Integer(result.len() as i64))
}
fn slice_range(data: &[u8], start: i64, end: i64) -> (usize, &[u8]) {
if data.is_empty() {
return (0, &[]);
}
let len = data.len() as i64;
let mut start = if start < 0 { len + start } else { start };
let mut end = if end < 0 { len + end } else { end };
if start < 0 {
start = 0;
}
if end >= len {
end = len - 1;
}
if start > end || start >= len {
return (0, &[]);
}
(start as usize, &data[start as usize..=end as usize])
}
async fn fetch_string(
pool: &sqlx::MySqlPool,
key: &Bytes,
now: i64,
) -> Result<StringRow, RespError> {
let row = sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
.bind(key.as_ref())
.fetch_optional(pool)
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(StringRow::Missing);
};
let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
if r_type != TYPE_STRING {
return Ok(StringRow::WrongType);
}
let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
if let Some(exp) = expires_at
&& exp <= now
{
delete_key_all(pool, key).await?;
return Ok(StringRow::Missing);
}
let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
Ok(StringRow::Value {
data: value.unwrap_or_default(),
expires_at,
})
}