1use std::collections::*;
4use std::time;
5
6use ecksport_core::peer::PeerData;
7
8use ecksport_core::frame::{CloseData, FrameBody, MsgFlags, NotificationData, OpenData, PushData};
9use ecksport_core::state_mach::{ClientMeta, ServerMeta};
10use ecksport_core::topic;
11use ecksport_core::traits::{AsyncRecvFrame, AsyncSendFrame, AuthConfig};
12
13use crate::channel_state::{self, Creator};
14use crate::errors::Error;
15use crate::event::{InbEvent, PushFlags};
16use crate::handshake::{self, do_server_handshake_async};
17
18pub struct Connection<T> {
21 inner: T,
22
23 protocol: topic::Topic,
25
26 initiator: Creator,
28
29 peer: PeerData,
31
32 chan_tbl: channel_state::ChannelTable,
34
35 event_queue: VecDeque<InbEvent>,
38}
39
40impl<T> Connection<T> {
41 fn new(inner: T, initiator: Creator, protocol: topic::Topic, peer: PeerData) -> Self {
43 Self {
44 inner,
45 protocol,
46 initiator,
47 peer,
48 chan_tbl: channel_state::ChannelTable::new(initiator),
49 event_queue: VecDeque::new(),
50 }
51 }
52
53 pub fn inner(&self) -> &T {
54 &self.inner
55 }
56
57 pub fn inner_mut(&mut self) -> &mut T {
60 &mut self.inner
61 }
62
63 pub fn into_inner(self) -> T {
64 self.inner
65 }
66
67 pub fn protocol(&self) -> topic::Topic {
68 self.protocol
69 }
70
71 pub fn initiator(&self) -> Creator {
72 self.initiator
73 }
74
75 pub fn peer_data(&self) -> &PeerData {
76 &self.peer
77 }
78
79 pub fn num_open_channels(&self) -> usize {
83 self.chan_tbl.num_open_channels()
84 }
85
86 pub fn has_pending_events(&self) -> bool {
88 !self.event_queue.is_empty()
89 }
90
91 fn handle_frame(&mut self, frame: FrameBody) -> Result<(), Error> {
92 match frame {
93 FrameBody::OpenChan(open) => {
94 let close = open.close();
96 let flags = PushFlags::from(open.flags());
97 let id = self.chan_tbl.init_remote_chan(open.topic(), close);
98
99 self.event_queue.push_back(InbEvent::NewChannel(
101 id,
102 open.topic(),
103 flags,
104 open.into_payload(),
105 ));
106 if close {
107 self.event_queue.push_back(InbEvent::CloseChannel(id, true))
108 }
109
110 Ok(())
111 }
112
113 FrameBody::PushChan(push) => {
114 let id = push.chan_id();
115 self.chan_tbl.check_recv_on_chan(id)?;
116 let flags = PushFlags::from(push.flags());
117 let close = push.close();
118
119 self.event_queue
120 .push_back(InbEvent::PushChannel(id, flags, push.into_payload()));
121
122 if close {
123 let removed = self
124 .chan_tbl
125 .mark_chan_remote_closed(id)
126 .expect("connection: close remote");
127
128 self.event_queue
129 .push_back(InbEvent::CloseChannel(id, !removed));
130 }
131
132 Ok(())
133 }
134
135 FrameBody::CloseChan(close) => {
136 self.chan_tbl.check_recv_on_chan(close.chan_id())?;
137
138 let removed = self
139 .chan_tbl
140 .mark_chan_remote_closed(close.chan_id())
141 .expect("connection: close remote");
142
143 self.event_queue
144 .push_back(InbEvent::CloseChannel(close.chan_id(), !removed));
145 Ok(())
146 }
147
148 FrameBody::Notification(notif) => {
149 let topic = notif.topic();
150 self.event_queue
151 .push_back(InbEvent::Notification(topic, notif.into_payload()));
152 Ok(())
153 }
154
155 _ => Err(Error::UnexpectedFrame(frame.ty())),
156 }
157 }
158}
159
160impl<T: AsyncRecvFrame> Connection<T> {
161 async fn recv_frame(&mut self) -> Result<(), Error> {
162 let frame = self.inner.recv_frame_async().await?;
163 self.handle_frame(frame)?;
164 Ok(())
165 }
166
167 pub async fn next_event(&mut self) -> Result<Option<InbEvent>, Error> {
170 if !self.event_queue.is_empty() {
171 return Ok(Some(self.event_queue.pop_front().unwrap()));
172 }
173
174 self.recv_frame().await?;
175 Ok(self.event_queue.pop_front())
176 }
177}
178
179impl<T: AsyncSendFrame> Connection<T> {
180 pub async fn open_channel(
183 &mut self,
184 topic: topic::Topic,
185 payload: Vec<u8>,
186 flags: MsgFlags,
187 ) -> Result<u32, Error> {
188 let im_close = flags.close;
189 let open_data = OpenData::new(topic, flags, payload);
190 let frame = FrameBody::OpenChan(open_data);
191 self.inner.send_frame_async(&frame).await?;
192 let id = self.chan_tbl.init_local_chan(topic, !im_close);
193 Ok(id)
194 }
195
196 pub async fn send_message(
199 &mut self,
200 chan_id: u32,
201 payload: Vec<u8>,
202 flags: MsgFlags,
203 ) -> Result<bool, Error> {
204 self.chan_tbl.check_send_on_chan(chan_id)?;
205
206 let push_data = PushData::new(chan_id, flags, payload);
207 let frame = FrameBody::PushChan(push_data);
208 self.inner.send_frame_async(&frame).await?;
209
210 if flags.close {
211 let removed = self
212 .chan_tbl
213 .mark_chan_local_closed(chan_id)
214 .expect("connection: close local");
215 Ok(!removed)
216 } else {
217 Ok(true)
218 }
219 }
220
221 pub async fn close_channel(&mut self, chan_id: u32) -> Result<bool, Error> {
224 self.chan_tbl.check_send_on_chan(chan_id)?;
225
226 let close_data = CloseData::new(chan_id);
227 let frame = FrameBody::CloseChan(close_data);
228 self.inner.send_frame_async(&frame).await?;
229
230 let removed = self
231 .chan_tbl
232 .mark_chan_local_closed(chan_id)
233 .expect("connection: close local");
234
235 Ok(!removed)
236 }
237
238 pub async fn send_notification(
240 &mut self,
241 topic: topic::Topic,
242 message: Vec<u8>,
243 ) -> Result<(), Error> {
244 let notif_data = NotificationData::new(topic, message);
245 let frame = FrameBody::Notification(notif_data);
246 self.inner.send_frame_async(&frame).await?;
247 Ok(())
248 }
249}
250
251#[derive(Clone, Debug)]
252pub struct ConnectOptions {
253 pub timeout: time::Duration,
254 pub client_meta: ClientMeta,
255}
256
257impl Default for ConnectOptions {
258 fn default() -> Self {
259 Self {
260 timeout: time::Duration::from_millis(15000),
261 client_meta: ClientMeta::new("/ecksport/alpha/".to_owned()),
262 }
263 }
264}
265
266pub async fn perform_handshake_async<
271 T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
272 A: AuthConfig,
273>(
274 mut stream: T,
275 protocol: topic::Topic,
276 opts: ConnectOptions,
277 auth: A,
278 peer: PeerData,
279) -> Result<Connection<T>, Error> {
280 let hs_opts = handshake::HandshakeOptions::new(opts.timeout);
281
282 let hs = handshake::do_client_handshake_async(
284 &mut stream,
285 protocol,
286 &opts.client_meta,
287 &hs_opts,
288 auth,
289 peer,
290 )
291 .await?;
292 assert_eq!(hs.ready().protocol(), protocol);
293
294 let peer = hs.into_peer();
295 Ok(Connection::new(stream, Creator::Local, protocol, peer))
296}
297
298#[derive(Clone, Debug)]
299pub struct AcceptOptions {
300 pub timeout: time::Duration,
301 pub server_meta: ServerMeta,
302}
303
304impl Default for AcceptOptions {
305 fn default() -> Self {
306 Self {
307 timeout: time::Duration::from_millis(15000),
308 server_meta: ServerMeta::new("/ecksport/alpha/".to_owned(), Vec::new()),
309 }
310 }
311}
312
313pub async fn accept_connection_async<
318 T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
319 A: AuthConfig,
320>(
321 mut stream: T,
322 opts: AcceptOptions,
323 auth: A,
324 peer: PeerData,
325) -> Result<Option<Connection<T>>, Error> {
326 let hs_opts = handshake::HandshakeOptions::new(opts.timeout);
327
328 let Some(hs) =
330 do_server_handshake_async(&mut stream, &opts.server_meta, &hs_opts, auth, peer).await?
331 else {
332 return Ok(None);
333 };
334
335 let proto = hs.ready().protocol();
336 let peer = hs.into_peer();
337
338 let conn = Connection::new(stream, Creator::Remote, proto, peer);
339 Ok(Some(conn))
340}
341
342#[cfg(test)]
343mod tests {
344 use core::net;
345
346 use ecksport_core::{stream_framing, topic};
347
348 use crate::builder::ClientBuilder;
349
350 use super::*;
351
352 pub type TokioTcpConnection = Connection<StreamFramer<tokio::net::TcpStream>>;
354
355 async fn connect_tcp_tokio<A: AuthConfig>(
358 socket_addr: SocketAddr,
359 protocol: topic::Topic,
360 opts: ConnectOptions,
361 auth: A,
362 ) -> Result<TokioTcpConnection, Error> {
363 let socket_connect_fut = tokio::net::TcpStream::connect(socket_addr);
365 let sock = match timeout(opts.timeout, socket_connect_fut).await {
366 Ok(res) => res?,
367 Err(_) => return Err(Error::ConnectionTimeout),
368 };
369
370 let peer = PeerData::new_loc(Location::Ip(socket_addr));
371 let framer = StreamFramer::new(sock);
372 Ok(perform_handshake_async(framer, protocol, opts, auth, peer).await?)
373 }
374
375 #[tokio::test]
378 async fn test_connect_accept() {
379 let socket_addr = "127.0.0.1:5436"
380 .parse::<net::SocketAddr>()
381 .expect("test: parse addr");
382
383 let lis = tokio::net::TcpListener::bind(socket_addr)
384 .await
385 .expect("test: bind");
386
387 let proto = topic::Topic::from_const_str("TESTTEST");
388 let topic = topic::Topic::from_const_str("FOOOBARR");
389 let topic2 = topic::Topic::from_const_str("BAZZQUUX");
390 let mut acc_opts = AcceptOptions::default();
391 acc_opts.server_meta.add_protocol(proto);
392 let conn_opts = ConnectOptions::default();
393
394 let lj = tokio::spawn(async move {
395 let (sock, _sa) = lis.accept().await.expect("test: accept");
396 let framer = stream_framing::StreamFramer::new(sock);
397 let pd = PeerData::default();
398
399 let mut conn = accept_connection_async(framer, acc_opts, (), pd)
400 .await
401 .expect("test: server handshake")
402 .expect("test: create server connection");
403
404 let ev = conn
406 .next_event()
407 .await
408 .expect("test: accept event")
409 .expect("test: read event");
410 eprintln!("got event: {ev:?}");
411
412 assert_eq!(conn.num_open_channels(), 1);
413
414 conn.open_channel(topic2, vec![5, 6, 7, 8], MsgFlags::none())
416 .await
417 .expect("test: open channel");
418
419 assert_eq!(conn.num_open_channels(), 2);
420
421 let ev = conn
423 .next_event()
424 .await
425 .expect("test: recv frame")
426 .expect("test: recv event");
427 eprintln!("got event: {ev:?}");
428
429 assert_eq!(conn.num_open_channels(), 2);
430
431 conn.close_channel(1).await.expect("test: close channel");
432
433 eprintln!("closing channel 0 on the server side");
434 conn.close_channel(0)
435 .await
436 .expect("test: close client chan");
437 });
438
439 let cj = tokio::spawn(async move {
440 let mut conn = connect_tcp_tokio(socket_addr, proto, conn_opts, ())
441 .await
442 .expect("test: connect and handshake");
443
444 eprintln!("opening channel, will im_close");
445 let ch_idx = conn
446 .open_channel(topic, vec![1, 2, 3, 4], MsgFlags::close())
447 .await
448 .expect("test: open channel");
449 assert_eq!(ch_idx, 0);
450
451 assert_eq!(conn.num_open_channels(), 1);
452
453 let ev = conn
454 .next_event()
455 .await
456 .expect("test: recv frame")
457 .expect("test: recv event");
458 assert_eq!(ev.chan_id(), Some(1));
459
460 assert_eq!(conn.num_open_channels(), 2);
461 });
462
463 lj.await.expect("test: server side");
464 cj.await.expect("test: client side");
465
466 }
468}