1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::io::Write;
4use std::net::SocketAddr;
5
6use anyhow::Context as _;
7use dropbox_sdk::check::{self, EchoArg};
8use hyper::service::{make_service_fn, service_fn};
9use hyper::{Body, Method, Request, Response, Server, StatusCode, Uri};
10use rand::distributions::Alphanumeric;
11use rand::{thread_rng, Rng as _};
12use serde::{Deserialize, Serialize};
13use tokio::sync::broadcast::{self, Sender};
14use url::form_urlencoded;
15
16use crate::abs_path::AbsPathBuf;
17use crate::hyper_client::{HyperClient, Oauth2AuthorizeUrlBuilder, Oauth2Type};
18use crate::web::open_in_browser;
19use crate::Result;
20use crate::{convert_dbx_err, Dropbox};
21
22static STATE_LEN: usize = 16;
23static DBX_CODE_PARAM: &str = "code";
24static DBX_STATE_PARAM: &str = "state";
25
26#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
27pub struct Token {
28 pub access_token: String,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub struct DbxAuthorizer<'a> {
33 app_key: &'a str,
34 app_secret: &'a str,
35 redirect_port: u16,
36 redirect_path: &'a str,
37 redirect_uri: String,
38 token_path: &'a AbsPathBuf,
39}
40
41impl<'a> DbxAuthorizer<'a> {
42 pub fn new(
43 app_key: &'a str,
44 app_secret: &'a str,
45 redirect_port: u16,
46 redirect_path: &'a str,
47 token_path: &'a AbsPathBuf,
48 ) -> Self {
49 Self {
50 app_key,
51 app_secret,
52 redirect_port,
53 redirect_path,
54 redirect_uri: format!("http://localhost:{}{}", redirect_port, redirect_path),
55 token_path,
56 }
57 }
58
59 pub fn load_or_request(
60 &self,
61 access_token: Option<String>,
62 cnsl: &mut dyn Write,
63 ) -> Result<Dropbox> {
64 let load_result = self.load_token(access_token, cnsl)?;
65 let (token, is_updated) = match load_result {
66 Some(token) if Self::validate_token(&token)? => (token, false),
67 _ => (self.request_token(cnsl)?, true),
68 };
69
70 if is_updated {
71 self.save_token(&token, cnsl)?;
72 }
73
74 Ok(Dropbox::new(token))
75 }
76
77 fn load_token(
78 &self,
79 access_token: Option<String>,
80 cnsl: &mut dyn Write,
81 ) -> Result<Option<Token>> {
82 if let Some(access_token) = access_token {
83 return Ok(Some(Token { access_token }));
84 }
85
86 if !self.token_path.as_ref().exists() {
87 return Ok(None);
88 }
89
90 let token = self.token_path.load_pretty(
91 |file| serde_json::from_reader(file).context("Could not load token from json file"),
92 None,
93 cnsl,
94 )?;
95
96 Ok(Some(token))
97 }
98
99 fn save_token(&self, token: &Token, cnsl: &mut dyn Write) -> Result<()> {
100 self.token_path.save_pretty(
101 |file| serde_json::to_writer(file, token).context("Could not save token as json file"),
102 true,
103 None,
104 cnsl,
105 )?;
106
107 Ok(())
108 }
109
110 fn validate_token(token: &Token) -> Result<bool> {
111 let client = HyperClient::new(token.access_token.clone());
112 match check::user(&client, &EchoArg { query: "".into() }) {
113 Ok(Ok(_)) => Ok(true),
114 Ok(Err(())) => Ok(false),
115 Err(dropbox_sdk::Error::InvalidToken(_)) => Ok(false),
116 Err(err) => Err(convert_dbx_err(err)),
117 }
118 .context("Could not validate access token")
119 }
120
121 #[tokio::main]
122 async fn request_token(&self, cnsl: &mut dyn Write) -> Result<Token> {
123 let state = gen_random_state();
124 let code = self
125 .authorize(state, cnsl)
126 .await
127 .context("Could not authorize acick on Dropbox")?;
128 let access_token = HyperClient::oauth2_token_from_authorization_code(
129 self.app_key,
130 self.app_secret,
131 &code,
132 Some(&self.redirect_uri),
133 )
134 .map_err(convert_dbx_err)
135 .context("Could not get access token from Dropbox")?;
136
137 Ok(Token { access_token })
138 }
139
140 async fn authorize(&self, state: String, cnsl: &mut dyn Write) -> Result<String> {
141 let (tx, mut rx) = broadcast::channel::<String>(1);
142
143 let addr = SocketAddr::from(([127, 0, 0, 1], self.redirect_port));
145 let make_service = make_service_fn(|_conn| {
146 let redirect_path = self.redirect_path.to_owned();
147 let state = state.clone();
148 let tx = tx.clone();
149 async {
150 Ok::<_, Infallible>(service_fn(move |req| {
151 respond(req, redirect_path.clone(), state.clone(), tx.clone())
152 }))
153 }
154 });
155 let server = Server::bind(&addr).serve(make_service);
156
157 let auth_url = Oauth2AuthorizeUrlBuilder::new(self.app_key, Oauth2Type::AuthorizationCode)
159 .redirect_uri(&self.redirect_uri)
160 .state(&state)
161 .build();
162 open_in_browser(auth_url.as_str())
163 .context("Could not open a url in browser")
164 .unwrap_or_else(|err| writeln!(cnsl, "{}", err).unwrap_or(()));
166 writeln!(cnsl, "Authorize Dropbox in web browser.")?;
167
168 let graceful = server.with_graceful_shutdown(async {
170 let mut rx = tx.subscribe();
171 rx.recv().await.unwrap();
172 });
173 graceful.await?;
174
175 Ok(rx.recv().await?)
176 }
177}
178
179fn gen_random_state() -> String {
180 thread_rng()
181 .sample_iter(&Alphanumeric)
182 .take(STATE_LEN)
183 .collect()
184}
185
186fn get_params(uri: &Uri) -> HashMap<String, String> {
187 uri.query()
188 .map(|query_str| {
189 form_urlencoded::parse(query_str.as_bytes())
190 .into_owned()
191 .collect()
192 })
193 .unwrap_or_else(HashMap::new)
194}
195
196fn respond_param_missing(name: &str) -> Response<Body> {
197 Response::builder()
198 .status(StatusCode::BAD_REQUEST)
199 .body(Body::from(format!("Missing parameter: {}", name)))
200 .unwrap()
201}
202
203fn respond_param_invalid(name: &str) -> Response<Body> {
204 Response::builder()
205 .status(StatusCode::BAD_REQUEST)
206 .body(Body::from(format!("Invalid parameter: {}", name)))
207 .unwrap()
208}
209
210fn respond_not_found() -> Response<Body> {
211 Response::builder()
212 .status(StatusCode::NOT_FOUND)
213 .body(Body::from("Not Found"))
214 .unwrap()
215}
216
217fn handle_callback(req: Request<Body>, tx: Sender<String>, state_expected: &str) -> Response<Body> {
218 let mut params = get_params(req.uri());
219 let code = match params.remove(DBX_CODE_PARAM) {
220 Some(code) => code,
221 None => return respond_param_missing(DBX_CODE_PARAM),
222 };
223 let state = match params.remove(DBX_STATE_PARAM) {
224 Some(state) => state,
225 None => return respond_param_missing(DBX_STATE_PARAM),
226 };
227 if state != state_expected {
228 return respond_param_invalid(DBX_STATE_PARAM);
229 }
230
231 tx.send(code).unwrap_or(0);
233
234 Response::new(Body::from(
235 "Successfully completed authorization. Go back to acick on your terminal.",
236 ))
237}
238
239async fn respond(
240 req: Request<Body>,
241 redirect_path: String,
242 state: String,
243 tx: Sender<String>,
244) -> std::result::Result<Response<Body>, Infallible> {
245 if req.method() == Method::GET && req.uri().path() == redirect_path {
246 return Ok(handle_callback(req, tx, &state));
247 }
248 Ok(respond_not_found())
249}
250
251#[cfg(test)]
252mod tests {
253 use tempfile::{tempdir, TempDir};
254
255 use super::*;
256
257 macro_rules! map(
258 { $($key:expr => $value:expr),+ } => {
259 {
260 let mut m = ::std::collections::HashMap::new();
261 $(
262 m.insert($key, $value);
263 )+
264 m
265 }
266 };
267 );
268
269 fn run_test(f: fn(test_dir: &TempDir, authorizer: DbxAuthorizer) -> anyhow::Result<()>) {
270 let test_dir = tempdir().unwrap();
271 let token_path = AbsPathBuf::try_new(test_dir.path().join("dbx_token.json")).unwrap();
272 let authorizer = DbxAuthorizer::new("test_key", "test_secret", 4100, "/path", &token_path);
273 f(&test_dir, authorizer).unwrap();
274 }
275
276 #[test]
277 fn test_load_token() {
278 run_test(|_, authorizer| {
279 let access_token = "test_token".to_string();
280 let token = Token {
281 access_token: access_token.clone(),
282 };
283 let mut buf = Vec::new();
284
285 let actual = authorizer.load_token(Some(access_token), &mut buf)?;
286 let expected = Some(token);
287 assert_eq!(actual, expected);
288
289 assert_eq!(authorizer.load_token(None, &mut buf)?, None);
290
291 let token_path = authorizer.token_path.as_ref();
292 let mut file = std::fs::File::create(token_path)?;
293 file.write_all(br#"{"access_token": "test_token"}"#)?;
294
295 let actual = authorizer.load_token(None, &mut buf)?;
296 assert_eq!(actual, expected);
297
298 Ok(())
299 })
300 }
301
302 #[test]
303 fn test_save_token() {
304 run_test(|_, authorizer| {
305 let access_token = "test_token".to_string();
306 let token = Token { access_token };
307 let mut buf = Vec::<u8>::new();
308 authorizer.save_token(&token, &mut buf)?;
309 let token_str = std::fs::read_to_string(authorizer.token_path.as_ref())?;
310 assert_eq!(token_str, r#"{"access_token":"test_token"}"#);
311 Ok(())
312 })
313 }
314
315 #[test]
316 fn test_validate_token() -> anyhow::Result<()> {
317 let access_token = std::env::var("ACICK_DBX_ACCESS_TOKEN")?;
318 assert!(DbxAuthorizer::validate_token(&Token { access_token })?);
319 assert!(!DbxAuthorizer::validate_token(&Token {
320 access_token: "test_token".into()
321 })?);
322 Ok(())
323 }
324
325 #[tokio::test]
326 async fn test_authorize() -> anyhow::Result<()> {
327 let test_dir = tempdir().unwrap();
328 let token_path = AbsPathBuf::try_new(test_dir.path().join("dbx_token.json")).unwrap();
329 let authorizer = DbxAuthorizer::new("test_key", "test_secret", 4100, "/path", &token_path);
330 let mut buf = Vec::<u8>::new();
331 let future = authorizer.authorize("test_state".to_string(), &mut buf);
332
333 tokio::spawn(async {
334 let client = hyper::Client::new();
335 let uri =
336 Uri::from_static("http://localhost:4100/path?code=test_code&state=test_state");
337 client.get(uri).await.unwrap();
338 });
339
340 let code = future.await?;
341 assert_eq!(code, "test_code");
342 Ok(())
343 }
344
345 #[test]
346 fn test_gen_random_state() {
347 assert_eq!(gen_random_state().len(), STATE_LEN);
348 assert_ne!(gen_random_state(), gen_random_state());
349 }
350
351 #[test]
352 fn test_get_params() {
353 let tests = &[
354 (Uri::from_static("http://example.com/"), HashMap::new()),
355 (Uri::from_static("http://example.com/?"), HashMap::new()),
356 (
357 Uri::from_static("http://example.com/?hoge=fuga&foo=bar"),
358 map!(String::from("hoge") => String::from("fuga"), String::from("foo") => String::from("bar")),
359 ),
360 ];
361
362 for (left, expected) in tests {
363 let actual = get_params(left);
364 assert_eq!(&actual, expected);
365 }
366 }
367
368 #[tokio::test]
369 async fn test_respond() -> anyhow::Result<()> {
370 let tests = &[
371 ("/path?code=test_code&state=test_state", StatusCode::OK),
372 ("/path", StatusCode::BAD_REQUEST),
373 ("/path?code=test_code", StatusCode::BAD_REQUEST),
374 (
375 "/path?code=test_code&state=invalid_state",
376 StatusCode::BAD_REQUEST,
377 ),
378 (
379 "/invalid_path?code=test_code&state=test_state",
380 StatusCode::NOT_FOUND,
381 ),
382 ];
383
384 for (left, expected) in tests {
385 let (tx, mut rx) = broadcast::channel::<String>(2);
386 let req = Request::get(format!("http://localhost:4100{}", left)).body(Body::empty())?;
387 let redirect_path = "/path".to_string();
388 let state = "test_state".to_string();
389 let res = respond(req, redirect_path, state, tx).await?;
390 assert_eq!(res.status(), *expected);
391 if res.status() == StatusCode::OK {
392 let code = rx.recv().await?;
393 assert_eq!(code, "test_code");
394 }
395 }
396 Ok(())
397 }
398}