create_rust_app/auth/oidc/
controller.rs1use crate::{
2 auth::{
3 controller::{create_user_session, generate_salt, ARGON_CONFIG},
4 AuthConfig, User, UserChangeset,
5 },
6 AppConfig, Database,
7};
8use anyhow::Result;
9use diesel::OptionalExtension;
10use rand::{distributions::Alphanumeric, Rng};
11
12use super::{
13 model::{CreateUserOauth2Link, UpdateUserOauth2Link, UserOauth2Link},
14 OIDCProvider,
15};
16
17use openidconnect::{
18 core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
19 reqwest::async_http_client,
20 AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
21 OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, TokenResponse,
22};
23
24async fn create_oidc_client(provider: &OIDCProvider, app_url: String) -> Result<CoreClient> {
25 let provider_metadata = CoreProviderMetadata::discover_async(
26 IssuerUrl::new(provider.clone().issuer_url)?,
27 async_http_client,
28 )
29 .await?;
30
31 Ok(CoreClient::from_provider_metadata(
32 provider_metadata,
33 ClientId::new(provider.clone().client_id),
34 Some(ClientSecret::new(provider.clone().client_secret)),
35 )
36 .set_redirect_uri(RedirectUrl::new(provider.redirect_uri(&app_url))?))
37}
38
39pub async fn oidc_login_url(
42 db: &Database,
43 app_config: &AppConfig,
44 auth_config: &AuthConfig,
45 provider_name: String,
46) -> Result<Option<String>> {
47 let mut db = db.get_connection().unwrap();
48
49 let Some(provider) = auth_config
50 .clone()
51 .oidc_providers
52 .into_iter()
53 .find(|provider_config| provider_config.name.eq(&provider_name))
54 else {
55 return Ok(None);
56 };
57
58 let client = create_oidc_client(&provider, app_config.clone().app_url).await?;
59
60 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
61
62 let (auth_url, csrf_token, nonce) = client
64 .authorize_url(
65 CoreAuthenticationFlow::AuthorizationCode,
66 CsrfToken::new_random,
67 Nonce::new_random,
68 )
69 .add_extra_param("access_type", "offline") .add_scopes(provider.scope.into_iter().map(Scope::new))
71 .set_pkce_challenge(pkce_challenge)
72 .url();
73
74 UserOauth2Link::create(
75 &mut db,
76 &CreateUserOauth2Link {
77 provider: provider_name,
78 access_token: None,
79 refresh_token: None,
80 subject_id: None,
81 user_id: None,
82 csrf_token: csrf_token.secret().clone(),
83 nonce: nonce.secret().clone(),
84 pkce_secret: pkce_verifier.secret().clone(),
85 },
86 )?;
87
88 Ok(Some(auth_url.to_string()))
89}
90
91type RefreshToken = String;
92type AccessToken = String;
93type StatusCode = u16;
94type Message = String;
95
96#[allow(clippy::too_many_lines)]
107pub async fn oauth_login(
108 db: &Database,
109 app_config: &AppConfig,
110 auth_config: &AuthConfig,
111 provider_name: String,
112 query_param_code: Option<String>,
113 query_param_error: Option<String>,
114 query_param_state: Option<String>,
115) -> Result<(AccessToken, RefreshToken), (StatusCode, Message)> {
116 let db = &mut db.get_connection().unwrap();
117
118 let Some(provider) = auth_config
120 .clone()
121 .oidc_providers
122 .into_iter()
123 .find(|provider_config| provider_config.name.eq(&provider_name))
124 else {
125 return Err((501, "This oauth provider is not supported".into()));
126 };
127
128 if let Some(query_param_error) = query_param_error {
130 return Err((401, query_param_error));
175 }
176
177 let Some(state) = query_param_state else {
180 return Err((400, "Invalid CSRF token".into()));
181 };
182 let oauth_request = UserOauth2Link::read_by_csrf_token(db, provider_name.clone(), state)
183 .expect("Invalid oauth2 redirection");
184
185 let pkce_verifier = PkceCodeVerifier::new(oauth_request.pkce_secret);
186
187 let Some(code) = query_param_code else {
189 return Err((400, "Invalid code".into()));
190 };
191
192 let Ok(client) = create_oidc_client(&provider, app_config.clone().app_url).await else {
193 return Err((500, "Internal server error".into()));
194 };
195
196 let Ok(token_response) = client
197 .exchange_code(AuthorizationCode::new(code))
198 .set_pkce_verifier(pkce_verifier)
199 .request_async(async_http_client)
200 .await
201 else {
202 return Err((400, "Invalid code".into()));
203 };
204
205 let Some(id_token) = token_response.id_token() else {
206 return Err((500, "Server did not return an ID token".into()));
207 };
208
209 let Ok(claims) = id_token.claims(
210 &client.id_token_verifier(),
211 &Nonce::new(oauth_request.nonce),
212 ) else {
213 return Err((500, "Invalid ID token claims".into()));
214 };
215
216 if let Some(expected_access_token_hash) = claims.access_token_hash() {
217 let Ok(signing_alg) = id_token.signing_alg() else {
218 return Err((500, "Invalid signing algorithm".into()));
219 };
220
221 let Ok(actual_access_token_hash) =
222 AccessTokenHash::from_token(token_response.access_token(), &signing_alg)
223 else {
224 return Err((500, "Invalid access token".into()));
225 };
226
227 if actual_access_token_hash != *expected_access_token_hash {
228 return Err((401, "Invalid access token".into()));
229 }
230 }
231
232 let subject = claims.subject().to_string();
233
234 let user = match UserOauth2Link::read_by_subject(db, subject).optional() {
239 Ok(Some(oauth2_link)) => {
240 if oauth2_link.user_id.is_none() {
242 return Err((500, "Internal server error".into()));
243 }
244 let Ok(user) = User::read(db, oauth2_link.user_id.unwrap()) else {
245 return Err((500, "Internal server error".into()));
246 };
247
248 UserOauth2Link::update(
251 db,
252 oauth_request.id,
253 &UpdateUserOauth2Link {
254 provider: None,
255 access_token: Some(Some(token_response.access_token().secret().to_string())),
256 refresh_token: token_response
257 .refresh_token()
258 .map(|token| Some(token.secret().to_string())),
259 csrf_token: None,
260 nonce: None,
261 pkce_secret: None,
262 user_id: None,
263 subject_id: None,
264 created_at: None,
265 updated_at: None,
266 },
267 )
268 .unwrap();
269
270 user
271 }
272 Ok(None) => {
273 let email = match (claims.email(), claims.email_verified()) {
275 (Some(email), Some(true)) => email.to_string(),
276 (None, _) => return Err((500, "No email returned".into())),
277 (_, Some(false) | None) => return Err((500, "Email not verified".into())),
278 };
279
280 match User::find_by_email(db, email.clone()).optional() {
281 Ok(Some(_)) => {
282 return Err((500, "Email already registered".into()));
283 }
284 Err(_) => {
285 return Err((500, "Internal server error".into()));
286 }
287 Ok(None) => {}
288 }
289
290 let salt = generate_salt();
292 let random_password = rand::thread_rng()
293 .sample_iter(&Alphanumeric)
294 .take(64)
295 .map(char::from)
296 .collect::<String>();
297 let hash =
298 argon2::hash_encoded(random_password.as_bytes(), &salt, &ARGON_CONFIG).unwrap();
299 let Ok(new_user) = User::create(
300 db,
301 &UserChangeset {
302 email,
303 activated: false, hash_password: hash,
305 },
306 ) else {
307 return Err((500, "Internal server error".into()));
308 };
309
310 UserOauth2Link::update(
313 db,
314 oauth_request.id,
315 &UpdateUserOauth2Link {
316 provider: None,
317 access_token: Some(Some(token_response.access_token().secret().to_string())),
318 refresh_token: Some(
319 token_response
320 .refresh_token()
321 .map(|token| token.secret().into()),
322 ),
323 csrf_token: Some(String::new()),
324 nonce: Some(String::new()),
325 pkce_secret: Some(String::new()),
326 user_id: Some(Some(new_user.id)),
327 subject_id: Some(Some(claims.subject().to_string())),
328 created_at: None,
329 updated_at: None,
330 },
331 )
332 .unwrap();
333
334 new_user
335 }
336 Err(_) => return Err((500, "Internal server error".into())),
337 };
338
339 create_user_session(
340 db,
341 Some(format!("Oauth2 - {}", &provider_name)),
342 None,
343 user.id,
344 )
345 .map_err(|error| (error.0, error.1.to_string()))
346}