Skip to main content

tehuti_socket/
lib.rs

1use std::{
2    collections::{BTreeMap, btree_map::Entry},
3    error::Error,
4    future::pending,
5    io::{Cursor, ErrorKind, Read, Write},
6    net::{TcpListener, TcpStream},
7    sync::Arc,
8    thread::{Builder, JoinHandle, sleep},
9    time::Duration,
10    vec,
11};
12use tehuti::{
13    engine::EnginePeerDescriptor,
14    event::{Duplex, Receiver, Sender, unbounded},
15    meeting::{Meeting, MeetingEngineEvent, MeetingInterface, MeetingInterfaceResult},
16    peer::{PeerFactory, PeerId},
17    protocol::{ProtocolControlFrame, ProtocolFrame, ProtocolPacketFrame},
18};
19
20pub struct TcpHost {
21    listener: TcpListener,
22    factory: Arc<PeerFactory>,
23}
24
25impl TcpHost {
26    pub fn make(listener: TcpListener, factory: Arc<PeerFactory>) -> Result<Self, Box<dyn Error>> {
27        listener.set_nonblocking(true)?;
28        Ok(TcpHost { listener, factory })
29    }
30
31    pub fn accept(&self) -> Result<Option<TcpSessionResult>, Box<dyn Error>> {
32        match self.listener.accept() {
33            Ok((stream, _)) => match TcpSession::make(stream, self.factory.clone()) {
34                Ok(session_result) => {
35                    return Ok(Some(session_result));
36                }
37                Err(err) => {
38                    tracing::event!(
39                        target: "tehuti::socket::host",
40                        tracing::Level::ERROR,
41                        "Failed to create session: {}",
42                        err,
43                    );
44                }
45            },
46            Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
47            Err(err) => {
48                tracing::event!(
49                    target: "tehuti::socket::host",
50                    tracing::Level::ERROR,
51                    "Failed to accept incoming connection: {}",
52                    err,
53                );
54            }
55        }
56        Ok(None)
57    }
58
59    pub async fn accept_async(&self) -> Result<TcpSessionResult, Box<dyn Error>> {
60        loop {
61            match self.accept() {
62                Ok(Some(session_result)) => {
63                    return Ok(session_result);
64                }
65                Ok(None) => {
66                    pending::<()>().await;
67                }
68                Err(err) => {
69                    tracing::event!(
70                        target: "tehuti::socket::host",
71                        tracing::Level::ERROR,
72                        "Session listener encountered error: {}",
73                        err,
74                    );
75                    return Err(err);
76                }
77            }
78        }
79    }
80
81    pub fn run(
82        self,
83        interval: Duration,
84        session_interval: Duration,
85        meeting_sender: Sender<(MeetingInterface, Sender<()>)>,
86    ) -> Result<JoinHandle<()>, Box<dyn Error>> {
87        Ok(Builder::new()
88            .name("Session Listener".to_string())
89            .spawn(move || {
90                loop {
91                    match self.accept() {
92                        Ok(Some(session_result)) => {
93                            let (terminate_sender, terminate_receiver) = unbounded();
94                            let TcpSessionResult { session, interface } = session_result;
95                            if let Err(err) = meeting_sender.send((interface, terminate_sender)) {
96                                tracing::event!(
97                                    target: "tehuti::socket::host",
98                                    tracing::Level::ERROR,
99                                    "Failed to send meeting interface to engine: {}",
100                                    err,
101                                );
102                            }
103                            if let Err(err) = session.run(session_interval, terminate_receiver) {
104                                tracing::event!(
105                                    target: "tehuti::socket::host",
106                                    tracing::Level::ERROR,
107                                    "Failed to run session: {}",
108                                    err,
109                                );
110                            }
111                        }
112                        Ok(None) => {}
113                        Err(err) => {
114                            tracing::event!(
115                                target: "tehuti::socket::host",
116                                tracing::Level::ERROR,
117                                "Session listener encountered error: {}",
118                                err,
119                            );
120                        }
121                    }
122                    sleep(interval);
123                }
124            })?)
125    }
126}
127
128pub struct TcpSessionResult {
129    pub session: TcpSession,
130    pub interface: MeetingInterface,
131}
132
133pub struct TcpSession {
134    id: String,
135    stream: TcpStream,
136    meeting: Meeting,
137    engine_event: Duplex<MeetingEngineEvent>,
138    peers: BTreeMap<PeerId, EnginePeerDescriptor>,
139    buffer_in: Vec<u8>,
140    buffer_out: Vec<u8>,
141    terminated: bool,
142}
143
144impl TcpSession {
145    pub fn make(
146        stream: TcpStream,
147        factory: Arc<PeerFactory>,
148    ) -> Result<TcpSessionResult, Box<dyn Error>> {
149        let id = format!("{}<->{}", stream.local_addr()?, stream.peer_addr()?);
150        let MeetingInterfaceResult {
151            meeting,
152            interface,
153            engine_event,
154        } = MeetingInterface::make(factory, id.clone());
155        stream.set_nonblocking(true)?;
156        Ok(TcpSessionResult {
157            session: TcpSession {
158                id,
159                stream,
160                meeting,
161                engine_event,
162                peers: Default::default(),
163                buffer_in: Default::default(),
164                buffer_out: Default::default(),
165                terminated: false,
166            },
167            interface,
168        })
169    }
170
171    pub fn maintain(&mut self) -> Result<(), Box<dyn Error>> {
172        if self.terminated {
173            return Err(format!("Session {} is terminated", self.id).into());
174        }
175        self.receive_frames()?;
176        self.send_frames()?;
177        Ok(())
178    }
179
180    fn receive_frames(&mut self) -> Result<(), Box<dyn Error>> {
181        let mut buffer = vec![0u8; 4096];
182        loop {
183            match self.stream.read(&mut buffer) {
184                Ok(0) => break,
185                Ok(n) => {
186                    self.buffer_in.extend_from_slice(&buffer[..n]);
187                }
188                Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
189                Err(e) => return Err(Box::new(e)),
190            }
191        }
192        let mut cursor = Cursor::new(&self.buffer_in);
193        let mut frames = Vec::new();
194        loop {
195            match ProtocolFrame::read(&mut cursor) {
196                Ok(frame) => frames.push(frame),
197                Err(ref e) if e.kind() == ErrorKind::UnexpectedEof => break,
198                Err(e) => return Err(Box::new(e)),
199            }
200        }
201        let pos = cursor.position() as usize;
202        self.buffer_in.drain(0..pos);
203        for frame in frames {
204            match frame {
205                ProtocolFrame::Control(frame) => match frame {
206                    ProtocolControlFrame::CreatePeer(peer_id, peer_role_id) => {
207                        self.engine_event
208                            .sender
209                            .send(MeetingEngineEvent::PeerJoined(peer_id, peer_role_id))
210                            .map_err(|err| {
211                                format!("Session {} outside engine sender error: {err}", self.id)
212                            })
213                            .unwrap();
214                    }
215                    ProtocolControlFrame::DestroyPeer(peer_id) => {
216                        self.engine_event
217                            .sender
218                            .send(MeetingEngineEvent::PeerLeft(peer_id))
219                            .map_err(|err| {
220                                format!("Session {} outside engine sender error: {err}", self.id)
221                            })
222                            .unwrap();
223                    }
224                    _ => {
225                        tracing::event!(
226                            target: "tehuti::socket::session",
227                            tracing::Level::WARN,
228                            "Session {} got unhandled control frame: {:?}",
229                            self.id,
230                            frame,
231                        );
232                    }
233                },
234                ProtocolFrame::Packet(frame) => {
235                    if let Some(peer) = self.peers.get(&frame.peer_id) {
236                        if let Some(sender) = peer.packet_senders.get(&frame.channel_id) {
237                            sender
238                                .sender
239                                .send(frame.data)
240                                .map_err(|err| {
241                                    format!("Session {} packet sender error: {err}", self.id)
242                                })
243                                .unwrap();
244                        } else {
245                            tracing::event!(
246                                target: "tehuti::socket::session",
247                                tracing::Level::WARN,
248                                "Session {} got packet frame for unknown channel {:?} of peer {:?}",
249                                self.id,
250                                frame.channel_id,
251                                frame.peer_id,
252                            );
253                        }
254                    } else {
255                        tracing::event!(
256                            target: "tehuti::socket::session",
257                            tracing::Level::WARN,
258                            "Session {} got packet frame for unknown peer {:?}",
259                            self.id,
260                            frame.peer_id,
261                        );
262                    }
263                }
264            }
265        }
266        Ok(())
267    }
268
269    fn send_frames(&mut self) -> Result<(), Box<dyn Error>> {
270        for peer in self.peers.values() {
271            for (channel_id, receiver) in &peer.packet_receivers {
272                for data in receiver.receiver.iter() {
273                    ProtocolFrame::Packet(ProtocolPacketFrame {
274                        peer_id: peer.info.peer_id,
275                        channel_id: *channel_id,
276                        data,
277                    })
278                    .write(&mut self.buffer_out)?;
279                }
280            }
281        }
282        if let Err(err) = self.meeting.pump_all() {
283            tracing::event!(
284                target: "tehuti::socket::session",
285                tracing::Level::ERROR,
286                "Session {} encountered error: {}. Terminating",
287                self.id,
288                err,
289            );
290            self.terminated = true;
291            return Err(err);
292        }
293        for event in self.engine_event.receiver.iter() {
294            match event {
295                MeetingEngineEvent::MeetingDestroyed => {
296                    tracing::event!(
297                        target: "tehuti::socket::session",
298                        tracing::Level::TRACE,
299                        "Session {} terminating",
300                        self.id,
301                    );
302                    self.terminated = true;
303                    return Err(format!("Session {} meeting destroyed", self.id).into());
304                }
305                MeetingEngineEvent::PeerCreated(descriptor) => {
306                    if let Entry::Vacant(entry) = self.peers.entry(descriptor.info.peer_id) {
307                        if !descriptor.info.remote {
308                            ProtocolFrame::Control(ProtocolControlFrame::CreatePeer(
309                                descriptor.info.peer_id,
310                                descriptor.info.role_id,
311                            ))
312                            .write(&mut self.buffer_out)?;
313                        }
314                        tracing::event!(
315                            target: "tehuti::socket::session",
316                            tracing::Level::TRACE,
317                            "Session {} created peer {:?}",
318                            self.id,
319                            descriptor.info.peer_id,
320                        );
321                        entry.insert(descriptor);
322                    } else {
323                        tracing::event!(
324                            target: "tehuti::socket::session",
325                            tracing::Level::WARN,
326                            "Session {} got duplicate peer {:?} created",
327                            self.id,
328                            descriptor.info.peer_id,
329                        );
330                    }
331                }
332                MeetingEngineEvent::PeerDestroyed(peer_id) => {
333                    if self.peers.contains_key(&peer_id) {
334                        ProtocolFrame::Control(ProtocolControlFrame::DestroyPeer(peer_id))
335                            .write(&mut self.buffer_out)?;
336                        tracing::event!(
337                            target: "tehuti::socket::session",
338                            tracing::Level::TRACE,
339                            "Session {} destroyed peer {:?}",
340                            self.id,
341                            peer_id,
342                        );
343                        self.peers.remove(&peer_id);
344                    } else {
345                        tracing::event!(
346                            target: "tehuti::socket::session",
347                            tracing::Level::WARN,
348                            "Session {} got unknown peer {:?} destroyed",
349                            self.id,
350                            peer_id,
351                        );
352                    }
353                }
354                event => {
355                    tracing::event!(
356                        target: "tehuti::socket::session",
357                        tracing::Level::WARN,
358                        "Session {} got unhandled engine event: {:?}",
359                        self.id,
360                        event,
361                    );
362                }
363            }
364        }
365        loop {
366            match self.stream.write(&self.buffer_out) {
367                Ok(0) => break,
368                Ok(n) => {
369                    self.buffer_out.drain(0..n);
370                }
371                Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
372                Err(e) => return Err(Box::new(e)),
373            }
374        }
375        Ok(())
376    }
377
378    pub fn run(
379        mut self,
380        interval: Duration,
381        terminate_receiver: Receiver<()>,
382    ) -> Result<JoinHandle<()>, Box<dyn Error>> {
383        Ok(Builder::new()
384            .name(format!("Session {}", self.id))
385            .spawn(move || {
386                loop {
387                    if terminate_receiver.try_recv().is_some() {
388                        tracing::event!(
389                            target: "tehuti::socket::session",
390                            tracing::Level::TRACE,
391                            "Session {} terminating on request",
392                            self.id,
393                        );
394                        break;
395                    }
396                    if let Err(err) = self.maintain() {
397                        tracing::event!(
398                            target: "tehuti::socket::session",
399                            tracing::Level::ERROR,
400                            "Session {} terminated with error: {}",
401                            self.id,
402                            err,
403                        );
404                        break;
405                    }
406                    sleep(interval);
407                }
408            })?)
409    }
410
411    pub async fn into_future(mut self) -> Result<(), Box<dyn Error>> {
412        loop {
413            if let Err(err) = self.maintain() {
414                tracing::event!(
415                    target: "tehuti::socket::session",
416                    tracing::Level::ERROR,
417                    "Session {} terminated with error: {}",
418                    self.id,
419                    err,
420                );
421                break;
422            }
423            pending::<()>().await;
424        }
425        Ok(())
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use tehuti::{
433        channel::{ChannelId, ChannelMode},
434        meeting::MeetingUserEvent,
435        peer::{PeerBuilder, PeerDestructurer, PeerRoleId, TypedPeer},
436    };
437    use tehuti_mock::{mock_env_tracing, mock_recv_matching};
438
439    struct Chatter {
440        pub sender: Sender<String>,
441        pub receiver: Receiver<String>,
442    }
443
444    impl TypedPeer for Chatter {
445        fn builder(builder: PeerBuilder) -> PeerBuilder {
446            builder.bind_read_write::<String, String>(
447                ChannelId::new(0),
448                ChannelMode::ReliableOrdered,
449                None,
450            )
451        }
452
453        fn into_typed(mut destructurer: PeerDestructurer) -> Result<Self, Box<dyn Error>> {
454            let sender = destructurer.write::<String>(ChannelId::new(0))?;
455            let receiver = destructurer.read::<String>(ChannelId::new(0))?;
456            Ok(Self { sender, receiver })
457        }
458    }
459
460    #[test]
461    fn test_tcp_session_creation() {
462        mock_env_tracing();
463
464        let factory = Arc::new(PeerFactory::default().with(PeerRoleId::new(0), Chatter::builder));
465
466        let listener = TcpListener::bind("127.0.0.1:8888").unwrap();
467        let host = TcpHost::make(listener, factory.clone()).unwrap();
468        let stream = TcpStream::connect("127.0.0.1:8888").unwrap();
469
470        let TcpSessionResult {
471            session,
472            interface: meeting_client,
473        } = TcpSession::make(stream, factory).unwrap();
474        let (terminate_client, terminate_receiver) = unbounded();
475        let session_client = session.run(Duration::ZERO, terminate_receiver).unwrap();
476
477        sleep(Duration::from_millis(100));
478
479        let TcpSessionResult {
480            session,
481            interface: meeting_server,
482        } = host.accept().unwrap().unwrap();
483        let (terminate_server, terminate_receiver) = unbounded();
484        let session_server = session.run(Duration::ZERO, terminate_receiver).unwrap();
485
486        meeting_server
487            .sender
488            .send(MeetingUserEvent::PeerCreate(
489                PeerId::new(0),
490                PeerRoleId::new(0),
491            ))
492            .unwrap();
493
494        let peer_server = mock_recv_matching!(
495            meeting_server.receiver,
496            Duration::from_secs(1),
497            MeetingUserEvent::PeerAdded(peer) => peer
498        )
499        .into_typed::<Chatter>()
500        .unwrap();
501
502        let peer_client = mock_recv_matching!(
503            meeting_client.receiver,
504            Duration::from_secs(1),
505            MeetingUserEvent::PeerAdded(peer) => peer
506        )
507        .into_typed::<Chatter>()
508        .unwrap();
509
510        peer_server
511            .sender
512            .send("Hello from server to client".to_owned())
513            .unwrap();
514
515        let msg = peer_client
516            .receiver
517            .recv_blocking_timeout(Duration::from_secs(1))
518            .unwrap();
519        assert_eq!(&msg, "Hello from server to client");
520
521        terminate_client.send(()).unwrap();
522        terminate_server.send(()).unwrap();
523        session_client.join().unwrap();
524        session_server.join().unwrap();
525    }
526}