actix_session_sqlx_postgres/
lib.rs

1use crate::ConnectionData::{ConnectionPool, ConnectionString};
2use actix_session::storage::{LoadError, SaveError, SessionKey, SessionStore, UpdateError};
3use chrono::Utc;
4use rand::{distributions::Alphanumeric, rngs::OsRng, Rng as _};
5use serde_json::{self, Value};
6use sqlx::postgres::PgPoolOptions;
7use sqlx::{Pool, Postgres, Row};
8use std::collections::HashMap;
9use std::sync::Arc;
10use time::Duration;
11
12/// Use Postgres via Sqlx as session storage backend.
13///
14/// ```no_run
15/// use actix_web::{web, App, HttpServer, HttpResponse, Error};
16/// use actix_session_sqlx_postgres::SqlxPostgresqlSessionStore;
17/// use actix_session::SessionMiddleware;
18/// use actix_web::cookie::Key;
19///
20/// // The secret key would usually be read from a configuration file/environment variables.
21/// fn get_secret_key() -> Key {
22///     # todo!()
23///     // [...]
24/// }
25///
26/// #[actix_web::main]
27/// async fn main() -> std::io::Result<()> {
28///     let secret_key = get_secret_key();
29///     let psql_connection_string = "postgres://<username>:<password>@127.0.0.1:5432/<yourdatabase>";
30///     let store = SqlxPostgresqlSessionStore::new(psql_connection_string).await.unwrap();
31///
32///     HttpServer::new(move ||
33///             App::new()
34///             .wrap(SessionMiddleware::new(
35///                 store.clone(),
36///                 secret_key.clone()
37///             ))
38///             .default_service(web::to(|| HttpResponse::Ok())))
39///         .bind(("127.0.0.1", 8080))?
40///         .run()
41///         .await
42/// }
43/// ```
44/// If you already have a connection pool, you can use something like
45/*/// ```no_run
46/// use actix_web::{web, App, HttpServer, HttpResponse, Error};
47/// use actix_session_sqlx_postgres::SqlxPostgresqlSessionStore;
48/// use actix_session::SessionMiddleware;
49/// use actix_web::cookie::Key;
50///
51/// // The secret key would usually be read from a configuration file/environment variables.
52/// fn get_secret_key() -> Key {
53///     # todo!()
54///     // [...]
55/// }
56/// #[actix_web::main]
57/// async fn main() -> std::io::Result<()> {
58///     use sqlx::postgres::PgPoolOptions;
59/// let secret_key = get_secret_key();
60///     let pool = PgPoolOptions::find_some_way_to_build_your_pool(psql_connection_string);
61///     let store = SqlxPostgresqlSessionStore::from_pool(pool).await.expect("Could not build session store");
62///
63///     HttpServer::new(move ||
64///             App::new()
65///             .wrap(SessionMiddleware::new(
66///                 store.clone(),
67///                 secret_key.clone()
68///             ))
69///             .default_service(web::to(|| HttpResponse::Ok())))
70///         .bind(("127.0.0.1", 8080))?
71///         .run()
72///         .await
73/// }
74/// ```
75*/
76#[derive(Clone)]
77struct CacheConfiguration {
78    cache_keygen: Arc<dyn Fn(&str) -> String + Send + Sync>,
79}
80
81impl Default for CacheConfiguration {
82    fn default() -> Self {
83        Self {
84            cache_keygen: Arc::new(str::to_owned),
85        }
86    }
87}
88
89#[derive(Clone)]
90pub struct SqlxPostgresqlSessionStore {
91    client_pool: Arc<Pool<Postgres>>,
92    configuration: CacheConfiguration,
93}
94
95fn generate_session_key() -> SessionKey {
96    let value = std::iter::repeat(())
97        .map(|()| OsRng.sample(Alphanumeric))
98        .take(64)
99        .collect::<Vec<_>>();
100
101    // These unwraps will never panic because pre-conditions are always verified
102    // (i.e. length and character set)
103    String::from_utf8(value).unwrap().try_into().unwrap()
104}
105
106impl SqlxPostgresqlSessionStore {
107    pub fn builder<S: Into<String>>(connection_string: S) -> SqlxPostgresqlSessionStoreBuilder {
108        SqlxPostgresqlSessionStoreBuilder {
109            connection_data: ConnectionString(connection_string.into()),
110            configuration: CacheConfiguration::default(),
111        }
112    }
113
114    pub async fn new<S: Into<String>>(
115        connection_string: S,
116    ) -> Result<SqlxPostgresqlSessionStore, anyhow::Error> {
117        Self::builder(connection_string).build().await
118    }
119
120    pub fn from_pool(pool: Arc<Pool<Postgres>>) -> SqlxPostgresqlSessionStore {
121        SqlxPostgresqlSessionStore {
122            client_pool: pool,
123            configuration: CacheConfiguration::default(),
124        }
125    }
126}
127
128pub enum ConnectionData {
129    ConnectionString(String),
130    ConnectionPool(Arc<Pool<Postgres>>),
131}
132
133#[must_use]
134pub struct SqlxPostgresqlSessionStoreBuilder {
135    connection_data: ConnectionData,
136    configuration: CacheConfiguration,
137}
138
139impl SqlxPostgresqlSessionStoreBuilder {
140    pub async fn build(self) -> Result<SqlxPostgresqlSessionStore, anyhow::Error> {
141        match self.connection_data {
142            ConnectionString(conn_string) => PgPoolOptions::new()
143                .max_connections(1)
144                .connect(conn_string.as_str())
145                .await
146                .map_err(Into::into)
147                .map(|pool| SqlxPostgresqlSessionStore {
148                    client_pool: Arc::new(pool),
149                    configuration: self.configuration,
150                }),
151            ConnectionPool(pool) => Ok(SqlxPostgresqlSessionStore {
152                client_pool: pool,
153                configuration: self.configuration,
154            }),
155        }
156    }
157}
158pub(crate) type SessionState = HashMap<String, String>;
159
160impl SessionStore for SqlxPostgresqlSessionStore {
161    async fn load(&self, session_key: &SessionKey) -> Result<Option<SessionState>, LoadError> {
162        let key = (self.configuration.cache_keygen)(session_key.as_ref());
163        let row =
164            sqlx::query("SELECT session_state FROM sessions WHERE key = $1 AND expires > NOW()")
165                .bind(key)
166                .fetch_optional(self.client_pool.as_ref())
167                .await
168                .map_err(Into::into)
169                .map_err(LoadError::Other)?;
170        match row {
171            None => Ok(None),
172            Some(r) => {
173                let data: Value = r.get("session_state");
174                let state: SessionState = serde_json::from_value(data)
175                    .map_err(Into::into)
176                    .map_err(LoadError::Deserialization)?;
177                Ok(Some(state))
178            }
179        }
180    }
181
182    async fn save(
183        &self,
184        session_state: SessionState,
185        ttl: &Duration,
186    ) -> Result<SessionKey, SaveError> {
187        let body = serde_json::to_value(&session_state)
188            .map_err(Into::into)
189            .map_err(SaveError::Serialization)?;
190        let key = generate_session_key();
191        let cache_key = (self.configuration.cache_keygen)(key.as_ref());
192        let expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
193        sqlx::query("INSERT INTO sessions(key, session_state, expires) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING")
194            .bind(cache_key)
195            .bind(body)
196            .bind(expires)
197            .execute(self.client_pool.as_ref())
198            .await
199            .map_err(Into::into)
200            .map_err(SaveError::Other)?;
201        Ok(key)
202    }
203
204    async fn update(
205        &self,
206        session_key: SessionKey,
207        session_state: SessionState,
208        ttl: &Duration,
209    ) -> Result<SessionKey, UpdateError> {
210        let body = serde_json::to_value(&session_state)
211            .map_err(Into::into)
212            .map_err(UpdateError::Serialization)?;
213        let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
214        let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
215        sqlx::query("UPDATE sessions SET session_state = $1, expires = $2 WHERE key = $3")
216            .bind(body)
217            .bind(new_expires)
218            .bind(cache_key)
219            .execute(self.client_pool.as_ref())
220            .await
221            .map_err(Into::into)
222            .map_err(UpdateError::Other)?;
223        Ok(session_key)
224    }
225
226    async fn update_ttl(
227        &self,
228        session_key: &SessionKey,
229        ttl: &Duration,
230    ) -> Result<(), anyhow::Error> {
231        let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
232        let key = (self.configuration.cache_keygen)(session_key.as_ref());
233        sqlx::query("UPDATE sessions SET expires = $1 WHERE key = $2")
234            .bind(new_expires)
235            .bind(key)
236            .execute(self.client_pool.as_ref())
237            .await
238            .map_err(Into::into)
239            .map_err(UpdateError::Other)?;
240        Ok(())
241    }
242
243    async fn delete(&self, session_key: &SessionKey) -> Result<(), anyhow::Error> {
244        let key = (self.configuration.cache_keygen)(session_key.as_ref());
245        sqlx::query("DELETE FROM sessions WHERE key = $1")
246            .bind(key)
247            .execute(self.client_pool.as_ref())
248            .await
249            .map_err(Into::into)
250            .map_err(UpdateError::Other)?;
251        Ok(())
252    }
253}