use crate::types::RedisBytes;
use crate::{
types::{QueueDescriptor, RsmqMessage, RsmqQueueAttributes},
RsmqError, RsmqResult,
};
use core::convert::TryFrom;
use radix_fmt::radix_36;
use rand::seq::IteratorRandom;
use redis::{aio::ConnectionLike, pipe};
use std::convert::TryInto;
use std::time::Duration;
const JS_COMPAT_MAX_TIME_MILLIS: u64 = 9_999_999_000;
#[cfg(feature = "break-js-comp")]
const DURATION_SCALE: u64 = 1000;
#[cfg(not(feature = "break-js-comp"))]
const DURATION_SCALE: u64 = 1;
#[cfg(feature = "break-js-comp")]
const USE_MICROSECONDS: u64 = 1;
#[cfg(not(feature = "break-js-comp"))]
const USE_MICROSECONDS: u64 = 0;
#[derive(Clone)]
pub struct RsmqFunctions<T: ConnectionLike> {
pub(crate) ns: String,
pub(crate) realtime: bool,
pub(crate) conn: std::marker::PhantomData<T>,
}
impl<T: ConnectionLike> std::fmt::Debug for RsmqFunctions<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "RsmqFunctions")
}
}
#[derive(Debug, Clone)]
pub struct CachedScript {
change_message_visibility_sha1: String,
receive_message_sha1: String,
get_queue_attributes_sha1: String,
}
impl CachedScript {
async fn init<T: ConnectionLike>(conn: &mut T) -> RsmqResult<Self> {
let change_message_visibility_sha1: String = redis::cmd("SCRIPT")
.arg("LOAD")
.arg(include_str!("./redis-scripts/changeMessageVisibility.lua"))
.query_async(conn)
.await?;
let receive_message_sha1: String = redis::cmd("SCRIPT")
.arg("LOAD")
.arg(include_str!("./redis-scripts/receiveMessage.lua"))
.query_async(conn)
.await?;
let get_queue_attributes_sha1: String = redis::cmd("SCRIPT")
.arg("LOAD")
.arg(include_str!("./redis-scripts/getQueueAttributes.lua"))
.query_async(conn)
.await?;
Ok(Self {
change_message_visibility_sha1,
receive_message_sha1,
get_queue_attributes_sha1,
})
}
async fn invoke_change_message_visibility<R, T: ConnectionLike>(
&self,
conn: &mut T,
key1: String,
key2: String,
key3: String,
) -> RsmqResult<R>
where
R: redis::FromRedisValue,
{
redis::cmd("EVALSHA")
.arg(&self.change_message_visibility_sha1)
.arg(3)
.arg(key1)
.arg(key2)
.arg(key3)
.query_async(conn)
.await
.map_err(Into::into)
}
async fn invoke_receive_message<R, T: ConnectionLike>(
&self,
conn: &mut T,
key1: String,
key2: String,
key3: String,
should_delete: String,
) -> RsmqResult<R>
where
R: redis::FromRedisValue,
{
redis::cmd("EVALSHA")
.arg(&self.receive_message_sha1)
.arg(3)
.arg(key1)
.arg(key2)
.arg(key3)
.arg(should_delete)
.query_async(conn)
.await
.map_err(Into::into)
}
async fn invoke_get_queue_attributes<R, T: ConnectionLike>(
&self,
conn: &mut T,
key_hash: String,
key_set: String,
time_multiplier: u64,
) -> RsmqResult<R>
where
R: redis::FromRedisValue,
{
redis::cmd("EVALSHA")
.arg(&self.get_queue_attributes_sha1)
.arg(2)
.arg(key_hash)
.arg(key_set)
.arg(time_multiplier)
.query_async(conn)
.await
.map_err(Into::into)
}
}
impl<T: ConnectionLike> RsmqFunctions<T> {
pub async fn change_message_visibility(
&self,
conn: &mut T,
qname: &str,
message_id: &str,
hidden: Duration,
cached_script: &CachedScript,
) -> RsmqResult<()> {
let hidden = get_redis_duration(Some(hidden), &Duration::from_secs(30));
let queue = self.get_queue(conn, qname, false).await?;
number_in_range(hidden, 0, JS_COMPAT_MAX_TIME_MILLIS)?;
cached_script
.invoke_change_message_visibility::<(), T>(
conn,
format!("{}:{}", self.ns, qname),
message_id.to_string(),
(queue.ts + hidden * DURATION_SCALE).to_string(),
)
.await?;
Ok(())
}
pub async fn load_scripts(&self, conn: &mut T) -> RsmqResult<CachedScript> {
CachedScript::init(conn).await
}
pub async fn create_queue(
&self,
conn: &mut T,
qname: &str,
hidden: Option<Duration>,
delay: Option<Duration>,
maxsize: Option<i64>,
) -> RsmqResult<()> {
valid_name_format(qname)?;
let key = format!("{}:{}:Q", self.ns, qname);
let hidden = get_redis_duration(hidden, &Duration::from_secs(30));
let delay = get_redis_duration(delay, &Duration::ZERO);
let maxsize = maxsize.unwrap_or(65536);
number_in_range(hidden, 0, JS_COMPAT_MAX_TIME_MILLIS)?;
number_in_range(delay, 0, JS_COMPAT_MAX_TIME_MILLIS)?;
if let Err(error) = number_in_range(maxsize, 1024, 65536) {
if maxsize != -1 {
return Err(error);
}
}
let time: (u64, u64) = redis::cmd("TIME").query_async(conn).await?;
let results: Vec<i64> = pipe()
.atomic()
.cmd("HSETNX")
.arg(&key)
.arg("vt")
.arg(hidden)
.cmd("HSETNX")
.arg(&key)
.arg("delay")
.arg(delay)
.cmd("HSETNX")
.arg(&key)
.arg("maxsize")
.arg(maxsize)
.cmd("HSETNX")
.arg(&key)
.arg("created")
.arg(time.0)
.cmd("HSETNX")
.arg(&key)
.arg("modified")
.arg(time.0)
.cmd("HSETNX")
.arg(&key)
.arg("totalrecv")
.arg(0_i32)
.cmd("HSETNX")
.arg(&key)
.arg("totalsent")
.arg(0_i32)
.cmd("SADD")
.arg(format!("{}:QUEUES", self.ns))
.arg(qname)
.query_async(conn)
.await?;
if results[0] == 0 {
return Err(RsmqError::QueueExists);
}
Ok(())
}
pub async fn delete_message(&self, conn: &mut T, qname: &str, id: &str) -> RsmqResult<bool> {
let key = format!("{}:{}", self.ns, qname);
let results: (u16, u16) = pipe()
.atomic()
.cmd("ZREM")
.arg(&key)
.arg(id)
.cmd("HDEL")
.arg(format!("{}:Q", &key))
.arg(id)
.arg(format!("{}:rc", id))
.arg(format!("{}:fr", id))
.query_async(conn)
.await?;
if results.0 == 1 && results.1 > 0 {
return Ok(true);
}
Ok(false)
}
pub async fn delete_queue(&self, conn: &mut T, qname: &str) -> RsmqResult<()> {
let key = format!("{}:{}", self.ns, qname);
let results: (u16, u16) = pipe()
.atomic()
.cmd("DEL")
.arg(format!("{}:Q", &key))
.arg(key)
.cmd("SREM")
.arg(format!("{}:QUEUES", self.ns))
.arg(qname)
.query_async(conn)
.await?;
if results.0 == 0 {
return Err(RsmqError::QueueNotFound);
}
Ok(())
}
pub async fn get_queue_attributes(
&self,
conn: &mut T,
qname: &str,
cached_script: &CachedScript,
) -> RsmqResult<RsmqQueueAttributes> {
let key = format!("{}:{}", self.ns, qname);
#[allow(clippy::type_complexity)]
let result: (
u64, u64,
Option<i64>, Option<i64>, Option<i64>,
Option<i64>, Option<i64>, Option<i64>, Option<i64>,
u64, u64,
) = cached_script
.invoke_get_queue_attributes(
conn,
format!("{}:Q", key),
key,
USE_MICROSECONDS,
)
.await?;
let (_time_sec, _time_usec, vt, delay, maxsize, totalrecv, totalsent, created, modified, msgs, hiddenmsgs) = result;
if vt.is_none() {
return Err(RsmqError::QueueNotFound);
}
Ok(RsmqQueueAttributes {
vt: vt
.map(|dur| Duration::from_millis(dur.try_into().unwrap_or(0)))
.unwrap_or(Duration::ZERO),
delay: delay
.map(|dur| Duration::from_millis(dur.try_into().unwrap_or(0)))
.unwrap_or(Duration::ZERO),
maxsize: maxsize.unwrap_or(0),
totalrecv: totalrecv.and_then(|v| v.try_into().ok()).unwrap_or(0),
totalsent: totalsent.and_then(|v| v.try_into().ok()).unwrap_or(0),
created: created.and_then(|v| v.try_into().ok()).unwrap_or(0),
modified: modified.and_then(|v| v.try_into().ok()).unwrap_or(0),
msgs,
hiddenmsgs,
})
}
pub async fn list_queues(&self, conn: &mut T) -> RsmqResult<Vec<String>> {
let queues = redis::cmd("SMEMBERS")
.arg(format!("{}:QUEUES", self.ns))
.query_async(conn)
.await?;
Ok(queues)
}
pub async fn pop_message<E: TryFrom<RedisBytes, Error = Vec<u8>>>(
&self,
conn: &mut T,
qname: &str,
cached_script: &CachedScript,
) -> RsmqResult<Option<RsmqMessage<E>>> {
let queue = self.get_queue(conn, qname, false).await?;
let result: (bool, String, Vec<u8>, u64, u64) = cached_script
.invoke_receive_message(
conn,
format!("{}:{}", self.ns, qname),
queue.ts.to_string(),
queue.ts.to_string(),
"true".to_string(),
)
.await?;
if !result.0 {
return Ok(None);
}
let message = E::try_from(RedisBytes(result.2)).map_err(RsmqError::CannotDecodeMessage)?;
Ok(Some(RsmqMessage {
id: result.1.clone(),
message,
rc: result.3,
fr: result.4,
sent: result.1.get(0..10).and_then(|s| u64::from_str_radix(s, 36).ok()).unwrap_or(0),
}))
}
pub async fn receive_message<E: TryFrom<RedisBytes, Error = Vec<u8>>>(
&self,
conn: &mut T,
qname: &str,
hidden: Option<Duration>,
cached_script: &CachedScript,
) -> RsmqResult<Option<RsmqMessage<E>>> {
let queue = self.get_queue(conn, qname, false).await?;
let hidden = get_redis_duration(hidden, &queue.vt);
number_in_range(hidden, 0, JS_COMPAT_MAX_TIME_MILLIS)?;
let result: (bool, String, Vec<u8>, u64, u64) = cached_script
.invoke_receive_message(
conn,
format!("{}:{}", self.ns, qname),
queue.ts.to_string(),
(queue.ts + hidden * DURATION_SCALE).to_string(),
"false".to_string(),
)
.await?;
if !result.0 {
return Ok(None);
}
let message = E::try_from(RedisBytes(result.2)).map_err(RsmqError::CannotDecodeMessage)?;
Ok(Some(RsmqMessage {
id: result.1.clone(),
message,
rc: result.3,
fr: result.4,
sent: result.1.get(0..10).and_then(|s| u64::from_str_radix(s, 36).ok()).unwrap_or(0),
}))
}
pub async fn send_message<E: Into<RedisBytes>>(
&self,
conn: &mut T,
qname: &str,
message: E,
delay: Option<Duration>,
) -> RsmqResult<String> {
let queue = self.get_queue(conn, qname, true).await?;
let delay = get_redis_duration(delay, &queue.delay);
let key = format!("{}:{}", self.ns, qname);
number_in_range(delay, 0, JS_COMPAT_MAX_TIME_MILLIS)?;
let message: RedisBytes = message.into();
let msg_len: i64 = message
.0
.len()
.try_into()
.map_err(|_| RsmqError::MessageTooLong)?;
if queue.maxsize != -1 && msg_len > queue.maxsize {
return Err(RsmqError::MessageTooLong);
}
let queue_uid = match queue.uid {
Some(uid) => uid,
None => return Err(RsmqError::QueueNotFound),
};
let queue_key = format!("{}:Q", key);
let mut piping = pipe();
let mut commands = piping
.atomic()
.cmd("ZADD")
.arg(&key)
.arg(queue.ts + delay * DURATION_SCALE)
.arg(&queue_uid)
.cmd("HSET")
.arg(&queue_key)
.arg(&queue_uid)
.arg(message.0)
.cmd("HINCRBY")
.arg(&queue_key)
.arg("totalsent")
.arg(1_u64);
if self.realtime {
commands = commands.cmd("ZCARD").arg(&key);
}
let result: Vec<i64> = commands.query_async(conn).await?;
if self.realtime {
redis::cmd("PUBLISH")
.arg(format!("{}:rt:{}", self.ns, qname))
.arg(result[3])
.query_async::<()>(conn)
.await?;
}
Ok(queue_uid)
}
pub async fn set_queue_attributes(
&self,
conn: &mut T,
qname: &str,
hidden: Option<Duration>,
delay: Option<Duration>,
maxsize: Option<i64>,
cached_script: &CachedScript,
) -> RsmqResult<RsmqQueueAttributes> {
self.get_queue(conn, qname, false).await?;
let queue_name = format!("{}:{}:Q", self.ns, qname);
let time: (u64, u64) = redis::cmd("TIME").query_async(conn).await?;
let mut pipe = pipe();
pipe.atomic()
.cmd("HSET")
.arg(&queue_name)
.arg("modified")
.arg(time.0);
if let Some(hidden) = hidden {
let duration = get_redis_duration(Some(hidden), &Duration::from_secs(30));
number_in_range(duration, 0, JS_COMPAT_MAX_TIME_MILLIS)?;
pipe.cmd("HSET").arg(&queue_name).arg("vt").arg(duration);
}
if let Some(delay) = delay {
let delay = get_redis_duration(Some(delay), &Duration::ZERO);
number_in_range(delay, 0, JS_COMPAT_MAX_TIME_MILLIS)?;
pipe.cmd("HSET").arg(&queue_name).arg("delay").arg(delay);
}
if let Some(maxsize) = maxsize {
if let Err(error) = number_in_range(maxsize, 1024, 65536) {
if maxsize != -1 {
return Err(error);
}
}
pipe.cmd("HSET").arg(&queue_name).arg("maxsize").arg(maxsize);
}
pipe.query_async::<()>(conn).await?;
self.get_queue_attributes(conn, qname, cached_script).await
}
async fn get_queue(&self, conn: &mut T, qname: &str, uid: bool) -> RsmqResult<QueueDescriptor> {
let result: (Vec<Option<String>>, (u64, u64)) = pipe()
.atomic()
.cmd("HMGET")
.arg(format!("{}:{}:Q", self.ns, qname))
.arg("vt")
.arg("delay")
.arg("maxsize")
.cmd("TIME")
.query_async(conn)
.await?;
let sec = (result.1).0;
let usec = (result.1).1;
let time_us = sec * 1_000_000 + usec;
#[cfg(feature = "break-js-comp")]
let ts = time_us;
#[cfg(not(feature = "break-js-comp"))]
let ts = sec * 1000 + usec / 1000;
let (hmget_first, hmget_second, hmget_third) =
match (result.0.first(), result.0.get(1), result.0.get(2)) {
(Some(Some(v0)), Some(Some(v1)), Some(Some(v2))) => (v0, v1, v2),
_ => return Err(RsmqError::QueueNotFound),
};
let quid = if uid {
Some(radix_36(time_us).to_string() + &RsmqFunctions::<T>::make_id(22)?)
} else {
None
};
Ok(QueueDescriptor {
vt: Duration::from_millis(hmget_first.parse().map_err(|_| RsmqError::CannotParseVT)?),
delay: Duration::from_millis(
hmget_second
.parse()
.map_err(|_| RsmqError::CannotParseDelay)?,
),
maxsize: hmget_third
.parse()
.map_err(|_| RsmqError::CannotParseMaxsize)?,
ts,
uid: quid,
})
}
fn make_id(len: usize) -> RsmqResult<String> {
const POSSIBLE: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::rng();
let mut id = String::with_capacity(len);
for _ in 0..len {
let idx = (0..POSSIBLE.len())
.choose(&mut rng)
.ok_or(RsmqError::BugCreatingRandomValue)?;
id.push(POSSIBLE[idx] as char);
}
Ok(id)
}
}
fn number_in_range<T: std::cmp::PartialOrd + std::fmt::Display>(
value: T,
min: T,
max: T,
) -> RsmqResult<()> {
if value >= min && value <= max {
Ok(())
} else {
Err(RsmqError::InvalidValue(
format!("{}", value),
format!("{}", min),
format!("{}", max),
))
}
}
fn valid_name_format(name: &str) -> RsmqResult<()> {
if name.is_empty() || name.len() > 160 {
return Err(RsmqError::InvalidFormat(name.to_string()));
}
if !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(RsmqError::InvalidFormat(name.to_string()));
}
Ok(())
}
fn get_redis_duration(d: Option<Duration>, default: &Duration) -> u64 {
d.as_ref()
.map(Duration::as_millis)
.map(u64::try_from)
.and_then(Result::ok)
.unwrap_or_else(|| u64::try_from(default.as_millis()).ok().unwrap_or(30_000))
}