axum_session_mongo/
lib.rs1#![doc = include_str!("../README.md")]
2#![allow(dead_code)]
3#![warn(clippy::all, nonstandard_style, future_incompatible)]
4#![forbid(unsafe_code)]
5
6use async_trait::async_trait;
7use axum_session::{DatabaseError, DatabasePool, Session, SessionStore};
8use chrono::Utc;
9use mongodb::{
10 bson::{doc, Document},
11 Client,
12};
13use serde::{Deserialize, Serialize};
14
15pub type SessionMongoSession = Session<SessionMongoPool>;
16pub type SessionMongoSessionStore = SessionStore<SessionMongoPool>;
17
18#[derive(Default, Debug, Serialize, Deserialize)]
19struct MongoSessionData {
20 id: String,
21 expires: i64,
22 session: String,
23}
24impl MongoSessionData {
25 fn to_document(&self) -> Document {
26 doc! {
27 "id": &self.id,
28 "expires": self.expires,
29 "session": &self.session
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct SessionMongoPool {
37 client: Client,
38}
39
40impl From<Client> for SessionMongoPool {
41 fn from(client: Client) -> Self {
42 SessionMongoPool { client }
43 }
44}
45
46#[async_trait]
47impl DatabasePool for SessionMongoPool {
48 async fn initiate(&self, table_name: &str) -> Result<(), DatabaseError> {
51 let tmp = MongoSessionData::default();
52
53 if let Some(db) = &self.client.default_database() {
54 let col = db.collection::<MongoSessionData>(table_name);
55
56 let _ = &col
57 .insert_one(&tmp)
58 .await
59 .map_err(|err| DatabaseError::GenericInsertError(err.to_string()))?;
60 let _ = col
61 .find_one_and_delete(tmp.to_document())
62 .await
63 .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
64 }
65
66 Ok(())
67 }
68
69 async fn delete_by_expiry(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
70 let mut ids: Vec<String> = Vec::new();
71
72 if let Some(db) = &self.client.default_database() {
73 let now = Utc::now().timestamp();
74 let filter = doc! {"expires":
75 {"$lte": now}
76 };
77 let result = db
78 .collection::<MongoSessionData>(table_name)
79 .find(filter.clone())
80 .await
81 .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?;
82
83 for item in result.deserialize_current().iter() {
84 if !&item.id.is_empty() {
85 ids.push(item.id.clone());
86 };
87 }
88 db.collection::<MongoSessionData>(table_name)
89 .delete_many(filter)
90 .await
91 .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
92 }
93
94 Ok(ids)
95 }
96
97 async fn count(&self, table_name: &str) -> Result<i64, DatabaseError> {
98 Ok(match &self.client.default_database() {
99 Some(db) => db
100 .collection::<MongoSessionData>(table_name)
101 .estimated_document_count()
102 .await
103 .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?
104 as i64,
105 None => 0,
106 })
107 }
108
109 async fn store(
110 &self,
111 id: &str,
112 session: &str,
113 expires: i64,
114 table_name: &str,
115 ) -> Result<(), DatabaseError> {
116 if let Some(db) = &self.client.default_database() {
117 let filter = doc! {
118 "id": id
119 };
120 let update_data = doc! {"$set": {
121 "id": id.to_string(),
122 "expires": expires,
123 "session": session.to_string()
124 }};
125
126 db.collection::<MongoSessionData>(table_name)
127 .update_one(filter, update_data)
128 .upsert(true)
129 .await
130 .map_err(|err| DatabaseError::GenericInsertError(err.to_string()))?;
131 }
132
133 Ok(())
134 }
135
136 async fn load(&self, id: &str, table_name: &str) -> Result<Option<String>, DatabaseError> {
137 Ok(match &self.client.default_database() {
138 Some(db) => {
139 let filter = doc! {
140 "id": id,
141 "expires":
142 {"$gte": Utc::now().timestamp()}
143 };
144 match db
145 .collection::<MongoSessionData>(table_name)
146 .find_one(filter)
147 .await
148 .unwrap_or_default()
149 {
150 Some(result) => {
151 if result.session.is_empty() {
152 None
153 } else {
154 Some(result.session)
155 }
156 }
157 None => None,
158 }
159 }
160 None => None,
161 })
162 }
163
164 async fn delete_one_by_id(&self, id: &str, table_name: &str) -> Result<(), DatabaseError> {
165 if let Some(db) = &self.client.default_database() {
166 let _ = db
167 .collection::<MongoSessionData>(table_name)
168 .delete_one(doc! {"id": id})
169 .await
170 .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
171 }
172
173 Ok(())
174 }
175
176 async fn exists(&self, id: &str, table_name: &str) -> Result<bool, DatabaseError> {
177 Ok(match &self.client.default_database() {
178 Some(db) => db
179 .collection::<MongoSessionData>(table_name)
180 .find_one(doc! {"id": id})
181 .await
182 .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?
183 .is_some(),
184 None => false,
185 })
186 }
187
188 async fn delete_all(&self, table_name: &str) -> Result<(), DatabaseError> {
189 if let Some(db) = &self.client.default_database() {
190 let _ = db
191 .collection::<MongoSessionData>(table_name)
192 .drop()
193 .await
194 .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
195 }
196
197 Ok(())
198 }
199
200 async fn get_ids(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
201 let mut ids: Vec<String> = Vec::new();
202 if let Some(db) = &self.client.default_database() {
203 let filter = doc! {"expires":
204 {"$gte": Utc::now().timestamp()}
205 };
206 let result = db
207 .collection::<MongoSessionData>(table_name)
208 .find(filter)
209 .await
210 .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?; for item in result.deserialize_current().iter() {
213 if !&item.id.is_empty() {
214 ids.push(item.id.clone());
215 };
216 }
217 }
218
219 Ok(ids)
220 }
221
222 fn auto_handles_expiry(&self) -> bool {
223 false
224 }
225}