use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::{MySqlPool, Row};
use std::sync::Arc;
use tokio::time::{self, Duration};
use crate::handlers::util::{
arg_as_bytes, arg_as_i64, invalid_arguments, ok, wrong_arity, wrong_type,
};
use crate::state::{AppState, Session, SessionHandle, now_ms};
use crate::storage::{
TYPE_STRING, delete_key_all, delete_key_in_tx, delete_keys_all, is_expired, load_meta,
map_sql_err,
};
pub async fn get(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity("GET"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let row = sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
.bind(key.as_ref())
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(Value::Null);
};
let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
if r_type != TYPE_STRING {
return Ok(wrong_type());
}
let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
if is_expired(expires_at, now) {
delete_key_all(auth.pool.as_ref(), key).await?;
return Ok(Value::Null);
}
let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
Ok(value
.map(Bytes::from)
.map(Value::Bulk)
.unwrap_or(Value::Null))
}
pub async fn set(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("SET"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let value = arg_as_bytes(&cmd.args[1])?;
let now = now_ms() as i64;
let opts = parse_set_options(&cmd.args, now)?;
if let Some(meta) = load_meta(auth.pool.as_ref(), key, now).await?
&& meta.r_type != TYPE_STRING
{
return Ok(wrong_type());
}
let outcome = if opts.needs_transaction() {
set_with_retry(
auth.pool.as_ref(),
key,
value,
&opts,
now,
&state.config.string_set_retry,
)
.await?
} else {
set_simple(auth.pool.as_ref(), key, value, &opts).await?;
SetOutcome {
applied: true,
old_value: None,
}
};
if opts.get {
return Ok(outcome.old_value.map(Value::Bulk).unwrap_or(Value::Null));
}
if !outcome.applied {
return Ok(Value::Null);
}
Ok(ok())
}
pub async fn setnx(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity("SETNX"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let value = arg_as_bytes(&cmd.args[1])?;
let now = now_ms() as i64;
let opts = SetOptions {
nx: true,
..SetOptions::default()
};
let outcome = set_with_retry(
auth.pool.as_ref(),
key,
value,
&opts,
now,
&state.config.string_set_retry,
)
.await?;
Ok(Value::Integer(if outcome.applied { 1 } else { 0 }))
}
pub async fn setex(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
setex_inner(cmd, state, session, 1000, "SETEX").await
}
pub async fn psetex(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
setex_inner(cmd, state, session, 1, "PSETEX").await
}
pub async fn del(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity("DEL"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let now = now_ms() as i64;
let mut qb = sqlx::QueryBuilder::new(
"SELECT r_key FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ",
);
qb.push_bind(now);
qb.push(") AND r_key IN (");
let mut separated = qb.separated(", ");
for arg in &cmd.args {
let key = arg_as_bytes(arg)?;
separated.push_bind(key.as_ref());
}
qb.push(")");
let rows = qb
.build()
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
if rows.is_empty() {
return Ok(Value::Integer(0));
}
let mut keys = Vec::with_capacity(rows.len());
for row in rows {
let key: Vec<u8> = row.try_get("r_key").map_err(map_sql_err)?;
keys.push(Bytes::from(key));
}
let count = keys.len() as i64;
delete_keys_all(auth.pool.as_ref(), &keys).await?;
Ok(Value::Integer(count))
}
pub async fn exists(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity("EXISTS"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let now = now_ms() as i64;
let mut qb = sqlx::QueryBuilder::new(
"SELECT COUNT(*) as count FROM redis_kv WHERE (expires_at_ms IS NULL OR expires_at_ms > ",
);
qb.push_bind(now);
qb.push(") AND r_key IN (");
let mut separated = qb.separated(", ");
for arg in &cmd.args {
let key = arg_as_bytes(arg)?;
separated.push_bind(key.as_ref());
}
qb.push(")");
let row = qb
.build()
.fetch_one(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let count: i64 = row.try_get("count").map_err(map_sql_err)?;
Ok(Value::Integer(count))
}
fn parse_set_options(args: &[Value], now_ms: i64) -> Result<SetOptions, RespError> {
let mut opts = SetOptions::default();
let mut i = 2;
while i < args.len() {
let arg = args
.get(i)
.ok_or_else(|| invalid_arguments("ERR syntax error"))?;
let mut opt = arg_as_bytes(arg)?.to_vec();
for b in &mut opt {
b.make_ascii_uppercase();
}
match opt.as_slice() {
b"NX" => opts.nx = true,
b"XX" => opts.xx = true,
b"GET" => opts.get = true,
b"KEEPTTL" => opts.keep_ttl = true,
b"EX" | b"PX" | b"EXAT" | b"PXAT" => {
let next = args
.get(i + 1)
.ok_or_else(|| invalid_arguments("ERR syntax error"))?;
let value = arg_as_i64(next)?;
let expires_at = match opt.as_slice() {
b"EX" => add_ms(now_ms, value, 1000)?,
b"PX" => add_ms(now_ms, value, 1)?,
b"EXAT" => value
.checked_mul(1000)
.ok_or_else(|| invalid_arguments("ERR invalid expire time in SET"))?,
_ => value,
};
if expires_at <= now_ms {
return Err(invalid_arguments("ERR invalid expire time in SET"));
}
if opts.expires_at_ms.is_some() {
return Err(invalid_arguments("ERR syntax error"));
}
opts.expires_at_ms = Some(expires_at);
i += 1;
}
_ => return Err(invalid_arguments("ERR syntax error")),
}
i += 1;
}
if opts.nx && opts.xx {
return Err(invalid_arguments("ERR syntax error"));
}
if opts.keep_ttl && opts.expires_at_ms.is_some() {
return Err(invalid_arguments("ERR syntax error"));
}
Ok(opts)
}
fn add_ms(now_ms: i64, value: i64, multiplier: i64) -> Result<i64, RespError> {
if value <= 0 {
return Err(invalid_arguments("ERR invalid expire time in SET"));
}
let delta = value
.checked_mul(multiplier)
.ok_or_else(|| invalid_arguments("ERR invalid expire time in SET"))?;
now_ms
.checked_add(delta)
.ok_or_else(|| invalid_arguments("ERR invalid expire time in SET"))
}
async fn setex_inner(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<Session>,
multiplier: i64,
name: &str,
) -> Result<Value, RespError> {
if cmd.args.len() != 3 {
return Err(wrong_arity(name));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let ttl = arg_as_i64(&cmd.args[1])?;
let value = arg_as_bytes(&cmd.args[2])?;
let now = now_ms() as i64;
let expires_at = add_ms(now, ttl, multiplier)?;
if let Some(meta) = load_meta(auth.pool.as_ref(), key, now).await?
&& meta.r_type != TYPE_STRING
{
return Ok(wrong_type());
}
let opts = SetOptions {
expires_at_ms: Some(expires_at),
..SetOptions::default()
};
set_simple(auth.pool.as_ref(), key, value, &opts).await?;
Ok(ok())
}
async fn set_simple(
pool: &MySqlPool,
key: &Bytes,
value: &Bytes,
opts: &SetOptions,
) -> Result<(), RespError> {
let len = value.len() as i64;
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) \
VALUES (?, ?, ?, ?, ?) \
ON DUPLICATE KEY UPDATE r_value = VALUES(r_value), r_len = VALUES(r_len), expires_at_ms = VALUES(expires_at_ms)",
)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(value.as_ref())
.bind(len)
.bind(opts.expires_at_ms)
.execute(pool)
.await
.map_err(map_sql_err)?;
Ok(())
}
struct SetOutcome {
applied: bool,
old_value: Option<Bytes>,
}
#[derive(Default)]
struct SetOptions {
nx: bool,
xx: bool,
get: bool,
keep_ttl: bool,
expires_at_ms: Option<i64>,
}
impl SetOptions {
fn needs_transaction(&self) -> bool {
self.nx || self.xx || self.get || self.keep_ttl
}
}
fn retry_backoff_ms(retry: &crate::config::RetryConfig, attempt: usize) -> u64 {
let shift = attempt.min(20);
let base = retry.backoff_base_ms.max(1);
let exp = 1u64 << shift;
let backoff = base.saturating_mul(exp);
backoff.min(retry.backoff_max_ms.max(1))
}
async fn set_with_retry(
pool: &MySqlPool,
key: &Bytes,
value: &Bytes,
opts: &SetOptions,
now: i64,
retry: &crate::config::RetryConfig,
) -> Result<SetOutcome, RespError> {
let mut attempt = 0usize;
let len = value.len() as i64;
let max_attempts = retry.max_attempts.max(1);
loop {
let row = sqlx::query(
"SELECT r_value, r_len, expires_at_ms FROM redis_kv WHERE r_key = ? AND r_type = ?",
)
.bind(key.as_ref())
.bind(TYPE_STRING)
.fetch_optional(pool)
.await
.map_err(map_sql_err)?;
let mut visible = false;
let mut row_present = false;
let mut old_value = None;
let mut current_expiry = None;
let mut current_len = 0i64;
if let Some(row) = row {
row_present = true;
current_expiry = row.try_get("expires_at_ms").map_err(map_sql_err)?;
if is_expired(current_expiry, now) {
delete_key_all(pool, key).await?;
row_present = false;
} else {
visible = true;
current_len = row.try_get("r_len").map_err(map_sql_err)?;
if opts.get {
let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
old_value = value.map(Bytes::from);
}
}
}
if opts.nx && visible {
return Ok(SetOutcome {
applied: false,
old_value: None,
});
}
if opts.xx && !visible {
return Ok(SetOutcome {
applied: false,
old_value: None,
});
}
let new_expiry = if opts.keep_ttl && visible {
current_expiry
} else {
opts.expires_at_ms
};
let applied = if row_present && visible {
let res = sqlx::query(
"UPDATE redis_kv SET r_value = ?, r_len = ?, expires_at_ms = ? \
WHERE r_key = ? AND r_type = ? AND r_len = ? AND (expires_at_ms <=> ?)",
)
.bind(value.as_ref())
.bind(len)
.bind(new_expiry)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(current_len)
.bind(current_expiry)
.execute(pool)
.await
.map_err(map_sql_err)?;
res.rows_affected() > 0
} else {
let res = sqlx::query(
"INSERT IGNORE INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) \
VALUES (?, ?, ?, ?, ?)",
)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(value.as_ref())
.bind(len)
.bind(new_expiry)
.execute(pool)
.await
.map_err(map_sql_err)?;
res.rows_affected() > 0
};
if applied {
return Ok(SetOutcome {
applied: true,
old_value: if opts.get { old_value } else { None },
});
}
attempt += 1;
if attempt >= max_attempts {
return Err(RespError::internal());
}
let jitter = now_ms() & 1;
let backoff_ms = retry_backoff_ms(retry, attempt).saturating_add(jitter);
time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
pub async fn incr(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
incr_by(cmd, state, session, 1, "INCR").await
}
pub async fn decr(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
incr_by(cmd, state, session, -1, "DECR").await
}
async fn incr_by(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<Session>,
delta: i64,
name: &'static str,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity(name));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let retry = &state.config.string_incr_retry;
let max_attempts = retry.max_attempts.max(1);
let deadline = retry
.deadline_ms
.map(|ms| time::Instant::now() + Duration::from_millis(ms.max(1)));
let delta_i128 = i128::from(delta);
let min_bound = (i128::from(i64::MIN) - delta_i128)
.clamp(i128::from(i64::MIN), i128::from(i64::MAX)) as i64;
let max_bound = (i128::from(i64::MAX) - delta_i128)
.clamp(i128::from(i64::MIN), i128::from(i64::MAX)) as i64;
let mut attempt = 0usize;
loop {
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
let res = sqlx::query(
"UPDATE redis_kv \
SET r_value = CAST(CAST(r_value AS DECIMAL(65,0)) + CAST(? AS DECIMAL(65,0)) AS CHAR), \
r_len = OCTET_LENGTH(CAST(CAST(r_value AS DECIMAL(65,0)) + CAST(? AS DECIMAL(65,0)) AS CHAR)) \
WHERE r_key = ? AND r_type = ? \
AND (expires_at_ms IS NULL OR expires_at_ms > ?) \
AND r_value REGEXP '^[+-]?[0-9]+$' \
AND CAST(r_value AS DECIMAL(65,0)) BETWEEN CAST(? AS DECIMAL(65,0)) AND CAST(? AS DECIMAL(65,0))",
)
.bind(delta)
.bind(delta)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(now)
.bind(min_bound)
.bind(max_bound)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 {
let row = sqlx::query(
"SELECT CAST(r_value AS SIGNED) AS value FROM redis_kv WHERE r_key = ?",
)
.bind(key.as_ref())
.fetch_one(&mut *tx)
.await
.map_err(map_sql_err)?;
let new_value: i64 = row
.try_get("value")
.map_err(|_| invalid_arguments("ERR value is not an integer or out of range"))?;
tx.commit().await.map_err(map_sql_err)?;
return Ok(Value::Integer(new_value));
}
let row =
sqlx::query("SELECT r_type, r_value, expires_at_ms FROM redis_kv WHERE r_key = ?")
.bind(key.as_ref())
.fetch_optional(&mut *tx)
.await
.map_err(map_sql_err)?;
if let Some(row) = row {
let r_type: u8 = row.try_get("r_type").map_err(map_sql_err)?;
let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
if is_expired(expires_at, now) {
delete_key_in_tx(&mut tx, key).await?;
} else if r_type != TYPE_STRING {
return Ok(wrong_type());
} else {
let value: Option<Vec<u8>> = row.try_get("r_value").map_err(map_sql_err)?;
let Some(value) = value else {
return Err(invalid_arguments(
"ERR value is not an integer or out of range",
));
};
let text = std::str::from_utf8(&value).map_err(|_| {
invalid_arguments("ERR value is not an integer or out of range")
})?;
if text.parse::<i64>().is_err() {
return Err(invalid_arguments(
"ERR value is not an integer or out of range",
));
}
}
}
let new_value = delta;
let new_text = new_value.to_string();
let new_len = new_text.len() as i64;
let res = sqlx::query(
"INSERT IGNORE INTO redis_kv (r_key, r_type, r_value, r_len, expires_at_ms) \
VALUES (?, ?, ?, ?, NULL)",
)
.bind(key.as_ref())
.bind(TYPE_STRING)
.bind(new_text.as_bytes())
.bind(new_len)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 {
tx.commit().await.map_err(map_sql_err)?;
return Ok(Value::Integer(new_value));
}
let _ = tx.rollback().await;
attempt += 1;
if attempt >= max_attempts
|| deadline
.map(|deadline| time::Instant::now() >= deadline)
.unwrap_or(false)
{
log::debug!(
target: "handler",
"incr_by conflict exhausted attempt={} delta={}",
attempt,
delta
);
return Err(RespError::invalid_data("ERR backend transaction conflict"));
}
let jitter = now_ms() & 1;
let backoff_ms = retry_backoff_ms(retry, attempt).saturating_add(jitter);
time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn parse_set_options_basic() {
let args = vec![
Value::Bulk(Bytes::from_static(b"key")),
Value::Bulk(Bytes::from_static(b"value")),
Value::Bulk(Bytes::from_static(b"EX")),
Value::Bulk(Bytes::from_static(b"10")),
Value::Bulk(Bytes::from_static(b"NX")),
];
let opts = parse_set_options(&args, 100).unwrap();
assert!(opts.nx);
assert!(opts.expires_at_ms.is_some());
}
#[test]
fn parse_set_options_rejects_conflicts() {
let args = vec![
Value::Bulk(Bytes::from_static(b"key")),
Value::Bulk(Bytes::from_static(b"value")),
Value::Bulk(Bytes::from_static(b"NX")),
Value::Bulk(Bytes::from_static(b"XX")),
];
assert!(parse_set_options(&args, 0).is_err());
}
}