selium_server/sink/
router.rs

1//! Router is based on tokio-stream::StreamMap
2
3use std::{
4    collections::{hash_map::IterMut, HashMap},
5    fmt::Debug,
6    hash::Hash,
7    pin::Pin,
8    str::FromStr,
9    task::{Context, Poll},
10};
11
12use anyhow::{anyhow, Result};
13use futures::Sink;
14use log::error;
15use selium_protocol::{Frame, MessagePayload};
16use tokio::pin;
17
18const CLIENT_ID_HEADER: &str = "cid";
19
20#[must_use = "sinks do nothing unless you poll them"]
21pub struct Router<K, V> {
22    entries: HashMap<K, V>,
23}
24
25impl<K, V> Router<K, V>
26where
27    K: Eq + Hash,
28{
29    pub fn new() -> Self {
30        Self {
31            entries: HashMap::new(),
32        }
33    }
34
35    pub fn with_capacity(capacity: usize) -> Self {
36        Self {
37            entries: HashMap::with_capacity(capacity),
38        }
39    }
40
41    pub fn iter_mut(&mut self) -> IterMut<K, V> {
42        self.entries.iter_mut()
43    }
44
45    pub fn insert(&mut self, k: K, sink: V) -> Option<V> {
46        let ret = self.remove(&k);
47        self.entries.insert(k, sink);
48
49        ret
50    }
51
52    pub fn remove(&mut self, k: &K) -> Option<V> {
53        self.entries.remove(k)
54    }
55}
56
57impl<K, V> Default for Router<K, V>
58where
59    K: Eq + Hash,
60{
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66// Note that this is an inefficient implementation as a slow sink will block other sinks
67// from receiving data, despite the slow sink not actually affecting other sinks.
68// https://github.com/seliumlabs/selium/issues/148
69impl<K, V> Sink<Frame> for Router<K, V>
70where
71    K: Eq + Hash + FromStr,
72    <K as FromStr>::Err: std::error::Error + Send + Sync + 'static,
73    V: Sink<Frame> + Unpin,
74    V::Error: Debug,
75    Self: Unpin,
76{
77    type Error = anyhow::Error;
78
79    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80        let mut pending = false;
81
82        self.entries.retain(|_, sink| {
83            if pending {
84                return true;
85            }
86            pin!(sink);
87            match sink.poll_ready(cx) {
88                Poll::Pending => {
89                    pending = true;
90                    true
91                }
92                Poll::Ready(Err(_)) => false,
93                Poll::Ready(Ok(())) => true,
94            }
95        });
96
97        if pending {
98            Poll::Pending
99        } else {
100            Poll::Ready(Ok(()))
101        }
102    }
103
104    fn start_send(mut self: Pin<&mut Self>, frame: Frame) -> Result<(), Self::Error> {
105        let payload = frame.unwrap_message();
106        if payload.headers.is_none() {
107            return Err(anyhow!("Expected headers for message"));
108        }
109        let mut headers = payload.headers.unwrap();
110        if !headers.contains_key(CLIENT_ID_HEADER) {
111            return Err(anyhow!("Missing CLIENT_ID_HEADER in message headers"));
112        }
113        let cid = headers.remove(CLIENT_ID_HEADER).unwrap().parse::<K>()?;
114        let headers = if headers.is_empty() {
115            None
116        } else {
117            Some(headers)
118        };
119
120        let item = self.entries.get_mut(&cid);
121
122        if item.is_none() {
123            return Err(anyhow!("Client ID not found"));
124        }
125
126        let sink = item.unwrap();
127        pin!(sink);
128        if let Err(e) = sink.start_send(Frame::Message(MessagePayload {
129            headers,
130            message: payload.message,
131        })) {
132            error!("Evicting broken sink from Router::start_send with err: {e:?}");
133            self.entries.remove(&cid);
134        }
135
136        Ok(())
137    }
138
139    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140        let mut pending = false;
141
142        self.entries.retain(|_, sink| {
143            if pending {
144                return true;
145            }
146            pin!(sink);
147            match sink.poll_flush(cx) {
148                Poll::Pending => {
149                    pending = true;
150                    true
151                }
152                Poll::Ready(Err(_)) => false,
153                Poll::Ready(Ok(())) => true,
154            }
155        });
156
157        if pending {
158            Poll::Pending
159        } else {
160            Poll::Ready(Ok(()))
161        }
162    }
163
164    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        let mut pending = false;
166
167        self.entries.retain(|_, sink| {
168            if pending {
169                return true;
170            }
171            pin!(sink);
172            match sink.poll_close(cx) {
173                Poll::Pending => {
174                    pending = true;
175                    true
176                }
177                Poll::Ready(Err(_)) => false,
178                Poll::Ready(Ok(())) => true,
179            }
180        });
181
182        if pending {
183            Poll::Pending
184        } else {
185            Poll::Ready(Ok(()))
186        }
187    }
188}
189
190impl<K, V> FromIterator<(K, V)> for Router<K, V>
191where
192    K: Eq + Hash,
193{
194    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
195        let iterator = iter.into_iter();
196        let (lower_bound, _) = iterator.size_hint();
197        let mut sink_map = Self::with_capacity(lower_bound);
198
199        for (key, value) in iterator {
200            sink_map.insert(key, value);
201        }
202
203        sink_map
204    }
205}
206
207impl<K, V> Extend<(K, V)> for Router<K, V>
208where
209    K: Eq + Hash,
210{
211    fn extend<T>(&mut self, iter: T)
212    where
213        T: IntoIterator<Item = (K, V)>,
214    {
215        self.entries.extend(iter);
216    }
217}