1#![allow(clippy::unused_async)]
2
3use crate::OAuth2ClientStore;
4use axum::{extract::Query, response::Redirect, Extension};
5use axum_session::{DatabasePool, Session};
6use loco_rs::prelude::*;
7use serde::de::DeserializeOwned;
8use serde::Deserialize;
9use std::fmt::Debug;
10use tokio::sync::MutexGuard;
11
12use crate::controllers::middleware::OAuth2PrivateCookieJarTrait;
13use crate::controllers::middleware::{OAuth2CookieUser, OAuth2PrivateCookieJar};
14use crate::grants::authorization_code::GrantTrait;
15use crate::models::oauth2_sessions::OAuth2SessionsTrait;
16use crate::models::users::OAuth2UserTrait;
17
18#[derive(Debug, Deserialize)]
19pub struct AuthParams {
20 code: String,
21 state: String,
22}
23
24pub async fn get_authorization_url<T: DatabasePool + Clone + Debug + Sync + Send + 'static>(
34 session: Session<T>,
35 oauth2_client: &mut MutexGuard<'_, dyn GrantTrait>,
36) -> String {
37 let (auth_url, csrf_token) = oauth2_client.get_authorization_url();
38 session.set("CSRF_TOKEN", csrf_token.secret().to_owned());
39 auth_url.to_string()
40}
41
42pub async fn callback<
63 T: DeserializeOwned + Send,
64 U: OAuth2UserTrait<T> + ModelTrait,
65 V: OAuth2SessionsTrait<U>,
66 W: DatabasePool + Clone + Debug + Sync + Send + 'static,
67>(
68 ctx: AppContext,
69 session: Session<W>,
70 params: AuthParams,
71 jar: OAuth2PrivateCookieJar,
73 client: &mut MutexGuard<'_, dyn GrantTrait>,
74) -> Result<impl IntoResponse> {
75 let csrf_token = session
77 .get::<String>("CSRF_TOKEN")
78 .ok_or_else(|| Error::BadRequest("CSRF token not found".to_string()))?;
79 let (token, profile) = client
81 .verify_code_from_callback(params.code, params.state, csrf_token)
82 .await
83 .map_err(|e| Error::BadRequest(e.to_string()))?;
84 let profile = profile.json::<T>().await.map_err(|e| {
86 tracing::error!("Error getting profile: {:?}", e);
87 Error::InternalServerError
88 })?;
89 let user = U::upsert_with_oauth(&ctx.db, &profile)
90 .await
91 .map_err(|_e| {
92 tracing::error!("Error creating user");
93 Error::InternalServerError
94 })?;
95 V::upsert_with_oauth2(&ctx.db, &token, &user)
96 .await
97 .map_err(|_e| {
98 tracing::error!("Error creating session");
99 Error::InternalServerError
100 })?;
101 let oauth2_cookie_config = client.get_cookie_config();
102 let jar = OAuth2PrivateCookieJar::create_short_live_cookie_with_token_response(
103 oauth2_cookie_config,
104 &token,
105 jar,
106 )
107 .map_err(|_e| Error::InternalServerError)?;
108 let protect_url = oauth2_cookie_config
109 .protected_url
110 .clone()
111 .unwrap_or_else(|| "/oauth2/protected".to_string());
112 let response = (jar, Redirect::to(&protect_url)).into_response();
113 tracing::info!("response: {:?}", response);
114 Ok(response)
115}
116
117pub async fn callback_jwt<
137 T: DeserializeOwned + Send,
138 U: OAuth2UserTrait<T> + ModelTrait,
139 V: OAuth2SessionsTrait<U>,
140 W: DatabasePool + Clone + Debug + Sync + Send + 'static,
141>(
142 ctx: &AppContext,
143 session: Session<W>,
144 params: AuthParams,
145 client: &mut MutexGuard<'_, dyn GrantTrait>,
146) -> Result<U> {
147 let csrf_token = session
149 .get::<String>("CSRF_TOKEN")
150 .ok_or_else(|| Error::BadRequest("CSRF token not found".to_string()))?;
151 let (token, profile) = client
153 .verify_code_from_callback(params.code, params.state, csrf_token)
154 .await
155 .map_err(|e| Error::BadRequest(e.to_string()))?;
156 let profile = profile.json::<T>().await.map_err(|e| {
158 tracing::error!("Error getting profile: {:?}", e);
159 Error::InternalServerError
160 })?;
161 let user = U::upsert_with_oauth(&ctx.db, &profile)
162 .await
163 .map_err(|_e| {
164 tracing::error!("Error creating user");
165 Error::InternalServerError
166 })?;
167 V::upsert_with_oauth2(&ctx.db, &token, &user)
168 .await
169 .map_err(|_e| {
170 tracing::error!("Error creating session");
171 Error::InternalServerError
172 })?;
173
174 Ok(user)
175}
176
177pub async fn google_authorization_url<T: DatabasePool + Clone + Debug + Sync + Send + 'static>(
191 session: Session<T>,
192 Extension(oauth2_store): Extension<OAuth2ClientStore>,
193) -> Result<String> {
194 let mut client = oauth2_store
195 .get_authorization_code_client("google")
196 .await
197 .map_err(|e| {
198 tracing::error!("Error getting client: {:?}", e);
199 Error::InternalServerError
200 })?;
201 let auth_url = get_authorization_url(session, &mut client).await;
202 drop(client);
203 Ok(auth_url)
204}
205
206pub async fn google_callback_cookie<
230 T: DeserializeOwned + Send,
231 U: OAuth2UserTrait<T> + ModelTrait,
232 V: OAuth2SessionsTrait<U>,
233 W: DatabasePool + Clone + Debug + Sync + Send + 'static,
234>(
235 State(ctx): State<AppContext>,
236 session: Session<W>,
237 Query(params): Query<AuthParams>,
238 jar: OAuth2PrivateCookieJar,
240 Extension(oauth2_store): Extension<OAuth2ClientStore>,
241) -> Result<impl IntoResponse> {
242 let mut client = oauth2_store
243 .get_authorization_code_client("google")
244 .await
245 .map_err(|e| {
246 tracing::error!("Error getting client: {:?}", e);
247 Error::InternalServerError
248 })?;
249 let response = callback::<T, U, V, W>(ctx, session, params, jar, &mut client).await?;
250 drop(client);
251 Ok(response)
252}
253
254pub async fn google_callback_jwt<
276 T: DeserializeOwned + Send,
277 U: OAuth2UserTrait<T> + ModelTrait,
278 V: OAuth2SessionsTrait<U>,
279 W: DatabasePool + Clone + Debug + Sync + Send + 'static,
280>(
281 State(ctx): State<AppContext>,
282 session: Session<W>,
283 Query(params): Query<AuthParams>,
284 Extension(oauth2_store): Extension<OAuth2ClientStore>,
285) -> Result<impl IntoResponse> {
286 let mut client = oauth2_store
287 .get_authorization_code_client("google")
288 .await
289 .map_err(|e| {
290 tracing::error!("Error getting client: {:?}", e);
291 Error::InternalServerError
292 })?;
293 let jwt_secret = ctx.config.get_jwt_config()?;
294 let user = callback_jwt::<T, U, V, W>(&ctx, session, params, &mut client).await?;
295 drop(client);
296 let token = user
297 .generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
298 .or_else(|_| unauthorized("unauthorized!"))?;
299 Ok(token)
300}
301
302pub async fn protected<
316 T: DeserializeOwned + Send,
317 U: OAuth2UserTrait<T> + ModelTrait,
318 V: OAuth2SessionsTrait<U> + ModelTrait,
319>(
320 user: OAuth2CookieUser<T, U, V>,
321) -> Result<impl IntoResponse> {
322 let _user = user.as_ref();
323 Ok("You are protected!".to_string())
324}