#[cfg(feature = "server")]
use bytes::BytesMut;
use crate::commands::redis::{
bulk, define_redis_command, eq_ignore_ascii_case, error, int, parse_f64, write_resp_null,
write_resp_wrong_arity, write_resp_wrongtype, wrong_arity, wrongtype,
};
use crate::protocol::Frame;
#[cfg(feature = "server")]
use crate::server::wire::ServerWire;
use crate::storage::{EmbeddedStore, RedisObjectResult};
define_redis_command!(ZAdd, "ZADD", true);
impl crate::commands::redis::RedisCommand for ZAdd {
fn execute(store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
if args.len() < 3 {
return wrong_arity("ZADD");
}
let key = args[0];
let mut index = 1;
let mut nx = false;
let mut xx = false;
let mut gt = false;
let mut lt = false;
let mut ch = false;
let mut incr = false;
while index < args.len() {
let option = args[index];
match option {
option if eq_ignore_ascii_case(option, b"NX") => nx = true,
option if eq_ignore_ascii_case(option, b"XX") => xx = true,
option if eq_ignore_ascii_case(option, b"GT") => gt = true,
option if eq_ignore_ascii_case(option, b"LT") => lt = true,
option if eq_ignore_ascii_case(option, b"CH") => ch = true,
option if eq_ignore_ascii_case(option, b"INCR") => incr = true,
_ => break,
}
index += 1;
}
if index >= args.len() || !(args.len() - index).is_multiple_of(2) {
return error("ERR syntax error");
}
let mut total = 0_i64;
let mut last_bulk = None;
for pair in args[index..].chunks_exact(2) {
let Ok(score) = parse_f64(pair[0]) else {
return error("ERR value is not a valid float");
};
match store.zadd_cond(key, score, pair[1], nx, xx, gt, lt, ch, incr) {
RedisObjectResult::Integer(value) => total += value,
RedisObjectResult::Bulk(value) => last_bulk = value,
RedisObjectResult::WrongType => return wrongtype(),
RedisObjectResult::Simple(message) if message.starts_with("ERR ") => {
return error(message);
}
_ => {}
}
}
if incr {
last_bulk.map_or(Frame::Null, bulk)
} else {
int(total)
}
}
#[cfg(feature = "server")]
fn write_resp(store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
if args.len() < 3 {
write_resp_wrong_arity(out, "ZADD");
return;
}
let key = args[0];
let mut index = 1;
let mut nx = false;
let mut xx = false;
let mut gt = false;
let mut lt = false;
let mut ch = false;
let mut incr = false;
while index < args.len() {
let option = args[index];
match option {
option if eq_ignore_ascii_case(option, b"NX") => nx = true,
option if eq_ignore_ascii_case(option, b"XX") => xx = true,
option if eq_ignore_ascii_case(option, b"GT") => gt = true,
option if eq_ignore_ascii_case(option, b"LT") => lt = true,
option if eq_ignore_ascii_case(option, b"CH") => ch = true,
option if eq_ignore_ascii_case(option, b"INCR") => incr = true,
_ => break,
}
index += 1;
}
if index >= args.len() || !(args.len() - index).is_multiple_of(2) {
ServerWire::write_resp_error(out, "ERR syntax error");
return;
}
let mut total = 0_i64;
let mut last_bulk = None;
for pair in args[index..].chunks_exact(2) {
let Ok(score) = parse_f64(pair[0]) else {
ServerWire::write_resp_error(out, "ERR value is not a valid float");
return;
};
match store.zadd_cond(key, score, pair[1], nx, xx, gt, lt, ch, incr) {
RedisObjectResult::Integer(value) => total += value,
RedisObjectResult::Bulk(value) => last_bulk = value,
RedisObjectResult::WrongType => {
write_resp_wrongtype(out);
return;
}
RedisObjectResult::Simple(message) if message.starts_with("ERR ") => {
ServerWire::write_resp_error(out, message);
return;
}
_ => {}
}
}
if incr {
match last_bulk {
Some(value) => ServerWire::write_resp_blob_string(out, &value),
None => write_resp_null(out),
}
} else {
ServerWire::write_resp_integer(out, total);
}
}
}