1use log::{error, info, trace};
14use oauth2::reqwest::http_client;
15use oauth2::{
16 basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge,
17 RedirectUrl, Scope, TokenResponse, TokenUrl,
18};
19use std::io;
20use std::time::{Duration, Instant};
21use std::{
22 io::{BufRead, BufReader, Write},
23 net::{SocketAddr, TcpListener},
24 sync::mpsc,
25};
26use thiserror::Error;
27use url::Url;
28
29#[derive(Debug, Error)]
30pub enum OAuthError {
31 #[error("Unable to parse redirect URI {uri} ({e})")]
32 AuthCodeBadUri { uri: String, e: url::ParseError },
33
34 #[error("Auth code param not found in URI {uri}")]
35 AuthCodeNotFound { uri: String },
36
37 #[error("Failed to read redirect URI from stdin")]
38 AuthCodeStdinRead,
39
40 #[error("Failed to bind server to {addr} ({e})")]
41 AuthCodeListenerBind { addr: SocketAddr, e: io::Error },
42
43 #[error("Listener terminated without accepting a connection")]
44 AuthCodeListenerTerminated,
45
46 #[error("Failed to read redirect URI from HTTP request")]
47 AuthCodeListenerRead,
48
49 #[error("Failed to parse redirect URI from HTTP request")]
50 AuthCodeListenerParse,
51
52 #[error("Failed to write HTTP response")]
53 AuthCodeListenerWrite,
54
55 #[error("Invalid Spotify OAuth URI")]
56 InvalidSpotifyUri,
57
58 #[error("Invalid Redirect URI {uri} ({e})")]
59 InvalidRedirectUri { uri: String, e: url::ParseError },
60
61 #[error("Failed to receive code")]
62 Recv,
63
64 #[error("Failed to exchange code for access token ({e})")]
65 ExchangeCode { e: String },
66}
67
68#[derive(Debug)]
69pub struct OAuthToken {
70 pub access_token: String,
71 pub refresh_token: String,
72 pub expires_at: Instant,
73 pub token_type: String,
74 pub scopes: Vec<String>,
75}
76
77fn get_code(redirect_url: &str) -> Result<AuthorizationCode, OAuthError> {
79 let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri {
80 uri: redirect_url.to_string(),
81 e,
82 })?;
83 let code = url
84 .query_pairs()
85 .find(|(key, _)| key == "code")
86 .map(|(_, code)| AuthorizationCode::new(code.into_owned()))
87 .ok_or(OAuthError::AuthCodeNotFound {
88 uri: redirect_url.to_string(),
89 })?;
90
91 Ok(code)
92}
93
94fn get_authcode_stdin() -> Result<AuthorizationCode, OAuthError> {
96 println!("Provide redirect URL");
97 let mut buffer = String::new();
98 let stdin = io::stdin();
99 stdin
100 .read_line(&mut buffer)
101 .map_err(|_| OAuthError::AuthCodeStdinRead)?;
102
103 get_code(buffer.trim())
104}
105
106fn get_authcode_listener(socket_address: SocketAddr) -> Result<AuthorizationCode, OAuthError> {
108 let listener =
109 TcpListener::bind(socket_address).map_err(|e| OAuthError::AuthCodeListenerBind {
110 addr: socket_address,
111 e,
112 })?;
113 info!("OAuth server listening on {:?}", socket_address);
114
115 let mut stream = listener
117 .incoming()
118 .flatten()
119 .next()
120 .ok_or(OAuthError::AuthCodeListenerTerminated)?;
121 let mut reader = BufReader::new(&stream);
122 let mut request_line = String::new();
123 reader
124 .read_line(&mut request_line)
125 .map_err(|_| OAuthError::AuthCodeListenerRead)?;
126
127 let redirect_url = request_line
128 .split_whitespace()
129 .nth(1)
130 .ok_or(OAuthError::AuthCodeListenerParse)?;
131 let code = get_code(&("http://localhost".to_string() + redirect_url));
132
133 let message = "Go back to your terminal :)";
134 let response = format!(
135 "HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
136 message.len(),
137 message
138 );
139 stream
140 .write_all(response.as_bytes())
141 .map_err(|_| OAuthError::AuthCodeListenerWrite)?;
142
143 code
144}
145
146fn get_socket_address(redirect_uri: &str) -> Option<SocketAddr> {
149 let url = match Url::parse(redirect_uri) {
150 Ok(u) if u.scheme() == "http" && u.port().is_some() => u,
151 _ => return None,
152 };
153 let socket_addr = match url.socket_addrs(|| None) {
154 Ok(mut addrs) => addrs.pop(),
155 _ => None,
156 };
157 if let Some(s) = socket_addr {
158 if s.ip().is_loopback() {
159 return socket_addr;
160 }
161 }
162 None
163}
164
165pub fn get_access_token(
168 client_id: &str,
169 redirect_uri: &str,
170 scopes: Vec<&str>,
171) -> Result<OAuthToken, OAuthError> {
172 let auth_url = AuthUrl::new("https://accounts.spotify.com/authorize".to_string())
173 .map_err(|_| OAuthError::InvalidSpotifyUri)?;
174 let token_url = TokenUrl::new("https://accounts.spotify.com/api/token".to_string())
175 .map_err(|_| OAuthError::InvalidSpotifyUri)?;
176 let redirect_url =
177 RedirectUrl::new(redirect_uri.to_string()).map_err(|e| OAuthError::InvalidRedirectUri {
178 uri: redirect_uri.to_string(),
179 e,
180 })?;
181 let client = BasicClient::new(
182 ClientId::new(client_id.to_string()),
183 None,
184 auth_url,
185 Some(token_url),
186 )
187 .set_redirect_uri(redirect_url);
188
189 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
190
191 let request_scopes: Vec<oauth2::Scope> = scopes
194 .clone()
195 .into_iter()
196 .map(|s| Scope::new(s.into()))
197 .collect();
198 let (auth_url, _) = client
199 .authorize_url(CsrfToken::new_random)
200 .add_scopes(request_scopes)
201 .set_pkce_challenge(pkce_challenge)
202 .url();
203
204 println!("Browse to: {}", auth_url);
205
206 let code = match get_socket_address(redirect_uri) {
207 Some(addr) => get_authcode_listener(addr),
208 _ => get_authcode_stdin(),
209 }?;
210 trace!("Exchange {code:?} for access token");
211
212 let (tx, rx) = mpsc::channel();
214 std::thread::spawn(move || {
215 let resp = client
216 .exchange_code(code)
217 .set_pkce_verifier(pkce_verifier)
218 .request(http_client);
219 if let Err(e) = tx.send(resp) {
220 error!("OAuth channel send error: {e}");
221 }
222 });
223 let token_response = rx.recv().map_err(|_| OAuthError::Recv)?;
224 let token = token_response.map_err(|e| OAuthError::ExchangeCode { e: e.to_string() })?;
225 trace!("Obtained new access token: {token:?}");
226
227 let token_scopes: Vec<String> = match token.scopes() {
228 Some(s) => s.iter().map(|s| s.to_string()).collect(),
229 _ => scopes.into_iter().map(|s| s.to_string()).collect(),
230 };
231 let refresh_token = match token.refresh_token() {
232 Some(t) => t.secret().to_string(),
233 _ => "".to_string(), };
235 Ok(OAuthToken {
236 access_token: token.access_token().secret().to_string(),
237 refresh_token,
238 expires_at: Instant::now()
239 + token
240 .expires_in()
241 .unwrap_or_else(|| Duration::from_secs(3600)),
242 token_type: format!("{:?}", token.token_type()).to_string(), scopes: token_scopes,
244 })
245}
246
247#[cfg(test)]
248mod test {
249 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
250
251 use super::*;
252
253 #[test]
254 fn get_socket_address_none() {
255 assert_eq!(get_socket_address("http://127.0.0.1/foo"), None);
257 assert_eq!(get_socket_address("http://127.0.0.1:/foo"), None);
258 assert_eq!(get_socket_address("http://[::1]/foo"), None);
259 assert_eq!(get_socket_address("http://56.0.0.1:1234/foo"), None);
261 assert_eq!(
262 get_socket_address("http://[3ffe:2a00:100:7031::1]:1234/foo"),
263 None
264 );
265 assert_eq!(get_socket_address("https://127.0.0.1/foo"), None);
267 }
268
269 #[test]
270 fn get_socket_address_localhost() {
271 let localhost_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234);
272 let localhost_v6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8888);
273
274 assert_eq!(
275 get_socket_address("http://127.0.0.1:1234/foo"),
276 Some(localhost_v4)
277 );
278 assert_eq!(
279 get_socket_address("http://[0:0:0:0:0:0:0:1]:8888/foo"),
280 Some(localhost_v6)
281 );
282 assert_eq!(
283 get_socket_address("http://[::1]:8888/foo"),
284 Some(localhost_v6)
285 );
286 }
287}