use std::{
io,
net::{self, SocketAddr},
time::Duration,
};
use interprocess::local_socket::{self, traits::Stream as StreamExt};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::ipc_common::{RemoteEvent, TrackInfo};
pub(crate) const IPC_PROTOCOL_VERSION: [u8; 4] = 1u32.to_be_bytes();
#[cfg(any(feature = "local-session", feature = "client"))]
pub mod cpal_thread;
mod methods;
pub use methods::*;
pub mod dto;
#[derive(Clone)]
pub enum ConnectionType {
Local(String),
Remote(SocketAddr),
}
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, Copy)]
pub struct SessionToken(#[serde(with = "hex::serde")] [u8; 32]);
impl From<[u8; 32]> for SessionToken {
fn from(value: [u8; 32]) -> Self {
SessionToken(value)
}
}
impl SessionToken {
#[cfg(any(feature = "local-session", feature = "client"))]
pub const LOCAL_SESSION: Self = Self([0u8; 32]);
pub(crate) fn generate() -> Self {
Self(rand::random::<[u8; 32]>())
}
}
impl std::str::FromStr for SessionToken {
type Err = hex::FromHexError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
#[cfg(any(feature = "local-session", feature = "client"))]
if s.eq_ignore_ascii_case("local") {
return Ok(SessionToken::LOCAL_SESSION);
}
let mut token = [0u8; 32];
hex::decode_to_slice(s, &mut token)?;
Ok(Self(token))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum SessionTarget {
Create,
Join { token: SessionToken },
}
#[derive(Serialize, Deserialize)]
pub(crate) struct SessionHandshakeRequest {
pub session_target: SessionTarget,
}
#[derive(Serialize, Deserialize)]
pub(crate) struct SessionHandshakeResponse {
pub token: SessionToken,
}
#[derive(Serialize, Deserialize)]
pub(crate) struct AudioHandshakeRequest {
pub session_target: SessionTarget,
}
#[derive(Serialize, Deserialize)]
pub(crate) struct AudioHandshakeResponse {
pub track_info: Option<TrackInfo>,
}
const MAX_PACKET_LEN: u32 = 1024 * 1024 * 8;
pub trait Stream: io::Read + io::Write + Send + Sync {
fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()>;
fn set_read_timeout(&self, duration: Option<Duration>) -> io::Result<()>;
fn set_write_timeout(&self, duration: Option<Duration>) -> io::Result<()>;
fn read_data(&mut self) -> io::Result<Vec<u8>> {
let mut buf = [0u8; 4];
self.read_exact(&mut buf)?;
let len = u32::from_be_bytes(buf);
if len == 0 || len > MAX_PACKET_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received packet was too large: {len}"),
));
}
let mut data_buf = vec![0u8; len as usize];
self.read_exact(&mut data_buf)?;
Ok(data_buf)
}
fn read_data_nonblocking(&mut self) -> io::Result<Option<Vec<u8>>> {
const MAX_PACKET_LEN: u32 = 1024 * 10;
let mut buf = [0u8; 4];
match self.read_exact(&mut buf[..1]) {
Ok(()) => {
self.set_nonblocking(false)?;
if let Err(err) = self.read_exact(&mut buf[1..])
&& matches!(err.kind(), io::ErrorKind::TimedOut)
{
return Ok(None);
}
}
Err(err)
if matches!(
err.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
) =>
{
return Ok(None);
}
Err(err) => return Err(err),
}
let len = u32::from_be_bytes(buf);
if len == 0 || len > MAX_PACKET_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"received packet was too large",
));
}
let mut data_buf = vec![0u8; len as usize];
if let Err(err) = self.read_exact(&mut data_buf)
&& matches!(err.kind(), io::ErrorKind::TimedOut)
{
return Ok(None);
}
Ok(Some(data_buf))
}
fn local(&self) -> bool {
false
}
}
impl Stream for local_socket::Stream {
fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
StreamExt::set_nonblocking(self, nonblocking)
}
fn set_read_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.set_recv_timeout(duration)
}
fn set_write_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.set_send_timeout(duration)
}
fn local(&self) -> bool {
true
}
}
impl Stream for net::TcpStream {
fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
net::TcpStream::set_nonblocking(self, nonblocking)
}
fn set_read_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
net::TcpStream::set_read_timeout(self, duration)
}
fn set_write_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
net::TcpStream::set_write_timeout(self, duration)
}
}
#[derive(Serialize, Deserialize)]
pub enum SessionPacket {
Event(RemoteEvent),
}
#[derive(Debug, Error, Serialize, Deserialize, Clone, Copy)]
pub enum HandshakeError {
#[error("The client is on the wrong version. Expected v{expected}. Found v{connected}")]
WrongVersion { expected: u32, connected: u32 },
#[error(
"The connection was refused. This could be because of invalid data, or because of invalid authentication"
)]
Refused,
}
pub enum HandshakeType {}
impl HandshakeType {
pub const SESSION: [u8; 4] = *b"CTRL";
pub const SESSION_METHOD: [u8; 4] = *b"SMTD";
pub const AUDIO: [u8; 4] = *b"SESS";
pub const METHOD: [u8; 4] = *b"MTHD";
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
pub enum LoudnormMode {
None = 0,
Track = 1,
Album = 2,
}