1use std::{borrow::Cow, error::Error, fmt::Display, sync::Arc};
2
3use cookie_monster::{Cookie, CookieBuilder, SameSite};
4use oauth2::{
5 AuthUrl, Client, ClientId, ClientSecret, RedirectUrl, Scope, TokenUrl,
6 reqwest::Client as HttpClient, url,
7};
8
9use crate::{
10 cookie::{CookieContext, CookieSessionBuilder, CookieStore},
11 http::default_reqwest_client,
12 oauth2::{
13 OAuth2Context, OAuth2Handler, OAuthState, context::OAuth2ContextInner,
14 handler::ErasedOAuth2Handler,
15 },
16 utils::get_env,
17};
18
19static DEFAULT_COOKIE_NAME: &str = "oauth2-session";
20
21pub struct OAuth2ContextBuilder<S> {
22 cookie_session: CookieSessionBuilder<S>,
23 login_path: Option<Cow<'static, str>>,
24 redirect_url: Option<String>,
25 client_id: Option<String>,
26 client_secret: Option<String>,
27 scopes: Vec<Scope>,
28 auth_url: Option<String>,
29 token_url: Option<String>,
30 http_client: Option<HttpClient>,
31 flow_type: FlowType,
32}
33
34impl<S> OAuth2ContextBuilder<S> {
35 pub fn new(store: S) -> OAuth2ContextBuilder<S> {
36 let dev_cookie = Cookie::named(DEFAULT_COOKIE_NAME)
38 .path("/")
39 .same_site(SameSite::Lax);
40
41 let cookie = Cookie::named(DEFAULT_COOKIE_NAME)
42 .http_only()
43 .same_site(SameSite::Strict)
44 .secure();
45
46 Self {
47 cookie_session: CookieContext::<()>::builder()
48 .store(store)
49 .cookie(|_| cookie)
50 .dev_cookie(|_| dev_cookie),
51 login_path: None,
52 redirect_url: None,
53 client_id: None,
54 client_secret: None,
55 scopes: Vec::new(),
56 auth_url: None,
57 token_url: None,
58 http_client: None,
59 flow_type: FlowType::AuthorizationCodeFlowPkce,
60 }
61 }
62
63 pub fn redirect_url(mut self, url: impl Into<String>) -> Self {
64 self.redirect_url = Some(url.into());
65 self
66 }
67
68 pub fn redirect_uri_env(self, name: &str) -> Self {
69 self.redirect_url(get_env(name))
70 }
71
72 pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
73 self.client_id = Some(client_id.into());
74 self
75 }
76
77 pub fn client_id_env(self, name: &str) -> Self {
78 self.client_id(get_env(name))
79 }
80
81 pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
82 self.client_secret = Some(client_secret.into());
83 self
84 }
85
86 pub fn client_secret_env(self, name: &str) -> Self {
87 self.client_secret(get_env(name))
88 }
89
90 pub fn auth_url(mut self, auth_url: impl Into<String>) -> Self {
91 self.auth_url = Some(auth_url.into());
92 self
93 }
94
95 pub fn auth_url_env(self, name: &str) -> Self {
96 self.auth_url(get_env(name))
97 }
98
99 pub fn token_url(mut self, token_url: impl Into<String>) -> Self {
100 self.token_url = Some(token_url.into());
101 self
102 }
103
104 pub fn token_url_env(self, name: &str) -> Self {
105 self.token_url(get_env(name))
106 }
107
108 pub fn scopes(mut self, scopes: &[&str]) -> Self {
109 self.scopes = scopes.iter().map(|s| Scope::new(s.to_string())).collect();
110 self
111 }
112
113 pub fn cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
114 self.cookie_session = self.cookie_session.cookie(f);
115 self
116 }
117
118 pub fn dev_cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
119 self.cookie_session = self.cookie_session.dev_cookie(f);
120 self
121 }
122
123 pub fn login_path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
124 self.login_path = Some(path.into());
125 self
126 }
127
128 pub fn use_dev_cookies(mut self, dev: bool) -> Self {
129 self.cookie_session = self.cookie_session.use_dev_cookie(dev);
130 self
131 }
132
133 pub fn use_normal_cookies(self, prod: bool) -> Self {
134 self.use_dev_cookies(!prod)
135 }
136
137 pub fn http_client(mut self, http_client: HttpClient) -> Self {
138 self.http_client = Some(http_client);
139 self
140 }
141
142 pub fn authorization_code_flow(mut self) -> Self {
143 self.flow_type = FlowType::AuthorizationCodeFlow;
144 self
145 }
146
147 pub fn authorization_code_flow_with_pkce(mut self) -> Self {
148 self.flow_type = FlowType::AuthorizationCodeFlowPkce;
149 self
150 }
151
152 pub fn store<S1>(self, store: S1) -> OAuth2ContextBuilder<S1> {
153 OAuth2ContextBuilder {
154 cookie_session: self.cookie_session.store(store),
155 login_path: self.login_path,
156 redirect_url: self.redirect_url,
157 client_id: self.client_id,
158 client_secret: self.client_secret,
159 scopes: self.scopes,
160 auth_url: self.auth_url,
161 token_url: self.token_url,
162 http_client: self.http_client,
163 flow_type: self.flow_type,
164 }
165 }
166
167 pub fn build<T>(self, inner: T) -> OAuth2Context
168 where
169 S: CookieStore<State = OAuthState>,
170 T: OAuth2Handler,
171 {
172 self.try_build(inner).unwrap()
173 }
174
175 pub fn try_build<T>(self, inner: T) -> Result<OAuth2Context, OAuth2BuilderError>
176 where
177 S: CookieStore<State = OAuthState>,
178 T: OAuth2Handler,
179 {
180 let client_id = self
181 .client_id
182 .ok_or(OAuth2BuilderError::MissingClientId)
183 .map(ClientId::new)?;
184
185 let redirect_url = self
186 .redirect_url
187 .ok_or(OAuth2BuilderError::MissingRedirectUrl)?;
188
189 let redirect_url =
190 RedirectUrl::new(redirect_url).map_err(OAuth2BuilderError::InvalidRedirectUrl)?;
191
192 let auth_url = self.auth_url.ok_or(OAuth2BuilderError::MissingAuthUrl)?;
193
194 let auth_url = AuthUrl::new(auth_url).map_err(OAuth2BuilderError::InvalidAuthUrl)?;
195
196 let token_url = self.token_url.ok_or(OAuth2BuilderError::MissingTokenUrl)?;
197
198 let token_url = TokenUrl::new(token_url).map_err(OAuth2BuilderError::InvalidTokenUrl)?;
199
200 let mut basic_client = Client::new(client_id)
201 .set_redirect_uri(redirect_url)
202 .set_auth_uri(auth_url)
203 .set_token_uri(token_url);
204
205 if let Some(client_secret) = self.client_secret {
206 basic_client = basic_client.set_client_secret(ClientSecret::new(client_secret));
207 }
208
209 Ok(OAuth2Context(Arc::new(OAuth2ContextInner {
210 client: basic_client,
211 inner: ErasedOAuth2Handler::new(inner),
212 session: self.cookie_session.build(),
213 login_path: self.login_path,
214 http_client: self.http_client.unwrap_or_else(default_reqwest_client),
215 scopes: self.scopes,
216 flow_type: self.flow_type,
217 })))
218 }
219}
220
221pub(crate) enum FlowType {
222 AuthorizationCodeFlow,
223 AuthorizationCodeFlowPkce,
224}
225
226#[derive(Debug)]
227pub enum OAuth2BuilderError {
228 MissingClientId,
229 MissingRedirectUrl,
230 MissingAuthUrl,
231 MissingTokenUrl,
232 InvalidRedirectUrl(url::ParseError),
233 InvalidAuthUrl(url::ParseError),
234 InvalidTokenUrl(url::ParseError),
235}
236
237impl Error for OAuth2BuilderError {}
238
239impl Display for OAuth2BuilderError {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 match self {
242 OAuth2BuilderError::MissingClientId => f.write_str("client id is missing"),
243 OAuth2BuilderError::MissingRedirectUrl => f.write_str("redirect url is missing"),
244 OAuth2BuilderError::MissingAuthUrl => f.write_str("authorization url is missing"),
245 OAuth2BuilderError::MissingTokenUrl => f.write_str("token url is missing"),
246 OAuth2BuilderError::InvalidRedirectUrl(parse_error) => {
247 write!(f, "could not parse redirect url: {}", parse_error)
248 }
249 OAuth2BuilderError::InvalidAuthUrl(parse_error) => {
250 write!(f, "could not parse authorization url: {}", parse_error)
251 }
252 OAuth2BuilderError::InvalidTokenUrl(parse_error) => {
253 write!(f, "could not parse token url: {}", parse_error)
254 }
255 }
256 }
257}
258
259#[cfg(test)]
260mod builder {
261 use axum::response::IntoResponse;
262
263 use crate::oauth2::{
264 AfterLoginCookies, OAuth2BuilderError, OAuth2Context, OAuth2Handler, TokenResponse,
265 providers::github,
266 };
267
268 const CLIENT_ID: &str = "test_client_id";
269 const CLIENT_SECRET: &str = "test_client_secret";
270 const REDIRECT_URL: &str = "http://rust-lang.org/redirect";
271 const AUTH_URL: &str = github::AUTH_URL;
272 const TOKEN_URL: &str = github::TOKEN_URL;
273
274 struct TestHandler {}
275
276 impl OAuth2Handler for TestHandler {
277 async fn after_login(
278 &self,
279 _token_res: TokenResponse,
280 _context: &mut AfterLoginCookies<'_>,
281 ) -> impl IntoResponse {
282 ()
283 }
284 }
285
286 #[test]
287 fn builder_errors() {
288 let res = OAuth2Context::builder()
289 .client_id(CLIENT_ID)
290 .client_secret(CLIENT_SECRET)
291 .auth_url(AUTH_URL)
292 .token_url(TOKEN_URL)
293 .redirect_url(REDIRECT_URL)
294 .try_build(TestHandler {});
295
296 assert!(res.is_ok());
297
298 let res = OAuth2Context::builder()
299 .client_id(CLIENT_ID)
300 .auth_url(AUTH_URL)
301 .token_url(TOKEN_URL)
302 .redirect_url(REDIRECT_URL)
303 .try_build(TestHandler {});
304
305 assert!(res.is_ok());
306 }
307
308 #[test]
309 fn client_id() {
310 let res = OAuth2Context::builder()
311 .client_secret(CLIENT_SECRET)
312 .auth_url(AUTH_URL)
313 .token_url(TOKEN_URL)
314 .redirect_url(REDIRECT_URL)
315 .try_build(TestHandler {});
316
317 assert!(matches!(res, Err(OAuth2BuilderError::MissingClientId)));
318 }
319
320 #[test]
321 fn auth_url() {
322 let res = OAuth2Context::builder()
323 .client_id(CLIENT_ID)
324 .client_secret(CLIENT_SECRET)
325 .token_url(TOKEN_URL)
326 .redirect_url(REDIRECT_URL)
327 .try_build(TestHandler {});
328
329 assert!(matches!(res, Err(OAuth2BuilderError::MissingAuthUrl)));
330
331 let res = OAuth2Context::builder()
332 .client_id(CLIENT_ID)
333 .client_secret(CLIENT_SECRET)
334 .auth_url("not an url")
335 .token_url(TOKEN_URL)
336 .redirect_url(REDIRECT_URL)
337 .try_build(TestHandler {});
338
339 assert!(matches!(res, Err(OAuth2BuilderError::InvalidAuthUrl(_))));
340 }
341
342 #[test]
343 fn token_url() {
344 let res = OAuth2Context::builder()
345 .client_id(CLIENT_ID)
346 .client_secret(CLIENT_SECRET)
347 .auth_url(AUTH_URL)
348 .redirect_url(REDIRECT_URL)
349 .try_build(TestHandler {});
350
351 assert!(matches!(res, Err(OAuth2BuilderError::MissingTokenUrl)));
352
353 let res = OAuth2Context::builder()
354 .client_id(CLIENT_ID)
355 .client_secret(CLIENT_SECRET)
356 .auth_url(AUTH_URL)
357 .token_url("not an url")
358 .redirect_url(REDIRECT_URL)
359 .try_build(TestHandler {});
360
361 assert!(matches!(res, Err(OAuth2BuilderError::InvalidTokenUrl(_))));
362 }
363
364 #[test]
365 fn redirect_url() {
366 let res = OAuth2Context::builder()
367 .client_id(CLIENT_ID)
368 .client_secret(CLIENT_SECRET)
369 .auth_url(AUTH_URL)
370 .token_url(TOKEN_URL)
371 .try_build(TestHandler {});
372
373 assert!(matches!(res, Err(OAuth2BuilderError::MissingRedirectUrl)));
374
375 let res = OAuth2Context::builder()
376 .client_id(CLIENT_ID)
377 .client_secret(CLIENT_SECRET)
378 .auth_url(AUTH_URL)
379 .token_url(TOKEN_URL)
380 .redirect_url("not an url")
381 .try_build(TestHandler {});
382
383 assert!(matches!(
384 res,
385 Err(OAuth2BuilderError::InvalidRedirectUrl(_))
386 ));
387 }
388}