use crate::bindings::vtx::api::vtx_auth_types::UserContext;
use crate::error::{VtxError, VtxResult};
pub struct AuthRequest<'a> {
headers: &'a [(String, String)],
}
impl<'a> AuthRequest<'a> {
pub fn new(headers: &'a [(String, String)]) -> Self {
Self { headers }
}
pub fn header(&self, key: &str) -> Option<&str> {
let search_key = key.to_lowercase();
for (k, v) in self.headers {
if k.to_lowercase() == search_key {
return Some(v.as_str());
}
}
None
}
pub fn require_header(&self, key: &str) -> VtxResult<&str> {
self.header(key).ok_or({
VtxError::AuthDenied(401)
})
}
pub fn bearer_token(&self) -> Option<&str> {
let val = self.header("Authorization")?;
if val.starts_with("Bearer ") || val.starts_with("bearer ") {
Some(&val[7..])
} else {
None
}
}
pub fn require_bearer_token(&self) -> VtxResult<&str> {
self.bearer_token().ok_or(VtxError::AuthDenied(401))
}
pub fn basic_auth(&self) -> Option<&str> {
let val = self.header("Authorization")?;
if val.starts_with("Basic ") || val.starts_with("basic ") {
Some(&val[6..])
} else {
None
}
}
}
pub struct UserBuilder {
user_id: String,
username: String,
groups: Vec<String>,
metadata: serde_json::Map<String, serde_json::Value>,
}
impl UserBuilder {
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
user_id: id.into(),
username: name.into(),
groups: Vec::new(),
metadata: serde_json::Map::new(),
}
}
pub fn group(mut self, group: impl Into<String>) -> Self {
self.groups.push(group.into());
self
}
pub fn meta<V: serde::Serialize>(mut self, key: &str, value: V) -> Self {
if let Ok(val) = serde_json::to_value(value) {
self.metadata.insert(key.to_string(), val);
}
self
}
pub fn build(self) -> UserContext {
UserContext {
user_id: self.user_id,
username: self.username,
groups: self.groups,
metadata: serde_json::to_string(&self.metadata).unwrap_or_else(|_| "{}".to_string()),
}
}
}
pub trait IntoAuthResult {
fn into_auth_result(self) -> Result<UserContext, u16>;
}
impl IntoAuthResult for VtxResult<UserContext> {
fn into_auth_result(self) -> Result<UserContext, u16> {
match self {
Ok(ctx) => Ok(ctx),
Err(e) => {
let status_code = match e {
VtxError::AuthDenied(code) => code,
VtxError::PermissionDenied(_) => 403,
VtxError::NotFound(_) => 404,
VtxError::DatabaseError(_)
| VtxError::SerializationError(_)
| VtxError::Internal(_) => 500,
};
Err(status_code)
}
}
}
}
#[cfg(test)]
mod tests {
use super::{AuthRequest, IntoAuthResult, UserBuilder};
use crate::error::VtxError;
use serde::Serialize;
#[test]
fn header_lookup_is_case_insensitive() {
let headers = vec![
("authorization".to_string(), "Bearer token".to_string()),
("X-Trace-Id".to_string(), "abc".to_string()),
];
let req = AuthRequest::new(&headers);
assert_eq!(req.header("Authorization"), Some("Bearer token"));
assert_eq!(req.header("x-trace-id"), Some("abc"));
assert_eq!(req.header("missing"), None);
}
#[test]
fn bearer_and_basic_auth_parsing() {
let headers = vec![("Authorization".to_string(), "Bearer abc".to_string())];
let req = AuthRequest::new(&headers);
assert_eq!(req.bearer_token(), Some("abc"));
assert!(req.basic_auth().is_none());
let headers = vec![(
"Authorization".to_string(),
"Basic Zm9vOmJhcg==".to_string(),
)];
let req = AuthRequest::new(&headers);
assert_eq!(req.basic_auth(), Some("Zm9vOmJhcg=="));
assert!(req.bearer_token().is_none());
}
#[test]
fn require_bearer_token_rejects_missing_or_invalid() {
let headers = vec![("Authorization".to_string(), "Token abc".to_string())];
let req = AuthRequest::new(&headers);
let err = req.require_bearer_token().unwrap_err();
assert!(matches!(err, VtxError::AuthDenied(401)));
}
#[test]
fn user_builder_ignores_unserializable_meta() {
struct BadSerialize;
impl Serialize for BadSerialize {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Err(serde::ser::Error::custom("nope"))
}
}
let ctx = UserBuilder::new("u1", "tester")
.group("admin")
.meta("good", 123)
.meta("bad", BadSerialize)
.build();
let meta: serde_json::Value =
serde_json::from_str(&ctx.metadata).expect("valid metadata json");
assert_eq!(meta["good"], 123);
assert!(meta.get("bad").is_none());
}
#[test]
fn into_auth_result_maps_errors_to_status_codes() {
let err = Err(VtxError::PermissionDenied("nope".to_string())).into_auth_result();
assert!(matches!(err, Err(403)));
let err = Err(VtxError::NotFound("missing".to_string())).into_auth_result();
assert!(matches!(err, Err(404)));
let err = Err(VtxError::Internal("boom".to_string())).into_auth_result();
assert!(matches!(err, Err(500)));
}
}