loco_oauth2/controllers/
oauth2.rs

1#![allow(clippy::unused_async)]
2
3use crate::OAuth2ClientStore;
4use axum::{extract::Query, response::Redirect, Extension};
5use axum_session::{DatabasePool, Session};
6use loco_rs::prelude::*;
7use serde::de::DeserializeOwned;
8use serde::Deserialize;
9use std::fmt::Debug;
10use tokio::sync::MutexGuard;
11
12use crate::controllers::middleware::OAuth2PrivateCookieJarTrait;
13use crate::controllers::middleware::{OAuth2CookieUser, OAuth2PrivateCookieJar};
14use crate::grants::authorization_code::GrantTrait;
15use crate::models::oauth2_sessions::OAuth2SessionsTrait;
16use crate::models::users::OAuth2UserTrait;
17
18#[derive(Debug, Deserialize)]
19pub struct AuthParams {
20    code: String,
21    state: String,
22}
23
24/// Helper function to get the authorization URL and save the CSRF token in the session
25///
26/// # Generics
27/// * `T` - The database pool
28/// # Arguments
29/// * `session` - The axum session
30/// * `oauth2_client` - The `AuthorizationCodeGrant` client
31/// # Returns
32/// * `String` - The authorization URL
33pub async fn get_authorization_url<T: DatabasePool + Clone + Debug + Sync + Send + 'static>(
34    session: Session<T>,
35    oauth2_client: &mut MutexGuard<'_, dyn GrantTrait>,
36) -> String {
37    let (auth_url, csrf_token) = oauth2_client.get_authorization_url();
38    session.set("CSRF_TOKEN", csrf_token.secret().to_owned());
39    auth_url.to_string()
40}
41
42/// Helper function to exchange the code for a token and then get the user profile
43/// then upsert the user and the session and set the token in a short live
44/// cookie
45///
46/// Lastly, it will redirect the user to the protected URL
47/// # Generics
48/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
49/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
50/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
51/// * `W` - The database pool
52/// # Arguments
53/// * `ctx` - The application context
54/// * `session` - The axum session
55/// * `params` - The query parameters
56/// * `jar` - The oauth2 private cookie jar
57/// * `client` - The `AuthorizationCodeGrant` client
58/// # Returns
59/// * `Result<impl IntoResponse>` - The response with the short live cookie and the redirect to the protected URL
60/// # Errors
61/// * `loco_rs::errors::Error`
62pub async fn callback<
63    T: DeserializeOwned + Send,
64    U: OAuth2UserTrait<T> + ModelTrait,
65    V: OAuth2SessionsTrait<U>,
66    W: DatabasePool + Clone + Debug + Sync + Send + 'static,
67>(
68    ctx: AppContext,
69    session: Session<W>,
70    params: AuthParams,
71    // Extract the private cookie jar from the request
72    jar: OAuth2PrivateCookieJar,
73    client: &mut MutexGuard<'_, dyn GrantTrait>,
74) -> Result<impl IntoResponse> {
75    // Get the CSRF token from the session
76    let csrf_token = session
77        .get::<String>("CSRF_TOKEN")
78        .ok_or_else(|| Error::BadRequest("CSRF token not found".to_string()))?;
79    // Exchange the code with a token
80    let (token, profile) = client
81        .verify_code_from_callback(params.code, params.state, csrf_token)
82        .await
83        .map_err(|e| Error::BadRequest(e.to_string()))?;
84    // Get the user profile
85    let profile = profile.json::<T>().await.map_err(|e| {
86        tracing::error!("Error getting profile: {:?}", e);
87        Error::InternalServerError
88    })?;
89    let user = U::upsert_with_oauth(&ctx.db, &profile)
90        .await
91        .map_err(|_e| {
92            tracing::error!("Error creating user");
93            Error::InternalServerError
94        })?;
95    V::upsert_with_oauth2(&ctx.db, &token, &user)
96        .await
97        .map_err(|_e| {
98            tracing::error!("Error creating session");
99            Error::InternalServerError
100        })?;
101    let oauth2_cookie_config = client.get_cookie_config();
102    let jar = OAuth2PrivateCookieJar::create_short_live_cookie_with_token_response(
103        oauth2_cookie_config,
104        &token,
105        jar,
106    )
107    .map_err(|_e| Error::InternalServerError)?;
108    let protect_url = oauth2_cookie_config
109        .protected_url
110        .clone()
111        .unwrap_or_else(|| "/oauth2/protected".to_string());
112    let response = (jar, Redirect::to(&protect_url)).into_response();
113    tracing::info!("response: {:?}", response);
114    Ok(response)
115}
116
117/// Helper function to exchange the code for a token and then get the user profile
118/// then upsert the user and the session and set the token in a short live
119/// cookie
120///
121/// Lastly, it will redirect the user to the protected URL
122/// # Generics
123/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
124/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
125/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
126/// * `W` - The database pool
127/// # Arguments
128/// * `ctx` - The application context
129/// * `session` - The axum session
130/// * `params` - The query parameters
131/// * `client` - The `AuthorizationCodeGrant` client
132/// # Returns
133/// * `Result<U>` - The user
134/// # Errors
135/// * `loco_rs::errors::Error`
136pub async fn callback_jwt<
137    T: DeserializeOwned + Send,
138    U: OAuth2UserTrait<T> + ModelTrait,
139    V: OAuth2SessionsTrait<U>,
140    W: DatabasePool + Clone + Debug + Sync + Send + 'static,
141>(
142    ctx: &AppContext,
143    session: Session<W>,
144    params: AuthParams,
145    client: &mut MutexGuard<'_, dyn GrantTrait>,
146) -> Result<U> {
147    // Get the CSRF token from the session
148    let csrf_token = session
149        .get::<String>("CSRF_TOKEN")
150        .ok_or_else(|| Error::BadRequest("CSRF token not found".to_string()))?;
151    // Exchange the code with a token
152    let (token, profile) = client
153        .verify_code_from_callback(params.code, params.state, csrf_token)
154        .await
155        .map_err(|e| Error::BadRequest(e.to_string()))?;
156    // Get the user profile
157    let profile = profile.json::<T>().await.map_err(|e| {
158        tracing::error!("Error getting profile: {:?}", e);
159        Error::InternalServerError
160    })?;
161    let user = U::upsert_with_oauth(&ctx.db, &profile)
162        .await
163        .map_err(|_e| {
164            tracing::error!("Error creating user");
165            Error::InternalServerError
166        })?;
167    V::upsert_with_oauth2(&ctx.db, &token, &user)
168        .await
169        .map_err(|_e| {
170            tracing::error!("Error creating session");
171            Error::InternalServerError
172        })?;
173
174    Ok(user)
175}
176
177/// The authorization URL for the `OAuth2` flow
178///
179/// This will redirect the user to the `OAuth2` provider's login page
180/// and then to the callback URL
181/// # Generics
182/// * `T` - The database pool
183/// # Arguments
184/// * `session` - The axum session
185/// * `oauth2_store` - The `OAuth2ClientStore` extension
186/// # Returns
187/// The HTML response with the link to the `OAuth2` provider's login page
188/// # Errors
189/// `loco_rs::errors::Error` - When the `OAuth2` client cannot be retrieved
190pub async fn google_authorization_url<T: DatabasePool + Clone + Debug + Sync + Send + 'static>(
191    session: Session<T>,
192    Extension(oauth2_store): Extension<OAuth2ClientStore>,
193) -> Result<String> {
194    let mut client = oauth2_store
195        .get_authorization_code_client("google")
196        .await
197        .map_err(|e| {
198            tracing::error!("Error getting client: {:?}", e);
199            Error::InternalServerError
200        })?;
201    let auth_url = get_authorization_url(session, &mut client).await;
202    drop(client);
203    Ok(auth_url)
204}
205
206/// The callback URL for the `OAuth2` flow
207///
208/// This will exchange the code for a token and then get the user profile
209/// then upsert the user and the session and set the token in a short live
210/// cookie
211///
212/// Lastly, it will redirect the user to the protected URL
213/// # Generics
214/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
215/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
216/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
217/// * `W` - The database pool
218/// # Arguments
219/// * `ctx` - The application context
220/// * `session` - The axum session
221/// * `params` - The query parameters
222/// * `jar` - The oauth2 private cookie jar
223/// * `oauth2_store` - The `OAuth2ClientStore` extension
224/// # Returns
225/// The response with the short live cookie and the redirect to the protected
226/// URL
227/// # Errors
228/// * `loco_rs::errors::Error`
229pub async fn google_callback_cookie<
230    T: DeserializeOwned + Send,
231    U: OAuth2UserTrait<T> + ModelTrait,
232    V: OAuth2SessionsTrait<U>,
233    W: DatabasePool + Clone + Debug + Sync + Send + 'static,
234>(
235    State(ctx): State<AppContext>,
236    session: Session<W>,
237    Query(params): Query<AuthParams>,
238    // Extract the private cookie jar from the request
239    jar: OAuth2PrivateCookieJar,
240    Extension(oauth2_store): Extension<OAuth2ClientStore>,
241) -> Result<impl IntoResponse> {
242    let mut client = oauth2_store
243        .get_authorization_code_client("google")
244        .await
245        .map_err(|e| {
246            tracing::error!("Error getting client: {:?}", e);
247            Error::InternalServerError
248        })?;
249    let response = callback::<T, U, V, W>(ctx, session, params, jar, &mut client).await?;
250    drop(client);
251    Ok(response)
252}
253
254/// The callback URL for the `OAuth2` flow
255///
256/// This will exchange the code for a token and then get the user profile
257/// then upsert the user and the session and set the token in a short live
258/// cookie.
259///
260/// Lastly, it will redirect the user to the protected URL
261/// # Generics
262/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
263/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
264/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
265/// * `W` - The database pool
266/// # Arguments
267/// * `ctx` - The application context
268/// * `session` - The axum session
269/// * `params` - The query parameters
270/// * `oauth2_store` - The `OAuth2ClientStore` extension
271/// # Return
272/// * `Result<impl IntoResponse>` - The response with the jwt token
273/// # Errors
274/// * `loco_rs::errors::Error`
275pub async fn google_callback_jwt<
276    T: DeserializeOwned + Send,
277    U: OAuth2UserTrait<T> + ModelTrait,
278    V: OAuth2SessionsTrait<U>,
279    W: DatabasePool + Clone + Debug + Sync + Send + 'static,
280>(
281    State(ctx): State<AppContext>,
282    session: Session<W>,
283    Query(params): Query<AuthParams>,
284    Extension(oauth2_store): Extension<OAuth2ClientStore>,
285) -> Result<impl IntoResponse> {
286    let mut client = oauth2_store
287        .get_authorization_code_client("google")
288        .await
289        .map_err(|e| {
290            tracing::error!("Error getting client: {:?}", e);
291            Error::InternalServerError
292        })?;
293    let jwt_secret = ctx.config.get_jwt_config()?;
294    let user = callback_jwt::<T, U, V, W>(&ctx, session, params, &mut client).await?;
295    drop(client);
296    let token = user
297        .generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
298        .or_else(|_| unauthorized("unauthorized!"))?;
299    Ok(token)
300}
301
302/// The protected URL for the `OAuth2` flow
303/// This will return a message indicating that the user is protected
304///
305/// # Generics
306/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
307/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
308/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
309/// # Arguments
310/// * `user` - The `OAuth2CookieUser` that holds the user and the session
311/// # Returns
312/// The response with the message indicating that the user is protected
313/// # Errors
314/// * `loco_rs::errors::Error` - When the user cannot be retrieved
315pub async fn protected<
316    T: DeserializeOwned + Send,
317    U: OAuth2UserTrait<T> + ModelTrait,
318    V: OAuth2SessionsTrait<U> + ModelTrait,
319>(
320    user: OAuth2CookieUser<T, U, V>,
321) -> Result<impl IntoResponse> {
322    let _user = user.as_ref();
323    Ok("You are protected!".to_string())
324}