use http::HeaderMap;
use parlov_core::ResponseSurface;
use super::auth_types::{AuthBlockConfidence, LoginRedirectSignal};
const LOGIN_PATHS: &[&str] = &[
"/login",
"/signin",
"/sign-in",
"/auth",
"/authenticate",
"/session/new",
"/users/sign_in",
"/oauth/authorize",
"/sso",
];
const OAUTH_QUERY_PARAMS: &[&str] = &[
"client_id",
"redirect_uri",
"response_type",
"scope",
"state",
];
const SESSION_COOKIE_TOKENS: &[&str] = &["session", "csrf", "xsrf", "state", "nonce"];
#[must_use]
pub fn is_login_redirect(res: &ResponseSurface) -> Option<LoginRedirectSignal> {
let status = res.status.as_u16();
if !matches!(status, 302 | 303 | 307 | 308) {
return None;
}
let loc = res.headers.get(http::header::LOCATION)?.to_str().ok()?;
let (path_match, oauth_match) = classify_location(loc);
if !path_match && !oauth_match {
return None;
}
let cookie_strength = has_session_cookie(&res.headers);
let confidence = if cookie_strength || oauth_match {
AuthBlockConfidence::Strong
} else {
AuthBlockConfidence::Medium
};
Some(LoginRedirectSignal {
location: loc.to_owned(),
confidence,
})
}
fn classify_location(loc: &str) -> (bool, bool) {
let (path, query) = split_location(loc);
let lower_path = path.to_ascii_lowercase();
let path_match = LOGIN_PATHS
.iter()
.any(|p| lower_path == *p || lower_path.starts_with(&format!("{p}/")));
let oauth_match = match query {
Some(q) => OAUTH_QUERY_PARAMS.iter().any(|p| query_has_param(q, p)),
None => false,
};
(path_match, oauth_match)
}
fn split_location(loc: &str) -> (&str, Option<&str>) {
let path_and_query = if has_scheme_prefix(loc) {
match loc.find("://") {
Some(idx) => {
let after_scheme = &loc[idx + 3..];
after_scheme.find('/').map_or("", |p| &after_scheme[p..])
}
None => loc,
}
} else {
loc
};
let path_and_query = if path_and_query.is_empty() {
loc
} else {
path_and_query
};
match path_and_query.split_once('?') {
Some((p, q)) => (p, Some(q)),
None => (path_and_query, None),
}
}
fn has_scheme_prefix(loc: &str) -> bool {
let bytes = loc.as_bytes();
if bytes.is_empty() || !bytes[0].is_ascii_alphabetic() {
return false;
}
let scheme_end = bytes
.iter()
.position(|&b| !(b.is_ascii_alphanumeric() || b == b'+' || b == b'-' || b == b'.'));
let Some(end) = scheme_end else { return false };
bytes[end..].starts_with(b"://")
}
fn query_has_param(query: &str, name: &str) -> bool {
query
.split('&')
.any(|pair| pair.split('=').next().is_some_and(|k| k == name))
}
fn has_session_cookie(headers: &HeaderMap) -> bool {
headers
.get_all(http::header::SET_COOKIE)
.iter()
.filter_map(|v| v.to_str().ok())
.any(|cookie| {
let lower = cookie.to_ascii_lowercase();
SESSION_COOKIE_TOKENS.iter().any(|tok| lower.contains(tok))
})
}