poem_dbsession/sqlx/
postgres.rs

1use std::{collections::BTreeMap, time::Duration};
2
3use chrono::Utc;
4use poem::{error::InternalServerError, session::SessionStorage, Result};
5use serde_json::Value;
6use sqlx::{postgres::PgStatement, types::Json, Executor, PgPool, Statement};
7
8use crate::DatabaseConfig;
9
10const LOAD_SESSION_SQL: &str = r#"
11    select session from {table_name}
12        where id = $1 and (expires is null or expires > $2)
13    "#;
14
15const UPDATE_SESSION_SQL: &str = r#"
16    insert into {table_name} (id, session, expires) values ($1, $2, $3)
17        on conflict(id) do update set
18            expires = excluded.expires,
19            session = excluded.session
20"#;
21
22const REMOVE_SESSION_SQL: &str = r#"
23    delete from {table_name} where id = $1
24"#;
25
26const CLEANUP_SQL: &str = r#"
27    delete from {table_name} where expires < $1
28"#;
29
30/// Session storage using Postgres.
31///
32/// # Errors
33///
34/// - [`sqlx::Error`]
35///
36/// # Create the table for session storage
37///
38/// ```sql
39/// create table if not exists poem_sessions (
40///     id varchar not null primary key,
41///     expires timestamp with time zone null,
42///     session jsonb not null
43/// );
44///
45/// create index if not exists poem_sessions_expires_idx on poem_sessions (expires);
46/// ```
47#[derive(Clone)]
48pub struct PgSessionStorage {
49    pool: PgPool,
50    load_stmt: PgStatement<'static>,
51    update_stmt: PgStatement<'static>,
52    remove_stmt: PgStatement<'static>,
53    cleanup_stmt: PgStatement<'static>,
54}
55
56impl PgSessionStorage {
57    /// Create an [`PgSessionStorage`].
58    pub async fn try_new(config: DatabaseConfig, pool: PgPool) -> sqlx::Result<Self> {
59        let mut conn = pool.acquire().await?;
60
61        let load_stmt = Statement::to_owned(
62            &conn
63                .prepare(&LOAD_SESSION_SQL.replace("{table_name}", &config.table_name))
64                .await?,
65        );
66
67        let update_stmt = Statement::to_owned(
68            &conn
69                .prepare(&UPDATE_SESSION_SQL.replace("{table_name}", &config.table_name))
70                .await?,
71        );
72
73        let remove_stmt = Statement::to_owned(
74            &conn
75                .prepare(&REMOVE_SESSION_SQL.replace("{table_name}", &config.table_name))
76                .await?,
77        );
78
79        let cleanup_stmt = Statement::to_owned(
80            &conn
81                .prepare(&CLEANUP_SQL.replace("{table_name}", &config.table_name))
82                .await?,
83        );
84
85        Ok(Self {
86            pool,
87            load_stmt,
88            update_stmt,
89            remove_stmt,
90            cleanup_stmt,
91        })
92    }
93
94    /// Cleanup expired sessions.
95    pub async fn cleanup(&self) -> sqlx::Result<()> {
96        let mut conn = self.pool.acquire().await?;
97        self.cleanup_stmt
98            .query()
99            .bind(Utc::now())
100            .execute(&mut conn)
101            .await?;
102        Ok(())
103    }
104}
105
106#[poem::async_trait]
107impl SessionStorage for PgSessionStorage {
108    async fn load_session(&self, session_id: &str) -> Result<Option<BTreeMap<String, Value>>> {
109        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
110        let res: Option<(Json<BTreeMap<String, Value>>,)> = self
111            .load_stmt
112            .query_as()
113            .bind(session_id)
114            .bind(Utc::now())
115            .fetch_optional(&mut conn)
116            .await
117            .map_err(InternalServerError)?;
118        Ok(res.map(|(value,)| value.0))
119    }
120
121    async fn update_session(
122        &self,
123        session_id: &str,
124        entries: &BTreeMap<String, Value>,
125        expires: Option<Duration>,
126    ) -> Result<()> {
127        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
128
129        let expires = match expires {
130            Some(expires) => {
131                Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?)
132            }
133            None => None,
134        };
135
136        self.update_stmt
137            .query()
138            .bind(session_id)
139            .bind(Json(entries))
140            .bind(expires.map(|expires| Utc::now() + expires))
141            .execute(&mut conn)
142            .await
143            .map_err(InternalServerError)?;
144        Ok(())
145    }
146
147    async fn remove_session(&self, session_id: &str) -> Result<()> {
148        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
149        self.remove_stmt
150            .query()
151            .bind(session_id)
152            .execute(&mut conn)
153            .await
154            .map_err(InternalServerError)?;
155        Ok(())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::test_harness;
163
164    #[tokio::test]
165    async fn test() {
166        let pool = PgPool::connect("postgres://postgres:123456@localhost/test_poem_sessions")
167            .await
168            .unwrap();
169
170        let mut conn = pool.acquire().await.unwrap();
171        sqlx::query(
172            r#"
173        create table if not exists poem_sessions (
174            id varchar not null primary key,
175            expires timestamp with time zone null,
176            session jsonb not null
177        )
178        "#,
179        )
180        .execute(&mut conn)
181        .await
182        .unwrap();
183
184        sqlx::query(
185            r#"
186        create index if not exists poem_sessions_expires_idx on poem_sessions (expires)
187        "#,
188        )
189        .execute(&mut conn)
190        .await
191        .unwrap();
192
193        let storage = PgSessionStorage::try_new(DatabaseConfig::new(), pool)
194            .await
195            .unwrap();
196
197        let join_handle = tokio::spawn({
198            let storage = storage.clone();
199            async move {
200                loop {
201                    tokio::time::sleep(Duration::from_secs(1)).await;
202                    storage.cleanup().await.unwrap();
203                }
204            }
205        });
206        test_harness::test_storage(storage).await;
207        join_handle.abort();
208    }
209}