hyper_auth_proxy/
redis_session.rs

1use std::io;
2use redis::{Client, AsyncCommands, IntoConnectionInfo, RedisError, RedisResult};
3use redis::aio::Connection;
4use serde::{Serialize, Deserialize};
5use thiserror::Error;
6
7#[derive(Clone, Debug)]
8pub struct RedisSessionStore {
9    client: Client,
10}
11
12#[derive(Serialize, Deserialize)]
13pub struct Session {
14    pub credentials: String,
15}
16
17#[derive(Error, Debug)]
18pub enum StoreError {
19    #[error("io error ({0})")]
20    Io(io::Error),
21    #[error("json deserialization error ({0})")]
22    Json(serde_json::Error),
23    #[error("redis error ({0})")]
24    Redis( #[from] RedisError)
25}
26
27impl From<serde_json::Error> for StoreError {
28    fn from(err: serde_json::Error) -> StoreError {
29        use serde_json::error::Category;
30        match err.classify() {
31            Category::Io => {
32                StoreError::Io(err.into())
33            }
34            Category::Syntax | Category::Data | Category::Eof => {
35                StoreError::Json(err)
36            }
37        }
38    }
39}
40
41impl RedisSessionStore {
42    pub(crate) async fn get(&self, sid: &str) -> Result<Option<Session>, StoreError> {
43        let mut connection = self.connection().await?;
44        let session_str: Option<String> = connection.get(sid).await?;
45        match session_str {
46            Some(json) => Ok(serde_json::from_str(&json)?),
47            None => Ok(None)
48        }
49    }
50    pub async fn set(&self, sid: &str, session: Session) -> Result<(), StoreError> {
51        let session_str = serde_json::to_string(&session)?;
52        let mut connection = self.connection().await?;
53        connection.set(sid, session_str).await?;
54        Ok(())
55    }
56    pub fn new(connection_info: impl IntoConnectionInfo) -> RedisResult<Self> {
57        Ok(Self {client: Client::open(connection_info)?})
58    }
59    async fn connection(&self) -> RedisResult<Connection> {
60        self.client.get_async_connection().await
61    }
62    pub async fn clear_store(&self, keys: &[&str]) -> Result<(), StoreError> {
63        let mut connection = self.connection().await?;
64        for key in keys {
65            connection.del(key).await?
66        }
67        Ok(())
68    }
69}
70
71#[cfg(test)]
72mod test {
73    use super::*;
74
75    #[tokio::test]
76    async fn get_unknown_key() {
77        assert!(create_store().await.get("unknown").await.unwrap().is_none())
78    }
79
80    #[tokio::test]
81    async fn get_session() {
82        let store = create_store().await;
83        store.set("sid", Session {credentials: String::from("credentials") }).await.unwrap();
84
85        let session = store.get("sid").await.unwrap().unwrap();
86
87        assert_eq!(session.credentials, "credentials");
88    }
89
90    async fn create_store() -> RedisSessionStore {
91        let store = RedisSessionStore::new("redis://redis/1").unwrap();
92        store.clear_store(&["sid"]).await.unwrap();
93        store
94    }
95}