1use std::ops::Deref;
2use std::time::Duration;
3
4use chrono::Utc;
5use clap::Args;
6use oauth2::basic::{BasicErrorResponseType, BasicTokenType};
7use oauth2::{
8 RefreshToken, RequestTokenError, RevocationErrorResponseType, StandardRevocableToken,
9};
10use openidconnect::core::{
11 CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreGenderClaim, CoreJsonWebKey,
12 CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm,
13 CoreJwsSigningAlgorithm, CoreProviderMetadata,
14};
15use openidconnect::{
16 Client, ClientId, ClientSecret, EmptyAdditionalClaims, EmptyExtraTokenFields, IdTokenFields,
17 IssuerUrl, OAuth2TokenResponse, ResourceOwnerPassword, ResourceOwnerUsername,
18 StandardErrorResponse, StandardTokenIntrospectionResponse, StandardTokenResponse,
19};
20use reqwest::RequestBuilder;
21use serde_with::formats::Flexible;
22use serde_with::TimestampSeconds;
23
24use openidconnect::reqwest::async_http_client;
25use tokio::sync::Mutex;
26
27use crate::config::DwhConfig;
28use crate::jwt::JwtError;
29use crate::ReqwestHooks;
30
31#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
32pub struct OidcConfig {
33 pub url: String,
34 pub user: String,
35 pub password: String,
36 pub client_id: String,
37 pub client_secret: Option<String>,
38}
39
40#[derive(Args, Debug)]
41pub struct OidcArgs {
42 #[arg(
44 long,
45 env,
46 requires("oidc_user"),
47 requires("oidc_password"),
48 requires("oidc_client_id")
49 )]
50 oidc_server_url: Option<String>,
51
52 #[arg(long, env)]
53 oidc_user: Option<String>,
54
55 #[arg(long, env)]
56 oidc_password: Option<String>,
57
58 #[arg(long, env)]
59 oidc_client_id: Option<String>,
60
61 #[arg(long, env)]
62 oidc_client_secret: Option<String>,
63
64 #[arg(long, env)]
65 oidc_profile: Option<String>,
66}
67
68impl OidcArgs {
69 pub fn oidc_config(&self) -> Option<OidcConfig> {
70 if let Some(profile) = self.oidc_profile.clone() {
71 match DwhConfig::read() {
72 Ok(config) => {
73 let Some(profiles) = config.profiles else {
74 panic!("No Profiles in dwh config")
75 };
76 return profiles.get(&profile).cloned();
77 }
78 Err(e) => {
79 panic!("failed to read config dwh config file {}", e)
80 }
81 }
82 }
83 let Some(url) = self.oidc_server_url.clone() else {
84 return None;
85 };
86 let Some(user) = self.oidc_user.clone() else {
87 return None;
88 };
89 let Some(password) = self.oidc_password.clone() else {
90 return None;
91 };
92 let Some(client_id) = self.oidc_client_id.clone() else {
93 return None;
94 };
95 return Some(OidcConfig {
96 url,
97 user,
98 password,
99 client_id,
100 client_secret: self.oidc_client_secret.clone(),
101 });
102 }
103}
104
105#[derive(thiserror::Error, Debug)]
106pub enum TokenProviderError {
107 #[error("The given credentials are not authirzed to create a token. reason: {0}")]
108 Unauthorized(String),
109 #[error("Failed to retreive a token. Server is not answering")]
110 Connection,
111 #[error("An unknown Error has Been occurred")]
112 Other,
113}
114
115#[async_trait::async_trait]
116pub trait TokenProvider {
117 async fn get_access_token(&self) -> Result<String, TokenProviderError>;
118}
119
120type OidcClient = Client<
121 EmptyAdditionalClaims,
122 CoreAuthDisplay,
123 CoreGenderClaim,
124 CoreJweContentEncryptionAlgorithm,
125 CoreJwsSigningAlgorithm,
126 CoreJsonWebKeyType,
127 CoreJsonWebKeyUse,
128 CoreJsonWebKey,
129 CoreAuthPrompt,
130 StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
131 StandardTokenResponse<
132 IdTokenFields<
133 EmptyAdditionalClaims,
134 EmptyExtraTokenFields,
135 CoreGenderClaim,
136 CoreJweContentEncryptionAlgorithm,
137 CoreJwsSigningAlgorithm,
138 CoreJsonWebKeyType,
139 >,
140 BasicTokenType,
141 >,
142 BasicTokenType,
143 StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>,
144 StandardRevocableToken,
145 StandardErrorResponse<RevocationErrorResponseType>,
146>;
147
148type TokenType = StandardTokenResponse<
149 IdTokenFields<
150 EmptyAdditionalClaims,
151 EmptyExtraTokenFields,
152 CoreGenderClaim,
153 CoreJweContentEncryptionAlgorithm,
154 CoreJwsSigningAlgorithm,
155 CoreJsonWebKeyType,
156 >,
157 BasicTokenType,
158>;
159
160pub struct Token {
161 pub token: String,
162 pub claims: Claims,
163}
164
165impl Token {
166 pub fn expires_soon(&self, min_time_left: Duration) -> bool {
167 return chrono::Utc::now() + min_time_left > self.claims.exp;
168 }
169}
170
171pub struct TokenState {
172 pub access_token: Token,
173 pub refresh_token: Option<Token>,
174}
175
176impl TryFrom<TokenType> for TokenState {
177 type Error = JwtError;
178
179 fn try_from(value: TokenType) -> Result<Self, Self::Error> {
180 let token = value.access_token().secret();
181 let access_token = Token {
182 claims: crate::jwt::decode(token)?,
183 token: token.to_string(),
184 };
185 let refresh_token = match value.refresh_token() {
186 Some(rt) => {
187 let token = rt.secret();
188 Some(Token {
189 claims: crate::jwt::decode(token)?,
190 token: token.to_string(),
191 })
192 }
193 None => None,
194 };
195 Ok(TokenState {
196 access_token,
197 refresh_token,
198 })
199 }
200}
201
202#[derive(thiserror::Error, Debug)]
203pub enum OidcTokenServiceError {
204 #[error("failed to request token {0}")]
205 RequestTokenError(String),
206 #[error("Failed to parse token {0}")]
207 JwtError(#[from] JwtError),
208}
209
210impl From<OidcTokenServiceError> for TokenProviderError {
211 fn from(value: OidcTokenServiceError) -> Self {
214 match value {
215 OidcTokenServiceError::RequestTokenError(_value) => TokenProviderError::Other,
216 OidcTokenServiceError::JwtError(_) => TokenProviderError::Other,
217 }
218 }
219}
220
221pub struct OidcTokenService {
222 user: String,
223 password: String,
224 client: OidcClient,
225 token: tokio::sync::Mutex<Option<TokenState>>,
226}
227
228impl OidcTokenService {
229 pub async fn new(config: OidcConfig) -> anyhow::Result<Self> {
230 let provider_metadata =
231 CoreProviderMetadata::discover_async(IssuerUrl::new(config.url)?, async_http_client)
232 .await?;
233
234 let client = CoreClient::from_provider_metadata(
235 provider_metadata,
236 ClientId::new(config.client_id.clone()),
237 config.client_secret.map(|secret| ClientSecret::new(secret)),
238 );
239
240 Ok(Self {
241 client,
242 user: config.user,
243 password: config.password,
244 token: Mutex::new(None),
245 })
246 }
247
248 pub async fn refresh_access_token_with_credentials(
249 &self,
250 ) -> Result<TokenState, OidcTokenServiceError> {
251 let user = ResourceOwnerUsername::new(self.user.clone());
252 let password = ResourceOwnerPassword::new(self.password.clone());
253 let result = self
254 .client
255 .exchange_password(&user, &password)
256 .request_async(async_http_client)
257 .await
258 .map_err(|e| OidcTokenServiceError::RequestTokenError(e.to_string()))?;
259 Ok(result.try_into()?)
261 }
262 pub async fn refresh_access_token_with_refresh_token(
263 &self,
264 refresh_token: String,
265 ) -> Result<TokenState, OidcTokenServiceError> {
266 let refresh_token = RefreshToken::new(refresh_token);
267 Ok(self
268 .client
269 .exchange_refresh_token(&refresh_token)
270 .request_async(async_http_client)
271 .await
272 .map_err(|e| OidcTokenServiceError::RequestTokenError(e.to_string()))?
273 .try_into()?)
274 }
275}
276
277impl
278 From<
279 RequestTokenError<
280 oauth2::reqwest::Error<reqwest::Error>,
281 StandardErrorResponse<BasicErrorResponseType>,
282 >,
283 > for TokenProviderError
284{
285 fn from(
286 value: RequestTokenError<
287 oauth2::reqwest::Error<reqwest::Error>,
288 StandardErrorResponse<BasicErrorResponseType>,
289 >,
290 ) -> Self {
291 let response = match value {
292 RequestTokenError::ServerResponse(response) => response,
293 RequestTokenError::Request(_) => return Self::Connection,
294 RequestTokenError::Parse(_, _) => return Self::Other,
295 RequestTokenError::Other(_) => return Self::Other,
296 };
297 Self::Unauthorized(
298 response
299 .error_description()
300 .cloned()
301 .unwrap_or("Unknown".to_string()),
302 )
303 }
304}
305
306#[serde_with::serde_as]
307#[derive(serde::Deserialize, serde::Serialize)]
308pub struct Claims {
309 #[serde_as(as = "TimestampSeconds<String, Flexible>")]
310 exp: chrono::DateTime<Utc>,
311}
312
313#[async_trait::async_trait]
314impl TokenProvider for OidcTokenService {
315 async fn get_access_token(&self) -> Result<String, TokenProviderError> {
316 let mut token_container = self.token.lock().await;
317 if let Some(token) = token_container.deref() {
318 if !token.access_token.expires_soon(Duration::from_secs(7)) {
320 return Ok(token.access_token.token.clone());
322 }
323 if let Some(refresh_token) = token.refresh_token.as_ref() {
325 if !refresh_token.expires_soon(Duration::from_secs(7)) {
327 let token = self
329 .refresh_access_token_with_refresh_token(refresh_token.token.clone())
330 .await?;
331 let token_string = token.access_token.token.clone();
332 *token_container = Some(token);
333 return Ok(token_string);
334 }
335 }
336 };
337 let token = self.refresh_access_token_with_credentials().await?;
341 let token_string = token.access_token.token.clone();
342 *token_container = Some(token);
343 return Ok(token_string);
344 }
345}
346
347#[async_trait::async_trait]
348impl ReqwestHooks for OidcTokenService {
349 async fn before_send(&self, req: RequestBuilder) -> crate::Result<RequestBuilder> {
350 let token = self.get_access_token().await?;
351 let req = req.header("Authorization", format!("Bearer {}", token));
352 Ok(req)
353 }
354}
355
356