poem_dbsession/sqlx/
mysql.rs1use std::{collections::BTreeMap, time::Duration};
2
3use chrono::Utc;
4use poem::{error::InternalServerError, session::SessionStorage, Result};
5use serde_json::Value;
6use sqlx::{mysql::MySqlStatement, types::Json, Executor, MySqlPool, Statement};
7
8use crate::DatabaseConfig;
9
10const LOAD_SESSION_SQL: &str = r#"
11 select session from {table_name}
12 where id = ? and (expires is null or expires > ?)
13 "#;
14
15const UPDATE_SESSION_SQL: &str = r#"
16 insert into {table_name} (id, session, expires) values (?, ?, ?)
17 on duplicate key update
18 expires = values(expires),
19 session = values(session)
20"#;
21
22const REMOVE_SESSION_SQL: &str = r#"
23 delete from {table_name} where id = ?
24"#;
25
26const CLEANUP_SQL: &str = r#"
27 delete from {table_name} where expires < ?
28"#;
29
30#[derive(Clone)]
50pub struct MysqlSessionStorage {
51 pool: MySqlPool,
52 load_stmt: MySqlStatement<'static>,
53 update_stmt: MySqlStatement<'static>,
54 remove_stmt: MySqlStatement<'static>,
55 cleanup_stmt: MySqlStatement<'static>,
56}
57
58impl MysqlSessionStorage {
59 pub async fn try_new(config: DatabaseConfig, pool: MySqlPool) -> sqlx::Result<Self> {
61 let mut conn = pool.acquire().await?;
62
63 let load_stmt = Statement::to_owned(
64 &conn
65 .prepare(&LOAD_SESSION_SQL.replace("{table_name}", &config.table_name))
66 .await?,
67 );
68
69 let update_stmt = Statement::to_owned(
70 &conn
71 .prepare(&UPDATE_SESSION_SQL.replace("{table_name}", &config.table_name))
72 .await?,
73 );
74
75 let remove_stmt = Statement::to_owned(
76 &conn
77 .prepare(&REMOVE_SESSION_SQL.replace("{table_name}", &config.table_name))
78 .await?,
79 );
80
81 let cleanup_stmt = Statement::to_owned(
82 &conn
83 .prepare(&CLEANUP_SQL.replace("{table_name}", &config.table_name))
84 .await?,
85 );
86
87 Ok(Self {
88 pool,
89 load_stmt,
90 update_stmt,
91 remove_stmt,
92 cleanup_stmt,
93 })
94 }
95
96 pub async fn cleanup(&self) -> sqlx::Result<()> {
98 let mut conn = self.pool.acquire().await?;
99 self.cleanup_stmt
100 .query()
101 .bind(Utc::now())
102 .execute(&mut conn)
103 .await?;
104 Ok(())
105 }
106}
107
108#[poem::async_trait]
109impl SessionStorage for MysqlSessionStorage {
110 async fn load_session(&self, session_id: &str) -> Result<Option<BTreeMap<String, Value>>> {
111 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
112 let res: Option<(Json<BTreeMap<String, Value>>,)> = self
113 .load_stmt
114 .query_as()
115 .bind(session_id)
116 .bind(Utc::now())
117 .fetch_optional(&mut conn)
118 .await
119 .map_err(InternalServerError)?;
120 Ok(res.map(|(value,)| value.0))
121 }
122
123 async fn update_session(
124 &self,
125 session_id: &str,
126 entries: &BTreeMap<String, Value>,
127 expires: Option<Duration>,
128 ) -> Result<()> {
129 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
130
131 let expires = match expires {
132 Some(expires) => {
133 Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?)
134 }
135 None => None,
136 };
137
138 self.update_stmt
139 .query()
140 .bind(session_id)
141 .bind(Json(entries))
142 .bind(expires.map(|expires| Utc::now() + expires))
143 .execute(&mut conn)
144 .await
145 .map_err(InternalServerError)?;
146 Ok(())
147 }
148
149 async fn remove_session(&self, session_id: &str) -> Result<()> {
150 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
151 self.remove_stmt
152 .query()
153 .bind(session_id)
154 .execute(&mut conn)
155 .await
156 .map_err(InternalServerError)?;
157 Ok(())
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::test_harness;
165
166 #[tokio::test]
167 async fn test() {
168 let pool = MySqlPool::connect("mysql://root:123456@localhost/test_poem_sessions")
169 .await
170 .unwrap();
171
172 let mut conn = pool.acquire().await.unwrap();
173 sqlx::query(
174 r#"
175 create table if not exists poem_sessions (
176 id varchar(128) not null,
177 expires timestamp(6) null,
178 session text not null,
179 primary key (id),
180 key expires (expires)
181 )
182 engine=innodb
183 default charset=utf8
184 "#,
185 )
186 .execute(&mut conn)
187 .await
188 .unwrap();
189
190 let storage = MysqlSessionStorage::try_new(DatabaseConfig::new(), pool)
191 .await
192 .unwrap();
193
194 let join_handle = tokio::spawn({
195 let storage = storage.clone();
196 async move {
197 loop {
198 tokio::time::sleep(Duration::from_secs(1)).await;
199 storage.cleanup().await.unwrap();
200 }
201 }
202 });
203 test_harness::test_storage(storage).await;
204 join_handle.abort();
205 }
206}