use axum::extract::FromRequestParts;
use http::request::Parts;
use serde::{Deserialize, Serialize};
use crate::session::Session;
const FLASH_SESSION_KEY: &str = "__autumn_flash";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum FlashLevel {
Success,
Info,
Warning,
Error,
}
impl FlashLevel {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Success => "success",
Self::Info => "info",
Self::Warning => "warning",
Self::Error => "error",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FlashMessage {
pub level: FlashLevel,
pub message: String,
}
#[derive(Debug, Clone)]
pub struct Flash {
session: Session,
}
impl Flash {
#[must_use]
pub const fn new(session: Session) -> Self {
Self { session }
}
pub async fn push(&self, level: FlashLevel, message: impl Into<String>) {
let mut messages = self.peek().await;
messages.push(FlashMessage {
level,
message: message.into(),
});
if let Ok(json) = serde_json::to_string(&messages) {
self.session.insert(FLASH_SESSION_KEY, json).await;
}
}
pub async fn success(&self, message: impl Into<String>) {
self.push(FlashLevel::Success, message).await;
}
pub async fn info(&self, message: impl Into<String>) {
self.push(FlashLevel::Info, message).await;
}
pub async fn warning(&self, message: impl Into<String>) {
self.push(FlashLevel::Warning, message).await;
}
pub async fn error(&self, message: impl Into<String>) {
self.push(FlashLevel::Error, message).await;
}
pub async fn peek(&self) -> Vec<FlashMessage> {
self.session
.get(FLASH_SESSION_KEY)
.await
.map_or_else(Vec::new, |json| {
serde_json::from_str(&json).unwrap_or_default()
})
}
pub async fn consume(&self) -> Vec<FlashMessage> {
let messages = self.peek().await;
if !messages.is_empty() {
self.session.remove(FLASH_SESSION_KEY).await;
}
messages
}
#[cfg(feature = "htmx")]
pub async fn inject_hx_trigger<T: axum::response::IntoResponse>(
&self,
response: T,
) -> axum::response::Response {
let messages = self.consume().await;
let mut res = response.into_response();
if !messages.is_empty() {
let payload = serde_json::json!({
"flash": messages
});
if let Ok(v) = http::header::HeaderValue::from_str(&payload.to_string()) {
res.headers_mut()
.insert(http::header::HeaderName::from_static("hx-trigger"), v);
}
}
res
}
}
impl<S> FromRequestParts<S> for Flash
where
S: Send + Sync,
{
type Rejection = <Session as FromRequestParts<S>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let session = Session::from_request_parts(parts, state).await?;
Ok(Self::new(session))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[tokio::test]
async fn flash_push_and_consume() {
let session = Session::new_for_test("test_id".to_string(), HashMap::new());
let flash = Flash::new(session.clone());
flash.success("Saved!").await;
flash.error("Failed!").await;
let messages = flash.peek().await;
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].level, FlashLevel::Success);
assert_eq!(messages[0].message, "Saved!");
assert_eq!(messages[1].level, FlashLevel::Error);
assert_eq!(messages[1].message, "Failed!");
assert_eq!(flash.peek().await.len(), 2);
let consumed = flash.consume().await;
assert_eq!(consumed.len(), 2);
assert_eq!(flash.peek().await.len(), 0);
}
#[tokio::test]
async fn flash_level_as_str() {
assert_eq!(FlashLevel::Success.as_str(), "success");
assert_eq!(FlashLevel::Info.as_str(), "info");
assert_eq!(FlashLevel::Warning.as_str(), "warning");
assert_eq!(FlashLevel::Error.as_str(), "error");
}
#[tokio::test]
async fn should_not_remove_key_when_consuming_empty_flash() -> Result<(), String> {
let session = Session::new_for_test("test_id".to_string(), HashMap::new());
session.insert("dummy", "val").await;
let flash = Flash::new(session.clone());
let messages = flash.consume().await;
assert_eq!(messages.len(), 0);
assert_eq!(
session.get("dummy").await.ok_or("missing key dummy")?,
"val"
);
assert!(!session.contains_key(FLASH_SESSION_KEY).await);
Ok(())
}
#[tokio::test]
async fn should_handle_invalid_json_gracefully() {
let session = Session::new_for_test("test_id".to_string(), HashMap::new());
session
.insert(FLASH_SESSION_KEY, "{ invalid_json: true")
.await;
let flash = Flash::new(session);
let messages = flash.peek().await;
assert_eq!(messages.len(), 0);
}
#[tokio::test]
async fn should_support_all_convenience_methods() {
let session = Session::new_for_test("test_id".to_string(), HashMap::new());
let flash = Flash::new(session);
flash.success("Success msg").await;
flash.info("Info msg").await;
flash.warning("Warning msg").await;
flash.error("Error msg").await;
let messages = flash.peek().await;
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].level, FlashLevel::Success);
assert_eq!(messages[0].message, "Success msg");
assert_eq!(messages[1].level, FlashLevel::Info);
assert_eq!(messages[1].message, "Info msg");
assert_eq!(messages[2].level, FlashLevel::Warning);
assert_eq!(messages[2].message, "Warning msg");
assert_eq!(messages[3].level, FlashLevel::Error);
assert_eq!(messages[3].message, "Error msg");
}
#[tokio::test]
#[cfg(feature = "htmx")]
async fn should_inject_hx_trigger() {
let session = Session::new_for_test("test_id".to_string(), HashMap::new());
let flash = Flash::new(session.clone());
flash.success("Item saved").await;
let response = flash.inject_hx_trigger("OK").await;
let header = response.headers().get("hx-trigger");
assert!(header.is_some());
let json_str = header.unwrap().to_str().unwrap();
let payload: serde_json::Value = serde_json::from_str(json_str).unwrap();
assert_eq!(payload["flash"][0]["level"], "success");
assert_eq!(payload["flash"][0]["message"], "Item saved");
}
}