use async_trait::async_trait;
use hyper::header::COOKIE;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::session::SessionData;
pub const MESSAGE_HEADER: &str = "X-Messages";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageLevel {
Debug,
Info,
Success,
Warning,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub level: MessageLevel,
pub text: String,
}
impl Message {
pub fn new(level: MessageLevel, text: String) -> Self {
Self { level, text }
}
pub fn debug(text: String) -> Self {
Self::new(MessageLevel::Debug, text)
}
pub fn info(text: String) -> Self {
Self::new(MessageLevel::Info, text)
}
pub fn success(text: String) -> Self {
Self::new(MessageLevel::Success, text)
}
pub fn warning(text: String) -> Self {
Self::new(MessageLevel::Warning, text)
}
pub fn error(text: String) -> Self {
Self::new(MessageLevel::Error, text)
}
}
pub trait MessageStorage: Send + Sync {
fn add_message(&self, session_id: &str, message: Message);
fn get_and_clear_messages(&self, session_id: &str) -> Vec<Message>;
fn get_messages(&self, session_id: &str) -> Vec<Message>;
}
pub struct SessionStorage {
messages: Arc<RwLock<HashMap<String, Vec<Message>>>>,
}
impl SessionStorage {
pub fn new() -> Self {
Self {
messages: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for SessionStorage {
fn default() -> Self {
Self::new()
}
}
impl MessageStorage for SessionStorage {
fn add_message(&self, session_id: &str, message: Message) {
let mut messages = self.messages.write().unwrap_or_else(|e| e.into_inner());
messages
.entry(session_id.to_string())
.or_default()
.push(message);
}
fn get_and_clear_messages(&self, session_id: &str) -> Vec<Message> {
let mut messages = self.messages.write().unwrap_or_else(|e| e.into_inner());
messages.remove(session_id).unwrap_or_default()
}
fn get_messages(&self, session_id: &str) -> Vec<Message> {
let messages = self.messages.read().unwrap_or_else(|e| e.into_inner());
messages.get(session_id).cloned().unwrap_or_default()
}
}
pub struct CookieStorage {
messages: Arc<RwLock<HashMap<String, Vec<Message>>>>,
}
impl CookieStorage {
pub fn new() -> Self {
Self {
messages: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for CookieStorage {
fn default() -> Self {
Self::new()
}
}
impl MessageStorage for CookieStorage {
fn add_message(&self, session_id: &str, message: Message) {
let mut messages = self.messages.write().unwrap_or_else(|e| e.into_inner());
messages
.entry(session_id.to_string())
.or_default()
.push(message);
}
fn get_and_clear_messages(&self, session_id: &str) -> Vec<Message> {
let mut messages = self.messages.write().unwrap_or_else(|e| e.into_inner());
messages.remove(session_id).unwrap_or_default()
}
fn get_messages(&self, session_id: &str) -> Vec<Message> {
let messages = self.messages.read().unwrap_or_else(|e| e.into_inner());
messages.get(session_id).cloned().unwrap_or_default()
}
}
#[allow(dead_code)]
pub struct MessageMiddleware {
storage: Arc<dyn MessageStorage>,
}
impl MessageMiddleware {
pub fn new(storage: Arc<dyn MessageStorage>) -> Self {
Self { storage }
}
fn get_session_id(request: &Request) -> String {
if let Some(session_data) = request.extensions.get::<SessionData>() {
return session_data.id.clone();
}
request
.headers
.get(COOKIE)
.and_then(|c| c.to_str().ok())
.and_then(|cookies| {
for cookie in cookies.split(';') {
let cookie = cookie.trim();
if let Some((name, value)) = cookie.split_once('=')
&& name == "sessionid"
{
return Some(value.to_string());
}
}
None
})
.unwrap_or_else(|| "default".to_string())
}
}
#[async_trait]
impl Middleware for MessageMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let _session_id = Self::get_session_id(&request);
let response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
#[test]
fn test_message_creation() {
let msg = Message::debug("Debug message".to_string());
assert_eq!(msg.level, MessageLevel::Debug);
let msg = Message::info("Info message".to_string());
assert_eq!(msg.level, MessageLevel::Info);
let msg = Message::success("Success message".to_string());
assert_eq!(msg.level, MessageLevel::Success);
let msg = Message::warning("Warning message".to_string());
assert_eq!(msg.level, MessageLevel::Warning);
let msg = Message::error("Error message".to_string());
assert_eq!(msg.level, MessageLevel::Error);
}
#[test]
fn test_session_storage_add_and_get() {
let storage = SessionStorage::new();
let session_id = "test-session";
storage.add_message(session_id, Message::info("Message 1".to_string()));
storage.add_message(session_id, Message::success("Message 2".to_string()));
let messages = storage.get_messages(session_id);
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].level, MessageLevel::Info);
assert_eq!(messages[1].level, MessageLevel::Success);
}
#[test]
fn test_session_storage_clear() {
let storage = SessionStorage::new();
let session_id = "test-session";
storage.add_message(session_id, Message::info("Message 1".to_string()));
storage.add_message(session_id, Message::info("Message 2".to_string()));
let messages = storage.get_and_clear_messages(session_id);
assert_eq!(messages.len(), 2);
let messages = storage.get_messages(session_id);
assert_eq!(messages.len(), 0);
}
#[test]
fn test_cookie_storage_add_and_get() {
let storage = CookieStorage::new();
let session_id = "test-session";
storage.add_message(session_id, Message::warning("Warning 1".to_string()));
storage.add_message(session_id, Message::error("Error 1".to_string()));
let messages = storage.get_messages(session_id);
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].level, MessageLevel::Warning);
assert_eq!(messages[1].level, MessageLevel::Error);
}
#[test]
fn test_cookie_storage_clear() {
let storage = CookieStorage::new();
let session_id = "test-session";
storage.add_message(session_id, Message::info("Info 1".to_string()));
let messages = storage.get_and_clear_messages(session_id);
assert_eq!(messages.len(), 1);
let messages = storage.get_messages(session_id);
assert_eq!(messages.len(), 0);
}
#[test]
fn test_separate_sessions() {
let storage = SessionStorage::new();
storage.add_message("session1", Message::info("Session 1 message".to_string()));
storage.add_message(
"session2",
Message::success("Session 2 message".to_string()),
);
let messages1 = storage.get_messages("session1");
let messages2 = storage.get_messages("session2");
assert_eq!(messages1.len(), 1);
assert_eq!(messages2.len(), 1);
assert_eq!(messages1[0].level, MessageLevel::Info);
assert_eq!(messages2[0].level, MessageLevel::Success);
}
struct TestHandler {
storage: Arc<dyn MessageStorage>,
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, request: Request) -> Result<Response> {
let session_id = MessageMiddleware::get_session_id(&request);
self.storage
.add_message(&session_id, Message::success("Test message".to_string()));
Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
}
}
#[tokio::test]
async fn test_middleware_with_session_storage() {
let storage: Arc<dyn MessageStorage> = Arc::new(SessionStorage::new());
let middleware = MessageMiddleware::new(storage.clone());
let handler = Arc::new(TestHandler {
storage: storage.clone(),
});
let mut headers = HeaderMap::new();
headers.insert(COOKIE, "sessionid=test-session".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let messages = storage.get_and_clear_messages("test-session");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].level, MessageLevel::Success);
}
#[tokio::test]
async fn test_middleware_default_session() {
let storage: Arc<dyn MessageStorage> = Arc::new(SessionStorage::new());
let middleware = MessageMiddleware::new(storage.clone());
let handler = Arc::new(TestHandler {
storage: storage.clone(),
});
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let messages = storage.get_messages("default");
assert_eq!(messages.len(), 1);
}
#[tokio::test]
async fn test_middleware_with_cookie_storage() {
let storage: Arc<dyn MessageStorage> = Arc::new(CookieStorage::new());
let middleware = MessageMiddleware::new(storage.clone());
let handler = Arc::new(TestHandler {
storage: storage.clone(),
});
let mut headers = HeaderMap::new();
headers.insert(COOKIE, "sessionid=cookie-session".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let messages = storage.get_and_clear_messages("cookie-session");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].level, MessageLevel::Success);
}
}