email/account/config/
oauth2.rs

1//! Module dedicated to OAuth 2.0 configuration.
2//!
3//! This module contains everything related to OAuth 2.0
4//! configuration.
5
6use std::{fmt, io, net::TcpListener, vec};
7
8use oauth::v2_0::{AuthorizationCodeGrant, Client, RefreshAccessToken};
9use secret::Secret;
10use tracing::debug;
11
12#[doc(inline)]
13pub use super::{Error, Result};
14
15/// The OAuth 2.0 configuration.
16#[derive(Clone, Debug, Default, Eq, PartialEq)]
17#[cfg_attr(
18    feature = "derive",
19    derive(serde::Serialize, serde::Deserialize),
20    serde(rename_all = "kebab-case")
21)]
22pub struct OAuth2Config {
23    /// Method for presenting an OAuth 2.0 bearer token to a service
24    /// for authentication.
25    pub method: OAuth2Method,
26
27    /// Client identifier issued to the client during the registration process described by
28    /// [Section 2.2](https://datatracker.ietf.org/doc/html/rfc6749#section-2.2).
29    pub client_id: String,
30
31    /// Client password issued to the client during the registration process described by
32    /// [Section 2.2](https://datatracker.ietf.org/doc/html/rfc6749#section-2.2).
33    pub client_secret: Option<Secret>,
34
35    /// URL of the authorization server's authorization endpoint.
36    pub auth_url: String,
37
38    /// URL of the authorization server's token endpoint.
39    pub token_url: String,
40
41    /// Access token returned by the token endpoint and used to access
42    /// protected resources.
43    #[cfg_attr(
44        feature = "derive",
45        serde(default, skip_serializing_if = "Secret::is_empty")
46    )]
47    pub access_token: Secret,
48
49    /// Refresh token used to obtain a new access token (if supported
50    /// by the authorization server).
51    #[cfg_attr(
52        feature = "derive",
53        serde(default, skip_serializing_if = "Secret::is_empty")
54    )]
55    pub refresh_token: Secret,
56
57    /// Enable the [PKCE](https://datatracker.ietf.org/doc/html/rfc7636) protection.
58    /// The value must have a minimum length of 43 characters and a maximum length of 128 characters.
59    /// Each character must be ASCII alphanumeric or one of the characters “-” / “.” / “_” / “~”.
60    pub pkce: bool,
61
62    pub redirect_scheme: Option<String>,
63    pub redirect_host: Option<String>,
64    pub redirect_port: Option<u16>,
65
66    /// Access token scope(s), as defined by the authorization server.
67    #[cfg_attr(feature = "derive", serde(flatten))]
68    pub scopes: OAuth2Scopes,
69}
70
71impl OAuth2Config {
72    pub const LOCALHOST: &'static str = "localhost";
73
74    /// Return the first available port on [`LOCALHOST`].
75    pub fn get_first_available_port() -> Result<u16> {
76        (49_152..65_535)
77            .find(|port| TcpListener::bind((OAuth2Config::LOCALHOST, *port)).is_ok())
78            .ok_or(Error::GetAvailablePortError)
79    }
80
81    /// Resets the three secrets of the OAuth 2.0 configuration.
82    pub async fn reset(&self) -> Result<()> {
83        if let Some(secret) = self.client_secret.as_ref() {
84            secret
85                .delete_if_keyring()
86                .await
87                .map_err(Error::DeleteClientSecretOauthError)?;
88        }
89
90        self.access_token
91            .delete_if_keyring()
92            .await
93            .map_err(Error::DeleteAccessTokenOauthError)?;
94        self.refresh_token
95            .delete_if_keyring()
96            .await
97            .map_err(Error::DeleteRefreshTokenOauthError)?;
98
99        Ok(())
100    }
101
102    /// If the access token is not defined, runs the authorization
103    /// code grant OAuth 2.0 flow in order to save the acces token and
104    /// the refresh token if present.
105    pub async fn configure(
106        &self,
107        get_client_secret: impl Fn() -> io::Result<String>,
108    ) -> Result<()> {
109        if self.access_token.get().await.is_ok() {
110            return Ok(());
111        }
112
113        let redirect_scheme = match self.redirect_scheme.as_ref() {
114            Some(scheme) => scheme.clone(),
115            None => "http".into(),
116        };
117
118        let redirect_host = match self.redirect_host.as_ref() {
119            Some(host) => host.clone(),
120            None => OAuth2Config::LOCALHOST.to_owned(),
121        };
122
123        let redirect_port = match self.redirect_port {
124            Some(port) => port,
125            None => OAuth2Config::get_first_available_port()?,
126        };
127
128        let client_secret = match self.client_secret.as_ref() {
129            None => None,
130            Some(secret) => Some(match secret.find().await {
131                Ok(None) => {
132                    debug!("cannot find oauth2 client secret from keyring, setting it");
133                    secret
134                        .set_if_keyring(
135                            get_client_secret()
136                                .map_err(Error::GetClientSecretFromUserOauthError)?,
137                        )
138                        .await
139                        .map_err(Error::SetClientSecretIntoKeyringOauthError)
140                }
141                Ok(Some(client_secret)) => Ok(client_secret),
142                Err(err) => Err(Error::GetClientSecretFromKeyringOauthError(err)),
143            }?),
144        };
145
146        let client = Client::new(
147            self.client_id.clone(),
148            client_secret,
149            self.auth_url.clone(),
150            self.token_url.clone(),
151            redirect_scheme,
152            redirect_host,
153            redirect_port,
154        )
155        .map_err(Error::BuildOauthClientError)?;
156
157        let mut auth_code_grant = AuthorizationCodeGrant::new();
158
159        if self.pkce {
160            auth_code_grant = auth_code_grant.with_pkce();
161        }
162
163        for scope in self.scopes.clone() {
164            auth_code_grant = auth_code_grant.with_scope(scope);
165        }
166
167        let (redirect_url, csrf_token) = auth_code_grant.get_redirect_url(&client);
168
169        println!("To complete your OAuth 2.0 setup, click on the following link:");
170        println!();
171        println!("{}", redirect_url);
172
173        let (access_token, refresh_token) = auth_code_grant
174            .wait_for_redirection(&client, csrf_token)
175            .await
176            .map_err(Error::WaitForOauthRedirectionError)?;
177
178        self.access_token
179            .set_if_keyring(access_token)
180            .await
181            .map_err(Error::SetAccessTokenOauthError)?;
182
183        if let Some(refresh_token) = &refresh_token {
184            self.refresh_token
185                .set_if_keyring(refresh_token)
186                .await
187                .map_err(Error::SetRefreshTokenOauthError)?;
188        }
189
190        Ok(())
191    }
192
193    /// Runs the refresh access token OAuth 2.0 flow by exchanging a
194    /// refresh token with a new pair of access/refresh token.
195    pub async fn refresh_access_token(&self) -> Result<String> {
196        let redirect_scheme = match self.redirect_scheme.as_ref() {
197            Some(scheme) => scheme.clone(),
198            None => "http".into(),
199        };
200
201        let redirect_host = match self.redirect_host.as_ref() {
202            Some(host) => host.clone(),
203            None => OAuth2Config::LOCALHOST.to_owned(),
204        };
205
206        let redirect_port = match self.redirect_port {
207            Some(port) => port,
208            None => OAuth2Config::get_first_available_port()?,
209        };
210
211        let client_secret = match self.client_secret.as_ref() {
212            None => None,
213            Some(secret) => {
214                let secret = secret
215                    .get()
216                    .await
217                    .map_err(Error::GetClientSecretFromKeyringOauthError)?;
218                Some(secret)
219            }
220        };
221
222        let client = Client::new(
223            self.client_id.clone(),
224            client_secret,
225            self.auth_url.clone(),
226            self.token_url.clone(),
227            redirect_scheme,
228            redirect_host,
229            redirect_port,
230        )
231        .map_err(Error::BuildOauthClientError)?;
232
233        let refresh_token = self
234            .refresh_token
235            .get()
236            .await
237            .map_err(Error::GetRefreshTokenOauthError)?;
238
239        let (access_token, refresh_token) = RefreshAccessToken::new()
240            .refresh_access_token(&client, refresh_token)
241            .await
242            .map_err(Error::RefreshAccessTokenOauthError)?;
243
244        self.access_token
245            .set_if_keyring(&access_token)
246            .await
247            .map_err(Error::SetAccessTokenOauthError)?;
248
249        if let Some(refresh_token) = &refresh_token {
250            self.refresh_token
251                .set_if_keyring(refresh_token)
252                .await
253                .map_err(Error::SetRefreshTokenOauthError)?;
254        }
255
256        Ok(access_token)
257    }
258
259    /// Returns the access token if existing, otherwise returns an
260    /// error.
261    pub async fn access_token(&self) -> Result<String> {
262        self.access_token
263            .get()
264            .await
265            .map_err(Error::GetAccessTokenOauthError)
266    }
267}
268
269/// Method for presenting an OAuth 2.0 bearer token to a service for
270/// authentication.
271#[derive(Clone, Debug, Default, Eq, PartialEq)]
272#[cfg_attr(
273    feature = "derive",
274    derive(serde::Serialize, serde::Deserialize),
275    serde(rename_all = "lowercase")
276)]
277pub enum OAuth2Method {
278    #[default]
279    #[cfg_attr(feature = "derive", serde(alias = "XOAUTH2"))]
280    XOAuth2,
281    #[cfg_attr(feature = "derive", serde(alias = "OAUTHBEARER"))]
282    OAuthBearer,
283}
284
285impl fmt::Display for OAuth2Method {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        match self {
288            Self::XOAuth2 => write!(f, "XOAUTH2"),
289            Self::OAuthBearer => write!(f, "OAUTHBEARER"),
290        }
291    }
292}
293
294/// Access token scope(s), as defined by the authorization server.
295#[derive(Clone, Debug, Eq, PartialEq)]
296#[cfg_attr(
297    feature = "derive",
298    derive(serde::Serialize, serde::Deserialize),
299    serde(rename_all = "kebab-case")
300)]
301pub enum OAuth2Scopes {
302    Scope(String),
303    Scopes(Vec<String>),
304}
305
306impl Default for OAuth2Scopes {
307    fn default() -> Self {
308        Self::Scopes(Vec::new())
309    }
310}
311
312impl IntoIterator for OAuth2Scopes {
313    type IntoIter = vec::IntoIter<Self::Item>;
314    type Item = String;
315
316    fn into_iter(self) -> Self::IntoIter {
317        match self {
318            Self::Scope(scope) => vec![scope].into_iter(),
319            Self::Scopes(scopes) => scopes.into_iter(),
320        }
321    }
322}