use std::collections::HashSet;
use std::collections::hash_map::Entry;
use std::sync::Weak;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use indexmap::IndexSet;
use futures_util::SinkExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::runtime::Handle;
use tokio::task::{JoinError, JoinSet};
use tokio::time::MissedTickBehavior;
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;
use crate::library_version::get_library_version;
use crate::sink_channel_filter::SinkChannelFilter;
use crate::websocket::connected_client::ShutdownReason;
use crate::websocket::streams::{Acceptor, StreamConfiguration, TlsIdentity};
use crate::{Context, FoxgloveError};
use super::connected_client::ConnectedClient;
use super::cow_vec::CowVec;
use super::service::{Service, ServiceId, ServiceMap};
use super::ws_protocol::server::PlaybackState;
use super::ws_protocol::server::{
AdvertiseServices, RemoveStatus, ServerInfo, UnadvertiseServices,
};
use super::{
AssetHandler, Capability, ClientId, ConnectionGraph, Parameter, ParameterHandler,
ServerListener, Status, advertise, handshake,
};
const DEFAULT_MESSAGE_BACKLOG_SIZE: usize = 1024;
#[derive(Default)]
pub(crate) struct ServerOptions {
pub session_id: Option<String>,
pub name: Option<String>,
pub message_backlog_size: Option<usize>,
pub listener: Option<Arc<dyn ServerListener>>,
pub capabilities: Option<IndexSet<Capability>>,
pub services: HashMap<String, Service>,
pub supported_encodings: Option<IndexSet<String>>,
pub runtime: Option<Handle>,
pub fetch_asset_handler: Option<Arc<dyn AssetHandler>>,
pub parameter_handler: Option<Arc<dyn ParameterHandler>>,
pub tls_identity: Option<TlsIdentity>,
pub channel_filter: Option<Arc<dyn SinkChannelFilter>>,
pub server_info: Option<HashMap<String, String>>,
pub playback_time_range: Option<(u64, u64)>,
}
impl std::fmt::Debug for ServerOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServerOptions")
.field("session_id", &self.session_id)
.field("name", &self.name)
.field("message_backlog_size", &self.message_backlog_size)
.field("services", &self.services)
.field("capabilities", &self.capabilities)
.field("supported_encodings", &self.supported_encodings)
.field("server_info", &self.server_info)
.finish()
}
}
fn process_task_result(result: Result<(), JoinError>) {
match result {
Err(e) if e.is_panic() => tracing::warn!("{e}"),
_ => (),
}
}
#[must_use]
#[derive(Debug)]
pub struct ShutdownHandle {
runtime: Handle,
tasks: JoinSet<()>,
}
impl ShutdownHandle {
fn new(runtime: Handle, tasks: JoinSet<()>) -> Self {
Self { runtime, tasks }
}
pub fn detach_all(mut self) {
self.tasks.detach_all();
}
async fn wait_inner(&mut self) {
while let Some(result) = self.tasks.join_next().await {
process_task_result(result);
}
tracing::info!("Shutdown complete");
}
pub async fn wait(mut self) {
self.wait_inner().await;
}
pub fn wait_blocking(mut self) {
self.runtime.clone().block_on(self.wait_inner());
}
}
pub(crate) fn create_server(
ctx: &Arc<Context>,
opts: ServerOptions,
) -> Result<Arc<Server>, FoxgloveError> {
if !opts.services.is_empty() {
let has_encodings = opts
.supported_encodings
.as_ref()
.is_some_and(|e| !e.is_empty())
|| opts
.services
.values()
.any(|s| s.request_encoding().is_some());
if !has_encodings {
if let Some(svc) = opts
.services
.values()
.find(|s| s.request_encoding().is_none())
{
return Err(FoxgloveError::MissingRequestEncoding(
svc.name().to_string(),
));
}
}
}
let stream_config = StreamConfiguration::new(opts.tls_identity.as_ref())?;
Ok(Arc::new_cyclic(|weak_self| {
Server::new(weak_self.clone(), ctx, opts, stream_config)
}))
}
pub(crate) struct Server {
weak_self: Weak<Self>,
context: Weak<Context>,
message_backlog_size: u32,
runtime: Handle,
session_id: parking_lot::RwLock<String>,
name: String,
clients: CowVec<Arc<ConnectedClient>>,
channel_filter: Option<Arc<dyn SinkChannelFilter>>,
listener: Option<Arc<dyn ServerListener>>,
capabilities: IndexSet<Capability>,
subscribed_parameters: parking_lot::RwLock<HashMap<String, HashSet<ClientId>>>,
supported_encodings: IndexSet<String>,
connection_graph: parking_lot::Mutex<ConnectionGraph>,
cancellation_token: CancellationToken,
services: parking_lot::RwLock<ServiceMap>,
fetch_asset_handler: Option<Arc<dyn AssetHandler>>,
parameter_handler: Option<Arc<dyn ParameterHandler>>,
tasks: parking_lot::Mutex<Option<JoinSet<()>>>,
stream_config: StreamConfiguration,
server_info: HashMap<String, String>,
playback_time_range: Option<(u64, u64)>,
}
impl Server {
pub fn generate_session_id() -> String {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.ok()
.map(|d| d.as_millis().to_string())
.unwrap_or_default()
}
fn new(
weak_self: Weak<Self>,
ctx: &Arc<Context>,
opts: ServerOptions,
stream_config: StreamConfiguration,
) -> Self {
let mut capabilities = opts.capabilities.unwrap_or_default();
let mut supported_encodings = opts.supported_encodings.unwrap_or_default();
if !opts.services.is_empty() {
capabilities.insert(Capability::Services);
supported_encodings.extend(
opts.services
.values()
.filter_map(|svc| svc.schema().request().map(|s| s.encoding.clone())),
);
}
if opts.playback_time_range.is_some() {
capabilities.insert(Capability::PlaybackControl);
} else if capabilities.contains(&Capability::PlaybackControl) {
panic!(
"Server declared the PlaybackControl capability but did not provide a playback time range"
);
}
if opts.fetch_asset_handler.is_some() {
capabilities.insert(Capability::Assets);
}
if opts.parameter_handler.is_some() {
capabilities.insert(Capability::Parameters);
}
Server {
weak_self,
context: Arc::downgrade(ctx),
message_backlog_size: opts
.message_backlog_size
.unwrap_or(DEFAULT_MESSAGE_BACKLOG_SIZE) as u32,
runtime: opts
.runtime
.unwrap_or_else(crate::runtime::get_runtime_handle),
channel_filter: opts.channel_filter.clone(),
listener: opts.listener,
session_id: parking_lot::RwLock::new(
opts.session_id.unwrap_or_else(Self::generate_session_id),
),
name: opts.name.unwrap_or_default(),
clients: CowVec::new(),
subscribed_parameters: parking_lot::RwLock::default(),
capabilities,
supported_encodings,
connection_graph: parking_lot::Mutex::default(),
cancellation_token: CancellationToken::new(),
services: parking_lot::RwLock::new(ServiceMap::from_iter(opts.services.into_values())),
fetch_asset_handler: opts.fetch_asset_handler,
parameter_handler: opts.parameter_handler,
tasks: parking_lot::Mutex::default(),
stream_config,
server_info: opts.server_info.unwrap_or_default(),
playback_time_range: opts.playback_time_range,
}
}
fn arc(&self) -> Arc<Self> {
self.weak_self
.upgrade()
.expect("server cannot be dropped while in use")
}
pub(super) fn has_capability(&self, cap: Capability) -> bool {
self.capabilities.contains(&cap)
}
pub(super) fn supports_encoding(&self, encoding: impl AsRef<str>) -> bool {
self.supported_encodings.contains(encoding.as_ref())
}
pub(super) fn fetch_asset_handler(&self) -> Option<&dyn AssetHandler> {
self.fetch_asset_handler.as_deref()
}
pub(super) fn parameter_handler(&self) -> Option<&dyn ParameterHandler> {
self.parameter_handler.as_deref()
}
pub(super) fn listener(&self) -> Option<&dyn ServerListener> {
self.listener.as_deref()
}
pub async fn start(&self, host: &str, port: u16) -> Result<SocketAddr, FoxgloveError> {
{
let mut tasks = self.tasks.lock();
if tasks.is_some() || self.cancellation_token.is_cancelled() {
return Err(FoxgloveError::ServerAlreadyStarted);
}
tasks.replace(JoinSet::new());
}
let addr = format!("{host}:{port}");
let listener = TcpListener::bind(&addr)
.await
.map_err(FoxgloveError::Bind)?;
let local_addr = listener.local_addr().map_err(FoxgloveError::Bind)?;
let cancellation_token = self.cancellation_token.clone();
let server = self.arc();
self.runtime.spawn(async move {
tokio::select! {
() = server.clone().accept_connections(listener) => (),
() = server.clone().reap_completed_tasks() => (),
() = cancellation_token.cancelled() => (),
}
});
let maybe_tls = if self.is_tls_configured() {
" (TLS enabled)"
} else {
""
};
tracing::info!("Started server on {local_addr}{maybe_tls}");
Ok(local_addr)
}
async fn accept_connections(self: Arc<Self>, listener: TcpListener) {
while let Ok((stream, addr)) = listener.accept().await {
if let Some(tasks) = self.tasks.lock().as_mut() {
tasks.spawn(self.clone().handle_connection(stream, addr));
} else {
break;
}
}
}
async fn reap_completed_tasks(self: Arc<Self>) {
let mut interval = tokio::time::interval(Duration::from_secs(1));
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
loop {
interval.tick().await;
if let Some(tasks) = self.tasks.lock().as_mut() {
while let Some(result) = tasks.try_join_next() {
process_task_result(result);
}
} else {
break;
}
}
}
#[must_use]
pub fn stop(&self) -> Option<ShutdownHandle> {
let tasks = self.tasks.lock().take()?;
tracing::info!("Shutting down");
self.cancellation_token.cancel();
let clients = self.clients.take_and_freeze();
for client in clients.iter() {
client.shutdown(ShutdownReason::ServerStopped);
}
Some(ShutdownHandle::new(self.runtime.clone(), tasks))
}
pub fn client_count(&self) -> usize {
self.clients.get().len()
}
pub fn broadcast_time(&self, timestamp: u64) {
use super::ws_protocol::server::Time;
if !self.has_capability(Capability::Time) {
tracing::error!("Server does not support time capability");
return;
}
let message = Time::new(timestamp);
let clients = self.clients.get();
for client in clients.iter() {
client.send_control_msg(&message);
}
}
pub fn broadcast_playback_state(&self, playback_state: PlaybackState) {
if !self.has_capability(Capability::PlaybackControl) {
tracing::error!("Server does not support the PlaybackControl capability");
return;
}
for client in self.clients.get().iter() {
client.send_control_msg(&playback_state);
}
}
pub(super) fn subscribe_parameters(&self, client_id: ClientId, names: Vec<String>) {
let mut subs = self.subscribed_parameters.write();
let mut new_names = vec![];
for name in names {
match subs.entry(name.clone()) {
Entry::Occupied(mut entry) => {
entry.get_mut().insert(client_id);
}
Entry::Vacant(entry) => {
entry.insert(HashSet::from_iter([client_id]));
new_names.push(name);
}
}
}
if !new_names.is_empty() {
if let Some(listener) = self.listener.as_ref() {
listener.on_parameters_subscribe(new_names);
}
}
}
pub(super) fn unsubscribe_parameters(&self, client_id: ClientId, names: Vec<String>) {
let mut subs = self.subscribed_parameters.write();
let mut old_names = vec![];
for name in names {
if let Some(entry) = subs.get_mut(&name) {
if entry.remove(&client_id) && entry.is_empty() {
subs.remove(&name);
old_names.push(name);
}
}
}
if !old_names.is_empty() {
if let Some(listener) = self.listener.as_ref() {
listener.on_parameters_unsubscribe(old_names);
}
}
}
fn unsubscribe_all_parameters(&self, client_id: ClientId) {
let mut subs = self.subscribed_parameters.write();
let mut old_names = vec![];
for (name, entry) in subs.iter_mut() {
if entry.remove(&client_id) && entry.is_empty() {
old_names.push(name.clone());
}
}
for name in &old_names {
subs.remove(name);
}
if !old_names.is_empty() {
if let Some(listener) = self.listener.as_ref() {
listener.on_parameters_unsubscribe(old_names);
}
}
}
pub(super) fn subscribe_connection_graph(&self, client_id: ClientId) -> Option<Message> {
let mut graph = self.connection_graph.lock();
let first = !graph.has_subscribers();
if !graph.add_subscriber(client_id) {
return None;
}
if first {
if let Some(listener) = self.listener.as_ref() {
listener.on_connection_graph_subscribe();
}
}
let initial_update = Message::from(&graph.as_initial_update());
Some(initial_update)
}
pub(super) fn unsubscribe_connection_graph(&self, client_id: ClientId) -> bool {
let mut graph = self.connection_graph.lock();
if !graph.remove_subscriber(client_id) {
return false;
}
if !graph.has_subscribers() {
if let Some(listener) = self.listener.as_ref() {
listener.on_connection_graph_unsubscribe();
}
}
true
}
pub fn publish_parameter_values(&self, parameters: Vec<Parameter>) {
if !self.has_capability(Capability::Parameters) {
tracing::error!("Server does not support parameters capability");
return;
}
let clients = self.clients.get();
for client in clients.iter() {
let filtered: Vec<_> = {
let subs = self.subscribed_parameters.read();
parameters
.iter()
.filter(|p| {
subs.get(&p.name)
.is_some_and(|ids| ids.contains(&client.id()))
})
.cloned()
.collect()
};
if !filtered.is_empty() {
client.update_parameters(filtered, None);
}
}
}
pub fn publish_status(&self, status: Status) {
let clients = self.clients.get();
for client in clients.iter() {
client.send_status(status.clone());
}
}
pub fn remove_status(&self, status_ids: Vec<String>) {
let message = RemoveStatus { status_ids };
let clients = self.clients.get();
for client in clients.iter() {
client.send_control_msg(&message);
}
}
fn server_info(&self) -> ServerInfo {
let mut metadata = self.server_info.clone();
if metadata.contains_key("fg-library") {
tracing::warn!("Overwriting reserved server_info key 'fg-library'");
}
metadata.insert("fg-library".into(), get_library_version());
ServerInfo::new(&self.name)
.with_capabilities(
self.capabilities
.iter()
.flat_map(Capability::as_protocol_capabilities)
.copied(),
)
.with_metadata(metadata)
.with_supported_encodings(&self.supported_encodings)
.with_session_id(self.session_id.read().clone())
.with_playback_time_range(self.playback_time_range)
}
pub fn clear_session(&self, new_session_id: Option<String>) {
*self.session_id.write() = new_session_id.unwrap_or_else(Self::generate_session_id);
let message = self.server_info();
let clients = self.clients.get();
for client in clients.iter() {
client.send_control_msg(&message);
}
}
async fn handle_connection(self: Arc<Self>, stream: TcpStream, addr: SocketAddr) {
let stream = match self.stream_config.accept(stream).await {
Ok(maybe_tls_stream) => maybe_tls_stream,
Err(e) => {
tracing::error!("Dropping client {addr}: secure handshake failed: {}", e);
return;
}
};
let Ok(mut ws_stream) = handshake::do_handshake(stream).await else {
tracing::error!("Dropping client {addr}: handshake failed");
return;
};
let message = Message::from(&self.server_info());
if let Err(err) = ws_stream.send(message).await {
tracing::error!("Failed to send required server info: {err}");
return;
}
let client = ConnectedClient::new(
&self.context,
&self.weak_self,
ws_stream,
addr,
self.message_backlog_size as usize,
self.channel_filter.clone(),
);
self.register_client_and_advertise(&client);
client.run().await;
self.unregister_client(&client);
}
fn register_client_and_advertise(&self, client: &Arc<ConnectedClient>) {
if !self.clients.push(client.clone()) {
tracing::debug!("Disconnecting client {}: server is stopped", client.addr());
client.shutdown(ShutdownReason::ServerStopped);
return;
}
tracing::info!("Registered client {}", client.addr());
if let Some(listener) = self.listener() {
tracing::debug!("Notifying listener of client connection");
listener.on_client_connect();
}
if let Some(context) = self.context.upgrade() {
context.add_sink(client.clone());
}
let services: Vec<_> = self.services.read().values().cloned().collect();
let msg = advertise::advertise_services(services.iter().map(|s| s.as_ref()));
if msg.services.is_empty() {
return;
}
if client.send_control_msg(&msg) {
for service in services {
tracing::debug!(
"Advertised service {} with id {} to client {}",
service.name(),
service.id(),
client.addr()
);
}
}
}
fn unregister_client(&self, client: &Arc<ConnectedClient>) {
if let Some(context) = self.context.upgrade() {
context.remove_sink(client.sink_id());
}
self.clients.retain(|c| c.id() != client.id());
if self.has_capability(Capability::Parameters) {
self.unsubscribe_all_parameters(client.id());
}
if self.has_capability(Capability::ConnectionGraph) {
self.unsubscribe_connection_graph(client.id());
}
client.on_disconnect();
if let Some(listener) = self.listener() {
listener.on_client_disconnect();
}
tracing::info!("Unregistered client {}", client.addr());
}
pub fn add_services(&self, new_services: Vec<Service>) -> Result<(), FoxgloveError> {
if !self.has_capability(Capability::Services) {
return Err(FoxgloveError::ServicesNotSupported);
}
if new_services.is_empty() {
return Ok(());
}
let mut new_names = HashMap::with_capacity(new_services.len());
let mut msg = AdvertiseServices { services: vec![] };
for service in &new_services {
if new_names
.insert(service.name().to_string(), service.id())
.is_some()
{
return Err(FoxgloveError::DuplicateService(service.name().to_string()));
}
if service.request_encoding().is_none() && self.supported_encodings.is_empty() {
return Err(FoxgloveError::MissingRequestEncoding(
service.name().to_string(),
));
}
if let Some(adv) = advertise::maybe_advertise_service(service) {
msg.services.push(adv.into_owned());
}
}
{
let mut services = self.services.write();
for service in &new_services {
if services.contains_name(service.name()) || services.contains_id(service.id()) {
return Err(FoxgloveError::DuplicateService(service.name().to_string()));
}
}
for service in new_services {
services.insert(service);
}
}
if msg.services.is_empty() {
return Ok(());
}
let clients = self.clients.get();
for client in clients.iter() {
for (name, id) in &new_names {
tracing::debug!(
"Advertising service {name} with id {id} to client {}",
client.addr()
);
}
client.send_control_msg(&msg);
}
Ok(())
}
pub fn remove_services(&self, names: impl IntoIterator<Item = impl AsRef<str>>) {
let names = names.into_iter();
let mut old_services = HashMap::with_capacity(names.size_hint().0);
{
let mut services = self.services.write();
for name in names {
if let Some(service) = services.remove_by_name(name) {
old_services.insert(service.id(), service.name().to_string());
}
}
}
if old_services.is_empty() {
return;
}
let msg = UnadvertiseServices::new(old_services.keys().map(|&id| id.into()));
let clients = self.clients.get();
for client in clients.iter() {
for (id, name) in &old_services {
tracing::debug!(
"Unadvertising service {name} with id {id} to client {}",
client.addr()
);
}
client.send_control_msg(&msg);
}
}
pub(super) fn get_service(&self, id: ServiceId) -> Option<Arc<Service>> {
self.services.read().get_by_id(id)
}
pub fn replace_connection_graph(
&self,
replacement_graph: ConnectionGraph,
) -> Result<(), FoxgloveError> {
if !self.has_capability(Capability::ConnectionGraph) {
return Err(FoxgloveError::ConnectionGraphNotSupported);
}
let mut graph = self.connection_graph.lock();
let msg = graph.update(replacement_graph);
for client in self.clients.get().iter() {
if graph.is_subscriber(client.id()) {
client.send_control_msg(&msg);
}
}
Ok(())
}
pub(crate) fn is_tls_configured(&self) -> bool {
self.stream_config.accepts_tls()
}
}