1use std::{
2 collections::HashMap,
3 future::Future,
4 io,
5 pin::Pin,
6 process::exit,
7 sync::{Arc, Weak},
8 task::{Context, Poll},
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12use byteorder::{BigEndian, ByteOrder};
13use bytes::Bytes;
14use futures_core::TryStream;
15use futures_util::StreamExt;
16use librespot_protocol::authentication::AuthenticationType;
17use num_traits::FromPrimitive;
18use once_cell::sync::OnceCell;
19use parking_lot::RwLock;
20use pin_project_lite::pin_project;
21use quick_xml::events::Event;
22use thiserror::Error;
23use tokio::{
24 sync::mpsc,
25 time::{sleep, Duration as TokioDuration, Instant as TokioInstant, Sleep},
26};
27use tokio_stream::wrappers::UnboundedReceiverStream;
28
29use crate::{
30 apresolve::{ApResolver, SocketAddress},
31 audio_key::AudioKeyManager,
32 authentication::Credentials,
33 cache::Cache,
34 channel::ChannelManager,
35 config::SessionConfig,
36 connection::{self, AuthenticationError, Transport},
37 http_client::HttpClient,
38 login5::Login5Manager,
39 mercury::MercuryManager,
40 packet::PacketType,
41 protocol::keyexchange::ErrorCode,
42 spclient::SpClient,
43 token::TokenProvider,
44 Error,
45};
46
47#[derive(Debug, Error)]
48pub enum SessionError {
49 #[error(transparent)]
50 AuthenticationError(#[from] AuthenticationError),
51 #[error("Cannot create session: {0}")]
52 IoError(#[from] io::Error),
53 #[error("Session is not connected")]
54 NotConnected,
55 #[error("packet {0} unknown")]
56 Packet(u8),
57}
58
59impl From<SessionError> for Error {
60 fn from(err: SessionError) -> Self {
61 match err {
62 SessionError::AuthenticationError(_) => Error::unauthenticated(err),
63 SessionError::IoError(_) => Error::unavailable(err),
64 SessionError::NotConnected => Error::unavailable(err),
65 SessionError::Packet(_) => Error::unimplemented(err),
66 }
67 }
68}
69
70pub type UserAttributes = HashMap<String, String>;
71
72#[derive(Debug, Clone, Default)]
73pub struct UserData {
74 pub country: String,
75 pub canonical_username: String,
76 pub attributes: UserAttributes,
77}
78
79#[derive(Debug, Clone, Default)]
80struct SessionData {
81 client_id: String,
82 client_name: String,
83 client_brand_name: String,
84 client_model_name: String,
85 connection_id: String,
86 auth_data: Vec<u8>,
87 time_delta: i64,
88 invalid: bool,
89 user_data: UserData,
90}
91
92struct SessionInternal {
93 config: SessionConfig,
94 data: RwLock<SessionData>,
95
96 http_client: HttpClient,
97 tx_connection: OnceCell<mpsc::UnboundedSender<(u8, Vec<u8>)>>,
98
99 apresolver: OnceCell<ApResolver>,
100 audio_key: OnceCell<AudioKeyManager>,
101 channel: OnceCell<ChannelManager>,
102 mercury: OnceCell<MercuryManager>,
103 spclient: OnceCell<SpClient>,
104 token_provider: OnceCell<TokenProvider>,
105 login5: OnceCell<Login5Manager>,
106 cache: Option<Arc<Cache>>,
107
108 handle: tokio::runtime::Handle,
109}
110
111#[derive(Clone)]
121pub struct Session(Arc<SessionInternal>);
122
123impl Session {
124 pub fn new(config: SessionConfig, cache: Option<Cache>) -> Self {
125 let http_client = HttpClient::new(config.proxy.as_ref());
126
127 debug!("new Session");
128
129 let session_data = SessionData {
130 client_id: config.client_id.clone(),
131 ..SessionData::default()
132 };
133
134 Self(Arc::new(SessionInternal {
135 config,
136 data: RwLock::new(session_data),
137 http_client,
138 tx_connection: OnceCell::new(),
139 cache: cache.map(Arc::new),
140 apresolver: OnceCell::new(),
141 audio_key: OnceCell::new(),
142 channel: OnceCell::new(),
143 mercury: OnceCell::new(),
144 spclient: OnceCell::new(),
145 token_provider: OnceCell::new(),
146 login5: OnceCell::new(),
147 handle: tokio::runtime::Handle::current(),
148 }))
149 }
150
151 async fn connect_inner(
152 &self,
153 access_point: &SocketAddress,
154 credentials: Credentials,
155 ) -> Result<(Credentials, Transport), Error> {
156 const MAX_RETRIES: u8 = 1;
157 let mut transport = connection::connect_with_retry(
158 &access_point.0,
159 access_point.1,
160 self.config().proxy.as_ref(),
161 MAX_RETRIES,
162 )
163 .await?;
164 let mut reusable_credentials = connection::authenticate(
165 &mut transport,
166 credentials.clone(),
167 &self.config().device_id,
168 )
169 .await?;
170
171 if credentials.auth_type == AuthenticationType::AUTHENTICATION_SPOTIFY_TOKEN {
173 trace!(
174 "Reconnect using stored credentials as token authed sessions cannot use keymaster."
175 );
176 transport = connection::connect_with_retry(
177 &access_point.0,
178 access_point.1,
179 self.config().proxy.as_ref(),
180 MAX_RETRIES,
181 )
182 .await?;
183 reusable_credentials = connection::authenticate(
184 &mut transport,
185 reusable_credentials.clone(),
186 &self.config().device_id,
187 )
188 .await?;
189 }
190
191 Ok((reusable_credentials, transport))
192 }
193
194 pub async fn connect(
195 &self,
196 credentials: Credentials,
197 store_credentials: bool,
198 ) -> Result<(), Error> {
199 const MAX_AP_TRIES: u8 = 6;
201 let mut num_ap_tries = 0;
202 let (reusable_credentials, transport) = loop {
203 let ap = self.apresolver().resolve("accesspoint").await?;
204 info!("Connecting to AP \"{}:{}\"", ap.0, ap.1);
205 match self.connect_inner(&ap, credentials.clone()).await {
206 Ok(ct) => break ct,
207 Err(e) => {
208 num_ap_tries += 1;
209 if MAX_AP_TRIES == num_ap_tries {
210 error!("Tried too many access points");
211 return Err(e);
212 }
213 if let Some(AuthenticationError::LoginFailed(ErrorCode::TryAnotherAP)) =
214 e.error.downcast_ref::<AuthenticationError>()
215 {
216 warn!("Instructed to try another access point...");
217 continue;
218 } else if let Some(AuthenticationError::LoginFailed(..)) =
219 e.error.downcast_ref::<AuthenticationError>()
220 {
221 return Err(e);
222 } else {
223 warn!("Try another access point...");
224 continue;
225 }
226 }
227 }
228 };
229
230 let username = reusable_credentials
231 .username
232 .as_ref()
233 .map_or("UNKNOWN", |s| s.as_str());
234 info!("Authenticated as '{username}' !");
235 self.set_username(username);
236 self.set_auth_data(&reusable_credentials.auth_data);
237 if let Some(cache) = self.cache() {
238 if store_credentials {
239 let cred_changed = cache
240 .credentials()
241 .map(|c| c != reusable_credentials)
242 .unwrap_or(true);
243 if cred_changed {
244 cache.save_credentials(&reusable_credentials);
245 }
246 }
247 }
248
249 let (tx_connection, rx_connection) = mpsc::unbounded_channel();
252 self.0
253 .tx_connection
254 .set(tx_connection)
255 .map_err(|_| SessionError::NotConnected)?;
256
257 let (sink, stream) = transport.split();
258 let sender_task = UnboundedReceiverStream::new(rx_connection)
259 .map(Ok)
260 .forward(sink);
261 let session_weak = self.weak();
262 tokio::spawn(async move {
263 if let Err(e) = sender_task.await {
264 error!("{}", e);
265 if let Some(session) = session_weak.try_upgrade() {
266 if !session.is_invalid() {
267 session.shutdown();
268 }
269 }
270 }
271 });
272
273 tokio::spawn(DispatchTask::new(self.weak(), stream));
274
275 Ok(())
276 }
277
278 pub fn apresolver(&self) -> &ApResolver {
279 self.0
280 .apresolver
281 .get_or_init(|| ApResolver::new(self.weak()))
282 }
283
284 pub fn audio_key(&self) -> &AudioKeyManager {
285 self.0
286 .audio_key
287 .get_or_init(|| AudioKeyManager::new(self.weak()))
288 }
289
290 pub fn channel(&self) -> &ChannelManager {
291 self.0
292 .channel
293 .get_or_init(|| ChannelManager::new(self.weak()))
294 }
295
296 pub fn http_client(&self) -> &HttpClient {
297 &self.0.http_client
298 }
299
300 pub fn mercury(&self) -> &MercuryManager {
301 self.0
302 .mercury
303 .get_or_init(|| MercuryManager::new(self.weak()))
304 }
305
306 pub fn spclient(&self) -> &SpClient {
307 self.0.spclient.get_or_init(|| SpClient::new(self.weak()))
308 }
309
310 pub fn token_provider(&self) -> &TokenProvider {
311 self.0
312 .token_provider
313 .get_or_init(|| TokenProvider::new(self.weak()))
314 }
315
316 pub fn login5(&self) -> &Login5Manager {
317 self.0
318 .login5
319 .get_or_init(|| Login5Manager::new(self.weak()))
320 }
321
322 pub fn time_delta(&self) -> i64 {
323 self.0.data.read().time_delta
324 }
325
326 pub fn spawn<T>(&self, task: T)
327 where
328 T: Future + Send + 'static,
329 T::Output: Send + 'static,
330 {
331 self.0.handle.spawn(task);
332 }
333
334 fn debug_info(&self) {
335 debug!(
336 "Session strong={} weak={}",
337 Arc::strong_count(&self.0),
338 Arc::weak_count(&self.0)
339 );
340 }
341
342 fn check_catalogue(attributes: &UserAttributes) {
343 if let Some(account_type) = attributes.get("type") {
344 if account_type != "premium" {
345 error!("librespot does not support {:?} accounts.", account_type);
346 info!("Please support Spotify and your artists and sign up for a premium account.");
347
348 exit(1);
350 }
351 }
352 }
353
354 pub fn send_packet(&self, cmd: PacketType, data: Vec<u8>) -> Result<(), Error> {
355 match self.0.tx_connection.get() {
356 Some(tx) => Ok(tx.send((cmd as u8, data))?),
357 None => Err(SessionError::NotConnected.into()),
358 }
359 }
360
361 pub fn cache(&self) -> Option<&Arc<Cache>> {
362 self.0.cache.as_ref()
363 }
364
365 pub fn config(&self) -> &SessionConfig {
366 &self.0.config
367 }
368
369 pub fn user_data(&self) -> UserData {
373 self.0.data.read().user_data.clone()
374 }
375
376 pub fn device_id(&self) -> &str {
377 &self.config().device_id
378 }
379
380 pub fn client_id(&self) -> String {
381 self.0.data.read().client_id.clone()
382 }
383
384 pub fn set_client_id(&self, client_id: &str) {
385 client_id.clone_into(&mut self.0.data.write().client_id);
386 }
387
388 pub fn client_name(&self) -> String {
389 self.0.data.read().client_name.clone()
390 }
391
392 pub fn set_client_name(&self, client_name: &str) {
393 client_name.clone_into(&mut self.0.data.write().client_name);
394 }
395
396 pub fn client_brand_name(&self) -> String {
397 self.0.data.read().client_brand_name.clone()
398 }
399
400 pub fn set_client_brand_name(&self, client_brand_name: &str) {
401 client_brand_name.clone_into(&mut self.0.data.write().client_brand_name);
402 }
403
404 pub fn client_model_name(&self) -> String {
405 self.0.data.read().client_model_name.clone()
406 }
407
408 pub fn set_client_model_name(&self, client_model_name: &str) {
409 client_model_name.clone_into(&mut self.0.data.write().client_model_name);
410 }
411
412 pub fn connection_id(&self) -> String {
413 self.0.data.read().connection_id.clone()
414 }
415
416 pub fn set_connection_id(&self, connection_id: &str) {
417 connection_id.clone_into(&mut self.0.data.write().connection_id);
418 }
419
420 pub fn username(&self) -> String {
421 self.0.data.read().user_data.canonical_username.clone()
422 }
423
424 pub fn set_username(&self, username: &str) {
425 username.clone_into(&mut self.0.data.write().user_data.canonical_username);
426 }
427
428 pub fn auth_data(&self) -> Vec<u8> {
429 self.0.data.read().auth_data.clone()
430 }
431
432 pub fn set_auth_data(&self, auth_data: &[u8]) {
433 self.0.data.write().auth_data = auth_data.to_owned();
434 }
435
436 pub fn country(&self) -> String {
437 self.0.data.read().user_data.country.clone()
438 }
439
440 pub fn filter_explicit_content(&self) -> bool {
441 match self.get_user_attribute("filter-explicit-content") {
442 Some(value) => matches!(&*value, "1"),
443 None => false,
444 }
445 }
446
447 pub fn autoplay(&self) -> bool {
448 if let Some(overide) = self.config().autoplay {
449 return overide;
450 }
451
452 match self.get_user_attribute("autoplay") {
453 Some(value) => matches!(&*value, "1"),
454 None => false,
455 }
456 }
457
458 pub fn set_user_attribute(&self, key: &str, value: &str) -> Option<String> {
459 let mut dummy_attributes = UserAttributes::new();
460 dummy_attributes.insert(key.to_owned(), value.to_owned());
461 Self::check_catalogue(&dummy_attributes);
462
463 self.0
464 .data
465 .write()
466 .user_data
467 .attributes
468 .insert(key.to_owned(), value.to_owned())
469 }
470
471 pub fn set_user_attributes(&self, attributes: UserAttributes) {
472 Self::check_catalogue(&attributes);
473
474 self.0.data.write().user_data.attributes.extend(attributes)
475 }
476
477 pub fn get_user_attribute(&self, key: &str) -> Option<String> {
478 self.0.data.read().user_data.attributes.get(key).cloned()
479 }
480
481 fn weak(&self) -> SessionWeak {
482 SessionWeak(Arc::downgrade(&self.0))
483 }
484
485 pub fn shutdown(&self) {
486 debug!("Shutdown: Invalidating session");
487 self.0.data.write().invalid = true;
488 self.mercury().shutdown();
489 self.channel().shutdown();
490 }
491
492 pub fn is_invalid(&self) -> bool {
493 self.0.data.read().invalid
494 }
495}
496
497#[derive(Clone)]
498pub struct SessionWeak(Weak<SessionInternal>);
499
500impl SessionWeak {
501 fn try_upgrade(&self) -> Option<Session> {
502 self.0.upgrade().map(Session)
503 }
504
505 pub(crate) fn upgrade(&self) -> Session {
506 self.try_upgrade()
507 .expect("session was dropped and so should have this component")
508 }
509}
510
511impl Drop for SessionInternal {
512 fn drop(&mut self) {
513 debug!("drop Session");
514 }
515}
516
517#[derive(Clone, Copy, Default, Debug, PartialEq)]
518enum KeepAliveState {
519 #[default]
520 ExpectingPing,
522
523 PendingPong,
525
526 ExpectingPongAck,
528}
529
530const INITIAL_PING_TIMEOUT: TokioDuration = TokioDuration::from_secs(20);
531const PING_TIMEOUT: TokioDuration = TokioDuration::from_secs(80); const PONG_DELAY: TokioDuration = TokioDuration::from_secs(60);
533const PONG_ACK_TIMEOUT: TokioDuration = TokioDuration::from_secs(20);
534
535impl KeepAliveState {
536 fn debug(&self, sleep: &Sleep) {
537 let delay = sleep
538 .deadline()
539 .checked_duration_since(TokioInstant::now())
540 .map(|t| t.as_secs_f64())
541 .unwrap_or(f64::INFINITY);
542
543 trace!("keep-alive state: {:?}, timeout in {:.1}", self, delay);
544 }
545}
546
547pin_project! {
548 struct DispatchTask<S>
549 where
550 S: TryStream<Ok = (u8, Bytes)>
551 {
552 session: SessionWeak,
553 keep_alive_state: KeepAliveState,
554 #[pin]
555 stream: S,
556 #[pin]
557 timeout: Sleep,
558 }
559
560 impl<S> PinnedDrop for DispatchTask<S>
561 where
562 S: TryStream<Ok = (u8, Bytes)>
563 {
564 fn drop(_this: Pin<&mut Self>) {
565 debug!("drop Dispatch");
566 }
567 }
568}
569
570impl<S> DispatchTask<S>
571where
572 S: TryStream<Ok = (u8, Bytes)>,
573{
574 fn new(session: SessionWeak, stream: S) -> Self {
575 Self {
576 session,
577 keep_alive_state: KeepAliveState::ExpectingPing,
578 stream,
579 timeout: sleep(INITIAL_PING_TIMEOUT),
580 }
581 }
582
583 fn dispatch(
584 mut self: Pin<&mut Self>,
585 session: &Session,
586 cmd: u8,
587 data: Bytes,
588 ) -> Result<(), Error> {
589 use KeepAliveState::*;
590 use PacketType::*;
591
592 let packet_type = FromPrimitive::from_u8(cmd);
593 let cmd = match packet_type {
594 Some(cmd) => cmd,
595 None => {
596 trace!("Ignoring unknown packet {:x}", cmd);
597 return Err(SessionError::Packet(cmd).into());
598 }
599 };
600
601 match packet_type {
602 Some(Ping) => {
603 trace!("Received Ping");
604 if self.keep_alive_state != ExpectingPing {
605 warn!("Received unexpected Ping from server")
606 }
607 let mut this = self.as_mut().project();
608 *this.keep_alive_state = PendingPong;
609 this.timeout
610 .as_mut()
611 .reset(TokioInstant::now() + PONG_DELAY);
612 this.keep_alive_state.debug(&this.timeout);
613
614 let server_timestamp = BigEndian::read_u32(data.as_ref()) as i64;
615 let timestamp = SystemTime::now()
616 .duration_since(UNIX_EPOCH)
617 .unwrap_or(Duration::ZERO)
618 .as_secs() as i64;
619 {
620 let mut data = session.0.data.write();
621 data.time_delta = server_timestamp.saturating_sub(timestamp);
622 }
623
624 session.debug_info();
625
626 Ok(())
627 }
628 Some(PongAck) => {
629 trace!("Received PongAck");
630 if self.keep_alive_state != ExpectingPongAck {
631 warn!("Received unexpected PongAck from server")
632 }
633 let mut this = self.as_mut().project();
634 *this.keep_alive_state = ExpectingPing;
635 this.timeout
636 .as_mut()
637 .reset(TokioInstant::now() + PING_TIMEOUT);
638 this.keep_alive_state.debug(&this.timeout);
639
640 Ok(())
641 }
642 Some(CountryCode) => {
643 let country = String::from_utf8(data.as_ref().to_owned())?;
644 info!("Country: {:?}", country);
645 session.0.data.write().user_data.country = country;
646 Ok(())
647 }
648 Some(StreamChunkRes) | Some(ChannelError) => session.channel().dispatch(cmd, data),
649 Some(AesKey) | Some(AesKeyError) => session.audio_key().dispatch(cmd, data),
650 Some(MercuryReq) | Some(MercurySub) | Some(MercuryUnsub) | Some(MercuryEvent) => {
651 session.mercury().dispatch(cmd, data)
652 }
653 Some(ProductInfo) => {
654 let data = std::str::from_utf8(&data)?;
655 let mut reader = quick_xml::Reader::from_str(data);
656
657 let mut buf = Vec::new();
658 let mut current_element = String::new();
659 let mut user_attributes: UserAttributes = HashMap::new();
660
661 loop {
662 match reader.read_event_into(&mut buf) {
663 Ok(Event::Start(ref element)) => {
664 std::str::from_utf8(element)?.clone_into(&mut current_element)
665 }
666 Ok(Event::End(_)) => {
667 current_element = String::new();
668 }
669 Ok(Event::Text(ref value)) => {
670 if !current_element.is_empty() {
671 let _ = user_attributes
672 .insert(current_element.clone(), value.unescape()?.to_string());
673 }
674 }
675 Ok(Event::Eof) => break,
676 Ok(_) => (),
677 Err(e) => warn!(
678 "Error parsing XML at position {}: {:?}",
679 reader.buffer_position(),
680 e
681 ),
682 }
683 }
684
685 trace!("Received product info: {:#?}", user_attributes);
686 Session::check_catalogue(&user_attributes);
687
688 session.0.data.write().user_data.attributes = user_attributes;
689 Ok(())
690 }
691 Some(SecretBlock)
692 | Some(LegacyWelcome)
693 | Some(UnknownDataAllZeros)
694 | Some(LicenseVersion) => Ok(()),
695 _ => {
696 trace!("Ignoring {:?} packet with data {:#?}", cmd, data);
697 Err(SessionError::Packet(cmd as u8).into())
698 }
699 }
700 }
701}
702
703impl<S> Future for DispatchTask<S>
704where
705 S: TryStream<Ok = (u8, Bytes), Error = std::io::Error>,
706 <S as TryStream>::Ok: std::fmt::Debug,
707{
708 type Output = Result<(), S::Error>;
709
710 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
711 use KeepAliveState::*;
712
713 let session = match self.session.try_upgrade() {
714 Some(session) => session,
715 None => return Poll::Ready(Ok(())),
716 };
717
718 loop {
720 match self.as_mut().project().stream.try_poll_next(cx) {
721 Poll::Ready(Some(Ok((cmd, data)))) => {
722 let result = self.as_mut().dispatch(&session, cmd, data);
723 if let Err(e) = result {
724 debug!("could not dispatch command: {}", e);
725 }
726 }
727 Poll::Ready(None) => {
728 warn!("Connection to server closed.");
729 session.shutdown();
730 return Poll::Ready(Ok(()));
731 }
732 Poll::Ready(Some(Err(e))) => {
733 error!("Connection to server closed.");
734 session.shutdown();
735 return Poll::Ready(Err(e));
736 }
737 Poll::Pending => break,
738 }
739 }
740
741 let mut this = self.as_mut().project();
759 if let Poll::Ready(()) = this.timeout.as_mut().poll(cx) {
760 match this.keep_alive_state {
761 ExpectingPing | ExpectingPongAck => {
762 if !session.is_invalid() {
763 session.shutdown();
764 }
765 return Poll::Ready(Err(io::Error::new(
767 io::ErrorKind::TimedOut,
768 format!(
769 "session lost connection to server ({:?})",
770 this.keep_alive_state
771 ),
772 )));
773 }
774 PendingPong => {
775 trace!("Sending Pong");
776 let _ = session.send_packet(PacketType::Pong, vec![0, 0, 0, 0]);
779 *this.keep_alive_state = ExpectingPongAck;
780 this.timeout
781 .as_mut()
782 .reset(TokioInstant::now() + PONG_ACK_TIMEOUT);
783 this.keep_alive_state.debug(&this.timeout);
784 }
785 }
786 }
787
788 Poll::Pending
789 }
790}