librespot_oauth/
lib.rs

1//! Provides a Spotify access token using the OAuth authorization code flow
2//! with PKCE.
3//!
4//! Assuming sufficient scopes, the returned access token may be used with Spotify's
5//! Web API, and/or to establish a new Session with [`librespot_core`].
6//!
7//! The authorization code flow is an interactive process which requires a web browser
8//! to complete. The resulting code must then be provided back from the browser to this
9//! library for exchange into an access token. Providing the code can be automatic via
10//! a spawned http server (mimicking Spotify's client), or manually via stdin. The latter
11//! is appropriate for headless systems.
12
13use log::{error, info, trace};
14use oauth2::reqwest::http_client;
15use oauth2::{
16    basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge,
17    RedirectUrl, Scope, TokenResponse, TokenUrl,
18};
19use std::io;
20use std::time::{Duration, Instant};
21use std::{
22    io::{BufRead, BufReader, Write},
23    net::{SocketAddr, TcpListener},
24    sync::mpsc,
25};
26use thiserror::Error;
27use url::Url;
28
29#[derive(Debug, Error)]
30pub enum OAuthError {
31    #[error("Unable to parse redirect URI {uri} ({e})")]
32    AuthCodeBadUri { uri: String, e: url::ParseError },
33
34    #[error("Auth code param not found in URI {uri}")]
35    AuthCodeNotFound { uri: String },
36
37    #[error("Failed to read redirect URI from stdin")]
38    AuthCodeStdinRead,
39
40    #[error("Failed to bind server to {addr} ({e})")]
41    AuthCodeListenerBind { addr: SocketAddr, e: io::Error },
42
43    #[error("Listener terminated without accepting a connection")]
44    AuthCodeListenerTerminated,
45
46    #[error("Failed to read redirect URI from HTTP request")]
47    AuthCodeListenerRead,
48
49    #[error("Failed to parse redirect URI from HTTP request")]
50    AuthCodeListenerParse,
51
52    #[error("Failed to write HTTP response")]
53    AuthCodeListenerWrite,
54
55    #[error("Invalid Spotify OAuth URI")]
56    InvalidSpotifyUri,
57
58    #[error("Invalid Redirect URI {uri} ({e})")]
59    InvalidRedirectUri { uri: String, e: url::ParseError },
60
61    #[error("Failed to receive code")]
62    Recv,
63
64    #[error("Failed to exchange code for access token ({e})")]
65    ExchangeCode { e: String },
66}
67
68#[derive(Debug)]
69pub struct OAuthToken {
70    pub access_token: String,
71    pub refresh_token: String,
72    pub expires_at: Instant,
73    pub token_type: String,
74    pub scopes: Vec<String>,
75}
76
77/// Return code query-string parameter from the redirect URI.
78fn get_code(redirect_url: &str) -> Result<AuthorizationCode, OAuthError> {
79    let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri {
80        uri: redirect_url.to_string(),
81        e,
82    })?;
83    let code = url
84        .query_pairs()
85        .find(|(key, _)| key == "code")
86        .map(|(_, code)| AuthorizationCode::new(code.into_owned()))
87        .ok_or(OAuthError::AuthCodeNotFound {
88            uri: redirect_url.to_string(),
89        })?;
90
91    Ok(code)
92}
93
94/// Prompt for redirect URI on stdin and return auth code.
95fn get_authcode_stdin() -> Result<AuthorizationCode, OAuthError> {
96    println!("Provide redirect URL");
97    let mut buffer = String::new();
98    let stdin = io::stdin();
99    stdin
100        .read_line(&mut buffer)
101        .map_err(|_| OAuthError::AuthCodeStdinRead)?;
102
103    get_code(buffer.trim())
104}
105
106/// Spawn HTTP server at provided socket address to accept OAuth callback and return auth code.
107fn get_authcode_listener(socket_address: SocketAddr) -> Result<AuthorizationCode, OAuthError> {
108    let listener =
109        TcpListener::bind(socket_address).map_err(|e| OAuthError::AuthCodeListenerBind {
110            addr: socket_address,
111            e,
112        })?;
113    info!("OAuth server listening on {:?}", socket_address);
114
115    // The server will terminate itself after collecting the first code.
116    let mut stream = listener
117        .incoming()
118        .flatten()
119        .next()
120        .ok_or(OAuthError::AuthCodeListenerTerminated)?;
121    let mut reader = BufReader::new(&stream);
122    let mut request_line = String::new();
123    reader
124        .read_line(&mut request_line)
125        .map_err(|_| OAuthError::AuthCodeListenerRead)?;
126
127    let redirect_url = request_line
128        .split_whitespace()
129        .nth(1)
130        .ok_or(OAuthError::AuthCodeListenerParse)?;
131    let code = get_code(&("http://localhost".to_string() + redirect_url));
132
133    let message = "Go back to your terminal :)";
134    let response = format!(
135        "HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
136        message.len(),
137        message
138    );
139    stream
140        .write_all(response.as_bytes())
141        .map_err(|_| OAuthError::AuthCodeListenerWrite)?;
142
143    code
144}
145
146// If the specified `redirect_uri` is HTTP, loopback, and contains a port,
147// then the corresponding socket address is returned.
148fn get_socket_address(redirect_uri: &str) -> Option<SocketAddr> {
149    let url = match Url::parse(redirect_uri) {
150        Ok(u) if u.scheme() == "http" && u.port().is_some() => u,
151        _ => return None,
152    };
153    let socket_addr = match url.socket_addrs(|| None) {
154        Ok(mut addrs) => addrs.pop(),
155        _ => None,
156    };
157    if let Some(s) = socket_addr {
158        if s.ip().is_loopback() {
159            return socket_addr;
160        }
161    }
162    None
163}
164
165/// Obtain a Spotify access token using the authorization code with PKCE OAuth flow.
166/// The redirect_uri must match what is registered to the client ID.
167pub fn get_access_token(
168    client_id: &str,
169    redirect_uri: &str,
170    scopes: Vec<&str>,
171) -> Result<OAuthToken, OAuthError> {
172    let auth_url = AuthUrl::new("https://accounts.spotify.com/authorize".to_string())
173        .map_err(|_| OAuthError::InvalidSpotifyUri)?;
174    let token_url = TokenUrl::new("https://accounts.spotify.com/api/token".to_string())
175        .map_err(|_| OAuthError::InvalidSpotifyUri)?;
176    let redirect_url =
177        RedirectUrl::new(redirect_uri.to_string()).map_err(|e| OAuthError::InvalidRedirectUri {
178            uri: redirect_uri.to_string(),
179            e,
180        })?;
181    let client = BasicClient::new(
182        ClientId::new(client_id.to_string()),
183        None,
184        auth_url,
185        Some(token_url),
186    )
187    .set_redirect_uri(redirect_url);
188
189    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
190
191    // Generate the full authorization URL.
192    // Some of these scopes are unavailable for custom client IDs. Which?
193    let request_scopes: Vec<oauth2::Scope> = scopes
194        .clone()
195        .into_iter()
196        .map(|s| Scope::new(s.into()))
197        .collect();
198    let (auth_url, _) = client
199        .authorize_url(CsrfToken::new_random)
200        .add_scopes(request_scopes)
201        .set_pkce_challenge(pkce_challenge)
202        .url();
203
204    println!("Browse to: {}", auth_url);
205
206    let code = match get_socket_address(redirect_uri) {
207        Some(addr) => get_authcode_listener(addr),
208        _ => get_authcode_stdin(),
209    }?;
210    trace!("Exchange {code:?} for access token");
211
212    // Do this sync in another thread because I am too stupid to make the async version work.
213    let (tx, rx) = mpsc::channel();
214    std::thread::spawn(move || {
215        let resp = client
216            .exchange_code(code)
217            .set_pkce_verifier(pkce_verifier)
218            .request(http_client);
219        if let Err(e) = tx.send(resp) {
220            error!("OAuth channel send error: {e}");
221        }
222    });
223    let token_response = rx.recv().map_err(|_| OAuthError::Recv)?;
224    let token = token_response.map_err(|e| OAuthError::ExchangeCode { e: e.to_string() })?;
225    trace!("Obtained new access token: {token:?}");
226
227    let token_scopes: Vec<String> = match token.scopes() {
228        Some(s) => s.iter().map(|s| s.to_string()).collect(),
229        _ => scopes.into_iter().map(|s| s.to_string()).collect(),
230    };
231    let refresh_token = match token.refresh_token() {
232        Some(t) => t.secret().to_string(),
233        _ => "".to_string(), // Spotify always provides a refresh token.
234    };
235    Ok(OAuthToken {
236        access_token: token.access_token().secret().to_string(),
237        refresh_token,
238        expires_at: Instant::now()
239            + token
240                .expires_in()
241                .unwrap_or_else(|| Duration::from_secs(3600)),
242        token_type: format!("{:?}", token.token_type()).to_string(), // Urgh!?
243        scopes: token_scopes,
244    })
245}
246
247#[cfg(test)]
248mod test {
249    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
250
251    use super::*;
252
253    #[test]
254    fn get_socket_address_none() {
255        // No port
256        assert_eq!(get_socket_address("http://127.0.0.1/foo"), None);
257        assert_eq!(get_socket_address("http://127.0.0.1:/foo"), None);
258        assert_eq!(get_socket_address("http://[::1]/foo"), None);
259        // Not localhost
260        assert_eq!(get_socket_address("http://56.0.0.1:1234/foo"), None);
261        assert_eq!(
262            get_socket_address("http://[3ffe:2a00:100:7031::1]:1234/foo"),
263            None
264        );
265        // Not http
266        assert_eq!(get_socket_address("https://127.0.0.1/foo"), None);
267    }
268
269    #[test]
270    fn get_socket_address_localhost() {
271        let localhost_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234);
272        let localhost_v6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8888);
273
274        assert_eq!(
275            get_socket_address("http://127.0.0.1:1234/foo"),
276            Some(localhost_v4)
277        );
278        assert_eq!(
279            get_socket_address("http://[0:0:0:0:0:0:0:1]:8888/foo"),
280            Some(localhost_v6)
281        );
282        assert_eq!(
283            get_socket_address("http://[::1]:8888/foo"),
284            Some(localhost_v6)
285        );
286    }
287}