use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use std::sync::Arc;
use crate::handlers::strings;
use crate::handlers::util::{
arg_as_bytes, arg_as_i64, glob_has_wildcards, glob_match, ok, random_index, wrong_arity,
};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{
TYPE_HASH, TYPE_LIST, TYPE_SET, TYPE_ZSET, delete_key_in_tx, load_meta, map_sql_err, type_name,
};
pub async fn type_cmd(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity("TYPE"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let name = match load_meta(auth.pool.as_ref(), key, now).await? {
Some(meta) => type_name(meta.r_type),
None => b"none",
};
Ok(Value::Simple(Bytes::from_static(name)))
}
pub async fn keys(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity("KEYS"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let pattern = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
if !glob_has_wildcards(pattern.as_ref()) {
let exists = load_meta(auth.pool.as_ref(), pattern, now).await?.is_some();
if exists {
return Ok(Value::Array(vec![Value::Bulk(pattern.clone())]));
}
return Ok(Value::Array(Vec::new()));
}
let rows = sqlx::query(
"SELECT r_key FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?)",
)
.bind(now)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut out = Vec::new();
for row in rows {
let key: Vec<u8> = row.try_get("r_key").map_err(map_sql_err)?;
if glob_match(pattern.as_ref(), &key) {
out.push(Value::Bulk(Bytes::from(key)));
}
}
Ok(Value::Array(out))
}
pub async fn scan(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity("SCAN"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let mut cursor = arg_as_i64(&cmd.args[0])?;
if cursor < 0 {
return Err(RespError::invalid_data("ERR invalid cursor"));
}
let mut count = 10i64;
let mut pattern: Option<Bytes> = None;
let mut i = 1;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"MATCH" => {
let next = cmd
.args
.get(i + 1)
.ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
pattern = Some(arg_as_bytes(next)?.clone());
i += 2;
}
b"COUNT" => {
let next = cmd
.args
.get(i + 1)
.ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
count = arg_as_i64(next)?;
if count <= 0 {
return Err(RespError::invalid_data("ERR invalid COUNT"));
}
i += 2;
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
}
}
let now = now_ms() as i64;
let rows = sqlx::query(
"SELECT r_key FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?) \
ORDER BY r_key LIMIT ?, ?",
)
.bind(now)
.bind(cursor)
.bind(count)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut out = Vec::new();
for row in &rows {
let key: Vec<u8> = row.try_get("r_key").map_err(map_sql_err)?;
if let Some(pattern) = &pattern
&& !glob_match(pattern.as_ref(), &key)
{
continue;
}
out.push(Value::Bulk(Bytes::from(key)));
}
if rows.len() < count as usize {
cursor = 0;
} else {
cursor = cursor.saturating_add(count);
}
Ok(Value::Array(vec![
Value::Bulk(Bytes::from(cursor.to_string())),
Value::Array(out),
]))
}
pub async fn randomkey(
_cmd: Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let now = now_ms() as i64;
let row = sqlx::query(
"SELECT COUNT(*) AS count FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?)",
)
.bind(now)
.fetch_one(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let count: i64 = row.try_get("count").map_err(map_sql_err)?;
if count <= 0 {
return Ok(Value::Null);
}
let offset = random_index(count);
let row = sqlx::query(
"SELECT r_key FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?) \
ORDER BY r_key LIMIT 1 OFFSET ?",
)
.bind(now)
.bind(offset)
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(Value::Null);
};
let key: Vec<u8> = row.try_get("r_key").map_err(map_sql_err)?;
Ok(Value::Bulk(Bytes::from(key)))
}
pub async fn rename(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
rename_inner(cmd, state, session, true).await
}
pub async fn renamenx(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
rename_inner(cmd, state, session, false).await
}
async fn rename_inner(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
allow_overwrite: bool,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity(if allow_overwrite {
"RENAME"
} else {
"RENAMENX"
}));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let src = arg_as_bytes(&cmd.args[0])?;
let dest = arg_as_bytes(&cmd.args[1])?;
if src == dest {
return Ok(if allow_overwrite {
ok()
} else {
Value::Integer(0)
});
}
let now = now_ms() as i64;
let src_meta = load_meta(auth.pool.as_ref(), src, now).await?;
let Some(src_meta) = src_meta else {
return Err(RespError::invalid_data("ERR no such key"));
};
let dest_meta = load_meta(auth.pool.as_ref(), dest, now).await?;
if dest_meta.is_some() && !allow_overwrite {
return Ok(Value::Integer(0));
}
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
if dest_meta.is_some() {
delete_key_in_tx(&mut tx, dest).await?;
}
sqlx::query("UPDATE redis_kv SET r_key = ? WHERE r_key = ?")
.bind(dest.as_ref())
.bind(src.as_ref())
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
match src_meta.r_type {
TYPE_HASH => update_key_in_tx(&mut tx, "redis_hash", src, dest).await?,
TYPE_SET => update_key_in_tx(&mut tx, "redis_set", src, dest).await?,
TYPE_ZSET => update_key_in_tx(&mut tx, "redis_zset", src, dest).await?,
TYPE_LIST => {
update_key_in_tx(&mut tx, "redis_list_meta", src, dest).await?;
update_key_in_tx(&mut tx, "redis_list", src, dest).await?;
}
_ => {}
}
tx.commit().await.map_err(map_sql_err)?;
Ok(if allow_overwrite {
ok()
} else {
Value::Integer(1)
})
}
pub async fn dbsize(
_cmd: Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let now = now_ms() as i64;
let row = sqlx::query(
"SELECT COUNT(*) AS count FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ?)",
)
.bind(now)
.fetch_one(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let count: i64 = row.try_get("count").map_err(map_sql_err)?;
Ok(Value::Integer(count))
}
pub async fn flushdb(
_cmd: Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
flush_all(state, session).await
}
pub async fn flushall(
_cmd: Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
flush_all(state, session).await
}
async fn flush_all(
state: Arc<AppState>,
session: Arc<crate::state::Session>,
) -> Result<Value, RespError> {
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
sqlx::query("DELETE FROM redis_hash")
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
sqlx::query("DELETE FROM redis_set")
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
sqlx::query("DELETE FROM redis_zset")
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
sqlx::query("DELETE FROM redis_list")
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
sqlx::query("DELETE FROM redis_list_meta")
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
sqlx::query("DELETE FROM redis_kv")
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
tx.commit().await.map_err(map_sql_err)?;
Ok(ok())
}
pub async fn unlink(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
strings::del(Cmd(cmd), State(state), SessionHandle(session)).await
}
async fn update_key_in_tx(
tx: &mut sqlx::Transaction<'_, sqlx::MySql>,
table: &str,
src: &Bytes,
dest: &Bytes,
) -> Result<(), RespError> {
let mut sql = String::from("UPDATE ");
sql.push_str(table);
sql.push_str(" SET r_key = ? WHERE r_key = ?");
sqlx::query(&sql)
.bind(dest.as_ref())
.bind(src.as_ref())
.execute(&mut **tx)
.await
.map_err(map_sql_err)?;
Ok(())
}