poem_dbsession/sqlx/
sqlite.rs

1use std::{collections::BTreeMap, time::Duration};
2
3use chrono::Utc;
4use poem::{error::InternalServerError, session::SessionStorage, Result};
5use serde_json::Value;
6use sqlx::{sqlite::SqliteStatement, types::Json, Executor, SqlitePool, Statement};
7
8use crate::DatabaseConfig;
9
10const LOAD_SESSION_SQL: &str = r#"
11    select session from {table_name}
12        where id = ? and (expires is null or expires > ?)
13    "#;
14
15const UPDATE_SESSION_SQL: &str = r#"
16    insert into {table_name} (id, session, expires) values (?, ?, ?)
17        on conflict(id) do update set
18            expires = excluded.expires,
19            session = excluded.session
20"#;
21
22const REMOVE_SESSION_SQL: &str = r#"
23    delete from {table_name} where id = ?
24"#;
25
26const CLEANUP_SQL: &str = r#"
27    delete from {table_name} where expires < $1
28"#;
29
30/// Session storage using Sqlite.
31///
32/// # Errors
33///
34/// - [`sqlx::Error`]
35///
36/// # Create the table for session storage
37///
38/// ```sql
39/// create table poem_sessions (
40///     id text primary key not null,
41///     expires integer null,
42///     session text not null
43/// )
44/// ```
45#[derive(Clone)]
46pub struct SqliteSessionStorage {
47    pool: SqlitePool,
48    load_stmt: SqliteStatement<'static>,
49    update_stmt: SqliteStatement<'static>,
50    remove_stmt: SqliteStatement<'static>,
51    cleanup_stmt: SqliteStatement<'static>,
52}
53
54impl SqliteSessionStorage {
55    /// Create an [`SqliteSessionStorage`].
56    pub async fn try_new(config: DatabaseConfig, pool: SqlitePool) -> sqlx::Result<Self> {
57        let mut conn = pool.acquire().await?;
58
59        let load_stmt = Statement::to_owned(
60            &conn
61                .prepare(&LOAD_SESSION_SQL.replace("{table_name}", &config.table_name))
62                .await?,
63        );
64
65        let update_stmt = Statement::to_owned(
66            &conn
67                .prepare(&UPDATE_SESSION_SQL.replace("{table_name}", &config.table_name))
68                .await?,
69        );
70
71        let remove_stmt = Statement::to_owned(
72            &conn
73                .prepare(&REMOVE_SESSION_SQL.replace("{table_name}", &config.table_name))
74                .await?,
75        );
76
77        let cleanup_stmt = Statement::to_owned(
78            &conn
79                .prepare(&CLEANUP_SQL.replace("{table_name}", &config.table_name))
80                .await?,
81        );
82
83        Ok(Self {
84            pool,
85            load_stmt,
86            update_stmt,
87            remove_stmt,
88            cleanup_stmt,
89        })
90    }
91
92    /// Cleanup expired sessions.
93    pub async fn cleanup(&self) -> sqlx::Result<()> {
94        let mut conn = self.pool.acquire().await?;
95        self.cleanup_stmt
96            .query()
97            .bind(Utc::now())
98            .execute(&mut conn)
99            .await?;
100        Ok(())
101    }
102}
103
104#[poem::async_trait]
105impl SessionStorage for SqliteSessionStorage {
106    async fn load_session(&self, session_id: &str) -> Result<Option<BTreeMap<String, Value>>> {
107        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
108        let res: Option<(Json<BTreeMap<String, Value>>,)> = self
109            .load_stmt
110            .query_as()
111            .bind(session_id)
112            .bind(Utc::now())
113            .fetch_optional(&mut conn)
114            .await
115            .map_err(InternalServerError)?;
116        Ok(res.map(|(value,)| value.0))
117    }
118
119    async fn update_session(
120        &self,
121        session_id: &str,
122        entries: &BTreeMap<String, Value>,
123        expires: Option<Duration>,
124    ) -> Result<()> {
125        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
126
127        let expires = match expires {
128            Some(expires) => {
129                Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?)
130            }
131            None => None,
132        };
133
134        self.update_stmt
135            .query()
136            .bind(session_id)
137            .bind(Json(entries))
138            .bind(expires.map(|expires| Utc::now() + expires))
139            .execute(&mut conn)
140            .await
141            .map_err(InternalServerError)?;
142        Ok(())
143    }
144
145    async fn remove_session(&self, session_id: &str) -> Result<()> {
146        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
147        self.remove_stmt
148            .query()
149            .bind(session_id)
150            .execute(&mut conn)
151            .await
152            .map_err(InternalServerError)?;
153        Ok(())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::test_harness;
161
162    #[tokio::test]
163    async fn test() {
164        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
165
166        let mut conn = pool.acquire().await.unwrap();
167        sqlx::query(
168            r#"
169        create table poem_sessions (
170            id text primary key not null,
171            expires integer null,
172            session text not null
173        )
174        "#,
175        )
176        .execute(&mut conn)
177        .await
178        .unwrap();
179
180        let storage = SqliteSessionStorage::try_new(DatabaseConfig::new(), pool)
181            .await
182            .unwrap();
183
184        let join_handle = tokio::spawn({
185            let storage = storage.clone();
186            async move {
187                loop {
188                    tokio::time::sleep(Duration::from_secs(1)).await;
189                    storage.cleanup().await.unwrap();
190                }
191            }
192        });
193        test_harness::test_storage(storage).await;
194        join_handle.abort();
195    }
196}