1use derive_more::Display;
46
47pub enum AuthenticationFlow {
48 AuthorizationCodeFlow,
49 ImplicitFlow { token: bool },
50 HybridFlow { token: bool, id_token: bool },
51}
52
53impl AuthenticationFlow {
54 pub fn as_reponse_type(&self) -> &'static str {
55 match self {
56 Self::AuthorizationCodeFlow => "code",
57 Self::ImplicitFlow { token: with_token } => {
58 if *with_token {
59 "id_token token"
60 } else {
61 "id_token"
62 }
63 }
64 Self::HybridFlow {
65 token: with_token,
66 id_token: with_id_token,
67 } => {
68 if *with_token && *with_id_token {
69 "code id_token token"
70 } else if *with_token && !*with_id_token {
71 "code token"
72 } else if !*with_token && *with_id_token {
73 "code id_token"
74 } else {
75 tracing::warn!("Using HybridFlow without tokens fallbacks to 'code id_token'");
76
77 "code id_token"
78 }
79 }
80 }
81 }
82}
83
84#[derive(Display)]
85pub enum AuthenticationRequestParameters {
86 #[display("code")]
87 Code,
88 #[display("response_type")]
89 ResponseType,
90 #[display("scope")]
91 Scope,
92}
93
94#[derive(Display)]
95pub enum AuthenticationRequestScope {
96 #[display("openid")]
97 OpenID,
98 #[display("profile")]
99 Profile,
100 #[display("email")]
101 Email,
102 #[display("address")]
103 Address,
104 #[display("phone")]
105 Phone,
106 #[display("offline_access")]
107 OfflineAccess,
108 #[display("{}", _0)]
109 Unchecked(&'static str),
110}
111
112#[cfg(test)]
113mod tests {
114
115 use super::AuthenticationFlow;
116
117 use pretty_assertions::assert_eq;
118
119 #[test]
120 fn check_response_type_of_authorization_code_flow() {
121 assert_eq!(AuthenticationFlow::AuthorizationCodeFlow.as_reponse_type(), "code");
122 }
123 #[test]
124 fn check_response_type_of_implicit_flow() {
125 assert_eq!(
126 AuthenticationFlow::ImplicitFlow { token: false }.as_reponse_type(),
127 "id_token"
128 );
129 assert_eq!(
130 AuthenticationFlow::ImplicitFlow { token: true }.as_reponse_type(),
131 "id_token token"
132 );
133 }
134 #[test]
135 fn check_response_type_of_hybrid_flow() {
136 assert_eq!(
137 AuthenticationFlow::HybridFlow {
138 token: false,
139 id_token: false
140 }
141 .as_reponse_type(),
142 "code id_token"
143 );
144
145 assert_eq!(
146 AuthenticationFlow::HybridFlow {
147 token: true,
148 id_token: false
149 }
150 .as_reponse_type(),
151 "code token"
152 );
153
154 assert_eq!(
155 AuthenticationFlow::HybridFlow {
156 token: false,
157 id_token: true
158 }
159 .as_reponse_type(),
160 "code id_token"
161 );
162
163 assert_eq!(
164 AuthenticationFlow::HybridFlow {
165 token: true,
166 id_token: true
167 }
168 .as_reponse_type(),
169 "code id_token token"
170 );
171 }
172}
173
174pub mod authorization_code_flow {
175 use serde::Deserialize;
176 use url::Url;
177
178 use crate::{config, id_token};
179
180 use super::{AuthenticationFlow, AuthenticationRequestParameters, AuthenticationRequestScope};
181
182 #[derive(Deserialize)]
183 pub struct AuthorizationCodeFlowTokenResponse {
184 pub id_token: id_token::IDToken,
185 pub access_token: String,
186 }
187
188 #[non_exhaustive]
219 pub struct AuthorizationCodeFlowClient {
220 flow: AuthenticationFlow,
221 http: reqwest::Client,
222 oidc_uri: String,
223 scopes: Vec<AuthenticationRequestScope>,
224 }
225
226 impl AuthorizationCodeFlowClient {
227 pub fn new(oidc_uri: &str) -> Self {
228 Self {
229 flow: AuthenticationFlow::AuthorizationCodeFlow,
230 http: reqwest::Client::new(),
231 oidc_uri: oidc_uri.to_owned(),
232 scopes: vec![AuthenticationRequestScope::OpenID],
233 }
234 }
235
236 pub fn with_scope(mut self, s: AuthenticationRequestScope) -> Self {
258 match s {
259 AuthenticationRequestScope::OpenID => self,
260 _ => {
261 self.scopes.push(s);
262 self
263 }
264 }
265 }
266
267 pub async fn build_authorization_endpoint(&self) -> anyhow::Result<Url> {
278 let conf = config::OpenIDConfiguration::from_remote(&self.http, &self.oidc_uri).await?;
279
280 let mut authorization_endpoint = Url::parse(&conf.authorization_endpoint)?;
281
282 if authorization_endpoint.scheme() != "https" {
283 anyhow::bail!("authorization endpoint must be TLS");
284 }
285
286 let request_params = [
287 (
288 AuthenticationRequestParameters::Scope.to_string(),
289 self.scopes.iter().map(|n| n.to_string()).collect::<Vec<String>>().join(" "),
290 ),
291 (
292 AuthenticationRequestParameters::ResponseType.to_string(),
293 self.flow.as_reponse_type().to_owned(),
294 ),
295 ];
296
297 for (key, value) in request_params.iter() {
298 authorization_endpoint.query_pairs_mut().append_pair(key, value);
299 }
300
301 Ok(authorization_endpoint)
302 }
303
304 pub async fn fetch_authorization_tokens(&self, code: &str) -> anyhow::Result<AuthorizationCodeFlowTokenResponse> {
305 let conf = config::OpenIDConfiguration::from_remote(&self.http, &self.oidc_uri).await?;
306 let mut token_endpoint = Url::parse(&conf.token_endpoint)?;
307 let request_params = [(AuthenticationRequestParameters::Code.to_string(), code)];
308
309 for (key, value) in request_params.iter() {
310 token_endpoint.query_pairs_mut().append_pair(key, value);
311 }
312
313 let token_response = self
314 .http
315 .post(token_endpoint.as_str())
316 .form(&request_params)
317 .send()
318 .await?
319 .json::<AuthorizationCodeFlowTokenResponse>()
320 .await?;
321
322 token_response.id_token.validate()?;
323 Ok(token_response)
324 }
325
326 pub fn extract_authorization_code(&self, url: &str) -> anyhow::Result<String> {
327 let url = Url::parse(url)?;
328
329 let Some(code) = url
330 .query_pairs()
331 .find(|n| n.0 == AuthenticationRequestParameters::Code.to_string())
332 else {
333 anyhow::bail!("Code not found in provided url");
334 };
335
336 Ok(code.1.into_owned())
337 }
338 }
339
340 #[cfg(test)]
341 mod tests {
342 use axum::{Json, Router, routing::get};
343 use serde_json::json;
344 use tokio::net::TcpListener;
345
346 use crate::authentication::{AuthenticationRequestParameters, AuthenticationRequestScope};
347
348 use super::AuthorizationCodeFlowClient;
349
350 use pretty_assertions::assert_eq;
351
352 #[tokio::test]
353 async fn authorization_endpoint_is_tls() {
354 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
355 let addr = listener.local_addr().unwrap();
356
357 let oidc_uri_path = "/.well-known/openid-configuration";
358
359 tokio::spawn(async move {
360 axum::serve(
361 listener,
362 Router::new().route(
363 oidc_uri_path,
364 get(|| async {
365 Json(json!({
366 "token_endpoint": "https://_/token",
367 "authorization_endpoint": "https://_/authorize"
368 }))
369 }),
370 ),
371 )
372 .await
373 .unwrap()
374 });
375
376 let oidc_uri = format!(
377 "http://{ip}:{port}{path}",
378 ip = addr.ip(),
379 port = addr.port(),
380 path = oidc_uri_path
381 );
382
383 let client = AuthorizationCodeFlowClient::new(&oidc_uri);
384 let authorization_endpoint = client.build_authorization_endpoint().await;
385 assert!(authorization_endpoint.is_ok());
386 }
387
388 #[tokio::test]
389 async fn authorization_endpoint_must_be_tls() {
390 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
391 let addr = listener.local_addr().unwrap();
392
393 let oidc_uri_path = "/.well-known/openid-configuration";
394
395 tokio::spawn(async move {
396 axum::serve(
397 listener,
398 Router::new().route(
399 oidc_uri_path,
400 get(|| async {
401 Json(json!({
402 "token_endpoint": "http://_/token",
403 "authorization_endpoint": "http://_/authorize"
404 }))
405 }),
406 ),
407 )
408 .await
409 .unwrap()
410 });
411
412 let oidc_uri = format!(
413 "http://{ip}:{port}{path}",
414 ip = addr.ip(),
415 port = addr.port(),
416 path = oidc_uri_path
417 );
418
419 let client = AuthorizationCodeFlowClient::new(&oidc_uri);
420 let authorization_endpoint = client.build_authorization_endpoint().await;
421 assert!(authorization_endpoint.is_err());
422 }
423
424 #[tokio::test]
425 async fn authorization_request_param_response_type_must_be_correct() {
426 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
427 let addr = listener.local_addr().unwrap();
428
429 let oidc_uri_path = "/.well-known/openid-configuration";
430
431 tokio::spawn(async move {
432 axum::serve(
433 listener,
434 Router::new().route(
435 oidc_uri_path,
436 get(|| async {
437 Json(json!({
438 "token_endpoint": "https://_/token",
439 "authorization_endpoint": "https://_/authorize"
440 }))
441 }),
442 ),
443 )
444 .await
445 .unwrap()
446 });
447
448 let oidc_uri = format!(
449 "http://{ip}:{port}{path}",
450 ip = addr.ip(),
451 port = addr.port(),
452 path = oidc_uri_path
453 );
454
455 let client = AuthorizationCodeFlowClient::new(&oidc_uri);
456
457 let authorization_endpoint = client.build_authorization_endpoint().await.unwrap();
458
459 assert_eq!(
460 authorization_endpoint
461 .query_pairs()
462 .find(|n| n.0 == AuthenticationRequestParameters::ResponseType.to_string())
463 .map(|n| n.1.into_owned()),
464 Some(String::from("code"))
465 );
466 }
467
468 #[tokio::test]
469 async fn authorization_request_param_scope_must_be_openid() {
470 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
471 let addr = listener.local_addr().unwrap();
472
473 let oidc_uri_path = "/.well-known/openid-configuration";
474
475 tokio::spawn(async move {
476 axum::serve(
477 listener,
478 Router::new().route(
479 oidc_uri_path,
480 get(|| async {
481 Json(json!({
482 "token_endpoint": "https://_/token",
483 "authorization_endpoint": "https://_/authorize"
484 }))
485 }),
486 ),
487 )
488 .await
489 .unwrap()
490 });
491
492 let oidc_uri = format!(
493 "http://{ip}:{port}{path}",
494 ip = addr.ip(),
495 port = addr.port(),
496 path = oidc_uri_path
497 );
498
499 let client = AuthorizationCodeFlowClient::new(&oidc_uri);
500
501 let authorization_endpoint = client.build_authorization_endpoint().await.unwrap();
502
503 assert_eq!(
504 authorization_endpoint
505 .query_pairs()
506 .find(|n| n.0 == AuthenticationRequestParameters::Scope.to_string())
507 .map(|n| n.1.into_owned()),
508 Some(AuthenticationRequestScope::OpenID.to_string())
509 );
510 }
511
512 #[tokio::test]
513 async fn authorization_request_param_scope_can_be_added() {
514 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
515 let addr = listener.local_addr().unwrap();
516
517 let oidc_uri_path = "/.well-known/openid-configuration";
518
519 tokio::spawn(async move {
520 axum::serve(
521 listener,
522 Router::new().route(
523 oidc_uri_path,
524 get(|| async {
525 Json(json!({
526 "token_endpoint": "https://_/token",
527 "authorization_endpoint": "https://_/authorize"
528 }))
529 }),
530 ),
531 )
532 .await
533 .unwrap()
534 });
535
536 let oidc_uri = format!(
537 "http://{ip}:{port}{path}",
538 ip = addr.ip(),
539 port = addr.port(),
540 path = oidc_uri_path
541 );
542
543 let client = AuthorizationCodeFlowClient::new(&oidc_uri)
544 .with_scope(AuthenticationRequestScope::Email)
545 .with_scope(AuthenticationRequestScope::Address)
546 .with_scope(AuthenticationRequestScope::Phone)
547 .with_scope(AuthenticationRequestScope::Profile)
548 .with_scope(AuthenticationRequestScope::OfflineAccess)
549 .with_scope(AuthenticationRequestScope::Unchecked("api://_/.default"));
550
551 let authorization_endpoint = client.build_authorization_endpoint().await.unwrap();
552
553 assert_eq!(
554 authorization_endpoint
555 .query_pairs()
556 .find(|n| n.0 == AuthenticationRequestParameters::Scope.to_string())
557 .map(|n| n.1.into_owned()),
558 Some(String::from(
559 "openid email address phone profile offline_access api://_/.default"
560 ))
561 );
562 }
563
564 #[tokio::test]
565 async fn authorization_request_param_scope_type_openid_can_only_be_added_once() {
566 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
567 let addr = listener.local_addr().unwrap();
568
569 let oidc_uri_path = "/.well-known/openid-configuration";
570
571 tokio::spawn(async move {
572 axum::serve(
573 listener,
574 Router::new().route(
575 oidc_uri_path,
576 get(|| async {
577 Json(json!({
578 "token_endpoint": "https://_/token",
579 "authorization_endpoint": "https://_/authorize"
580 }))
581 }),
582 ),
583 )
584 .await
585 .unwrap()
586 });
587
588 let oidc_uri = format!(
589 "http://{ip}:{port}{path}",
590 ip = addr.ip(),
591 port = addr.port(),
592 path = oidc_uri_path
593 );
594
595 let client = AuthorizationCodeFlowClient::new(&oidc_uri)
596 .with_scope(AuthenticationRequestScope::OpenID)
597 .with_scope(AuthenticationRequestScope::OpenID);
598
599 let authorization_endpoint = client.build_authorization_endpoint().await.unwrap();
600
601 assert_eq!(
602 authorization_endpoint
603 .query_pairs()
604 .find(|n| n.0 == AuthenticationRequestParameters::Scope.to_string())
605 .map(|n| n.1.into_owned()),
606 Some(AuthenticationRequestScope::OpenID.to_string())
607 );
608 }
609 }
610}