Skip to main content

axum_security/oauth2/
builder.rs

1use std::{borrow::Cow, error::Error, fmt::Display, sync::Arc};
2
3use cookie_monster::{Cookie, CookieBuilder, SameSite};
4use oauth2::{
5    AuthUrl, Client, ClientId, ClientSecret, RedirectUrl, Scope, TokenUrl,
6    reqwest::Client as HttpClient, url,
7};
8
9use crate::{
10    cookie::{CookieContext, CookieSessionBuilder, CookieStore},
11    http::default_reqwest_client,
12    oauth2::{
13        OAuth2Context, OAuth2Handler, OAuthState, context::OAuth2ContextInner,
14        handler::ErasedOAuth2Handler,
15    },
16    utils::get_env,
17};
18
19static DEFAULT_COOKIE_NAME: &str = "oauth2-session";
20
21pub struct OAuth2ContextBuilder<S> {
22    cookie_session: CookieSessionBuilder<S>,
23    login_path: Option<Cow<'static, str>>,
24    redirect_url: Option<String>,
25    client_id: Option<String>,
26    client_secret: Option<String>,
27    scopes: Vec<Scope>,
28    auth_url: Option<String>,
29    token_url: Option<String>,
30    http_client: Option<HttpClient>,
31    flow_type: FlowType,
32}
33
34impl<S> OAuth2ContextBuilder<S> {
35    pub fn new(store: S) -> OAuth2ContextBuilder<S> {
36        // Make sure to use "/" as path so all paths can see the cookie in dev mode.
37        let dev_cookie = Cookie::named(DEFAULT_COOKIE_NAME)
38            .path("/")
39            .same_site(SameSite::Lax);
40
41        let cookie = Cookie::named(DEFAULT_COOKIE_NAME)
42            .http_only()
43            .same_site(SameSite::Strict)
44            .secure();
45
46        Self {
47            cookie_session: CookieContext::<()>::builder()
48                .store(store)
49                .cookie(|_| cookie)
50                .dev_cookie(|_| dev_cookie),
51            login_path: None,
52            redirect_url: None,
53            client_id: None,
54            client_secret: None,
55            scopes: Vec::new(),
56            auth_url: None,
57            token_url: None,
58            http_client: None,
59            flow_type: FlowType::AuthorizationCodeFlowPkce,
60        }
61    }
62
63    pub fn redirect_url(mut self, url: impl Into<String>) -> Self {
64        self.redirect_url = Some(url.into());
65        self
66    }
67
68    pub fn redirect_uri_env(self, name: &str) -> Self {
69        self.redirect_url(get_env(name))
70    }
71
72    pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
73        self.client_id = Some(client_id.into());
74        self
75    }
76
77    pub fn client_id_env(self, name: &str) -> Self {
78        self.client_id(get_env(name))
79    }
80
81    pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
82        self.client_secret = Some(client_secret.into());
83        self
84    }
85
86    pub fn client_secret_env(self, name: &str) -> Self {
87        self.client_secret(get_env(name))
88    }
89
90    pub fn auth_url(mut self, auth_url: impl Into<String>) -> Self {
91        self.auth_url = Some(auth_url.into());
92        self
93    }
94
95    pub fn auth_url_env(self, name: &str) -> Self {
96        self.auth_url(get_env(name))
97    }
98
99    pub fn token_url(mut self, token_url: impl Into<String>) -> Self {
100        self.token_url = Some(token_url.into());
101        self
102    }
103
104    pub fn token_url_env(self, name: &str) -> Self {
105        self.token_url(get_env(name))
106    }
107
108    pub fn scopes(mut self, scopes: &[&str]) -> Self {
109        self.scopes = scopes.iter().map(|s| Scope::new(s.to_string())).collect();
110        self
111    }
112
113    pub fn cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
114        self.cookie_session = self.cookie_session.cookie(f);
115        self
116    }
117
118    pub fn dev_cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
119        self.cookie_session = self.cookie_session.dev_cookie(f);
120        self
121    }
122
123    pub fn login_path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
124        self.login_path = Some(path.into());
125        self
126    }
127
128    pub fn use_dev_cookies(mut self, dev: bool) -> Self {
129        self.cookie_session = self.cookie_session.use_dev_cookie(dev);
130        self
131    }
132
133    pub fn use_normal_cookies(self, prod: bool) -> Self {
134        self.use_dev_cookies(!prod)
135    }
136
137    pub fn http_client(mut self, http_client: HttpClient) -> Self {
138        self.http_client = Some(http_client);
139        self
140    }
141
142    pub fn authorization_code_flow(mut self) -> Self {
143        self.flow_type = FlowType::AuthorizationCodeFlow;
144        self
145    }
146
147    pub fn authorization_code_flow_with_pkce(mut self) -> Self {
148        self.flow_type = FlowType::AuthorizationCodeFlowPkce;
149        self
150    }
151
152    pub fn store<S1>(self, store: S1) -> OAuth2ContextBuilder<S1> {
153        OAuth2ContextBuilder {
154            cookie_session: self.cookie_session.store(store),
155            login_path: self.login_path,
156            redirect_url: self.redirect_url,
157            client_id: self.client_id,
158            client_secret: self.client_secret,
159            scopes: self.scopes,
160            auth_url: self.auth_url,
161            token_url: self.token_url,
162            http_client: self.http_client,
163            flow_type: self.flow_type,
164        }
165    }
166
167    pub fn build<T>(self, inner: T) -> OAuth2Context
168    where
169        S: CookieStore<State = OAuthState>,
170        T: OAuth2Handler,
171    {
172        self.try_build(inner).unwrap()
173    }
174
175    pub fn try_build<T>(self, inner: T) -> Result<OAuth2Context, OAuth2BuilderError>
176    where
177        S: CookieStore<State = OAuthState>,
178        T: OAuth2Handler,
179    {
180        let client_id = self
181            .client_id
182            .ok_or(OAuth2BuilderError::MissingClientId)
183            .map(ClientId::new)?;
184
185        let redirect_url = self
186            .redirect_url
187            .ok_or(OAuth2BuilderError::MissingRedirectUrl)?;
188
189        let redirect_url =
190            RedirectUrl::new(redirect_url).map_err(OAuth2BuilderError::InvalidRedirectUrl)?;
191
192        let auth_url = self.auth_url.ok_or(OAuth2BuilderError::MissingAuthUrl)?;
193
194        let auth_url = AuthUrl::new(auth_url).map_err(OAuth2BuilderError::InvalidAuthUrl)?;
195
196        let token_url = self.token_url.ok_or(OAuth2BuilderError::MissingTokenUrl)?;
197
198        let token_url = TokenUrl::new(token_url).map_err(OAuth2BuilderError::InvalidTokenUrl)?;
199
200        let mut basic_client = Client::new(client_id)
201            .set_redirect_uri(redirect_url)
202            .set_auth_uri(auth_url)
203            .set_token_uri(token_url);
204
205        if let Some(client_secret) = self.client_secret {
206            basic_client = basic_client.set_client_secret(ClientSecret::new(client_secret));
207        }
208
209        Ok(OAuth2Context(Arc::new(OAuth2ContextInner {
210            client: basic_client,
211            inner: ErasedOAuth2Handler::new(inner),
212            session: self.cookie_session.build(),
213            login_path: self.login_path,
214            http_client: self.http_client.unwrap_or_else(default_reqwest_client),
215            scopes: self.scopes,
216            flow_type: self.flow_type,
217        })))
218    }
219}
220
221pub(crate) enum FlowType {
222    AuthorizationCodeFlow,
223    AuthorizationCodeFlowPkce,
224}
225
226#[derive(Debug)]
227pub enum OAuth2BuilderError {
228    MissingClientId,
229    MissingRedirectUrl,
230    MissingAuthUrl,
231    MissingTokenUrl,
232    InvalidRedirectUrl(url::ParseError),
233    InvalidAuthUrl(url::ParseError),
234    InvalidTokenUrl(url::ParseError),
235}
236
237impl Error for OAuth2BuilderError {}
238
239impl Display for OAuth2BuilderError {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        match self {
242            OAuth2BuilderError::MissingClientId => f.write_str("client id is missing"),
243            OAuth2BuilderError::MissingRedirectUrl => f.write_str("redirect url is missing"),
244            OAuth2BuilderError::MissingAuthUrl => f.write_str("authorization url is missing"),
245            OAuth2BuilderError::MissingTokenUrl => f.write_str("token url is missing"),
246            OAuth2BuilderError::InvalidRedirectUrl(parse_error) => {
247                write!(f, "could not parse redirect url: {}", parse_error)
248            }
249            OAuth2BuilderError::InvalidAuthUrl(parse_error) => {
250                write!(f, "could not parse authorization url: {}", parse_error)
251            }
252            OAuth2BuilderError::InvalidTokenUrl(parse_error) => {
253                write!(f, "could not parse token url: {}", parse_error)
254            }
255        }
256    }
257}
258
259#[cfg(test)]
260mod builder {
261    use axum::response::IntoResponse;
262
263    use crate::oauth2::{
264        AfterLoginCookies, OAuth2BuilderError, OAuth2Context, OAuth2Handler, TokenResponse,
265        providers::github,
266    };
267
268    const CLIENT_ID: &str = "test_client_id";
269    const CLIENT_SECRET: &str = "test_client_secret";
270    const REDIRECT_URL: &str = "http://rust-lang.org/redirect";
271    const AUTH_URL: &str = github::AUTH_URL;
272    const TOKEN_URL: &str = github::TOKEN_URL;
273
274    struct TestHandler {}
275
276    impl OAuth2Handler for TestHandler {
277        async fn after_login(
278            &self,
279            _token_res: TokenResponse,
280            _context: &mut AfterLoginCookies<'_>,
281        ) -> impl IntoResponse {
282            ()
283        }
284    }
285
286    #[test]
287    fn builder_errors() {
288        let res = OAuth2Context::builder()
289            .client_id(CLIENT_ID)
290            .client_secret(CLIENT_SECRET)
291            .auth_url(AUTH_URL)
292            .token_url(TOKEN_URL)
293            .redirect_url(REDIRECT_URL)
294            .try_build(TestHandler {});
295
296        assert!(res.is_ok());
297
298        let res = OAuth2Context::builder()
299            .client_id(CLIENT_ID)
300            .auth_url(AUTH_URL)
301            .token_url(TOKEN_URL)
302            .redirect_url(REDIRECT_URL)
303            .try_build(TestHandler {});
304
305        assert!(res.is_ok());
306    }
307
308    #[test]
309    fn client_id() {
310        let res = OAuth2Context::builder()
311            .client_secret(CLIENT_SECRET)
312            .auth_url(AUTH_URL)
313            .token_url(TOKEN_URL)
314            .redirect_url(REDIRECT_URL)
315            .try_build(TestHandler {});
316
317        assert!(matches!(res, Err(OAuth2BuilderError::MissingClientId)));
318    }
319
320    #[test]
321    fn auth_url() {
322        let res = OAuth2Context::builder()
323            .client_id(CLIENT_ID)
324            .client_secret(CLIENT_SECRET)
325            .token_url(TOKEN_URL)
326            .redirect_url(REDIRECT_URL)
327            .try_build(TestHandler {});
328
329        assert!(matches!(res, Err(OAuth2BuilderError::MissingAuthUrl)));
330
331        let res = OAuth2Context::builder()
332            .client_id(CLIENT_ID)
333            .client_secret(CLIENT_SECRET)
334            .auth_url("not an url")
335            .token_url(TOKEN_URL)
336            .redirect_url(REDIRECT_URL)
337            .try_build(TestHandler {});
338
339        assert!(matches!(res, Err(OAuth2BuilderError::InvalidAuthUrl(_))));
340    }
341
342    #[test]
343    fn token_url() {
344        let res = OAuth2Context::builder()
345            .client_id(CLIENT_ID)
346            .client_secret(CLIENT_SECRET)
347            .auth_url(AUTH_URL)
348            .redirect_url(REDIRECT_URL)
349            .try_build(TestHandler {});
350
351        assert!(matches!(res, Err(OAuth2BuilderError::MissingTokenUrl)));
352
353        let res = OAuth2Context::builder()
354            .client_id(CLIENT_ID)
355            .client_secret(CLIENT_SECRET)
356            .auth_url(AUTH_URL)
357            .token_url("not an url")
358            .redirect_url(REDIRECT_URL)
359            .try_build(TestHandler {});
360
361        assert!(matches!(res, Err(OAuth2BuilderError::InvalidTokenUrl(_))));
362    }
363
364    #[test]
365    fn redirect_url() {
366        let res = OAuth2Context::builder()
367            .client_id(CLIENT_ID)
368            .client_secret(CLIENT_SECRET)
369            .auth_url(AUTH_URL)
370            .token_url(TOKEN_URL)
371            .try_build(TestHandler {});
372
373        assert!(matches!(res, Err(OAuth2BuilderError::MissingRedirectUrl)));
374
375        let res = OAuth2Context::builder()
376            .client_id(CLIENT_ID)
377            .client_secret(CLIENT_SECRET)
378            .auth_url(AUTH_URL)
379            .token_url(TOKEN_URL)
380            .redirect_url("not an url")
381            .try_build(TestHandler {});
382
383        assert!(matches!(
384            res,
385            Err(OAuth2BuilderError::InvalidRedirectUrl(_))
386        ));
387    }
388}