use std::{
collections::HashSet,
net::SocketAddr,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::Stream;
use rustc_hash::FxHashMap;
use tokio::{
net::{ToSocketAddrs, lookup_host},
sync::mpsc,
task::JoinSet,
};
use msg_common::{IpAddrExt, JoinMap};
use msg_transport::{Address, Transport};
use crate::{
ConnectionHook, ConnectionHookErased,
sub::{
Command, DEFAULT_BUFFER_SIZE, PubMessage, SocketState, SubDriver, SubError, SubOptions,
stats::SubStats,
},
};
pub struct SubSocket<T: Transport<A>, A: Address> {
to_driver: mpsc::Sender<Command<A>>,
from_driver: mpsc::Receiver<PubMessage<A>>,
#[allow(unused)]
options: Arc<SubOptions>,
driver: Option<SubDriver<T, A>>,
state: Arc<SocketState<A>>,
hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
_marker: std::marker::PhantomData<T>,
}
impl<T> SubSocket<T, SocketAddr>
where
T: Transport<SocketAddr> + Send + Sync + Unpin + 'static,
{
pub async fn connect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
let mut addrs = lookup_host(endpoint).await?;
let mut endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
if endpoint.ip().is_unspecified() {
endpoint.set_ip(endpoint.ip().as_localhost());
}
self.connect_inner(endpoint).await
}
pub fn try_connect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
let addr = endpoint.into();
let mut endpoint: SocketAddr = addr.parse().map_err(|_| SubError::NoValidEndpoints)?;
if endpoint.ip().is_unspecified() {
endpoint.set_ip(endpoint.ip().as_localhost());
}
self.try_connect_inner(endpoint)
}
pub async fn disconnect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
let mut addrs = lookup_host(endpoint).await?;
let endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
self.disconnect_inner(endpoint).await
}
pub fn try_disconnect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
let endpoint = endpoint.into();
let endpoint: SocketAddr = endpoint.parse().map_err(|_| SubError::NoValidEndpoints)?;
self.try_disconnect_inner(endpoint)
}
}
impl<T> SubSocket<T, PathBuf>
where
T: Transport<PathBuf> + Send + Sync + Unpin + 'static,
{
pub async fn connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.connect_inner(path.into()).await
}
pub fn try_connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.try_connect_inner(path.into())
}
pub async fn disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.disconnect_inner(path.into()).await
}
pub fn try_disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.try_disconnect_inner(path.into())
}
}
impl<T, A> SubSocket<T, A>
where
T: Transport<A> + Send + Sync + Unpin + 'static,
A: Address,
{
pub fn new(transport: T) -> Self {
Self::with_options(transport, SubOptions::default())
}
pub fn with_options(transport: T, options: SubOptions) -> Self {
let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE);
let (to_socket, from_driver) = mpsc::channel(options.ingress_queue_size);
let options = Arc::new(options);
let state = Arc::new(SocketState::default());
let mut publishers = FxHashMap::default();
publishers.reserve(32);
let driver = SubDriver {
options: Arc::clone(&options),
transport,
from_socket,
to_socket,
conn_tasks: JoinMap::new(),
hook_tasks: JoinSet::new(),
subscribed_topics: HashSet::with_capacity(32),
publishers,
state: Arc::clone(&state),
hook: None,
};
Self {
to_driver,
from_driver,
driver: Some(driver),
options,
state,
hook: None,
_marker: std::marker::PhantomData,
}
}
pub fn with_connection_hook<H>(mut self, hook: H) -> Self
where
H: ConnectionHook<T::Io>,
{
let hook_arc: Arc<dyn ConnectionHookErased<T::Io>> = Arc::new(hook);
let driver =
self.driver.as_mut().expect("cannot set connection hook after driver has started");
driver.hook = Some(hook_arc.clone());
self.hook = Some(hook_arc);
self
}
pub async fn connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.send_command(Command::Connect { endpoint }).await?;
Ok(())
}
pub fn try_connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.try_send_command(Command::Connect { endpoint })?;
Ok(())
}
pub async fn disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.send_command(Command::Disconnect { endpoint }).await?;
Ok(())
}
pub fn try_disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.try_send_command(Command::Disconnect { endpoint })?;
Ok(())
}
pub async fn subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.send_command(Command::Subscribe { topic }).await?;
Ok(())
}
pub fn try_subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.try_send_command(Command::Subscribe { topic })?;
Ok(())
}
pub async fn unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.send_command(Command::Unsubscribe { topic }).await?;
Ok(())
}
pub fn try_unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.try_send_command(Command::Unsubscribe { topic })?;
Ok(())
}
async fn send_command(&self, command: Command<A>) -> Result<(), SubError> {
self.to_driver.send(command).await.map_err(|_| SubError::SocketClosed)?;
Ok(())
}
fn try_send_command(&self, command: Command<A>) -> Result<(), SubError> {
use mpsc::error::TrySendError::*;
self.to_driver.try_send(command).map_err(|e| match e {
Full(_) => SubError::ChannelFull,
Closed(_) => SubError::SocketClosed,
})?;
Ok(())
}
fn ensure_active_driver(&mut self) {
if let Some(driver) = self.driver.take() {
tokio::spawn(driver);
}
}
pub fn stats(&self) -> &SubStats<A> {
&self.state.stats.specific
}
}
impl<T: Transport<A>, A: Address> Drop for SubSocket<T, A> {
fn drop(&mut self) {
let _ = self.to_driver.try_send(Command::Shutdown);
}
}
impl<T: Transport<A> + Unpin, A: Address> Stream for SubSocket<T, A> {
type Item = PubMessage<A>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.from_driver.poll_recv(cx)
}
}