mal/auth/
mod.rs

1/// structs and methods for oauth2 authentication flow
2pub mod redirect;
3
4/// structs and methods for token management
5pub mod token;
6
7/// methods for cache
8pub mod cache;
9
10use crate::config::oauth_config::AuthConfig;
11use color_eyre::Result;
12use rand::{distr::Alphanumeric, rng, Rng};
13use serde::{Deserialize, Serialize};
14use serde_json;
15use serde_urlencoded;
16use std::{io::Error, iter, str::FromStr}; // process::Output
17use token::{Token, TokenWrapper};
18use url::Url;
19
20const USER_AGENT: &str = "mal-cli";
21const AUTHORIZE_URL: &str = "https://myanimelist.net/v1/oauth2/authorize";
22const TOKEN_URL: &str = "https://myanimelist.net/v1/oauth2/token";
23
24#[derive(Clone, Debug)]
25pub enum AuthError {
26    UnknownError,
27    NetworkTimeout,
28    InvalidResponse(String),
29    AuthNotPresent,
30    TokenNotPresent,
31}
32
33impl From<reqwest::Error> for AuthError {
34    fn from(e: reqwest::Error) -> Self {
35        if e.is_timeout() {
36            AuthError::NetworkTimeout
37        } else {
38            AuthError::UnknownError
39        }
40    }
41}
42
43impl std::error::Error for AuthError {
44    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
45        match *self {
46            AuthError::UnknownError => None,
47            AuthError::NetworkTimeout => None,
48            AuthError::InvalidResponse(_) => None,
49            AuthError::AuthNotPresent => None,
50            AuthError::TokenNotPresent => None,
51        }
52    }
53}
54
55impl std::fmt::Display for AuthError {
56    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
57        match *self {
58            AuthError::UnknownError => write!(f, "Unknown Error"),
59            AuthError::NetworkTimeout => write!(f, "Network Timeout"),
60            AuthError::InvalidResponse(ref err) => err.fmt(f),
61            AuthError::AuthNotPresent => write!(f, "Auth is not present"),
62            AuthError::TokenNotPresent => write!(f, "Token is not present"),
63        }
64    }
65}
66
67const CODE_CHALLENGE_LENGTH: usize = 128;
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct OAuth {
71    pub client_id: String,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub client_secret: Option<String>,
74    pub redirect_url: String,
75    pub user_agent: String,
76    pub challenge: String,
77    pub state: String,
78    pub auth_code: Option<String>,
79    pub token: Option<TokenWrapper>,
80}
81
82impl OAuth {
83    /// Start of a new oauth2 flow
84    /// # Parameters
85    /// * `user`
86    pub fn new<A: ToString>(
87        user_agent: A,
88        client_id: A,
89        client_secret: Option<A>,
90        redirect_url: A,
91    ) -> Self {
92        OAuth {
93            client_id: client_id.to_string(),
94            client_secret: client_secret.map(|cs| cs.to_string()),
95            redirect_url: redirect_url.to_string(),
96            user_agent: user_agent.to_string(),
97            challenge: Self::new_challenge(CODE_CHALLENGE_LENGTH),
98            state: "AUTHSTART".to_string(),
99            auth_code: None,
100            token: None,
101        }
102    }
103
104    /// Generates a new base64-encoded SHA-256 PKCE code
105    /// # Panic
106    /// `len` needs to be a value between 48 and 128
107    fn new_challenge(len: usize) -> String {
108        // Check whether the len in in between the valid length for a
109        // PKCE code (43 chars - 128 chars)
110        if !(48..=128).contains(&len) {
111            panic!("len is not in between 48 and 128");
112        }
113        let mut rng = rng();
114        // needs to be url safe so we use Alphanumeric
115        let challenge: String = iter::repeat(())
116            .map(|()| rng.sample(Alphanumeric) as char)
117            .take(len)
118            .collect();
119        challenge
120    }
121
122    /// Returns user agent
123    pub fn user_agent(&self) -> &String {
124        &self.user_agent
125    }
126
127    /// Creates a new authorization url
128    pub fn get_auth_url(&self) -> Url {
129        #[derive(Serialize, Debug)]
130        struct AuthQuery {
131            response_type: String,
132            client_id: String,
133            code_challenge: String,
134            state: String,
135            redirect_url: String,
136            code_challenge_method: String,
137        }
138
139        let auth_query = AuthQuery {
140            response_type: "code".to_string(),
141            client_id: self.client_id.clone(),
142            code_challenge: self.challenge.clone(),
143            state: self.state.to_string(),
144            redirect_url: self.redirect_url.clone(),
145            // mal only supports plain
146            code_challenge_method: "plain".to_string(),
147        };
148
149        url::Url::from_str(&format!(
150            "{}?{}",
151            AUTHORIZE_URL,
152            serde_urlencoded::to_string(auth_query).unwrap()
153        ))
154        .unwrap()
155    }
156
157    /// Parses redirection url
158    pub fn parse_redirect_query_string(&mut self, query_string: &str) -> Result<(), AuthError> {
159        #[derive(Deserialize, Debug)]
160        struct AuthResponse {
161            code: String,
162            state: String,
163        }
164
165        let auth_response = match serde_urlencoded::from_str::<AuthResponse>(query_string) {
166            Ok(r) => r,
167            Err(e) => {
168                return Err(AuthError::InvalidResponse(e.to_string()));
169            }
170        };
171
172        if auth_response.state != self.state {
173            return Err(AuthError::InvalidResponse("State Mismatch".to_string()));
174        }
175
176        self.auth_code = Some(auth_response.code);
177        Ok(())
178    }
179
180    /// Creates a new url to get the token
181    pub fn get_token_query_string(&self) -> Result<String, AuthError> {
182        #[derive(Serialize, Debug)]
183        struct TokenRequest {
184            client_id: String,
185            #[serde(skip_serializing_if = "Option::is_none")]
186            client_secret: Option<String>,
187            code: String,
188            code_verifier: String,
189            grant_type: String,
190        }
191
192        if self.auth_code.is_none() {
193            return Err(AuthError::AuthNotPresent);
194        }
195
196        let query = TokenRequest {
197            client_id: self.client_id.clone(),
198            client_secret: self.client_secret.clone(),
199            code: self.auth_code.as_ref().unwrap().clone(),
200            code_verifier: self.challenge.clone(),
201            grant_type: "authorization_code".to_string(),
202        };
203
204        Ok(serde_urlencoded::to_string(query).unwrap())
205    }
206
207    /// Get access token
208    pub fn get_access_token(&mut self) -> Result<(), AuthError> {
209        let request = reqwest::blocking::ClientBuilder::new()
210            .user_agent(USER_AGENT)
211            .build()?
212            .post(TOKEN_URL)
213            .header(reqwest::header::ACCEPT, "application/json")
214            .header(
215                reqwest::header::CONTENT_TYPE,
216                "application/x-www-form-urlencoded",
217            )
218            .body(self.get_token_query_string()?);
219
220        let response = request.send()?;
221        let success = response.status().is_success();
222        let body = response.text()?;
223        self.handle_response(success, &body)
224    }
225
226    /// Refresh the token (async)
227    pub async fn get_access_token_async(&mut self) -> Result<(), AuthError> {
228        let request = reqwest::ClientBuilder::new()
229            .user_agent(USER_AGENT)
230            .build()?
231            .post(TOKEN_URL)
232            .header(reqwest::header::ACCEPT, "application/json")
233            .header(
234                reqwest::header::CONTENT_TYPE,
235                "application/x-www-form-urlencoded",
236            )
237            .body(self.get_token_query_string()?);
238
239        let response = request.send().await?;
240        let success = response.status().is_success();
241        let body = response.text().await?;
242        self.handle_response(success, &body)
243    }
244
245    /// Handle a repsonse for get_access_token()
246    pub fn handle_response(&mut self, success: bool, body: &str) -> Result<(), AuthError> {
247        if success {
248            match serde_json::from_str::<Token>(body) {
249                Ok(result) => {
250                    self.token = Some(TokenWrapper::new(result));
251                    Ok(())
252                }
253                Err(e) => Err(AuthError::InvalidResponse(e.to_string())),
254            }
255        } else {
256            println!("{}", body);
257            Err(AuthError::UnknownError)
258        }
259    }
260
261    /// Get a token reference
262    pub fn token(&self) -> Option<&TokenWrapper> {
263        self.token.as_ref()
264    }
265
266    pub fn get_token_refresh_query_string(&self) -> Result<String, AuthError> {
267        #[derive(Serialize, Debug)]
268        struct TokenRequest {
269            client_id: String,
270            #[serde(skip_serializing_if = "Option::is_none")]
271            client_secret: Option<String>,
272            code: String,
273            code_verifier: String,
274            grant_type: String,
275            refresh_token: String,
276        }
277
278        if self.auth_code.is_none() {
279            return Err(AuthError::AuthNotPresent);
280        }
281        if self.token.is_none() {
282            return Err(AuthError::TokenNotPresent);
283        }
284
285        let query = TokenRequest {
286            client_id: self.client_id.clone(),
287            client_secret: self.client_secret.clone(),
288            code: self.auth_code.as_ref().unwrap().clone(),
289            code_verifier: self.challenge.clone(),
290            grant_type: "refresh_token".to_string(),
291            refresh_token: self.token().unwrap().token.refresh_token.clone(),
292        };
293
294        Ok(serde_urlencoded::to_string(query).unwrap())
295    }
296
297    /// Refresh the token
298    pub fn refresh(&mut self) -> Result<(), AuthError> {
299        if self.token().unwrap().expired() {
300            let request = reqwest::blocking::ClientBuilder::new()
301                .user_agent(USER_AGENT)
302                .build()?
303                .post(TOKEN_URL)
304                .header(reqwest::header::ACCEPT, "application/json")
305                .header(
306                    reqwest::header::CONTENT_TYPE,
307                    "application/x-www-form-urlencoded",
308                )
309                .body(self.get_token_refresh_query_string()?);
310
311            let response = request.send()?;
312            let success = response.status().is_success();
313            let body = response.text()?;
314            self.handle_response(success, &body)
315        } else {
316            Ok(())
317        }
318    }
319
320    /// Refresh the token (async)
321    pub async fn refresh_async(&mut self) -> Result<(), AuthError> {
322        if self.token().unwrap().expired() {
323            let request = reqwest::ClientBuilder::new()
324                .user_agent(USER_AGENT)
325                .build()?
326                .post(TOKEN_URL)
327                .header(reqwest::header::ACCEPT, "application/json")
328                .header(
329                    reqwest::header::CONTENT_TYPE,
330                    "application/x-www-form-urlencoded",
331                )
332                .body(self.get_token_refresh_query_string()?);
333
334            let response = request.send().await?;
335            let success = response.status().is_success();
336            let body = response.text().await?;
337            self.handle_response(success, &body)
338        } else {
339            Ok(())
340        }
341    }
342
343    pub async fn get_auth_async(config: AuthConfig) -> Result<OAuth, AuthError> {
344        if let Some(mut auth) = cache::load_cached_auth() {
345            auth.refresh_async().await?;
346            Ok(auth)
347        } else {
348            let auth = OAuth::new(
349                config.get_user_agent(),
350                config.client_id.clone(),
351                None,
352                config.get_redirect_uri(),
353            );
354
355            let url = auth.get_auth_url();
356
357            if test_oauth_url(&url).await {
358                open(&url).unwrap();
359            } else {
360                println!("==> Please verify your creds and retry.");
361                println!("==> Note: cached auth file will be deleted.");
362                // delete oauth cache file
363                cache::delete_cached_auth();
364                // If the URL cannot be opened, return an error
365                return Err(AuthError::InvalidResponse("Failed to open URL".to_string()));
366            }
367            let mut auth = redirect::Server::new(config.get_user_agent(), auth)
368                .go()
369                .unwrap();
370
371            auth.get_access_token_async().await.unwrap();
372
373            cache::cache_auth(&auth);
374
375            Ok(auth)
376        }
377    }
378
379    // for tests
380    pub fn get_auth(config: AuthConfig) -> Result<OAuth, AuthError> {
381        if let Some(mut auth) = cache::load_cached_auth() {
382            auth.refresh()?;
383            Ok(auth)
384        } else {
385            let auth = OAuth::new(
386                config.get_user_agent(),
387                config.client_id.clone(),
388                None,
389                config.get_redirect_uri(),
390            );
391
392            let url = auth.get_auth_url();
393            open(&url).unwrap();
394
395            let mut auth = redirect::Server::new(config.get_user_agent(), auth)
396                .go()
397                .unwrap();
398
399            auth.get_access_token().unwrap();
400
401            cache::cache_auth(&auth);
402
403            Ok(auth)
404        }
405    }
406}
407
408pub async fn test_oauth_url(url: &Url) -> bool {
409    let res = reqwest::ClientBuilder::new()
410        .user_agent(USER_AGENT)
411        .build()
412        .unwrap()
413        .get(url.as_ref())
414        .send()
415        .await;
416
417    match res {
418        Ok(response) => response.status().is_success(),
419        Err(_) => false,
420    }
421}
422/// use webbrowser crate to open url in browser
423pub fn open(url: &Url) -> Result<(), Error> {
424    webbrowser::open(url.as_ref())
425}
426
427#[cfg(test)]
428pub mod tests {
429    use super::*;
430    pub fn get_auth() -> OAuth {
431        let config = AuthConfig::load().unwrap();
432        OAuth::get_auth(config).unwrap()
433    }
434
435    #[test]
436    fn test_refresh_token() {
437        let mut auth = get_auth();
438        auth.refresh().unwrap();
439        println!("{}", serde_json::to_string(&auth).unwrap());
440    }
441    #[test]
442    fn test_get_auth() {
443        // Get config from file
444        let config = AuthConfig::load().unwrap();
445
446        // make auth
447        let auth = OAuth::new(
448            config.get_user_agent(),
449            config.client_id.clone(),
450            None,
451            config.get_redirect_uri(),
452        );
453
454        println!("{}", auth.redirect_url);
455
456        // create and open url
457        let url = auth.get_auth_url();
458        open(&url).unwrap();
459
460        // wait for redirect
461        let mut auth = redirect::Server::new(config.get_user_agent(), auth)
462            .go()
463            .unwrap();
464
465        // get access token
466        auth.get_access_token().unwrap();
467        println!("{}", serde_json::to_string(&auth).unwrap());
468
469        // get refresh token
470        auth.refresh().unwrap();
471        println!("{}", serde_json::to_string(&auth).unwrap());
472
473        cache::cache_auth(&auth);
474    }
475
476    #[test]
477    fn test_challenge() {
478        let challenge = OAuth::new_challenge(CODE_CHALLENGE_LENGTH);
479
480        assert!(challenge.len() == CODE_CHALLENGE_LENGTH);
481        println!("{}", challenge);
482        println!(
483            "len: {}, CODE_CHALLENGE_LEN: {}",
484            challenge.len(),
485            CODE_CHALLENGE_LENGTH
486        );
487    }
488    #[test]
489    #[should_panic(expected = "len is not in between 48 and 128")]
490    fn test_challenge_len() {
491        // should panic
492        let _challenge = OAuth::new_challenge(5);
493    }
494}