use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::{MySqlPool, Row};
use std::sync::Arc;
use crate::handlers::util::{
arg_as_bytes, arg_as_i64, glob_match, random_index, shuffle_slice, wrong_arity, wrong_type,
};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_SET, delete_key_in_tx, load_meta, map_sql_err};
pub async fn sadd(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("SADD"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
if let Some(meta) = load_meta(auth.pool.as_ref(), key, now).await?
&& meta.r_type != TYPE_SET
{
return Ok(wrong_type());
}
let mut members = Vec::with_capacity(cmd.args.len() - 1);
for arg in cmd.args.iter().skip(1) {
members.push(arg_as_bytes(arg)?.clone());
}
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
let mut qb = sqlx::QueryBuilder::new("INSERT IGNORE INTO redis_set (r_key, member) ");
qb.push_values(members.iter(), |mut row, member| {
row.push_bind(key.as_ref());
row.push_bind(member.as_ref());
});
let res = qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
let added = res.rows_affected() as i64;
if added > 0 {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_len, expires_at_ms) \
VALUES (?, ?, ?, NULL) \
ON DUPLICATE KEY UPDATE r_len = r_len + VALUES(r_len)",
)
.bind(key.as_ref())
.bind(TYPE_SET)
.bind(added)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
tx.commit().await.map_err(map_sql_err)?;
Ok(Value::Integer(added))
}
pub async fn srem(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("SREM"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Integer(0));
};
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
let mut qb = sqlx::QueryBuilder::new("DELETE FROM redis_set WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for arg in cmd.args.iter().skip(1) {
let member = arg_as_bytes(arg)?;
separated.push_bind(member.as_ref());
}
qb.push(")");
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
let res = qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
let removed = res.rows_affected() as i64;
if removed > 0 {
let res =
sqlx::query("UPDATE redis_kv SET r_len = r_len - ? WHERE r_key = ? AND r_type = ?")
.bind(removed)
.bind(key.as_ref())
.bind(TYPE_SET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 {
sqlx::query("DELETE FROM redis_kv WHERE r_key = ? AND r_type = ? AND r_len <= 0")
.bind(key.as_ref())
.bind(TYPE_SET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
}
tx.commit().await.map_err(map_sql_err)?;
Ok(Value::Integer(removed))
}
pub async fn sismember(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity("SISMEMBER"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let member = arg_as_bytes(&cmd.args[1])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Integer(0));
};
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
let row = sqlx::query("SELECT 1 FROM redis_set WHERE r_key = ? AND member = ?")
.bind(key.as_ref())
.bind(member.as_ref())
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
Ok(Value::Integer(if row.is_some() { 1 } else { 0 }))
}
pub async fn smembers(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity("SMEMBERS"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(Vec::new()));
};
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
let rows = sqlx::query("SELECT member FROM redis_set WHERE r_key = ?")
.bind(key.as_ref())
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Value::Bulk(Bytes::from(member)));
}
Ok(Value::Array(out))
}
pub async fn scard(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity("SCARD"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Integer(0));
};
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
Ok(Value::Integer(meta.r_len))
}
pub async fn srandmember(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity("SRANDMEMBER"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let count = if cmd.args.len() >= 2 {
Some(arg_as_i64(&cmd.args[1])?)
} else {
None
};
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(if count.is_some() {
Value::Array(Vec::new())
} else {
Value::Null
});
};
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
let len = meta.r_len;
if len <= 0 {
return Ok(if count.is_some() {
Value::Array(Vec::new())
} else {
Value::Null
});
}
match count {
None => {
let offset = random_index(len);
let row = sqlx::query(
"SELECT member FROM redis_set WHERE r_key = ? ORDER BY member LIMIT 1 OFFSET ?",
)
.bind(key.as_ref())
.bind(offset)
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(Value::Null);
};
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
Ok(Value::Bulk(Bytes::from(member)))
}
Some(n) if n >= 0 => {
if n == 0 {
return Ok(Value::Array(Vec::new()));
}
let fetch_all = n >= len;
let rows = if fetch_all {
sqlx::query("SELECT member FROM redis_set WHERE r_key = ? ORDER BY member")
.bind(key.as_ref())
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?
} else {
let range = len - n + 1;
let offset = random_index(range);
sqlx::query(
"SELECT member FROM redis_set WHERE r_key = ? ORDER BY member LIMIT ? OFFSET ?",
)
.bind(key.as_ref())
.bind(n)
.bind(offset)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?
};
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Bytes::from(member));
}
if fetch_all {
shuffle_slice(&mut out);
}
Ok(Value::Array(out.into_iter().map(Value::Bulk).collect()))
}
Some(n) => {
let n = n.abs();
let mut out = Vec::new();
for _ in 0..n {
let offset = random_index(len);
let row = sqlx::query(
"SELECT member FROM redis_set WHERE r_key = ? ORDER BY member LIMIT 1 OFFSET ?",
)
.bind(key.as_ref())
.bind(offset)
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
break;
};
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Value::Bulk(Bytes::from(member)));
}
Ok(Value::Array(out))
}
}
}
pub async fn spop(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity("SPOP"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Null);
};
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
let count = if cmd.args.len() >= 2 {
arg_as_i64(&cmd.args[1])?
} else {
1
};
if count <= 0 {
return Ok(Value::Array(Vec::new()));
}
let len = meta.r_len;
if len <= 0 {
return Ok(Value::Null);
}
let fetch_all = count >= len;
let rows = if fetch_all {
sqlx::query("SELECT member FROM redis_set WHERE r_key = ? ORDER BY member")
.bind(key.as_ref())
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?
} else {
let range = len - count + 1;
let offset = random_index(range);
sqlx::query("SELECT member FROM redis_set WHERE r_key = ? ORDER BY member LIMIT ? OFFSET ?")
.bind(key.as_ref())
.bind(count)
.bind(offset)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?
};
if rows.is_empty() {
return Ok(Value::Null);
}
let mut members = Vec::with_capacity(rows.len());
for row in &rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
members.push(Bytes::from(member));
}
if fetch_all {
shuffle_slice(&mut members);
}
let mut qb = sqlx::QueryBuilder::new("DELETE FROM redis_set WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for member in &members {
separated.push_bind(member.as_ref());
}
qb.push(")");
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
let removed = members.len() as i64;
let res = sqlx::query("UPDATE redis_kv SET r_len = r_len - ? WHERE r_key = ? AND r_type = ?")
.bind(removed)
.bind(key.as_ref())
.bind(TYPE_SET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 {
sqlx::query("DELETE FROM redis_kv WHERE r_key = ? AND r_type = ? AND r_len <= 0")
.bind(key.as_ref())
.bind(TYPE_SET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
tx.commit().await.map_err(map_sql_err)?;
if count == 1 {
return Ok(Value::Bulk(members.remove(0)));
}
Ok(Value::Array(members.into_iter().map(Value::Bulk).collect()))
}
pub async fn sinter(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
set_op(cmd, state, session, SetOp::Inter, false).await
}
pub async fn sunion(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
set_op(cmd, state, session, SetOp::Union, false).await
}
pub async fn sdiff(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
set_op(cmd, state, session, SetOp::Diff, false).await
}
pub async fn sinterstore(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
set_op(cmd, state, session, SetOp::Inter, true).await
}
pub async fn sunionstore(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
set_op(cmd, state, session, SetOp::Union, true).await
}
pub async fn sdiffstore(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
set_op(cmd, state, session, SetOp::Diff, true).await
}
enum SetOp {
Inter,
Union,
Diff,
}
async fn set_op(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
op: SetOp,
store: bool,
) -> Result<Value, RespError> {
if cmd.args.len() < if store { 2 } else { 1 } {
return Err(wrong_arity(match (op, store) {
(SetOp::Inter, true) => "SINTERSTORE",
(SetOp::Union, true) => "SUNIONSTORE",
(SetOp::Diff, true) => "SDIFFSTORE",
(SetOp::Inter, false) => "SINTER",
(SetOp::Union, false) => "SUNION",
(SetOp::Diff, false) => "SDIFF",
}));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let (dest, keys) = if store {
let dest = arg_as_bytes(&cmd.args[0])?.clone();
let mut keys = Vec::with_capacity(cmd.args.len() - 1);
for arg in cmd.args.iter().skip(1) {
keys.push(arg_as_bytes(arg)?.clone());
}
(Some(dest), keys)
} else {
let mut keys = Vec::with_capacity(cmd.args.len());
for arg in cmd.args.iter() {
keys.push(arg_as_bytes(arg)?.clone());
}
(None, keys)
};
let now = now_ms() as i64;
let mut present_keys = Vec::new();
for key in &keys {
if let Some(meta) = load_meta(auth.pool.as_ref(), key, now).await? {
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
present_keys.push(key.clone());
}
}
if matches!(op, SetOp::Inter) && present_keys.len() != keys.len() {
return Ok(if store {
Value::Integer(0)
} else {
Value::Array(Vec::new())
});
}
let members = match op {
SetOp::Inter => query_intersection(auth.pool.as_ref(), &present_keys).await?,
SetOp::Union => query_union(auth.pool.as_ref(), &present_keys).await?,
SetOp::Diff => query_diff(auth.pool.as_ref(), &present_keys).await?,
};
if let Some(dest) = dest {
store_set(auth.pool.as_ref(), &dest, &members).await?;
return Ok(Value::Integer(members.len() as i64));
}
Ok(Value::Array(members.into_iter().map(Value::Bulk).collect()))
}
async fn query_intersection(pool: &MySqlPool, keys: &[Bytes]) -> Result<Vec<Bytes>, RespError> {
if keys.is_empty() {
return Ok(Vec::new());
}
let mut qb = sqlx::QueryBuilder::new("SELECT member FROM redis_set WHERE r_key IN (");
let mut separated = qb.separated(", ");
for key in keys {
separated.push_bind(key.as_ref());
}
qb.push(") GROUP BY member HAVING COUNT(DISTINCT r_key) = ");
qb.push_bind(keys.len() as i64);
let rows = qb.build().fetch_all(pool).await.map_err(map_sql_err)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Bytes::from(member));
}
Ok(out)
}
async fn query_union(pool: &MySqlPool, keys: &[Bytes]) -> Result<Vec<Bytes>, RespError> {
if keys.is_empty() {
return Ok(Vec::new());
}
let mut qb = sqlx::QueryBuilder::new("SELECT DISTINCT member FROM redis_set WHERE r_key IN (");
let mut separated = qb.separated(", ");
for key in keys {
separated.push_bind(key.as_ref());
}
qb.push(")");
let rows = qb.build().fetch_all(pool).await.map_err(map_sql_err)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Bytes::from(member));
}
Ok(out)
}
async fn query_diff(pool: &MySqlPool, keys: &[Bytes]) -> Result<Vec<Bytes>, RespError> {
if keys.is_empty() {
return Ok(Vec::new());
}
let first = &keys[0];
if keys.len() == 1 {
let rows = sqlx::query("SELECT member FROM redis_set WHERE r_key = ?")
.bind(first.as_ref())
.fetch_all(pool)
.await
.map_err(map_sql_err)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Bytes::from(member));
}
return Ok(out);
}
let mut qb = sqlx::QueryBuilder::new("SELECT member FROM redis_set WHERE r_key = ");
qb.push_bind(first.as_ref());
qb.push(" AND member NOT IN (");
let mut first_union = true;
for key in keys.iter().skip(1) {
if !first_union {
qb.push(" UNION ");
}
first_union = false;
qb.push("SELECT member FROM redis_set WHERE r_key = ");
qb.push_bind(key.as_ref());
}
qb.push(")");
let rows = qb.build().fetch_all(pool).await.map_err(map_sql_err)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Bytes::from(member));
}
Ok(out)
}
async fn store_set(pool: &MySqlPool, key: &Bytes, members: &[Bytes]) -> Result<(), RespError> {
let mut tx = pool.begin().await.map_err(map_sql_err)?;
delete_key_in_tx(&mut tx, key).await?;
if members.is_empty() {
tx.commit().await.map_err(map_sql_err)?;
return Ok(());
}
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_len, expires_at_ms) VALUES (?, ?, ?, NULL)",
)
.bind(key.as_ref())
.bind(TYPE_SET)
.bind(members.len() as i64)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
let mut qb = sqlx::QueryBuilder::new("INSERT INTO redis_set (r_key, member) ");
qb.push_values(members, |mut row, member| {
row.push_bind(key.as_ref());
row.push_bind(member.as_ref());
});
qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
tx.commit().await.map_err(map_sql_err)?;
Ok(())
}
pub async fn sscan(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("SSCAN"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(vec![
Value::Bulk(Bytes::from_static(b"0")),
Value::Array(Vec::new()),
]));
};
if meta.r_type != TYPE_SET {
return Ok(wrong_type());
}
let cursor = arg_as_bytes(&cmd.args[1])?;
let mut count = 10i64;
let mut pattern: Option<Bytes> = None;
let mut i = 2;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"MATCH" => {
let next = cmd
.args
.get(i + 1)
.ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
pattern = Some(arg_as_bytes(next)?.clone());
i += 2;
}
b"COUNT" => {
let next = cmd
.args
.get(i + 1)
.ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
count = arg_as_i64(next)?;
if count <= 0 {
return Err(RespError::invalid_data("ERR invalid COUNT"));
}
i += 2;
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
}
}
let rows = sqlx::query(
"SELECT member FROM redis_set WHERE r_key = ? AND member > ? \
ORDER BY member LIMIT ?",
)
.bind(key.as_ref())
.bind(cursor.as_ref())
.bind(count)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut out = Vec::new();
let mut next_cursor: Option<Bytes> = None;
for row in &rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
next_cursor = Some(Bytes::from(member.clone()));
if let Some(pattern) = &pattern
&& !glob_match(pattern.as_ref(), &member)
{
continue;
}
out.push(Value::Bulk(Bytes::from(member)));
}
let cursor_value = match next_cursor {
Some(value) if rows.len() as i64 == count => value,
_ => Bytes::from_static(b"0"),
};
Ok(Value::Array(vec![
Value::Bulk(cursor_value),
Value::Array(out),
]))
}