use std::convert::AsRef;
use std::future::{ready, Ready};
use std::sync::Arc;
use std::time::Duration;
use actix_web::{dev::Payload, error::ErrorInternalServerError, FromRequest, HttpRequest};
use crate::dev::{Expiry, ExpiryStore, Store};
use crate::error::Result;
#[cfg(feature = "with-serde")]
use crate::format::{deserialize, serialize, Format};
pub const GLOBAL_SCOPE: [u8; 20] = *b"STORAGE_GLOBAL_SCOPE";
#[derive(Clone)]
pub struct Storage {
scope: Arc<[u8]>,
store: Arc<dyn ExpiryStore>,
#[cfg(feature = "with-serde")]
format: Format,
}
impl Storage {
pub fn build() -> StorageBuilder {
StorageBuilder::default()
}
pub fn scope(&self, scope: impl AsRef<[u8]>) -> Storage {
Storage {
scope: scope.as_ref().into(),
store: self.store.clone(),
#[cfg(feature = "with-serde")]
format: self.format,
}
}
#[cfg(feature = "with-serde")]
pub async fn set<V>(&self, key: impl AsRef<[u8]>, value: &V) -> Result<()>
where
V: serde::Serialize,
{
self.store
.set(
self.scope.clone(),
key.as_ref().into(),
serialize(value, &self.format)?.into(),
)
.await
}
#[cfg(feature = "with-serde")]
pub async fn set_expiring<V>(
&self,
key: impl AsRef<[u8]>,
value: &V,
expires_in: Duration,
) -> Result<()>
where
V: serde::Serialize,
{
self.store
.set_expiring(
self.scope.clone(),
key.as_ref().into(),
serialize(value, &self.format)?.into(),
expires_in,
)
.await
}
pub async fn set_bytes(&self, key: impl AsRef<[u8]>, value: impl AsRef<[u8]>) -> Result<()> {
self.store
.set(
self.scope.clone(),
key.as_ref().into(),
value.as_ref().into(),
)
.await
}
pub async fn set_expiring_bytes(
&self,
key: impl AsRef<[u8]>,
value: impl AsRef<[u8]>,
expires_in: Duration,
) -> Result<()> {
self.store
.set_expiring(
self.scope.clone(),
key.as_ref().into(),
value.as_ref().into(),
expires_in,
)
.await
}
#[cfg(feature = "with-serde")]
pub async fn get<K, V>(&self, key: K) -> Result<Option<V>>
where
K: AsRef<[u8]>,
V: serde::de::DeserializeOwned,
{
let val = self
.store
.get(self.scope.clone(), key.as_ref().into())
.await?;
val.map(|val| deserialize(val.as_ref(), &self.format))
.transpose()
}
#[cfg(feature = "with-serde")]
pub async fn get_expiring<K, V>(&self, key: K) -> Result<Option<(V, Option<Duration>)>>
where
K: AsRef<[u8]>,
V: serde::de::DeserializeOwned,
{
if let Some((val, expiry)) = self
.store
.get_expiring(self.scope.clone(), key.as_ref().into())
.await?
{
let val = deserialize(val.as_ref(), &self.format)?;
Ok(Some((val, expiry)))
} else {
Ok(None)
}
}
pub async fn get_bytes(&self, key: impl AsRef<[u8]>) -> Result<Option<Vec<u8>>> {
Ok(self
.store
.get(self.scope.clone(), key.as_ref().into())
.await?
.map(|val| {
let mut new_value = vec![];
new_value.extend_from_slice(val.as_ref());
new_value
}))
}
pub async fn get_expiring_bytes(
&self,
key: impl AsRef<[u8]>,
) -> Result<Option<(Vec<u8>, Option<Duration>)>> {
if let Some((val, expiry)) = self
.store
.get_expiring(self.scope.clone(), key.as_ref().into())
.await?
{
Ok(Some((val.as_ref().into(), expiry)))
} else {
Ok(None)
}
}
pub async fn get_bytes_ref(&self, key: impl AsRef<[u8]>) -> Result<Option<Arc<[u8]>>> {
self.store
.get(self.scope.clone(), key.as_ref().into())
.await
}
pub async fn get_expiring_bytes_ref(
&self,
key: impl AsRef<[u8]>,
) -> Result<Option<(Arc<[u8]>, Option<Duration>)>> {
if let Some((val, expiry)) = self
.store
.get_expiring(self.scope.clone(), key.as_ref().into())
.await?
{
Ok(Some((val, expiry)))
} else {
Ok(None)
}
}
pub async fn delete(&self, key: impl AsRef<[u8]>) -> Result<()> {
self.store
.delete(self.scope.clone(), key.as_ref().into())
.await
}
pub async fn contains_key(&self, key: impl AsRef<[u8]>) -> Result<bool> {
self.store
.contains_key(self.scope.clone(), key.as_ref().into())
.await
}
pub async fn expire(&self, key: impl AsRef<[u8]>, expire_in: Duration) -> Result<()> {
self.store
.expire(self.scope.clone(), key.as_ref().into(), expire_in)
.await
}
pub async fn expiry(&self, key: impl AsRef<[u8]>) -> Result<Option<Duration>> {
self.store
.expiry(self.scope.clone(), key.as_ref().into())
.await
}
pub async fn extend(&self, key: impl AsRef<[u8]>, expire_in: Duration) -> Result<()> {
self.store
.extend(self.scope.clone(), key.as_ref().into(), expire_in)
.await
}
pub async fn persist(&self, key: impl AsRef<[u8]>) -> Result<()> {
self.store
.persist(self.scope.clone(), key.as_ref().into())
.await
}
}
#[derive(Default)]
pub struct StorageBuilder {
store: Option<Arc<dyn Store>>,
expiry: Option<Arc<dyn Expiry>>,
expiry_store: Option<Arc<dyn ExpiryStore>>,
#[cfg(feature = "with-serde")]
format: Format,
}
impl StorageBuilder {
#[must_use = "Builder must be used by calling finish"]
pub fn store(mut self, store: impl Store + 'static) -> Self {
self.store = Some(Arc::new(store));
self
}
#[must_use = "Builder must be used by calling finish"]
pub fn expiry(mut self, expiry: impl Expiry + 'static) -> Self {
self.expiry = Some(Arc::new(expiry));
self
}
#[must_use = "Builder must be used by calling finish"]
pub fn expiry_store<T>(mut self, expiry_store: T) -> Self
where
T: 'static + Store + Expiry + ExpiryStore,
{
self.expiry_store = Some(Arc::new(expiry_store));
self
}
#[cfg(feature = "with-serde")]
#[must_use = "Builder must be used by calling finish"]
pub fn format(mut self, format: Format) -> Self {
self.format = format;
self
}
pub fn finish(self) -> Storage {
let expiry_store = if let Some(expiry_store) = self.expiry_store {
expiry_store
} else if let Some(store) = self.store {
Arc::new(self::private::ExpiryStoreGlue(store, self.expiry))
} else {
panic!("Storage builder needs at least a store");
};
Storage {
scope: Arc::new(GLOBAL_SCOPE),
store: expiry_store,
#[cfg(feature = "with-serde")]
format: self.format,
}
}
}
impl FromRequest for Storage {
type Error = actix_web::Error;
type Future = Ready<std::result::Result<Self, actix_web::Error>>;
#[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
if let Some(st) = req.app_data::<Storage>() {
ready(Ok(st.clone()))
} else {
log::debug!(
"Failed to construct Storage(actix-storage). \
Request path: {:?}",
req.path(),
);
ready(Err(ErrorInternalServerError(
"Storage is not configured, please refer to actix-storage documentation\
for more information.",
)))
}
}
}
mod private {
use std::sync::Arc;
use std::time::Duration;
use crate::{
error::{Result, StorageError},
provider::{Expiry, ExpiryStore, Store},
};
pub(crate) struct ExpiryStoreGlue(pub Arc<dyn Store>, pub Option<Arc<dyn Expiry>>);
#[async_trait::async_trait]
impl Expiry for ExpiryStoreGlue {
async fn expire(
&self,
scope: Arc<[u8]>,
key: Arc<[u8]>,
expire_in: Duration,
) -> Result<()> {
if let Some(expiry) = self.1.clone() {
expiry.expire(scope, key, expire_in).await
} else {
Err(StorageError::MethodNotSupported)
}
}
async fn expiry(&self, scope: Arc<[u8]>, key: Arc<[u8]>) -> Result<Option<Duration>> {
if let Some(ref expiry) = self.1 {
expiry.expiry(scope, key).await
} else {
Err(StorageError::MethodNotSupported)
}
}
async fn extend(
&self,
scope: Arc<[u8]>,
key: Arc<[u8]>,
expire_in: Duration,
) -> Result<()> {
if let Some(ref expiry) = self.1 {
expiry.extend(scope, key, expire_in).await
} else {
Err(StorageError::MethodNotSupported)
}
}
async fn persist(&self, scope: Arc<[u8]>, key: Arc<[u8]>) -> Result<()> {
if let Some(ref expiry) = self.1 {
expiry.persist(scope, key).await
} else {
Err(StorageError::MethodNotSupported)
}
}
}
#[async_trait::async_trait]
impl Store for ExpiryStoreGlue {
async fn set(&self, scope: Arc<[u8]>, key: Arc<[u8]>, value: Arc<[u8]>) -> Result<()> {
self.0.set(scope, key.clone(), value).await?;
if let Some(ref expiry) = self.1 {
expiry.set_called(key).await;
};
Ok(())
}
async fn get(&self, scope: Arc<[u8]>, key: Arc<[u8]>) -> Result<Option<Arc<[u8]>>> {
self.0.get(scope, key).await
}
async fn delete(&self, scope: Arc<[u8]>, key: Arc<[u8]>) -> Result<()> {
self.0.delete(scope, key).await
}
async fn contains_key(&self, scope: Arc<[u8]>, key: Arc<[u8]>) -> Result<bool> {
self.0.contains_key(scope, key).await
}
}
#[async_trait::async_trait]
impl ExpiryStore for ExpiryStoreGlue {
async fn set_expiring(
&self,
scope: Arc<[u8]>,
key: Arc<[u8]>,
value: Arc<[u8]>,
expire_in: Duration,
) -> Result<()> {
if let Some(expiry) = self.1.clone() {
self.0.set(scope.clone(), key.clone(), value).await?;
expiry.expire(scope, key, expire_in).await
} else {
Err(StorageError::MethodNotSupported)
}
}
async fn get_expiring(
&self,
scope: Arc<[u8]>,
key: Arc<[u8]>,
) -> Result<Option<(Arc<[u8]>, Option<Duration>)>> {
if let Some(expiry) = self.1.clone() {
let val = self.0.get(scope.clone(), key.clone()).await?;
if let Some(val) = val {
let expiry = expiry.expiry(scope, key).await?;
Ok(Some((val, expiry)))
} else {
Ok(None)
}
} else {
Err(StorageError::MethodNotSupported)
}
}
}
}
#[cfg(test)]
mod test {
use std::time::Duration;
use super::*;
#[actix::test]
async fn test_no_expiry() {
struct OnlyStore;
#[async_trait::async_trait]
impl Store for OnlyStore {
async fn set(&self, _: Arc<[u8]>, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<()> {
Ok(())
}
async fn get(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<Option<Arc<[u8]>>> {
Ok(None)
}
async fn contains_key(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<bool> {
Ok(false)
}
async fn delete(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<()> {
Ok(())
}
}
let storage = Storage::build().store(OnlyStore).finish();
let k = "key";
let v = "value".as_bytes();
let d = Duration::from_secs(1);
assert!(storage.expire(k, d).await.is_err());
assert!(storage.expiry(k).await.is_err());
assert!(storage.extend(k, d).await.is_err());
assert!(storage.persist(k).await.is_err());
assert!(storage.set_expiring_bytes(k, v, d).await.is_err());
assert!(storage.get_expiring_bytes(k).await.is_err());
assert!(storage.set_bytes(k, v).await.is_ok());
assert!(storage.get_bytes(k).await.is_ok());
assert!(storage.delete(k).await.is_ok());
assert!(storage.contains_key(k).await.is_ok());
}
#[actix::test]
async fn test_expiry_store_polyfill() {
#[derive(Clone)]
struct SampleStore;
#[async_trait::async_trait]
impl Store for SampleStore {
async fn set(&self, _: Arc<[u8]>, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<()> {
Ok(())
}
async fn get(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<Option<Arc<[u8]>>> {
Ok(Some("v".as_bytes().into()))
}
async fn contains_key(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<bool> {
Ok(false)
}
async fn delete(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<()> {
Ok(())
}
}
#[async_trait::async_trait]
impl Expiry for SampleStore {
async fn expire(&self, _: Arc<[u8]>, _: Arc<[u8]>, _: Duration) -> Result<()> {
Ok(())
}
async fn expiry(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<Option<Duration>> {
Ok(Some(Duration::from_secs(1)))
}
async fn extend(&self, _: Arc<[u8]>, _: Arc<[u8]>, _: Duration) -> Result<()> {
Ok(())
}
async fn persist(&self, _: Arc<[u8]>, _: Arc<[u8]>) -> Result<()> {
Ok(())
}
}
let k = "key";
let v = "value".as_bytes();
let d = Duration::from_secs(1);
let store = SampleStore;
let storage = Storage::build().store(store.clone()).expiry(store).finish();
assert!(storage
.set_expiring_bytes("key", "value", Duration::from_secs(1))
.await
.is_ok());
assert!(storage.expire(k, d).await.is_ok());
assert!(storage.expiry(k).await.is_ok());
assert!(storage.extend(k, d).await.is_ok());
assert!(storage.persist(k).await.is_ok());
assert!(storage.set_expiring_bytes(k, v, d).await.is_ok());
assert!(storage.get_expiring_bytes(k).await.is_ok());
assert!(storage.set_bytes(k, v).await.is_ok());
assert!(storage.get_bytes(k).await.is_ok());
assert!(storage.delete(k).await.is_ok());
assert!(storage.contains_key(k).await.is_ok());
let res = storage.get_expiring_bytes("key").await;
assert!(res.is_ok());
assert!(res.unwrap() == Some(("v".as_bytes().into(), Some(Duration::from_secs(1)))));
}
#[test]
#[should_panic(expected = "Storage builder needs at least a store")]
fn test_no_sotre() {
Storage::build().finish();
}
}