1use crate::{OAUTH_TOKEN_URL, OAUTH_URL};
4use oauth2::basic::BasicClient;
5use oauth2::http::Uri;
6use oauth2::reqwest::async_http_client;
7use oauth2::ClientId;
8use oauth2::{
9 AccessToken, AuthUrl, AuthorizationCode, ClientSecret, CsrfToken, PkceCodeChallenge,
10 PkceCodeVerifier, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
11};
12use serde::{Deserialize, Serialize};
13use std::marker::PhantomData;
14use std::path::Path;
15use std::time::{Duration, SystemTime};
16use std::{env, fs};
17use thiserror::Error;
18use toml;
19use url::Url;
20
21const EXPIRATION_IN_SECONDS: u64 = 2415600;
24
25#[derive(Debug, Error)]
26pub enum OauthError {
27 #[error("missing environment variable")]
28 MissingEnvVar,
29
30 #[error("missing client id")]
31 MissingClientId,
32
33 #[error("missing client secret")]
34 MissingClientSecret,
35
36 #[error("missing redirect url")]
37 MissingRedirectUrl,
38
39 #[error("received state does not match")]
40 StateMismatch,
41
42 #[error("server failed to authenticate client")]
43 BadTokenResponse,
44
45 #[error("invalid redirect url")]
46 InvalidRedirectUrl,
47
48 #[error("invalid redirect response")]
49 InvalidRedirectResponse,
50
51 #[error("missing access token")]
52 MissingAccessToken,
53
54 #[error("missing refresh token")]
55 MissingRefreshToken,
56
57 #[error("missing token expiration time")]
58 MissingTokenExpiration,
59
60 #[error("missing config")]
61 MissingConfig,
62
63 #[error("invalid config format")]
64 InvalidConfigFormat,
65
66 #[error("failed to create config")]
67 ConfigCreationFailure,
68
69 #[error("unable to fetch system time")]
70 NoSystemTime,
71
72 #[error("invalid expiration time")]
73 InvalidExpirationTime,
74
75 #[error("failed to refresh the authentication token")]
76 FailedToRefreshToken,
77
78 #[error("missing the code or state from response")]
79 MissingCodeOrState,
80}
81
82#[derive(Debug)]
86pub struct MalClientId(pub ClientId);
87
88impl MalClientId {
89 pub fn new<T: Into<String>>(id: T) -> Self {
93 let client_id = ClientId::new(id.into());
94 Self(client_id)
95 }
96
97 pub fn try_from_env() -> Result<Self, OauthError> {
99 let client_id = OauthClient::load_client_id_from_env()?;
100 Ok(Self(ClientId::new(client_id)))
101 }
102}
103
104#[derive(Debug)]
106pub struct Unauthenticated;
107
108#[derive(Debug)]
110pub struct Authenticated;
111
112#[derive(Debug)]
114pub struct OauthClient<State = Unauthenticated> {
115 client: BasicClient,
116 csrf: CsrfToken,
117 pkce_verifier: PkceCodeVerifier,
118 state: PhantomData<State>,
119 access_token: AccessToken,
120 refresh_token: RefreshToken,
121 expires_at: u64,
122}
123
124impl OauthClient<Unauthenticated> {
125 pub fn new<T: Into<String>>(
127 client_id: T,
128 client_secret: Option<T>,
129 redirect_url: T,
130 ) -> Result<Self, OauthError> {
131 let (client_id, redirect_url) = (client_id.into(), redirect_url.into());
132 let client_secret = client_secret.map(|c| c.into());
133
134 let client = Self::create_oauth2_client(client_id, client_secret, redirect_url)?;
135
136 Ok(Self {
137 client,
138 pkce_verifier: PkceCodeVerifier::new("".to_string()),
139 csrf: CsrfToken::new(String::from("")),
140 state: PhantomData::<Unauthenticated>,
141 access_token: AccessToken::new("".to_string()),
142 refresh_token: RefreshToken::new("".to_string()),
143 expires_at: Duration::new(0, 0).as_secs(),
144 })
145 }
146
147 fn create_oauth2_client(
148 client_id: String,
149 client_secret: Option<String>,
150 redirect_url: String,
151 ) -> Result<BasicClient, OauthError> {
152 match client_secret {
153 Some(c) => {
154 let client = BasicClient::new(
155 ClientId::new(client_id),
156 Some(ClientSecret::new(c.into())),
157 AuthUrl::new(OAUTH_URL.to_string()).unwrap(),
158 Some(TokenUrl::new(OAUTH_TOKEN_URL.to_string()).unwrap()),
159 )
160 .set_redirect_uri(
161 RedirectUrl::new(redirect_url).map_err(|_| OauthError::InvalidRedirectUrl)?,
162 )
163 .set_auth_type(oauth2::AuthType::BasicAuth);
164 Ok(client)
165 }
166 None => {
167 let client = BasicClient::new(
168 ClientId::new(client_id),
169 None,
170 AuthUrl::new(OAUTH_URL.to_string()).unwrap(),
171 Some(TokenUrl::new(OAUTH_TOKEN_URL.to_string()).unwrap()),
172 )
173 .set_redirect_uri(
174 RedirectUrl::new(redirect_url).map_err(|_| OauthError::InvalidRedirectUrl)?,
175 )
176 .set_auth_type(oauth2::AuthType::RequestBody);
177 Ok(client)
178 }
179 }
180 }
181
182 pub fn generate_auth_url(&mut self) -> String {
185 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_plain();
186
187 let (auth_url, csrf_token) = self
188 .client
189 .authorize_url(CsrfToken::new_random)
190 .set_pkce_challenge(pkce_challenge)
191 .url();
192
193 self.csrf = csrf_token;
194 self.pkce_verifier = pkce_verifier;
195
196 auth_url.to_string()
197 }
198
199 pub async fn authenticate(
202 self,
203 authorization_response: RedirectResponse,
204 ) -> Result<OauthClient<Authenticated>, OauthError> {
205 if authorization_response.state != *self.csrf.secret() {
206 return Err(OauthError::StateMismatch);
207 }
208
209 let code = AuthorizationCode::new(authorization_response.code);
210 let token_result = self
211 .client
212 .exchange_code(code)
213 .set_pkce_verifier(self.pkce_verifier)
214 .request_async(async_http_client)
215 .await
216 .map_err(|_| OauthError::BadTokenResponse)?;
217
218 let now = calculate_current_system_time()?;
219
220 Ok(OauthClient::<Authenticated> {
221 client: self.client,
222 csrf: self.csrf,
223 pkce_verifier: PkceCodeVerifier::new("".to_string()),
224 state: PhantomData::<Authenticated>,
225 access_token: token_result.access_token().to_owned(),
226 refresh_token: token_result
227 .refresh_token()
228 .ok_or_else(|| OauthError::MissingRefreshToken)?
229 .to_owned(),
230 expires_at: now
231 + token_result
232 .expires_in()
233 .unwrap_or(Duration::from_secs(EXPIRATION_IN_SECONDS))
234 .as_secs(),
235 })
236 }
237
238 fn load_from_env() -> Result<OauthClient<Authenticated>, OauthError> {
243 let (client_id, redirect_url) = (
244 Self::load_client_id_from_env()?,
245 Self::load_redirect_url_from_env()?,
246 );
247 let client_secret = Self::load_client_secret_from_env().ok();
248
249 let client = Self::create_oauth2_client(client_id, client_secret, redirect_url)?;
250
251 let access_token = Self::load_env_var("MAL_ACCESS_TOKEN")?;
252 let refresh_token = Self::load_env_var("MAL_REFRESH_TOKEN")?;
253 let expires_at = Self::load_env_var("MAL_TOKEN_EXPIRES_AT")
254 .map_err(|_| OauthError::MissingTokenExpiration)?
255 .parse::<u64>()
256 .map_err(|_| OauthError::InvalidExpirationTime)?;
257
258 Ok(OauthClient::<Authenticated> {
259 client,
260 csrf: CsrfToken::new(String::default()),
261 pkce_verifier: PkceCodeVerifier::new(String::default()),
262 state: PhantomData::<Authenticated>,
263 access_token: AccessToken::new(access_token),
264 refresh_token: RefreshToken::new(refresh_token),
265 expires_at,
266 })
267 }
268
269 pub fn load_from_config<T: Into<String>>(
274 path: T,
275 ) -> Result<OauthClient<Authenticated>, OauthError> {
276 let path: String = path.into();
277 let dir = env::current_dir().map_err(|_| OauthError::MissingConfig)?;
278 let path_to_config = dir.join(path);
279 if !Path::new(&path_to_config).exists() {
280 return Err(OauthError::MissingConfig);
281 }
282
283 let toml_content =
284 fs::read_to_string(&path_to_config).map_err(|_| OauthError::InvalidConfigFormat)?;
285 let parsed_toml: MalCredentialsConfig =
286 toml::from_str(&toml_content).map_err(|_| OauthError::InvalidConfigFormat)?;
287
288 env::set_var("MAL_ACCESS_TOKEN", parsed_toml.mal_access_token.to_string());
289 env::set_var(
290 "MAL_REFRESH_TOKEN",
291 parsed_toml.mal_refresh_token.to_string(),
292 );
293 env::set_var(
294 "MAL_TOKEN_EXPIRES_AT",
295 parsed_toml.mal_token_expires_at.to_string(),
296 );
297 Self::load_from_env()
298 }
299
300 pub fn load_from_values<T: Into<String>>(
308 access_token: T,
309 refresh_token: T,
310 client_id: T,
311 client_secret: Option<T>,
312 redirect_url: T,
313 expires_at: u64,
314 ) -> Result<OauthClient<Authenticated>, OauthError> {
315 let (access_token, refresh_token) = (access_token.into(), refresh_token.into());
316 let (client_id, client_secret, redirect_url) = (
317 client_id.into(),
318 client_secret.map(|c| c.into()),
319 redirect_url.into(),
320 );
321
322 let unix_epoch = SystemTime::UNIX_EPOCH
323 .duration_since(SystemTime::UNIX_EPOCH)
324 .map_err(|_| OauthError::NoSystemTime)?
325 .as_secs();
326
327 if expires_at < unix_epoch {
328 return Err(OauthError::InvalidExpirationTime);
329 }
330
331 let client = Self::create_oauth2_client(client_id, client_secret, redirect_url)?;
332
333 Ok(OauthClient::<Authenticated> {
334 client,
335 csrf: CsrfToken::new(String::default()),
336 pkce_verifier: PkceCodeVerifier::new(String::default()),
337 state: PhantomData::<Authenticated>,
338 access_token: AccessToken::new(access_token),
339 refresh_token: RefreshToken::new(refresh_token),
340 expires_at,
341 })
342 }
343
344 fn load_env_var(name: &str) -> Result<String, OauthError> {
345 let result = env::var(name).map_err(|_| OauthError::MissingEnvVar)?;
346 Ok(result)
347 }
348
349 pub fn load_client_id_from_env() -> Result<String, OauthError> {
351 let client_id =
352 Self::load_env_var("MAL_CLIENT_ID").map_err(|_| OauthError::MissingClientId)?;
353 Ok(client_id)
354 }
355
356 pub fn load_client_secret_from_env() -> Result<String, OauthError> {
358 let client_secret =
359 Self::load_env_var("MAL_CLIENT_SECRET").map_err(|_| OauthError::MissingClientSecret)?;
360 Ok(client_secret)
361 }
362
363 pub fn load_redirect_url_from_env() -> Result<String, OauthError> {
365 let redirect_url =
366 Self::load_env_var("MAL_REDIRECT_URL").map_err(|_| OauthError::MissingRedirectUrl)?;
367 Ok(redirect_url)
368 }
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372struct MalCredentialsConfig {
373 mal_access_token: String,
374 mal_refresh_token: String,
375 mal_token_expires_at: u64,
376}
377
378impl OauthClient<Authenticated> {
379 pub(crate) fn get_access_token(&self) -> &AccessToken {
381 &self.access_token
382 }
383
384 pub fn get_access_token_secret(&self) -> &String {
386 &self.access_token.secret()
387 }
388
389 pub fn get_refresh_token_secret(&self) -> &String {
391 &self.refresh_token.secret()
392 }
393
394 pub fn get_expires_at(&self) -> &u64 {
398 &self.expires_at
399 }
400
401 pub fn save_to_config<T: Into<String>>(&self, path: T) -> Result<(), OauthError> {
406 let path: String = path.into();
407 let dir = env::current_dir().map_err(|_| OauthError::MissingConfig)?;
408 let path_to_config = dir.join(path);
409
410 let config = MalCredentialsConfig {
411 mal_access_token: self.access_token.secret().clone(),
412 mal_refresh_token: self.refresh_token.secret().clone(),
413 mal_token_expires_at: *self.get_expires_at(),
414 };
415 let toml = toml::to_string(&config).map_err(|_| OauthError::InvalidConfigFormat)?;
416
417 if let Some(parent_dir) = Path::new(&path_to_config).parent() {
418 fs::create_dir_all(parent_dir).map_err(|_| OauthError::ConfigCreationFailure)?;
419 }
420
421 fs::write(&path_to_config, toml).map_err(|_| OauthError::ConfigCreationFailure)?;
422 Ok(())
423 }
424
425 pub async fn refresh(self) -> Result<Self, OauthError> {
427 let refresh_result = self
428 .client
429 .exchange_refresh_token(&self.refresh_token)
430 .request_async(async_http_client)
431 .await
432 .map_err(|_| OauthError::FailedToRefreshToken)?;
433
434 let now = calculate_current_system_time()?;
435
436 Ok(OauthClient::<Authenticated> {
437 client: self.client,
438 csrf: self.csrf,
439 pkce_verifier: PkceCodeVerifier::new("".to_string()),
440 state: PhantomData::<Authenticated>,
441 access_token: refresh_result.access_token().to_owned(),
442 refresh_token: refresh_result.refresh_token().unwrap().to_owned(),
443 expires_at: now
444 + refresh_result
445 .expires_in()
446 .unwrap_or(Duration::from_secs(EXPIRATION_IN_SECONDS))
447 .as_secs(),
448 })
449 }
450}
451
452#[derive(Debug, Deserialize)]
453pub struct RedirectResponse {
454 code: String,
455 state: String,
456}
457
458impl RedirectResponse {
459 pub fn new<T: Into<String>>(code: T, state: T) -> Self {
461 let code = code.into();
462 let state = state.into();
463 Self { code, state }
464 }
465
466 pub fn try_from_uri(uri: &Uri) -> Result<RedirectResponse, OauthError> {
471 let query_params: Option<Self> = uri.query().map(|query| {
472 serde_urlencoded::from_str(query).expect("Failed to get code and state from response.")
473 });
474
475 match query_params {
476 Some(q) => Ok(q),
477 None => Err(OauthError::InvalidRedirectResponse),
478 }
479 }
480}
481
482impl TryFrom<String> for RedirectResponse {
483 type Error = OauthError;
484
485 fn try_from(value: String) -> Result<Self, Self::Error> {
486 let query_string = value
487 .parse::<Url>()
488 .map_err(|_| OauthError::InvalidRedirectResponse)?;
489
490 let query_params = query_string
491 .query()
492 .ok_or_else(|| OauthError::MissingCodeOrState)?;
493
494 serde_urlencoded::from_str::<RedirectResponse>(&query_params)
495 .map_err(|_| OauthError::MissingCodeOrState)
496 }
497}
498
499fn calculate_current_system_time() -> Result<u64, OauthError> {
500 let now = SystemTime::UNIX_EPOCH
501 .duration_since(SystemTime::UNIX_EPOCH)
502 .map_err(|_| OauthError::NoSystemTime)?
503 .as_secs();
504 Ok(now)
505}