mal_api/
oauth.rs

1//! Module for working through MAL OAuth2 flow
2
3use crate::{OAUTH_TOKEN_URL, OAUTH_URL};
4use oauth2::basic::BasicClient;
5use oauth2::http::Uri;
6use oauth2::reqwest::async_http_client;
7use oauth2::ClientId;
8use oauth2::{
9    AccessToken, AuthUrl, AuthorizationCode, ClientSecret, CsrfToken, PkceCodeChallenge,
10    PkceCodeVerifier, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
11};
12use serde::{Deserialize, Serialize};
13use std::marker::PhantomData;
14use std::path::Path;
15use std::time::{Duration, SystemTime};
16use std::{env, fs};
17use thiserror::Error;
18use toml;
19use url::Url;
20
21// Expiration date for access tokens is one month
22// We use 28 days in seconds to be safe
23const EXPIRATION_IN_SECONDS: u64 = 2415600;
24
25#[derive(Debug, Error)]
26pub enum OauthError {
27    #[error("missing environment variable")]
28    MissingEnvVar,
29
30    #[error("missing client id")]
31    MissingClientId,
32
33    #[error("missing client secret")]
34    MissingClientSecret,
35
36    #[error("missing redirect url")]
37    MissingRedirectUrl,
38
39    #[error("received state does not match")]
40    StateMismatch,
41
42    #[error("server failed to authenticate client")]
43    BadTokenResponse,
44
45    #[error("invalid redirect url")]
46    InvalidRedirectUrl,
47
48    #[error("invalid redirect response")]
49    InvalidRedirectResponse,
50
51    #[error("missing access token")]
52    MissingAccessToken,
53
54    #[error("missing refresh token")]
55    MissingRefreshToken,
56
57    #[error("missing token expiration time")]
58    MissingTokenExpiration,
59
60    #[error("missing config")]
61    MissingConfig,
62
63    #[error("invalid config format")]
64    InvalidConfigFormat,
65
66    #[error("failed to create config")]
67    ConfigCreationFailure,
68
69    #[error("unable to fetch system time")]
70    NoSystemTime,
71
72    #[error("invalid expiration time")]
73    InvalidExpirationTime,
74
75    #[error("failed to refresh the authentication token")]
76    FailedToRefreshToken,
77
78    #[error("missing the code or state from response")]
79    MissingCodeOrState,
80}
81
82/// If you only need to access public information on MAL that does
83/// not require an Oauth access token, you can use the [MalClientId]
84/// as your authorization client
85#[derive(Debug)]
86pub struct MalClientId(pub ClientId);
87
88impl MalClientId {
89    /// Create a [MalClientId] by passing in your ClientId as a string
90    ///
91    /// Useful if you want to control how your program fetches your MAL `MAL_CLIENT_ID`
92    pub fn new<T: Into<String>>(id: T) -> Self {
93        let client_id = ClientId::new(id.into());
94        Self(client_id)
95    }
96
97    /// Try to load your MAL ClientId from the environment variable `MAL_CLIENT_ID`
98    pub fn try_from_env() -> Result<Self, OauthError> {
99        let client_id = OauthClient::load_client_id_from_env()?;
100        Ok(Self(ClientId::new(client_id)))
101    }
102}
103
104/// State struct for separating an Authenticated and Unauthenticated OAuthClient
105#[derive(Debug)]
106pub struct Unauthenticated;
107
108/// State struct for separating an Authenticated and Unauthenticated OAuthClient
109#[derive(Debug)]
110pub struct Authenticated;
111
112/// Client used to navigate and manage Oauth credentials with MAL
113#[derive(Debug)]
114pub struct OauthClient<State = Unauthenticated> {
115    client: BasicClient,
116    csrf: CsrfToken,
117    pkce_verifier: PkceCodeVerifier,
118    state: PhantomData<State>,
119    access_token: AccessToken,
120    refresh_token: RefreshToken,
121    expires_at: u64,
122}
123
124impl OauthClient<Unauthenticated> {
125    /// Creates a new [OauthClient] for the PKCE flow
126    pub fn new<T: Into<String>>(
127        client_id: T,
128        client_secret: Option<T>,
129        redirect_url: T,
130    ) -> Result<Self, OauthError> {
131        let (client_id, redirect_url) = (client_id.into(), redirect_url.into());
132        let client_secret = client_secret.map(|c| c.into());
133
134        let client = Self::create_oauth2_client(client_id, client_secret, redirect_url)?;
135
136        Ok(Self {
137            client,
138            pkce_verifier: PkceCodeVerifier::new("".to_string()),
139            csrf: CsrfToken::new(String::from("")),
140            state: PhantomData::<Unauthenticated>,
141            access_token: AccessToken::new("".to_string()),
142            refresh_token: RefreshToken::new("".to_string()),
143            expires_at: Duration::new(0, 0).as_secs(),
144        })
145    }
146
147    fn create_oauth2_client(
148        client_id: String,
149        client_secret: Option<String>,
150        redirect_url: String,
151    ) -> Result<BasicClient, OauthError> {
152        match client_secret {
153            Some(c) => {
154                let client = BasicClient::new(
155                    ClientId::new(client_id),
156                    Some(ClientSecret::new(c.into())),
157                    AuthUrl::new(OAUTH_URL.to_string()).unwrap(),
158                    Some(TokenUrl::new(OAUTH_TOKEN_URL.to_string()).unwrap()),
159                )
160                .set_redirect_uri(
161                    RedirectUrl::new(redirect_url).map_err(|_| OauthError::InvalidRedirectUrl)?,
162                )
163                .set_auth_type(oauth2::AuthType::BasicAuth);
164                Ok(client)
165            }
166            None => {
167                let client = BasicClient::new(
168                    ClientId::new(client_id),
169                    None,
170                    AuthUrl::new(OAUTH_URL.to_string()).unwrap(),
171                    Some(TokenUrl::new(OAUTH_TOKEN_URL.to_string()).unwrap()),
172                )
173                .set_redirect_uri(
174                    RedirectUrl::new(redirect_url).map_err(|_| OauthError::InvalidRedirectUrl)?,
175                )
176                .set_auth_type(oauth2::AuthType::RequestBody);
177                Ok(client)
178            }
179        }
180    }
181
182    /// Generate an authorization URL for the user to navigate to,
183    /// to begin the authorization process
184    pub fn generate_auth_url(&mut self) -> String {
185        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_plain();
186
187        let (auth_url, csrf_token) = self
188            .client
189            .authorize_url(CsrfToken::new_random)
190            .set_pkce_challenge(pkce_challenge)
191            .url();
192
193        self.csrf = csrf_token;
194        self.pkce_verifier = pkce_verifier;
195
196        auth_url.to_string()
197    }
198
199    /// Try and authenticate the client using a redirect response to
200    /// get an authenticated Oauth client back
201    pub async fn authenticate(
202        self,
203        authorization_response: RedirectResponse,
204    ) -> Result<OauthClient<Authenticated>, OauthError> {
205        if authorization_response.state != *self.csrf.secret() {
206            return Err(OauthError::StateMismatch);
207        }
208
209        let code = AuthorizationCode::new(authorization_response.code);
210        let token_result = self
211            .client
212            .exchange_code(code)
213            .set_pkce_verifier(self.pkce_verifier)
214            .request_async(async_http_client)
215            .await
216            .map_err(|_| OauthError::BadTokenResponse)?;
217
218        let now = calculate_current_system_time()?;
219
220        Ok(OauthClient::<Authenticated> {
221            client: self.client,
222            csrf: self.csrf,
223            pkce_verifier: PkceCodeVerifier::new("".to_string()),
224            state: PhantomData::<Authenticated>,
225            access_token: token_result.access_token().to_owned(),
226            refresh_token: token_result
227                .refresh_token()
228                .ok_or_else(|| OauthError::MissingRefreshToken)?
229                .to_owned(),
230            expires_at: now
231                + token_result
232                    .expires_in()
233                    .unwrap_or(Duration::from_secs(EXPIRATION_IN_SECONDS))
234                    .as_secs(),
235        })
236    }
237
238    /// Load Oauth credentials from the environment
239    ///
240    /// `Note`: This is expected to work after saving the credentials from an
241    /// authenticated OauthClient
242    fn load_from_env() -> Result<OauthClient<Authenticated>, OauthError> {
243        let (client_id, redirect_url) = (
244            Self::load_client_id_from_env()?,
245            Self::load_redirect_url_from_env()?,
246        );
247        let client_secret = Self::load_client_secret_from_env().ok();
248
249        let client = Self::create_oauth2_client(client_id, client_secret, redirect_url)?;
250
251        let access_token = Self::load_env_var("MAL_ACCESS_TOKEN")?;
252        let refresh_token = Self::load_env_var("MAL_REFRESH_TOKEN")?;
253        let expires_at = Self::load_env_var("MAL_TOKEN_EXPIRES_AT")
254            .map_err(|_| OauthError::MissingTokenExpiration)?
255            .parse::<u64>()
256            .map_err(|_| OauthError::InvalidExpirationTime)?;
257
258        Ok(OauthClient::<Authenticated> {
259            client,
260            csrf: CsrfToken::new(String::default()),
261            pkce_verifier: PkceCodeVerifier::new(String::default()),
262            state: PhantomData::<Authenticated>,
263            access_token: AccessToken::new(access_token),
264            refresh_token: RefreshToken::new(refresh_token),
265            expires_at,
266        })
267    }
268
269    /// Load an authenticated Oauth client from a MAL config file
270    ///
271    /// It is recommended to refresh the client after loading to ensure
272    /// that all of the tokens are still valid
273    pub fn load_from_config<T: Into<String>>(
274        path: T,
275    ) -> Result<OauthClient<Authenticated>, OauthError> {
276        let path: String = path.into();
277        let dir = env::current_dir().map_err(|_| OauthError::MissingConfig)?;
278        let path_to_config = dir.join(path);
279        if !Path::new(&path_to_config).exists() {
280            return Err(OauthError::MissingConfig);
281        }
282
283        let toml_content =
284            fs::read_to_string(&path_to_config).map_err(|_| OauthError::InvalidConfigFormat)?;
285        let parsed_toml: MalCredentialsConfig =
286            toml::from_str(&toml_content).map_err(|_| OauthError::InvalidConfigFormat)?;
287
288        env::set_var("MAL_ACCESS_TOKEN", parsed_toml.mal_access_token.to_string());
289        env::set_var(
290            "MAL_REFRESH_TOKEN",
291            parsed_toml.mal_refresh_token.to_string(),
292        );
293        env::set_var(
294            "MAL_TOKEN_EXPIRES_AT",
295            parsed_toml.mal_token_expires_at.to_string(),
296        );
297        Self::load_from_env()
298    }
299
300    /// Load an authenticated OauthClient by passing the necessary values
301    ///
302    /// It's recommended to refresh the client after to ensure that
303    /// the given values are still valid credentials.
304    ///
305    /// `Note`: This method still relies on the `MAL_CLIENT_ID`, `MAL_CLIENT_SECRET`, and
306    /// `MAL_REDIRECT_URL` environment variables being set
307    pub fn load_from_values<T: Into<String>>(
308        access_token: T,
309        refresh_token: T,
310        client_id: T,
311        client_secret: Option<T>,
312        redirect_url: T,
313        expires_at: u64,
314    ) -> Result<OauthClient<Authenticated>, OauthError> {
315        let (access_token, refresh_token) = (access_token.into(), refresh_token.into());
316        let (client_id, client_secret, redirect_url) = (
317            client_id.into(),
318            client_secret.map(|c| c.into()),
319            redirect_url.into(),
320        );
321
322        let unix_epoch = SystemTime::UNIX_EPOCH
323            .duration_since(SystemTime::UNIX_EPOCH)
324            .map_err(|_| OauthError::NoSystemTime)?
325            .as_secs();
326
327        if expires_at < unix_epoch {
328            return Err(OauthError::InvalidExpirationTime);
329        }
330
331        let client = Self::create_oauth2_client(client_id, client_secret, redirect_url)?;
332
333        Ok(OauthClient::<Authenticated> {
334            client,
335            csrf: CsrfToken::new(String::default()),
336            pkce_verifier: PkceCodeVerifier::new(String::default()),
337            state: PhantomData::<Authenticated>,
338            access_token: AccessToken::new(access_token),
339            refresh_token: RefreshToken::new(refresh_token),
340            expires_at,
341        })
342    }
343
344    fn load_env_var(name: &str) -> Result<String, OauthError> {
345        let result = env::var(name).map_err(|_| OauthError::MissingEnvVar)?;
346        Ok(result)
347    }
348
349    /// Load the MAL_CLIENT_ID environment variable
350    pub fn load_client_id_from_env() -> Result<String, OauthError> {
351        let client_id =
352            Self::load_env_var("MAL_CLIENT_ID").map_err(|_| OauthError::MissingClientId)?;
353        Ok(client_id)
354    }
355
356    /// Load the MAL_CLIENT_SECRET environment variable
357    pub fn load_client_secret_from_env() -> Result<String, OauthError> {
358        let client_secret =
359            Self::load_env_var("MAL_CLIENT_SECRET").map_err(|_| OauthError::MissingClientSecret)?;
360        Ok(client_secret)
361    }
362
363    /// Load the MAL_REDIRECT_URL environment variable
364    pub fn load_redirect_url_from_env() -> Result<String, OauthError> {
365        let redirect_url =
366            Self::load_env_var("MAL_REDIRECT_URL").map_err(|_| OauthError::MissingRedirectUrl)?;
367        Ok(redirect_url)
368    }
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372struct MalCredentialsConfig {
373    mal_access_token: String,
374    mal_refresh_token: String,
375    mal_token_expires_at: u64,
376}
377
378impl OauthClient<Authenticated> {
379    /// Get the access token for the OauthClient
380    pub(crate) fn get_access_token(&self) -> &AccessToken {
381        &self.access_token
382    }
383
384    /// Get the access token secret value
385    pub fn get_access_token_secret(&self) -> &String {
386        &self.access_token.secret()
387    }
388
389    /// Get the refresh token secret value
390    pub fn get_refresh_token_secret(&self) -> &String {
391        &self.refresh_token.secret()
392    }
393
394    /// Get the time at which the token will expire
395    ///
396    /// The time is represented as number of seconds since the Unix Epoch
397    pub fn get_expires_at(&self) -> &u64 {
398        &self.expires_at
399    }
400
401    /// Save the Oauth credentials to the config
402    ///
403    /// This method is available if you want to persist your
404    /// access, refresh, and expires_at values on the host
405    pub fn save_to_config<T: Into<String>>(&self, path: T) -> Result<(), OauthError> {
406        let path: String = path.into();
407        let dir = env::current_dir().map_err(|_| OauthError::MissingConfig)?;
408        let path_to_config = dir.join(path);
409
410        let config = MalCredentialsConfig {
411            mal_access_token: self.access_token.secret().clone(),
412            mal_refresh_token: self.refresh_token.secret().clone(),
413            mal_token_expires_at: *self.get_expires_at(),
414        };
415        let toml = toml::to_string(&config).map_err(|_| OauthError::InvalidConfigFormat)?;
416
417        if let Some(parent_dir) = Path::new(&path_to_config).parent() {
418            fs::create_dir_all(parent_dir).map_err(|_| OauthError::ConfigCreationFailure)?;
419        }
420
421        fs::write(&path_to_config, toml).map_err(|_| OauthError::ConfigCreationFailure)?;
422        Ok(())
423    }
424
425    /// Refresh the access token using the refresh token
426    pub async fn refresh(self) -> Result<Self, OauthError> {
427        let refresh_result = self
428            .client
429            .exchange_refresh_token(&self.refresh_token)
430            .request_async(async_http_client)
431            .await
432            .map_err(|_| OauthError::FailedToRefreshToken)?;
433
434        let now = calculate_current_system_time()?;
435
436        Ok(OauthClient::<Authenticated> {
437            client: self.client,
438            csrf: self.csrf,
439            pkce_verifier: PkceCodeVerifier::new("".to_string()),
440            state: PhantomData::<Authenticated>,
441            access_token: refresh_result.access_token().to_owned(),
442            refresh_token: refresh_result.refresh_token().unwrap().to_owned(),
443            expires_at: now
444                + refresh_result
445                    .expires_in()
446                    .unwrap_or(Duration::from_secs(EXPIRATION_IN_SECONDS))
447                    .as_secs(),
448        })
449    }
450}
451
452#[derive(Debug, Deserialize)]
453pub struct RedirectResponse {
454    code: String,
455    state: String,
456}
457
458impl RedirectResponse {
459    /// Create a new RedirectResponse from given code and state
460    pub fn new<T: Into<String>>(code: T, state: T) -> Self {
461        let code = code.into();
462        let state = state.into();
463        Self { code, state }
464    }
465
466    /// Create a RedirectResponse from the given OAuth2 redirect result
467    ///
468    /// This function just requires a reference to a Uri, that includes
469    /// the `code` and `state` parameters
470    pub fn try_from_uri(uri: &Uri) -> Result<RedirectResponse, OauthError> {
471        let query_params: Option<Self> = uri.query().map(|query| {
472            serde_urlencoded::from_str(query).expect("Failed to get code and state from response.")
473        });
474
475        match query_params {
476            Some(q) => Ok(q),
477            None => Err(OauthError::InvalidRedirectResponse),
478        }
479    }
480}
481
482impl TryFrom<String> for RedirectResponse {
483    type Error = OauthError;
484
485    fn try_from(value: String) -> Result<Self, Self::Error> {
486        let query_string = value
487            .parse::<Url>()
488            .map_err(|_| OauthError::InvalidRedirectResponse)?;
489
490        let query_params = query_string
491            .query()
492            .ok_or_else(|| OauthError::MissingCodeOrState)?;
493
494        serde_urlencoded::from_str::<RedirectResponse>(&query_params)
495            .map_err(|_| OauthError::MissingCodeOrState)
496    }
497}
498
499fn calculate_current_system_time() -> Result<u64, OauthError> {
500    let now = SystemTime::UNIX_EPOCH
501        .duration_since(SystemTime::UNIX_EPOCH)
502        .map_err(|_| OauthError::NoSystemTime)?
503        .as_secs();
504    Ok(now)
505}