use async_trait::async_trait;
use redis::AsyncCommands;
use serde_json::Value;
use uuid::Uuid;
use crate::runtime::invocation::{SessionDataError, SessionDataStore};
pub struct RedisSessionDataStore {
client: redis::Client,
}
impl RedisSessionDataStore {
pub fn new(url: &str) -> Result<Self, SessionDataError> {
let client = redis::Client::open(url).map_err(|e| SessionDataError::Storage {
message: format!("redis client error: {e}"),
})?;
Ok(Self { client })
}
#[must_use]
pub fn from_client(client: redis::Client) -> Self {
Self { client }
}
async fn conn(&self) -> Result<redis::aio::MultiplexedConnection, SessionDataError> {
self.client
.get_multiplexed_async_connection()
.await
.map_err(|e| SessionDataError::Storage {
message: format!("redis connection error: {e}"),
})
}
fn hash_key(session_id: Uuid) -> String {
format!("behest:session_data:{session_id}")
}
}
#[async_trait]
impl SessionDataStore for RedisSessionDataStore {
async fn set(
&self,
session_id: Uuid,
key: String,
value: Value,
) -> Result<(), SessionDataError> {
let mut conn = self.conn().await?;
let hash = Self::hash_key(session_id);
let json = serde_json::to_string(&value).map_err(|e| SessionDataError::Storage {
message: format!("serialization error: {e}"),
})?;
conn.hset::<_, _, _, ()>(&hash, &key, &json)
.await
.map_err(|e| SessionDataError::Storage {
message: format!("redis HSET error: {e}"),
})
}
async fn get(&self, session_id: Uuid, key: &str) -> Result<Option<Value>, SessionDataError> {
let mut conn = self.conn().await?;
let hash = Self::hash_key(session_id);
let raw: Option<String> =
conn.hget(&hash, key)
.await
.map_err(|e| SessionDataError::Storage {
message: format!("redis HGET error: {e}"),
})?;
match raw {
Some(s) => {
let val: Value =
serde_json::from_str(&s).map_err(|e| SessionDataError::Storage {
message: format!("deserialization error: {e}"),
})?;
Ok(Some(val))
}
None => Ok(None),
}
}
async fn delete(&self, session_id: Uuid, key: &str) -> Result<(), SessionDataError> {
let mut conn = self.conn().await?;
let hash = Self::hash_key(session_id);
conn.hdel::<_, _, ()>(&hash, key)
.await
.map_err(|e| SessionDataError::Storage {
message: format!("redis HDEL error: {e}"),
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn hash_key_format() {
let sid = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
assert_eq!(
RedisSessionDataStore::hash_key(sid),
"behest:session_data:550e8400-e29b-41d4-a716-446655440000"
);
}
}