poem_dbsession/sqlx/
sqlite.rs1use std::{collections::BTreeMap, time::Duration};
2
3use chrono::Utc;
4use poem::{error::InternalServerError, session::SessionStorage, Result};
5use serde_json::Value;
6use sqlx::{sqlite::SqliteStatement, types::Json, Executor, SqlitePool, Statement};
7
8use crate::DatabaseConfig;
9
10const LOAD_SESSION_SQL: &str = r#"
11 select session from {table_name}
12 where id = ? and (expires is null or expires > ?)
13 "#;
14
15const UPDATE_SESSION_SQL: &str = r#"
16 insert into {table_name} (id, session, expires) values (?, ?, ?)
17 on conflict(id) do update set
18 expires = excluded.expires,
19 session = excluded.session
20"#;
21
22const REMOVE_SESSION_SQL: &str = r#"
23 delete from {table_name} where id = ?
24"#;
25
26const CLEANUP_SQL: &str = r#"
27 delete from {table_name} where expires < $1
28"#;
29
30#[derive(Clone)]
46pub struct SqliteSessionStorage {
47 pool: SqlitePool,
48 load_stmt: SqliteStatement<'static>,
49 update_stmt: SqliteStatement<'static>,
50 remove_stmt: SqliteStatement<'static>,
51 cleanup_stmt: SqliteStatement<'static>,
52}
53
54impl SqliteSessionStorage {
55 pub async fn try_new(config: DatabaseConfig, pool: SqlitePool) -> sqlx::Result<Self> {
57 let mut conn = pool.acquire().await?;
58
59 let load_stmt = Statement::to_owned(
60 &conn
61 .prepare(&LOAD_SESSION_SQL.replace("{table_name}", &config.table_name))
62 .await?,
63 );
64
65 let update_stmt = Statement::to_owned(
66 &conn
67 .prepare(&UPDATE_SESSION_SQL.replace("{table_name}", &config.table_name))
68 .await?,
69 );
70
71 let remove_stmt = Statement::to_owned(
72 &conn
73 .prepare(&REMOVE_SESSION_SQL.replace("{table_name}", &config.table_name))
74 .await?,
75 );
76
77 let cleanup_stmt = Statement::to_owned(
78 &conn
79 .prepare(&CLEANUP_SQL.replace("{table_name}", &config.table_name))
80 .await?,
81 );
82
83 Ok(Self {
84 pool,
85 load_stmt,
86 update_stmt,
87 remove_stmt,
88 cleanup_stmt,
89 })
90 }
91
92 pub async fn cleanup(&self) -> sqlx::Result<()> {
94 let mut conn = self.pool.acquire().await?;
95 self.cleanup_stmt
96 .query()
97 .bind(Utc::now())
98 .execute(&mut conn)
99 .await?;
100 Ok(())
101 }
102}
103
104#[poem::async_trait]
105impl SessionStorage for SqliteSessionStorage {
106 async fn load_session(&self, session_id: &str) -> Result<Option<BTreeMap<String, Value>>> {
107 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
108 let res: Option<(Json<BTreeMap<String, Value>>,)> = self
109 .load_stmt
110 .query_as()
111 .bind(session_id)
112 .bind(Utc::now())
113 .fetch_optional(&mut conn)
114 .await
115 .map_err(InternalServerError)?;
116 Ok(res.map(|(value,)| value.0))
117 }
118
119 async fn update_session(
120 &self,
121 session_id: &str,
122 entries: &BTreeMap<String, Value>,
123 expires: Option<Duration>,
124 ) -> Result<()> {
125 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
126
127 let expires = match expires {
128 Some(expires) => {
129 Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?)
130 }
131 None => None,
132 };
133
134 self.update_stmt
135 .query()
136 .bind(session_id)
137 .bind(Json(entries))
138 .bind(expires.map(|expires| Utc::now() + expires))
139 .execute(&mut conn)
140 .await
141 .map_err(InternalServerError)?;
142 Ok(())
143 }
144
145 async fn remove_session(&self, session_id: &str) -> Result<()> {
146 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
147 self.remove_stmt
148 .query()
149 .bind(session_id)
150 .execute(&mut conn)
151 .await
152 .map_err(InternalServerError)?;
153 Ok(())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::test_harness;
161
162 #[tokio::test]
163 async fn test() {
164 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
165
166 let mut conn = pool.acquire().await.unwrap();
167 sqlx::query(
168 r#"
169 create table poem_sessions (
170 id text primary key not null,
171 expires integer null,
172 session text not null
173 )
174 "#,
175 )
176 .execute(&mut conn)
177 .await
178 .unwrap();
179
180 let storage = SqliteSessionStorage::try_new(DatabaseConfig::new(), pool)
181 .await
182 .unwrap();
183
184 let join_handle = tokio::spawn({
185 let storage = storage.clone();
186 async move {
187 loop {
188 tokio::time::sleep(Duration::from_secs(1)).await;
189 storage.cleanup().await.unwrap();
190 }
191 }
192 });
193 test_harness::test_storage(storage).await;
194 join_handle.abort();
195 }
196}