dwh/
oidc.rs

1use std::ops::Deref;
2use std::time::Duration;
3
4use chrono::Utc;
5use clap::Args;
6use oauth2::basic::{BasicErrorResponseType, BasicTokenType};
7use oauth2::{
8    RefreshToken, RequestTokenError, RevocationErrorResponseType, StandardRevocableToken,
9};
10use openidconnect::core::{
11    CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreGenderClaim, CoreJsonWebKey,
12    CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm,
13    CoreJwsSigningAlgorithm, CoreProviderMetadata,
14};
15use openidconnect::{
16    Client, ClientId, ClientSecret, EmptyAdditionalClaims, EmptyExtraTokenFields, IdTokenFields,
17    IssuerUrl, OAuth2TokenResponse, ResourceOwnerPassword, ResourceOwnerUsername,
18    StandardErrorResponse, StandardTokenIntrospectionResponse, StandardTokenResponse,
19};
20use reqwest::RequestBuilder;
21use serde_with::formats::Flexible;
22use serde_with::TimestampSeconds;
23
24use openidconnect::reqwest::async_http_client;
25use tokio::sync::Mutex;
26
27use crate::config::DwhConfig;
28use crate::jwt::JwtError;
29use crate::ReqwestHooks;
30
31#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
32pub struct OidcConfig {
33    pub url: String,
34    pub user: String,
35    pub password: String,
36    pub client_id: String,
37    pub client_secret: Option<String>,
38}
39
40#[derive(Args, Debug)]
41pub struct OidcArgs {
42    /// set version manually
43    #[arg(
44        long,
45        env,
46        requires("oidc_user"),
47        requires("oidc_password"),
48        requires("oidc_client_id")
49    )]
50    oidc_server_url: Option<String>,
51
52    #[arg(long, env)]
53    oidc_user: Option<String>,
54
55    #[arg(long, env)]
56    oidc_password: Option<String>,
57
58    #[arg(long, env)]
59    oidc_client_id: Option<String>,
60
61    #[arg(long, env)]
62    oidc_client_secret: Option<String>,
63
64    #[arg(long, env)]
65    oidc_profile: Option<String>,
66}
67
68impl OidcArgs {
69    pub fn oidc_config(&self) -> Option<OidcConfig> {
70        if let Some(profile) = self.oidc_profile.clone() {
71            match DwhConfig::read() {
72                Ok(config) => {
73                    let Some(profiles) = config.profiles else {
74                        panic!("No Profiles in dwh config")
75                    };
76                    return profiles.get(&profile).cloned();
77                }
78                Err(e) => {
79                    panic!("failed to read config dwh config file {}", e)
80                }
81            }
82        }
83        let Some(url) = self.oidc_server_url.clone() else {
84            return None;
85        };
86        let Some(user) = self.oidc_user.clone() else {
87            return None;
88        };
89        let Some(password) = self.oidc_password.clone() else {
90            return None;
91        };
92        let Some(client_id) = self.oidc_client_id.clone() else {
93            return None;
94        };
95        return Some(OidcConfig {
96            url,
97            user,
98            password,
99            client_id,
100            client_secret: self.oidc_client_secret.clone(),
101        });
102    }
103}
104
105#[derive(thiserror::Error, Debug)]
106pub enum TokenProviderError {
107    #[error("The given credentials are not authirzed to create a token. reason: {0}")]
108    Unauthorized(String),
109    #[error("Failed to retreive a token. Server is not answering")]
110    Connection,
111    #[error("An unknown Error has Been occurred")]
112    Other,
113}
114
115#[async_trait::async_trait]
116pub trait TokenProvider {
117    async fn get_access_token(&self) -> Result<String, TokenProviderError>;
118}
119
120type OidcClient = Client<
121    EmptyAdditionalClaims,
122    CoreAuthDisplay,
123    CoreGenderClaim,
124    CoreJweContentEncryptionAlgorithm,
125    CoreJwsSigningAlgorithm,
126    CoreJsonWebKeyType,
127    CoreJsonWebKeyUse,
128    CoreJsonWebKey,
129    CoreAuthPrompt,
130    StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
131    StandardTokenResponse<
132        IdTokenFields<
133            EmptyAdditionalClaims,
134            EmptyExtraTokenFields,
135            CoreGenderClaim,
136            CoreJweContentEncryptionAlgorithm,
137            CoreJwsSigningAlgorithm,
138            CoreJsonWebKeyType,
139        >,
140        BasicTokenType,
141    >,
142    BasicTokenType,
143    StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>,
144    StandardRevocableToken,
145    StandardErrorResponse<RevocationErrorResponseType>,
146>;
147
148type TokenType = StandardTokenResponse<
149    IdTokenFields<
150        EmptyAdditionalClaims,
151        EmptyExtraTokenFields,
152        CoreGenderClaim,
153        CoreJweContentEncryptionAlgorithm,
154        CoreJwsSigningAlgorithm,
155        CoreJsonWebKeyType,
156    >,
157    BasicTokenType,
158>;
159
160pub struct Token {
161    pub token: String,
162    pub claims: Claims,
163}
164
165impl Token {
166    pub fn expires_soon(&self, min_time_left: Duration) -> bool {
167        return chrono::Utc::now() + min_time_left > self.claims.exp;
168    }
169}
170
171pub struct TokenState {
172    pub access_token: Token,
173    pub refresh_token: Option<Token>,
174}
175
176impl TryFrom<TokenType> for TokenState {
177    type Error = JwtError;
178
179    fn try_from(value: TokenType) -> Result<Self, Self::Error> {
180        let token = value.access_token().secret();
181        let access_token = Token {
182            claims: crate::jwt::decode(token)?,
183            token: token.to_string(),
184        };
185        let refresh_token = match value.refresh_token() {
186            Some(rt) => {
187                let token = rt.secret();
188                Some(Token {
189                    claims: crate::jwt::decode(token)?,
190                    token: token.to_string(),
191                })
192            }
193            None => None,
194        };
195        Ok(TokenState {
196            access_token,
197            refresh_token,
198        })
199    }
200}
201
202#[derive(thiserror::Error, Debug)]
203pub enum OidcTokenServiceError {
204    #[error("failed to request token {0}")]
205    RequestTokenError(String),
206    #[error("Failed to parse token {0}")]
207    JwtError(#[from] JwtError),
208}
209
210impl From<OidcTokenServiceError> for TokenProviderError {
211    //    fn into(self) -> TokenProviderError {}
212
213    fn from(value: OidcTokenServiceError) -> Self {
214        match value {
215            OidcTokenServiceError::RequestTokenError(_value) => TokenProviderError::Other,
216            OidcTokenServiceError::JwtError(_) => TokenProviderError::Other,
217        }
218    }
219}
220
221pub struct OidcTokenService {
222    user: String,
223    password: String,
224    client: OidcClient,
225    token: tokio::sync::Mutex<Option<TokenState>>,
226}
227
228impl OidcTokenService {
229    pub async fn new(config: OidcConfig) -> anyhow::Result<Self> {
230        let provider_metadata =
231            CoreProviderMetadata::discover_async(IssuerUrl::new(config.url)?, async_http_client)
232                .await?;
233
234        let client = CoreClient::from_provider_metadata(
235            provider_metadata,
236            ClientId::new(config.client_id.clone()),
237            config.client_secret.map(|secret| ClientSecret::new(secret)),
238        );
239
240        Ok(Self {
241            client,
242            user: config.user,
243            password: config.password,
244            token: Mutex::new(None),
245        })
246    }
247
248    pub async fn refresh_access_token_with_credentials(
249        &self,
250    ) -> Result<TokenState, OidcTokenServiceError> {
251        let user = ResourceOwnerUsername::new(self.user.clone());
252        let password = ResourceOwnerPassword::new(self.password.clone());
253        let result = self
254            .client
255            .exchange_password(&user, &password)
256            .request_async(async_http_client)
257            .await
258            .map_err(|e| OidcTokenServiceError::RequestTokenError(e.to_string()))?;
259        //let token = result.access_token().secret();
260        Ok(result.try_into()?)
261    }
262    pub async fn refresh_access_token_with_refresh_token(
263        &self,
264        refresh_token: String,
265    ) -> Result<TokenState, OidcTokenServiceError> {
266        let refresh_token = RefreshToken::new(refresh_token);
267        Ok(self
268            .client
269            .exchange_refresh_token(&refresh_token)
270            .request_async(async_http_client)
271            .await
272            .map_err(|e| OidcTokenServiceError::RequestTokenError(e.to_string()))?
273            .try_into()?)
274    }
275}
276
277impl
278    From<
279        RequestTokenError<
280            oauth2::reqwest::Error<reqwest::Error>,
281            StandardErrorResponse<BasicErrorResponseType>,
282        >,
283    > for TokenProviderError
284{
285    fn from(
286        value: RequestTokenError<
287            oauth2::reqwest::Error<reqwest::Error>,
288            StandardErrorResponse<BasicErrorResponseType>,
289        >,
290    ) -> Self {
291        let response = match value {
292            RequestTokenError::ServerResponse(response) => response,
293            RequestTokenError::Request(_) => return Self::Connection,
294            RequestTokenError::Parse(_, _) => return Self::Other,
295            RequestTokenError::Other(_) => return Self::Other,
296        };
297        Self::Unauthorized(
298            response
299                .error_description()
300                .cloned()
301                .unwrap_or("Unknown".to_string()),
302        )
303    }
304}
305
306#[serde_with::serde_as]
307#[derive(serde::Deserialize, serde::Serialize)]
308pub struct Claims {
309    #[serde_as(as = "TimestampSeconds<String, Flexible>")]
310    exp: chrono::DateTime<Utc>,
311}
312
313#[async_trait::async_trait]
314impl TokenProvider for OidcTokenService {
315    async fn get_access_token(&self) -> Result<String, TokenProviderError> {
316        let mut token_container = self.token.lock().await;
317        if let Some(token) = token_container.deref() {
318            //we have tokens!
319            if !token.access_token.expires_soon(Duration::from_secs(7)) {
320                //happy path. we have a token and nobody has a problem
321                return Ok(token.access_token.token.clone());
322            }
323            //do we have a refresh token?
324            if let Some(refresh_token) = token.refresh_token.as_ref() {
325                // is that refresh_token still valid?
326                if !refresh_token.expires_soon(Duration::from_secs(7)) {
327                    //use that refresh token to get a new access token
328                    let token = self
329                        .refresh_access_token_with_refresh_token(refresh_token.token.clone())
330                        .await?;
331                    let token_string = token.access_token.token.clone();
332                    *token_container = Some(token);
333                    return Ok(token_string);
334                }
335            }
336        };
337        // no token
338        // refresh token epxired
339        // no refresh token or refresh token expired
340        let token = self.refresh_access_token_with_credentials().await?;
341        let token_string = token.access_token.token.clone();
342        *token_container = Some(token);
343        return Ok(token_string);
344    }
345}
346
347#[async_trait::async_trait]
348impl ReqwestHooks for OidcTokenService {
349    async fn before_send(&self, req: RequestBuilder) -> crate::Result<RequestBuilder> {
350        let token = self.get_access_token().await?;
351        let req = req.header("Authorization", format!("Bearer {}", token));
352        Ok(req)
353    }
354}
355
356// #[derive(serde::Serialize, serde::Deserialize)]
357// #[serde(untagged)]
358// pub enum MacType {
359//     Vec(Vec<String>),
360//     String(String),
361// }
362
363// impl MacType {
364//     pub fn into_vec(self) -> Vec<String> {
365//         match self {
366//             MacType::String(mac) => vec![mac],
367//             MacType::Vec(macs) => macs,
368//         }
369//     }
370// }
371// #[cfg(test)]
372// mod test {
373//     use super::MacType;
374//     #[test]
375
376//     pub fn test_deserialize() -> anyhow::Result<()> {
377//         let t: Vec<String> = serde_json::from_str::<MacType>(r#""00:11:22:33:44:55""#)?.into_vec();
378//         assert_eq!(1, t.len());
379//         let t: Vec<String> =
380//             serde_json::from_str::<MacType>(r#"["00:11:22:33:44:55", "00:11:22:33:44:56"]"#)?
381//                 .into_vec();
382//         assert_eq!(2, t.len());
383//         Ok(())
384//     }
385// }