1use std::{
2 net::SocketAddr,
3 path::PathBuf,
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use futures::{Stream, stream::FuturesUnordered};
10use tokio::{
11 net::{ToSocketAddrs, lookup_host},
12 sync::mpsc,
13 task::{JoinHandle, JoinSet},
14};
15use tokio_stream::StreamMap;
16use tracing::{debug, warn};
17
18use crate::{
19 ConnectionHook, ConnectionHookErased, DEFAULT_QUEUE_SIZE, RepOptions, Request,
20 rep::{RepError, SocketState, driver::RepDriver},
21};
22
23use msg_transport::{Address, Transport};
24use msg_wire::compression::Compressor;
25
26use super::stats::RepStats;
27
28pub struct RepSocket<T: Transport<A>, A: Address> {
30 options: Arc<RepOptions>,
32 state: Arc<SocketState>,
34 from_driver: Option<mpsc::Receiver<Request<A>>>,
36 transport: Option<T>,
39 hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
41 local_addr: Option<A>,
43 compressor: Option<Arc<dyn Compressor>>,
45 control_tx: Option<mpsc::Sender<T::Control>>,
47
48 _driver_task: Option<JoinHandle<Result<(), RepError>>>,
50}
51
52impl<T> RepSocket<T, SocketAddr>
53where
54 T: Transport<SocketAddr>,
55{
56 pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), RepError> {
58 let addrs = lookup_host(addr).await?;
59 self.try_bind(addrs.collect()).await
60 }
61}
62
63impl<T> RepSocket<T, PathBuf>
64where
65 T: Transport<PathBuf>,
66{
67 pub async fn bind(&mut self, path: impl Into<PathBuf>) -> Result<(), RepError> {
69 let addr = path.into().clone();
70 self.try_bind(vec![addr]).await
71 }
72}
73
74impl<T, A> RepSocket<T, A>
75where
76 T: Transport<A>,
77 A: Address,
78{
79 pub fn new(transport: T) -> Self {
81 Self::with_options(transport, RepOptions::balanced())
82 }
83
84 pub fn with_options(transport: T, options: RepOptions) -> Self {
86 Self {
87 from_driver: None,
88 local_addr: None,
89 transport: Some(transport),
90 options: Arc::new(options),
91 state: Arc::new(SocketState::default()),
92 hook: None,
93 compressor: None,
94 control_tx: None,
95 _driver_task: None,
96 }
97 }
98
99 pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
101 self.compressor = Some(Arc::new(compressor));
102 self
103 }
104
105 pub fn with_connection_hook<H>(mut self, hook: H) -> Self
114 where
115 H: ConnectionHook<T::Io>,
116 {
117 assert!(self.transport.is_some(), "cannot set connection hook after socket has been bound");
118 self.hook = Some(Arc::new(hook));
119 self
120 }
121
122 pub async fn try_bind(&mut self, addresses: Vec<A>) -> Result<(), RepError> {
124 let (to_socket, from_backend) = mpsc::channel(DEFAULT_QUEUE_SIZE);
125 let (control_tx, control_rx) = mpsc::channel(DEFAULT_QUEUE_SIZE);
126
127 let mut transport = self.transport.take().expect("transport has been moved already");
128
129 for addr in addresses {
130 match transport.bind(addr.clone()).await {
131 Ok(_) => break,
132 Err(e) => {
133 warn!(?e, ?addr, "failed to bind");
134 continue;
135 }
136 }
137 }
138
139 let Some(local_addr) = transport.local_addr() else {
140 return Err(RepError::NoValidEndpoints);
141 };
142
143 let span = tracing::info_span!(parent: None, "rep_driver", ?local_addr);
144
145 span.in_scope(|| {
146 debug!("listening");
147 });
148
149 let backend = RepDriver {
150 transport,
151 options: Arc::clone(&self.options),
152 state: Arc::clone(&self.state),
153 peer_states: StreamMap::with_capacity(self.options.max_clients.unwrap_or(64)),
154 to_socket,
155 hook: self.hook.take(),
156 hook_tasks: JoinSet::new(),
157 compressor: self.compressor.take(),
158 conn_tasks: FuturesUnordered::new(),
159 control_rx,
160 span,
161 };
162
163 self._driver_task = Some(tokio::spawn(backend));
164 self.local_addr = Some(local_addr);
165 self.from_driver = Some(from_backend);
166 self.control_tx = Some(control_tx);
167
168 Ok(())
169 }
170
171 pub fn stats(&self) -> &RepStats {
173 &self.state.stats.specific
174 }
175
176 pub fn local_addr(&self) -> Option<&A> {
178 self.local_addr.as_ref()
179 }
180
181 pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<Request<A>>> {
183 Pin::new(self).poll_next(cx)
184 }
185
186 pub async fn control(
188 &mut self,
189 control: T::Control,
190 ) -> Result<(), mpsc::error::SendError<T::Control>> {
191 let Some(tx) = self.control_tx.as_mut() else {
192 tracing::warn!("calling control on a non-bound socket, this is a no-op");
193 return Ok(());
194 };
195 tx.send(control).await
196 }
197}
198
199impl<T, A> Stream for RepSocket<T, A>
200where
201 T: Transport<A>,
202 A: Address,
203{
204 type Item = Request<A>;
205
206 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207 self.get_mut().from_driver.as_mut().expect("Inactive socket").poll_recv(cx)
208 }
209}