acick_dropbox/
authorizer.rs

1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::io::Write;
4use std::net::SocketAddr;
5
6use anyhow::Context as _;
7use dropbox_sdk::check::{self, EchoArg};
8use hyper::service::{make_service_fn, service_fn};
9use hyper::{Body, Method, Request, Response, Server, StatusCode, Uri};
10use rand::distributions::Alphanumeric;
11use rand::{thread_rng, Rng as _};
12use serde::{Deserialize, Serialize};
13use tokio::sync::broadcast::{self, Sender};
14use url::form_urlencoded;
15
16use crate::abs_path::AbsPathBuf;
17use crate::hyper_client::{HyperClient, Oauth2AuthorizeUrlBuilder, Oauth2Type};
18use crate::web::open_in_browser;
19use crate::Result;
20use crate::{convert_dbx_err, Dropbox};
21
22static STATE_LEN: usize = 16;
23static DBX_CODE_PARAM: &str = "code";
24static DBX_STATE_PARAM: &str = "state";
25
26#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
27pub struct Token {
28    pub access_token: String,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub struct DbxAuthorizer<'a> {
33    app_key: &'a str,
34    app_secret: &'a str,
35    redirect_port: u16,
36    redirect_path: &'a str,
37    redirect_uri: String,
38    token_path: &'a AbsPathBuf,
39}
40
41impl<'a> DbxAuthorizer<'a> {
42    pub fn new(
43        app_key: &'a str,
44        app_secret: &'a str,
45        redirect_port: u16,
46        redirect_path: &'a str,
47        token_path: &'a AbsPathBuf,
48    ) -> Self {
49        Self {
50            app_key,
51            app_secret,
52            redirect_port,
53            redirect_path,
54            redirect_uri: format!("http://localhost:{}{}", redirect_port, redirect_path),
55            token_path,
56        }
57    }
58
59    pub fn load_or_request(
60        &self,
61        access_token: Option<String>,
62        cnsl: &mut dyn Write,
63    ) -> Result<Dropbox> {
64        let load_result = self.load_token(access_token, cnsl)?;
65        let (token, is_updated) = match load_result {
66            Some(token) if Self::validate_token(&token)? => (token, false),
67            _ => (self.request_token(cnsl)?, true),
68        };
69
70        if is_updated {
71            self.save_token(&token, cnsl)?;
72        }
73
74        Ok(Dropbox::new(token))
75    }
76
77    fn load_token(
78        &self,
79        access_token: Option<String>,
80        cnsl: &mut dyn Write,
81    ) -> Result<Option<Token>> {
82        if let Some(access_token) = access_token {
83            return Ok(Some(Token { access_token }));
84        }
85
86        if !self.token_path.as_ref().exists() {
87            return Ok(None);
88        }
89
90        let token = self.token_path.load_pretty(
91            |file| serde_json::from_reader(file).context("Could not load token from json file"),
92            None,
93            cnsl,
94        )?;
95
96        Ok(Some(token))
97    }
98
99    fn save_token(&self, token: &Token, cnsl: &mut dyn Write) -> Result<()> {
100        self.token_path.save_pretty(
101            |file| serde_json::to_writer(file, token).context("Could not save token as json file"),
102            true,
103            None,
104            cnsl,
105        )?;
106
107        Ok(())
108    }
109
110    fn validate_token(token: &Token) -> Result<bool> {
111        let client = HyperClient::new(token.access_token.clone());
112        match check::user(&client, &EchoArg { query: "".into() }) {
113            Ok(Ok(_)) => Ok(true),
114            Ok(Err(())) => Ok(false),
115            Err(dropbox_sdk::Error::InvalidToken(_)) => Ok(false),
116            Err(err) => Err(convert_dbx_err(err)),
117        }
118        .context("Could not validate access token")
119    }
120
121    #[tokio::main]
122    async fn request_token(&self, cnsl: &mut dyn Write) -> Result<Token> {
123        let state = gen_random_state();
124        let code = self
125            .authorize(state, cnsl)
126            .await
127            .context("Could not authorize acick on Dropbox")?;
128        let access_token = HyperClient::oauth2_token_from_authorization_code(
129            self.app_key,
130            self.app_secret,
131            &code,
132            Some(&self.redirect_uri),
133        )
134        .map_err(convert_dbx_err)
135        .context("Could not get access token from Dropbox")?;
136
137        Ok(Token { access_token })
138    }
139
140    async fn authorize(&self, state: String, cnsl: &mut dyn Write) -> Result<String> {
141        let (tx, mut rx) = broadcast::channel::<String>(1);
142
143        // start local server
144        let addr = SocketAddr::from(([127, 0, 0, 1], self.redirect_port));
145        let make_service = make_service_fn(|_conn| {
146            let redirect_path = self.redirect_path.to_owned();
147            let state = state.clone();
148            let tx = tx.clone();
149            async {
150                Ok::<_, Infallible>(service_fn(move |req| {
151                    respond(req, redirect_path.clone(), state.clone(), tx.clone())
152                }))
153            }
154        });
155        let server = Server::bind(&addr).serve(make_service);
156
157        // open auth url in browser
158        let auth_url = Oauth2AuthorizeUrlBuilder::new(self.app_key, Oauth2Type::AuthorizationCode)
159            .redirect_uri(&self.redirect_uri)
160            .state(&state)
161            .build();
162        open_in_browser(auth_url.as_str())
163            .context("Could not open a url in browser")
164            // coerce error
165            .unwrap_or_else(|err| writeln!(cnsl, "{}", err).unwrap_or(()));
166        writeln!(cnsl, "Authorize Dropbox in web browser.")?;
167
168        // wait for code to arrive and shutdown server
169        let graceful = server.with_graceful_shutdown(async {
170            let mut rx = tx.subscribe();
171            rx.recv().await.unwrap();
172        });
173        graceful.await?;
174
175        Ok(rx.recv().await?)
176    }
177}
178
179fn gen_random_state() -> String {
180    thread_rng()
181        .sample_iter(&Alphanumeric)
182        .take(STATE_LEN)
183        .collect()
184}
185
186fn get_params(uri: &Uri) -> HashMap<String, String> {
187    uri.query()
188        .map(|query_str| {
189            form_urlencoded::parse(query_str.as_bytes())
190                .into_owned()
191                .collect()
192        })
193        .unwrap_or_else(HashMap::new)
194}
195
196fn respond_param_missing(name: &str) -> Response<Body> {
197    Response::builder()
198        .status(StatusCode::BAD_REQUEST)
199        .body(Body::from(format!("Missing parameter: {}", name)))
200        .unwrap()
201}
202
203fn respond_param_invalid(name: &str) -> Response<Body> {
204    Response::builder()
205        .status(StatusCode::BAD_REQUEST)
206        .body(Body::from(format!("Invalid parameter: {}", name)))
207        .unwrap()
208}
209
210fn respond_not_found() -> Response<Body> {
211    Response::builder()
212        .status(StatusCode::NOT_FOUND)
213        .body(Body::from("Not Found"))
214        .unwrap()
215}
216
217fn handle_callback(req: Request<Body>, tx: Sender<String>, state_expected: &str) -> Response<Body> {
218    let mut params = get_params(req.uri());
219    let code = match params.remove(DBX_CODE_PARAM) {
220        Some(code) => code,
221        None => return respond_param_missing(DBX_CODE_PARAM),
222    };
223    let state = match params.remove(DBX_STATE_PARAM) {
224        Some(state) => state,
225        None => return respond_param_missing(DBX_STATE_PARAM),
226    };
227    if state != state_expected {
228        return respond_param_invalid(DBX_STATE_PARAM);
229    }
230
231    // send auth code to Authorizer
232    tx.send(code).unwrap_or(0);
233
234    Response::new(Body::from(
235        "Successfully completed authorization. Go back to acick on your terminal.",
236    ))
237}
238
239async fn respond(
240    req: Request<Body>,
241    redirect_path: String,
242    state: String,
243    tx: Sender<String>,
244) -> std::result::Result<Response<Body>, Infallible> {
245    if req.method() == Method::GET && req.uri().path() == redirect_path {
246        return Ok(handle_callback(req, tx, &state));
247    }
248    Ok(respond_not_found())
249}
250
251#[cfg(test)]
252mod tests {
253    use tempfile::{tempdir, TempDir};
254
255    use super::*;
256
257    macro_rules! map(
258        { $($key:expr => $value:expr),+ } => {
259            {
260                let mut m = ::std::collections::HashMap::new();
261                $(
262                    m.insert($key, $value);
263                )+
264                m
265            }
266         };
267    );
268
269    fn run_test(f: fn(test_dir: &TempDir, authorizer: DbxAuthorizer) -> anyhow::Result<()>) {
270        let test_dir = tempdir().unwrap();
271        let token_path = AbsPathBuf::try_new(test_dir.path().join("dbx_token.json")).unwrap();
272        let authorizer = DbxAuthorizer::new("test_key", "test_secret", 4100, "/path", &token_path);
273        f(&test_dir, authorizer).unwrap();
274    }
275
276    #[test]
277    fn test_load_token() {
278        run_test(|_, authorizer| {
279            let access_token = "test_token".to_string();
280            let token = Token {
281                access_token: access_token.clone(),
282            };
283            let mut buf = Vec::new();
284
285            let actual = authorizer.load_token(Some(access_token), &mut buf)?;
286            let expected = Some(token);
287            assert_eq!(actual, expected);
288
289            assert_eq!(authorizer.load_token(None, &mut buf)?, None);
290
291            let token_path = authorizer.token_path.as_ref();
292            let mut file = std::fs::File::create(token_path)?;
293            file.write_all(br#"{"access_token": "test_token"}"#)?;
294
295            let actual = authorizer.load_token(None, &mut buf)?;
296            assert_eq!(actual, expected);
297
298            Ok(())
299        })
300    }
301
302    #[test]
303    fn test_save_token() {
304        run_test(|_, authorizer| {
305            let access_token = "test_token".to_string();
306            let token = Token { access_token };
307            let mut buf = Vec::<u8>::new();
308            authorizer.save_token(&token, &mut buf)?;
309            let token_str = std::fs::read_to_string(authorizer.token_path.as_ref())?;
310            assert_eq!(token_str, r#"{"access_token":"test_token"}"#);
311            Ok(())
312        })
313    }
314
315    #[test]
316    fn test_validate_token() -> anyhow::Result<()> {
317        let access_token = std::env::var("ACICK_DBX_ACCESS_TOKEN")?;
318        assert!(DbxAuthorizer::validate_token(&Token { access_token })?);
319        assert!(!DbxAuthorizer::validate_token(&Token {
320            access_token: "test_token".into()
321        })?);
322        Ok(())
323    }
324
325    #[tokio::test]
326    async fn test_authorize() -> anyhow::Result<()> {
327        let test_dir = tempdir().unwrap();
328        let token_path = AbsPathBuf::try_new(test_dir.path().join("dbx_token.json")).unwrap();
329        let authorizer = DbxAuthorizer::new("test_key", "test_secret", 4100, "/path", &token_path);
330        let mut buf = Vec::<u8>::new();
331        let future = authorizer.authorize("test_state".to_string(), &mut buf);
332
333        tokio::spawn(async {
334            let client = hyper::Client::new();
335            let uri =
336                Uri::from_static("http://localhost:4100/path?code=test_code&state=test_state");
337            client.get(uri).await.unwrap();
338        });
339
340        let code = future.await?;
341        assert_eq!(code, "test_code");
342        Ok(())
343    }
344
345    #[test]
346    fn test_gen_random_state() {
347        assert_eq!(gen_random_state().len(), STATE_LEN);
348        assert_ne!(gen_random_state(), gen_random_state());
349    }
350
351    #[test]
352    fn test_get_params() {
353        let tests = &[
354            (Uri::from_static("http://example.com/"), HashMap::new()),
355            (Uri::from_static("http://example.com/?"), HashMap::new()),
356            (
357                Uri::from_static("http://example.com/?hoge=fuga&foo=bar"),
358                map!(String::from("hoge") => String::from("fuga"), String::from("foo") => String::from("bar")),
359            ),
360        ];
361
362        for (left, expected) in tests {
363            let actual = get_params(left);
364            assert_eq!(&actual, expected);
365        }
366    }
367
368    #[tokio::test]
369    async fn test_respond() -> anyhow::Result<()> {
370        let tests = &[
371            ("/path?code=test_code&state=test_state", StatusCode::OK),
372            ("/path", StatusCode::BAD_REQUEST),
373            ("/path?code=test_code", StatusCode::BAD_REQUEST),
374            (
375                "/path?code=test_code&state=invalid_state",
376                StatusCode::BAD_REQUEST,
377            ),
378            (
379                "/invalid_path?code=test_code&state=test_state",
380                StatusCode::NOT_FOUND,
381            ),
382        ];
383
384        for (left, expected) in tests {
385            let (tx, mut rx) = broadcast::channel::<String>(2);
386            let req = Request::get(format!("http://localhost:4100{}", left)).body(Body::empty())?;
387            let redirect_path = "/path".to_string();
388            let state = "test_state".to_string();
389            let res = respond(req, redirect_path, state, tx).await?;
390            assert_eq!(res.status(), *expected);
391            if res.status() == StatusCode::OK {
392                let code = rx.recv().await?;
393                assert_eq!(code, "test_code");
394            }
395        }
396        Ok(())
397    }
398}