1use std::{
2 collections::HashMap,
3 future::Future,
4 io,
5 pin::Pin,
6 process::exit,
7 sync::OnceLock,
8 sync::{Arc, Weak},
9 task::{Context, Poll},
10 time::{Duration, SystemTime, UNIX_EPOCH},
11};
12
13use crate::dealer::manager::DealerManager;
14use crate::{
15 Error,
16 apresolve::{ApResolver, SocketAddress},
17 audio_key::AudioKeyManager,
18 authentication::Credentials,
19 cache::Cache,
20 channel::ChannelManager,
21 config::SessionConfig,
22 connection::{self, AuthenticationError, Transport},
23 http_client::HttpClient,
24 login5::Login5Manager,
25 mercury::MercuryManager,
26 packet::PacketType,
27 protocol::keyexchange::ErrorCode,
28 spclient::SpClient,
29 token::TokenProvider,
30};
31use byteorder::{BigEndian, ByteOrder};
32use bytes::Bytes;
33use futures_core::TryStream;
34use futures_util::StreamExt;
35use librespot_protocol::authentication::AuthenticationType;
36use num_traits::FromPrimitive;
37use parking_lot::RwLock;
38use pin_project_lite::pin_project;
39use quick_xml::events::Event;
40use thiserror::Error;
41use tokio::{
42 sync::mpsc,
43 time::{Duration as TokioDuration, Instant as TokioInstant, Sleep, sleep},
44};
45use tokio_stream::wrappers::UnboundedReceiverStream;
46use uuid::Uuid;
47
48#[derive(Debug, Error)]
49pub enum SessionError {
50 #[error(transparent)]
51 AuthenticationError(#[from] AuthenticationError),
52 #[error("Cannot create session: {0}")]
53 IoError(#[from] io::Error),
54 #[error("Session is not connected")]
55 NotConnected,
56 #[error("packet {0} unknown")]
57 Packet(u8),
58}
59
60impl From<SessionError> for Error {
61 fn from(err: SessionError) -> Self {
62 match err {
63 SessionError::AuthenticationError(_) => Error::unauthenticated(err),
64 SessionError::IoError(_) => Error::unavailable(err),
65 SessionError::NotConnected => Error::unavailable(err),
66 SessionError::Packet(_) => Error::unimplemented(err),
67 }
68 }
69}
70
71impl From<quick_xml::encoding::EncodingError> for Error {
72 fn from(err: quick_xml::encoding::EncodingError) -> Self {
73 Error::invalid_argument(err)
74 }
75}
76
77pub type UserAttributes = HashMap<String, String>;
78
79#[derive(Debug, Clone, Default)]
80pub struct UserData {
81 pub country: String,
82 pub canonical_username: String,
83 pub attributes: UserAttributes,
84}
85
86#[derive(Debug, Clone, Default)]
87struct SessionData {
88 session_id: String,
89 client_id: String,
90 client_name: String,
91 client_brand_name: String,
92 client_model_name: String,
93 connection_id: String,
94 auth_data: Vec<u8>,
95 time_delta: i64,
96 invalid: bool,
97 user_data: UserData,
98}
99
100struct SessionInternal {
101 config: SessionConfig,
102 data: RwLock<SessionData>,
103
104 http_client: HttpClient,
105 tx_connection: OnceLock<mpsc::UnboundedSender<(u8, Vec<u8>)>>,
106
107 apresolver: OnceLock<ApResolver>,
108 audio_key: OnceLock<AudioKeyManager>,
109 channel: OnceLock<ChannelManager>,
110 mercury: OnceLock<MercuryManager>,
111 dealer: OnceLock<DealerManager>,
112 spclient: OnceLock<SpClient>,
113 token_provider: OnceLock<TokenProvider>,
114 login5: OnceLock<Login5Manager>,
115 cache: Option<Arc<Cache>>,
116
117 handle: tokio::runtime::Handle,
118}
119
120#[derive(Clone)]
130pub struct Session(Arc<SessionInternal>);
131
132impl Session {
133 pub fn new(config: SessionConfig, cache: Option<Cache>) -> Self {
134 let http_client = HttpClient::new(config.proxy.as_ref());
135
136 debug!("new Session");
137
138 let session_data = SessionData {
139 client_id: config.client_id.clone(),
140 session_id: Uuid::new_v4().as_simple().to_string(),
142 ..SessionData::default()
143 };
144
145 Self(Arc::new(SessionInternal {
146 config,
147 data: RwLock::new(session_data),
148 http_client,
149 tx_connection: OnceLock::new(),
150 cache: cache.map(Arc::new),
151 apresolver: OnceLock::new(),
152 audio_key: OnceLock::new(),
153 channel: OnceLock::new(),
154 mercury: OnceLock::new(),
155 dealer: OnceLock::new(),
156 spclient: OnceLock::new(),
157 token_provider: OnceLock::new(),
158 login5: OnceLock::new(),
159 handle: tokio::runtime::Handle::current(),
160 }))
161 }
162
163 async fn connect_inner(
164 &self,
165 access_point: &SocketAddress,
166 credentials: Credentials,
167 ) -> Result<(Credentials, Transport), Error> {
168 const MAX_RETRIES: u8 = 1;
169 let mut transport = connection::connect_with_retry(
170 &access_point.0,
171 access_point.1,
172 self.config().proxy.as_ref(),
173 MAX_RETRIES,
174 )
175 .await?;
176 let mut reusable_credentials = connection::authenticate(
177 &mut transport,
178 credentials.clone(),
179 &self.config().device_id,
180 )
181 .await?;
182
183 if credentials.auth_type == AuthenticationType::AUTHENTICATION_SPOTIFY_TOKEN {
185 trace!(
186 "Reconnect using stored credentials as token authed sessions cannot use keymaster."
187 );
188 transport = connection::connect_with_retry(
189 &access_point.0,
190 access_point.1,
191 self.config().proxy.as_ref(),
192 MAX_RETRIES,
193 )
194 .await?;
195 reusable_credentials = connection::authenticate(
196 &mut transport,
197 reusable_credentials.clone(),
198 &self.config().device_id,
199 )
200 .await?;
201 }
202
203 Ok((reusable_credentials, transport))
204 }
205
206 pub async fn connect(
207 &self,
208 credentials: Credentials,
209 store_credentials: bool,
210 ) -> Result<(), Error> {
211 const MAX_AP_TRIES: u8 = 6;
213 let mut num_ap_tries = 0;
214 let (reusable_credentials, transport) = loop {
215 let ap = self.apresolver().resolve("accesspoint").await?;
216 info!("Connecting to AP \"{}:{}\"", ap.0, ap.1);
217 match self.connect_inner(&ap, credentials.clone()).await {
218 Ok(ct) => break ct,
219 Err(e) => {
220 num_ap_tries += 1;
221 if MAX_AP_TRIES == num_ap_tries {
222 error!("Tried too many access points");
223 return Err(e);
224 }
225 if let Some(AuthenticationError::LoginFailed(ErrorCode::TryAnotherAP)) =
226 e.error.downcast_ref::<AuthenticationError>()
227 {
228 warn!("Instructed to try another access point...");
229 continue;
230 } else if let Some(AuthenticationError::LoginFailed(..)) =
231 e.error.downcast_ref::<AuthenticationError>()
232 {
233 return Err(e);
234 } else {
235 warn!("Try another access point...");
236 continue;
237 }
238 }
239 }
240 };
241
242 let username = reusable_credentials
243 .username
244 .as_ref()
245 .map_or("UNKNOWN", |s| s.as_str());
246 info!("Authenticated as '{username}' !");
247 self.set_username(username);
248 self.set_auth_data(&reusable_credentials.auth_data);
249 if let Some(cache) = self.cache() {
250 if store_credentials {
251 let cred_changed = cache
252 .credentials()
253 .map(|c| c != reusable_credentials)
254 .unwrap_or(true);
255 if cred_changed {
256 cache.save_credentials(&reusable_credentials);
257 }
258 }
259 }
260
261 let (tx_connection, rx_connection) = mpsc::unbounded_channel();
264 self.0
265 .tx_connection
266 .set(tx_connection)
267 .map_err(|_| SessionError::NotConnected)?;
268
269 let (sink, stream) = transport.split();
270 let sender_task = UnboundedReceiverStream::new(rx_connection)
271 .map(Ok)
272 .forward(sink);
273 let session_weak = self.weak();
274 tokio::spawn(async move {
275 if let Err(e) = sender_task.await {
276 error!("{e}");
277 if let Some(session) = session_weak.try_upgrade() {
278 if !session.is_invalid() {
279 session.shutdown();
280 }
281 }
282 }
283 });
284
285 tokio::spawn(DispatchTask::new(self.weak(), stream));
286
287 Ok(())
288 }
289
290 pub fn apresolver(&self) -> &ApResolver {
291 self.0
292 .apresolver
293 .get_or_init(|| ApResolver::new(self.weak()))
294 }
295
296 pub fn audio_key(&self) -> &AudioKeyManager {
297 self.0
298 .audio_key
299 .get_or_init(|| AudioKeyManager::new(self.weak()))
300 }
301
302 pub fn channel(&self) -> &ChannelManager {
303 self.0
304 .channel
305 .get_or_init(|| ChannelManager::new(self.weak()))
306 }
307
308 pub fn http_client(&self) -> &HttpClient {
309 &self.0.http_client
310 }
311
312 pub fn mercury(&self) -> &MercuryManager {
313 self.0
314 .mercury
315 .get_or_init(|| MercuryManager::new(self.weak()))
316 }
317
318 pub fn dealer(&self) -> &DealerManager {
319 self.0
320 .dealer
321 .get_or_init(|| DealerManager::new(self.weak()))
322 }
323
324 pub fn spclient(&self) -> &SpClient {
325 self.0.spclient.get_or_init(|| SpClient::new(self.weak()))
326 }
327
328 pub fn token_provider(&self) -> &TokenProvider {
329 self.0
330 .token_provider
331 .get_or_init(|| TokenProvider::new(self.weak()))
332 }
333
334 pub fn login5(&self) -> &Login5Manager {
335 self.0
336 .login5
337 .get_or_init(|| Login5Manager::new(self.weak()))
338 }
339
340 pub fn time_delta(&self) -> i64 {
341 self.0.data.read().time_delta
342 }
343
344 pub fn spawn<T>(&self, task: T)
345 where
346 T: Future + Send + 'static,
347 T::Output: Send + 'static,
348 {
349 self.0.handle.spawn(task);
350 }
351
352 fn debug_info(&self) {
353 debug!(
354 "Session strong={} weak={}",
355 Arc::strong_count(&self.0),
356 Arc::weak_count(&self.0)
357 );
358 }
359
360 fn check_catalogue(attributes: &UserAttributes) {
361 if let Some(account_type) = attributes.get("type") {
362 if account_type != "premium" {
363 error!("librespot does not support {account_type:?} accounts.");
364 info!("Please support Spotify and your artists and sign up for a premium account.");
365
366 exit(1);
368 }
369 }
370 }
371
372 pub fn send_packet(&self, cmd: PacketType, data: Vec<u8>) -> Result<(), Error> {
373 match self.0.tx_connection.get() {
374 Some(tx) => Ok(tx.send((cmd as u8, data))?),
375 None => Err(SessionError::NotConnected.into()),
376 }
377 }
378
379 pub fn cache(&self) -> Option<&Arc<Cache>> {
380 self.0.cache.as_ref()
381 }
382
383 pub fn config(&self) -> &SessionConfig {
384 &self.0.config
385 }
386
387 pub fn user_data(&self) -> UserData {
391 self.0.data.read().user_data.clone()
392 }
393
394 pub fn session_id(&self) -> String {
395 self.0.data.read().session_id.clone()
396 }
397
398 pub fn set_session_id(&self, session_id: &str) {
399 session_id.clone_into(&mut self.0.data.write().session_id);
400 }
401
402 pub fn device_id(&self) -> &str {
403 &self.config().device_id
404 }
405
406 pub fn client_id(&self) -> String {
407 self.0.data.read().client_id.clone()
408 }
409
410 pub fn set_client_id(&self, client_id: &str) {
411 client_id.clone_into(&mut self.0.data.write().client_id);
412 }
413
414 pub fn client_name(&self) -> String {
415 self.0.data.read().client_name.clone()
416 }
417
418 pub fn set_client_name(&self, client_name: &str) {
419 client_name.clone_into(&mut self.0.data.write().client_name);
420 }
421
422 pub fn client_brand_name(&self) -> String {
423 self.0.data.read().client_brand_name.clone()
424 }
425
426 pub fn set_client_brand_name(&self, client_brand_name: &str) {
427 client_brand_name.clone_into(&mut self.0.data.write().client_brand_name);
428 }
429
430 pub fn client_model_name(&self) -> String {
431 self.0.data.read().client_model_name.clone()
432 }
433
434 pub fn set_client_model_name(&self, client_model_name: &str) {
435 client_model_name.clone_into(&mut self.0.data.write().client_model_name);
436 }
437
438 pub fn connection_id(&self) -> String {
439 self.0.data.read().connection_id.clone()
440 }
441
442 pub fn set_connection_id(&self, connection_id: &str) {
443 connection_id.clone_into(&mut self.0.data.write().connection_id);
444 }
445
446 pub fn username(&self) -> String {
447 self.0.data.read().user_data.canonical_username.clone()
448 }
449
450 pub fn set_username(&self, username: &str) {
451 username.clone_into(&mut self.0.data.write().user_data.canonical_username);
452 }
453
454 pub fn auth_data(&self) -> Vec<u8> {
455 self.0.data.read().auth_data.clone()
456 }
457
458 pub fn set_auth_data(&self, auth_data: &[u8]) {
459 auth_data.clone_into(&mut self.0.data.write().auth_data);
460 }
461
462 pub fn country(&self) -> String {
463 self.0.data.read().user_data.country.clone()
464 }
465
466 pub fn filter_explicit_content(&self) -> bool {
467 match self.get_user_attribute("filter-explicit-content") {
468 Some(value) => matches!(&*value, "1"),
469 None => false,
470 }
471 }
472
473 pub fn autoplay(&self) -> bool {
474 if let Some(overide) = self.config().autoplay {
475 return overide;
476 }
477
478 match self.get_user_attribute("autoplay") {
479 Some(value) => matches!(&*value, "1"),
480 None => false,
481 }
482 }
483
484 pub fn set_user_attribute(&self, key: &str, value: &str) -> Option<String> {
485 let mut dummy_attributes = UserAttributes::new();
486 dummy_attributes.insert(key.to_owned(), value.to_owned());
487 Self::check_catalogue(&dummy_attributes);
488
489 self.0
490 .data
491 .write()
492 .user_data
493 .attributes
494 .insert(key.to_owned(), value.to_owned())
495 }
496
497 pub fn set_user_attributes(&self, attributes: UserAttributes) {
498 Self::check_catalogue(&attributes);
499
500 self.0.data.write().user_data.attributes.extend(attributes)
501 }
502
503 pub fn get_user_attribute(&self, key: &str) -> Option<String> {
504 self.0.data.read().user_data.attributes.get(key).cloned()
505 }
506
507 fn weak(&self) -> SessionWeak {
508 SessionWeak(Arc::downgrade(&self.0))
509 }
510
511 pub fn shutdown(&self) {
512 debug!("Shutdown: Invalidating session");
513 self.0.data.write().invalid = true;
514 self.mercury().shutdown();
515 self.channel().shutdown();
516 }
517
518 pub fn is_invalid(&self) -> bool {
519 self.0.data.read().invalid
520 }
521}
522
523#[derive(Clone)]
524pub struct SessionWeak(Weak<SessionInternal>);
525
526impl SessionWeak {
527 fn try_upgrade(&self) -> Option<Session> {
528 self.0.upgrade().map(Session)
529 }
530
531 pub(crate) fn upgrade(&self) -> Session {
532 self.try_upgrade()
533 .expect("session was dropped and so should have this component")
534 }
535}
536
537impl Drop for SessionInternal {
538 fn drop(&mut self) {
539 debug!("drop Session");
540 }
541}
542
543#[derive(Clone, Copy, Default, Debug, PartialEq)]
544enum KeepAliveState {
545 #[default]
546 ExpectingPing,
548
549 PendingPong,
551
552 ExpectingPongAck,
554}
555
556const INITIAL_PING_TIMEOUT: TokioDuration = TokioDuration::from_secs(20);
557const PING_TIMEOUT: TokioDuration = TokioDuration::from_secs(80); const PONG_DELAY: TokioDuration = TokioDuration::from_secs(60);
559const PONG_ACK_TIMEOUT: TokioDuration = TokioDuration::from_secs(20);
560
561impl KeepAliveState {
562 fn debug(&self, sleep: &Sleep) {
563 let delay = sleep
564 .deadline()
565 .checked_duration_since(TokioInstant::now())
566 .map(|t| t.as_secs_f64())
567 .unwrap_or(f64::INFINITY);
568
569 trace!("keep-alive state: {self:?}, timeout in {delay:.1}");
570 }
571}
572
573pin_project! {
574 struct DispatchTask<S>
575 where
576 S: TryStream<Ok = (u8, Bytes)>
577 {
578 session: SessionWeak,
579 keep_alive_state: KeepAliveState,
580 #[pin]
581 stream: S,
582 #[pin]
583 timeout: Sleep,
584 }
585
586 impl<S> PinnedDrop for DispatchTask<S>
587 where
588 S: TryStream<Ok = (u8, Bytes)>
589 {
590 fn drop(_this: Pin<&mut Self>) {
591 debug!("drop Dispatch");
592 }
593 }
594}
595
596impl<S> DispatchTask<S>
597where
598 S: TryStream<Ok = (u8, Bytes)>,
599{
600 fn new(session: SessionWeak, stream: S) -> Self {
601 Self {
602 session,
603 keep_alive_state: KeepAliveState::ExpectingPing,
604 stream,
605 timeout: sleep(INITIAL_PING_TIMEOUT),
606 }
607 }
608
609 fn dispatch(
610 mut self: Pin<&mut Self>,
611 session: &Session,
612 cmd: u8,
613 data: Bytes,
614 ) -> Result<(), Error> {
615 use KeepAliveState::*;
616 use PacketType::*;
617
618 let packet_type = FromPrimitive::from_u8(cmd);
619 let cmd = match packet_type {
620 Some(cmd) => cmd,
621 None => {
622 trace!("Ignoring unknown packet {cmd:x}");
623 return Err(SessionError::Packet(cmd).into());
624 }
625 };
626
627 match packet_type {
628 Some(Ping) => {
629 trace!("Received Ping");
630 if self.keep_alive_state != ExpectingPing {
631 warn!("Received unexpected Ping from server")
632 }
633 let mut this = self.as_mut().project();
634 *this.keep_alive_state = PendingPong;
635 this.timeout
636 .as_mut()
637 .reset(TokioInstant::now() + PONG_DELAY);
638 this.keep_alive_state.debug(&this.timeout);
639
640 let server_timestamp = BigEndian::read_u32(data.as_ref()) as i64;
641 let timestamp = SystemTime::now()
642 .duration_since(UNIX_EPOCH)
643 .unwrap_or(Duration::ZERO)
644 .as_secs() as i64;
645 {
646 let mut data = session.0.data.write();
647 data.time_delta = server_timestamp.saturating_sub(timestamp);
648 }
649
650 session.debug_info();
651
652 Ok(())
653 }
654 Some(PongAck) => {
655 trace!("Received PongAck");
656 if self.keep_alive_state != ExpectingPongAck {
657 warn!("Received unexpected PongAck from server")
658 }
659 let mut this = self.as_mut().project();
660 *this.keep_alive_state = ExpectingPing;
661 this.timeout
662 .as_mut()
663 .reset(TokioInstant::now() + PING_TIMEOUT);
664 this.keep_alive_state.debug(&this.timeout);
665
666 Ok(())
667 }
668 Some(CountryCode) => {
669 let country = String::from_utf8(data.as_ref().to_owned())?;
670 info!("Country: {country:?}");
671 session.0.data.write().user_data.country = country;
672 Ok(())
673 }
674 Some(StreamChunkRes) | Some(ChannelError) => session.channel().dispatch(cmd, data),
675 Some(AesKey) | Some(AesKeyError) => session.audio_key().dispatch(cmd, data),
676 Some(MercuryReq) | Some(MercurySub) | Some(MercuryUnsub) | Some(MercuryEvent) => {
677 session.mercury().dispatch(cmd, data)
678 }
679 Some(ProductInfo) => {
680 let data = std::str::from_utf8(&data)?;
681 let mut reader = quick_xml::Reader::from_str(data);
682
683 let mut buf = Vec::new();
684 let mut current_element = String::new();
685 let mut user_attributes: UserAttributes = HashMap::new();
686
687 loop {
688 match reader.read_event_into(&mut buf) {
689 Ok(Event::Start(ref element)) => {
690 std::str::from_utf8(element)?.clone_into(&mut current_element)
691 }
692 Ok(Event::End(_)) => {
693 current_element = String::new();
694 }
695 Ok(Event::Text(ref value)) => {
696 if !current_element.is_empty() {
697 let _ = user_attributes.insert(
698 current_element.clone(),
699 value.xml_content()?.to_string(),
700 );
701 }
702 }
703 Ok(Event::Eof) => break,
704 Ok(_) => (),
705 Err(e) => warn!(
706 "Error parsing XML at position {}: {:?}",
707 reader.buffer_position(),
708 e
709 ),
710 }
711 }
712
713 trace!("Received product info: {user_attributes:#?}");
714 Session::check_catalogue(&user_attributes);
715
716 session.0.data.write().user_data.attributes = user_attributes;
717 Ok(())
718 }
719 Some(SecretBlock)
720 | Some(LegacyWelcome)
721 | Some(UnknownDataAllZeros)
722 | Some(LicenseVersion) => Ok(()),
723 _ => {
724 trace!("Ignoring {cmd:?} packet with data {data:#?}");
725 Err(SessionError::Packet(cmd as u8).into())
726 }
727 }
728 }
729}
730
731impl<S> Future for DispatchTask<S>
732where
733 S: TryStream<Ok = (u8, Bytes), Error = std::io::Error>,
734 <S as TryStream>::Ok: std::fmt::Debug,
735{
736 type Output = Result<(), S::Error>;
737
738 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
739 use KeepAliveState::*;
740
741 let session = match self.session.try_upgrade() {
742 Some(session) => session,
743 None => return Poll::Ready(Ok(())),
744 };
745
746 loop {
748 match self.as_mut().project().stream.try_poll_next(cx) {
749 Poll::Ready(Some(Ok((cmd, data)))) => {
750 let result = self.as_mut().dispatch(&session, cmd, data);
751 if let Err(e) = result {
752 debug!("could not dispatch command: {e}");
753 }
754 }
755 Poll::Ready(None) => {
756 warn!("Connection to server closed.");
757 session.shutdown();
758 return Poll::Ready(Ok(()));
759 }
760 Poll::Ready(Some(Err(e))) => {
761 error!("Connection to server closed.");
762 session.shutdown();
763 return Poll::Ready(Err(e));
764 }
765 Poll::Pending => break,
766 }
767 }
768
769 let mut this = self.as_mut().project();
787 if let Poll::Ready(()) = this.timeout.as_mut().poll(cx) {
788 match this.keep_alive_state {
789 ExpectingPing | ExpectingPongAck => {
790 if !session.is_invalid() {
791 session.shutdown();
792 }
793 return Poll::Ready(Err(io::Error::new(
795 io::ErrorKind::TimedOut,
796 format!(
797 "session lost connection to server ({:?})",
798 this.keep_alive_state
799 ),
800 )));
801 }
802 PendingPong => {
803 trace!("Sending Pong");
804 let _ = session.send_packet(PacketType::Pong, vec![0, 0, 0, 0]);
807 *this.keep_alive_state = ExpectingPongAck;
808 this.timeout
809 .as_mut()
810 .reset(TokioInstant::now() + PONG_ACK_TIMEOUT);
811 this.keep_alive_state.debug(&this.timeout);
812 }
813 }
814 }
815
816 Poll::Pending
817 }
818}