1use std::{
2 collections::HashSet,
3 net::SocketAddr,
4 path::PathBuf,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use futures::Stream;
11use rustc_hash::FxHashMap;
12use tokio::{
13 net::{ToSocketAddrs, lookup_host},
14 sync::mpsc,
15 task::JoinSet,
16};
17
18use msg_common::{IpAddrExt, JoinMap};
19use msg_transport::{Address, Transport};
20
21use crate::{
22 ConnectionHook, ConnectionHookErased,
23 sub::{
24 Command, DEFAULT_BUFFER_SIZE, PubMessage, SocketState, SubDriver, SubError, SubOptions,
25 stats::SubStats,
26 },
27};
28
29pub struct SubSocket<T: Transport<A>, A: Address> {
31 to_driver: mpsc::Sender<Command<A>>,
33 from_driver: mpsc::Receiver<PubMessage<A>>,
35 #[allow(unused)]
37 options: Arc<SubOptions>,
38 driver: Option<SubDriver<T, A>>,
40 state: Arc<SocketState<A>>,
42 hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
44 _marker: std::marker::PhantomData<T>,
46}
47
48impl<T> SubSocket<T, SocketAddr>
49where
50 T: Transport<SocketAddr> + Send + Sync + Unpin + 'static,
51{
52 pub async fn connect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
54 let mut addrs = lookup_host(endpoint).await?;
55 let mut endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
56
57 if endpoint.ip().is_unspecified() {
60 endpoint.set_ip(endpoint.ip().as_localhost());
61 }
62
63 self.connect_inner(endpoint).await
64 }
65
66 pub fn try_connect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
68 let addr = endpoint.into();
69 let mut endpoint: SocketAddr = addr.parse().map_err(|_| SubError::NoValidEndpoints)?;
70
71 if endpoint.ip().is_unspecified() {
74 endpoint.set_ip(endpoint.ip().as_localhost());
75 }
76
77 self.try_connect_inner(endpoint)
78 }
79
80 pub async fn disconnect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
82 let mut addrs = lookup_host(endpoint).await?;
83 let endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
84
85 self.disconnect_inner(endpoint).await
86 }
87
88 pub fn try_disconnect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
90 let endpoint = endpoint.into();
91 let endpoint: SocketAddr = endpoint.parse().map_err(|_| SubError::NoValidEndpoints)?;
92
93 self.try_disconnect_inner(endpoint)
94 }
95}
96
97impl<T> SubSocket<T, PathBuf>
98where
99 T: Transport<PathBuf> + Send + Sync + Unpin + 'static,
100{
101 pub async fn connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
103 self.connect_inner(path.into()).await
104 }
105
106 pub fn try_connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
108 self.try_connect_inner(path.into())
109 }
110
111 pub async fn disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
113 self.disconnect_inner(path.into()).await
114 }
115
116 pub fn try_disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
118 self.try_disconnect_inner(path.into())
119 }
120}
121
122impl<T, A> SubSocket<T, A>
123where
124 T: Transport<A> + Send + Sync + Unpin + 'static,
125 A: Address,
126{
127 pub fn new(transport: T) -> Self {
129 Self::with_options(transport, SubOptions::default())
130 }
131
132 pub fn with_options(transport: T, options: SubOptions) -> Self {
134 let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE);
135 let (to_socket, from_driver) = mpsc::channel(options.ingress_queue_size);
136
137 let options = Arc::new(options);
138
139 let state = Arc::new(SocketState::default());
140
141 let mut publishers = FxHashMap::default();
142 publishers.reserve(32);
143
144 let driver = SubDriver {
145 options: Arc::clone(&options),
146 transport,
147 from_socket,
148 to_socket,
149 conn_tasks: JoinMap::new(),
150 hook_tasks: JoinSet::new(),
151 subscribed_topics: HashSet::with_capacity(32),
152 publishers,
153 state: Arc::clone(&state),
154 hook: None,
155 };
156
157 Self {
158 to_driver,
159 from_driver,
160 driver: Some(driver),
161 options,
162 state,
163 hook: None,
164 _marker: std::marker::PhantomData,
165 }
166 }
167
168 pub fn with_connection_hook<H>(mut self, hook: H) -> Self
177 where
178 H: ConnectionHook<T::Io>,
179 {
180 let hook_arc: Arc<dyn ConnectionHookErased<T::Io>> = Arc::new(hook);
181
182 let driver =
184 self.driver.as_mut().expect("cannot set connection hook after driver has started");
185 driver.hook = Some(hook_arc.clone());
186
187 self.hook = Some(hook_arc);
188 self
189 }
190
191 pub async fn connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
193 self.ensure_active_driver();
194 self.send_command(Command::Connect { endpoint }).await?;
195 Ok(())
196 }
197
198 pub fn try_connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
200 self.ensure_active_driver();
201 self.try_send_command(Command::Connect { endpoint })?;
202 Ok(())
203 }
204
205 pub async fn disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
207 self.ensure_active_driver();
208 self.send_command(Command::Disconnect { endpoint }).await?;
209 Ok(())
210 }
211
212 pub fn try_disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
214 self.ensure_active_driver();
215 self.try_send_command(Command::Disconnect { endpoint })?;
216 Ok(())
217 }
218
219 pub async fn subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
223 self.ensure_active_driver();
224
225 let topic = topic.into();
226 if topic.starts_with("MSG") {
227 return Err(SubError::ReservedTopic);
228 }
229
230 self.send_command(Command::Subscribe { topic }).await?;
231
232 Ok(())
233 }
234
235 pub fn try_subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
237 self.ensure_active_driver();
238
239 let topic = topic.into();
240 if topic.starts_with("MSG") {
241 return Err(SubError::ReservedTopic);
242 }
243
244 self.try_send_command(Command::Subscribe { topic })?;
245
246 Ok(())
247 }
248
249 pub async fn unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
251 self.ensure_active_driver();
252
253 let topic = topic.into();
254 if topic.starts_with("MSG") {
255 return Err(SubError::ReservedTopic);
256 }
257
258 self.send_command(Command::Unsubscribe { topic }).await?;
259
260 Ok(())
261 }
262
263 pub fn try_unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
265 self.ensure_active_driver();
266
267 let topic = topic.into();
268 if topic.starts_with("MSG") {
269 return Err(SubError::ReservedTopic);
270 }
271
272 self.try_send_command(Command::Unsubscribe { topic })?;
273
274 Ok(())
275 }
276
277 async fn send_command(&self, command: Command<A>) -> Result<(), SubError> {
280 self.to_driver.send(command).await.map_err(|_| SubError::SocketClosed)?;
281
282 Ok(())
283 }
284
285 fn try_send_command(&self, command: Command<A>) -> Result<(), SubError> {
286 use mpsc::error::TrySendError::*;
287 self.to_driver.try_send(command).map_err(|e| match e {
288 Full(_) => SubError::ChannelFull,
289 Closed(_) => SubError::SocketClosed,
290 })?;
291 Ok(())
292 }
293
294 fn ensure_active_driver(&mut self) {
297 if let Some(driver) = self.driver.take() {
298 tokio::spawn(driver);
299 }
300 }
301
302 pub fn stats(&self) -> &SubStats<A> {
304 &self.state.stats.specific
305 }
306}
307
308impl<T: Transport<A>, A: Address> Drop for SubSocket<T, A> {
309 fn drop(&mut self) {
310 let _ = self.to_driver.try_send(Command::Shutdown);
312 }
313}
314
315impl<T: Transport<A> + Unpin, A: Address> Stream for SubSocket<T, A> {
316 type Item = PubMessage<A>;
317
318 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
319 self.from_driver.poll_recv(cx)
320 }
321}