use crate::sessions::backends::{SessionBackend, SessionError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(feature = "compression-zstd")]
mod zstd;
#[cfg(feature = "compression-zstd")]
pub use self::zstd::ZstdCompressor;
#[cfg(feature = "compression-gzip")]
mod gzip;
#[cfg(feature = "compression-gzip")]
pub use self::gzip::GzipCompressor;
#[cfg(feature = "compression-brotli")]
mod brotli;
#[cfg(feature = "compression-brotli")]
pub use self::brotli::BrotliCompressor;
#[non_exhaustive]
#[derive(Debug, Error)]
pub enum CompressionError {
#[error("Compression failed: {0}")]
CompressionFailed(String),
#[error("Decompression failed: {0}")]
DecompressionFailed(String),
}
pub trait Compressor: Send + Sync + Clone {
fn compress(&self, data: &[u8]) -> Result<Vec<u8>, CompressionError>;
fn decompress(&self, compressed: &[u8]) -> Result<Vec<u8>, CompressionError>;
fn name(&self) -> &'static str;
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct CompressedData {
payload: Vec<u8>,
is_compressed: bool,
}
#[derive(Clone)]
pub struct CompressedSessionBackend<B, C> {
backend: B,
compressor: C,
threshold_bytes: usize,
}
impl<B, C> CompressedSessionBackend<B, C>
where
B: SessionBackend,
C: Compressor,
{
pub fn new(backend: B, compressor: C) -> Self {
Self {
backend,
compressor,
threshold_bytes: 512, }
}
pub fn with_threshold(backend: B, compressor: C, threshold_bytes: usize) -> Self {
Self {
backend,
compressor,
threshold_bytes,
}
}
pub fn threshold(&self) -> usize {
self.threshold_bytes
}
}
#[async_trait]
impl<B, C> SessionBackend for CompressedSessionBackend<B, C>
where
B: SessionBackend,
C: Compressor,
{
async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
where
T: for<'de> Deserialize<'de> + Send,
{
let envelope: Option<CompressedData> = self.backend.load(session_key).await?;
match envelope {
Some(envelope) => {
let payload = if envelope.is_compressed {
self.compressor
.decompress(&envelope.payload)
.map_err(|e| SessionError::SerializationError(e.to_string()))?
} else {
envelope.payload
};
let data: T = serde_json::from_slice(&payload)
.map_err(|e| SessionError::SerializationError(e.to_string()))?;
Ok(Some(data))
}
None => Ok(None),
}
}
async fn save<T>(
&self,
session_key: &str,
data: &T,
ttl: Option<u64>,
) -> Result<(), SessionError>
where
T: Serialize + Send + Sync,
{
let serialized = serde_json::to_vec(data)
.map_err(|e| SessionError::SerializationError(e.to_string()))?;
let (payload, is_compressed) = if serialized.len() > self.threshold_bytes {
let compressed = self
.compressor
.compress(&serialized)
.map_err(|e| SessionError::SerializationError(e.to_string()))?;
(compressed, true)
} else {
(serialized, false)
};
let envelope = CompressedData {
payload,
is_compressed,
};
self.backend.save(session_key, &envelope, ttl).await
}
async fn delete(&self, session_key: &str) -> Result<(), SessionError> {
self.backend.delete(session_key).await
}
async fn exists(&self, session_key: &str) -> Result<bool, SessionError> {
self.backend.exists(session_key).await
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "compression-zstd")]
#[tokio::test]
async fn test_compressed_backend_above_threshold() {
use super::{CompressedSessionBackend, ZstdCompressor};
use crate::sessions::{InMemorySessionBackend, SessionBackend};
let backend = InMemorySessionBackend::new();
let compressor = ZstdCompressor::new();
let compressed_backend = CompressedSessionBackend::with_threshold(backend, compressor, 10);
let data = serde_json::json!({
"key": "value_with_many_characters_to_exceed_threshold",
});
compressed_backend
.save("test_key", &data, None)
.await
.unwrap();
let loaded: Option<serde_json::Value> = compressed_backend.load("test_key").await.unwrap();
assert_eq!(loaded.unwrap(), data);
}
#[cfg(feature = "compression-zstd")]
#[tokio::test]
async fn test_compressed_backend_below_threshold() {
use super::{CompressedSessionBackend, ZstdCompressor};
use crate::sessions::{InMemorySessionBackend, SessionBackend};
let backend = InMemorySessionBackend::new();
let compressor = ZstdCompressor::new();
let compressed_backend =
CompressedSessionBackend::with_threshold(backend, compressor, 1000);
let data = serde_json::json!({"key": "value"});
compressed_backend
.save("test_key", &data, None)
.await
.unwrap();
let loaded: Option<serde_json::Value> = compressed_backend.load("test_key").await.unwrap();
assert_eq!(loaded.unwrap(), data);
}
#[cfg(feature = "compression-zstd")]
#[tokio::test]
async fn test_compressed_backend_delete() {
use super::{CompressedSessionBackend, ZstdCompressor};
use crate::sessions::{InMemorySessionBackend, SessionBackend};
let backend = InMemorySessionBackend::new();
let compressor = ZstdCompressor::new();
let compressed_backend = CompressedSessionBackend::new(backend, compressor);
let data = serde_json::json!({"key": "value"});
compressed_backend
.save("test_key", &data, None)
.await
.unwrap();
assert!(compressed_backend.exists("test_key").await.unwrap());
compressed_backend.delete("test_key").await.unwrap();
assert!(!compressed_backend.exists("test_key").await.unwrap());
}
}