use base64::prelude::*;
use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error, Result,
};
use http::header::{self, HeaderMap};
pub struct Credentials {
pub username: String,
pub password: String,
}
impl Credentials {
pub fn new(username: &str, password: &str) -> Credentials {
Self {
username: username.to_string(),
password: password.to_string(),
}
}
pub fn verify(&self, header: &HeaderMap) -> Result<()> {
let Some(auth_header) = header.get(header::AUTHORIZATION) else {
return Err(Error::Unauthorized);
};
if let Some((typ, payload)) = auth_header
.to_str()
.or_err(ErrorType::ParseError)?
.to_string()
.split_once(' ')
{
if typ.to_lowercase() != "basic" {
return Err(Error::Unauthorized);
};
let decoded = String::from_utf8(
BASE64_STANDARD
.decode(payload)
.or_err(ErrorType::ParseError)?,
)
.or_err(ErrorType::ParseError)?;
let Some((username, password)) = decoded.split_once(':') else {
return Err(Error::Unauthorized);
};
if username != self.username || password != self.password {
return Err(Error::Unauthorized);
}
return Ok(());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::header::HeaderValue;
#[test]
fn test_verify_no_auth_header() {
let credentials = Credentials::new("user", "pass");
let header = HeaderMap::new();
let result = credentials.verify(&header);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::Unauthorized));
}
#[test]
fn test_verify_invalid_auth_type() {
let credentials = Credentials::new("user", "pass");
let mut header = HeaderMap::new();
header.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Bearer some_token"),
);
let result = credentials.verify(&header);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::Unauthorized));
}
#[test]
fn test_verify_invalid_base64() {
let credentials = Credentials::new("user", "pass");
let mut header = HeaderMap::new();
header.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Basic invalid_base64"),
);
let result = credentials.verify(&header);
assert!(result.is_err());
assert_eq!(
format!("{}", result.err().unwrap()),
format!(
"{:?} cause: Invalid symbol 95, offset 7.",
ErrorType::ParseError
),
);
}
#[test]
fn test_verify_invalid_format() {
let credentials = Credentials::new("user", "pass");
let mut header = HeaderMap::new();
header.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Basic dXNlcg=="), );
let result = credentials.verify(&header);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::Unauthorized));
}
#[test]
fn test_verify_incorrect_credentials() {
let credentials = Credentials::new("user", "pass");
let mut header = HeaderMap::new();
header.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Basic dXNlcjpwYXNzX2Vycm9y"), );
let result = credentials.verify(&header);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::Unauthorized));
}
#[test]
fn test_verify_correct_credentials() {
let credentials = Credentials::new("user", "pass");
let mut header = HeaderMap::new();
header.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Basic dXNlcjpwYXNz"), );
let result = credentials.verify(&header);
assert!(result.is_ok());
}
}