use std::{collections::HashMap, fmt::Display, future::Future, sync::Arc, time::Duration};
use indexmap::IndexSet;
use tokio::runtime::Handle;
use tokio::task::JoinHandle;
use crate::{
ChannelDescriptor, Context, FoxgloveError, SinkChannelFilter, SinkId,
protocol::v2::parameter::Parameter,
remote_common::connection_graph::ConnectionGraph,
remote_common::fetch_asset::{AssetHandler, AsyncAssetHandlerFn, BlockingAssetHandlerFn},
remote_common::service::{Service, ServiceMap},
runtime::get_runtime_handle,
sink_channel_filter::SinkChannelFilterFn,
};
use super::qos::{QosClassifier, QosClassifierFn, QosProfile};
use super::connection::{ConnectionParams, ConnectionStatus, RemoteAccessConnection};
use super::{Capability, Client, Listener};
pub struct GatewayHandle {
connection: Arc<RemoteAccessConnection>,
runner: JoinHandle<()>,
runtime: Handle,
}
impl GatewayHandle {
fn new(connection: Arc<RemoteAccessConnection>, runtime: Handle) -> Self {
let runner = connection.clone().spawn_run_until_cancelled();
Self {
connection,
runner,
runtime,
}
}
pub fn connection_status(&self) -> ConnectionStatus {
self.connection.status()
}
#[doc(hidden)]
pub fn sink_id(&self) -> Option<SinkId> {
self.connection.sink_id()
}
pub fn add_services(
&self,
services: impl IntoIterator<Item = Service>,
) -> Result<(), FoxgloveError> {
self.connection.add_services(services.into_iter().collect())
}
pub fn remove_services(&self, names: impl IntoIterator<Item = impl AsRef<str>>) {
self.connection.remove_services(names);
}
pub fn publish_parameter_values(&self, parameters: Vec<Parameter>) {
self.connection.publish_parameter_values(parameters);
}
pub fn publish_status(&self, status: super::Status) {
self.connection.publish_status(status);
}
pub fn remove_status(&self, status_ids: Vec<String>) {
self.connection.remove_status(status_ids);
}
pub fn publish_connection_graph(
&self,
replacement_graph: ConnectionGraph,
) -> Result<(), FoxgloveError> {
self.connection.replace_connection_graph(replacement_graph)
}
pub fn stop(self) -> JoinHandle<()> {
self.connection.shutdown();
self.runner
}
#[cfg(test)]
fn with_runner(runner: JoinHandle<()>, runtime: Handle) -> Self {
let params = ConnectionParams {
name: None,
device_token: String::new(),
foxglove_api_url: None,
foxglove_api_timeout: None,
listener: None,
capabilities: Vec::new(),
supported_encodings: None,
fetch_asset_handler: None,
runtime: runtime.clone(),
channel_filter: None,
qos_classifier: None,
server_info: None,
message_backlog_size: None,
context: std::sync::Weak::new(),
};
let services = Arc::new(parking_lot::RwLock::new(ServiceMap::default()));
let connection = RemoteAccessConnection::new(params, services);
Self {
connection: Arc::new(connection),
runner,
runtime,
}
}
pub fn stop_blocking(self) {
self.connection.shutdown();
if let Err(e) = self.runtime.block_on(self.runner) {
tracing::warn!("Gateway connection task panicked: {e}");
}
}
}
const FOXGLOVE_DEVICE_TOKEN_ENV: &str = "FOXGLOVE_DEVICE_TOKEN";
const FOXGLOVE_API_URL_ENV: &str = "FOXGLOVE_API_URL";
const FOXGLOVE_API_TIMEOUT_ENV: &str = "FOXGLOVE_API_TIMEOUT";
#[must_use]
pub struct Gateway {
name: Option<String>,
device_token: Option<String>,
foxglove_api_url: Option<String>,
foxglove_api_timeout: Option<Duration>,
listener: Option<Arc<dyn Listener>>,
capabilities: Vec<Capability>,
supported_encodings: Option<IndexSet<String>>,
services: HashMap<String, Service>,
fetch_asset_handler: Option<Box<dyn AssetHandler<Client>>>,
runtime: Option<Handle>,
channel_filter: Option<Arc<dyn SinkChannelFilter>>,
qos_classifier: Option<Arc<dyn QosClassifier>>,
server_info: Option<HashMap<String, String>>,
message_backlog_size: Option<usize>,
context: std::sync::Weak<Context>,
}
impl Default for Gateway {
fn default() -> Self {
Self {
name: None,
device_token: None,
foxglove_api_url: None,
foxglove_api_timeout: None,
listener: None,
capabilities: Vec::new(),
supported_encodings: None,
services: HashMap::new(),
fetch_asset_handler: None,
runtime: None,
channel_filter: None,
qos_classifier: None,
server_info: None,
message_backlog_size: None,
context: Arc::downgrade(&Context::get_default()),
}
}
}
impl std::fmt::Debug for Gateway {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut dbg = f.debug_struct("Gateway");
dbg.field("name", &self.name)
.field("has_device_token", &self.device_token.is_some())
.field("foxglove_api_url", &self.foxglove_api_url)
.field("foxglove_api_timeout", &self.foxglove_api_timeout)
.field("has_listener", &self.listener.is_some())
.field("capabilities", &self.capabilities)
.field("supported_encodings", &self.supported_encodings)
.field("num_services", &self.services.len())
.field(
"has_fetch_asset_handler",
&self.fetch_asset_handler.is_some(),
)
.field("has_runtime", &self.runtime.is_some())
.field("has_channel_filter", &self.channel_filter.is_some())
.field("has_qos_classifier", &self.qos_classifier.is_some())
.field("server_info", &self.server_info)
.field("message_backlog_size", &self.message_backlog_size)
.field("has_context", &(self.context.strong_count() > 0));
dbg.finish()
}
}
impl Gateway {
pub fn new() -> Self {
Self::default()
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn listener(mut self, listener: Arc<dyn Listener>) -> Self {
self.listener = Some(listener);
self
}
pub fn capabilities(mut self, capabilities: impl IntoIterator<Item = Capability>) -> Self {
self.capabilities = capabilities.into_iter().collect();
self
}
pub fn supported_encodings(
mut self,
encodings: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.supported_encodings = Some(encodings.into_iter().map(|e| e.into()).collect());
self
}
#[doc(hidden)]
pub fn server_info(mut self, info: HashMap<String, String>) -> Self {
self.server_info = Some(info);
self
}
pub fn context(mut self, ctx: &Arc<Context>) -> Self {
self.context = Arc::downgrade(ctx);
self
}
#[doc(hidden)]
pub fn tokio_runtime(mut self, handle: &tokio::runtime::Handle) -> Self {
self.runtime = Some(handle.clone());
self
}
pub fn channel_filter(mut self, filter: Arc<dyn SinkChannelFilter>) -> Self {
self.channel_filter = Some(filter);
self
}
pub fn device_token(mut self, token: impl Into<String>) -> Self {
self.device_token = Some(token.into());
self
}
pub fn foxglove_api_url(mut self, url: impl Into<String>) -> Self {
self.foxglove_api_url = Some(url.into());
self
}
pub fn foxglove_api_timeout(mut self, timeout: Duration) -> Self {
self.foxglove_api_timeout = Some(timeout);
self
}
pub fn message_backlog_size(mut self, size: usize) -> Self {
self.message_backlog_size = Some(size);
self
}
pub fn channel_filter_fn(
mut self,
filter: impl Fn(&ChannelDescriptor) -> bool + Sync + Send + 'static,
) -> Self {
self.channel_filter = Some(Arc::new(SinkChannelFilterFn(filter)));
self
}
pub fn qos_classifier(mut self, classifier: Arc<dyn QosClassifier>) -> Self {
self.qos_classifier = Some(classifier);
self
}
pub fn qos_classifier_fn(
mut self,
classifier: impl Fn(&ChannelDescriptor) -> QosProfile + Sync + Send + 'static,
) -> Self {
self.qos_classifier = Some(Arc::new(QosClassifierFn(classifier)));
self
}
pub fn services(mut self, services: impl IntoIterator<Item = Service>) -> Self {
self.services.clear();
for service in services {
let name = service.name().to_string();
if let Some(s) = self.services.insert(name, service) {
tracing::warn!("Redefining service {}", s.name());
}
}
self
}
pub fn fetch_asset_handler(mut self, handler: Box<dyn AssetHandler<Client>>) -> Self {
self.fetch_asset_handler = Some(handler);
self
}
pub fn fetch_asset_handler_blocking_fn<F, T, Err>(mut self, handler: F) -> Self
where
F: Fn(Client, String) -> Result<T, Err> + Send + Sync + 'static,
T: AsRef<[u8]>,
Err: Display,
{
self.fetch_asset_handler = Some(Box::new(BlockingAssetHandlerFn(Arc::new(handler))));
self
}
pub fn fetch_asset_handler_async_fn<F, Fut, T, Err>(mut self, handler: F) -> Self
where
F: Fn(Client, String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<T, Err>> + Send + 'static,
T: AsRef<[u8]>,
Err: Display,
{
self.fetch_asset_handler = Some(Box::new(AsyncAssetHandlerFn(Arc::new(handler))));
self
}
pub fn start(mut self) -> Result<GatewayHandle, FoxgloveError> {
crate::crypto::install_default_crypto_provider();
let device_token = self
.device_token
.or_else(|| std::env::var(FOXGLOVE_DEVICE_TOKEN_ENV).ok())
.ok_or_else(|| {
FoxgloveError::ConfigurationError(format!(
"No device token provided. Set the {FOXGLOVE_DEVICE_TOKEN_ENV} environment variable or call .device_token() on the builder."
))
})?;
let foxglove_api_url = self
.foxglove_api_url
.or_else(|| std::env::var(FOXGLOVE_API_URL_ENV).ok());
let foxglove_api_timeout = self.foxglove_api_timeout.or_else(|| {
std::env::var(FOXGLOVE_API_TIMEOUT_ENV)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs)
});
if !self.services.is_empty() {
if !self.capabilities.contains(&Capability::Services) {
self.capabilities.push(Capability::Services);
}
let encodings = self
.supported_encodings
.get_or_insert_with(Default::default);
for svc in self.services.values() {
if let Some(encoding) = svc.request_encoding() {
encodings.insert(encoding.to_string());
}
}
if encodings.is_empty() {
if let Some(svc) = self
.services
.values()
.find(|s| s.request_encoding().is_none())
{
return Err(FoxgloveError::MissingRequestEncoding(
svc.name().to_string(),
));
}
}
}
if self.fetch_asset_handler.is_some() && !self.capabilities.contains(&Capability::Assets) {
self.capabilities.push(Capability::Assets);
}
if self.capabilities.contains(&Capability::Assets) && self.fetch_asset_handler.is_none() {
return Err(FoxgloveError::ConfigurationError(
"The Assets capability requires a fetch asset handler. \
Use fetch_asset_handler(), fetch_asset_handler_blocking_fn(), \
or fetch_asset_handler_async_fn()."
.to_string(),
));
}
let runtime = self.runtime.unwrap_or_else(get_runtime_handle);
let services = Arc::new(parking_lot::RwLock::new(ServiceMap::from_iter(
self.services.into_values(),
)));
let params = ConnectionParams {
name: self.name,
device_token,
foxglove_api_url,
foxglove_api_timeout,
listener: self.listener,
capabilities: self.capabilities,
supported_encodings: self.supported_encodings,
fetch_asset_handler: self.fetch_asset_handler.map(Arc::from),
runtime: runtime.clone(),
channel_filter: self.channel_filter,
qos_classifier: self.qos_classifier,
server_info: self.server_info,
message_backlog_size: self.message_backlog_size,
context: self.context,
};
let connection = RemoteAccessConnection::new(params, services);
Ok(GatewayHandle::new(Arc::new(connection), runtime))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FoxgloveError;
use crate::remote_common::service::{Service, ServiceSchema};
#[test]
fn stop_blocking_clean_shutdown() {
let rt = tokio::runtime::Runtime::new().unwrap();
let runner = rt.spawn(async {});
let handle = GatewayHandle::with_runner(runner, rt.handle().clone());
handle.stop_blocking();
}
#[test]
fn stop_blocking_logs_panic() {
let rt = tokio::runtime::Runtime::new().unwrap();
let runner = rt.spawn(async { panic!("test panic") });
std::thread::sleep(std::time::Duration::from_millis(10));
let handle = GatewayHandle::with_runner(runner, rt.handle().clone());
handle.stop_blocking();
}
#[test]
fn test_initial_service_missing_request_encoding() {
let svc =
Service::builder("/s", ServiceSchema::new("")).handler_fn(|_| Ok::<_, String>(b""));
let result = Gateway::new()
.device_token("test-token")
.services([svc])
.start();
assert!(matches!(
result,
Err(FoxgloveError::MissingRequestEncoding(_))
));
}
#[test]
fn test_assets_capability_without_handler() {
let result = Gateway::new()
.device_token("test-token")
.capabilities([Capability::Assets])
.start();
assert!(matches!(result, Err(FoxgloveError::ConfigurationError(_))));
}
}