axum_util/
oidc.rs

1// use always_cell::AlwaysCell;
2use anyhow::{Context, Result};
3use chrono::{DateTime, Utc};
4use indexmap::IndexMap;
5use log::warn;
6use openid::{
7    error::ClientError, Bearer, Client, DiscoveredClient, OAuth2Error, OAuth2ErrorCode, Options,
8    StandardClaims, Token, Userinfo,
9};
10use serde::{Deserialize, Serialize};
11use std::{sync::Arc, time::Duration};
12use tokio::sync::RwLock;
13use url::Url;
14
15#[derive(Serialize, Deserialize, Clone, Debug)]
16pub struct OidcConfig {
17    pub name: String,
18    pub client_id: String,
19    pub client_secret: String,
20    pub issuer: Url,
21    pub redirect: Url,
22    pub refresh_cycle: Duration,
23}
24
25pub struct OidcController {
26    handlers: IndexMap<String, OidcHandler>,
27}
28
29impl OidcController {
30    pub async fn new(configs: &[OidcConfig]) -> Self {
31        let mut handlers = IndexMap::new();
32        for config in configs {
33            handlers.insert(config.name.clone(), OidcHandler::new(config).await);
34        }
35        Self { handlers }
36    }
37
38    pub fn handler(&self, name: &str) -> Option<&OidcHandler> {
39        self.handlers.get(name)
40    }
41
42    pub fn handlers(&self) -> impl Iterator<Item = &String> {
43        self.handlers.keys()
44    }
45}
46
47#[derive(Clone)]
48pub struct OidcHandler {
49    client: Arc<RwLock<(DateTime<Utc>, Client)>>,
50    config: OidcConfig,
51}
52
53impl OidcHandler {
54    pub async fn new(config: &OidcConfig) -> Self {
55        let client = loop {
56            match DiscoveredClient::discover(
57                config.client_id.to_string(),
58                config.client_secret.to_string(),
59                Some(config.redirect.to_string()),
60                config.issuer.clone(),
61            )
62            .await
63            {
64                Ok(x) => break x,
65                Err(e) => {
66                    warn!("failed to discover OIDC: {e:?}");
67                    tokio::time::sleep(Duration::from_secs(1)).await;
68                }
69            }
70        };
71        Self {
72            client: Arc::new(RwLock::new((
73                Utc::now() + chrono::Duration::from_std(config.refresh_cycle).unwrap(),
74                client,
75            ))),
76            config: config.clone(),
77        }
78    }
79
80    async fn recreate(&self) -> Client {
81        loop {
82            match DiscoveredClient::discover(
83                self.config.client_id.clone(),
84                self.config.client_secret.clone(),
85                Some(self.config.redirect.to_string()),
86                self.config.issuer.clone(),
87            )
88            .await
89            {
90                Ok(x) => break x,
91                Err(e) => {
92                    warn!("failed to rediscover OIDC: {e:?}");
93                    tokio::time::sleep(Duration::from_secs(1)).await;
94                }
95            }
96        }
97    }
98
99    pub async fn auth_url(&self, redirect: Option<&Url>) -> Url {
100        let client = self.client.read().await;
101        let mut tclient;
102        let client = if let Some(redirect) = redirect {
103            tclient = client.1.clone();
104            tclient.redirect_uri = Some(redirect.to_string());
105            &tclient
106        } else {
107            &client.1
108        };
109        client.auth_url(&Options {
110            scope: Some("openid email profile".into()),
111            state: None,
112            ..Default::default()
113        })
114    }
115
116    pub async fn validate_code(
117        &self,
118        code: &str,
119        redirect: Option<&Url>,
120    ) -> Result<Option<(Bearer, StandardClaims, Userinfo)>> {
121        let mut client = self.client.read().await;
122        let now = Utc::now();
123        if client.0 < now {
124            drop(client);
125            let mut old_client = self.client.write().await;
126            if old_client.0 < now {
127                let new_client = self.recreate().await;
128                *old_client = (
129                    now + chrono::Duration::from_std(self.config.refresh_cycle).unwrap(),
130                    new_client,
131                )
132            }
133            drop(old_client);
134            client = self.client.read().await;
135        }
136        let mut tclient;
137        let client = if let Some(redirect) = redirect {
138            tclient = client.1.clone();
139            tclient.redirect_uri = Some(redirect.to_string());
140            &tclient
141        } else {
142            &client.1
143        };
144        let mut token: Token = match client.request_token(code).await {
145            Ok(x) => x.into(),
146            Err(ClientError::OAuth2(OAuth2Error {
147                error: OAuth2ErrorCode::InvalidGrant,
148                ..
149            })) => {
150                return Ok(None);
151            }
152            Err(e) => return Err(e.into()),
153        };
154
155        if let Some(id_token) = &mut token.id_token {
156            client
157                .decode_token(id_token)
158                .context("failed to decode token")?;
159            client
160                .validate_token(id_token, None, None)
161                .context("failed to validate token")?;
162        } else {
163            return Ok(None);
164        };
165
166        let info = client.request_userinfo(&token).await?;
167
168        Ok(Some((
169            token.bearer,
170            token.id_token.unwrap().unwrap_decoded().1,
171            info,
172        )))
173    }
174}