1use anyhow::{Context, Result};
3use chrono::{DateTime, Utc};
4use indexmap::IndexMap;
5use log::warn;
6use openid::{
7 error::ClientError, Bearer, Client, DiscoveredClient, OAuth2Error, OAuth2ErrorCode, Options,
8 StandardClaims, Token, Userinfo,
9};
10use serde::{Deserialize, Serialize};
11use std::{sync::Arc, time::Duration};
12use tokio::sync::RwLock;
13use url::Url;
14
15#[derive(Serialize, Deserialize, Clone, Debug)]
16pub struct OidcConfig {
17 pub name: String,
18 pub client_id: String,
19 pub client_secret: String,
20 pub issuer: Url,
21 pub redirect: Url,
22 pub refresh_cycle: Duration,
23}
24
25pub struct OidcController {
26 handlers: IndexMap<String, OidcHandler>,
27}
28
29impl OidcController {
30 pub async fn new(configs: &[OidcConfig]) -> Self {
31 let mut handlers = IndexMap::new();
32 for config in configs {
33 handlers.insert(config.name.clone(), OidcHandler::new(config).await);
34 }
35 Self { handlers }
36 }
37
38 pub fn handler(&self, name: &str) -> Option<&OidcHandler> {
39 self.handlers.get(name)
40 }
41
42 pub fn handlers(&self) -> impl Iterator<Item = &String> {
43 self.handlers.keys()
44 }
45}
46
47#[derive(Clone)]
48pub struct OidcHandler {
49 client: Arc<RwLock<(DateTime<Utc>, Client)>>,
50 config: OidcConfig,
51}
52
53impl OidcHandler {
54 pub async fn new(config: &OidcConfig) -> Self {
55 let client = loop {
56 match DiscoveredClient::discover(
57 config.client_id.to_string(),
58 config.client_secret.to_string(),
59 Some(config.redirect.to_string()),
60 config.issuer.clone(),
61 )
62 .await
63 {
64 Ok(x) => break x,
65 Err(e) => {
66 warn!("failed to discover OIDC: {e:?}");
67 tokio::time::sleep(Duration::from_secs(1)).await;
68 }
69 }
70 };
71 Self {
72 client: Arc::new(RwLock::new((
73 Utc::now() + chrono::Duration::from_std(config.refresh_cycle).unwrap(),
74 client,
75 ))),
76 config: config.clone(),
77 }
78 }
79
80 async fn recreate(&self) -> Client {
81 loop {
82 match DiscoveredClient::discover(
83 self.config.client_id.clone(),
84 self.config.client_secret.clone(),
85 Some(self.config.redirect.to_string()),
86 self.config.issuer.clone(),
87 )
88 .await
89 {
90 Ok(x) => break x,
91 Err(e) => {
92 warn!("failed to rediscover OIDC: {e:?}");
93 tokio::time::sleep(Duration::from_secs(1)).await;
94 }
95 }
96 }
97 }
98
99 pub async fn auth_url(&self, redirect: Option<&Url>) -> Url {
100 let client = self.client.read().await;
101 let mut tclient;
102 let client = if let Some(redirect) = redirect {
103 tclient = client.1.clone();
104 tclient.redirect_uri = Some(redirect.to_string());
105 &tclient
106 } else {
107 &client.1
108 };
109 client.auth_url(&Options {
110 scope: Some("openid email profile".into()),
111 state: None,
112 ..Default::default()
113 })
114 }
115
116 pub async fn validate_code(
117 &self,
118 code: &str,
119 redirect: Option<&Url>,
120 ) -> Result<Option<(Bearer, StandardClaims, Userinfo)>> {
121 let mut client = self.client.read().await;
122 let now = Utc::now();
123 if client.0 < now {
124 drop(client);
125 let mut old_client = self.client.write().await;
126 if old_client.0 < now {
127 let new_client = self.recreate().await;
128 *old_client = (
129 now + chrono::Duration::from_std(self.config.refresh_cycle).unwrap(),
130 new_client,
131 )
132 }
133 drop(old_client);
134 client = self.client.read().await;
135 }
136 let mut tclient;
137 let client = if let Some(redirect) = redirect {
138 tclient = client.1.clone();
139 tclient.redirect_uri = Some(redirect.to_string());
140 &tclient
141 } else {
142 &client.1
143 };
144 let mut token: Token = match client.request_token(code).await {
145 Ok(x) => x.into(),
146 Err(ClientError::OAuth2(OAuth2Error {
147 error: OAuth2ErrorCode::InvalidGrant,
148 ..
149 })) => {
150 return Ok(None);
151 }
152 Err(e) => return Err(e.into()),
153 };
154
155 if let Some(id_token) = &mut token.id_token {
156 client
157 .decode_token(id_token)
158 .context("failed to decode token")?;
159 client
160 .validate_token(id_token, None, None)
161 .context("failed to validate token")?;
162 } else {
163 return Ok(None);
164 };
165
166 let info = client.request_userinfo(&token).await?;
167
168 Ok(Some((
169 token.bearer,
170 token.id_token.unwrap().unwrap_decoded().1,
171 info,
172 )))
173 }
174}