1use std::{net::SocketAddr, path::PathBuf, sync::Arc};
2
3use bytes::Bytes;
4use futures::stream::FuturesUnordered;
5use tokio::{
6 net::{ToSocketAddrs, lookup_host},
7 sync::broadcast,
8 task::JoinSet,
9};
10use tracing::{debug, trace, warn};
11
12use super::{PubError, PubMessage, PubOptions, SocketState, driver::PubDriver, stats::PubStats};
13use crate::{ConnectionHook, ConnectionHookErased};
14
15use msg_transport::{Address, Transport};
16use msg_wire::compression::Compressor;
17
18#[derive(Clone)]
27pub struct PubSocket<T: Transport<A>, A: Address> {
28 options: Arc<PubOptions>,
30 state: Arc<SocketState>,
32 transport: Option<T>,
35 to_sessions_bcast: Option<broadcast::Sender<PubMessage>>,
38 hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
40 compressor: Option<Arc<dyn Compressor>>,
44 local_addr: Option<A>,
46}
47
48impl<T> PubSocket<T, SocketAddr>
49where
50 T: Transport<SocketAddr>,
51{
52 pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> {
57 let addrs = lookup_host(addr).await?;
58 self.try_bind(addrs.collect()).await
59 }
60}
61
62impl<T> PubSocket<T, PathBuf>
63where
64 T: Transport<PathBuf>,
65{
66 pub async fn bind(&mut self, path: impl Into<PathBuf>) -> Result<(), PubError> {
71 self.try_bind(vec![path.into()]).await
72 }
73}
74
75impl<T, A> PubSocket<T, A>
76where
77 T: Transport<A>,
78 A: Address,
79{
80 pub fn new(transport: T) -> Self {
82 Self::with_options(transport, PubOptions::default())
83 }
84
85 pub fn with_options(transport: T, options: PubOptions) -> Self {
87 Self {
88 local_addr: None,
89 to_sessions_bcast: None,
90 options: Arc::new(options),
91 transport: Some(transport),
92 state: Arc::new(SocketState::default()),
93 hook: None,
94 compressor: None,
95 }
96 }
97
98 pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
100 self.compressor = Some(Arc::new(compressor));
101 self
102 }
103
104 pub fn with_connection_hook<H>(mut self, hook: H) -> Self
113 where
114 H: ConnectionHook<T::Io>,
115 {
116 assert!(self.transport.is_some(), "cannot set connection hook after socket has been bound");
117 self.hook = Some(Arc::new(hook));
118 self
119 }
120
121 pub async fn try_bind(&mut self, addresses: Vec<A>) -> Result<(), PubError> {
125 let (to_sessions_bcast, from_socket_bcast) =
126 broadcast::channel(self.options.high_water_mark);
127
128 let mut transport = self.transport.take().expect("Transport has been moved already");
129
130 for addr in addresses {
131 match transport.bind(addr.clone()).await {
132 Ok(_) => break,
133 Err(e) => {
134 warn!(err = ?e, "Failed to bind to {:?}, trying next address", addr);
135 continue;
136 }
137 }
138 }
139
140 let Some(local_addr) = transport.local_addr() else {
141 return Err(PubError::NoValidEndpoints);
142 };
143
144 debug!("Listening on {:?}", local_addr);
145
146 let backend = PubDriver {
147 id_counter: 0,
148 transport,
149 options: Arc::clone(&self.options),
150 state: Arc::clone(&self.state),
151 hook: self.hook.take(),
152 hook_tasks: JoinSet::new(),
153 conn_tasks: FuturesUnordered::new(),
154 from_socket_bcast,
155 };
156
157 tokio::spawn(backend);
158
159 self.local_addr = Some(local_addr);
160 self.to_sessions_bcast = Some(to_sessions_bcast);
161
162 Ok(())
163 }
164
165 pub async fn publish(&self, topic: impl Into<String>, message: Bytes) -> Result<(), PubError> {
167 let mut msg = PubMessage::new(topic.into(), message);
168
169 let len_before = msg.payload().len();
173 if len_before > self.options.min_compress_size &&
174 let Some(ref compressor) = self.compressor
175 {
176 msg.compress(compressor.as_ref())?;
177 trace!("Compressed message from {} to {} bytes", len_before, msg.payload().len());
178 }
179
180 if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() {
182 debug!("No active subscriber sessions");
183 }
184
185 Ok(())
186 }
187
188 pub fn try_publish(&self, topic: String, message: Bytes) -> Result<(), PubError> {
191 let mut msg = PubMessage::new(topic, message);
192
193 if let Some(ref compressor) = self.compressor {
195 let len_before = msg.payload().len();
196
197 msg.compress(compressor.as_ref())?;
199
200 debug!("Compressed message from {} to {} bytes", len_before, msg.payload().len(),);
201 }
202
203 if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() {
205 debug!("No active subscriber sessions");
206 }
207
208 Ok(())
209 }
210
211 pub fn stats(&self) -> &PubStats {
212 &self.state.stats.specific
213 }
214
215 pub fn local_addr(&self) -> Option<&A> {
217 self.local_addr.as_ref()
218 }
219}