use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use reinhardt_core::macros::model;
use reinhardt_db::DatabaseConnection;
use reinhardt_db::orm::{DatabaseBackend, Filter, FilterOperator, FilterValue, Model};
use reinhardt_query::prelude::{
Alias, ColumnDef, CreateIndexStatement, Expr, ExprTrait, IntoValue, MySqlQueryBuilder,
OnConflict, PostgresQueryBuilder, Query, QueryStatementBuilder, SqliteQueryBuilder,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::sessions::cleanup::{CleanupableBackend, SessionMetadata};
use super::cache::{SessionBackend, SessionError};
#[model(table_name = "sessions")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
#[field(primary_key = true, max_length = 255)]
pub session_key: String,
#[field(max_length = 65535)]
pub session_data: String,
#[field]
pub expire_date: i64,
#[field]
pub created_at: i64,
#[field]
pub last_accessed: Option<i64>,
}
#[derive(Clone)]
pub struct DatabaseSessionBackend {
connection: Arc<DatabaseConnection>,
}
impl DatabaseSessionBackend {
pub async fn new(database_url: &str) -> Result<Self, SessionError> {
let connection = DatabaseConnection::connect(database_url)
.await
.map_err(|e| SessionError::CacheError(format!("Database connection error: {}", e)))?;
Ok(Self {
connection: Arc::new(connection),
})
}
pub fn from_connection(connection: Arc<DatabaseConnection>) -> Self {
Self { connection }
}
fn build_sql<T>(&self, statement: T) -> String
where
T: QueryStatementBuilder,
{
match self.connection.backend() {
DatabaseBackend::Postgres => statement.to_string(PostgresQueryBuilder),
DatabaseBackend::MySql => statement.to_string(MySqlQueryBuilder),
DatabaseBackend::Sqlite => statement.to_string(SqliteQueryBuilder),
}
}
fn build_table_sql<T>(&self, statement: T) -> String
where
T: QueryStatementBuilder,
{
match self.connection.backend() {
DatabaseBackend::Postgres => statement.to_string(PostgresQueryBuilder),
DatabaseBackend::MySql => statement.to_string(MySqlQueryBuilder),
DatabaseBackend::Sqlite => statement.to_string(SqliteQueryBuilder),
}
}
fn build_index_sql(&self, statement: &CreateIndexStatement) -> String {
match self.connection.backend() {
DatabaseBackend::Postgres => statement.to_string(PostgresQueryBuilder),
DatabaseBackend::MySql => statement.to_string(MySqlQueryBuilder),
DatabaseBackend::Sqlite => statement.to_string(SqliteQueryBuilder),
}
}
pub async fn cleanup_expired(&self) -> Result<u64, SessionError> {
let now_timestamp = Utc::now().timestamp_millis();
let stmt = Query::delete()
.from_table(Alias::new("sessions"))
.and_where(Expr::col(Alias::new("expire_date")).lt(now_timestamp))
.to_owned();
let sql = self.build_sql(stmt);
let rows_affected =
self.connection.execute(&sql, vec![]).await.map_err(|e| {
SessionError::CacheError(format!("Failed to cleanup sessions: {}", e))
})?;
Ok(rows_affected)
}
pub async fn create_table(&self) -> Result<(), SessionError> {
let stmt = Query::create_table()
.table(Alias::new("sessions"))
.if_not_exists()
.col(
ColumnDef::new(Alias::new("session_key"))
.string_len(255)
.not_null(true)
.primary_key(true),
)
.col(
ColumnDef::new(Alias::new("session_data"))
.text()
.not_null(true),
)
.col(
ColumnDef::new(Alias::new("expire_date"))
.big_integer()
.not_null(true),
)
.col(
ColumnDef::new(Alias::new("created_at"))
.big_integer()
.not_null(true),
)
.col(ColumnDef::new(Alias::new("last_accessed")).big_integer())
.to_owned();
let sql = self.build_table_sql(stmt);
self.connection.execute(&sql, vec![]).await.map_err(|e| {
SessionError::CacheError(format!("Failed to create sessions table: {}", e))
})?;
let index_stmt = Query::create_index()
.if_not_exists()
.name("idx_sessions_expire_date")
.table(Alias::new("sessions"))
.col(Alias::new("expire_date"))
.to_owned();
let index_sql = self.build_index_sql(&index_stmt);
let _ = self.connection.execute(&index_sql, vec![]).await;
Ok(())
}
}
#[async_trait]
impl SessionBackend for DatabaseSessionBackend {
async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
where
T: for<'de> Deserialize<'de> + Send,
{
let session = Session::objects()
.filter_by(Filter::new(
"session_key".to_string(),
FilterOperator::Eq,
FilterValue::String(session_key.to_string()),
))
.first()
.await
.ok()
.flatten();
match session {
Some(session) => {
let expire_date =
DateTime::from_timestamp_millis(session.expire_date).unwrap_or_else(Utc::now);
if expire_date < Utc::now() {
let _ = self.delete(session_key).await;
return Ok(None);
}
let data: T = serde_json::from_str(&session.session_data).map_err(|e| {
SessionError::SerializationError(format!("Deserialization error: {}", e))
})?;
Ok(Some(data))
}
None => Ok(None),
}
}
async fn save<T>(
&self,
session_key: &str,
data: &T,
ttl: Option<u64>,
) -> Result<(), SessionError>
where
T: Serialize + Send + Sync,
{
let session_data = serde_json::to_string(data)
.map_err(|e| SessionError::SerializationError(format!("Serialization error: {}", e)))?;
let now = Utc::now();
let expire_date = match ttl {
Some(seconds) => now + Duration::seconds(seconds as i64),
None => now + Duration::days(14), };
let now_timestamp = now.timestamp_millis();
let expire_timestamp = expire_date.timestamp_millis();
let stmt = Query::insert()
.into_table(Alias::new("sessions"))
.columns([
Alias::new("session_key"),
Alias::new("session_data"),
Alias::new("expire_date"),
Alias::new("created_at"),
Alias::new("last_accessed"),
])
.values_panic(vec![
session_key.into_value(),
session_data.into_value(),
expire_timestamp.into_value(),
now_timestamp.into_value(),
now_timestamp.into_value(),
])
.on_conflict(
OnConflict::column(Alias::new("session_key"))
.update_columns([
Alias::new("session_data"),
Alias::new("expire_date"),
Alias::new("last_accessed"),
])
.to_owned(),
)
.to_owned();
let sql = self.build_sql(stmt);
self.connection
.execute(&sql, vec![])
.await
.map_err(|e| SessionError::CacheError(format!("Failed to save session: {}", e)))?;
Ok(())
}
async fn delete(&self, session_key: &str) -> Result<(), SessionError> {
let stmt = Query::delete()
.from_table(Alias::new("sessions"))
.and_where(Expr::col(Alias::new("session_key")).eq(session_key))
.to_owned();
let sql = self.build_sql(stmt);
self.connection
.execute(&sql, vec![])
.await
.map_err(|e| SessionError::CacheError(format!("Failed to delete session: {}", e)))?;
Ok(())
}
async fn exists(&self, session_key: &str) -> Result<bool, SessionError> {
let now_timestamp = Utc::now().timestamp_millis();
let session = Session::objects()
.filter_by(Filter::new(
"session_key".to_string(),
FilterOperator::Eq,
FilterValue::String(session_key.to_string()),
))
.filter(Filter::new(
"expire_date".to_string(),
FilterOperator::Gt,
FilterValue::Integer(now_timestamp),
))
.first()
.await
.ok()
.flatten();
Ok(session.is_some())
}
}
#[async_trait]
impl CleanupableBackend for DatabaseSessionBackend {
async fn get_all_keys(&self) -> Result<Vec<String>, SessionError> {
let sessions = Session::objects()
.all()
.all()
.await
.map_err(|e| SessionError::CacheError(format!("Failed to get all keys: {}", e)))?;
let keys: Vec<String> = sessions.into_iter().map(|s| s.session_key).collect();
Ok(keys)
}
async fn get_metadata(
&self,
session_key: &str,
) -> Result<Option<SessionMetadata>, SessionError> {
let session = Session::objects()
.filter_by(Filter::new(
"session_key".to_string(),
FilterOperator::Eq,
FilterValue::String(session_key.to_string()),
))
.first()
.await
.ok()
.flatten();
match session {
Some(session) => {
let created_at =
DateTime::from_timestamp_millis(session.created_at).unwrap_or_else(Utc::now);
let last_accessed = session
.last_accessed
.and_then(DateTime::from_timestamp_millis);
Ok(Some(SessionMetadata {
created_at,
last_accessed,
}))
}
None => Ok(None),
}
}
async fn list_keys_with_prefix(&self, prefix: &str) -> Result<Vec<String>, SessionError> {
let sessions = Session::objects()
.filter_by(Filter::new(
"session_key".to_string(),
FilterOperator::StartsWith,
FilterValue::String(prefix.to_string()),
))
.all()
.await
.map_err(|e| SessionError::CacheError(format!("Failed to list session keys: {}", e)))?;
let keys: Vec<String> = sessions.into_iter().map(|s| s.session_key).collect();
Ok(keys)
}
async fn count_keys_with_prefix(&self, prefix: &str) -> Result<usize, SessionError> {
let count = Session::objects()
.filter_by(Filter::new(
"session_key".to_string(),
FilterOperator::StartsWith,
FilterValue::String(prefix.to_string()),
))
.count()
.await
.map_err(|e| {
SessionError::CacheError(format!("Failed to count session keys: {}", e))
})?;
Ok(count)
}
async fn delete_keys_with_prefix(&self, prefix: &str) -> Result<usize, SessionError> {
let pattern = format!("{}%", prefix);
let stmt = Query::delete()
.from_table(Alias::new("sessions"))
.and_where(Expr::col(Alias::new("session_key")).like(pattern.as_str()))
.to_owned();
let sql = self.build_sql(stmt);
let rows_affected = self.connection.execute(&sql, vec![]).await.map_err(|e| {
SessionError::CacheError(format!("Failed to delete session keys: {}", e))
})?;
Ok(rows_affected as usize)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_struct_fields() {
let now_ms = Utc::now().timestamp_millis();
let session = Session::new(
"test_key".to_string(),
r#"{"user_id": 42}"#.to_string(),
now_ms + 3600000, now_ms,
Some(now_ms),
);
assert_eq!(session.session_key, "test_key");
assert_eq!(session.session_data, r#"{"user_id": 42}"#);
assert_eq!(session.expire_date, now_ms + 3600000);
assert_eq!(session.created_at, now_ms);
assert_eq!(session.last_accessed, Some(now_ms));
}
#[test]
fn test_session_struct_without_last_accessed() {
let now_ms = Utc::now().timestamp_millis();
let session = Session::new(
"key".to_string(),
"{}".to_string(),
now_ms + 1000,
now_ms,
None,
);
assert!(session.last_accessed.is_none());
}
#[test]
fn test_session_clone() {
let now_ms = Utc::now().timestamp_millis();
let session = Session::new(
"clone_test".to_string(),
r#"{"data": "value"}"#.to_string(),
now_ms + 3600000,
now_ms,
Some(now_ms),
);
let cloned = session.clone();
assert_eq!(cloned.session_key, session.session_key);
assert_eq!(cloned.session_data, session.session_data);
assert_eq!(cloned.expire_date, session.expire_date);
assert_eq!(cloned.created_at, session.created_at);
assert_eq!(cloned.last_accessed, session.last_accessed);
}
#[test]
fn test_session_debug() {
let now_ms = Utc::now().timestamp_millis();
let session = Session::new(
"debug_key".to_string(),
"{}".to_string(),
now_ms,
now_ms,
None,
);
let debug_str = format!("{:?}", session);
assert!(debug_str.contains("Session"));
assert!(debug_str.contains("debug_key"));
}
#[test]
fn test_session_serialize() {
let now_ms = Utc::now().timestamp_millis();
let session = Session::new(
"serialize_key".to_string(),
r#"{"count": 10}"#.to_string(),
now_ms + 3600000,
now_ms,
Some(now_ms),
);
let json = serde_json::to_string(&session).unwrap();
assert!(json.contains("serialize_key"));
assert!(json.contains(r#"{\"count\": 10}"#));
}
#[test]
fn test_session_deserialize() {
let now_ms = 1700000000000_i64; let json = format!(
r#"{{
"session_key": "deserialize_key",
"session_data": "{{\"user\": \"test\"}}",
"expire_date": {},
"created_at": {},
"last_accessed": {}
}}"#,
now_ms + 3600000,
now_ms,
now_ms
);
let session: Session = serde_json::from_str(&json).unwrap();
assert_eq!(session.session_key, "deserialize_key");
assert_eq!(session.session_data, r#"{"user": "test"}"#);
assert_eq!(session.expire_date, now_ms + 3600000);
assert_eq!(session.created_at, now_ms);
assert_eq!(session.last_accessed, Some(now_ms));
}
#[test]
fn test_session_deserialize_without_last_accessed() {
let now_ms = 1700000000000_i64;
let json = format!(
r#"{{
"session_key": "no_access",
"session_data": "{{}}",
"expire_date": {},
"created_at": {},
"last_accessed": null
}}"#,
now_ms + 3600000,
now_ms
);
let session: Session = serde_json::from_str(&json).unwrap();
assert_eq!(session.session_key, "no_access");
assert!(session.last_accessed.is_none());
}
#[test]
fn test_database_session_backend_clone() {
fn assert_clone<T: Clone>() {}
assert_clone::<DatabaseSessionBackend>();
}
}