axum_session_mongo/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(dead_code)]
3#![warn(clippy::all, nonstandard_style, future_incompatible)]
4#![forbid(unsafe_code)]
5
6use async_trait::async_trait;
7use axum_session::{DatabaseError, DatabasePool, Session, SessionStore};
8use chrono::Utc;
9use mongodb::{
10    bson::{doc, Document},
11    Client,
12};
13use serde::{Deserialize, Serialize};
14
15pub type SessionMongoSession = Session<SessionMongoPool>;
16pub type SessionMongoSessionStore = SessionStore<SessionMongoPool>;
17
18#[derive(Default, Debug, Serialize, Deserialize)]
19struct MongoSessionData {
20    id: String,
21    expires: i64,
22    session: String,
23}
24impl MongoSessionData {
25    fn to_document(&self) -> Document {
26        doc! {
27            "id": &self.id,
28            "expires": self.expires,
29            "session": &self.session
30        }
31    }
32}
33
34///Mongodb's Pool type for the DatabasePool. Needs a mongodb Client.
35#[derive(Debug, Clone)]
36pub struct SessionMongoPool {
37    client: Client,
38}
39
40impl From<Client> for SessionMongoPool {
41    fn from(client: Client) -> Self {
42        SessionMongoPool { client }
43    }
44}
45
46#[async_trait]
47impl DatabasePool for SessionMongoPool {
48    // Make sure the collection exists in the database
49    // by inserting a record then deleting it
50    async fn initiate(&self, table_name: &str) -> Result<(), DatabaseError> {
51        let tmp = MongoSessionData::default();
52
53        if let Some(db) = &self.client.default_database() {
54            let col = db.collection::<MongoSessionData>(table_name);
55
56            let _ = &col
57                .insert_one(&tmp)
58                .await
59                .map_err(|err| DatabaseError::GenericInsertError(err.to_string()))?;
60            let _ = col
61                .find_one_and_delete(tmp.to_document())
62                .await
63                .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
64        }
65
66        Ok(())
67    }
68
69    async fn delete_by_expiry(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
70        let mut ids: Vec<String> = Vec::new();
71
72        if let Some(db) = &self.client.default_database() {
73            let now = Utc::now().timestamp();
74            let filter = doc! {"expires":
75                {"$lte": now}
76            };
77            let result = db
78                .collection::<MongoSessionData>(table_name)
79                .find(filter.clone())
80                .await
81                .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?;
82
83            for item in result.deserialize_current().iter() {
84                if !&item.id.is_empty() {
85                    ids.push(item.id.clone());
86                };
87            }
88            db.collection::<MongoSessionData>(table_name)
89                .delete_many(filter)
90                .await
91                .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
92        }
93
94        Ok(ids)
95    }
96
97    async fn count(&self, table_name: &str) -> Result<i64, DatabaseError> {
98        Ok(match &self.client.default_database() {
99            Some(db) => db
100                .collection::<MongoSessionData>(table_name)
101                .estimated_document_count()
102                .await
103                .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?
104                as i64,
105            None => 0,
106        })
107    }
108
109    async fn store(
110        &self,
111        id: &str,
112        session: &str,
113        expires: i64,
114        table_name: &str,
115    ) -> Result<(), DatabaseError> {
116        if let Some(db) = &self.client.default_database() {
117            let filter = doc! {
118                "id": id
119            };
120            let update_data = doc! {"$set": {
121                "id": id.to_string(),
122                "expires": expires,
123                "session": session.to_string()
124            }};
125
126            db.collection::<MongoSessionData>(table_name)
127                .update_one(filter, update_data)
128                .upsert(true)
129                .await
130                .map_err(|err| DatabaseError::GenericInsertError(err.to_string()))?;
131        }
132
133        Ok(())
134    }
135
136    async fn load(&self, id: &str, table_name: &str) -> Result<Option<String>, DatabaseError> {
137        Ok(match &self.client.default_database() {
138            Some(db) => {
139                let filter = doc! {
140                    "id": id,
141                    "expires":
142                        {"$gte": Utc::now().timestamp()}
143                };
144                match db
145                    .collection::<MongoSessionData>(table_name)
146                    .find_one(filter)
147                    .await
148                    .unwrap_or_default()
149                {
150                    Some(result) => {
151                        if result.session.is_empty() {
152                            None
153                        } else {
154                            Some(result.session)
155                        }
156                    }
157                    None => None,
158                }
159            }
160            None => None,
161        })
162    }
163
164    async fn delete_one_by_id(&self, id: &str, table_name: &str) -> Result<(), DatabaseError> {
165        if let Some(db) = &self.client.default_database() {
166            let _ = db
167                .collection::<MongoSessionData>(table_name)
168                .delete_one(doc! {"id": id})
169                .await
170                .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
171        }
172
173        Ok(())
174    }
175
176    async fn exists(&self, id: &str, table_name: &str) -> Result<bool, DatabaseError> {
177        Ok(match &self.client.default_database() {
178            Some(db) => db
179                .collection::<MongoSessionData>(table_name)
180                .find_one(doc! {"id": id})
181                .await
182                .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?
183                .is_some(),
184            None => false,
185        })
186    }
187
188    async fn delete_all(&self, table_name: &str) -> Result<(), DatabaseError> {
189        if let Some(db) = &self.client.default_database() {
190            let _ = db
191                .collection::<MongoSessionData>(table_name)
192                .drop()
193                .await
194                .map_err(|err| DatabaseError::GenericDeleteError(err.to_string()))?;
195        }
196
197        Ok(())
198    }
199
200    async fn get_ids(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
201        let mut ids: Vec<String> = Vec::new();
202        if let Some(db) = &self.client.default_database() {
203            let filter = doc! {"expires":
204                {"$gte": Utc::now().timestamp()}
205            };
206            let result = db
207                .collection::<MongoSessionData>(table_name)
208                .find(filter)
209                .await
210                .map_err(|err| DatabaseError::GenericSelectError(err.to_string()))?; // add filter for expiration
211
212            for item in result.deserialize_current().iter() {
213                if !&item.id.is_empty() {
214                    ids.push(item.id.clone());
215                };
216            }
217        }
218
219        Ok(ids)
220    }
221
222    fn auto_handles_expiry(&self) -> bool {
223        false
224    }
225}