use axum::http::{header, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use crate::core::Model;
use crate::query::QuerySet;
use crate::sql::{
ExecError, FetcherPool, LoadRelated, MaybeMyFromRow, MaybeMyLoadRelated, MaybePgFromRow,
MaybeSqliteFromRow, MaybeSqliteLoadRelated, Pool,
};
#[derive(Debug)]
pub enum ShortcutError {
NotFound { message: String },
Database(ExecError),
}
impl ShortcutError {
#[must_use]
pub fn not_found(message: impl Into<String>) -> Self {
Self::NotFound {
message: message.into(),
}
}
}
impl std::fmt::Display for ShortcutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound { message } => write!(f, "{message}"),
Self::Database(e) => write!(f, "database error: {e}"),
}
}
}
impl std::error::Error for ShortcutError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::NotFound { .. } => None,
Self::Database(e) => Some(e),
}
}
}
impl IntoResponse for ShortcutError {
fn into_response(self) -> Response {
match self {
Self::NotFound { message } => (StatusCode::NOT_FOUND, message).into_response(),
Self::Database(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("database error: {e}"),
)
.into_response(),
}
}
}
impl From<ExecError> for ShortcutError {
fn from(e: ExecError) -> Self {
Self::Database(e)
}
}
pub async fn get_object_or_404<T>(qs: QuerySet<T>, pool: &Pool) -> Result<T, ShortcutError>
where
T: Model
+ Send
+ Unpin
+ MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated,
{
match qs.first(pool).await? {
Some(row) => Ok(row),
None => Err(ShortcutError::not_found(format!(
"no {} matches",
T::SCHEMA.name
))),
}
}
pub async fn get_list_or_404<T>(qs: QuerySet<T>, pool: &Pool) -> Result<Vec<T>, ShortcutError>
where
T: Model
+ Send
+ Unpin
+ MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated,
{
let rows = qs.fetch_pool(pool).await?;
if rows.is_empty() {
Err(ShortcutError::not_found(format!(
"no {} matches",
T::SCHEMA.name
)))
} else {
Ok(rows)
}
}
#[must_use]
pub fn render(tera: &tera::Tera, name: &str, ctx: &tera::Context) -> Response {
match tera.render(name, ctx) {
Ok(body) => axum::response::Html(body).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("template `{name}` failed: {e}"),
)
.into_response(),
}
}
pub fn render_to_string(
tera: &tera::Tera,
name: &str,
ctx: &tera::Context,
) -> Result<String, tera::Error> {
tera.render(name, ctx)
}
#[must_use]
pub fn redirect(url: impl Into<String>) -> Response {
build_redirect(StatusCode::FOUND, url.into())
}
#[must_use]
pub fn redirect_permanent(url: impl Into<String>) -> Response {
build_redirect(StatusCode::MOVED_PERMANENTLY, url.into())
}
#[must_use]
pub fn redirect_see_other(url: impl Into<String>) -> Response {
build_redirect(StatusCode::SEE_OTHER, url.into())
}
#[must_use]
pub fn redirect_temporary(url: impl Into<String>) -> Response {
build_redirect(StatusCode::TEMPORARY_REDIRECT, url.into())
}
#[must_use]
pub fn redirect_permanent_preserve_method(url: impl Into<String>) -> Response {
build_redirect(StatusCode::PERMANENT_REDIRECT, url.into())
}
#[must_use]
pub fn redirect_to_login(next: &str, login_url: &str) -> Response {
let encoded = crate::url_codec::url_encode(next);
let separator = if login_url.contains('?') { '&' } else { '?' };
redirect(format!("{login_url}{separator}next={encoded}"))
}
fn build_redirect(status: StatusCode, url: String) -> Response {
let mut res = Response::builder()
.status(status)
.body(axum::body::Body::empty())
.expect("status + empty body is always valid");
if let Ok(v) = HeaderValue::from_str(&url) {
res.headers_mut().insert(header::LOCATION, v);
}
res
}
#[must_use]
pub fn json_response<T: serde::Serialize>(data: &T, status: u16) -> Response {
let body = match serde_json::to_vec(data) {
Ok(b) => b,
Err(_) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(axum::body::Body::empty())
.expect("500 + empty body is always valid");
}
};
let status = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
let mut res = Response::builder()
.status(status)
.body(axum::body::Body::from(body))
.expect("status + non-empty body is always valid");
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
res
}
#[must_use]
pub fn json_ok<T: serde::Serialize>(data: &T) -> Response {
json_response(data, 200)
}
#[must_use]
pub fn json_bad_request<T: serde::Serialize>(data: &T) -> Response {
json_response(data, 400)
}
#[must_use]
pub fn json_unauthorized<T: serde::Serialize>(data: &T) -> Response {
json_response(data, 401)
}
#[must_use]
pub fn json_forbidden<T: serde::Serialize>(data: &T) -> Response {
json_response(data, 403)
}
#[must_use]
pub fn json_not_found<T: serde::Serialize>(data: &T) -> Response {
json_response(data, 404)
}
#[must_use]
pub fn json_server_error<T: serde::Serialize>(data: &T) -> Response {
json_response(data, 500)
}
#[must_use]
pub fn html_response(content: impl Into<String>, status: u16) -> Response {
let body = content.into();
let status = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
let mut res = Response::builder()
.status(status)
.body(axum::body::Body::from(body))
.expect("status + body is always valid");
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
res
}
#[must_use]
pub fn text_response(content: impl Into<String>, status: u16) -> Response {
let body = content.into();
let status = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
let mut res = Response::builder()
.status(status)
.body(axum::body::Body::from(body))
.expect("status + body is always valid");
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/plain; charset=utf-8"),
);
res
}
#[must_use]
pub fn file_response(
content: impl Into<axum::body::Bytes>,
filename: &str,
content_type: &str,
) -> Response {
let bytes = content.into();
let safe_filename = sanitize_attachment_filename(filename);
let mut res = Response::builder()
.status(StatusCode::OK)
.body(axum::body::Body::from(bytes))
.expect("200 + body is always valid");
let ct = HeaderValue::from_str(content_type)
.unwrap_or_else(|_| HeaderValue::from_static("application/octet-stream"));
res.headers_mut().insert(header::CONTENT_TYPE, ct);
let cd = format!(r#"attachment; filename="{safe_filename}""#);
if let Ok(v) = HeaderValue::from_str(&cd) {
res.headers_mut().insert(header::CONTENT_DISPOSITION, v);
}
res
}
fn sanitize_attachment_filename(name: &str) -> String {
name.chars()
.map(|c| match c {
'"' | '\r' | '\n' => '_',
other => other,
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn not_found_message_round_trips() {
let err = ShortcutError::not_found("post #42 not found");
assert_eq!(err.to_string(), "post #42 not found");
}
#[tokio::test]
async fn not_found_into_response_is_404() {
let err = ShortcutError::not_found("missing");
let res = err.into_response();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn database_error_into_response_is_500_not_404() {
let exec_err = ExecError::Sql(crate::sql::SqlError::EmptyInList);
let err = ShortcutError::Database(exec_err);
let res = err.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn from_exec_error_routes_to_database_variant() {
let exec_err = ExecError::Sql(crate::sql::SqlError::EmptyInList);
let err: ShortcutError = exec_err.into();
assert!(matches!(err, ShortcutError::Database(_)), "got: {err:?}");
}
#[tokio::test]
async fn render_template_returns_html_200() {
let mut tera = tera::Tera::default();
tera.add_raw_template("hello", "Hello, {{ name }}!")
.unwrap();
let mut ctx = tera::Context::new();
ctx.insert("name", "alice");
let res = render(&tera, "hello", &ctx);
assert_eq!(res.status(), StatusCode::OK);
let ct = res
.headers()
.get(axum::http::header::CONTENT_TYPE)
.expect("content-type");
assert!(ct.to_str().unwrap().starts_with("text/html"), "got: {ct:?}");
}
#[tokio::test]
async fn render_template_missing_template_is_500() {
let tera = tera::Tera::default();
let ctx = tera::Context::new();
let res = render(&tera, "nope.html", &ctx);
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn redirect_is_302_with_location_header() {
let res = redirect("/posts/42");
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(axum::http::header::LOCATION)
.expect("location");
assert_eq!(loc.to_str().unwrap(), "/posts/42");
}
#[tokio::test]
async fn redirect_permanent_is_301_with_location_header() {
let res = redirect_permanent("/posts/42");
assert_eq!(res.status(), StatusCode::MOVED_PERMANENTLY);
let loc = res
.headers()
.get(axum::http::header::LOCATION)
.expect("location");
assert_eq!(loc.to_str().unwrap(), "/posts/42");
}
#[tokio::test]
async fn redirect_see_other_is_303_with_location_header() {
let res = redirect_see_other("/items");
assert_eq!(res.status(), StatusCode::SEE_OTHER);
let loc = res
.headers()
.get(axum::http::header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/items");
}
#[tokio::test]
async fn redirect_temporary_is_307_with_location_header() {
let res = redirect_temporary("/api/v1-new/widgets");
assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
let loc = res
.headers()
.get(axum::http::header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/api/v1-new/widgets");
}
#[tokio::test]
async fn redirect_permanent_preserve_method_is_308() {
let res = redirect_permanent_preserve_method("/api/v2/widgets");
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
}
#[tokio::test]
async fn redirect_with_crlf_drops_location_header_no_response_splitting() {
let res = redirect("/posts/42\r\nSet-Cookie: pwned=1");
assert_eq!(res.status(), StatusCode::FOUND);
assert!(
res.headers().get(axum::http::header::LOCATION).is_none(),
"CRLF-injected URL must NOT produce a Location header"
);
}
#[test]
fn render_to_string_returns_rendered_body() {
let mut tera = tera::Tera::default();
tera.add_raw_template("hi", "Hello, {{ name }}!").unwrap();
let mut ctx = tera::Context::new();
ctx.insert("name", "alice");
let out = render_to_string(&tera, "hi", &ctx).unwrap();
assert_eq!(out, "Hello, alice!");
}
#[test]
fn render_to_string_propagates_template_errors() {
let tera = tera::Tera::default();
let ctx = tera::Context::new();
let err = render_to_string(&tera, "absent.html", &ctx).unwrap_err();
let s = format!("{err}");
assert!(s.contains("absent") || s.contains("not found"), "got: {s}");
}
#[tokio::test]
async fn redirect_to_login_appends_next_with_question_mark() {
let res = redirect_to_login("/profile", "/login");
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(axum::http::header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert_eq!(loc, "/login?next=%2Fprofile");
}
#[tokio::test]
async fn redirect_to_login_appends_next_with_ampersand_when_url_has_query() {
let res = redirect_to_login("/profile", "/login?lang=fr");
let loc = res
.headers()
.get(axum::http::header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert_eq!(loc, "/login?lang=fr&next=%2Fprofile");
}
#[tokio::test]
async fn redirect_to_login_encodes_special_chars_in_next() {
let res = redirect_to_login("/posts/42?utm=ad&q=spaces here", "/login");
let loc = res
.headers()
.get(axum::http::header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert!(loc.starts_with("/login?next="), "got: {loc}");
assert!(!loc[12..].contains(' '), "next= must percent-escape spaces");
assert!(!loc[12..].contains('&'), "next= must percent-escape &");
assert!(!loc[12..].contains('?'), "next= must percent-escape ?");
}
async fn body_bytes(res: Response) -> Vec<u8> {
axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap()
.to_vec()
}
#[tokio::test]
async fn json_response_emits_status_content_type_and_serialized_body() {
let res = json_response(&serde_json::json!({"id": 1, "name": "Alice"}), 201);
assert_eq!(res.status(), StatusCode::CREATED);
assert_eq!(
res.headers()
.get(axum::http::header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap(),
"application/json"
);
let body = body_bytes(res).await;
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v, serde_json::json!({"id": 1, "name": "Alice"}));
}
#[tokio::test]
async fn json_ok_is_200() {
let res = json_ok(&serde_json::json!({"hello": "world"}));
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn json_bad_request_is_400() {
let res = json_bad_request(&serde_json::json!({"error": "validation"}));
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body = body_bytes(res).await;
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["error"], "validation");
}
#[tokio::test]
async fn json_response_status_below_100_falls_back_to_200() {
let res = json_response(&serde_json::json!({}), 42);
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn json_response_serializes_arbitrary_struct() {
#[derive(serde::Serialize)]
struct Out {
ok: bool,
count: u32,
}
let res = json_response(&Out { ok: true, count: 7 }, 200);
assert_eq!(res.status(), StatusCode::OK);
let body = body_bytes(res).await;
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v, serde_json::json!({"ok": true, "count": 7}));
}
#[tokio::test]
async fn json_unauthorized_is_401() {
let res = json_unauthorized(&serde_json::json!({"error": "login required"}));
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn json_forbidden_is_403() {
let res = json_forbidden(&serde_json::json!({"error": "no access"}));
assert_eq!(res.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn json_not_found_is_404() {
let res = json_not_found(&serde_json::json!({"error": "no such item"}));
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn json_server_error_is_500() {
let res = json_server_error(&serde_json::json!({"error": "internal"}));
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn json_error_helpers_all_emit_application_json_content_type() {
let cases: Vec<axum::response::Response> = vec![
json_unauthorized(&serde_json::json!({})),
json_forbidden(&serde_json::json!({})),
json_not_found(&serde_json::json!({})),
json_server_error(&serde_json::json!({})),
];
for res in cases {
let ct = res
.headers()
.get(axum::http::header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert_eq!(ct, "application/json");
}
}
#[tokio::test]
async fn html_response_emits_html_content_type_and_body() {
let res = html_response("<h1>Hello</h1>", 200);
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.headers()
.get(axum::http::header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap(),
"text/html; charset=utf-8"
);
let body = body_bytes(res).await;
assert_eq!(String::from_utf8(body).unwrap(), "<h1>Hello</h1>");
}
#[tokio::test]
async fn html_response_respects_custom_status() {
let res = html_response("<h1>Down</h1>", 503);
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
async fn html_response_invalid_status_falls_back_to_200() {
let res = html_response("x", 42);
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn text_response_emits_text_content_type_and_body() {
let res = text_response("ok", 200);
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.headers()
.get(axum::http::header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap(),
"text/plain; charset=utf-8"
);
let body = body_bytes(res).await;
assert_eq!(String::from_utf8(body).unwrap(), "ok");
}
#[tokio::test]
async fn text_response_respects_custom_status() {
let res = text_response("teapot", 418);
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
}
#[tokio::test]
async fn file_response_emits_attachment_disposition_with_filename() {
let res = file_response(b"a,b,c\n1,2,3\n".to_vec(), "report.csv", "text/csv");
assert_eq!(res.status(), StatusCode::OK);
let cd = res
.headers()
.get(axum::http::header::CONTENT_DISPOSITION)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert_eq!(cd, r#"attachment; filename="report.csv""#);
}
#[tokio::test]
async fn file_response_passes_through_content_type() {
let res = file_response(b"PDF...".to_vec(), "x.pdf", "application/pdf");
let ct = res
.headers()
.get(axum::http::header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert_eq!(ct, "application/pdf");
}
#[tokio::test]
async fn file_response_returns_body_bytes_verbatim() {
let body = b"raw bytes \xFF\xFE here".to_vec();
let res = file_response(body.clone(), "x.bin", "application/octet-stream");
let got = body_bytes(res).await;
assert_eq!(got, body);
}
#[tokio::test]
async fn file_response_sanitizes_quotes_in_filename() {
let res = file_response(b"".to_vec(), r#"a"; injected=bad.txt"#, "text/plain");
let cd = res
.headers()
.get(axum::http::header::CONTENT_DISPOSITION)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert!(
!cd.contains('"').then(|| ()).is_none() || !cd.contains("\";"),
"got: {cd}"
);
assert!(cd.contains("filename=\"a_;"), "got: {cd}");
}
#[tokio::test]
async fn file_response_sanitizes_crlf_in_filename_no_header_splitting() {
let res = file_response(b"".to_vec(), "a\r\nX-Hack: y.txt", "text/plain");
let cd = res
.headers()
.get(axum::http::header::CONTENT_DISPOSITION)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert!(!cd.contains('\r'));
assert!(!cd.contains('\n'));
}
#[tokio::test]
async fn file_response_invalid_content_type_falls_back_to_octet_stream() {
let res = file_response(b"".to_vec(), "x.bin", "text/x\0bad");
let ct = res
.headers()
.get(axum::http::header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap()
.to_owned();
assert_eq!(ct, "application/octet-stream");
}
}