use crate::request::RequestInfo;
use http::Method;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResponseShape {
Html,
Json,
SeeOther {
location: String,
},
InertiaLocation {
location: String,
},
}
#[derive(Debug, Clone)]
pub struct DecisionInputs<'a> {
pub req: &'a RequestInfo,
pub server_version: &'a str,
pub redirect: Option<Redirect>,
pub csr_only: bool,
}
#[derive(Debug, Clone)]
pub enum Redirect {
Internal(String),
External(String),
}
pub fn decide(input: DecisionInputs<'_>) -> ResponseShape {
if let Some(r) = input.redirect {
return match r {
Redirect::External(loc) => ResponseShape::InertiaLocation { location: loc },
Redirect::Internal(loc) => {
if matches!(
input.req.method,
Method::POST | Method::PUT | Method::PATCH | Method::DELETE
) {
ResponseShape::SeeOther { location: loc }
} else {
ResponseShape::SeeOther { location: loc }
}
}
};
}
if input.req.is_inertia {
if input.req.method == Method::GET
&& input.req.client_version.as_deref() != Some(input.server_version)
{
return ResponseShape::InertiaLocation {
location: input.req.url.clone(),
};
}
return ResponseShape::Json;
}
if input.csr_only {
ResponseShape::Json
} else {
ResponseShape::Html
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::HeaderMap;
fn req(method: Method, url: &str, is_inertia: bool, version: Option<&str>) -> RequestInfo {
let mut info = RequestInfo::from_parts(method, url.to_string(), &HeaderMap::new());
info.is_inertia = is_inertia;
info.client_version = version.map(str::to_owned);
info
}
#[test]
fn plain_get_returns_html() {
let r = req(Method::GET, "/", false, None);
let d = decide(DecisionInputs {
req: &r,
server_version: "v1",
redirect: None,
csr_only: false,
});
assert_eq!(d, ResponseShape::Html);
}
#[test]
fn csr_only_returns_json_for_plain_get() {
let r = req(Method::GET, "/", false, None);
let d = decide(DecisionInputs {
req: &r,
server_version: "v1",
redirect: None,
csr_only: true,
});
assert_eq!(d, ResponseShape::Json);
}
#[test]
fn xhr_with_matching_version_returns_json() {
let r = req(Method::GET, "/", true, Some("v1"));
let d = decide(DecisionInputs {
req: &r,
server_version: "v1",
redirect: None,
csr_only: false,
});
assert_eq!(d, ResponseShape::Json);
}
#[test]
fn xhr_get_with_stale_version_returns_409_at_same_url() {
let r = req(Method::GET, "/users", true, Some("old"));
let d = decide(DecisionInputs {
req: &r,
server_version: "new",
redirect: None,
csr_only: false,
});
assert_eq!(
d,
ResponseShape::InertiaLocation {
location: "/users".into()
}
);
}
#[test]
fn xhr_post_with_stale_version_still_returns_json() {
let r = req(Method::POST, "/users", true, Some("old"));
let d = decide(DecisionInputs {
req: &r,
server_version: "new",
redirect: None,
csr_only: false,
});
assert_eq!(d, ResponseShape::Json);
}
#[test]
fn internal_redirect_from_post_returns_303() {
let r = req(Method::POST, "/users", true, Some("v1"));
let d = decide(DecisionInputs {
req: &r,
server_version: "v1",
redirect: Some(Redirect::Internal("/users/42".into())),
csr_only: false,
});
assert_eq!(
d,
ResponseShape::SeeOther {
location: "/users/42".into()
}
);
}
#[test]
fn external_redirect_returns_inertia_location() {
let r = req(Method::GET, "/oauth", true, Some("v1"));
let d = decide(DecisionInputs {
req: &r,
server_version: "v1",
redirect: Some(Redirect::External("https://example.com/oauth".into())),
csr_only: false,
});
assert_eq!(
d,
ResponseShape::InertiaLocation {
location: "https://example.com/oauth".into()
}
);
}
}