use bytes::Bytes;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::header::AUTHORIZATION;
use hyper::{Request, Response, StatusCode};
use crate::Config;
pub fn check_auth(req: &Request<Incoming>, config: &Config) -> Result<(), Response<Full<Bytes>>> {
let (Some(want_user), Some(want_pass)) = (&config.user, &config.pass) else {
return Ok(());
};
let Some(auth) = req.headers().get(AUTHORIZATION) else {
return Err(auth_required());
};
let Ok(value) = auth.to_str() else {
return Err(auth_required());
};
if let Some(encoded) = value.strip_prefix("Basic ") {
if let Ok(decoded) = base64_decode(encoded) {
if let Some((u, p)) = decoded.split_once(':') {
if u == want_user && p == want_pass {
return Ok(());
}
}
}
}
Err(auth_required())
}
pub fn check_credentials(user: &str, pass: &str, config: &Config) -> bool {
match (&config.user, &config.pass) {
(Some(want_u), Some(want_p)) => user == want_u && pass == want_p,
_ => true, }
}
pub fn auth_required() -> Response<Full<Bytes>> {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", "Basic realm=\"gitrub\"")
.body(Full::new(Bytes::from("Authentication required")))
.unwrap()
}
pub fn base64_decode(input: &str) -> Result<String, ()> {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = Vec::new();
let mut buf: u32 = 0;
let mut bits: u32 = 0;
for &b in input.as_bytes() {
if b == b'=' {
break;
}
let val = TABLE.iter().position(|&c| c == b).ok_or(())? as u32;
buf = (buf << 6) | val;
bits += 6;
if bits >= 8 {
bits -= 8;
out.push((buf >> bits) as u8);
buf &= (1 << bits) - 1;
}
}
String::from_utf8(out).map_err(|_| ())
}