use std::collections::HashMap;
use std::error::Error;
use async_trait::async_trait;
use thiserror::Error;
use tower_sessions::session::{Id, Record};
use tower_sessions::{SessionStore, session_store};
use crate::db::{Auto, Database, DatabaseError, Model, query};
use crate::session::db::Session;
use crate::session::store::{ERROR_PREFIX, MAX_COLLISION_RETRIES};
use crate::utils::chrono::DateTimeWithOffsetAdapter;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum DbStoreError {
#[error("{ERROR_PREFIX} {0} ")]
DatabaseError(#[from] DatabaseError),
#[error("{ERROR_PREFIX} sessionāid collision retried too many times ({0})")]
TooManyIdCollisions(u32),
#[error("{ERROR_PREFIX} JSON serialization error: {0}")]
Serialize(Box<dyn Error + Send + Sync>),
#[error("{ERROR_PREFIX} JSON serialization error: {0}")]
Deserialize(Box<dyn Error + Send + Sync>),
}
impl From<DbStoreError> for session_store::Error {
fn from(err: DbStoreError) -> Self {
match err {
DbStoreError::DatabaseError(db_err) => {
session_store::Error::Backend(db_err.to_string())
}
DbStoreError::Serialize(ser_err) => session_store::Error::Encode(ser_err.to_string()),
DbStoreError::Deserialize(de_err) => session_store::Error::Decode(de_err.to_string()),
other => session_store::Error::Backend(other.to_string()),
}
}
}
#[derive(Clone, Debug)]
pub struct DbStore {
connection: Database,
}
impl DbStore {
#[must_use]
pub fn new(connection: Database) -> DbStore {
DbStore { connection }
}
}
#[async_trait]
impl SessionStore for DbStore {
async fn create(&self, record: &mut Record) -> session_store::Result<()> {
for _ in 0..=MAX_COLLISION_RETRIES {
let key = record.id.to_string();
let data = serde_json::to_string(&record.data).unwrap();
let expiry = DateTimeWithOffsetAdapter::try_from(record.expiry_date)
.expect("Failed to convert expiry date to a valid datetime")
.into_chrono_db_safe();
let mut model = Session {
id: Auto::auto(),
key,
data,
expiry,
};
let res = self.connection.insert(&mut model).await;
match res {
Ok(()) => {
return Ok(());
}
Err(DatabaseError::UniqueViolation) => {
record.id = Id::default();
}
Err(err) => return Err(DbStoreError::DatabaseError(err))?,
}
}
Err(DbStoreError::TooManyIdCollisions(MAX_COLLISION_RETRIES))?
}
async fn save(&self, record: &Record) -> session_store::Result<()> {
let key = record.id.to_string();
let data = serde_json::to_string(&record.data)
.map_err(|err| DbStoreError::Serialize(Box::new(err)))?;
let query = query!(Session, $key == key)
.get(&self.connection)
.await
.map_err(DbStoreError::DatabaseError)?;
if let Some(mut model) = query {
model.data = data;
model
.update(&self.connection)
.await
.map_err(DbStoreError::DatabaseError)?;
} else {
let mut record = record.clone();
self.create(&mut record).await?;
}
Ok(())
}
async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
let key = session_id.to_string();
let query = query!(Session, $key == key)
.get(&self.connection)
.await
.map_err(DbStoreError::DatabaseError)?;
if let Some(session) = query {
let data = serde_json::from_str::<HashMap<String, serde_json::Value>>(&session.data)
.map_err(|err| DbStoreError::Serialize(Box::new(err)))?;
let id = session
.key
.parse::<Id>()
.map_err(|err| DbStoreError::Deserialize(Box::new(err)))?;
let expiry_date = DateTimeWithOffsetAdapter::new(session.expiry).into_offsetdatetime();
let rec = Record {
id,
data,
expiry_date,
};
Ok(Some(rec))
} else {
Ok(None)
}
}
async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
let key = session_id.to_string();
query!(Session, $key == key)
.delete(&self.connection)
.await
.map_err(DbStoreError::DatabaseError)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::io;
use tower_sessions::session::Record;
use super::*;
use crate::db::DatabaseError;
#[cot::test]
async fn test_from_db_store_error_to_session_store_error() {
let sqlx_err = sqlx::Error::Protocol("protocol error".into());
let db_err = DatabaseError::DatabaseEngineError(sqlx_err);
let sess_err: session_store::Error = DbStoreError::DatabaseError(db_err).into();
assert!(matches!(sess_err, session_store::Error::Backend(_)));
let io_err = io::Error::other("oops");
let serialize_err: session_store::Error = DbStoreError::Serialize(Box::new(io_err)).into();
assert!(matches!(serialize_err, session_store::Error::Encode(_)));
let parse_err = serde_json::from_str::<Record>("not a json").unwrap_err();
let deserialize_err: session_store::Error =
DbStoreError::Deserialize(Box::new(parse_err)).into();
assert!(matches!(deserialize_err, session_store::Error::Decode(_)));
let sess_err: session_store::Error = DbStoreError::TooManyIdCollisions(99).into();
assert!(matches!(sess_err, session_store::Error::Backend(_)));
}
}