Skip to main content

tehuti_socket/
lib.rs

1use std::{
2    collections::{BTreeMap, btree_map::Entry},
3    error::Error,
4    future::pending,
5    io::{Cursor, ErrorKind, Read, Write},
6    net::{TcpListener, TcpStream},
7    sync::Arc,
8    thread::{Builder, JoinHandle, sleep},
9    time::Duration,
10    vec,
11};
12use tehuti::{
13    engine::EnginePeerDescriptor,
14    event::{Duplex, Receiver, Sender, unbounded},
15    meeting::{Meeting, MeetingEngineEvent, MeetingInterface, MeetingInterfaceResult},
16    peer::{PeerFactory, PeerId},
17    protocol::{ProtocolControlFrame, ProtocolFrame, ProtocolPacketFrame},
18};
19
20pub struct TcpHost {
21    listener: TcpListener,
22    factory: Arc<PeerFactory>,
23}
24
25impl TcpHost {
26    pub fn make(listener: TcpListener, factory: Arc<PeerFactory>) -> Result<Self, Box<dyn Error>> {
27        listener.set_nonblocking(true)?;
28        Ok(TcpHost { listener, factory })
29    }
30
31    pub fn accept(&self) -> Result<Option<TcpSessionResult>, Box<dyn Error>> {
32        match self.listener.accept() {
33            Ok((stream, _)) => match TcpSession::make(stream, self.factory.clone()) {
34                Ok(session_result) => {
35                    return Ok(Some(session_result));
36                }
37                Err(err) => {
38                    tracing::event!(
39                        target: "tehuti::socket::host",
40                        tracing::Level::ERROR,
41                        "Failed to create session: {}",
42                        err,
43                    );
44                }
45            },
46            Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
47            Err(err) => {
48                tracing::event!(
49                    target: "tehuti::socket::host",
50                    tracing::Level::ERROR,
51                    "Failed to accept incoming connection: {}",
52                    err,
53                );
54            }
55        }
56        Ok(None)
57    }
58
59    pub async fn accept_async(&self) -> Result<TcpSessionResult, Box<dyn Error>> {
60        loop {
61            match self.accept() {
62                Ok(Some(session_result)) => {
63                    return Ok(session_result);
64                }
65                Ok(None) => {
66                    pending::<()>().await;
67                }
68                Err(err) => {
69                    tracing::event!(
70                        target: "tehuti::socket::host",
71                        tracing::Level::ERROR,
72                        "Session listener encountered error: {}",
73                        err,
74                    );
75                    return Err(err);
76                }
77            }
78        }
79    }
80
81    pub fn run(
82        self,
83        interval: Duration,
84        session_interval: Duration,
85        meeting_sender: Sender<(MeetingInterface, Sender<()>)>,
86    ) -> Result<JoinHandle<()>, Box<dyn Error>> {
87        Ok(Builder::new()
88            .name("Session Listener".to_string())
89            .spawn(move || {
90                loop {
91                    match self.accept() {
92                        Ok(Some(session_result)) => {
93                            let (terminate_sender, terminate_receiver) = unbounded();
94                            let TcpSessionResult { session, interface } = session_result;
95                            if let Err(err) = meeting_sender.send((interface, terminate_sender)) {
96                                tracing::event!(
97                                    target: "tehuti::socket::host",
98                                    tracing::Level::ERROR,
99                                    "Failed to send meeting interface to engine: {}",
100                                    err,
101                                );
102                            }
103                            if let Err(err) = session.run(session_interval, terminate_receiver) {
104                                tracing::event!(
105                                    target: "tehuti::socket::host",
106                                    tracing::Level::ERROR,
107                                    "Failed to run session: {}",
108                                    err,
109                                );
110                            }
111                        }
112                        Ok(None) => {}
113                        Err(err) => {
114                            tracing::event!(
115                                target: "tehuti::socket::host",
116                                tracing::Level::ERROR,
117                                "Session listener encountered error: {}",
118                                err,
119                            );
120                        }
121                    }
122                    sleep(interval);
123                }
124            })?)
125    }
126}
127
128pub struct TcpSessionResult {
129    pub session: TcpSession,
130    pub interface: MeetingInterface,
131}
132
133pub struct TcpSession {
134    id: String,
135    stream: TcpStream,
136    meeting: Meeting,
137    engine_event: Duplex<MeetingEngineEvent>,
138    peers: BTreeMap<PeerId, EnginePeerDescriptor>,
139    buffer_in: Vec<u8>,
140    buffer_out: Vec<u8>,
141    terminated: bool,
142}
143
144impl TcpSession {
145    pub fn make(
146        stream: TcpStream,
147        factory: Arc<PeerFactory>,
148    ) -> Result<TcpSessionResult, Box<dyn Error>> {
149        let id = format!("{}<->{}", stream.local_addr()?, stream.peer_addr()?);
150        let MeetingInterfaceResult {
151            meeting,
152            interface,
153            engine_event,
154        } = MeetingInterface::make(factory, id.clone());
155        stream.set_nonblocking(true)?;
156        Ok(TcpSessionResult {
157            session: TcpSession {
158                id,
159                stream,
160                meeting,
161                engine_event,
162                peers: Default::default(),
163                buffer_in: Default::default(),
164                buffer_out: Default::default(),
165                terminated: false,
166            },
167            interface,
168        })
169    }
170
171    pub fn maintain(&mut self) -> Result<(), Box<dyn Error>> {
172        if self.terminated {
173            return Err(format!("Session {} is terminated", self.id).into());
174        }
175        self.receive_frames()?;
176        self.send_frames()?;
177        Ok(())
178    }
179
180    fn receive_frames(&mut self) -> Result<(), Box<dyn Error>> {
181        let mut buffer = vec![0u8; 4096];
182        loop {
183            match self.stream.read(&mut buffer) {
184                Ok(0) => break,
185                Ok(n) => {
186                    self.buffer_in.extend_from_slice(&buffer[..n]);
187                }
188                Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
189                Err(e) => return Err(Box::new(e)),
190            }
191        }
192        let mut cursor = Cursor::new(&self.buffer_in);
193        let mut frames = Vec::new();
194        loop {
195            match ProtocolFrame::read(&mut cursor) {
196                Ok(frame) => frames.push(frame),
197                Err(ref e) if e.kind() == ErrorKind::UnexpectedEof => break,
198                Err(e) => return Err(Box::new(e)),
199            }
200        }
201        let pos = cursor.position() as usize;
202        self.buffer_in.drain(0..pos);
203        for frame in frames {
204            match frame {
205                ProtocolFrame::Control(frame) => match frame {
206                    ProtocolControlFrame::CreatePeer(peer_id, peer_role_id) => {
207                        tracing::event!(
208                            target: "tehuti::socket::session",
209                            tracing::Level::TRACE,
210                            "Session {} got create peer {:?} with role {:?}",
211                            self.id,
212                            peer_id,
213                            peer_role_id,
214                        );
215                        self.engine_event
216                            .sender
217                            .send(MeetingEngineEvent::PeerJoined(peer_id, peer_role_id))
218                            .map_err(|err| {
219                                format!("Session {} outside engine sender error: {err}", self.id)
220                            })
221                            .unwrap();
222                    }
223                    ProtocolControlFrame::DestroyPeer(peer_id) => {
224                        tracing::event!(
225                            target: "tehuti::socket::session",
226                            tracing::Level::TRACE,
227                            "Session {} got destroy peer {:?}",
228                            self.id,
229                            peer_id,
230                        );
231                        self.engine_event
232                            .sender
233                            .send(MeetingEngineEvent::PeerLeft(peer_id))
234                            .map_err(|err| {
235                                format!("Session {} outside engine sender error: {err}", self.id)
236                            })
237                            .unwrap();
238                    }
239                    _ => {
240                        tracing::event!(
241                            target: "tehuti::socket::session",
242                            tracing::Level::WARN,
243                            "Session {} got unhandled control frame: {:?}",
244                            self.id,
245                            frame,
246                        );
247                    }
248                },
249                ProtocolFrame::Packet(frame) => {
250                    if let Some(peer) = self.peers.get(&frame.peer_id) {
251                        if let Some(sender) = peer.packet_senders.get(&frame.channel_id) {
252                            tracing::event!(
253                                target: "tehuti::socket::session",
254                                tracing::Level::TRACE,
255                                "Session {} got packet frame for peer {:?} channel {:?}: {} bytes",
256                                self.id,
257                                frame.peer_id,
258                                frame.channel_id,
259                                frame.data.len(),
260                            );
261                            sender
262                                .sender
263                                .send(frame.data)
264                                .map_err(|err| {
265                                    format!("Session {} packet sender error: {err}", self.id)
266                                })
267                                .unwrap();
268                        } else {
269                            tracing::event!(
270                                target: "tehuti::socket::session",
271                                tracing::Level::WARN,
272                                "Session {} got packet frame for unknown channel {:?} of peer {:?}",
273                                self.id,
274                                frame.channel_id,
275                                frame.peer_id,
276                            );
277                        }
278                    } else {
279                        tracing::event!(
280                            target: "tehuti::socket::session",
281                            tracing::Level::WARN,
282                            "Session {} got packet frame for unknown peer {:?}",
283                            self.id,
284                            frame.peer_id,
285                        );
286                    }
287                }
288            }
289        }
290        Ok(())
291    }
292
293    fn send_frames(&mut self) -> Result<(), Box<dyn Error>> {
294        for peer in self.peers.values() {
295            for (channel_id, receiver) in &peer.packet_receivers {
296                for data in receiver.receiver.iter() {
297                    tracing::event!(
298                        target: "tehuti::socket::session",
299                        tracing::Level::TRACE,
300                        "Session {} sending packet frame for peer {:?} channel {:?}: {} bytes",
301                        self.id,
302                        peer.info.peer_id,
303                        channel_id,
304                        data.len(),
305                    );
306                    ProtocolFrame::Packet(ProtocolPacketFrame {
307                        peer_id: peer.info.peer_id,
308                        channel_id: *channel_id,
309                        data,
310                    })
311                    .write(&mut self.buffer_out)?;
312                }
313            }
314        }
315        if let Err(err) = self.meeting.pump_all() {
316            tracing::event!(
317                target: "tehuti::socket::session",
318                tracing::Level::ERROR,
319                "Session {} encountered error: {}. Terminating",
320                self.id,
321                err,
322            );
323            self.terminated = true;
324            return Err(err);
325        }
326        for event in self.engine_event.receiver.iter() {
327            match event {
328                MeetingEngineEvent::MeetingDestroyed => {
329                    tracing::event!(
330                        target: "tehuti::socket::session",
331                        tracing::Level::TRACE,
332                        "Session {} terminating",
333                        self.id,
334                    );
335                    self.terminated = true;
336                    return Err(format!("Session {} meeting destroyed", self.id).into());
337                }
338                MeetingEngineEvent::PeerCreated(descriptor) => {
339                    if let Entry::Vacant(entry) = self.peers.entry(descriptor.info.peer_id) {
340                        if !descriptor.info.remote {
341                            ProtocolFrame::Control(ProtocolControlFrame::CreatePeer(
342                                descriptor.info.peer_id,
343                                descriptor.info.role_id,
344                            ))
345                            .write(&mut self.buffer_out)?;
346                        }
347                        tracing::event!(
348                            target: "tehuti::socket::session",
349                            tracing::Level::TRACE,
350                            "Session {} created peer {:?}",
351                            self.id,
352                            descriptor.info.peer_id,
353                        );
354                        entry.insert(descriptor);
355                    } else {
356                        tracing::event!(
357                            target: "tehuti::socket::session",
358                            tracing::Level::WARN,
359                            "Session {} got duplicate peer {:?} created",
360                            self.id,
361                            descriptor.info.peer_id,
362                        );
363                    }
364                }
365                MeetingEngineEvent::PeerDestroyed(peer_id) => {
366                    if self.peers.contains_key(&peer_id) {
367                        ProtocolFrame::Control(ProtocolControlFrame::DestroyPeer(peer_id))
368                            .write(&mut self.buffer_out)?;
369                        tracing::event!(
370                            target: "tehuti::socket::session",
371                            tracing::Level::TRACE,
372                            "Session {} destroyed peer {:?}",
373                            self.id,
374                            peer_id,
375                        );
376                        self.peers.remove(&peer_id);
377                    } else {
378                        tracing::event!(
379                            target: "tehuti::socket::session",
380                            tracing::Level::WARN,
381                            "Session {} got unknown peer {:?} destroyed",
382                            self.id,
383                            peer_id,
384                        );
385                    }
386                }
387                event => {
388                    tracing::event!(
389                        target: "tehuti::socket::session",
390                        tracing::Level::WARN,
391                        "Session {} got unhandled engine event: {:?}",
392                        self.id,
393                        event,
394                    );
395                }
396            }
397        }
398        loop {
399            match self.stream.write(&self.buffer_out) {
400                Ok(0) => break,
401                Ok(n) => {
402                    self.buffer_out.drain(0..n);
403                }
404                Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
405                Err(e) => return Err(Box::new(e)),
406            }
407        }
408        Ok(())
409    }
410
411    pub fn run(
412        mut self,
413        interval: Duration,
414        terminate_receiver: Receiver<()>,
415    ) -> Result<JoinHandle<()>, Box<dyn Error>> {
416        Ok(Builder::new()
417            .name(format!("Session {}", self.id))
418            .spawn(move || {
419                loop {
420                    if terminate_receiver.try_recv().is_some() {
421                        tracing::event!(
422                            target: "tehuti::socket::session",
423                            tracing::Level::TRACE,
424                            "Session {} terminating on request",
425                            self.id,
426                        );
427                        break;
428                    }
429                    if let Err(err) = self.maintain() {
430                        tracing::event!(
431                            target: "tehuti::socket::session",
432                            tracing::Level::ERROR,
433                            "Session {} terminated with error: {}",
434                            self.id,
435                            err,
436                        );
437                        break;
438                    }
439                    sleep(interval);
440                }
441            })?)
442    }
443
444    pub async fn into_future(mut self) -> Result<(), Box<dyn Error>> {
445        loop {
446            if let Err(err) = self.maintain() {
447                tracing::event!(
448                    target: "tehuti::socket::session",
449                    tracing::Level::ERROR,
450                    "Session {} terminated with error: {}",
451                    self.id,
452                    err,
453                );
454                break;
455            }
456            pending::<()>().await;
457        }
458        Ok(())
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use tehuti::{
466        channel::{ChannelId, ChannelMode},
467        meeting::MeetingUserEvent,
468        peer::{PeerBuilder, PeerDestructurer, PeerRoleId, TypedPeer},
469    };
470    use tehuti_mock::{mock_env_tracing, mock_recv_matching};
471
472    struct Chatter {
473        pub sender: Sender<String>,
474        pub receiver: Receiver<String>,
475    }
476
477    impl TypedPeer for Chatter {
478        fn builder(builder: PeerBuilder) -> PeerBuilder {
479            builder.bind_read_write::<String, String>(
480                ChannelId::new(0),
481                ChannelMode::ReliableOrdered,
482                None,
483            )
484        }
485
486        fn into_typed(mut destructurer: PeerDestructurer) -> Result<Self, Box<dyn Error>> {
487            let sender = destructurer.write::<String>(ChannelId::new(0))?;
488            let receiver = destructurer.read::<String>(ChannelId::new(0))?;
489            Ok(Self { sender, receiver })
490        }
491    }
492
493    #[test]
494    fn test_tcp_session_creation() {
495        mock_env_tracing();
496
497        let factory = Arc::new(PeerFactory::default().with(PeerRoleId::new(0), Chatter::builder));
498
499        let listener = TcpListener::bind("127.0.0.1:8888").unwrap();
500        let host = TcpHost::make(listener, factory.clone()).unwrap();
501        let stream = TcpStream::connect("127.0.0.1:8888").unwrap();
502
503        let TcpSessionResult {
504            session,
505            interface: meeting_client,
506        } = TcpSession::make(stream, factory).unwrap();
507        let (terminate_client, terminate_receiver) = unbounded();
508        let session_client = session.run(Duration::ZERO, terminate_receiver).unwrap();
509
510        sleep(Duration::from_millis(100));
511
512        let TcpSessionResult {
513            session,
514            interface: meeting_server,
515        } = host.accept().unwrap().unwrap();
516        let (terminate_server, terminate_receiver) = unbounded();
517        let session_server = session.run(Duration::ZERO, terminate_receiver).unwrap();
518
519        meeting_server
520            .sender
521            .send(MeetingUserEvent::PeerCreate(
522                PeerId::new(0),
523                PeerRoleId::new(0),
524            ))
525            .unwrap();
526
527        let peer_server = mock_recv_matching!(
528            meeting_server.receiver,
529            Duration::from_secs(1),
530            MeetingUserEvent::PeerAdded(peer) => peer
531        )
532        .into_typed::<Chatter>()
533        .unwrap();
534
535        let peer_client = mock_recv_matching!(
536            meeting_client.receiver,
537            Duration::from_secs(1),
538            MeetingUserEvent::PeerAdded(peer) => peer
539        )
540        .into_typed::<Chatter>()
541        .unwrap();
542
543        peer_server
544            .sender
545            .send("Hello from server to client".to_owned())
546            .unwrap();
547
548        let msg = peer_client
549            .receiver
550            .recv_blocking_timeout(Duration::from_secs(1))
551            .unwrap();
552        assert_eq!(&msg, "Hello from server to client");
553
554        terminate_client.send(()).unwrap();
555        terminate_server.send(()).unwrap();
556        session_client.join().unwrap();
557        session_server.join().unwrap();
558    }
559}