selium_server/sink/
router.rs1use std::{
4 collections::{hash_map::IterMut, HashMap},
5 fmt::Debug,
6 hash::Hash,
7 pin::Pin,
8 str::FromStr,
9 task::{Context, Poll},
10};
11
12use anyhow::{anyhow, Result};
13use futures::Sink;
14use log::error;
15use selium_protocol::{Frame, MessagePayload};
16use tokio::pin;
17
18const CLIENT_ID_HEADER: &str = "cid";
19
20#[must_use = "sinks do nothing unless you poll them"]
21pub struct Router<K, V> {
22 entries: HashMap<K, V>,
23}
24
25impl<K, V> Router<K, V>
26where
27 K: Eq + Hash,
28{
29 pub fn new() -> Self {
30 Self {
31 entries: HashMap::new(),
32 }
33 }
34
35 pub fn with_capacity(capacity: usize) -> Self {
36 Self {
37 entries: HashMap::with_capacity(capacity),
38 }
39 }
40
41 pub fn iter_mut(&mut self) -> IterMut<K, V> {
42 self.entries.iter_mut()
43 }
44
45 pub fn insert(&mut self, k: K, sink: V) -> Option<V> {
46 let ret = self.remove(&k);
47 self.entries.insert(k, sink);
48
49 ret
50 }
51
52 pub fn remove(&mut self, k: &K) -> Option<V> {
53 self.entries.remove(k)
54 }
55}
56
57impl<K, V> Default for Router<K, V>
58where
59 K: Eq + Hash,
60{
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl<K, V> Sink<Frame> for Router<K, V>
70where
71 K: Eq + Hash + FromStr,
72 <K as FromStr>::Err: std::error::Error + Send + Sync + 'static,
73 V: Sink<Frame> + Unpin,
74 V::Error: Debug,
75 Self: Unpin,
76{
77 type Error = anyhow::Error;
78
79 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80 let mut pending = false;
81
82 self.entries.retain(|_, sink| {
83 if pending {
84 return true;
85 }
86 pin!(sink);
87 match sink.poll_ready(cx) {
88 Poll::Pending => {
89 pending = true;
90 true
91 }
92 Poll::Ready(Err(_)) => false,
93 Poll::Ready(Ok(())) => true,
94 }
95 });
96
97 if pending {
98 Poll::Pending
99 } else {
100 Poll::Ready(Ok(()))
101 }
102 }
103
104 fn start_send(mut self: Pin<&mut Self>, frame: Frame) -> Result<(), Self::Error> {
105 let payload = frame.unwrap_message();
106 if payload.headers.is_none() {
107 return Err(anyhow!("Expected headers for message"));
108 }
109 let mut headers = payload.headers.unwrap();
110 if !headers.contains_key(CLIENT_ID_HEADER) {
111 return Err(anyhow!("Missing CLIENT_ID_HEADER in message headers"));
112 }
113 let cid = headers.remove(CLIENT_ID_HEADER).unwrap().parse::<K>()?;
114 let headers = if headers.is_empty() {
115 None
116 } else {
117 Some(headers)
118 };
119
120 let item = self.entries.get_mut(&cid);
121
122 if item.is_none() {
123 return Err(anyhow!("Client ID not found"));
124 }
125
126 let sink = item.unwrap();
127 pin!(sink);
128 if let Err(e) = sink.start_send(Frame::Message(MessagePayload {
129 headers,
130 message: payload.message,
131 })) {
132 error!("Evicting broken sink from Router::start_send with err: {e:?}");
133 self.entries.remove(&cid);
134 }
135
136 Ok(())
137 }
138
139 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140 let mut pending = false;
141
142 self.entries.retain(|_, sink| {
143 if pending {
144 return true;
145 }
146 pin!(sink);
147 match sink.poll_flush(cx) {
148 Poll::Pending => {
149 pending = true;
150 true
151 }
152 Poll::Ready(Err(_)) => false,
153 Poll::Ready(Ok(())) => true,
154 }
155 });
156
157 if pending {
158 Poll::Pending
159 } else {
160 Poll::Ready(Ok(()))
161 }
162 }
163
164 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165 let mut pending = false;
166
167 self.entries.retain(|_, sink| {
168 if pending {
169 return true;
170 }
171 pin!(sink);
172 match sink.poll_close(cx) {
173 Poll::Pending => {
174 pending = true;
175 true
176 }
177 Poll::Ready(Err(_)) => false,
178 Poll::Ready(Ok(())) => true,
179 }
180 });
181
182 if pending {
183 Poll::Pending
184 } else {
185 Poll::Ready(Ok(()))
186 }
187 }
188}
189
190impl<K, V> FromIterator<(K, V)> for Router<K, V>
191where
192 K: Eq + Hash,
193{
194 fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
195 let iterator = iter.into_iter();
196 let (lower_bound, _) = iterator.size_hint();
197 let mut sink_map = Self::with_capacity(lower_bound);
198
199 for (key, value) in iterator {
200 sink_map.insert(key, value);
201 }
202
203 sink_map
204 }
205}
206
207impl<K, V> Extend<(K, V)> for Router<K, V>
208where
209 K: Eq + Hash,
210{
211 fn extend<T>(&mut self, iter: T)
212 where
213 T: IntoIterator<Item = (K, V)>,
214 {
215 self.entries.extend(iter);
216 }
217}