extern crate redis;
use self::redis::aio::ConnectionManager;
use self::redis::ConnectionInfo;
use crate::counter::Counter;
use crate::limit::Limit;
use crate::storage::keys::*;
use crate::storage::redis::is_limited;
use crate::storage::redis::scripts::{SCRIPT_UPDATE_COUNTER, VALUES_AND_TTLS};
use crate::storage::{AsyncCounterStorage, Authorization, StorageErr};
use async_trait::async_trait;
use redis::{AsyncCommands, ErrorKind, RedisError};
use std::collections::HashSet;
use std::ops::Deref;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{info_span, Instrument};
#[derive(Clone)]
pub struct AsyncRedisStorage {
conn_manager: ConnectionManager,
}
#[async_trait]
impl AsyncCounterStorage for AsyncRedisStorage {
#[tracing::instrument(skip_all)]
async fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result<bool, StorageErr> {
let mut con = self.conn_manager.clone();
match con
.get::<Vec<u8>, Option<i64>>(key_for_counter(counter))
.instrument(info_span!("datastore"))
.await?
{
Some(val) => Ok(u64::try_from(val).unwrap_or(0) + delta <= counter.max_value()),
None => Ok(counter.max_value().checked_sub(delta).is_some()),
}
}
#[tracing::instrument(skip_all)]
async fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> {
let mut con = self.conn_manager.clone();
redis::Script::new(SCRIPT_UPDATE_COUNTER)
.key(key_for_counter(counter))
.key(key_for_counters_of_limit(counter.limit()))
.arg(counter.window().as_secs())
.arg(delta)
.invoke_async::<()>(&mut con)
.instrument(info_span!("datastore"))
.await?;
Ok(())
}
#[tracing::instrument(skip_all)]
async fn check_and_update<'a>(
&self,
counters: &mut Vec<Counter>,
delta: u64,
load_counters: bool,
) -> Result<Authorization, StorageErr> {
let mut con = self.conn_manager.clone();
let counter_keys: Vec<Vec<u8>> = counters.iter().map(key_for_counter).collect();
if load_counters {
let script = redis::Script::new(VALUES_AND_TTLS);
let mut script_invocation = script.prepare_invoke();
for counter_key in &counter_keys {
script_invocation.key(counter_key);
}
let script_res: Vec<Option<i64>> = {
script_invocation
.invoke_async(&mut con)
.instrument(info_span!("datastore"))
.await?
};
if let Some(res) = is_limited(counters, delta, script_res) {
return Ok(res);
}
} else {
let counter_vals: Vec<Option<i64>> = {
redis::cmd("MGET")
.arg(counter_keys.clone())
.query_async(&mut con)
.instrument(info_span!("datastore"))
.await?
};
for (i, counter) in counters.iter().enumerate() {
let remaining = counter
.max_value()
.checked_sub(u64::try_from(counter_vals[i].unwrap_or(0)).unwrap_or(0) + delta);
if remaining.is_none() {
return Ok(Authorization::Limited(
counter.limit().name().map(|n| n.to_owned()),
));
}
}
}
let script = redis::Script::new(SCRIPT_UPDATE_COUNTER);
let mut pipeline = redis::pipe();
let mut pipeline = &mut pipeline;
for (counter_idx, key) in counter_keys.iter().enumerate() {
let counter = &counters[counter_idx];
pipeline = pipeline
.invoke_script(
script
.key(key)
.key(key_for_counters_of_limit(counter.limit()))
.arg(counter.window().as_secs())
.arg(delta),
)
.ignore()
}
if let Err(err) = pipeline
.query_async::<()>(&mut con)
.instrument(info_span!("datastore"))
.await
{
if err.kind() == ErrorKind::NoScriptError {
script.prepare_invoke().load_async(&mut con).await?;
pipeline
.query_async::<()>(&mut con)
.instrument(info_span!("datastore"))
.await?;
} else {
Err(err)?;
}
}
Ok(Authorization::Ok)
}
#[tracing::instrument(skip_all)]
async fn get_counters(
&self,
limits: &HashSet<Arc<Limit>>,
) -> Result<HashSet<Counter>, StorageErr> {
let mut res = HashSet::new();
let mut con = self.conn_manager.clone();
for limit in limits {
let counter_keys = {
con.smembers::<Vec<u8>, HashSet<Vec<u8>>>(key_for_counters_of_limit(limit))
.instrument(info_span!("datastore"))
.await?
};
for counter_key in counter_keys {
let mut counter: Counter =
counter_from_counter_key(&counter_key, Arc::clone(limit));
let option = {
con.get::<Vec<u8>, Option<i64>>(counter_key.clone())
.instrument(info_span!("datastore"))
.await?
};
if let Some(val) = option {
counter.set_remaining(limit.max_value() - u64::try_from(val).unwrap_or(0));
let ttl: i64 = {
con.ttl(&counter_key)
.instrument(info_span!("datastore"))
.await?
};
counter.set_expires_in(Duration::from_secs(u64::try_from(ttl).unwrap_or(0)));
res.insert(counter);
}
}
}
Ok(res)
}
#[tracing::instrument(skip_all)]
async fn delete_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<(), StorageErr> {
for limit in limits {
self.delete_counters_associated_with_limit(limit.deref())
.instrument(info_span!("datastore"))
.await?
}
Ok(())
}
#[tracing::instrument(skip_all)]
async fn clear(&self) -> Result<(), StorageErr> {
let mut con = self.conn_manager.clone();
redis::cmd("FLUSHDB")
.query_async::<()>(&mut con)
.instrument(info_span!("datastore"))
.await?;
Ok(())
}
}
impl AsyncRedisStorage {
pub async fn new(redis_url: &str) -> Result<Self, RedisError> {
let info = ConnectionInfo::from_str(redis_url)?;
Self::new_with_conn_manager(
ConnectionManager::new(
redis::Client::open(info)
.expect("This couldn't fail in the past, yet now it did somehow!"),
)
.await?,
)
.await
}
pub async fn new_with_conn_manager(
conn_manager: ConnectionManager,
) -> Result<Self, RedisError> {
let store = Self { conn_manager };
store.load_script(SCRIPT_UPDATE_COUNTER).await?;
store.load_script(VALUES_AND_TTLS).await?;
Ok(store)
}
async fn delete_counters_associated_with_limit(&self, limit: &Limit) -> Result<(), StorageErr> {
let mut con = self.conn_manager.clone();
let counter_keys = {
con.smembers::<Vec<u8>, HashSet<Vec<u8>>>(key_for_counters_of_limit(limit))
.instrument(info_span!("datastore"))
.await?
};
for counter_key in counter_keys {
con.del::<_, ()>(counter_key)
.instrument(info_span!("datastore"))
.await?;
}
con.del::<_, ()>(key_for_counters_of_limit(limit)).await?;
Ok(())
}
pub(super) async fn load_script(&self, script: &str) -> Result<(), RedisError> {
let mut con = self.conn_manager.clone();
let script = redis::Script::new(script);
script.prepare_invoke().load_async(&mut con).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::storage::redis::AsyncRedisStorage;
use redis::ErrorKind;
#[tokio::test]
async fn errs_on_bad_url() {
let result = AsyncRedisStorage::new("cassandra://127.0.0.1:6379").await;
assert!(result.is_err());
assert_eq!(result.err().unwrap().kind(), ErrorKind::InvalidClientConfig);
}
#[tokio::test]
async fn errs_on_connection_issue() {
let result = AsyncRedisStorage::new("redis://127.0.0.1:21").await;
assert!(result.is_err());
let error = result.err().unwrap();
assert_eq!(error.kind(), ErrorKind::IoError);
assert!(error.is_connection_refusal())
}
}