use std::{
borrow::Cow,
fmt::Debug,
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine};
use http::StatusCode;
use livekit_protocol as proto;
use livekit_runtime::{interval, sleep, Instant, JoinHandle};
use parking_lot::Mutex;
use prost::Message;
use thiserror::Error;
use tokio::sync::{mpsc, Mutex as AsyncMutex, RwLock as AsyncRwLock};
#[cfg(feature = "signal-client-tokio")]
use tokio_tungstenite::tungstenite::Error as WsError;
#[cfg(feature = "__signal-client-async-compatible")]
use async_tungstenite::tungstenite::Error as WsError;
use crate::{http_client, signal_client::signal_stream::SignalStream};
mod region;
mod signal_stream;
pub use region::RegionUrlProvider;
pub type SignalEmitter = mpsc::UnboundedSender<SignalEvent>;
pub type SignalEvents = mpsc::UnboundedReceiver<SignalEvent>;
pub type SignalResult<T> = Result<T, SignalError>;
pub const JOIN_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);
pub const SIGNAL_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const REGION_FETCH_TIMEOUT: Duration = Duration::from_secs(3);
const VALIDATE_TIMEOUT: Duration = Duration::from_secs(3);
pub const PROTOCOL_VERSION: u32 = 17;
const CLIENT_CAPABILITIES: &[proto::client_info::Capability] =
&[proto::client_info::Capability::CapPacketTrailer];
#[derive(Error, Debug)]
pub enum SignalError {
#[error("ws failure: {0}")]
WsError(#[from] WsError),
#[error("failed to parse the url: {0}")]
UrlParse(String),
#[error("access token has invalid characters")]
TokenFormat,
#[error("client error: {0} - {1}")]
Client(StatusCode, String),
#[error("server error: {0} - {1}")]
Server(StatusCode, String),
#[error("failed to decode messages from server: {0}")]
ProtoParse(#[from] prost::DecodeError),
#[error("{0}")]
Timeout(String),
#[error("failed to send message to the server")]
SendError,
#[error("failed to retrieve region info: {0}")]
RegionError(String),
#[error("server sent leave during reconnect: reason={reason:?}, action={action:?}")]
LeaveRequest { reason: proto::DisconnectReason, action: proto::leave_request::Action },
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct SignalSdkOptions {
pub sdk: String,
pub sdk_version: Option<String>,
}
impl Default for SignalSdkOptions {
fn default() -> Self {
Self { sdk: "rust".to_string(), sdk_version: None }
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct SignalOptions {
pub auto_subscribe: bool,
pub adaptive_stream: bool,
pub sdk_options: SignalSdkOptions,
pub single_peer_connection: bool,
pub connect_timeout: Duration,
}
impl Default for SignalOptions {
fn default() -> Self {
Self {
auto_subscribe: true,
adaptive_stream: false,
sdk_options: SignalSdkOptions::default(),
single_peer_connection: false,
connect_timeout: SIGNAL_CONNECT_TIMEOUT,
}
}
}
pub enum SignalEvent {
Message(Box<proto::signal_response::Message>),
Close(Cow<'static, str>),
}
struct SignalInner {
stream: AsyncRwLock<Option<SignalStream>>,
token: Mutex<String>, reconnecting: AtomicBool,
queue: AsyncMutex<Vec<proto::signal_request::Message>>,
url: String,
options: SignalOptions,
join_response: proto::JoinResponse,
request_id: AtomicU32,
single_pc_mode_active: bool,
}
pub struct SignalClient {
inner: Arc<SignalInner>,
emitter: SignalEmitter,
handle: Mutex<Option<JoinHandle<()>>>,
}
impl Debug for SignalClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SignalClient")
.field("url", &self.url())
.field("join_response", &self.join_response())
.field("options", &self.options())
.finish()
}
}
impl SignalClient {
pub async fn connect(
url: &str,
token: &str,
options: SignalOptions,
) -> SignalResult<(Self, proto::JoinResponse, SignalEvents)> {
let handle_success = |inner: Arc<SignalInner>, join_response, stream_events| {
let (emitter, events) = mpsc::unbounded_channel();
let signal_task =
livekit_runtime::spawn(signal_task(inner.clone(), emitter.clone(), stream_events));
(Self { inner, emitter, handle: Mutex::new(Some(signal_task)) }, join_response, events)
};
match SignalInner::connect(url, token, options.clone()).await {
Ok((inner, join_response, stream_events)) => {
Ok(handle_success(inner, join_response, stream_events))
}
Err(err) => {
if matches!(&err, SignalError::WsError(WsError::Http(e)) if e.status() != 403) {
log::error!("unexpected signal error: {}", err.to_string());
}
let urls = RegionUrlProvider::fetch_region_urls(url, token).await?;
let mut last_err = err;
for url in urls.iter() {
log::info!("fallback connection to: {}", url);
match SignalInner::connect(url, token, options.clone()).await {
Ok((inner, join_response, stream_events)) => {
return Ok(handle_success(inner, join_response, stream_events))
}
Err(err) => last_err = err,
}
}
Err(last_err)
}
}
}
pub async fn restart(&self) -> SignalResult<proto::ReconnectResponse> {
self.close().await;
let (reconnect_response, stream_events) = self.inner.restart().await?;
let signal_task = livekit_runtime::spawn(signal_task(
self.inner.clone(),
self.emitter.clone(),
stream_events,
));
*self.handle.lock() = Some(signal_task);
Ok(reconnect_response)
}
pub async fn set_reconnected(&self) {
self.inner.set_reconnected().await;
}
pub async fn send(&self, signal: proto::signal_request::Message) {
self.inner.send(signal).await
}
pub async fn close(&self) {
self.inner.close(true).await;
let handle = self.handle.lock().take();
if let Some(signal_task) = handle {
let _ = signal_task.await;
}
}
pub fn join_response(&self) -> proto::JoinResponse {
self.inner.join_response.clone()
}
pub fn options(&self) -> SignalOptions {
self.inner.options.clone()
}
pub fn url(&self) -> String {
self.inner.url.clone()
}
pub fn token(&self) -> String {
self.inner.token.lock().clone()
}
pub fn next_request_id(&self) -> u32 {
self.inner.next_request_id().clone()
}
pub fn is_single_pc_mode_active(&self) -> bool {
self.inner.is_single_pc_mode_active()
}
pub async fn is_connected(&self) -> bool {
self.inner.stream.read().await.is_some()
}
}
impl SignalInner {
pub async fn connect(
url: &str,
token: &str,
options: SignalOptions,
) -> SignalResult<(
Arc<Self>,
proto::JoinResponse,
mpsc::UnboundedReceiver<Box<proto::signal_response::Message>>,
)> {
let use_v1_path = options.single_peer_connection;
let lk_url = get_livekit_url(url, &options, use_v1_path, false, None, "")?;
let (stream, mut events, single_pc_mode_active) =
match SignalStream::connect(lk_url.clone(), token, options.connect_timeout).await {
Ok((new_stream, stream_events)) => {
log::debug!(
"signal connection successful: path={}, single_pc_mode={}",
if use_v1_path { "v1" } else { "v0" },
use_v1_path
);
(new_stream, stream_events, use_v1_path)
}
Err(err) => {
log::warn!(
"signal connection failed on {} path: {:?}",
if use_v1_path { "v1" } else { "v0" },
err
);
if let SignalError::TokenFormat = err {
return Err(err);
}
let is_not_found =
matches!(&err, SignalError::WsError(WsError::Http(e)) if e.status() == 404);
if use_v1_path && is_not_found {
let lk_url_v0 = get_livekit_url(url, &options, false, false, None, "")?;
log::warn!("v1 path not found (404), falling back to v0 path");
match SignalStream::connect(
lk_url_v0.clone(),
token,
options.connect_timeout,
)
.await
{
Ok((new_stream, stream_events)) => (new_stream, stream_events, false),
Err(err) => {
log::error!("v0 fallback also failed: {:?}", err);
if let SignalError::TokenFormat = err {
return Err(err);
}
Self::validate(lk_url_v0, token).await?;
return Err(err);
}
}
} else {
Self::validate(lk_url, token).await?;
return Err(err);
}
}
};
let join_response = get_join_response(&mut events).await?;
let inner = Arc::new(SignalInner {
stream: AsyncRwLock::new(Some(stream)),
token: Mutex::new(token.to_owned()),
reconnecting: AtomicBool::new(false),
queue: Default::default(),
options,
url: url.to_string(),
join_response: join_response.clone(),
request_id: AtomicU32::new(1),
single_pc_mode_active,
});
Ok((inner, join_response, events))
}
async fn validate(ws_url: url::Url, token: &str) -> SignalResult<()> {
let validate_url = get_validate_url(ws_url);
let validate_fut = async {
if let Ok(res) = http_client::get_with_token(validate_url.as_str(), token).await {
let status = res.status();
let body = res.text().await.ok().unwrap_or_default();
if status.is_client_error() {
return Err(SignalError::Client(status, body));
} else if status.is_server_error() {
return Err(SignalError::Server(status, body));
}
}
Ok(())
};
livekit_runtime::timeout(VALIDATE_TIMEOUT, validate_fut)
.await
.map_err(|_| SignalError::Timeout("validate request timed out".into()))?
}
pub fn is_single_pc_mode_active(&self) -> bool {
self.single_pc_mode_active
}
pub async fn restart(
self: &Arc<Self>,
) -> SignalResult<(
proto::ReconnectResponse,
mpsc::UnboundedReceiver<Box<proto::signal_response::Message>>,
)> {
self.reconnecting.store(true, Ordering::Release);
let mut stream_guard = self.stream.write().await;
if let Some(old_stream) = stream_guard.take() {
old_stream.close(false).await;
}
let sid = &self.join_response.participant.as_ref().unwrap().sid;
let token = self.token.lock().clone();
let lk_url =
get_livekit_url(&self.url, &self.options, self.single_pc_mode_active, true, None, sid)
.unwrap();
let result = async {
let (new_stream, mut events) =
SignalStream::connect(lk_url, &token, self.options.connect_timeout).await?;
let reconnect_response = get_reconnect_response(&mut events).await?;
SignalResult::Ok((new_stream, reconnect_response, events))
}
.await;
match result {
Ok((new_stream, reconnect_response, events)) => {
*stream_guard = Some(new_stream);
drop(stream_guard);
Ok((reconnect_response, events))
}
Err(err) => {
drop(stream_guard);
self.reconnecting.store(false, Ordering::Release);
Err(err)
}
}
}
pub async fn set_reconnected(&self) {
self.reconnecting.store(false, Ordering::Release);
self.flush_queue().await;
}
pub async fn close(&self, notify_close: bool) {
if let Some(stream) = self.stream.write().await.take() {
stream.close(notify_close).await;
}
}
pub async fn send(&self, signal: proto::signal_request::Message) {
let pass_through = is_pass_through(&signal);
let reconnecting = self.reconnecting.load(Ordering::Acquire);
if reconnecting && !pass_through {
self.queue.lock().await.push(signal);
return;
}
if !reconnecting {
self.flush_queue().await;
}
if let Some(stream) = self.stream.read().await.as_ref() {
if let Err(SignalError::SendError) = stream.send(signal.clone()).await {
if !pass_through {
self.queue.lock().await.push(signal);
} else {
log::warn!("dropping pass-through signal — send failed");
}
}
} else if !pass_through {
self.queue.lock().await.push(signal);
} else {
log::warn!("dropping pass-through signal — no stream available");
}
}
pub async fn flush_queue(&self) {
let mut queue = self.queue.lock().await;
if queue.is_empty() {
return;
}
if let Some(stream) = self.stream.read().await.as_ref() {
for signal in queue.drain(..) {
if let Err(err) = stream.send(signal).await {
log::error!("failed to send queued signal: {}", err); }
}
}
}
pub fn next_request_id(&self) -> u32 {
self.request_id.fetch_add(1, Ordering::SeqCst)
}
}
async fn signal_task(
inner: Arc<SignalInner>,
emitter: SignalEmitter, mut internal_events: mpsc::UnboundedReceiver<Box<proto::signal_response::Message>>,
) {
let mut ping_interval = interval(Duration::from_secs(inner.join_response.ping_interval as u64));
let timeout_duration = Duration::from_secs(inner.join_response.ping_timeout as u64);
let ping_timeout = sleep(timeout_duration);
tokio::pin!(ping_timeout);
let mut rtt = 0;
loop {
tokio::select! {
signal = internal_events.recv() => {
if let Some(signal) = signal {
match signal.as_ref() {
proto::signal_response::Message::RefreshToken(ref token) => {
*inner.token.lock() = token.clone();
}
proto::signal_response::Message::PongResp(ref pong) => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
rtt = now - pong.last_ping_timestamp;
}
_ => {}
}
ping_timeout.as_mut().reset(Instant::now() + timeout_duration);
let _ = emitter.send(SignalEvent::Message(signal));
} else {
let _ = emitter.send(SignalEvent::Close("stream closed".into()));
break; }
}
_ = ping_interval.tick() => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let ping = proto::signal_request::Message::PingReq(proto::Ping{
timestamp: now,
rtt,
});
inner.send(ping).await;
}
_ = &mut ping_timeout => {
let _ = emitter.send(SignalEvent::Close("ping timeout".into()));
break;
}
}
}
inner.close(true).await; }
fn is_pass_through(signal: &proto::signal_request::Message) -> bool {
matches!(
signal,
proto::signal_request::Message::SyncState(_)
| proto::signal_request::Message::Trickle(_)
| proto::signal_request::Message::Offer(_)
| proto::signal_request::Message::Answer(_)
| proto::signal_request::Message::Simulate(_)
| proto::signal_request::Message::Leave(_)
)
}
fn create_join_request_param(
options: &SignalOptions,
reconnect: bool,
reconnect_reason: Option<i32>,
participant_sid: &str,
os: String,
os_version: String,
device_model: String,
) -> String {
let connection_settings = proto::ConnectionSettings {
auto_subscribe: options.auto_subscribe,
adaptive_stream: options.adaptive_stream,
..Default::default()
};
let client_info = proto::ClientInfo {
sdk: proto::client_info::Sdk::Rust as i32,
version: options.sdk_options.sdk_version.clone().unwrap_or_default(),
protocol: PROTOCOL_VERSION as i32,
os,
os_version,
device_model,
capabilities: CLIENT_CAPABILITIES.iter().map(|c| *c as i32).collect(),
..Default::default()
};
let mut join_request = proto::JoinRequest {
client_info: Some(client_info),
connection_settings: Some(connection_settings),
reconnect,
..Default::default()
};
if !participant_sid.is_empty() {
join_request.participant_sid = participant_sid.to_string();
}
if let Some(reason) = reconnect_reason {
join_request.reconnect_reason = reason;
}
let join_request_bytes = join_request.encode_to_vec();
let wrapped_join_request =
proto::WrappedJoinRequest { join_request: join_request_bytes, ..Default::default() };
let wrapped_bytes = wrapped_join_request.encode_to_vec();
BASE64_STANDARD.encode(&wrapped_bytes)
}
fn get_livekit_url(
url: &str,
options: &SignalOptions,
use_v1_path: bool,
reconnect: bool,
reconnect_reason: Option<i32>,
participant_sid: &str,
) -> SignalResult<url::Url> {
let mut lk_url = url::Url::parse(url).map_err(|err| SignalError::UrlParse(err.to_string()))?;
if !lk_url.has_host() {
return Err(SignalError::UrlParse("missing host or scheme".into()));
}
if lk_url.scheme() == "https" {
lk_url.set_scheme("wss").unwrap();
} else if lk_url.scheme() == "http" {
lk_url.set_scheme("ws").unwrap();
} else if lk_url.scheme() != "wss" && lk_url.scheme() != "ws" {
return Err(SignalError::UrlParse(format!("unsupported scheme: {}", lk_url.scheme())));
}
if let Ok(mut segs) = lk_url.path_segments_mut() {
segs.push("rtc");
if use_v1_path {
segs.push("v1");
}
}
let os_info = os_info::get();
let device_model = device_info::device_info().map(|info| info.model).unwrap_or_default();
if use_v1_path {
let join_request_param = create_join_request_param(
options,
reconnect,
reconnect_reason,
participant_sid,
os_info.os_type().to_string(),
os_info.version().to_string(),
device_model.to_string(),
);
lk_url.query_pairs_mut().append_pair("join_request", &join_request_param);
} else {
lk_url
.query_pairs_mut()
.append_pair("sdk", options.sdk_options.sdk.as_str())
.append_pair("os", os_info.os_type().to_string().as_str())
.append_pair("os_version", os_info.version().to_string().as_str())
.append_pair("device_model", device_model.to_string().as_str())
.append_pair("protocol", PROTOCOL_VERSION.to_string().as_str())
.append_pair("auto_subscribe", if options.auto_subscribe { "1" } else { "0" })
.append_pair("adaptive_stream", if options.adaptive_stream { "1" } else { "0" });
if let Some(sdk_version) = &options.sdk_options.sdk_version {
lk_url.query_pairs_mut().append_pair("version", sdk_version.as_str());
}
if !CLIENT_CAPABILITIES.is_empty() {
let caps =
CLIENT_CAPABILITIES.iter().map(|c| c.as_str_name()).collect::<Vec<_>>().join(",");
lk_url.query_pairs_mut().append_pair("capabilities", &caps);
}
if reconnect {
lk_url
.query_pairs_mut()
.append_pair("reconnect", "1")
.append_pair("sid", participant_sid);
}
}
Ok(lk_url)
}
fn get_validate_url(mut ws_url: url::Url) -> url::Url {
ws_url.set_scheme(if ws_url.scheme() == "wss" { "https" } else { "http" }).unwrap();
if let Ok(mut segs) = ws_url.path_segments_mut() {
segs.push("validate");
}
ws_url
}
macro_rules! get_async_message {
($fnc:ident, $pattern:pat => $result:expr, $ty:ty) => {
async fn $fnc(
receiver: &mut mpsc::UnboundedReceiver<Box<proto::signal_response::Message>>,
) -> SignalResult<$ty> {
let join = async {
while let Some(event) = receiver.recv().await {
if let $pattern = *event {
return Ok($result);
}
}
Err(WsError::ConnectionClosed)?
};
livekit_runtime::timeout(JOIN_RESPONSE_TIMEOUT, join).await.map_err(|_| {
SignalError::Timeout(format!("failed to receive {}", std::any::type_name::<$ty>()))
})?
}
};
}
get_async_message!(
get_join_response,
proto::signal_response::Message::Join(msg) => msg,
proto::JoinResponse
);
async fn get_reconnect_response(
receiver: &mut mpsc::UnboundedReceiver<Box<proto::signal_response::Message>>,
) -> SignalResult<proto::ReconnectResponse> {
let join = async {
while let Some(event) = receiver.recv().await {
match *event {
proto::signal_response::Message::Reconnect(msg) => return Ok(msg),
proto::signal_response::Message::Leave(leave) => {
return Err(SignalError::LeaveRequest {
reason: leave.reason(),
action: leave.action(),
});
}
_ => {}
}
}
Err(WsError::ConnectionClosed)?
};
livekit_runtime::timeout(JOIN_RESPONSE_TIMEOUT, join).await.map_err(|_| {
SignalError::Timeout(format!(
"failed to receive {}",
std::any::type_name::<proto::ReconnectResponse>()
))
})?
}
#[cfg(test)]
mod tests {
use super::*;
fn make_stub_inner() -> Arc<SignalInner> {
Arc::new(SignalInner {
stream: AsyncRwLock::new(None),
token: Mutex::new(String::new()),
reconnecting: AtomicBool::new(false),
queue: Default::default(),
url: "wss://localhost:7880".to_string(),
options: SignalOptions::default(),
join_response: proto::JoinResponse::default(),
request_id: AtomicU32::new(1),
single_pc_mode_active: false,
})
}
#[cfg(feature = "signal-client-tokio")]
#[tokio::test]
async fn send_queues_queueable_signals_during_reconnect() {
let inner = make_stub_inner();
inner.reconnecting.store(true, Ordering::Release);
inner
.send(proto::signal_request::Message::AddTrack(proto::AddTrackRequest {
cid: "track1".into(),
..Default::default()
}))
.await;
inner
.send(proto::signal_request::Message::Mute(proto::MuteTrackRequest {
sid: "sid1".into(),
muted: true,
}))
.await;
inner
.send(proto::signal_request::Message::Subscription(proto::UpdateSubscription {
track_sids: vec!["sid2".into()],
..Default::default()
}))
.await;
let queue = inner.queue.lock().await;
assert_eq!(queue.len(), 3, "all three queueable signals should be buffered");
}
#[cfg(feature = "signal-client-tokio")]
#[tokio::test]
async fn send_does_not_queue_pass_through_signals_during_reconnect() {
let inner = make_stub_inner();
inner.reconnecting.store(true, Ordering::Release);
inner.send(proto::signal_request::Message::Trickle(proto::TrickleRequest::default())).await;
inner
.send(proto::signal_request::Message::Offer(proto::SessionDescription::default()))
.await;
inner
.send(proto::signal_request::Message::Answer(proto::SessionDescription::default()))
.await;
inner.send(proto::signal_request::Message::SyncState(proto::SyncState::default())).await;
inner
.send(proto::signal_request::Message::Simulate(proto::SimulateScenario::default()))
.await;
inner.send(proto::signal_request::Message::Leave(proto::LeaveRequest::default())).await;
let queue = inner.queue.lock().await;
assert!(queue.is_empty(), "pass-through signals must not be queued, got {}", queue.len());
}
#[cfg(feature = "signal-client-tokio")]
#[tokio::test]
async fn set_reconnected_drains_queue_and_clears_flag() {
let inner = make_stub_inner();
inner.reconnecting.store(true, Ordering::Release);
inner
.send(proto::signal_request::Message::Mute(proto::MuteTrackRequest {
sid: "sid1".into(),
muted: true,
}))
.await;
assert_eq!(inner.queue.lock().await.len(), 1);
inner.set_reconnected().await;
assert!(!inner.reconnecting.load(Ordering::Acquire), "flag must be cleared");
}
#[test]
fn livekit_url_test() {
let io = SignalOptions::default();
assert!(get_livekit_url("localhost:7880", &io, false, false, None, "").is_err());
assert_eq!(
get_livekit_url("https://localhost:7880", &io, false, false, None, "")
.unwrap()
.scheme(),
"wss"
);
assert_eq!(
get_livekit_url("http://localhost:7880", &io, false, false, None, "").unwrap().scheme(),
"ws"
);
assert_eq!(
get_livekit_url("wss://localhost:7880", &io, false, false, None, "").unwrap().scheme(),
"wss"
);
assert_eq!(
get_livekit_url("ws://localhost:7880", &io, false, false, None, "").unwrap().scheme(),
"ws"
);
assert!(get_livekit_url("ftp://localhost:7880", &io, false, false, None, "").is_err());
}
#[test]
fn validate_url_test() {
let io = SignalOptions::default();
let lk_url = get_livekit_url("wss://localhost:7880", &io, false, false, None, "").unwrap();
let validate_url = get_validate_url(lk_url);
assert_eq!(validate_url.path(), "/rtc/validate");
assert_eq!(validate_url.scheme(), "https");
}
#[test]
fn livekit_url_includes_client_capabilities() {
let io = SignalOptions::default();
let lk_url = get_livekit_url("wss://localhost:7880", &io, false, false, None, "").unwrap();
let capabilities = lk_url
.query_pairs()
.find_map(|(key, value)| (key == "capabilities").then(|| value.into_owned()))
.unwrap();
let expected = CLIENT_CAPABILITIES
.iter()
.map(|capability| capability.as_str_name())
.collect::<Vec<_>>()
.join(",");
assert_eq!(capabilities, expected);
}
#[cfg(feature = "signal-client-tokio")]
#[tokio::test]
async fn validate_sends_bearer_token() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel::<Vec<u8>>();
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 4096];
let n = socket.read(&mut buf).await.unwrap();
buf.truncate(n);
let _ = tx.send(buf);
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
let _ = socket.write_all(response).await;
});
let ws_url = url::Url::parse(&format!("ws://127.0.0.1:{}/rtc", addr.port())).unwrap();
let result = SignalInner::validate(ws_url, "test-bearer-token").await;
assert!(result.is_ok(), "expected Ok from validate, got: {:?}", result);
let request = rx.await.expect("server task never received a request");
let request = String::from_utf8_lossy(&request);
assert!(
request.to_lowercase().contains("authorization: bearer test-bearer-token"),
"validate() must attach the access token as a Bearer header; request was:\n{}",
request
);
}
#[cfg(feature = "signal-client-tokio")]
#[tokio::test]
async fn signal_stream_connect_timeout() {
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let _accept_task = tokio::spawn(async move {
loop {
let Ok((_socket, _)) = listener.accept().await else {
break;
};
tokio::time::sleep(Duration::from_secs(60)).await;
}
});
let url = url::Url::parse(&format!("ws://127.0.0.1:{}", addr.port())).unwrap();
let result = SignalStream::connect(url, "fake-token", Duration::from_millis(500)).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, SignalError::Timeout(_)), "expected Timeout error, got: {:?}", err);
}
#[cfg(feature = "signal-client-tokio")]
#[tokio::test]
async fn region_fetch_parses_response() {
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tokio::io::AsyncReadExt::read(&mut socket, &mut buf).await;
let body = r#"{"regions":[{"region":"us-east-1","url":"wss://us-east.livekit.cloud","distance":"100"},{"region":"eu-west-1","url":"wss://eu-west.livekit.cloud","distance":"200"}]}"#;
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
socket.write_all(response.as_bytes()).await.unwrap();
});
let endpoint = format!("http://127.0.0.1:{}/settings/regions", addr.port());
let result = region::fetch_from_endpoint(&endpoint, "fake-token").await;
let urls = result.unwrap();
assert_eq!(
urls,
vec![
"wss://us-east.livekit.cloud".to_string(),
"wss://eu-west.livekit.cloud".to_string(),
]
);
}
#[cfg(feature = "signal-client-tokio")]
#[tokio::test]
async fn region_fetch_timeout() {
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let Ok((_socket, _)) = listener.accept().await else {
break;
};
tokio::time::sleep(Duration::from_secs(60)).await;
}
});
let endpoint = format!("http://127.0.0.1:{}/settings/regions", addr.port());
let result = region::fetch_from_endpoint(&endpoint, "fake-token").await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, SignalError::RegionError(ref msg) if msg.contains("timed out")),
"expected RegionError with 'timed out', got: {:?}",
err
);
}
}