use super::imp::CAT;
use super::mediaformat::*;
use atomic_refcell::AtomicRefCell;
use data_encoding::BASE64;
use futures::future;
use futures::prelude::*;
use gst::glib;
use httparse::Response;
use std::pin::Pin;
use std::sync::{Arc, LazyLock, Mutex};
use std::task::{Context, Poll};
use rustls_pki_types::ServerName;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::runtime;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::time::Duration;
use rustls::ClientConfig;
use rustls_platform_verifier::BuilderVerifierExt;
use tokio_rustls::TlsConnector;
use url::Url;
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
enum TcpOrTlsStream {
Plain(tokio::net::TcpStream),
Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
}
impl AsyncWrite for TcpOrTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
TcpOrTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
TcpOrTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TcpOrTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
TcpOrTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TcpOrTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
TcpOrTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
impl AsyncRead for TcpOrTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TcpOrTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
TcpOrTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
enum State {
Connecting {
join_handle: JoinHandle<Result<TcpOrTlsStream, gst::ErrorMessage>>,
},
WaitingForConnect,
Streaming {
stream: TcpOrTlsStream,
},
Error,
Dropping,
}
#[derive(Default, Debug)]
enum Canceller {
#[default]
None,
Armed(future::AbortHandle),
Cancelled,
}
impl Canceller {
fn cancel(&mut self) {
if let Canceller::Armed(ref abort_handle) = *self {
abort_handle.abort();
}
*self = Canceller::Cancelled;
}
fn clear_cancel(&mut self) {
if matches!(*self, Canceller::Cancelled) {
*self = Canceller::None;
}
}
}
#[derive(Debug)]
pub(super) struct IceClient {
state: AtomicRefCell<State>,
caps_tx: AtomicRefCell<Option<oneshot::Sender<MediaFormat>>>,
canceller: Mutex<Canceller>,
log_id: glib::GString, }
const USER_AGENT: &str = concat!(
"GStreamer icecastsink ",
env!("CARGO_PKG_VERSION"),
"-",
env!("COMMIT_ID")
);
static RUNTIME: LazyLock<runtime::Runtime> = LazyLock::new(|| {
runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(1)
.build()
.unwrap()
});
impl IceClient {
pub(super) fn new(
url: Url,
public: bool,
stream_name: Option<String>,
log_id: glib::GString,
) -> Result<Self, gst::ErrorMessage> {
let (abort_handle, abort_registration) = future::AbortHandle::new_pair();
let (caps_tx, caps_rx) = oneshot::channel();
let debug_log_id = log_id.clone();
gst::info!(
CAT,
id = &log_id,
"Initiating connection to server (in new thread).. "
);
let future = async move {
let public = public as i32;
let scheme = url.scheme();
let host_name = url.host_str().unwrap();
let port = url.port().unwrap_or(8000);
let path = url.path();
let username = url.username();
let password = url.password().unwrap_or("");
gst::info!(
CAT,
id = &log_id,
"Connecting via {scheme} to server {host_name}:{port}.."
);
let stream = TcpStream::connect(format!("{host_name}:{port}"))
.await
.map_err(|err| {
gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to connect to server {host_name}:{port}: {err}"]
)
})?;
gst::info!(CAT, id = &log_id, "Connected to server {host_name}:{port}");
let stream = match (scheme, stream) {
("ice+https", stream) => {
let provider = Arc::new(rustls::crypto::ring::default_provider());
let config = ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.unwrap()
.with_platform_verifier()
.unwrap()
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let dnsname = ServerName::try_from(host_name.to_string()).map_err(|err| {
gst::error_msg!(
gst::ResourceError::Write,
["Server name failed for '{host_name}': {err}"]
)
})?;
gst::info!(CAT, id = &log_id, "TLS connect..");
let stream = connector.connect(dnsname, stream).await.map_err(|err| {
gst::error_msg!(
gst::ResourceError::Write,
["TLS handshake with server failed: {err}"]
)
})?;
gst::info!(CAT, id = &log_id, "TLS setup done");
TcpOrTlsStream::Tls(stream)
}
("ice+http", stream) => TcpOrTlsStream::Plain(stream),
_ => unreachable!(),
};
let mut stream = BufReader::new(stream);
gst::info!(CAT, id = &log_id, "Sending OPTIONS request to server..");
let options_request = format!(
"\
OPTIONS * HTTP/1.1\r\n\
Host: {host_name}:{port}\r\n\
User-Agent: {USER_AGENT}\r\n\
Connection: keep-alive\r\n\
\r\n"
);
stream
.write_all(options_request.as_bytes())
.await
.map_err(|err| {
gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to send OPTIONS request to server: {err}"]
)
})?;
const MAX_RESPONSE_LENGTH: usize = 2048;
let mut response = String::new();
while !response.ends_with("\r\n\r\n") {
stream.read_line(&mut response).await.map_err(|err| {
gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to complete OPTIONS handshake with server: {err}"]
)
})?;
if response.len() > MAX_RESPONSE_LENGTH {
return Err(gst::error_msg!(
gst::ResourceError::OpenRead,
["Excessive server OPTIONS response length"]
));
}
}
if response.contains("Server: rocketstreamingserver")
&& !response.contains("Content-Length: ")
&& stream.buffer().ends_with(b"\r\n\r\n")
{
response.pop().unwrap();
response.pop().unwrap();
while !response.ends_with("\r\n\r\n") {
stream.read_line(&mut response).await.map_err(|err| {
gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to complete OPTIONS handshake with server: {err}"]
)
})?;
if response.len() > MAX_RESPONSE_LENGTH {
return Err(gst::error_msg!(
gst::ResourceError::OpenRead,
["Excessive server OPTIONS response length"]
));
}
}
}
gst::info!(CAT, id = &log_id, "OPTIONS response: {response}");
let mut r_headers = [httparse::EMPTY_HEADER; 32];
let mut r = Response::new(&mut r_headers);
use httparse::Status::{Complete, Partial};
match r.parse(response.as_bytes()) {
Ok(Complete(_)) => {
gst::trace!(CAT, id = &log_id, "Parsed OPTIONS response: {r:?}");
match r.code {
Some(200..=204) => Ok(()),
Some(401) => Ok(()),
Some(405) => Ok(()),
_ => Err(gst::error_msg!(
gst::ResourceError::OpenWrite,
[
"Error probing server via OPTIONS request: {} {}",
r.code.unwrap(),
r.reason.unwrap_or("")
]
)),
}
}
Ok(Partial) => Err(gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to parse OPTIONS response from server: partial response"]
)),
Err(err) => Err(gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to parse OPTIONS response from server: {err:?}"]
)),
}?;
if let Some(mut content_len) = r
.headers
.iter()
.find(|h| h.name.eq_ignore_ascii_case("content-length"))
.and_then(|h| std::str::from_utf8(h.value).ok())
.and_then(|s| s.parse::<usize>().ok())
{
gst::debug!(CAT, id = &log_id, "Content-Length: {content_len} bytes");
while content_len > 0 {
let n_bytes = content_len.min(4096);
gst::trace!(CAT, id = &log_id, "Reading {n_bytes} content bytes");
let mut buf = vec![0u8; n_bytes];
let _ = stream.read_exact(&mut buf).await.map_err(|err| {
gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to read content after OPTIONS response: {err}"]
)
})?;
content_len -= n_bytes;
}
}
if !stream.buffer().is_empty() {
let n_bytes = stream.buffer().len();
gst::warning!(
CAT,
id = &log_id,
"Discarding {n_bytes} excess bytes after OPTIONS response!"
);
let mut buf = vec![0u8; n_bytes];
let _ = stream.read_exact(&mut buf).await.unwrap();
}
gst::info!(
CAT,
id = &log_id,
"Waiting for initial caps, media format.."
);
let media_format: MediaFormat = caps_rx.await.map_err(|err| {
gst::error_msg!(
gst::ResourceError::Read,
["Failed to receive media format: {err}"]
)
})?;
let (content_type, ice_audio_info) = {
if media_format == MediaFormat::None {
return Err(gst::error_msg!(
gst::StreamError::Format,
["No media format configured"]
));
}
gst::info!(CAT, id = &log_id, "Media format: {media_format:?}");
let content_type = media_format.content_type().unwrap();
let ice_audio_info = media_format.ice_audio_info().unwrap();
(content_type, ice_audio_info)
};
let auth_header = if !username.is_empty() || !password.is_empty() {
format!(
"Authorization: Basic {}",
BASE64.encode(format!("{username}:{password}").as_bytes())
)
} else {
String::new()
};
let mut ice_headers = String::with_capacity(1024);
ice_headers.push_str(&format!("Ice-audio-info: {ice_audio_info}\r\n"));
ice_headers.push_str(&format!("Ice-public: {public}\r\n"));
if let Some(stream_name) = stream_name {
ice_headers.push_str(&format!("Ice-name: {stream_name}\r\n"));
}
let put_request = format!(
"\
PUT {path} HTTP/1.1\r\n\
Host: {host_name}:{port}\r\n\
{auth_header}\r\n\
User-Agent: {USER_AGENT}\r\n\
Content-Type: {content_type}\r\n\
{ice_headers}\
Expect: 100-continue\r\n\
\r\n"
);
gst::info!(CAT, id = &log_id, "PUT request: {put_request}");
stream
.write_all(put_request.as_bytes())
.await
.map_err(|err| {
gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to send PUT request to server: {err}"]
)
})?;
response.clear();
while !response.ends_with("\r\n\r\n") {
stream.read_line(&mut response).await.map_err(|err| {
gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to receive PUT request response: {err}"]
)
})?;
if response.len() > MAX_RESPONSE_LENGTH {
return Err(gst::error_msg!(
gst::ResourceError::OpenRead,
["Excessive server PUT request response length"]
));
}
}
gst::info!(CAT, id = &log_id, "PUT response: {response}");
let mut r_headers = [httparse::EMPTY_HEADER; 32];
let mut r = Response::new(&mut r_headers);
match r.parse(response.as_bytes()) {
Ok(Complete(_)) => {
gst::trace!(CAT, id = &log_id, "Parsed PUT response: {r:?}");
match r.code {
Some(100) => Ok(()),
Some(200..=204) => Ok(()),
Some(401) => Err(if auth_header.is_empty() {
gst::error_msg!(
gst::ResourceError::NotAuthorized,
[
"Server requires authorization, but no username and/or password configured"
]
)
} else {
gst::error_msg!(
gst::ResourceError::NotAuthorized,
["Server didn't accept credentials for mount point {path}"]
)
}),
Some(403) => match r.reason {
Some("Content-type not supported") => Err(gst::error_msg!(
gst::ResourceError::Settings,
["Server doesn't support content type {content_type}"]
)),
Some("Mountpoint in use") => Err(gst::error_msg!(
gst::ResourceError::Busy,
["Mount point {path} already in use by another client"]
)),
_ => Err(gst::error_msg!(
gst::ResourceError::Settings,
[
"Server didn't accept content type {content_type} on mount point {path} ({})",
r.reason.unwrap_or("no reason")
]
)),
},
Some(404) => Err(gst::error_msg!(
gst::ResourceError::NotFound,
[
"Server didn't accept mount point {path} ({})",
r.reason.unwrap_or("no reason")
]
)),
Some(405) => Err(gst::error_msg!(
gst::ResourceError::OpenWrite,
["Server doesn't support PUT method, upgrade your server!"]
)),
_ => Err(gst::error_msg!(
gst::ResourceError::OpenWrite,
[
"Error sending PUT request: {} {}",
r.code.unwrap(),
r.reason.unwrap_or("")
]
)),
}
}
Ok(Partial) => Err(gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to parse PUT response from server: partial response"]
)),
Err(err) => Err(gst::error_msg!(
gst::ResourceError::OpenWrite,
["Failed to parse PUT response from server: {err}"]
)),
}?;
if !stream.buffer().is_empty() {
let n_bytes = stream.buffer().len();
gst::warning!(
CAT,
id = &log_id,
"Discarding {n_bytes} excess bytes after PUT response!"
);
let mut buf = vec![0u8; n_bytes];
let _ = stream.read_exact(&mut buf).await.unwrap();
}
Ok(stream.into_inner()) };
let future = async {
future::Abortable::new(future, abort_registration)
.await
.map_err(|err| gst::error_msg!(gst::LibraryError::Failed, ["{err}"]))?
};
let join_handle = RUNTIME.spawn(future);
let client = IceClient {
state: AtomicRefCell::new(State::Connecting { join_handle }),
caps_tx: AtomicRefCell::new(Some(caps_tx)),
canceller: Mutex::new(Canceller::Armed(abort_handle)),
log_id: debug_log_id,
};
Ok(client)
}
pub(super) fn set_media_format(&self, media_format: MediaFormat) {
gst::info!(CAT, id = &self.log_id, "{media_format:?}");
let mut caps_tx_storage = self.caps_tx.borrow_mut();
if let Some(caps_tx) = caps_tx_storage.take() {
let _ = caps_tx.send(media_format);
}
}
pub(super) fn cancel(&self) {
gst::info!(CAT, id = &self.log_id, "Cancelling..");
let mut canceller = self.canceller.lock().unwrap();
canceller.cancel();
gst::log!(CAT, id = &self.log_id, "Cancelled!");
}
pub(super) fn clear_cancel(&self) {
let mut canceller = self.canceller.lock().unwrap();
canceller.clear_cancel();
gst::info!(CAT, id = &self.log_id, "Cancel cleared");
}
pub(super) fn wait_for_connection_and_handshake(
&self,
timeout_millisecs: u32,
) -> Result<(), Option<gst::ErrorMessage>> {
let mut state = self.state.borrow_mut();
if !matches!(*state, State::Connecting { .. }) {
return Ok(());
}
let mut temp = State::WaitingForConnect;
std::mem::swap(&mut *state, &mut temp);
let State::Connecting { join_handle } = temp else {
unreachable!()
};
gst::info!(
CAT,
id = &self.log_id,
"Waiting for connection + handshake to server to complete"
);
let future = async {
join_handle
.await
.map_err(|err| gst::error_msg!(gst::LibraryError::Failed, ["{err}"]))?
};
let res = self.sync_wait(future, timeout_millisecs);
let stream = match res {
Ok(res) => res,
Err(err) => {
gst::debug!(CAT, id = &self.log_id, "Error {err:?}");
*state = State::Error;
return Err(err);
}
};
*state = State::Streaming { stream };
gst::info!(CAT, id = &self.log_id, "Ready to stream");
Ok(())
}
pub(super) fn send_data(
&self,
data: &[u8],
timeout_millisecs: u32,
) -> Result<(), Option<gst::ErrorMessage>> {
let mut state = self.state.borrow_mut();
let State::Streaming { ref mut stream } = *state else {
unreachable!();
};
gst::trace!(
CAT,
id = &self.log_id,
"Sending {} bytes of data..",
data.len()
);
let future = async move {
stream.write_all(data).await.map_err(|err| {
gst::error_msg!(
gst::ResourceError::Write,
["Failed to send data to server: {err}"]
)
})
};
let res = self.sync_wait(future, timeout_millisecs);
gst::trace!(CAT, id = &self.log_id, "Done sending data: {res:?}");
if let Err(err) = res {
gst::debug!(CAT, id = &self.log_id, "Error {err:?}");
*state = State::Error;
return Err(err);
}
Ok(())
}
fn sync_wait<F, T>(&self, future: F, timeout: u32) -> Result<T, Option<gst::ErrorMessage>>
where
F: Send + Future<Output = Result<T, gst::ErrorMessage>>,
T: Send + 'static,
{
let mut canceller = self.canceller.lock().unwrap();
if matches!(*canceller, Canceller::Cancelled) {
return Err(None);
}
let (abort_handle, abort_registration) = future::AbortHandle::new_pair();
*canceller = Canceller::Armed(abort_handle);
drop(canceller);
let future = async {
if timeout == 0 {
future.await
} else {
let res = tokio::time::timeout(Duration::from_millis(timeout.into()), future).await;
match res {
Ok(res) => res,
Err(_) => Err(gst::error_msg!(
gst::ResourceError::Write,
["Request timeout"]
)),
}
}
};
let future = async {
match future::Abortable::new(future, abort_registration).await {
Ok(res) => res.map_err(Some),
Err(_) => Err(None),
}
};
let res = RUNTIME.block_on(future);
let mut canceller = self.canceller.lock().unwrap();
if matches!(*canceller, Canceller::Cancelled) {
return Err(None);
}
*canceller = Canceller::None;
res
}
}
impl Drop for IceClient {
fn drop(&mut self) {
gst::log!(CAT, id = &self.log_id, "Dropping client");
let mut state = self.state.borrow_mut();
let mut temp = State::Dropping;
std::mem::swap(&mut *state, &mut temp);
if let State::Connecting { join_handle } = temp {
gst::info!(
CAT,
id = &self.log_id,
"Aborting join handle of thread doing the initial connection"
);
join_handle.abort();
};
}
}