1#![warn(missing_docs)]
2use std::{
15 io::{self, BufRead, BufReader, Write},
16 net::{SocketAddr, TcpListener},
17 sync::mpsc,
18 time::{Duration, Instant},
19};
20
21use oauth2::{
22 AuthUrl, AuthorizationCode, ClientId, CsrfToken, EmptyExtraTokenFields, EndpointNotSet,
23 EndpointSet, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope,
24 StandardTokenResponse, TokenResponse, TokenUrl, basic::BasicClient, basic::BasicTokenType,
25};
26
27use log::{error, info, trace};
28use thiserror::Error;
29use url::Url;
30
31#[cfg(all(feature = "native-tls", feature = "__rustls"))]
43compile_error!(
44 "Feature \"native-tls\" is mutually exclusive with \"rustls-tls-native-roots\" and \"rustls-tls-webpki-roots\". Enable only one."
45);
46
47#[cfg(not(any(feature = "native-tls", feature = "__rustls")))]
48compile_error!(
49 "Either feature \"native-tls\" (default), \"rustls-tls-native-roots\" or \"rustls-tls-webpki-roots\" must be enabled for this crate."
50);
51
52#[derive(Debug, Error)]
54pub enum OAuthError {
55 #[error("Unable to parse redirect URI {uri} ({e})")]
57 AuthCodeBadUri {
58 uri: String,
60 e: url::ParseError,
62 },
63
64 #[error("Auth code param not found in URI {uri}")]
66 AuthCodeNotFound {
67 uri: String,
69 },
70
71 #[error("Failed to read redirect URI from stdin")]
73 AuthCodeStdinRead,
74
75 #[error("Failed to bind server to {addr} ({e})")]
77 AuthCodeListenerBind {
78 addr: SocketAddr,
80 e: io::Error,
82 },
83
84 #[error("Listener terminated without accepting a connection")]
86 AuthCodeListenerTerminated,
87
88 #[error("Failed to read redirect URI from HTTP request")]
90 AuthCodeListenerRead,
91
92 #[error("Failed to parse redirect URI from HTTP request")]
94 AuthCodeListenerParse,
95
96 #[error("Failed to write HTTP response")]
98 AuthCodeListenerWrite,
99
100 #[error("Invalid Spotify OAuth URI")]
102 InvalidSpotifyUri,
103
104 #[error("Invalid Redirect URI {uri} ({e})")]
106 InvalidRedirectUri {
107 uri: String,
109 e: url::ParseError,
111 },
112
113 #[error("Failed to receive code")]
115 Recv,
116
117 #[error("Failed to exchange code for access token ({e})")]
119 ExchangeCode {
120 e: String,
122 },
123}
124
125#[derive(Debug, Clone)]
127pub struct OAuthToken {
128 pub access_token: String,
130 pub refresh_token: String,
132 pub expires_at: Instant,
134 pub token_type: String,
136 pub scopes: Vec<String>,
138}
139
140fn get_code(redirect_url: &str) -> Result<AuthorizationCode, OAuthError> {
142 let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri {
143 uri: redirect_url.to_string(),
144 e,
145 })?;
146 let code = url
147 .query_pairs()
148 .find(|(key, _)| key == "code")
149 .map(|(_, code)| AuthorizationCode::new(code.into_owned()))
150 .ok_or(OAuthError::AuthCodeNotFound {
151 uri: redirect_url.to_string(),
152 })?;
153
154 Ok(code)
155}
156
157fn get_authcode_stdin() -> Result<AuthorizationCode, OAuthError> {
159 println!("Provide redirect URL");
160 let mut buffer = String::new();
161 let stdin = io::stdin();
162 stdin
163 .read_line(&mut buffer)
164 .map_err(|_| OAuthError::AuthCodeStdinRead)?;
165
166 get_code(buffer.trim())
167}
168
169fn get_authcode_listener(
171 socket_address: SocketAddr,
172 message: String,
173) -> Result<AuthorizationCode, OAuthError> {
174 let listener =
175 TcpListener::bind(socket_address).map_err(|e| OAuthError::AuthCodeListenerBind {
176 addr: socket_address,
177 e,
178 })?;
179 info!("OAuth server listening on {socket_address:?}");
180
181 let mut stream = listener
183 .incoming()
184 .flatten()
185 .next()
186 .ok_or(OAuthError::AuthCodeListenerTerminated)?;
187 let mut reader = BufReader::new(&stream);
188 let mut request_line = String::new();
189 reader
190 .read_line(&mut request_line)
191 .map_err(|_| OAuthError::AuthCodeListenerRead)?;
192
193 let redirect_url = request_line
194 .split_whitespace()
195 .nth(1)
196 .ok_or(OAuthError::AuthCodeListenerParse)?;
197 let code = get_code(&("http://localhost".to_string() + redirect_url));
198
199 let response = format!(
200 "HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
201 message.len(),
202 message
203 );
204 stream
205 .write_all(response.as_bytes())
206 .map_err(|_| OAuthError::AuthCodeListenerWrite)?;
207
208 code
209}
210
211fn get_socket_address(redirect_uri: &str) -> Option<SocketAddr> {
214 let url = match Url::parse(redirect_uri) {
215 Ok(u) if u.scheme() == "http" && u.port().is_some() => u,
216 _ => return None,
217 };
218 match url.socket_addrs(|| None) {
219 Ok(mut addrs) => addrs.pop(),
220 _ => None,
221 }
222}
223
224pub struct OAuthClient {
226 scopes: Vec<String>,
227 redirect_uri: String,
228 should_open_url: bool,
229 message: String,
230 client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,
231}
232
233impl OAuthClient {
234 fn set_auth_url(&self) -> PkceCodeVerifier {
238 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
239 let request_scopes: Vec<oauth2::Scope> =
242 self.scopes.iter().map(|s| Scope::new(s.into())).collect();
243 let (auth_url, _) = self
244 .client
245 .authorize_url(CsrfToken::new_random)
246 .add_scopes(request_scopes)
247 .set_pkce_challenge(pkce_challenge)
248 .url();
249
250 if self.should_open_url {
251 open::that_in_background(auth_url.as_str());
252 }
253 println!("Browse to: {auth_url}");
254
255 pkce_verifier
256 }
257
258 fn build_token(
259 &self,
260 resp: StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
261 ) -> Result<OAuthToken, OAuthError> {
262 trace!("Obtained new access token: {resp:?}");
263
264 let token_scopes: Vec<String> = match resp.scopes() {
265 Some(s) => s.iter().map(|s| s.to_string()).collect(),
266 _ => self.scopes.clone(),
267 };
268 let refresh_token = match resp.refresh_token() {
269 Some(t) => t.secret().to_string(),
270 _ => "".to_string(), };
272 Ok(OAuthToken {
273 access_token: resp.access_token().secret().to_string(),
274 refresh_token,
275 expires_at: Instant::now()
276 + resp
277 .expires_in()
278 .unwrap_or_else(|| Duration::from_secs(3600)),
279 token_type: format!("{:?}", resp.token_type()),
280 scopes: token_scopes,
281 })
282 }
283
284 pub fn get_access_token(&self) -> Result<OAuthToken, OAuthError> {
286 let pkce_verifier = self.set_auth_url();
287
288 let code = match get_socket_address(&self.redirect_uri) {
289 Some(addr) => get_authcode_listener(addr, self.message.clone()),
290 _ => get_authcode_stdin(),
291 }?;
292 trace!("Exchange {code:?} for access token");
293
294 let (tx, rx) = mpsc::channel();
295 let client = self.client.clone();
296 std::thread::spawn(move || {
297 let http_client = reqwest::blocking::Client::new();
298 let resp = client
299 .exchange_code(code)
300 .set_pkce_verifier(pkce_verifier)
301 .request(&http_client);
302 if let Err(e) = tx.send(resp) {
303 error!("OAuth channel send error: {e}");
304 }
305 });
306 let channel_response = rx.recv().map_err(|_| OAuthError::Recv)?;
307 let token_response =
308 channel_response.map_err(|e| OAuthError::ExchangeCode { e: e.to_string() })?;
309
310 self.build_token(token_response)
311 }
312
313 pub fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, OAuthError> {
315 let refresh_token = RefreshToken::new(refresh_token.to_string());
316 let http_client = reqwest::blocking::Client::new();
317 let resp = self
318 .client
319 .exchange_refresh_token(&refresh_token)
320 .request(&http_client);
321
322 let resp = resp.map_err(|e| OAuthError::ExchangeCode { e: e.to_string() })?;
323 self.build_token(resp)
324 }
325
326 pub async fn get_access_token_async(&self) -> Result<OAuthToken, OAuthError> {
328 let pkce_verifier = self.set_auth_url();
329
330 let code = match get_socket_address(&self.redirect_uri) {
331 Some(addr) => get_authcode_listener(addr, self.message.clone()),
332 _ => get_authcode_stdin(),
333 }?;
334 trace!("Exchange {code:?} for access token");
335
336 let http_client = reqwest::Client::new();
337 let resp = self
338 .client
339 .exchange_code(code)
340 .set_pkce_verifier(pkce_verifier)
341 .request_async(&http_client)
342 .await;
343
344 let resp = resp.map_err(|e| OAuthError::ExchangeCode { e: e.to_string() })?;
345 self.build_token(resp)
346 }
347
348 pub async fn refresh_token_async(&self, refresh_token: &str) -> Result<OAuthToken, OAuthError> {
350 let refresh_token = RefreshToken::new(refresh_token.to_string());
351 let http_client = reqwest::Client::new();
352 let resp = self
353 .client
354 .exchange_refresh_token(&refresh_token)
355 .request_async(&http_client)
356 .await;
357
358 let resp = resp.map_err(|e| OAuthError::ExchangeCode { e: e.to_string() })?;
359 self.build_token(resp)
360 }
361}
362
363pub struct OAuthClientBuilder {
365 client_id: String,
366 redirect_uri: String,
367 scopes: Vec<String>,
368 should_open_url: bool,
369 message: String,
370}
371
372impl OAuthClientBuilder {
373 pub fn new(client_id: &str, redirect_uri: &str, scopes: Vec<&str>) -> Self {
377 Self {
378 client_id: client_id.to_string(),
379 redirect_uri: redirect_uri.to_string(),
380 scopes: scopes.into_iter().map(Into::into).collect(),
381 should_open_url: false,
382 message: String::from("Go back to your terminal :)"),
383 }
384 }
385
386 pub fn open_in_browser(mut self) -> Self {
389 self.should_open_url = true;
390 self
391 }
392
393 pub fn with_custom_message(mut self, message: &str) -> Self {
396 self.message = message.to_string();
397 self
398 }
399
400 pub fn build(self) -> Result<OAuthClient, OAuthError> {
402 let auth_url = AuthUrl::new("https://accounts.spotify.com/authorize".to_string())
403 .map_err(|_| OAuthError::InvalidSpotifyUri)?;
404 let token_url = TokenUrl::new("https://accounts.spotify.com/api/token".to_string())
405 .map_err(|_| OAuthError::InvalidSpotifyUri)?;
406 let redirect_url = RedirectUrl::new(self.redirect_uri.clone()).map_err(|e| {
407 OAuthError::InvalidRedirectUri {
408 uri: self.redirect_uri.clone(),
409 e,
410 }
411 })?;
412
413 let client = BasicClient::new(ClientId::new(self.client_id.to_string()))
414 .set_auth_uri(auth_url)
415 .set_token_uri(token_url)
416 .set_redirect_uri(redirect_url);
417
418 Ok(OAuthClient {
419 scopes: self.scopes,
420 should_open_url: self.should_open_url,
421 message: self.message,
422 redirect_uri: self.redirect_uri,
423 client,
424 })
425 }
426}
427
428#[deprecated(
431 since = "0.7.0",
432 note = "please use builder pattern with `OAuthClientBuilder` instead"
433)]
434pub fn get_access_token(
437 client_id: &str,
438 redirect_uri: &str,
439 scopes: Vec<&str>,
440) -> Result<OAuthToken, OAuthError> {
441 let auth_url = AuthUrl::new("https://accounts.spotify.com/authorize".to_string())
442 .map_err(|_| OAuthError::InvalidSpotifyUri)?;
443 let token_url = TokenUrl::new("https://accounts.spotify.com/api/token".to_string())
444 .map_err(|_| OAuthError::InvalidSpotifyUri)?;
445 let redirect_url =
446 RedirectUrl::new(redirect_uri.to_string()).map_err(|e| OAuthError::InvalidRedirectUri {
447 uri: redirect_uri.to_string(),
448 e,
449 })?;
450 let client = BasicClient::new(ClientId::new(client_id.to_string()))
451 .set_auth_uri(auth_url)
452 .set_token_uri(token_url)
453 .set_redirect_uri(redirect_url);
454
455 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
456
457 let request_scopes: Vec<oauth2::Scope> = scopes
460 .clone()
461 .into_iter()
462 .map(|s| Scope::new(s.into()))
463 .collect();
464 let (auth_url, _) = client
465 .authorize_url(CsrfToken::new_random)
466 .add_scopes(request_scopes)
467 .set_pkce_challenge(pkce_challenge)
468 .url();
469
470 println!("Browse to: {auth_url}");
471
472 let code = match get_socket_address(redirect_uri) {
473 Some(addr) => get_authcode_listener(addr, String::from("Go back to your terminal :)")),
474 _ => get_authcode_stdin(),
475 }?;
476 trace!("Exchange {code:?} for access token");
477
478 let (tx, rx) = mpsc::channel();
480 std::thread::spawn(move || {
481 let http_client = reqwest::blocking::Client::new();
482 let resp = client
483 .exchange_code(code)
484 .set_pkce_verifier(pkce_verifier)
485 .request(&http_client);
486 if let Err(e) = tx.send(resp) {
487 error!("OAuth channel send error: {e}");
488 }
489 });
490 let token_response = rx.recv().map_err(|_| OAuthError::Recv)?;
491 let token = token_response.map_err(|e| OAuthError::ExchangeCode { e: e.to_string() })?;
492 trace!("Obtained new access token: {token:?}");
493
494 let token_scopes: Vec<String> = match token.scopes() {
495 Some(s) => s.iter().map(|s| s.to_string()).collect(),
496 _ => scopes.into_iter().map(|s| s.to_string()).collect(),
497 };
498 let refresh_token = match token.refresh_token() {
499 Some(t) => t.secret().to_string(),
500 _ => "".to_string(), };
502 Ok(OAuthToken {
503 access_token: token.access_token().secret().to_string(),
504 refresh_token,
505 expires_at: Instant::now()
506 + token
507 .expires_in()
508 .unwrap_or_else(|| Duration::from_secs(3600)),
509 token_type: format!("{:?}", token.token_type()).to_string(), scopes: token_scopes,
511 })
512}
513
514#[cfg(test)]
515mod test {
516 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
517
518 use super::*;
519
520 #[test]
521 fn get_socket_address_none() {
522 assert_eq!(get_socket_address("http://127.0.0.1/foo"), None);
524 assert_eq!(get_socket_address("http://127.0.0.1:/foo"), None);
525 assert_eq!(get_socket_address("http://[::1]/foo"), None);
526 assert_eq!(get_socket_address("https://127.0.0.1/foo"), None);
528 }
529
530 #[test]
531 fn get_socket_address_some() {
532 let localhost_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234);
533 let localhost_v6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8888);
534 let addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 1234);
535 let addr_v6 = SocketAddr::new(
536 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
537 8888,
538 );
539
540 assert_eq!(
542 get_socket_address("http://127.0.0.1:1234/foo"),
543 Some(localhost_v4)
544 );
545 assert_eq!(
546 get_socket_address("http://[0:0:0:0:0:0:0:1]:8888/foo"),
547 Some(localhost_v6)
548 );
549 assert_eq!(
550 get_socket_address("http://[::1]:8888/foo"),
551 Some(localhost_v6)
552 );
553
554 assert_eq!(get_socket_address("http://8.8.8.8:1234/foo"), Some(addr_v4));
556 assert_eq!(
557 get_socket_address("http://[2001:4860:4860::8888]:8888/foo"),
558 Some(addr_v6)
559 );
560 }
561}