use form_urlencoded::byte_serialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ActionKind {
#[default]
Generic,
NotFound,
Forbidden,
Unauthorized,
}
impl ActionKind {
pub(crate) fn as_query_str(&self) -> &'static str {
match self {
Self::Generic => "generic",
Self::NotFound => "not_found",
Self::Forbidden => "forbidden",
Self::Unauthorized => "unauthorized",
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FlashVariant {
#[default]
Error,
Warning,
Info,
}
#[derive(Debug, Clone, Error)]
#[error("{message}")]
pub struct ActionError {
pub message: String,
pub kind: ActionKind,
pub flash_variant: FlashVariant,
pub redirect_override: Option<String>,
pub(crate) suppress_url_envelope: bool,
}
impl ActionError {
pub fn msg(message: impl Into<String>) -> Self {
Self {
message: message.into(),
kind: ActionKind::Generic,
flash_variant: FlashVariant::Error,
redirect_override: None,
suppress_url_envelope: false,
}
}
pub fn validation_failed(redirect_to: impl Into<String>) -> Self {
Self {
message: String::new(),
kind: ActionKind::Generic,
flash_variant: FlashVariant::Error,
redirect_override: Some(redirect_to.into()),
suppress_url_envelope: true,
}
}
pub fn not_found(message: impl Into<String>) -> Self {
Self {
kind: ActionKind::NotFound,
..Self::msg(message)
}
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self {
kind: ActionKind::Forbidden,
..Self::msg(message)
}
}
pub fn unauthorized(message: impl Into<String>) -> Self {
Self {
kind: ActionKind::Unauthorized,
..Self::msg(message)
}
}
#[must_use]
pub fn with_flash(mut self, variant: FlashVariant) -> Self {
self.flash_variant = variant;
self
}
#[must_use]
pub fn redirect_to(mut self, url: impl Into<String>) -> Self {
self.redirect_override = Some(url.into());
self
}
}
impl From<String> for ActionError {
fn from(s: String) -> Self {
Self::msg(s)
}
}
impl From<&'static str> for ActionError {
fn from(s: &'static str) -> Self {
Self::msg(s)
}
}
impl From<crate::error::FrameworkError> for ActionError {
fn from(err: crate::error::FrameworkError) -> Self {
Self::msg(err.to_string())
}
}
impl From<sea_orm::DbErr> for ActionError {
fn from(err: sea_orm::DbErr) -> Self {
Self::msg(err.to_string())
}
}
pub trait IntoActionError {
fn into_action_error(self) -> ActionError;
}
impl<E: std::fmt::Display> IntoActionError for E {
fn into_action_error(self) -> ActionError {
ActionError::msg(self.to_string())
}
}
pub trait ActionResultExt<T> {
fn action_err(self) -> Result<T, ActionError>;
}
impl<T, E: IntoActionError> ActionResultExt<T> for Result<T, E> {
fn action_err(self) -> Result<T, ActionError> {
self.map_err(|e| e.into_action_error())
}
}
pub type ActionResult = Result<(), ActionError>;
#[derive(Debug, Default, Clone)]
pub(crate) struct ActionOverrides {
pub flash: Option<String>,
pub redirect_override: Option<String>,
}
pub(crate) fn is_same_origin(url: &str) -> bool {
url.starts_with('/') && !url.starts_with("//")
}
pub(crate) fn sanitize_for_log(s: &str) -> String {
s.chars()
.map(|c| if c.is_control() { ' ' } else { c })
.collect()
}
#[derive(Debug, Serialize, Deserialize)]
struct ActionFlashPayload<'a> {
variant: &'a str,
message: &'a str,
}
#[doc(hidden)]
pub fn handle_action_result(
result: ActionResult,
redirect_to: &'static str,
handler_name: &'static str,
req: &mut crate::http::Request,
) -> crate::http::Response {
match result {
Ok(()) => {
let overrides = req.action_overrides().clone();
let target = match overrides.redirect_override.as_deref() {
Some(url) if is_same_origin(url) => url.to_string(),
Some(rejected) => {
tracing::warn!(
handler = %handler_name,
rejected_url = %sanitize_for_log(rejected),
"redirect_override rejected: not same-origin (success path)"
);
redirect_to.to_string()
}
None => redirect_to.to_string(),
};
if let Some(key) = overrides.flash.as_deref() {
let payload = ActionFlashPayload {
variant: "success",
message: key,
};
crate::session::session_mut(|s| s.flash("_action", &payload));
}
let sep = if target.contains('?') { '&' } else { '?' };
let suffix = match overrides.flash.as_deref() {
Some(k) if !k.is_empty() => {
let encoded_key: String = byte_serialize(k.as_bytes()).collect();
format!("{sep}success={encoded_key}")
}
_ => format!("{sep}success=1"),
};
let location = format!("{target}{suffix}");
Ok(crate::http::HttpResponse::new()
.status(303)
.header("Location", &location))
}
Err(err) => {
let safe_msg = sanitize_for_log(&err.message);
tracing::error!(
handler = %handler_name,
msg = %safe_msg,
kind = ?err.kind,
"action handler error — redirecting"
);
let target = match err.redirect_override.as_deref() {
Some(url) if is_same_origin(url) => url.to_string(),
Some(rejected) => {
tracing::warn!(
handler = %handler_name,
rejected_url = %sanitize_for_log(rejected),
"redirect_override rejected: not same-origin (error path)"
);
redirect_to.to_string()
}
None => redirect_to.to_string(),
};
let location = if err.suppress_url_envelope {
target
} else {
let variant_str = match err.flash_variant {
FlashVariant::Error => "error",
FlashVariant::Warning => "warning",
FlashVariant::Info => "info",
};
let payload = ActionFlashPayload {
variant: variant_str,
message: &err.message,
};
crate::session::session_mut(|s| s.flash("_action", &payload));
let sep = if target.contains('?') { '&' } else { '?' };
let encoded_msg: String = byte_serialize(err.message.as_bytes()).collect();
format!(
"{target}{sep}error={kind}&msg={msg}",
target = target,
sep = sep,
kind = err.kind.as_query_str(),
msg = encoded_msg
)
};
Ok(crate::http::HttpResponse::new()
.status(303)
.header("Location", &location))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn msg_constructor_defaults() {
let e = ActionError::msg("boom");
assert_eq!(e.message, "boom");
assert!(matches!(e.kind, ActionKind::Generic));
assert!(matches!(e.flash_variant, FlashVariant::Error));
assert!(e.redirect_override.is_none());
}
#[test]
fn not_found_constructor_sets_kind() {
let e = ActionError::not_found("missing");
assert!(matches!(e.kind, ActionKind::NotFound));
}
#[test]
fn forbidden_constructor_sets_kind() {
let e = ActionError::forbidden("nope");
assert!(matches!(e.kind, ActionKind::Forbidden));
}
#[test]
fn unauthorized_constructor_no_default_redirect() {
let e = ActionError::unauthorized("login first");
assert!(matches!(e.kind, ActionKind::Unauthorized));
assert!(
e.redirect_override.is_none(),
"ferro must not hardcode a default auth-redirect path (D-08)"
);
}
#[test]
fn builders_consume_self() {
let e = ActionError::msg("x")
.with_flash(FlashVariant::Warning)
.redirect_to("/login");
assert!(matches!(e.flash_variant, FlashVariant::Warning));
assert_eq!(e.redirect_override.as_deref(), Some("/login"));
}
#[test]
fn from_string_impl() {
let e: ActionError = "oops".to_string().into();
assert_eq!(e.message, "oops");
}
#[test]
fn from_static_str_impl() {
let e: ActionError = "static".into();
assert_eq!(e.message, "static");
}
#[test]
fn from_framework_error_impl() {
let fe = crate::error::FrameworkError::internal("framework boom");
let e: ActionError = fe.into();
assert!(e.message.contains("framework boom"));
}
#[test]
fn into_action_error_blanket_for_display_types() {
let n: i32 = 42;
let e = n.into_action_error();
assert_eq!(e.message, "42");
}
#[test]
fn action_err_extension_on_result() {
let r: Result<(), i32> = Err(7);
let converted: Result<(), ActionError> = r.action_err();
assert!(converted.is_err());
assert_eq!(converted.unwrap_err().message, "7");
}
#[test]
fn sanitize_strips_control_chars() {
assert_eq!(sanitize_for_log("a\nb\tc\x00d"), "a b c d");
}
#[test]
fn is_same_origin_accepts_relative() {
assert!(is_same_origin("/dashboard"));
assert!(is_same_origin("/"));
}
#[test]
fn is_same_origin_rejects_absolute() {
assert!(!is_same_origin("https://evil.example/"));
assert!(!is_same_origin("//evil.example/"));
assert!(!is_same_origin("http://localhost/"));
}
#[test]
fn action_kind_query_strings() {
assert_eq!(ActionKind::Generic.as_query_str(), "generic");
assert_eq!(ActionKind::NotFound.as_query_str(), "not_found");
assert_eq!(ActionKind::Forbidden.as_query_str(), "forbidden");
assert_eq!(ActionKind::Unauthorized.as_query_str(), "unauthorized");
}
}