use std::{net::SocketAddr, path::PathBuf, sync::Arc};
use bytes::Bytes;
use futures::stream::FuturesUnordered;
use tokio::{
net::{ToSocketAddrs, lookup_host},
sync::broadcast,
task::JoinSet,
};
use tracing::{debug, trace, warn};
use super::{PubError, PubMessage, PubOptions, SocketState, driver::PubDriver, stats::PubStats};
use crate::{ConnectionHook, ConnectionHookErased};
use msg_transport::{Address, Transport};
use msg_wire::compression::Compressor;
#[derive(Clone)]
pub struct PubSocket<T: Transport<A>, A: Address> {
options: Arc<PubOptions>,
state: Arc<SocketState>,
transport: Option<T>,
to_sessions_bcast: Option<broadcast::Sender<PubMessage>>,
hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
compressor: Option<Arc<dyn Compressor>>,
local_addr: Option<A>,
}
impl<T> PubSocket<T, SocketAddr>
where
T: Transport<SocketAddr>,
{
pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> {
let addrs = lookup_host(addr).await?;
self.try_bind(addrs.collect()).await
}
}
impl<T> PubSocket<T, PathBuf>
where
T: Transport<PathBuf>,
{
pub async fn bind(&mut self, path: impl Into<PathBuf>) -> Result<(), PubError> {
self.try_bind(vec![path.into()]).await
}
}
impl<T, A> PubSocket<T, A>
where
T: Transport<A>,
A: Address,
{
pub fn new(transport: T) -> Self {
Self::with_options(transport, PubOptions::default())
}
pub fn with_options(transport: T, options: PubOptions) -> Self {
Self {
local_addr: None,
to_sessions_bcast: None,
options: Arc::new(options),
transport: Some(transport),
state: Arc::new(SocketState::default()),
hook: None,
compressor: None,
}
}
pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
self.compressor = Some(Arc::new(compressor));
self
}
pub fn with_connection_hook<H>(mut self, hook: H) -> Self
where
H: ConnectionHook<T::Io>,
{
assert!(self.transport.is_some(), "cannot set connection hook after socket has been bound");
self.hook = Some(Arc::new(hook));
self
}
pub async fn try_bind(&mut self, addresses: Vec<A>) -> Result<(), PubError> {
let (to_sessions_bcast, from_socket_bcast) =
broadcast::channel(self.options.high_water_mark);
let mut transport = self.transport.take().expect("Transport has been moved already");
for addr in addresses {
match transport.bind(addr.clone()).await {
Ok(_) => break,
Err(e) => {
warn!(err = ?e, "Failed to bind to {:?}, trying next address", addr);
continue;
}
}
}
let Some(local_addr) = transport.local_addr() else {
return Err(PubError::NoValidEndpoints);
};
debug!("Listening on {:?}", local_addr);
let backend = PubDriver {
id_counter: 0,
transport,
options: Arc::clone(&self.options),
state: Arc::clone(&self.state),
hook: self.hook.take(),
hook_tasks: JoinSet::new(),
conn_tasks: FuturesUnordered::new(),
from_socket_bcast,
};
tokio::spawn(backend);
self.local_addr = Some(local_addr);
self.to_sessions_bcast = Some(to_sessions_bcast);
Ok(())
}
pub async fn publish(&self, topic: impl Into<String>, message: Bytes) -> Result<(), PubError> {
let mut msg = PubMessage::new(topic.into(), message);
let len_before = msg.payload().len();
if len_before > self.options.min_compress_size &&
let Some(ref compressor) = self.compressor
{
msg.compress(compressor.as_ref())?;
trace!("Compressed message from {} to {} bytes", len_before, msg.payload().len());
}
if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() {
debug!("No active subscriber sessions");
}
Ok(())
}
pub fn try_publish(&self, topic: String, message: Bytes) -> Result<(), PubError> {
let mut msg = PubMessage::new(topic, message);
if let Some(ref compressor) = self.compressor {
let len_before = msg.payload().len();
msg.compress(compressor.as_ref())?;
debug!("Compressed message from {} to {} bytes", len_before, msg.payload().len(),);
}
if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() {
debug!("No active subscriber sessions");
}
Ok(())
}
pub fn stats(&self) -> &PubStats {
&self.state.stats.specific
}
pub fn local_addr(&self) -> Option<&A> {
self.local_addr.as_ref()
}
}