selium_server/topic/
reqrep.rs

1use crate::{sink::Router, BoxSink};
2use futures::{
3    channel::mpsc::{self, Receiver, Sender},
4    ready,
5    stream::BoxStream,
6    Future, Sink, SinkExt, Stream, StreamExt,
7};
8use log::{error, warn};
9use pin_project_lite::pin_project;
10use selium_protocol::{
11    error_codes::REPLIER_ALREADY_BOUND,
12    traits::{ShutdownSink, ShutdownStream},
13    ErrorPayload, Frame,
14};
15use selium_std::errors::{Result, SeliumError};
16use std::{
17    collections::HashMap,
18    pin::Pin,
19    task::{Context, Poll},
20};
21use tokio_stream::StreamMap;
22
23const SOCK_CHANNEL_SIZE: usize = 100;
24
25type BoxedBiStream = (
26    BoxSink<Frame, SeliumError>,
27    BoxStream<'static, Result<Frame>>,
28);
29
30pub enum Socket {
31    Client(
32        (
33            BoxSink<Frame, SeliumError>,
34            BoxStream<'static, Result<Frame>>,
35        ),
36    ),
37    Server(
38        (
39            BoxSink<Frame, SeliumError>,
40            BoxStream<'static, Result<Frame>>,
41        ),
42    ),
43}
44
45pin_project! {
46    #[project = TopicProj]
47    #[must_use = "futures do nothing unless you `.await` or poll them"]
48    pub struct Topic {
49        #[pin]
50        server: Option<BoxedBiStream>,
51        #[pin]
52        stream: StreamMap<usize, BoxStream<'static, Result<Frame>>>,
53        #[pin]
54        sink: Router<usize, BoxSink<Frame, SeliumError>>,
55        next_id: usize,
56        #[pin]
57        handle: Receiver<Socket>,
58        buffered_req: Option<Frame>,
59        buffered_rep: Option<Frame>,
60        buffered_err: Option<(Option<ErrorPayload>, BoxSink<Frame, SeliumError>)>,
61    }
62}
63
64impl Topic {
65    pub fn pair() -> (Self, Sender<Socket>) {
66        let (tx, rx) = mpsc::channel(SOCK_CHANNEL_SIZE);
67
68        (
69            Self {
70                server: None,
71                stream: StreamMap::new(),
72                sink: Router::new(),
73                next_id: 0,
74                handle: rx,
75                buffered_req: None,
76                buffered_rep: None,
77                buffered_err: None,
78            },
79            tx,
80        )
81    }
82}
83
84impl Future for Topic {
85    type Output = ();
86
87    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88        let TopicProj {
89            mut server,
90            mut stream,
91            mut sink,
92            next_id,
93            mut handle,
94            buffered_req,
95            buffered_rep,
96            buffered_err,
97        } = self.project();
98
99        loop {
100            let mut server_pending = false;
101            let mut stream_pending = false;
102
103            // If we've got a request buffered already, we need to write it to the replier
104            // before we can do anything else.
105            if buffered_req.is_some() && server.is_some() {
106                let si = &mut server.as_mut().as_pin_mut().unwrap().0;
107                // Unwrapping is safe as the underlying sink is guaranteed not to error
108                ready!(si.poll_ready_unpin(cx)).unwrap();
109                si.start_send_unpin(buffered_req.take().unwrap()).unwrap();
110            }
111
112            // If we've got an error buffered already, we need to write it to the client
113            // before we can do anything else.
114            if let Some((maybe_err, mut si)) = buffered_err.take() {
115                if let Some(err) = maybe_err {
116                    match si.poll_ready_unpin(cx) {
117                        Poll::Ready(Ok(_)) => {
118                            if si.start_send_unpin(Frame::Error(err)).is_ok() {
119                                *buffered_err = Some((None, si));
120                            }
121                        }
122                        Poll::Ready(Err(e)) => warn!("Could not poll replier sink: {e:?}"),
123                        Poll::Pending => {
124                            *buffered_err = Some((Some(err), si));
125                            return Poll::Pending;
126                        }
127                    }
128                } else {
129                    match si.poll_close_unpin(cx) {
130                        Poll::Ready(Ok(_)) => (),
131                        Poll::Ready(Err(e)) => warn!("Could not close replier sink: {e:?}"),
132                        Poll::Pending => {
133                            *buffered_err = Some((None, si));
134                            return Poll::Pending;
135                        }
136                    }
137                }
138            }
139
140            match handle.as_mut().poll_next(cx) {
141                Poll::Ready(Some(sock)) => match sock {
142                    Socket::Client((si, st)) => {
143                        stream.as_mut().insert(*next_id, st);
144                        sink.as_mut().insert(*next_id, si);
145
146                        *next_id += 1;
147                    }
148                    Socket::Server((si, st)) => {
149                        if server.is_some() {
150                            let error_payload = ErrorPayload {
151                                code: REPLIER_ALREADY_BOUND,
152                                message: "A replier already exists for this topic".into(),
153                            };
154                            *buffered_err = Some((Some(error_payload), si));
155                        } else {
156                            let _ = server.insert((si, st));
157                        }
158                    }
159                },
160                // If handle is terminated, the stream is dead
161                Poll::Ready(None) => {
162                    ready!(sink.as_mut().poll_flush(cx)).unwrap();
163                    stream.iter_mut().for_each(|(_, s)| s.shutdown_stream());
164                    sink.iter_mut().for_each(|(_, s)| s.shutdown_sink());
165
166                    if server.is_some() {
167                        server.as_mut().as_pin_mut().unwrap().1.shutdown_stream();
168                    }
169
170                    return Poll::Ready(());
171                }
172                // If no messages are available and there's no work to do, block this future
173                Poll::Pending
174                    if stream.is_empty()
175                        && server.is_none()
176                        && buffered_req.is_none()
177                        && buffered_rep.is_none() =>
178                {
179                    return Poll::Pending
180                }
181                // Otherwise, move on with running the stream
182                Poll::Pending => (),
183            }
184
185            if server.is_some() {
186                let st = &mut server.as_mut().as_pin_mut().unwrap().1;
187
188                match st.poll_next_unpin(cx) {
189                    // Received message from the server stream
190                    Poll::Ready(Some(Ok(item))) => {
191                        *buffered_rep = Some(item);
192                    }
193                    // Encountered an error whilst receiving a message from an inner stream
194                    Poll::Ready(Some(Err(e))) => {
195                        error!("Received invalid message from replier: {e:?}")
196                    }
197                    // Server has finished
198                    Poll::Ready(None) => {
199                        let si = &mut server.as_mut().as_pin_mut().unwrap().0;
200                        ready!(si.poll_flush_unpin(cx)).unwrap();
201                        ready!(sink.as_mut().poll_flush(cx)).unwrap();
202                        *server = None;
203                    }
204                    // No messages are available at this time
205                    Poll::Pending => {
206                        server_pending = true;
207                    }
208                }
209            }
210
211            // If we've got a reply buffered already, we need to write it to the sink
212            // before we can do anything else.
213            if buffered_rep.is_some() {
214                // Unwrapping is safe as the underlying sink is guaranteed not to error
215                ready!(sink.as_mut().poll_ready(cx)).unwrap();
216
217                let r = sink.as_mut().start_send(buffered_rep.take().unwrap());
218
219                if let Some(e) = r.err() {
220                    error!("Failed to send reply to requestor: {e:?}");
221                }
222            }
223
224            match stream.as_mut().poll_next(cx) {
225                // Received message from a client stream
226                Poll::Ready(Some((id, Ok(item)))) => {
227                    let mut payload = item.unwrap_message();
228                    payload
229                        .headers
230                        .get_or_insert(HashMap::new())
231                        .insert("cid".into(), format!("{id}"));
232                    *buffered_req = Some(Frame::Message(payload));
233                }
234                // Encountered an error whilst receiving a message from an inner stream
235                Poll::Ready(Some((_, Err(e)))) => {
236                    error!("Received invalid message from requestor: {e:?}")
237                }
238                // All streams have finished
239                Poll::Ready(None) => {
240                    // Unwrapping is safe as the underlying sink is guaranteed not to error
241                    ready!(sink.as_mut().poll_flush(cx)).unwrap();
242
243                    if server.is_some() {
244                        let si = &mut server.as_mut().as_pin_mut().unwrap().0;
245                        ready!(si.poll_flush_unpin(cx)).unwrap();
246                    }
247                }
248                // No messages are available at this time
249                Poll::Pending => {
250                    stream_pending = true;
251                }
252            }
253
254            if server_pending && stream_pending {
255                // Unwrapping is safe as the underlying sink is guaranteed not to error
256                ready!(sink.poll_flush(cx)).unwrap();
257
258                if server.is_some() {
259                    let si = &mut server.as_mut().as_pin_mut().unwrap().0;
260                    ready!(si.poll_flush_unpin(cx)).unwrap();
261                }
262
263                return Poll::Pending;
264            }
265        }
266    }
267}