zeromq/
pub.rs

1use crate::codec::*;
2use crate::endpoint::Endpoint;
3use crate::error::ZmqResult;
4use crate::message::*;
5use crate::transport::AcceptStopHandle;
6use crate::util::PeerIdentity;
7use crate::{async_rt, CaptureSocket, SocketOptions};
8use crate::{
9    MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketSend, SocketType, ZmqError,
10};
11
12use async_trait::async_trait;
13use dashmap::DashMap;
14use futures_channel::{mpsc, oneshot};
15use futures_util::{select, FutureExt, StreamExt};
16use parking_lot::Mutex;
17
18use std::collections::HashMap;
19use std::io::ErrorKind;
20use std::pin::Pin;
21use std::sync::Arc;
22
23pub(crate) struct Subscriber {
24    pub(crate) subscriptions: Vec<Vec<u8>>,
25    pub(crate) send_queue: Pin<Box<ZmqFramedWrite>>,
26    _subscription_coro_stop: oneshot::Sender<()>,
27}
28
29pub(crate) struct PubSocketBackend {
30    subscribers: DashMap<PeerIdentity, Subscriber>,
31    socket_monitor: Mutex<Option<mpsc::Sender<SocketEvent>>>,
32    socket_options: SocketOptions,
33}
34
35impl PubSocketBackend {
36    fn message_received(&self, peer_id: &PeerIdentity, message: Message) {
37        let message = match message {
38            Message::Message(m) => m,
39            _ => return,
40        };
41        assert_eq!(message.len(), 1);
42        let data: Vec<u8> = message.into_vec().pop().unwrap().to_vec();
43        if data.is_empty() {
44            return;
45        }
46        match data[0] {
47            1 => {
48                // Subscribe
49                self.subscribers
50                    .get_mut(peer_id)
51                    .unwrap()
52                    .subscriptions
53                    .push(Vec::from(&data[1..]));
54            }
55            0 => {
56                // Unsubscribe
57                let mut del_index = None;
58                let sub = Vec::from(&data[1..]);
59                for (idx, subscription) in self
60                    .subscribers
61                    .get(peer_id)
62                    .unwrap()
63                    .subscriptions
64                    .iter()
65                    .enumerate()
66                {
67                    if &sub == subscription {
68                        del_index = Some(idx);
69                        break;
70                    }
71                }
72                if let Some(index) = del_index {
73                    self.subscribers
74                        .get_mut(peer_id)
75                        .unwrap()
76                        .subscriptions
77                        .remove(index);
78                }
79            }
80            _ => (),
81        }
82    }
83}
84
85impl SocketBackend for PubSocketBackend {
86    fn socket_type(&self) -> SocketType {
87        SocketType::PUB
88    }
89
90    fn socket_options(&self) -> &SocketOptions {
91        &self.socket_options
92    }
93
94    fn shutdown(&self) {
95        self.subscribers.clear();
96    }
97
98    fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>> {
99        &self.socket_monitor
100    }
101}
102
103#[async_trait]
104impl MultiPeerBackend for PubSocketBackend {
105    async fn peer_connected(self: Arc<Self>, peer_id: &PeerIdentity, io: FramedIo) {
106        let (mut recv_queue, send_queue) = io.into_parts();
107        // TODO provide handling for recv_queue
108        let (sender, stop_receiver) = oneshot::channel();
109        self.subscribers.insert(
110            peer_id.clone(),
111            Subscriber {
112                subscriptions: vec![],
113                send_queue: Box::pin(send_queue),
114                _subscription_coro_stop: sender,
115            },
116        );
117        let backend = self;
118        let peer_id = peer_id.clone();
119        async_rt::task::spawn(async move {
120            let mut stop_receiver = stop_receiver.fuse();
121            loop {
122                select! {
123                     _ = stop_receiver => {
124                         break;
125                     },
126                     message = recv_queue.next().fuse() => {
127                        match message {
128                            Some(Ok(m)) => backend.message_received(&peer_id, m),
129                            Some(Err(e)) => {
130                                dbg!(e);
131                                backend.peer_disconnected(&peer_id);
132                                break;
133                            }
134                            None => {
135                                backend.peer_disconnected(&peer_id);
136                                break
137                            }
138                        }
139
140                     }
141                }
142            }
143        });
144    }
145
146    fn peer_disconnected(&self, peer_id: &PeerIdentity) {
147        log::info!("Client disconnected {:?}", peer_id);
148        self.subscribers.remove(peer_id);
149    }
150}
151
152pub struct PubSocket {
153    pub(crate) backend: Arc<PubSocketBackend>,
154    binds: HashMap<Endpoint, AcceptStopHandle>,
155}
156
157impl Drop for PubSocket {
158    fn drop(&mut self) {
159        self.backend.shutdown();
160    }
161}
162
163#[async_trait]
164impl SocketSend for PubSocket {
165    async fn send(&mut self, message: ZmqMessage) -> ZmqResult<()> {
166        let mut dead_peers = Vec::new();
167        for mut subscriber in self.backend.subscribers.iter_mut() {
168            for sub_filter in &subscriber.subscriptions {
169                if sub_filter.len() <= message.get(0).unwrap().len()
170                    && sub_filter.as_slice() == &message.get(0).unwrap()[0..sub_filter.len()]
171                {
172                    let res = subscriber
173                        .send_queue
174                        .as_mut()
175                        .try_send(Message::Message(message.clone()));
176                    match res {
177                        Ok(()) => {}
178                        Err(ZmqError::Codec(CodecError::Io(e))) => {
179                            if e.kind() == ErrorKind::BrokenPipe {
180                                dead_peers.push(subscriber.key().clone());
181                            } else {
182                                dbg!(e);
183                            }
184                        }
185                        Err(ZmqError::BufferFull(_)) => {
186                            // ignore silently. https://rfc.zeromq.org/spec/29/ says:
187                            // For processing outgoing messages:
188                            //   SHALL silently drop the message if the queue for a subscriber is full.
189                        }
190                        Err(e) => {
191                            dbg!(e);
192                            todo!()
193                        }
194                    }
195                    break;
196                }
197            }
198        }
199        for peer in dead_peers {
200            self.backend.peer_disconnected(&peer);
201        }
202        Ok(())
203    }
204}
205
206impl CaptureSocket for PubSocket {}
207
208#[async_trait]
209impl Socket for PubSocket {
210    fn with_options(options: SocketOptions) -> Self {
211        Self {
212            backend: Arc::new(PubSocketBackend {
213                subscribers: DashMap::new(),
214                socket_monitor: Mutex::new(None),
215                socket_options: options,
216            }),
217            binds: HashMap::new(),
218        }
219    }
220
221    fn backend(&self) -> Arc<dyn MultiPeerBackend> {
222        self.backend.clone()
223    }
224
225    fn binds(&mut self) -> &mut HashMap<Endpoint, AcceptStopHandle> {
226        &mut self.binds
227    }
228
229    fn monitor(&mut self) -> mpsc::Receiver<SocketEvent> {
230        let (sender, receiver) = mpsc::channel(1024);
231        self.backend.socket_monitor.lock().replace(sender);
232        receiver
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::util::tests::{
240        test_bind_to_any_port_helper, test_bind_to_unspecified_interface_helper,
241    };
242    use crate::ZmqResult;
243    use std::net::IpAddr;
244
245    #[async_rt::test]
246    async fn test_bind_to_any_port() -> ZmqResult<()> {
247        let s = PubSocket::new();
248        test_bind_to_any_port_helper(s).await
249    }
250
251    #[async_rt::test]
252    async fn test_bind_to_any_ipv4_interface() -> ZmqResult<()> {
253        let any_ipv4: IpAddr = "0.0.0.0".parse().unwrap();
254        let s = PubSocket::new();
255        test_bind_to_unspecified_interface_helper(any_ipv4, s, 4000).await
256    }
257
258    #[async_rt::test]
259    async fn test_bind_to_any_ipv6_interface() -> ZmqResult<()> {
260        let any_ipv6: IpAddr = "::".parse().unwrap();
261        let s = PubSocket::new();
262        test_bind_to_unspecified_interface_helper(any_ipv6, s, 4010).await
263    }
264}