use crate::core::{FlowRequest, Result, ResumaError};
use axum::http::{header, StatusCode};
use axum::response::{IntoResponse, Redirect as AxumRedirect, Response};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub const FLASH_KEY: &str = "flash";
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct Redirect {
pub to: String,
}
impl Redirect {
pub fn to(path: impl Into<String>) -> Self {
Self { to: path.into() }
}
pub fn into_response(self) -> axum::response::Response {
redirect_response(&self.to)
}
}
impl Serialize for Redirect {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut s = serializer.serialize_struct("Redirect", 1)?;
s.serialize_field("redirect", &self.to)?;
s.end()
}
}
pub fn redirect(path: impl Into<String>) -> Redirect {
Redirect::to(path)
}
pub fn redirect_with_flash(path: impl Into<String>, message: impl AsRef<str>) -> Redirect {
Redirect::to(append_flash(&path.into(), message.as_ref()))
}
pub fn flash_message(req: &FlowRequest) -> Option<String> {
req.query_param(FLASH_KEY)
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
}
fn append_flash(path: &str, message: &str) -> String {
let encoded = serde_urlencoded::to_string([(FLASH_KEY, message)])
.unwrap_or_else(|_| format!("{FLASH_KEY}="));
let sep = if path.contains('?') { '&' } else { '?' };
format!("{path}{sep}{encoded}")
}
pub fn extract_redirect(value: &Value) -> Option<String> {
let path = value.get("redirect")?.as_str()?;
validate_redirect_path(path).ok().map(str::to_string)
}
pub fn validate_redirect_path(path: &str) -> Result<&str> {
if !path.starts_with('/') || path.starts_with("//") {
return Err(ResumaError::Other(format!(
"invalid redirect path `{path}` (must start with `/`, not `//`)"
)));
}
Ok(path)
}
pub fn redirect_response(path: &str) -> Response {
match validate_redirect_path(path) {
Ok(loc) => AxumRedirect::to(loc).into_response(),
Err(err) => (StatusCode::BAD_REQUEST, err.to_string()).into_response(),
}
}
pub fn redirect_json_headers(path: &str) -> Option<[(header::HeaderName, String); 1]> {
validate_redirect_path(path)
.ok()
.map(|loc| [(header::LOCATION, loc.to_string())])
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn accepts_root_relative_paths() {
assert_eq!(validate_redirect_path("/items").unwrap(), "/items");
assert_eq!(
validate_redirect_path("/items/42?created=1").unwrap(),
"/items/42?created=1"
);
}
#[test]
fn rejects_open_redirects() {
assert!(validate_redirect_path("https://evil.test").is_err());
assert!(validate_redirect_path("//evil.test").is_err());
}
#[test]
fn extracts_redirect_field() {
assert_eq!(
extract_redirect(&json!({ "redirect": "/done" })),
Some("/done".into())
);
assert_eq!(extract_redirect(&json!({ "ok": true })), None);
}
#[test]
fn flash_appends_query_param() {
assert_eq!(append_flash("/items", "Saved!"), "/items?flash=Saved%21");
assert_eq!(
append_flash("/items?page=2", "Saved!"),
"/items?page=2&flash=Saved%21"
);
}
#[test]
fn flash_roundtrips_through_request() {
let redirect = redirect_with_flash("/items", "Item created");
assert!(redirect.to.starts_with("/items?flash="));
let query = crate::flow::request::parse_query(redirect.to.split_once('?').map(|x| x.1));
let req = FlowRequest::from_parts(
"GET",
"/items",
Default::default(),
Default::default(),
query,
);
assert_eq!(flash_message(&req).as_deref(), Some("Item created"));
}
}