1use std::collections::HashMap;
17use std::fmt::{self, Debug};
18use std::time::SystemTime;
19
20use async_trait::async_trait;
21use http::{HeaderMap, HeaderValue};
22use reqwest::{Client, Request, Response};
23use secrecy::SecretString;
24use thiserror::Error;
25use tracing::{Level, event, info, instrument};
26
27pub mod authtoken;
28pub mod authtoken_scope;
29pub mod types;
30
31pub use authtoken::{AuthToken, AuthTokenError};
32pub use authtoken_scope::AuthTokenScope;
33pub use types::*;
34
35#[derive(Debug, Error)]
37#[non_exhaustive]
38pub enum AuthError {
39 #[error("authentication rejected")]
41 AuthReceipt(AuthReceiptResponse),
42
43 #[error("authentication receipt cannot be converted to string")]
45 AuthReceiptNotString,
46
47 #[error("AuthToken error: {}", source)]
49 AuthToken {
50 #[from]
52 source: AuthTokenError,
53 },
54
55 #[error("token missing in the response")]
57 AuthTokenNotInResponse,
58
59 #[error("token cannot be converted to string")]
61 AuthTokenNotString,
62
63 #[error("value necessary for the chosen auth method was not supplied to the auth method")]
65 AuthValueNotSupplied(String),
66
67 #[error("authentication method error: {}", .0.message)]
69 Identity(IdentityError),
70
71 #[error("plugin specified malformed requirements")]
74 PluginMalformedRequirement,
75
76 #[error("failed to deserialize response body: {}", source)]
78 Serde {
79 #[from]
81 source: serde_json::Error,
82 },
83
84 #[error("header value error: {}", source)]
86 HeaderValue {
87 #[from]
89 source: http::header::InvalidHeaderValue,
90 },
91
92 #[error("plugin error: {}", source)]
94 Plugin {
95 #[source]
97 source: Box<dyn std::error::Error + Send + Sync + 'static>,
98 },
99
100 #[error(transparent)]
102 Reqwest {
103 #[from]
105 source: reqwest::Error,
106 },
107
108 #[error("identity service error")]
110 UnknownAuth {
111 code: u16,
113 message: Option<String>,
115 },
116
117 #[error(transparent)]
119 Url {
120 #[from]
122 source: url::ParseError,
123 },
124}
125
126impl AuthError {
127 pub fn plugin<E>(error: E) -> Self
128 where
129 E: std::error::Error + Send + Sync + 'static,
130 {
131 Self::Plugin {
132 source: Box::new(error),
133 }
134 }
135}
136
137#[async_trait]
139pub trait OpenStackAuthType: Send + Sync {
140 fn get_supported_auth_methods(&self) -> Vec<&'static str>;
142
143 fn requirements(
145 &self,
146 hints: Option<&serde_json::Value>,
147 ) -> Result<serde_json::Value, AuthError>;
148
149 fn api_version(&self) -> (u8, u8);
151
152 async fn auth(
154 &self,
155 http_client: &reqwest::Client,
156 identity_url: &url::Url,
157 values: HashMap<String, SecretString>,
158 scope: Option<&AuthTokenScope>,
159 hints: Option<&serde_json::Value>,
160 ) -> Result<Auth, AuthError>;
161}
162
163pub struct AuthPluginRegistration {
165 pub method: &'static dyn OpenStackAuthType,
166}
167
168inventory::collect!(AuthPluginRegistration);
170
171pub trait OpenStackMultifactorAuthMethod: Send + Sync {
173 fn get_supported_auth_methods(&self) -> Vec<&'static str>;
175
176 fn requirements(
178 &self,
179 hints: Option<&serde_json::Value>,
180 ) -> Result<serde_json::Value, AuthError>;
181
182 fn get_auth_data(
184 &self,
185 values: &HashMap<String, SecretString>,
186 ) -> Result<(&'static str, serde_json::Value), AuthError>;
187}
188
189pub struct AuthMethodPluginRegistration {
191 pub method: &'static dyn OpenStackMultifactorAuthMethod,
192}
193inventory::collect!(AuthMethodPluginRegistration);
194
195#[instrument(name="request", skip_all, fields(http.uri = request.url().as_str(), http.method = request.method().as_str(), openstack.ver=request.headers().get("openstack-api-version").map(|v| v.to_str().unwrap_or(""))))]
196pub async fn execute_auth_request(
197 client: &Client,
198 request: Request,
199) -> Result<Response, reqwest::Error> {
200 info!("Sending request {:?}", request);
201 let url = request.url().clone();
202 let method = request.method().clone();
203 let start = SystemTime::now();
204 let rsp = client.execute(request).await?;
205 let elapsed = SystemTime::now().duration_since(start).unwrap_or_default();
206 event!(
207 name: "http_request",
208 Level::INFO,
209 url=url.as_str(),
210 duration_ms=elapsed.as_millis(),
211 status=rsp.status().as_u16(),
212 method=method.as_str(),
213 request_id=rsp.headers().get("x-openstack-request-id").map(|v| v.to_str().unwrap_or("")),
214 "Request completed with status {}",
215 rsp.status(),
216 );
217 Ok(rsp)
218}
219
220#[derive(Clone)]
222#[non_exhaustive]
223pub enum Auth {
224 AuthToken(Box<AuthToken>),
226 None,
228}
229
230impl Auth {
231 pub fn set_header<'a>(
235 &self,
236 headers: &'a mut HeaderMap<HeaderValue>,
237 ) -> Result<&'a mut HeaderMap<HeaderValue>, AuthError> {
238 if let Auth::AuthToken(token) = self {
239 let _ = token.set_header(headers);
240 }
241
242 Ok(headers)
243 }
244}
245
246impl fmt::Debug for Auth {
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 write!(
249 f,
250 "Auth {}",
251 match self {
252 Auth::AuthToken(_) => "Token",
253 Auth::None => "unauthed",
254 }
255 )
256 }
257}
258
259impl TryFrom<http::Response<bytes::Bytes>> for Auth {
260 type Error = AuthError;
261 fn try_from(value: http::Response<bytes::Bytes>) -> Result<Self, Self::Error> {
262 Ok(Self::AuthToken(Box::new(AuthToken::try_from(value)?)))
263 }
264}
265
266#[derive(Debug, Eq, PartialEq)]
268pub enum AuthState {
269 Valid,
271 Expired,
273 AboutToExpire,
275 Unset,
277}
278
279#[derive(Debug, Error)]
284#[non_exhaustive]
285pub enum BuilderError {
286 #[error("{0}")]
288 UninitializedField(String),
289 #[error("{0}")]
291 Validation(String),
292}
293
294impl From<String> for BuilderError {
295 fn from(s: String) -> Self {
296 Self::Validation(s)
297 }
298}
299
300impl From<derive_builder::UninitializedFieldError> for BuilderError {
301 fn from(ufe: derive_builder::UninitializedFieldError) -> Self {
302 Self::UninitializedField(ufe.to_string())
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::types::{AuthResponse, AuthToken};
310
311 #[test]
312 fn test_auth_validity_unset() {
313 let auth = super::AuthToken::default();
314 assert!(matches!(auth.get_state(None), AuthState::Unset));
315 }
316
317 #[test]
318 fn test_auth_validity_expired() {
319 let auth = super::AuthToken::new(
320 String::new(),
321 Some(AuthResponse {
322 token: AuthToken {
323 expires_at: chrono::Utc::now() - chrono::TimeDelta::days(1),
324 ..Default::default()
325 },
326 }),
327 );
328 assert!(matches!(auth.get_state(None), AuthState::Expired));
329 }
330
331 #[test]
332 fn test_auth_validity_expire_soon() {
333 let auth = super::AuthToken::new(
334 String::new(),
335 Some(AuthResponse {
336 token: AuthToken {
337 expires_at: chrono::Utc::now() + chrono::TimeDelta::minutes(10),
338 ..Default::default()
339 },
340 }),
341 );
342 assert!(matches!(
343 auth.get_state(Some(chrono::TimeDelta::minutes(15))),
344 AuthState::AboutToExpire
345 ));
346 }
347
348 #[test]
349 fn test_auth_validity_valid() {
350 let auth = super::AuthToken::new(
351 String::new(),
352 Some(AuthResponse {
353 token: AuthToken {
354 expires_at: chrono::Utc::now() + chrono::TimeDelta::days(1),
355 ..Default::default()
356 },
357 }),
358 );
359 assert!(matches!(auth.get_state(None), AuthState::Valid));
360 }
361}