poem_dbsession/sqlx/
mysql.rs

1use std::{collections::BTreeMap, time::Duration};
2
3use chrono::Utc;
4use poem::{error::InternalServerError, session::SessionStorage, Result};
5use serde_json::Value;
6use sqlx::{mysql::MySqlStatement, types::Json, Executor, MySqlPool, Statement};
7
8use crate::DatabaseConfig;
9
10const LOAD_SESSION_SQL: &str = r#"
11    select session from {table_name}
12        where id = ? and (expires is null or expires > ?)
13    "#;
14
15const UPDATE_SESSION_SQL: &str = r#"
16    insert into {table_name} (id, session, expires) values (?, ?, ?)
17        on duplicate key update
18            expires = values(expires),
19            session = values(session)
20"#;
21
22const REMOVE_SESSION_SQL: &str = r#"
23    delete from {table_name} where id = ?
24"#;
25
26const CLEANUP_SQL: &str = r#"
27    delete from {table_name} where expires < ?
28"#;
29
30/// Session storage using Mysql.
31///
32/// # Errors
33///
34/// - [`sqlx::Error`]
35///
36/// # Create the table for session storage
37///
38/// ```sql
39/// create table if not exists poem_sessions (
40///     id varchar(128) not null,
41///     expires timestamp(6) null,
42///     session text not null,
43///     primary key (id),
44///     key expires (expires)
45/// )
46/// engine=innodb
47/// default charset=utf8
48/// ```
49#[derive(Clone)]
50pub struct MysqlSessionStorage {
51    pool: MySqlPool,
52    load_stmt: MySqlStatement<'static>,
53    update_stmt: MySqlStatement<'static>,
54    remove_stmt: MySqlStatement<'static>,
55    cleanup_stmt: MySqlStatement<'static>,
56}
57
58impl MysqlSessionStorage {
59    /// Create an [`MysqlSessionStorage`].
60    pub async fn try_new(config: DatabaseConfig, pool: MySqlPool) -> sqlx::Result<Self> {
61        let mut conn = pool.acquire().await?;
62
63        let load_stmt = Statement::to_owned(
64            &conn
65                .prepare(&LOAD_SESSION_SQL.replace("{table_name}", &config.table_name))
66                .await?,
67        );
68
69        let update_stmt = Statement::to_owned(
70            &conn
71                .prepare(&UPDATE_SESSION_SQL.replace("{table_name}", &config.table_name))
72                .await?,
73        );
74
75        let remove_stmt = Statement::to_owned(
76            &conn
77                .prepare(&REMOVE_SESSION_SQL.replace("{table_name}", &config.table_name))
78                .await?,
79        );
80
81        let cleanup_stmt = Statement::to_owned(
82            &conn
83                .prepare(&CLEANUP_SQL.replace("{table_name}", &config.table_name))
84                .await?,
85        );
86
87        Ok(Self {
88            pool,
89            load_stmt,
90            update_stmt,
91            remove_stmt,
92            cleanup_stmt,
93        })
94    }
95
96    /// Cleanup expired sessions.
97    pub async fn cleanup(&self) -> sqlx::Result<()> {
98        let mut conn = self.pool.acquire().await?;
99        self.cleanup_stmt
100            .query()
101            .bind(Utc::now())
102            .execute(&mut conn)
103            .await?;
104        Ok(())
105    }
106}
107
108#[poem::async_trait]
109impl SessionStorage for MysqlSessionStorage {
110    async fn load_session(&self, session_id: &str) -> Result<Option<BTreeMap<String, Value>>> {
111        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
112        let res: Option<(Json<BTreeMap<String, Value>>,)> = self
113            .load_stmt
114            .query_as()
115            .bind(session_id)
116            .bind(Utc::now())
117            .fetch_optional(&mut conn)
118            .await
119            .map_err(InternalServerError)?;
120        Ok(res.map(|(value,)| value.0))
121    }
122
123    async fn update_session(
124        &self,
125        session_id: &str,
126        entries: &BTreeMap<String, Value>,
127        expires: Option<Duration>,
128    ) -> Result<()> {
129        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
130
131        let expires = match expires {
132            Some(expires) => {
133                Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?)
134            }
135            None => None,
136        };
137
138        self.update_stmt
139            .query()
140            .bind(session_id)
141            .bind(Json(entries))
142            .bind(expires.map(|expires| Utc::now() + expires))
143            .execute(&mut conn)
144            .await
145            .map_err(InternalServerError)?;
146        Ok(())
147    }
148
149    async fn remove_session(&self, session_id: &str) -> Result<()> {
150        let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
151        self.remove_stmt
152            .query()
153            .bind(session_id)
154            .execute(&mut conn)
155            .await
156            .map_err(InternalServerError)?;
157        Ok(())
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::test_harness;
165
166    #[tokio::test]
167    async fn test() {
168        let pool = MySqlPool::connect("mysql://root:123456@localhost/test_poem_sessions")
169            .await
170            .unwrap();
171
172        let mut conn = pool.acquire().await.unwrap();
173        sqlx::query(
174            r#"
175        create table if not exists poem_sessions (
176            id varchar(128) not null,
177            expires timestamp(6) null,
178            session text not null,
179            primary key (id),
180            key expires (expires)
181        )
182        engine=innodb
183        default charset=utf8
184        "#,
185        )
186        .execute(&mut conn)
187        .await
188        .unwrap();
189
190        let storage = MysqlSessionStorage::try_new(DatabaseConfig::new(), pool)
191            .await
192            .unwrap();
193
194        let join_handle = tokio::spawn({
195            let storage = storage.clone();
196            async move {
197                loop {
198                    tokio::time::sleep(Duration::from_secs(1)).await;
199                    storage.cleanup().await.unwrap();
200                }
201            }
202        });
203        test_harness::test_storage(storage).await;
204        join_handle.abort();
205    }
206}