use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use std::collections::HashMap;
use std::sync::Arc;
use crate::handlers::util::{
arg_as_bytes, arg_as_f64, arg_as_i64, glob_match, wrong_arity, wrong_type,
};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_ZSET, load_meta, map_sql_err};
pub async fn zadd(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 3 {
return Err(wrong_arity("ZADD"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
if let Some(meta) = load_meta(auth.pool.as_ref(), key, now).await?
&& meta.r_type != TYPE_ZSET
{
return Ok(wrong_type());
}
let mut nx = false;
let mut xx = false;
let mut ch = false;
let mut incr = false;
let mut i = 1;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"NX" => nx = true,
b"XX" => xx = true,
b"CH" => ch = true,
b"INCR" => incr = true,
_ => break,
}
i += 1;
}
if i >= cmd.args.len() || (cmd.args.len() - i) % 2 != 0 {
return Err(wrong_arity("ZADD"));
}
let mut pairs = Vec::new();
let mut j = i;
while j < cmd.args.len() {
let score = arg_as_f64(&cmd.args[j])?;
let member = arg_as_bytes(&cmd.args[j + 1])?.clone();
pairs.push((member, score));
j += 2;
}
if incr && pairs.len() != 1 {
return Err(RespError::invalid_data(
"ERR INCR option supports a single increment",
));
}
let mut existing = HashMap::new();
if !pairs.is_empty() {
let mut qb = sqlx::QueryBuilder::new("SELECT member, score FROM redis_zset WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for (member, _) in &pairs {
separated.push_bind(member.as_ref());
}
qb.push(")");
let rows = qb
.build()
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
existing.insert(member, score);
}
}
if incr {
let (member, increment) = &pairs[0];
let current = existing.get(member.as_ref()).copied().unwrap_or(0.0);
if nx && existing.contains_key(member.as_ref()) {
return Ok(Value::Null);
}
if xx && !existing.contains_key(member.as_ref()) {
return Ok(Value::Null);
}
let new_score = current + increment;
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
let res = sqlx::query(
"INSERT INTO redis_zset (r_key, member, score) VALUES (?, ?, ?) \
ON DUPLICATE KEY UPDATE score = VALUES(score)",
)
.bind(key.as_ref())
.bind(member.as_ref())
.bind(new_score)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 && !existing.contains_key(member.as_ref()) {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_len, expires_at_ms) \
VALUES (?, ?, 1, NULL) \
ON DUPLICATE KEY UPDATE r_len = r_len + 1",
)
.bind(key.as_ref())
.bind(TYPE_ZSET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
tx.commit().await.map_err(map_sql_err)?;
return Ok(Value::Bulk(Bytes::from(new_score.to_string())));
}
let mut to_write = Vec::new();
let mut added = 0i64;
let mut changed = 0i64;
for (member, score) in &pairs {
let existing_score = existing.get(member.as_ref());
if nx && existing_score.is_some() {
continue;
}
if xx && existing_score.is_none() {
continue;
}
if existing_score.is_none() {
added += 1;
} else if existing_score
.map(|v| (*v - score).abs() > f64::EPSILON)
.unwrap_or(false)
{
changed += 1;
}
to_write.push((member.clone(), *score));
}
if !to_write.is_empty() {
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
if added > 0 {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_len, expires_at_ms) \
VALUES (?, ?, ?, NULL) \
ON DUPLICATE KEY UPDATE r_len = r_len + VALUES(r_len)",
)
.bind(key.as_ref())
.bind(TYPE_ZSET)
.bind(added)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
let mut qb = sqlx::QueryBuilder::new("INSERT INTO redis_zset (r_key, member, score) ");
qb.push_values(to_write.iter(), |mut row, (member, score)| {
row.push_bind(key.as_ref());
row.push_bind(member.as_ref());
row.push_bind(*score);
});
qb.push(" ON DUPLICATE KEY UPDATE score = VALUES(score)");
qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
tx.commit().await.map_err(map_sql_err)?;
}
let count = if ch { added + changed } else { added };
Ok(Value::Integer(count))
}
pub async fn zrem(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("ZREM"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Integer(0));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let mut qb = sqlx::QueryBuilder::new("DELETE FROM redis_zset WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for arg in cmd.args.iter().skip(1) {
let member = arg_as_bytes(arg)?;
separated.push_bind(member.as_ref());
}
qb.push(")");
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
let res = qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
let removed = res.rows_affected() as i64;
if removed > 0 {
let res =
sqlx::query("UPDATE redis_kv SET r_len = r_len - ? WHERE r_key = ? AND r_type = ?")
.bind(removed)
.bind(key.as_ref())
.bind(TYPE_ZSET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 {
sqlx::query("DELETE FROM redis_kv WHERE r_key = ? AND r_type = ? AND r_len <= 0")
.bind(key.as_ref())
.bind(TYPE_ZSET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
}
tx.commit().await.map_err(map_sql_err)?;
Ok(Value::Integer(removed))
}
pub async fn zincrby(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 3 {
return Err(wrong_arity("ZINCRBY"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let increment = arg_as_f64(&cmd.args[1])?;
let member = arg_as_bytes(&cmd.args[2])?;
let now = now_ms() as i64;
if let Some(meta) = load_meta(auth.pool.as_ref(), key, now).await?
&& meta.r_type != TYPE_ZSET
{
return Ok(wrong_type());
}
let row = sqlx::query("SELECT score FROM redis_zset WHERE r_key = ? AND member = ?")
.bind(key.as_ref())
.bind(member.as_ref())
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let (mut score, existed) = if let Some(row) = row {
(row.try_get::<f64, _>("score").map_err(map_sql_err)?, true)
} else {
(0.0, false)
};
score += increment;
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
let res = sqlx::query(
"INSERT INTO redis_zset (r_key, member, score) VALUES (?, ?, ?) \
ON DUPLICATE KEY UPDATE score = VALUES(score)",
)
.bind(key.as_ref())
.bind(member.as_ref())
.bind(score)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 && !existed {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_len, expires_at_ms) \
VALUES (?, ?, 1, NULL) \
ON DUPLICATE KEY UPDATE r_len = r_len + 1",
)
.bind(key.as_ref())
.bind(TYPE_ZSET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
tx.commit().await.map_err(map_sql_err)?;
Ok(Value::Bulk(Bytes::from(score.to_string())))
}
pub async fn zscore(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity("ZSCORE"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let member = arg_as_bytes(&cmd.args[1])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Null);
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let row = sqlx::query("SELECT score FROM redis_zset WHERE r_key = ? AND member = ?")
.bind(key.as_ref())
.bind(member.as_ref())
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(Value::Null);
};
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
Ok(Value::Bulk(Bytes::from(score.to_string())))
}
pub async fn zcard(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 1 {
return Err(wrong_arity("ZCARD"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Integer(0));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
Ok(Value::Integer(meta.r_len))
}
pub async fn zcount(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 3 {
return Err(wrong_arity("ZCOUNT"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Integer(0));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let (min, min_ex) = parse_score_bound(arg_as_bytes(&cmd.args[1])?)?;
let (max, max_ex) = parse_score_bound(arg_as_bytes(&cmd.args[2])?)?;
let mut sql = String::from("SELECT COUNT(*) AS count FROM redis_zset WHERE r_key = ? AND ");
sql.push_str(&score_cond("score", min_ex, max_ex));
let row = sqlx::query(&sql)
.bind(key.as_ref())
.bind(min)
.bind(max)
.fetch_one(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let count: i64 = row.try_get("count").map_err(map_sql_err)?;
Ok(Value::Integer(count))
}
pub async fn zrange(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
range_by_rank(cmd, state, session, false).await
}
pub async fn zrevrange(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
range_by_rank(cmd, state, session, true).await
}
async fn range_by_rank(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
reverse: bool,
) -> Result<Value, RespError> {
if cmd.args.len() < 3 {
return Err(wrong_arity(if reverse { "ZREVRANGE" } else { "ZRANGE" }));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let start = arg_as_i64(&cmd.args[1])?;
let stop = arg_as_i64(&cmd.args[2])?;
let with_scores = cmd.args.len() > 3
&& arg_as_bytes(&cmd.args[3])?
.as_ref()
.eq_ignore_ascii_case(b"WITHSCORES");
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(Vec::new()));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let Some((offset, count)) = normalize_range(start, stop, meta.r_len) else {
return Ok(Value::Array(Vec::new()));
};
let order = if reverse { "DESC" } else { "ASC" };
let sql = format!(
"SELECT member, score FROM redis_zset WHERE r_key = ? \
ORDER BY score {order}, member {order} LIMIT ?, ?"
);
let rows = sqlx::query(&sql)
.bind(key.as_ref())
.bind(offset)
.bind(count)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
Ok(build_member_score(rows, with_scores))
}
pub async fn zrangebyscore(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
range_by_score(cmd, state, session, false).await
}
pub async fn zrevrangebyscore(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
range_by_score(cmd, state, session, true).await
}
async fn range_by_score(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
reverse: bool,
) -> Result<Value, RespError> {
if cmd.args.len() < 3 {
return Err(wrong_arity(if reverse {
"ZREVRANGEBYSCORE"
} else {
"ZRANGEBYSCORE"
}));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let (min, min_ex, max, max_ex) = if reverse {
let (max, max_ex) = parse_score_bound(arg_as_bytes(&cmd.args[1])?)?;
let (min, min_ex) = parse_score_bound(arg_as_bytes(&cmd.args[2])?)?;
(min, min_ex, max, max_ex)
} else {
let (min, min_ex) = parse_score_bound(arg_as_bytes(&cmd.args[1])?)?;
let (max, max_ex) = parse_score_bound(arg_as_bytes(&cmd.args[2])?)?;
(min, min_ex, max, max_ex)
};
let mut with_scores = false;
let mut offset: Option<i64> = None;
let mut count: Option<i64> = None;
let mut i = 3;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"WITHSCORES" => {
with_scores = true;
i += 1;
}
b"LIMIT" => {
let off = arg_as_i64(&cmd.args[i + 1])?;
let cnt = arg_as_i64(&cmd.args[i + 2])?;
offset = Some(off);
count = Some(cnt);
i += 3;
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
}
}
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(Vec::new()));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let order = if reverse { "DESC" } else { "ASC" };
let mut sql = format!(
"SELECT member, score FROM redis_zset WHERE r_key = ? AND {} \
ORDER BY score {order}, member {order}",
score_cond("score", min_ex, max_ex)
);
if let Some(off) = offset {
sql.push_str(" LIMIT ");
sql.push_str(&format!("{}, {}", off.max(0), count.unwrap_or(0).max(0)));
}
let rows = sqlx::query(&sql)
.bind(key.as_ref())
.bind(min)
.bind(max)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
Ok(build_member_score(rows, with_scores))
}
pub async fn zrank(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
rank(cmd, state, session, false).await
}
pub async fn zrevrank(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
rank(cmd, state, session, true).await
}
async fn rank(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
reverse: bool,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity(if reverse { "ZREVRANK" } else { "ZRANK" }));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let member = arg_as_bytes(&cmd.args[1])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Null);
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let row = sqlx::query("SELECT score FROM redis_zset WHERE r_key = ? AND member = ?")
.bind(key.as_ref())
.bind(member.as_ref())
.fetch_optional(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(Value::Null);
};
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
let sql = if reverse {
"SELECT COUNT(*) AS count FROM redis_zset WHERE r_key = ? AND (score > ? OR (score = ? AND member > ?))"
} else {
"SELECT COUNT(*) AS count FROM redis_zset WHERE r_key = ? AND (score < ? OR (score = ? AND member < ?))"
};
let row = sqlx::query(sql)
.bind(key.as_ref())
.bind(score)
.bind(score)
.bind(member.as_ref())
.fetch_one(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let count: i64 = row.try_get("count").map_err(map_sql_err)?;
Ok(Value::Integer(count))
}
pub async fn zpopmin(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
pop_extreme(cmd, state, session, false).await
}
pub async fn zpopmax(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
pop_extreme(cmd, state, session, true).await
}
async fn pop_extreme(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
reverse: bool,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity(if reverse { "ZPOPMAX" } else { "ZPOPMIN" }));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let count = if cmd.args.len() >= 2 {
arg_as_i64(&cmd.args[1])?
} else {
1
};
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(Vec::new()));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let order = if reverse { "DESC" } else { "ASC" };
let sql = format!(
"SELECT member, score FROM redis_zset WHERE r_key = ? ORDER BY score {order}, member {order} LIMIT ?"
);
let rows = sqlx::query(&sql)
.bind(key.as_ref())
.bind(count)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
if rows.is_empty() {
return Ok(Value::Array(Vec::new()));
}
let mut members = Vec::with_capacity(rows.len());
for row in &rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
members.push(Bytes::from(member));
}
let mut qb = sqlx::QueryBuilder::new("DELETE FROM redis_zset WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for member in &members {
separated.push_bind(member.as_ref());
}
qb.push(")");
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
let removed = members.len() as i64;
let res = sqlx::query("UPDATE redis_kv SET r_len = r_len - ? WHERE r_key = ? AND r_type = ?")
.bind(removed)
.bind(key.as_ref())
.bind(TYPE_ZSET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
if res.rows_affected() > 0 {
sqlx::query("DELETE FROM redis_kv WHERE r_key = ? AND r_type = ? AND r_len <= 0")
.bind(key.as_ref())
.bind(TYPE_ZSET)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
tx.commit().await.map_err(map_sql_err)?;
Ok(build_member_score(rows, true))
}
pub async fn zlexcount(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 3 {
return Err(wrong_arity("ZLEXCOUNT"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Integer(0));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
ensure_lex_compatible(auth.pool.as_ref(), key).await?;
let (min, min_ex) = parse_lex_bound(arg_as_bytes(&cmd.args[1])?)?;
let (max, max_ex) = parse_lex_bound(arg_as_bytes(&cmd.args[2])?)?;
let mut sql = String::from("SELECT COUNT(*) AS count FROM redis_zset WHERE r_key = ?");
sql.push_str(&lex_where("member", &min, min_ex, &max, max_ex));
let mut query = sqlx::query(&sql).bind(key.as_ref());
let min_bind = min.as_ref().map(|v| v.to_vec());
let max_bind = max.as_ref().map(|v| v.to_vec());
if let Some(val) = min_bind {
query = query.bind(val);
}
if let Some(val) = max_bind {
query = query.bind(val);
}
let row = query
.fetch_one(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let count: i64 = row.try_get("count").map_err(map_sql_err)?;
Ok(Value::Integer(count))
}
pub async fn zrangebylex(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
range_by_lex(cmd, state, session, false).await
}
pub async fn zrevrangebylex(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
range_by_lex(cmd, state, session, true).await
}
async fn range_by_lex(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
reverse: bool,
) -> Result<Value, RespError> {
if cmd.args.len() < 3 {
return Err(wrong_arity(if reverse {
"ZREVRANGEBYLEX"
} else {
"ZRANGEBYLEX"
}));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let (min, min_ex, max, max_ex) = if reverse {
let (max, max_ex) = parse_lex_bound(arg_as_bytes(&cmd.args[1])?)?;
let (min, min_ex) = parse_lex_bound(arg_as_bytes(&cmd.args[2])?)?;
(min, min_ex, max, max_ex)
} else {
let (min, min_ex) = parse_lex_bound(arg_as_bytes(&cmd.args[1])?)?;
let (max, max_ex) = parse_lex_bound(arg_as_bytes(&cmd.args[2])?)?;
(min, min_ex, max, max_ex)
};
let mut offset: Option<i64> = None;
let mut count: Option<i64> = None;
let mut i = 3;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"LIMIT" => {
let off = arg_as_i64(&cmd.args[i + 1])?;
let cnt = arg_as_i64(&cmd.args[i + 2])?;
offset = Some(off);
count = Some(cnt);
i += 3;
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
}
}
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(Vec::new()));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
ensure_lex_compatible(auth.pool.as_ref(), key).await?;
let order = if reverse { "DESC" } else { "ASC" };
let mut sql = format!(
"SELECT member FROM redis_zset WHERE r_key = ?{} ORDER BY member {order}",
lex_where("member", &min, min_ex, &max, max_ex)
);
if let Some(off) = offset {
sql.push_str(" LIMIT ");
sql.push_str(&format!("{}, {}", off.max(0), count.unwrap_or(0).max(0)));
}
let mut query = sqlx::query(&sql).bind(key.as_ref());
let min_bind = min.as_ref().map(|v| v.to_vec());
let max_bind = max.as_ref().map(|v| v.to_vec());
if let Some(val) = min_bind {
query = query.bind(val);
}
if let Some(val) = max_bind {
query = query.bind(val);
}
let rows = query
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
out.push(Value::Bulk(Bytes::from(member)));
}
Ok(Value::Array(out))
}
pub async fn zscan(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("ZSCAN"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(vec![
Value::Bulk(Bytes::from_static(b"0")),
Value::Array(Vec::new()),
]));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let cursor = arg_as_bytes(&cmd.args[1])?;
let mut count = 10i64;
let mut pattern: Option<Bytes> = None;
let mut i = 2;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"MATCH" => {
let next = cmd
.args
.get(i + 1)
.ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
pattern = Some(arg_as_bytes(next)?.clone());
i += 2;
}
b"COUNT" => {
let next = cmd
.args
.get(i + 1)
.ok_or_else(|| RespError::invalid_data("ERR syntax error"))?;
count = arg_as_i64(next)?;
if count <= 0 {
return Err(RespError::invalid_data("ERR invalid COUNT"));
}
i += 2;
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
}
}
let rows = sqlx::query(
"SELECT member, score FROM redis_zset WHERE r_key = ? AND member > ? \
ORDER BY member LIMIT ?",
)
.bind(key.as_ref())
.bind(cursor.as_ref())
.bind(count)
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut out = Vec::new();
let mut next_cursor: Option<Bytes> = None;
for row in &rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
next_cursor = Some(Bytes::from(member.clone()));
if let Some(pattern) = &pattern
&& !glob_match(pattern.as_ref(), &member)
{
continue;
}
out.push(Value::Bulk(Bytes::from(member)));
out.push(Value::Bulk(Bytes::from(score.to_string())));
}
let cursor_value = match next_cursor {
Some(value) if rows.len() as i64 == count => value,
_ => Bytes::from_static(b"0"),
};
Ok(Value::Array(vec![
Value::Bulk(cursor_value),
Value::Array(out),
]))
}
fn normalize_range(start: i64, stop: i64, len: i64) -> Option<(i64, i64)> {
if len <= 0 {
return None;
}
let mut start = if start < 0 { len + start } else { start };
let mut stop = if stop < 0 { len + stop } else { stop };
if start < 0 {
start = 0;
}
if stop >= len {
stop = len - 1;
}
if start > stop || start >= len {
return None;
}
let count = stop - start + 1;
Some((start, count))
}
fn build_member_score(rows: Vec<sqlx::mysql::MySqlRow>, with_scores: bool) -> Value {
let mut out = Vec::new();
for row in rows {
let member: Vec<u8> = row.try_get("member").unwrap_or_default();
let score: f64 = row.try_get("score").unwrap_or(0.0);
out.push(Value::Bulk(Bytes::from(member)));
if with_scores {
out.push(Value::Bulk(Bytes::from(score.to_string())));
}
}
Value::Array(out)
}
fn parse_score_bound(arg: &Bytes) -> Result<(f64, bool), RespError> {
if arg.as_ref() == b"-inf" {
return Ok((f64::NEG_INFINITY, false));
}
if arg.as_ref() == b"+inf" || arg.as_ref() == b"inf" {
return Ok((f64::INFINITY, false));
}
if let Some(rest) = arg.as_ref().strip_prefix(b"(") {
let s =
std::str::from_utf8(rest).map_err(|_| RespError::invalid_data("ERR invalid float"))?;
let v = s
.parse::<f64>()
.map_err(|_| RespError::invalid_data("ERR invalid float"))?;
return Ok((v, true));
}
let s = std::str::from_utf8(arg.as_ref())
.map_err(|_| RespError::invalid_data("ERR invalid float"))?;
let v = s
.parse::<f64>()
.map_err(|_| RespError::invalid_data("ERR invalid float"))?;
Ok((v, false))
}
fn score_cond(field: &str, min_ex: bool, max_ex: bool) -> String {
let left = if min_ex { ">" } else { ">=" };
let right = if max_ex { "<" } else { "<=" };
format!("{field} {left} ? AND {field} {right} ?")
}
fn parse_lex_bound(arg: &Bytes) -> Result<(Option<Bytes>, bool), RespError> {
if arg.as_ref() == b"-" {
return Ok((None, false));
}
if arg.as_ref() == b"+" {
return Ok((None, false));
}
if let Some(rest) = arg.as_ref().strip_prefix(b"(") {
return Ok((Some(Bytes::copy_from_slice(rest)), true));
}
if let Some(rest) = arg.as_ref().strip_prefix(b"[") {
return Ok((Some(Bytes::copy_from_slice(rest)), false));
}
Err(RespError::invalid_data("ERR invalid lex range"))
}
fn lex_where(
field: &str,
min: &Option<Bytes>,
min_ex: bool,
max: &Option<Bytes>,
max_ex: bool,
) -> String {
let mut out = String::new();
if min.is_some() {
out.push_str(" AND ");
out.push_str(field);
out.push_str(if min_ex { " > ?" } else { " >= ?" });
}
if max.is_some() {
out.push_str(" AND ");
out.push_str(field);
out.push_str(if max_ex { " < ?" } else { " <= ?" });
}
out
}
async fn ensure_lex_compatible(pool: &sqlx::MySqlPool, key: &Bytes) -> Result<(), RespError> {
let row = sqlx::query(
"SELECT MIN(score) AS min_score, MAX(score) AS max_score FROM redis_zset WHERE r_key = ?",
)
.bind(key.as_ref())
.fetch_one(pool)
.await
.map_err(map_sql_err)?;
let min: Option<f64> = row.try_get("min_score").map_err(map_sql_err)?;
let max: Option<f64> = row.try_get("max_score").map_err(map_sql_err)?;
if min.is_none() || max.is_none() {
return Ok(());
}
if (min.unwrap() - max.unwrap()).abs() > f64::EPSILON {
return Err(RespError::invalid_data(
"ERR operation against a sorted set with different scores",
));
}
Ok(())
}