1pub mod errors;
40
41use cord_message::{errors::Error as MessageError, Codec, Message, Pattern};
42use errors::{Error, ErrorKind, Result};
43use futures::{
44 future::{self, try_select},
45 stream::SplitSink,
46 Sink, SinkExt, Stream, StreamExt, TryStreamExt,
47};
48use futures_locks::Mutex;
49use tokio::{
50 net::{TcpStream, ToSocketAddrs},
51 sync::mpsc,
52 sync::oneshot,
53};
54use tokio_util::codec::Framed;
55
56use std::{
57 collections::HashMap,
58 convert::Into,
59 ops::Drop,
60 pin::Pin,
61 result,
62 sync::Arc,
63 task::{Context, Poll},
64};
65
66pub type Client = ClientConn<SplitSink<Framed<TcpStream, Codec>, Message>>;
75
76pub struct ClientConn<S> {
81 sink: S,
82 inner: Arc<Inner>,
83}
84
85pub struct Subscriber {
110 receiver: mpsc::Receiver<Message>,
111 _inner: Arc<Inner>,
112}
113
114struct Inner {
115 receivers: Mutex<HashMap<Pattern, Vec<mpsc::Sender<Message>>>>,
116 detonator: Option<oneshot::Sender<()>>,
117}
118
119impl<S> ClientConn<S>
120where
121 S: Sink<Message, Error = MessageError> + Unpin,
122{
123 pub async fn connect<A>(addr: A) -> Result<Client>
125 where
126 A: ToSocketAddrs,
127 {
128 let (det_tx, det_rx) = oneshot::channel();
130
131 let sock = TcpStream::connect(addr).await?;
133
134 let framed = Framed::new(sock, Codec::default());
136 let (sink, stream) = framed.split();
137
138 let receivers = Mutex::new(HashMap::new());
140 let receivers_c = receivers.clone();
141
142 let router = Box::pin(
144 stream
145 .map_err(|e| Error::from_kind(ErrorKind::Message(e)))
146 .try_fold(receivers_c, |recv, message| async move {
147 route(&recv, message).await;
148 Ok(recv)
149 }),
150 );
151
152 tokio::spawn(try_select(router, det_rx));
153
154 Ok(ClientConn {
155 sink,
156 inner: Arc::new(Inner {
157 receivers,
158 detonator: Some(det_tx),
159 }),
160 })
161 }
162
163 pub async fn provide(&mut self, namespace: Pattern) -> Result<()> {
165 self.sink
166 .send(Message::Provide(namespace))
167 .await
168 .map_err(|e| ErrorKind::Message(e).into())
169 }
170
171 pub async fn revoke(&mut self, namespace: Pattern) -> Result<()> {
173 self.sink
174 .send(Message::Revoke(namespace))
175 .await
176 .map_err(|e| ErrorKind::Message(e).into())
177 }
178
179 pub async fn subscribe(&mut self, namespace: Pattern) -> Result<Subscriber> {
203 let namespace_c = namespace.clone();
204 self.sink.send(Message::Subscribe(namespace)).await?;
205
206 let (tx, rx) = mpsc::channel(10);
207 self.inner
208 .receivers
209 .with(move |mut guard| {
210 (*guard)
211 .entry(namespace_c)
212 .or_insert_with(Vec::new)
213 .push(tx);
214 let ok: result::Result<_, ()> = Ok(());
215 future::ready(ok)
216 })
217 .await
218 .unwrap();
219 Ok(Subscriber {
220 receiver: rx,
221 _inner: self.inner.clone(),
222 })
223 }
224
225 pub async fn unsubscribe(&mut self, namespace: Pattern) -> Result<()> {
227 let namespace_c = namespace.clone();
228 self.sink.send(Message::Unsubscribe(namespace)).await?;
229
230 self.inner
231 .receivers
232 .with(move |mut guard| {
233 (*guard).remove(&namespace_c);
234 future::ready(())
235 })
236 .await;
237 Ok(())
238 }
239
240 pub async fn event<Str: Into<String>>(&mut self, namespace: Pattern, data: Str) -> Result<()> {
242 self.sink
243 .send(Message::Event(namespace, data.into()))
244 .await
245 .map_err(|e| ErrorKind::Message(e).into())
246 }
247}
248
249impl<E, S, T> Sink<T> for ClientConn<S>
250where
251 S: Sink<T, Error = E>,
252 E: Into<Error>,
253{
254 type Error = Error;
255
256 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<result::Result<(), Self::Error>> {
257 unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
258 .poll_ready(cx)
259 .map_err(|e| e.into())
260 }
261
262 fn start_send(self: Pin<&mut Self>, item: T) -> result::Result<(), Self::Error> {
263 unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
264 .start_send(item)
265 .map_err(|e| e.into())
266 }
267
268 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<result::Result<(), Self::Error>> {
269 unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
270 .poll_flush(cx)
271 .map_err(|e| e.into())
272 }
273
274 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<result::Result<(), Self::Error>> {
275 unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
276 .poll_close(cx)
277 .map_err(|e| e.into())
278 }
279}
280
281impl Stream for Subscriber {
282 type Item = (Pattern, String);
283
284 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
285 unsafe { Pin::map_unchecked_mut(self, |x| &mut x.receiver) }
286 .poll_recv(cx)
287 .map(|opt_msg| match opt_msg {
288 Some(Message::Event(pattern, data)) => Some((pattern, data)),
289 None => None,
290 _ => unreachable!(),
291 })
292 }
293}
294
295impl Drop for Inner {
296 fn drop(&mut self) {
297 let _ = self
300 .detonator
301 .take()
302 .expect("Inner has already been terminated")
303 .send(());
304 }
305}
306
307async fn route(receivers: &Mutex<HashMap<Pattern, Vec<mpsc::Sender<Message>>>>, message: Message) {
308 receivers
309 .with(move |mut guard| {
310 (*guard).retain(|namespace, senders| {
312 if namespace.contains(message.namespace()) {
315 senders.retain_mut(|tx| tx.try_send(message.clone()).is_ok());
317 }
318
319 !senders.is_empty()
321 });
322
323 future::ready(())
324 })
325 .await
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 use cord_message::errors::ErrorKind as MessageErrorKind;
333
334 use futures::channel::mpsc::{unbounded, UnboundedReceiver};
337
338 struct ForwardStream(Vec<Message>);
339
340 impl Stream for ForwardStream {
341 type Item = Result<Message>;
342
343 fn poll_next(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
344 Poll::Ready(self.0.pop().map(Ok))
345 }
346 }
347
348 #[allow(clippy::type_complexity)]
349 fn setup_client() -> (
350 ClientConn<impl Sink<Message, Error = MessageError>>,
351 UnboundedReceiver<Message>,
352 Mutex<HashMap<Pattern, Vec<mpsc::Sender<Message>>>>,
353 ) {
354 let (tx, rx) = unbounded();
355 let (det_tx, _) = oneshot::channel();
356 let receivers = Mutex::new(HashMap::new());
357
358 (
359 ClientConn {
360 sink: tx.sink_map_err(|e| MessageErrorKind::Msg(format!("{}", e)).into()),
361 inner: Arc::new(Inner {
362 receivers: receivers.clone(),
363 detonator: Some(det_tx),
364 }),
365 },
366 rx,
367 receivers,
368 )
369 }
370
371 #[tokio::test]
372 async fn test_forward() {
373 let (client, rx, _) = setup_client();
374
375 let data_stream = ForwardStream(vec![
376 Message::Event("/a".into(), "b".into()),
377 Message::Provide("/a".into()),
378 ]);
379 data_stream.forward(client).await.unwrap();
380
381 let (item, rx) = rx.into_future().await;
384 assert_eq!(item, Some(Message::Provide("/a".into())));
385
386 let (item, _) = rx.into_future().await;
387 assert_eq!(item, Some(Message::Event("/a".into(), "b".into())));
388 }
389
390 #[tokio::test]
391 async fn test_provide() {
392 let (mut client, rx, _) = setup_client();
393
394 client.provide("/a/b".into()).await.unwrap();
395 assert_eq!(
396 rx.into_future().await.0.unwrap(),
397 Message::Provide("/a/b".into())
398 );
399 }
400
401 #[tokio::test]
402 async fn test_revoke() {
403 let (mut client, rx, _) = setup_client();
404
405 client.revoke("/a/b".into()).await.unwrap();
406 assert_eq!(
407 rx.into_future().await.0.unwrap(),
408 Message::Revoke("/a/b".into())
409 );
410 }
411
412 #[tokio::test]
413 async fn test_subscribe() {
414 let (mut client, rx, receivers) = setup_client();
415
416 client.subscribe("/a/b".into()).await.unwrap();
417
418 assert_eq!(
420 rx.into_future().await.0.unwrap(),
421 Message::Subscribe("/a/b".into())
422 );
423
424 let guard = receivers.lock().await;
426 assert!((*guard).contains_key(&"/a/b".into()));
427 }
428
429 #[tokio::test]
430 async fn test_unsubscribe() {
431 let (mut client, rx, receivers) = setup_client();
432
433 receivers
434 .with(|mut guard| {
435 (*guard).insert("/a/b".into(), Vec::new());
436 future::ready(())
437 })
438 .await;
439
440 client.unsubscribe("/a/b".into()).await.unwrap();
441
442 assert_eq!(
444 rx.into_future().await.0.unwrap(),
445 Message::Unsubscribe("/a/b".into())
446 );
447
448 let guard = receivers.lock().await;
450 assert!((*guard).is_empty());
451 }
452
453 #[tokio::test]
454 async fn test_event() {
455 let (mut client, rx, _) = setup_client();
456
457 client.event("/a/b".into(), "moo").await.unwrap();
458 assert_eq!(
459 rx.into_future().await.0.unwrap(),
460 Message::Event("/a/b".into(), "moo".into())
461 );
462 }
463
464 #[tokio::test]
465 async fn test_route() {
466 let (tx, mut rx) = mpsc::channel(10);
467 let receivers = Mutex::new(HashMap::new());
468
469 receivers
470 .with(|mut guard| {
471 (*guard).insert("/a/b".into(), vec![tx]);
472 future::ready(())
473 })
474 .await;
475
476 let event_msg = Message::Event("/a/b".into(), "Moo!".into());
477 let event_msg_c = event_msg.clone();
478
479 route(&receivers, event_msg).await;
480
481 assert_eq!(rx.recv().await.unwrap(), event_msg_c);
483
484 let guard = receivers.lock().await;
485 assert!((*guard).contains_key(&"/a/b".into()));
486 }
487
488 #[tokio::test]
489 async fn test_route_norecv() {
490 let (tx, _) = mpsc::channel(10);
491 let receivers = Mutex::new(HashMap::new());
492
493 receivers
494 .with(|mut guard| {
495 (*guard).insert("/a/b".into(), vec![tx]);
496 future::ready(())
497 })
498 .await;
499
500 route(&receivers, Message::Event("/a/b".into(), "Moo!".into())).await;
501
502 let guard = receivers.lock().await;
504 assert!((*guard).is_empty());
505 }
506}