ecksport_net/
connection.rs

1//! Connection bookkeeping to track and emit user-facing events.
2
3use std::collections::*;
4use std::time;
5
6use ecksport_core::peer::PeerData;
7
8use ecksport_core::frame::{CloseData, FrameBody, MsgFlags, NotificationData, OpenData, PushData};
9use ecksport_core::state_mach::{ClientMeta, ServerMeta};
10use ecksport_core::topic;
11use ecksport_core::traits::{AsyncRecvFrame, AsyncSendFrame, AuthConfig};
12
13use crate::channel_state::{self, Creator};
14use crate::errors::Error;
15use crate::event::{InbEvent, PushFlags};
16use crate::handshake::{self, do_server_handshake_async};
17
18/// Low-level connection type that handles basic channel idx bookkeeping and
19/// inbound event buffering.
20pub struct Connection<T> {
21    inner: T,
22
23    /// Protocol chosen by initiator.
24    protocol: topic::Topic,
25
26    /// The side that created the channel.
27    initiator: Creator,
28
29    /// Identity data about the peer figured out during handshake.
30    peer: PeerData,
31
32    /// Bookkeeping for channel state.
33    chan_tbl: channel_state::ChannelTable,
34
35    /// Buffer for events that have been produced in response to frames but not
36    /// emitted yet.
37    event_queue: VecDeque<InbEvent>,
38}
39
40impl<T> Connection<T> {
41    /// Wraps an existing connection that's already been handshaked.
42    fn new(inner: T, initiator: Creator, protocol: topic::Topic, peer: PeerData) -> Self {
43        Self {
44            inner,
45            protocol,
46            initiator,
47            peer,
48            chan_tbl: channel_state::ChannelTable::new(initiator),
49            event_queue: VecDeque::new(),
50        }
51    }
52
53    pub fn inner(&self) -> &T {
54        &self.inner
55    }
56
57    /// Returns a mutable reference to the inner recv instance.  Note that using
58    /// this to recv messages can break the bookkeeping, so use caution.
59    pub fn inner_mut(&mut self) -> &mut T {
60        &mut self.inner
61    }
62
63    pub fn into_inner(self) -> T {
64        self.inner
65    }
66
67    pub fn protocol(&self) -> topic::Topic {
68        self.protocol
69    }
70
71    pub fn initiator(&self) -> Creator {
72        self.initiator
73    }
74
75    pub fn peer_data(&self) -> &PeerData {
76        &self.peer
77    }
78
79    /// Returns the number of open channels in the internal channel table.  This
80    /// includes channels opened in response to pending events that are buffered
81    /// and need to be processed.
82    pub fn num_open_channels(&self) -> usize {
83        self.chan_tbl.num_open_channels()
84    }
85
86    /// Returns if there's events that have been produced but not consumed yet.
87    pub fn has_pending_events(&self) -> bool {
88        !self.event_queue.is_empty()
89    }
90
91    fn handle_frame(&mut self, frame: FrameBody) -> Result<(), Error> {
92        match frame {
93            FrameBody::OpenChan(open) => {
94                // Set up new channel bookkeeping.
95                let close = open.close();
96                let flags = PushFlags::from(open.flags());
97                let id = self.chan_tbl.init_remote_chan(open.topic(), close);
98
99                // Insert appropriate events.
100                self.event_queue.push_back(InbEvent::NewChannel(
101                    id,
102                    open.topic(),
103                    flags,
104                    open.into_payload(),
105                ));
106                if close {
107                    self.event_queue.push_back(InbEvent::CloseChannel(id, true))
108                }
109
110                Ok(())
111            }
112
113            FrameBody::PushChan(push) => {
114                let id = push.chan_id();
115                self.chan_tbl.check_recv_on_chan(id)?;
116                let flags = PushFlags::from(push.flags());
117                let close = push.close();
118
119                self.event_queue
120                    .push_back(InbEvent::PushChannel(id, flags, push.into_payload()));
121
122                if close {
123                    let removed = self
124                        .chan_tbl
125                        .mark_chan_remote_closed(id)
126                        .expect("connection: close remote");
127
128                    self.event_queue
129                        .push_back(InbEvent::CloseChannel(id, !removed));
130                }
131
132                Ok(())
133            }
134
135            FrameBody::CloseChan(close) => {
136                self.chan_tbl.check_recv_on_chan(close.chan_id())?;
137
138                let removed = self
139                    .chan_tbl
140                    .mark_chan_remote_closed(close.chan_id())
141                    .expect("connection: close remote");
142
143                self.event_queue
144                    .push_back(InbEvent::CloseChannel(close.chan_id(), !removed));
145                Ok(())
146            }
147
148            FrameBody::Notification(notif) => {
149                let topic = notif.topic();
150                self.event_queue
151                    .push_back(InbEvent::Notification(topic, notif.into_payload()));
152                Ok(())
153            }
154
155            _ => Err(Error::UnexpectedFrame(frame.ty())),
156        }
157    }
158}
159
160impl<T: AsyncRecvFrame> Connection<T> {
161    async fn recv_frame(&mut self) -> Result<(), Error> {
162        let frame = self.inner.recv_frame_async().await?;
163        self.handle_frame(frame)?;
164        Ok(())
165    }
166
167    /// Takes the next event out of the buffer or, if empty, waits for a new
168    /// event to be produced.
169    pub async fn next_event(&mut self) -> Result<Option<InbEvent>, Error> {
170        if !self.event_queue.is_empty() {
171            return Ok(Some(self.event_queue.pop_front().unwrap()));
172        }
173
174        self.recv_frame().await?;
175        Ok(self.event_queue.pop_front())
176    }
177}
178
179impl<T: AsyncSendFrame> Connection<T> {
180    /// Opens a channel on a topic with a channel, optionally closing our end of
181    /// it immediately.
182    pub async fn open_channel(
183        &mut self,
184        topic: topic::Topic,
185        payload: Vec<u8>,
186        flags: MsgFlags,
187    ) -> Result<u32, Error> {
188        let im_close = flags.close;
189        let open_data = OpenData::new(topic, flags, payload);
190        let frame = FrameBody::OpenChan(open_data);
191        self.inner.send_frame_async(&frame).await?;
192        let id = self.chan_tbl.init_local_chan(topic, !im_close);
193        Ok(id)
194    }
195
196    /// Sends a message on a channel that we haven't closed ourselves yet.
197    /// Returns if the channel is kept alive after this.
198    pub async fn send_message(
199        &mut self,
200        chan_id: u32,
201        payload: Vec<u8>,
202        flags: MsgFlags,
203    ) -> Result<bool, Error> {
204        self.chan_tbl.check_send_on_chan(chan_id)?;
205
206        let push_data = PushData::new(chan_id, flags, payload);
207        let frame = FrameBody::PushChan(push_data);
208        self.inner.send_frame_async(&frame).await?;
209
210        if flags.close {
211            let removed = self
212                .chan_tbl
213                .mark_chan_local_closed(chan_id)
214                .expect("connection: close local");
215            Ok(!removed)
216        } else {
217            Ok(true)
218        }
219    }
220
221    /// Close a channel without sending a message.  Returns if the channel is
222    /// kept alive after this.
223    pub async fn close_channel(&mut self, chan_id: u32) -> Result<bool, Error> {
224        self.chan_tbl.check_send_on_chan(chan_id)?;
225
226        let close_data = CloseData::new(chan_id);
227        let frame = FrameBody::CloseChan(close_data);
228        self.inner.send_frame_async(&frame).await?;
229
230        let removed = self
231            .chan_tbl
232            .mark_chan_local_closed(chan_id)
233            .expect("connection: close local");
234
235        Ok(!removed)
236    }
237
238    /// Sends a notification on a topic.
239    pub async fn send_notification(
240        &mut self,
241        topic: topic::Topic,
242        message: Vec<u8>,
243    ) -> Result<(), Error> {
244        let notif_data = NotificationData::new(topic, message);
245        let frame = FrameBody::Notification(notif_data);
246        self.inner.send_frame_async(&frame).await?;
247        Ok(())
248    }
249}
250
251#[derive(Clone, Debug)]
252pub struct ConnectOptions {
253    pub timeout: time::Duration,
254    pub client_meta: ClientMeta,
255}
256
257impl Default for ConnectOptions {
258    fn default() -> Self {
259        Self {
260            timeout: time::Duration::from_millis(15000),
261            client_meta: ClientMeta::new("/ecksport/alpha/".to_owned()),
262        }
263    }
264}
265
266/// Takes a freshly constructed connection, available as frames, and performs
267/// the client side of a normal handshake to select the specified protocol,
268/// returning a new low-level connection if the handshake completes
269/// successfully.
270pub async fn perform_handshake_async<
271    T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
272    A: AuthConfig,
273>(
274    mut stream: T,
275    protocol: topic::Topic,
276    opts: ConnectOptions,
277    auth: A,
278    peer: PeerData,
279) -> Result<Connection<T>, Error> {
280    let hs_opts = handshake::HandshakeOptions::new(opts.timeout);
281
282    // Perform the actual client handshake.
283    let hs = handshake::do_client_handshake_async(
284        &mut stream,
285        protocol,
286        &opts.client_meta,
287        &hs_opts,
288        auth,
289        peer,
290    )
291    .await?;
292    assert_eq!(hs.ready().protocol(), protocol);
293
294    let peer = hs.into_peer();
295    Ok(Connection::new(stream, Creator::Local, protocol, peer))
296}
297
298#[derive(Clone, Debug)]
299pub struct AcceptOptions {
300    pub timeout: time::Duration,
301    pub server_meta: ServerMeta,
302}
303
304impl Default for AcceptOptions {
305    fn default() -> Self {
306        Self {
307            timeout: time::Duration::from_millis(15000),
308            server_meta: ServerMeta::new("/ecksport/alpha/".to_owned(), Vec::new()),
309        }
310    }
311}
312
313/// Takes a freshly accepted connection and performs the server side of a
314/// handshake, returning a new low-level connection if the handshake completes
315/// successfully.
316// TODO figure out how to reuse the options on each accept
317pub async fn accept_connection_async<
318    T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
319    A: AuthConfig,
320>(
321    mut stream: T,
322    opts: AcceptOptions,
323    auth: A,
324    peer: PeerData,
325) -> Result<Option<Connection<T>>, Error> {
326    let hs_opts = handshake::HandshakeOptions::new(opts.timeout);
327
328    // The server side is simpler.
329    let Some(hs) =
330        do_server_handshake_async(&mut stream, &opts.server_meta, &hs_opts, auth, peer).await?
331    else {
332        return Ok(None);
333    };
334
335    let proto = hs.ready().protocol();
336    let peer = hs.into_peer();
337
338    let conn = Connection::new(stream, Creator::Remote, proto, peer);
339    Ok(Some(conn))
340}
341
342#[cfg(test)]
343mod tests {
344    use core::net;
345
346    use ecksport_core::{stream_framing, topic};
347
348    use crate::builder::ClientBuilder;
349
350    use super::*;
351
352    /// Simple connection over a TCP stream.
353    pub type TokioTcpConnection = Connection<StreamFramer<tokio::net::TcpStream>>;
354
355    /// Connects to a TCP socket with a Tokio channel, requests to use a particular
356    /// protocol, and returns a low-level connection to be used.
357    async fn connect_tcp_tokio<A: AuthConfig>(
358        socket_addr: SocketAddr,
359        protocol: topic::Topic,
360        opts: ConnectOptions,
361        auth: A,
362    ) -> Result<TokioTcpConnection, Error> {
363        // Do the initial connection.
364        let socket_connect_fut = tokio::net::TcpStream::connect(socket_addr);
365        let sock = match timeout(opts.timeout, socket_connect_fut).await {
366            Ok(res) => res?,
367            Err(_) => return Err(Error::ConnectionTimeout),
368        };
369
370        let peer = PeerData::new_loc(Location::Ip(socket_addr));
371        let framer = StreamFramer::new(sock);
372        Ok(perform_handshake_async(framer, protocol, opts, auth, peer).await?)
373    }
374
375    /// Basic connection test that opens and closes a few channels and does some
376    /// basic checks that they make sense.
377    #[tokio::test]
378    async fn test_connect_accept() {
379        let socket_addr = "127.0.0.1:5436"
380            .parse::<net::SocketAddr>()
381            .expect("test: parse addr");
382
383        let lis = tokio::net::TcpListener::bind(socket_addr)
384            .await
385            .expect("test: bind");
386
387        let proto = topic::Topic::from_const_str("TESTTEST");
388        let topic = topic::Topic::from_const_str("FOOOBARR");
389        let topic2 = topic::Topic::from_const_str("BAZZQUUX");
390        let mut acc_opts = AcceptOptions::default();
391        acc_opts.server_meta.add_protocol(proto);
392        let conn_opts = ConnectOptions::default();
393
394        let lj = tokio::spawn(async move {
395            let (sock, _sa) = lis.accept().await.expect("test: accept");
396            let framer = stream_framing::StreamFramer::new(sock);
397            let pd = PeerData::default();
398
399            let mut conn = accept_connection_async(framer, acc_opts, (), pd)
400                .await
401                .expect("test: server handshake")
402                .expect("test: create server connection");
403
404            // Should be an open with data.
405            let ev = conn
406                .next_event()
407                .await
408                .expect("test: accept event")
409                .expect("test: read event");
410            eprintln!("got event: {ev:?}");
411
412            assert_eq!(conn.num_open_channels(), 1);
413
414            // Open a new channel.
415            conn.open_channel(topic2, vec![5, 6, 7, 8], MsgFlags::none())
416                .await
417                .expect("test: open channel");
418
419            assert_eq!(conn.num_open_channels(), 2);
420
421            // Should be a close.
422            let ev = conn
423                .next_event()
424                .await
425                .expect("test: recv frame")
426                .expect("test: recv event");
427            eprintln!("got event: {ev:?}");
428
429            assert_eq!(conn.num_open_channels(), 2);
430
431            conn.close_channel(1).await.expect("test: close channel");
432
433            eprintln!("closing channel 0 on the server side");
434            conn.close_channel(0)
435                .await
436                .expect("test: close client chan");
437        });
438
439        let cj = tokio::spawn(async move {
440            let mut conn = connect_tcp_tokio(socket_addr, proto, conn_opts, ())
441                .await
442                .expect("test: connect and handshake");
443
444            eprintln!("opening channel, will im_close");
445            let ch_idx = conn
446                .open_channel(topic, vec![1, 2, 3, 4], MsgFlags::close())
447                .await
448                .expect("test: open channel");
449            assert_eq!(ch_idx, 0);
450
451            assert_eq!(conn.num_open_channels(), 1);
452
453            let ev = conn
454                .next_event()
455                .await
456                .expect("test: recv frame")
457                .expect("test: recv event");
458            assert_eq!(ev.chan_id(), Some(1));
459
460            assert_eq!(conn.num_open_channels(), 2);
461        });
462
463        lj.await.expect("test: server side");
464        cj.await.expect("test: client side");
465
466        // test
467    }
468}