poem_dbsession/sqlx/
postgres.rs1use std::{collections::BTreeMap, time::Duration};
2
3use chrono::Utc;
4use poem::{error::InternalServerError, session::SessionStorage, Result};
5use serde_json::Value;
6use sqlx::{postgres::PgStatement, types::Json, Executor, PgPool, Statement};
7
8use crate::DatabaseConfig;
9
10const LOAD_SESSION_SQL: &str = r#"
11 select session from {table_name}
12 where id = $1 and (expires is null or expires > $2)
13 "#;
14
15const UPDATE_SESSION_SQL: &str = r#"
16 insert into {table_name} (id, session, expires) values ($1, $2, $3)
17 on conflict(id) do update set
18 expires = excluded.expires,
19 session = excluded.session
20"#;
21
22const REMOVE_SESSION_SQL: &str = r#"
23 delete from {table_name} where id = $1
24"#;
25
26const CLEANUP_SQL: &str = r#"
27 delete from {table_name} where expires < $1
28"#;
29
30#[derive(Clone)]
48pub struct PgSessionStorage {
49 pool: PgPool,
50 load_stmt: PgStatement<'static>,
51 update_stmt: PgStatement<'static>,
52 remove_stmt: PgStatement<'static>,
53 cleanup_stmt: PgStatement<'static>,
54}
55
56impl PgSessionStorage {
57 pub async fn try_new(config: DatabaseConfig, pool: PgPool) -> sqlx::Result<Self> {
59 let mut conn = pool.acquire().await?;
60
61 let load_stmt = Statement::to_owned(
62 &conn
63 .prepare(&LOAD_SESSION_SQL.replace("{table_name}", &config.table_name))
64 .await?,
65 );
66
67 let update_stmt = Statement::to_owned(
68 &conn
69 .prepare(&UPDATE_SESSION_SQL.replace("{table_name}", &config.table_name))
70 .await?,
71 );
72
73 let remove_stmt = Statement::to_owned(
74 &conn
75 .prepare(&REMOVE_SESSION_SQL.replace("{table_name}", &config.table_name))
76 .await?,
77 );
78
79 let cleanup_stmt = Statement::to_owned(
80 &conn
81 .prepare(&CLEANUP_SQL.replace("{table_name}", &config.table_name))
82 .await?,
83 );
84
85 Ok(Self {
86 pool,
87 load_stmt,
88 update_stmt,
89 remove_stmt,
90 cleanup_stmt,
91 })
92 }
93
94 pub async fn cleanup(&self) -> sqlx::Result<()> {
96 let mut conn = self.pool.acquire().await?;
97 self.cleanup_stmt
98 .query()
99 .bind(Utc::now())
100 .execute(&mut conn)
101 .await?;
102 Ok(())
103 }
104}
105
106#[poem::async_trait]
107impl SessionStorage for PgSessionStorage {
108 async fn load_session(&self, session_id: &str) -> Result<Option<BTreeMap<String, Value>>> {
109 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
110 let res: Option<(Json<BTreeMap<String, Value>>,)> = self
111 .load_stmt
112 .query_as()
113 .bind(session_id)
114 .bind(Utc::now())
115 .fetch_optional(&mut conn)
116 .await
117 .map_err(InternalServerError)?;
118 Ok(res.map(|(value,)| value.0))
119 }
120
121 async fn update_session(
122 &self,
123 session_id: &str,
124 entries: &BTreeMap<String, Value>,
125 expires: Option<Duration>,
126 ) -> Result<()> {
127 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
128
129 let expires = match expires {
130 Some(expires) => {
131 Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?)
132 }
133 None => None,
134 };
135
136 self.update_stmt
137 .query()
138 .bind(session_id)
139 .bind(Json(entries))
140 .bind(expires.map(|expires| Utc::now() + expires))
141 .execute(&mut conn)
142 .await
143 .map_err(InternalServerError)?;
144 Ok(())
145 }
146
147 async fn remove_session(&self, session_id: &str) -> Result<()> {
148 let mut conn = self.pool.acquire().await.map_err(InternalServerError)?;
149 self.remove_stmt
150 .query()
151 .bind(session_id)
152 .execute(&mut conn)
153 .await
154 .map_err(InternalServerError)?;
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::test_harness;
163
164 #[tokio::test]
165 async fn test() {
166 let pool = PgPool::connect("postgres://postgres:123456@localhost/test_poem_sessions")
167 .await
168 .unwrap();
169
170 let mut conn = pool.acquire().await.unwrap();
171 sqlx::query(
172 r#"
173 create table if not exists poem_sessions (
174 id varchar not null primary key,
175 expires timestamp with time zone null,
176 session jsonb not null
177 )
178 "#,
179 )
180 .execute(&mut conn)
181 .await
182 .unwrap();
183
184 sqlx::query(
185 r#"
186 create index if not exists poem_sessions_expires_idx on poem_sessions (expires)
187 "#,
188 )
189 .execute(&mut conn)
190 .await
191 .unwrap();
192
193 let storage = PgSessionStorage::try_new(DatabaseConfig::new(), pool)
194 .await
195 .unwrap();
196
197 let join_handle = tokio::spawn({
198 let storage = storage.clone();
199 async move {
200 loop {
201 tokio::time::sleep(Duration::from_secs(1)).await;
202 storage.cleanup().await.unwrap();
203 }
204 }
205 });
206 test_harness::test_storage(storage).await;
207 join_handle.abort();
208 }
209}