actix_session_sqlx_postgres/
lib.rs1use crate::ConnectionData::{ConnectionPool, ConnectionString};
2use actix_session::storage::{LoadError, SaveError, SessionKey, SessionStore, UpdateError};
3use chrono::Utc;
4use rand::{distributions::Alphanumeric, rngs::OsRng, Rng as _};
5use serde_json::{self, Value};
6use sqlx::postgres::PgPoolOptions;
7use sqlx::{Pool, Postgres, Row};
8use std::collections::HashMap;
9use std::sync::Arc;
10use time::Duration;
11
12#[derive(Clone)]
77struct CacheConfiguration {
78 cache_keygen: Arc<dyn Fn(&str) -> String + Send + Sync>,
79}
80
81impl Default for CacheConfiguration {
82 fn default() -> Self {
83 Self {
84 cache_keygen: Arc::new(str::to_owned),
85 }
86 }
87}
88
89#[derive(Clone)]
90pub struct SqlxPostgresqlSessionStore {
91 client_pool: Arc<Pool<Postgres>>,
92 configuration: CacheConfiguration,
93}
94
95fn generate_session_key() -> SessionKey {
96 let value = std::iter::repeat(())
97 .map(|()| OsRng.sample(Alphanumeric))
98 .take(64)
99 .collect::<Vec<_>>();
100
101 String::from_utf8(value).unwrap().try_into().unwrap()
104}
105
106impl SqlxPostgresqlSessionStore {
107 pub fn builder<S: Into<String>>(connection_string: S) -> SqlxPostgresqlSessionStoreBuilder {
108 SqlxPostgresqlSessionStoreBuilder {
109 connection_data: ConnectionString(connection_string.into()),
110 configuration: CacheConfiguration::default(),
111 }
112 }
113
114 pub async fn new<S: Into<String>>(
115 connection_string: S,
116 ) -> Result<SqlxPostgresqlSessionStore, anyhow::Error> {
117 Self::builder(connection_string).build().await
118 }
119
120 pub fn from_pool(pool: Arc<Pool<Postgres>>) -> SqlxPostgresqlSessionStore {
121 SqlxPostgresqlSessionStore {
122 client_pool: pool,
123 configuration: CacheConfiguration::default(),
124 }
125 }
126}
127
128pub enum ConnectionData {
129 ConnectionString(String),
130 ConnectionPool(Arc<Pool<Postgres>>),
131}
132
133#[must_use]
134pub struct SqlxPostgresqlSessionStoreBuilder {
135 connection_data: ConnectionData,
136 configuration: CacheConfiguration,
137}
138
139impl SqlxPostgresqlSessionStoreBuilder {
140 pub async fn build(self) -> Result<SqlxPostgresqlSessionStore, anyhow::Error> {
141 match self.connection_data {
142 ConnectionString(conn_string) => PgPoolOptions::new()
143 .max_connections(1)
144 .connect(conn_string.as_str())
145 .await
146 .map_err(Into::into)
147 .map(|pool| SqlxPostgresqlSessionStore {
148 client_pool: Arc::new(pool),
149 configuration: self.configuration,
150 }),
151 ConnectionPool(pool) => Ok(SqlxPostgresqlSessionStore {
152 client_pool: pool,
153 configuration: self.configuration,
154 }),
155 }
156 }
157}
158pub(crate) type SessionState = HashMap<String, String>;
159
160impl SessionStore for SqlxPostgresqlSessionStore {
161 async fn load(&self, session_key: &SessionKey) -> Result<Option<SessionState>, LoadError> {
162 let key = (self.configuration.cache_keygen)(session_key.as_ref());
163 let row =
164 sqlx::query("SELECT session_state FROM sessions WHERE key = $1 AND expires > NOW()")
165 .bind(key)
166 .fetch_optional(self.client_pool.as_ref())
167 .await
168 .map_err(Into::into)
169 .map_err(LoadError::Other)?;
170 match row {
171 None => Ok(None),
172 Some(r) => {
173 let data: Value = r.get("session_state");
174 let state: SessionState = serde_json::from_value(data)
175 .map_err(Into::into)
176 .map_err(LoadError::Deserialization)?;
177 Ok(Some(state))
178 }
179 }
180 }
181
182 async fn save(
183 &self,
184 session_state: SessionState,
185 ttl: &Duration,
186 ) -> Result<SessionKey, SaveError> {
187 let body = serde_json::to_value(&session_state)
188 .map_err(Into::into)
189 .map_err(SaveError::Serialization)?;
190 let key = generate_session_key();
191 let cache_key = (self.configuration.cache_keygen)(key.as_ref());
192 let expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
193 sqlx::query("INSERT INTO sessions(key, session_state, expires) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING")
194 .bind(cache_key)
195 .bind(body)
196 .bind(expires)
197 .execute(self.client_pool.as_ref())
198 .await
199 .map_err(Into::into)
200 .map_err(SaveError::Other)?;
201 Ok(key)
202 }
203
204 async fn update(
205 &self,
206 session_key: SessionKey,
207 session_state: SessionState,
208 ttl: &Duration,
209 ) -> Result<SessionKey, UpdateError> {
210 let body = serde_json::to_value(&session_state)
211 .map_err(Into::into)
212 .map_err(UpdateError::Serialization)?;
213 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
214 let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
215 sqlx::query("UPDATE sessions SET session_state = $1, expires = $2 WHERE key = $3")
216 .bind(body)
217 .bind(new_expires)
218 .bind(cache_key)
219 .execute(self.client_pool.as_ref())
220 .await
221 .map_err(Into::into)
222 .map_err(UpdateError::Other)?;
223 Ok(session_key)
224 }
225
226 async fn update_ttl(
227 &self,
228 session_key: &SessionKey,
229 ttl: &Duration,
230 ) -> Result<(), anyhow::Error> {
231 let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
232 let key = (self.configuration.cache_keygen)(session_key.as_ref());
233 sqlx::query("UPDATE sessions SET expires = $1 WHERE key = $2")
234 .bind(new_expires)
235 .bind(key)
236 .execute(self.client_pool.as_ref())
237 .await
238 .map_err(Into::into)
239 .map_err(UpdateError::Other)?;
240 Ok(())
241 }
242
243 async fn delete(&self, session_key: &SessionKey) -> Result<(), anyhow::Error> {
244 let key = (self.configuration.cache_keygen)(session_key.as_ref());
245 sqlx::query("DELETE FROM sessions WHERE key = $1")
246 .bind(key)
247 .execute(self.client_pool.as_ref())
248 .await
249 .map_err(Into::into)
250 .map_err(UpdateError::Other)?;
251 Ok(())
252 }
253}