1pub mod redirect;
3
4pub mod token;
6
7pub mod cache;
9
10use crate::config::oauth_config::AuthConfig;
11use color_eyre::Result;
12use rand::{distr::Alphanumeric, rng, Rng};
13use serde::{Deserialize, Serialize};
14use serde_json;
15use serde_urlencoded;
16use std::{io::Error, iter, str::FromStr}; use token::{Token, TokenWrapper};
18use url::Url;
19
20const USER_AGENT: &str = "mal-cli";
21const AUTHORIZE_URL: &str = "https://myanimelist.net/v1/oauth2/authorize";
22const TOKEN_URL: &str = "https://myanimelist.net/v1/oauth2/token";
23
24#[derive(Clone, Debug)]
25pub enum AuthError {
26 UnknownError,
27 NetworkTimeout,
28 InvalidResponse(String),
29 AuthNotPresent,
30 TokenNotPresent,
31}
32
33impl From<reqwest::Error> for AuthError {
34 fn from(e: reqwest::Error) -> Self {
35 if e.is_timeout() {
36 AuthError::NetworkTimeout
37 } else {
38 AuthError::UnknownError
39 }
40 }
41}
42
43impl std::error::Error for AuthError {
44 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
45 match *self {
46 AuthError::UnknownError => None,
47 AuthError::NetworkTimeout => None,
48 AuthError::InvalidResponse(_) => None,
49 AuthError::AuthNotPresent => None,
50 AuthError::TokenNotPresent => None,
51 }
52 }
53}
54
55impl std::fmt::Display for AuthError {
56 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
57 match *self {
58 AuthError::UnknownError => write!(f, "Unknown Error"),
59 AuthError::NetworkTimeout => write!(f, "Network Timeout"),
60 AuthError::InvalidResponse(ref err) => err.fmt(f),
61 AuthError::AuthNotPresent => write!(f, "Auth is not present"),
62 AuthError::TokenNotPresent => write!(f, "Token is not present"),
63 }
64 }
65}
66
67const CODE_CHALLENGE_LENGTH: usize = 128;
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct OAuth {
71 pub client_id: String,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub client_secret: Option<String>,
74 pub redirect_url: String,
75 pub user_agent: String,
76 pub challenge: String,
77 pub state: String,
78 pub auth_code: Option<String>,
79 pub token: Option<TokenWrapper>,
80}
81
82impl OAuth {
83 pub fn new<A: ToString>(
87 user_agent: A,
88 client_id: A,
89 client_secret: Option<A>,
90 redirect_url: A,
91 ) -> Self {
92 OAuth {
93 client_id: client_id.to_string(),
94 client_secret: client_secret.map(|cs| cs.to_string()),
95 redirect_url: redirect_url.to_string(),
96 user_agent: user_agent.to_string(),
97 challenge: Self::new_challenge(CODE_CHALLENGE_LENGTH),
98 state: "AUTHSTART".to_string(),
99 auth_code: None,
100 token: None,
101 }
102 }
103
104 fn new_challenge(len: usize) -> String {
108 if !(48..=128).contains(&len) {
111 panic!("len is not in between 48 and 128");
112 }
113 let mut rng = rng();
114 let challenge: String = iter::repeat(())
116 .map(|()| rng.sample(Alphanumeric) as char)
117 .take(len)
118 .collect();
119 challenge
120 }
121
122 pub fn user_agent(&self) -> &String {
124 &self.user_agent
125 }
126
127 pub fn get_auth_url(&self) -> Url {
129 #[derive(Serialize, Debug)]
130 struct AuthQuery {
131 response_type: String,
132 client_id: String,
133 code_challenge: String,
134 state: String,
135 redirect_url: String,
136 code_challenge_method: String,
137 }
138
139 let auth_query = AuthQuery {
140 response_type: "code".to_string(),
141 client_id: self.client_id.clone(),
142 code_challenge: self.challenge.clone(),
143 state: self.state.to_string(),
144 redirect_url: self.redirect_url.clone(),
145 code_challenge_method: "plain".to_string(),
147 };
148
149 url::Url::from_str(&format!(
150 "{}?{}",
151 AUTHORIZE_URL,
152 serde_urlencoded::to_string(auth_query).unwrap()
153 ))
154 .unwrap()
155 }
156
157 pub fn parse_redirect_query_string(&mut self, query_string: &str) -> Result<(), AuthError> {
159 #[derive(Deserialize, Debug)]
160 struct AuthResponse {
161 code: String,
162 state: String,
163 }
164
165 let auth_response = match serde_urlencoded::from_str::<AuthResponse>(query_string) {
166 Ok(r) => r,
167 Err(e) => {
168 return Err(AuthError::InvalidResponse(e.to_string()));
169 }
170 };
171
172 if auth_response.state != self.state {
173 return Err(AuthError::InvalidResponse("State Mismatch".to_string()));
174 }
175
176 self.auth_code = Some(auth_response.code);
177 Ok(())
178 }
179
180 pub fn get_token_query_string(&self) -> Result<String, AuthError> {
182 #[derive(Serialize, Debug)]
183 struct TokenRequest {
184 client_id: String,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 client_secret: Option<String>,
187 code: String,
188 code_verifier: String,
189 grant_type: String,
190 }
191
192 if self.auth_code.is_none() {
193 return Err(AuthError::AuthNotPresent);
194 }
195
196 let query = TokenRequest {
197 client_id: self.client_id.clone(),
198 client_secret: self.client_secret.clone(),
199 code: self.auth_code.as_ref().unwrap().clone(),
200 code_verifier: self.challenge.clone(),
201 grant_type: "authorization_code".to_string(),
202 };
203
204 Ok(serde_urlencoded::to_string(query).unwrap())
205 }
206
207 pub fn get_access_token(&mut self) -> Result<(), AuthError> {
209 let request = reqwest::blocking::ClientBuilder::new()
210 .user_agent(USER_AGENT)
211 .build()?
212 .post(TOKEN_URL)
213 .header(reqwest::header::ACCEPT, "application/json")
214 .header(
215 reqwest::header::CONTENT_TYPE,
216 "application/x-www-form-urlencoded",
217 )
218 .body(self.get_token_query_string()?);
219
220 let response = request.send()?;
221 let success = response.status().is_success();
222 let body = response.text()?;
223 self.handle_response(success, &body)
224 }
225
226 pub async fn get_access_token_async(&mut self) -> Result<(), AuthError> {
228 let request = reqwest::ClientBuilder::new()
229 .user_agent(USER_AGENT)
230 .build()?
231 .post(TOKEN_URL)
232 .header(reqwest::header::ACCEPT, "application/json")
233 .header(
234 reqwest::header::CONTENT_TYPE,
235 "application/x-www-form-urlencoded",
236 )
237 .body(self.get_token_query_string()?);
238
239 let response = request.send().await?;
240 let success = response.status().is_success();
241 let body = response.text().await?;
242 self.handle_response(success, &body)
243 }
244
245 pub fn handle_response(&mut self, success: bool, body: &str) -> Result<(), AuthError> {
247 if success {
248 match serde_json::from_str::<Token>(body) {
249 Ok(result) => {
250 self.token = Some(TokenWrapper::new(result));
251 Ok(())
252 }
253 Err(e) => Err(AuthError::InvalidResponse(e.to_string())),
254 }
255 } else {
256 println!("{}", body);
257 Err(AuthError::UnknownError)
258 }
259 }
260
261 pub fn token(&self) -> Option<&TokenWrapper> {
263 self.token.as_ref()
264 }
265
266 pub fn get_token_refresh_query_string(&self) -> Result<String, AuthError> {
267 #[derive(Serialize, Debug)]
268 struct TokenRequest {
269 client_id: String,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 client_secret: Option<String>,
272 code: String,
273 code_verifier: String,
274 grant_type: String,
275 refresh_token: String,
276 }
277
278 if self.auth_code.is_none() {
279 return Err(AuthError::AuthNotPresent);
280 }
281 if self.token.is_none() {
282 return Err(AuthError::TokenNotPresent);
283 }
284
285 let query = TokenRequest {
286 client_id: self.client_id.clone(),
287 client_secret: self.client_secret.clone(),
288 code: self.auth_code.as_ref().unwrap().clone(),
289 code_verifier: self.challenge.clone(),
290 grant_type: "refresh_token".to_string(),
291 refresh_token: self.token().unwrap().token.refresh_token.clone(),
292 };
293
294 Ok(serde_urlencoded::to_string(query).unwrap())
295 }
296
297 pub fn refresh(&mut self) -> Result<(), AuthError> {
299 if self.token().unwrap().expired() {
300 let request = reqwest::blocking::ClientBuilder::new()
301 .user_agent(USER_AGENT)
302 .build()?
303 .post(TOKEN_URL)
304 .header(reqwest::header::ACCEPT, "application/json")
305 .header(
306 reqwest::header::CONTENT_TYPE,
307 "application/x-www-form-urlencoded",
308 )
309 .body(self.get_token_refresh_query_string()?);
310
311 let response = request.send()?;
312 let success = response.status().is_success();
313 let body = response.text()?;
314 self.handle_response(success, &body)
315 } else {
316 Ok(())
317 }
318 }
319
320 pub async fn refresh_async(&mut self) -> Result<(), AuthError> {
322 if self.token().unwrap().expired() {
323 let request = reqwest::ClientBuilder::new()
324 .user_agent(USER_AGENT)
325 .build()?
326 .post(TOKEN_URL)
327 .header(reqwest::header::ACCEPT, "application/json")
328 .header(
329 reqwest::header::CONTENT_TYPE,
330 "application/x-www-form-urlencoded",
331 )
332 .body(self.get_token_refresh_query_string()?);
333
334 let response = request.send().await?;
335 let success = response.status().is_success();
336 let body = response.text().await?;
337 self.handle_response(success, &body)
338 } else {
339 Ok(())
340 }
341 }
342
343 pub async fn get_auth_async(config: AuthConfig) -> Result<OAuth, AuthError> {
344 if let Some(mut auth) = cache::load_cached_auth() {
345 auth.refresh_async().await?;
346 Ok(auth)
347 } else {
348 let auth = OAuth::new(
349 config.get_user_agent(),
350 config.client_id.clone(),
351 None,
352 config.get_redirect_uri(),
353 );
354
355 let url = auth.get_auth_url();
356
357 if test_oauth_url(&url).await {
358 open(&url).unwrap();
359 } else {
360 println!("==> Please verify your creds and retry.");
361 println!("==> Note: cached auth file will be deleted.");
362 cache::delete_cached_auth();
364 return Err(AuthError::InvalidResponse("Failed to open URL".to_string()));
366 }
367 let mut auth = redirect::Server::new(config.get_user_agent(), auth)
368 .go()
369 .unwrap();
370
371 auth.get_access_token_async().await.unwrap();
372
373 cache::cache_auth(&auth);
374
375 Ok(auth)
376 }
377 }
378
379 pub fn get_auth(config: AuthConfig) -> Result<OAuth, AuthError> {
381 if let Some(mut auth) = cache::load_cached_auth() {
382 auth.refresh()?;
383 Ok(auth)
384 } else {
385 let auth = OAuth::new(
386 config.get_user_agent(),
387 config.client_id.clone(),
388 None,
389 config.get_redirect_uri(),
390 );
391
392 let url = auth.get_auth_url();
393 open(&url).unwrap();
394
395 let mut auth = redirect::Server::new(config.get_user_agent(), auth)
396 .go()
397 .unwrap();
398
399 auth.get_access_token().unwrap();
400
401 cache::cache_auth(&auth);
402
403 Ok(auth)
404 }
405 }
406}
407
408pub async fn test_oauth_url(url: &Url) -> bool {
409 let res = reqwest::ClientBuilder::new()
410 .user_agent(USER_AGENT)
411 .build()
412 .unwrap()
413 .get(url.as_ref())
414 .send()
415 .await;
416
417 match res {
418 Ok(response) => response.status().is_success(),
419 Err(_) => false,
420 }
421}
422pub fn open(url: &Url) -> Result<(), Error> {
424 webbrowser::open(url.as_ref())
425}
426
427#[cfg(test)]
428pub mod tests {
429 use super::*;
430 pub fn get_auth() -> OAuth {
431 let config = AuthConfig::load().unwrap();
432 OAuth::get_auth(config).unwrap()
433 }
434
435 #[test]
436 fn test_refresh_token() {
437 let mut auth = get_auth();
438 auth.refresh().unwrap();
439 println!("{}", serde_json::to_string(&auth).unwrap());
440 }
441 #[test]
442 fn test_get_auth() {
443 let config = AuthConfig::load().unwrap();
445
446 let auth = OAuth::new(
448 config.get_user_agent(),
449 config.client_id.clone(),
450 None,
451 config.get_redirect_uri(),
452 );
453
454 println!("{}", auth.redirect_url);
455
456 let url = auth.get_auth_url();
458 open(&url).unwrap();
459
460 let mut auth = redirect::Server::new(config.get_user_agent(), auth)
462 .go()
463 .unwrap();
464
465 auth.get_access_token().unwrap();
467 println!("{}", serde_json::to_string(&auth).unwrap());
468
469 auth.refresh().unwrap();
471 println!("{}", serde_json::to_string(&auth).unwrap());
472
473 cache::cache_auth(&auth);
474 }
475
476 #[test]
477 fn test_challenge() {
478 let challenge = OAuth::new_challenge(CODE_CHALLENGE_LENGTH);
479
480 assert!(challenge.len() == CODE_CHALLENGE_LENGTH);
481 println!("{}", challenge);
482 println!(
483 "len: {}, CODE_CHALLENGE_LEN: {}",
484 challenge.len(),
485 CODE_CHALLENGE_LENGTH
486 );
487 }
488 #[test]
489 #[should_panic(expected = "len is not in between 48 and 128")]
490 fn test_challenge_len() {
491 let _challenge = OAuth::new_challenge(5);
493 }
494}