use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthConfig {
Basic {
username: String,
password: String,
},
Bearer {
token: String,
},
ApiKey {
header: String,
key: String,
},
Header {
name: String,
value: String,
},
}
impl AuthConfig {
pub fn apply_to_headers(&self, headers: &mut HeaderMap) -> Result<()> {
let (name, value) = self.to_header()?;
headers.insert(name, value);
Ok(())
}
fn to_header(&self) -> Result<(HeaderName, HeaderValue)> {
match self {
AuthConfig::Basic { username, password } => {
let credentials = format!("{}:{}", username, password);
let encoded = BASE64.encode(credentials.as_bytes());
let value = format!("Basic {}", encoded);
Ok((
HeaderName::from_static("authorization"),
HeaderValue::from_str(&value).map_err(|e| {
Error::config(format!("Invalid basic auth credentials: {}", e))
})?,
))
}
AuthConfig::Bearer { token } => {
let value = format!("Bearer {}", token);
Ok((
HeaderName::from_static("authorization"),
HeaderValue::from_str(&value)
.map_err(|e| Error::config(format!("Invalid bearer token: {}", e)))?,
))
}
AuthConfig::ApiKey { header, key } => {
let name = HeaderName::from_bytes(header.as_bytes()).map_err(|e| {
Error::config(format!("Invalid API key header name '{}': {}", header, e))
})?;
let value = HeaderValue::from_str(key).map_err(|e| {
Error::config(format!(
"Invalid API key value for header '{}': {}",
header, e
))
})?;
Ok((name, value))
}
AuthConfig::Header { name, value } => {
let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
Error::config(format!("Invalid auth header name '{}': {}", name, e))
})?;
let header_value = HeaderValue::from_str(value).map_err(|e| {
Error::config(format!("Invalid auth header value for '{}': {}", name, e))
})?;
Ok((header_name, header_value))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_auth_header() {
let auth = AuthConfig::Basic {
username: "user".to_string(),
password: "pass".to_string(),
};
let mut headers = HeaderMap::new();
auth.apply_to_headers(&mut headers).unwrap();
let auth_header = headers.get("authorization").unwrap();
assert_eq!(auth_header.to_str().unwrap(), "Basic dXNlcjpwYXNz");
}
#[test]
fn test_bearer_auth_header() {
let auth = AuthConfig::Bearer {
token: "my-jwt-token".to_string(),
};
let mut headers = HeaderMap::new();
auth.apply_to_headers(&mut headers).unwrap();
let auth_header = headers.get("authorization").unwrap();
assert_eq!(auth_header.to_str().unwrap(), "Bearer my-jwt-token");
}
#[test]
fn test_api_key_header() {
let auth = AuthConfig::ApiKey {
header: "X-API-Key".to_string(),
key: "secret-key".to_string(),
};
let mut headers = HeaderMap::new();
auth.apply_to_headers(&mut headers).unwrap();
let api_key = headers.get("x-api-key").unwrap();
assert_eq!(api_key.to_str().unwrap(), "secret-key");
}
#[test]
fn test_custom_header_auth() {
let auth = AuthConfig::Header {
name: "X-Custom-Auth".to_string(),
value: "custom-value".to_string(),
};
let mut headers = HeaderMap::new();
auth.apply_to_headers(&mut headers).unwrap();
let custom = headers.get("x-custom-auth").unwrap();
assert_eq!(custom.to_str().unwrap(), "custom-value");
}
#[test]
fn test_invalid_header_name() {
let auth = AuthConfig::ApiKey {
header: "Invalid Header".to_string(), key: "key".to_string(),
};
let mut headers = HeaderMap::new();
let result = auth.apply_to_headers(&mut headers);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid API key header name")
);
}
#[test]
fn test_deserialize_basic() {
let yaml = r#"
type: basic
username: admin
password: secret123
"#;
let auth: AuthConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(
auth,
AuthConfig::Basic {
username: "admin".to_string(),
password: "secret123".to_string(),
}
);
}
#[test]
fn test_deserialize_bearer() {
let yaml = r#"
type: bearer
token: "jwt-token-here"
"#;
let auth: AuthConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(
auth,
AuthConfig::Bearer {
token: "jwt-token-here".to_string(),
}
);
}
#[test]
fn test_deserialize_api_key() {
let yaml = r#"
type: api_key
header: X-API-Key
key: "my-api-key"
"#;
let auth: AuthConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(
auth,
AuthConfig::ApiKey {
header: "X-API-Key".to_string(),
key: "my-api-key".to_string(),
}
);
}
#[test]
fn test_deserialize_header() {
let yaml = r#"
type: header
name: X-Custom
value: "custom-value"
"#;
let auth: AuthConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(
auth,
AuthConfig::Header {
name: "X-Custom".to_string(),
value: "custom-value".to_string(),
}
);
}
#[test]
fn test_serialize_roundtrip() {
let auth = AuthConfig::Bearer {
token: "test-token".to_string(),
};
let yaml = serde_yaml::to_string(&auth).unwrap();
let parsed: AuthConfig = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(auth, parsed);
}
}