use axum::{extract::FromRequestParts, http::request::Parts};
use serde::{Deserialize, Serialize};
use tower_sessions::Session;
use crate::error::Error;
const FLASH_SESSION_KEY: &str = "_flash_messages";
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FlashKind {
Success,
Info,
Warning,
Error,
}
impl FlashKind {
#[must_use]
pub fn css_class(&self) -> &'static str {
match self {
Self::Success => "flash-success",
Self::Info => "flash-info",
Self::Warning => "flash-warning",
Self::Error => "flash-error",
}
}
#[must_use]
pub fn icon(&self) -> &'static str {
match self {
Self::Success => "check-circle",
Self::Info => "info-circle",
Self::Warning => "exclamation-triangle",
Self::Error => "x-circle",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlashMessage {
pub kind: FlashKind,
pub message: String,
}
impl FlashMessage {
#[must_use]
pub fn new(kind: FlashKind, message: impl Into<String>) -> Self {
Self {
kind,
message: message.into(),
}
}
#[must_use]
pub fn success(message: impl Into<String>) -> Self {
Self::new(FlashKind::Success, message)
}
#[must_use]
pub fn info(message: impl Into<String>) -> Self {
Self::new(FlashKind::Info, message)
}
#[must_use]
pub fn warning(message: impl Into<String>) -> Self {
Self::new(FlashKind::Warning, message)
}
#[must_use]
pub fn error(message: impl Into<String>) -> Self {
Self::new(FlashKind::Error, message)
}
}
pub struct FlashMessages {
messages: Vec<FlashMessage>,
}
impl FlashMessages {
#[must_use]
pub fn messages(&self) -> &[FlashMessage] {
&self.messages
}
#[must_use]
pub fn into_messages(self) -> Vec<FlashMessage> {
self.messages
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.messages.len()
}
#[must_use]
pub fn by_kind(&self, kind: FlashKind) -> Vec<&FlashMessage> {
self.messages.iter().filter(|m| m.kind == kind).collect()
}
#[must_use]
pub fn has_success(&self) -> bool {
self.messages.iter().any(|m| m.kind == FlashKind::Success)
}
#[must_use]
pub fn has_errors(&self) -> bool {
self.messages.iter().any(|m| m.kind == FlashKind::Error)
}
pub async fn push(session: &Session, message: FlashMessage) -> Result<(), Error> {
let mut messages: Vec<FlashMessage> = session
.get(FLASH_SESSION_KEY)
.await
.map_err(|e| Error::Session(format!("Failed to read flash messages: {e}")))?
.unwrap_or_default();
messages.push(message);
session
.insert(FLASH_SESSION_KEY, &messages)
.await
.map_err(|e| Error::Session(format!("Failed to write flash messages: {e}")))
}
pub async fn push_many(
session: &Session,
new_messages: impl IntoIterator<Item = FlashMessage>,
) -> Result<(), Error> {
let mut messages: Vec<FlashMessage> = session
.get(FLASH_SESSION_KEY)
.await
.map_err(|e| Error::Session(format!("Failed to read flash messages: {e}")))?
.unwrap_or_default();
messages.extend(new_messages);
session
.insert(FLASH_SESSION_KEY, &messages)
.await
.map_err(|e| Error::Session(format!("Failed to write flash messages: {e}")))
}
pub async fn clear(session: &Session) -> Result<(), Error> {
session
.remove::<Vec<FlashMessage>>(FLASH_SESSION_KEY)
.await
.map_err(|e| Error::Session(format!("Failed to clear flash messages: {e}")))?;
Ok(())
}
}
impl<S> FromRequestParts<S> for FlashMessages
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let session = parts.extensions.get::<Session>().cloned().ok_or_else(|| {
Error::Session("Session not found in request extensions for flash messages".to_string())
})?;
let messages: Vec<FlashMessage> = session
.remove(FLASH_SESSION_KEY)
.await
.map_err(|e| Error::Session(format!("Failed to read flash messages: {e}")))?
.unwrap_or_default();
Ok(Self { messages })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flash_message_constructors() {
let success = FlashMessage::success("Done!");
assert_eq!(success.kind, FlashKind::Success);
assert_eq!(success.message, "Done!");
let error = FlashMessage::error("Failed");
assert_eq!(error.kind, FlashKind::Error);
assert_eq!(error.message, "Failed");
}
#[test]
fn test_flash_kind_css_class() {
assert_eq!(FlashKind::Success.css_class(), "flash-success");
assert_eq!(FlashKind::Error.css_class(), "flash-error");
assert_eq!(FlashKind::Warning.css_class(), "flash-warning");
assert_eq!(FlashKind::Info.css_class(), "flash-info");
}
#[test]
fn test_flash_messages_filtering() {
let messages = FlashMessages {
messages: vec![
FlashMessage::success("OK"),
FlashMessage::error("Bad"),
FlashMessage::success("Also OK"),
],
};
assert_eq!(messages.len(), 3);
assert!(!messages.is_empty());
assert!(messages.has_success());
assert!(messages.has_errors());
let successes = messages.by_kind(FlashKind::Success);
assert_eq!(successes.len(), 2);
}
}