use std::collections::HashMap;
use std::fmt::{self, Debug};
use std::time::SystemTime;
use async_trait::async_trait;
use http::{HeaderMap, HeaderValue};
use reqwest::{Client, Request, Response};
use secrecy::SecretString;
use thiserror::Error;
use tracing::{Level, event, info, instrument};
pub mod authtoken;
pub mod authtoken_scope;
pub mod types;
pub use authtoken::{AuthToken, AuthTokenError};
pub use authtoken_scope::AuthTokenScope;
pub use types::*;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum AuthError {
#[error("authentication rejected")]
AuthReceipt(AuthReceiptResponse),
#[error("authentication receipt cannot be converted to string")]
AuthReceiptNotString,
#[error("AuthToken error: {}", source)]
AuthToken {
#[from]
source: AuthTokenError,
},
#[error("token missing in the response")]
AuthTokenNotInResponse,
#[error("token cannot be converted to string")]
AuthTokenNotString,
#[error("value necessary for the chosen auth method was not supplied to the auth method")]
AuthValueNotSupplied(String),
#[error("authentication method error: {}", .0.message)]
Identity(IdentityError),
#[error("plugin specified malformed requirements")]
PluginMalformedRequirement,
#[error("failed to deserialize response body: {}", source)]
Serde {
#[from]
source: serde_json::Error,
},
#[error("header value error: {}", source)]
HeaderValue {
#[from]
source: http::header::InvalidHeaderValue,
},
#[error("plugin error: {}", source)]
Plugin {
#[source]
source: Box<dyn std::error::Error + Send + Sync + 'static>,
},
#[error(transparent)]
Reqwest {
#[from]
source: reqwest::Error,
},
#[error("identity service error")]
UnknownAuth {
code: u16,
message: Option<String>,
},
#[error(transparent)]
Url {
#[from]
source: url::ParseError,
},
}
impl AuthError {
pub fn plugin<E>(error: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Plugin {
source: Box::new(error),
}
}
}
#[async_trait]
pub trait OpenStackAuthType: Send + Sync {
fn get_supported_auth_methods(&self) -> Vec<&'static str>;
fn requirements(
&self,
hints: Option<&serde_json::Value>,
) -> Result<serde_json::Value, AuthError>;
fn api_version(&self) -> (u8, u8);
async fn auth(
&self,
http_client: &reqwest::Client,
identity_url: &url::Url,
values: HashMap<String, SecretString>,
scope: Option<&AuthTokenScope>,
hints: Option<&serde_json::Value>,
) -> Result<Auth, AuthError>;
}
pub struct AuthPluginRegistration {
pub method: &'static dyn OpenStackAuthType,
}
inventory::collect!(AuthPluginRegistration);
pub trait OpenStackMultifactorAuthMethod: Send + Sync {
fn get_supported_auth_methods(&self) -> Vec<&'static str>;
fn requirements(
&self,
hints: Option<&serde_json::Value>,
) -> Result<serde_json::Value, AuthError>;
fn get_auth_data(
&self,
values: &HashMap<String, SecretString>,
) -> Result<(&'static str, serde_json::Value), AuthError>;
}
pub struct AuthMethodPluginRegistration {
pub method: &'static dyn OpenStackMultifactorAuthMethod,
}
inventory::collect!(AuthMethodPluginRegistration);
#[instrument(name="request", skip_all, fields(http.uri = request.url().as_str(), http.method = request.method().as_str(), openstack.ver=request.headers().get("openstack-api-version").map(|v| v.to_str().unwrap_or(""))))]
pub async fn execute_auth_request(
client: &Client,
request: Request,
) -> Result<Response, reqwest::Error> {
info!("Sending request {:?}", request);
let url = request.url().clone();
let method = request.method().clone();
let start = SystemTime::now();
let rsp = client.execute(request).await?;
let elapsed = SystemTime::now().duration_since(start).unwrap_or_default();
event!(
name: "http_request",
Level::INFO,
url=url.as_str(),
duration_ms=elapsed.as_millis(),
status=rsp.status().as_u16(),
method=method.as_str(),
request_id=rsp.headers().get("x-openstack-request-id").map(|v| v.to_str().unwrap_or("")),
"Request completed with status {}",
rsp.status(),
);
Ok(rsp)
}
#[derive(Clone)]
#[non_exhaustive]
pub enum Auth {
AuthToken(Box<AuthToken>),
None,
}
impl Auth {
pub fn set_header<'a>(
&self,
headers: &'a mut HeaderMap<HeaderValue>,
) -> Result<&'a mut HeaderMap<HeaderValue>, AuthError> {
if let Auth::AuthToken(token) = self {
let _ = token.set_header(headers);
}
Ok(headers)
}
}
impl fmt::Debug for Auth {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Auth {}",
match self {
Auth::AuthToken(_) => "Token",
Auth::None => "unauthed",
}
)
}
}
impl TryFrom<http::Response<bytes::Bytes>> for Auth {
type Error = AuthError;
fn try_from(value: http::Response<bytes::Bytes>) -> Result<Self, Self::Error> {
Ok(Self::AuthToken(Box::new(AuthToken::try_from(value)?)))
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum AuthState {
Valid,
Expired,
AboutToExpire,
Unset,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum BuilderError {
#[error("{0}")]
UninitializedField(String),
#[error("{0}")]
Validation(String),
}
impl From<String> for BuilderError {
fn from(s: String) -> Self {
Self::Validation(s)
}
}
impl From<derive_builder::UninitializedFieldError> for BuilderError {
fn from(ufe: derive_builder::UninitializedFieldError) -> Self {
Self::UninitializedField(ufe.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AuthResponse, AuthToken};
#[test]
fn test_auth_validity_unset() {
let auth = super::AuthToken::default();
assert!(matches!(auth.get_state(None), AuthState::Unset));
}
#[test]
fn test_auth_validity_expired() {
let auth = super::AuthToken::new(
String::new(),
Some(AuthResponse {
token: AuthToken {
expires_at: chrono::Utc::now() - chrono::TimeDelta::days(1),
..Default::default()
},
}),
);
assert!(matches!(auth.get_state(None), AuthState::Expired));
}
#[test]
fn test_auth_validity_expire_soon() {
let auth = super::AuthToken::new(
String::new(),
Some(AuthResponse {
token: AuthToken {
expires_at: chrono::Utc::now() + chrono::TimeDelta::minutes(10),
..Default::default()
},
}),
);
assert!(matches!(
auth.get_state(Some(chrono::TimeDelta::minutes(15))),
AuthState::AboutToExpire
));
}
#[test]
fn test_auth_validity_valid() {
let auth = super::AuthToken::new(
String::new(),
Some(AuthResponse {
token: AuthToken {
expires_at: chrono::Utc::now() + chrono::TimeDelta::days(1),
..Default::default()
},
}),
);
assert!(matches!(auth.get_state(None), AuthState::Valid));
}
}