use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use std::sync::Arc;
use crate::handlers::util::{arg_as_bytes, arg_as_i64, invalid_integer, wrong_arity};
use crate::state::{AppState, Session, SessionHandle, now_ms};
use crate::storage::{delete_key_all, map_sql_err};
pub async fn expire(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
expire_relative(cmd, state, session, 1000, "EXPIRE").await
}
pub async fn pexpire(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
expire_relative(cmd, state, session, 1, "PEXPIRE").await
}
pub async fn expireat(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
expire_absolute(cmd, state, session, 1000, "EXPIREAT").await
}
pub async fn pexpireat(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
expire_absolute(cmd, state, session, 1, "PEXPIREAT").await
}
pub async fn persist(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity("PERSIST"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let res = sqlx::query(
"UPDATE redis_kv SET expires_at_ms = NULL \
WHERE r_key = ? AND expires_at_ms IS NOT NULL AND expires_at_ms > ?",
)
.bind(key.as_ref())
.bind(now)
.execute(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
Ok(Value::Integer(res.rows_affected() as i64))
}
pub async fn ttl(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
ttl_inner(cmd, state, session, 1000, "TTL").await
}
pub async fn pttl(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
ttl_inner(cmd, state, session, 1, "PTTL").await
}
async fn expire_relative(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<Session>,
multiplier: i64,
name: &str,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity(name));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let ttl = arg_as_i64(&cmd.args[1])?;
if ttl <= 0 {
return delete_key_if_present(auth.pool.as_ref(), key, now_ms() as i64).await;
}
let now = now_ms() as i64;
let expires_at = ttl
.checked_mul(multiplier)
.and_then(|delta| now.checked_add(delta))
.ok_or_else(invalid_integer)?;
expire_at_ms(auth.pool.as_ref(), key, expires_at, now).await
}
async fn expire_absolute(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<Session>,
multiplier: i64,
name: &str,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity(name));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let ts = arg_as_i64(&cmd.args[1])?;
let expires_at = ts.checked_mul(multiplier).ok_or_else(invalid_integer)?;
let now = now_ms() as i64;
if expires_at <= now {
return delete_key_if_present(auth.pool.as_ref(), key, now).await;
}
expire_at_ms(auth.pool.as_ref(), key, expires_at, now).await
}
async fn expire_at_ms(
pool: &sqlx::MySqlPool,
key: &bytes::Bytes,
expires_at: i64,
now: i64,
) -> Result<Value, RespError> {
let res = sqlx::query(
"UPDATE redis_kv SET expires_at_ms = ? \
WHERE r_key = ? AND (expires_at_ms IS NULL OR expires_at_ms > ?)",
)
.bind(expires_at)
.bind(key.as_ref())
.bind(now)
.execute(pool)
.await
.map_err(map_sql_err)?;
Ok(Value::Integer(res.rows_affected() as i64))
}
async fn delete_key_if_present(
pool: &sqlx::MySqlPool,
key: &bytes::Bytes,
now: i64,
) -> Result<Value, RespError> {
let row = sqlx::query(
"SELECT r_key FROM redis_kv WHERE r_key = ? AND (expires_at_ms IS NULL OR expires_at_ms > ?)",
)
.bind(key.as_ref())
.bind(now)
.fetch_optional(pool)
.await
.map_err(map_sql_err)?;
if row.is_none() {
return Ok(Value::Integer(0));
}
delete_key_all(pool, key).await?;
Ok(Value::Integer(1))
}
async fn ttl_inner(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<Session>,
divisor: i64,
name: &str,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity(name));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let row = sqlx::query("SELECT expires_at_ms FROM redis_kv WHERE r_key = ?")
.bind(key.as_ref())
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(Value::Integer(-2));
};
let expires_at: Option<i64> = row.try_get("expires_at_ms").map_err(map_sql_err)?;
let Some(expires_at) = expires_at else {
return Ok(Value::Integer(-1));
};
if expires_at <= now {
delete_key_all(auth.pool.as_ref(), key).await?;
return Ok(Value::Integer(-2));
}
let diff = expires_at - now;
Ok(Value::Integer(diff / divisor))
}
#[cfg(test)]
mod tests {
#[test]
fn ttl_division_floor() {
let now = 1000i64;
let expires_at = 2500i64;
let diff = expires_at - now;
assert_eq!(diff / 1000, 1);
}
}