1use crate::codec::*;
2use crate::endpoint::Endpoint;
3use crate::error::ZmqResult;
4use crate::message::*;
5use crate::transport::AcceptStopHandle;
6use crate::util::PeerIdentity;
7use crate::{async_rt, CaptureSocket, SocketOptions};
8use crate::{
9 MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketSend, SocketType, ZmqError,
10};
11
12use async_trait::async_trait;
13use dashmap::DashMap;
14use futures_channel::{mpsc, oneshot};
15use futures_util::{select, FutureExt, StreamExt};
16use parking_lot::Mutex;
17
18use std::collections::HashMap;
19use std::io::ErrorKind;
20use std::pin::Pin;
21use std::sync::Arc;
22
23pub(crate) struct Subscriber {
24 pub(crate) subscriptions: Vec<Vec<u8>>,
25 pub(crate) send_queue: Pin<Box<ZmqFramedWrite>>,
26 _subscription_coro_stop: oneshot::Sender<()>,
27}
28
29pub(crate) struct PubSocketBackend {
30 subscribers: DashMap<PeerIdentity, Subscriber>,
31 socket_monitor: Mutex<Option<mpsc::Sender<SocketEvent>>>,
32 socket_options: SocketOptions,
33}
34
35impl PubSocketBackend {
36 fn message_received(&self, peer_id: &PeerIdentity, message: Message) {
37 let message = match message {
38 Message::Message(m) => m,
39 _ => return,
40 };
41 assert_eq!(message.len(), 1);
42 let data: Vec<u8> = message.into_vec().pop().unwrap().to_vec();
43 if data.is_empty() {
44 return;
45 }
46 match data[0] {
47 1 => {
48 self.subscribers
50 .get_mut(peer_id)
51 .unwrap()
52 .subscriptions
53 .push(Vec::from(&data[1..]));
54 }
55 0 => {
56 let mut del_index = None;
58 let sub = Vec::from(&data[1..]);
59 for (idx, subscription) in self
60 .subscribers
61 .get(peer_id)
62 .unwrap()
63 .subscriptions
64 .iter()
65 .enumerate()
66 {
67 if &sub == subscription {
68 del_index = Some(idx);
69 break;
70 }
71 }
72 if let Some(index) = del_index {
73 self.subscribers
74 .get_mut(peer_id)
75 .unwrap()
76 .subscriptions
77 .remove(index);
78 }
79 }
80 _ => (),
81 }
82 }
83}
84
85impl SocketBackend for PubSocketBackend {
86 fn socket_type(&self) -> SocketType {
87 SocketType::PUB
88 }
89
90 fn socket_options(&self) -> &SocketOptions {
91 &self.socket_options
92 }
93
94 fn shutdown(&self) {
95 self.subscribers.clear();
96 }
97
98 fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>> {
99 &self.socket_monitor
100 }
101}
102
103#[async_trait]
104impl MultiPeerBackend for PubSocketBackend {
105 async fn peer_connected(self: Arc<Self>, peer_id: &PeerIdentity, io: FramedIo) {
106 let (mut recv_queue, send_queue) = io.into_parts();
107 let (sender, stop_receiver) = oneshot::channel();
109 self.subscribers.insert(
110 peer_id.clone(),
111 Subscriber {
112 subscriptions: vec![],
113 send_queue: Box::pin(send_queue),
114 _subscription_coro_stop: sender,
115 },
116 );
117 let backend = self;
118 let peer_id = peer_id.clone();
119 async_rt::task::spawn(async move {
120 let mut stop_receiver = stop_receiver.fuse();
121 loop {
122 select! {
123 _ = stop_receiver => {
124 break;
125 },
126 message = recv_queue.next().fuse() => {
127 match message {
128 Some(Ok(m)) => backend.message_received(&peer_id, m),
129 Some(Err(e)) => {
130 dbg!(e);
131 backend.peer_disconnected(&peer_id);
132 break;
133 }
134 None => {
135 backend.peer_disconnected(&peer_id);
136 break
137 }
138 }
139
140 }
141 }
142 }
143 });
144 }
145
146 fn peer_disconnected(&self, peer_id: &PeerIdentity) {
147 log::info!("Client disconnected {:?}", peer_id);
148 self.subscribers.remove(peer_id);
149 }
150}
151
152pub struct PubSocket {
153 pub(crate) backend: Arc<PubSocketBackend>,
154 binds: HashMap<Endpoint, AcceptStopHandle>,
155}
156
157impl Drop for PubSocket {
158 fn drop(&mut self) {
159 self.backend.shutdown();
160 }
161}
162
163#[async_trait]
164impl SocketSend for PubSocket {
165 async fn send(&mut self, message: ZmqMessage) -> ZmqResult<()> {
166 let mut dead_peers = Vec::new();
167 for mut subscriber in self.backend.subscribers.iter_mut() {
168 for sub_filter in &subscriber.subscriptions {
169 if sub_filter.len() <= message.get(0).unwrap().len()
170 && sub_filter.as_slice() == &message.get(0).unwrap()[0..sub_filter.len()]
171 {
172 let res = subscriber
173 .send_queue
174 .as_mut()
175 .try_send(Message::Message(message.clone()));
176 match res {
177 Ok(()) => {}
178 Err(ZmqError::Codec(CodecError::Io(e))) => {
179 if e.kind() == ErrorKind::BrokenPipe {
180 dead_peers.push(subscriber.key().clone());
181 } else {
182 dbg!(e);
183 }
184 }
185 Err(ZmqError::BufferFull(_)) => {
186 }
190 Err(e) => {
191 dbg!(e);
192 todo!()
193 }
194 }
195 break;
196 }
197 }
198 }
199 for peer in dead_peers {
200 self.backend.peer_disconnected(&peer);
201 }
202 Ok(())
203 }
204}
205
206impl CaptureSocket for PubSocket {}
207
208#[async_trait]
209impl Socket for PubSocket {
210 fn with_options(options: SocketOptions) -> Self {
211 Self {
212 backend: Arc::new(PubSocketBackend {
213 subscribers: DashMap::new(),
214 socket_monitor: Mutex::new(None),
215 socket_options: options,
216 }),
217 binds: HashMap::new(),
218 }
219 }
220
221 fn backend(&self) -> Arc<dyn MultiPeerBackend> {
222 self.backend.clone()
223 }
224
225 fn binds(&mut self) -> &mut HashMap<Endpoint, AcceptStopHandle> {
226 &mut self.binds
227 }
228
229 fn monitor(&mut self) -> mpsc::Receiver<SocketEvent> {
230 let (sender, receiver) = mpsc::channel(1024);
231 self.backend.socket_monitor.lock().replace(sender);
232 receiver
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::util::tests::{
240 test_bind_to_any_port_helper, test_bind_to_unspecified_interface_helper,
241 };
242 use crate::ZmqResult;
243 use std::net::IpAddr;
244
245 #[async_rt::test]
246 async fn test_bind_to_any_port() -> ZmqResult<()> {
247 let s = PubSocket::new();
248 test_bind_to_any_port_helper(s).await
249 }
250
251 #[async_rt::test]
252 async fn test_bind_to_any_ipv4_interface() -> ZmqResult<()> {
253 let any_ipv4: IpAddr = "0.0.0.0".parse().unwrap();
254 let s = PubSocket::new();
255 test_bind_to_unspecified_interface_helper(any_ipv4, s, 4000).await
256 }
257
258 #[async_rt::test]
259 async fn test_bind_to_any_ipv6_interface() -> ZmqResult<()> {
260 let any_ipv6: IpAddr = "::".parse().unwrap();
261 let s = PubSocket::new();
262 test_bind_to_unspecified_interface_helper(any_ipv6, s, 4010).await
263 }
264}