Skip to main content

mill_net/tcp/
mod.rs

1//! TCP server and client implementations with lockfree connection management.
2//!
3//! This module provides high-performance TCP networking built on Mill-IO's event loop.
4//!
5//! Each TCP connection is assigned a unique ConnectionId generated atomically.
6//! The connection state is stored in a lockfree map, allowing concurrent access
7//! from multiple worker threads without blocking.
8//!
9//! ```text
10//! Connection Storage:
11//!   LockfreeMap<u64, TcpConnection>
12//!        │
13//!        ├--> ConnId(1) --> TcpConnection { stream, token, addr }
14//!        ├--> ConnId(2) --> TcpConnection { stream, token, addr }
15//!        └--> ConnId(N) --> TcpConnection { stream, token, addr }
16//! ```
17//!
18//! ## Event Handling Pipeline
19//!
20//! ```text
21//! 1. Listener Events:
22//!    New Connection --> TcpListenerHandler::handle_event()
23//!        - accept() --> Create ConnectionId
24//!        - Register TcpConnectionHandler with EventLoop
25//!        - Insert into LockfreeMap
26//!        - Call handler.on_connect()
27//!
28//! 2. Connection Events:
29//!    Readable Event --> TcpConnectionHandler::handle_event()
30//!        - Read from stream into pooled buffer
31//!        - Call handler.on_data()
32//!        - If EOF: disconnect()
33//!
34//!    Writable Event --> TcpConnectionHandler::handle_event()
35//!        - Call handler.on_writable()
36//!
37//! 3. Disconnection:
38//!    disconnect() --> Remove from LockfreeMap
39//!        - Deregister from EventLoop
40//!        - Call handler.on_disconnect()
41//! ```
42//!
43//! ## Configuration
44//!
45//! TcpServerConfig uses the builder pattern for ergonomic configuration:
46//!
47//! ```rust
48//! use mill_net::tcp::config::TcpServerConfig;
49//! # use std::sync::Arc;
50//!
51//! let config = TcpServerConfig::builder()
52//!     .address("0.0.0.0:8080".parse().unwrap())
53//!     .buffer_size(16384)              // Larger buffers for high throughput
54//!     .max_connections(1000)           // Limit concurrent connections
55//!     .no_delay(true)                  // Disable Nagle's algorithm
56//!     .build();
57//! ```
58//!
59//! ## Handler Implementation
60//!
61//! Your handler must implement NetworkHandler trait.
62//!
63//! ```rust
64//! use mill_net::tcp::{traits::{NetworkHandler, ConnectionId}, ServerContext};
65//! use mill_net::errors::Result;
66//!
67//! struct MyHandler;
68//!
69//! impl NetworkHandler for MyHandler {
70//!     fn on_data(&self, ctx: &ServerContext, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
71//!         println!("Received {} bytes from {:?}", data.len(), conn_id);
72//!         ctx.send_to(conn_id, b"some response")?;
73//!         Ok(())
74//!     }
75//! }
76//! ```
77
78pub mod config;
79pub mod traits;
80
81pub use config::TcpServerConfig;
82pub use traits::*;
83
84use crate::errors::Result;
85use crate::errors::{NetworkError, NetworkEvent};
86use lock_freedom::map::Map as LockfreeMap;
87use mill_io::{EventHandler, EventLoop, ObjectPool, PooledObject};
88use mio::event::Event;
89use mio::net::{TcpListener, TcpStream};
90use mio::{Interest, Token};
91use parking_lot::Mutex;
92use std::io;
93use std::io::{Read, Write};
94use std::net::SocketAddr;
95use std::sync::{
96    atomic::{AtomicU64, AtomicUsize, Ordering},
97    Arc, RwLock, Weak,
98};
99
100/// Context for network handlers to interact with the server.
101pub struct ServerContext {
102    server: RwLock<Option<Arc<dyn ServerOperations>>>,
103    event_loop: RwLock<Option<Arc<EventLoop>>>,
104}
105
106impl ServerContext {
107    /// Send data to a specific connection.
108    pub fn send_to(&self, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
109        if let Some(server) = self.server.read().unwrap().as_ref() {
110            server.send_to(conn_id, data)
111        } else {
112            Ok(())
113        }
114    }
115
116    /// Broadcast data to all connections.
117    pub fn broadcast(&self, data: &[u8]) -> Result<()> {
118        if let Some(server) = self.server.read().unwrap().as_ref() {
119            server.broadcast(data)
120        } else {
121            Ok(())
122        }
123    }
124
125    /// Close a specific connection.
126    pub fn close_connection(&self, conn_id: ConnectionId) -> Result<()> {
127        let server_guard = self.server.read().unwrap();
128        if let Some(server) = server_guard.as_ref() {
129            if let Some(event_loop) = self.event_loop.read().unwrap().as_ref() {
130                server.close_connection(event_loop, conn_id)
131            } else {
132                Ok(())
133            }
134        } else {
135            Ok(())
136        }
137    }
138}
139
140/// Trait defining server operations, allowing for a weak reference from the context.
141trait ServerOperations: Send + Sync {
142    fn send_to(&self, conn_id: ConnectionId, data: &[u8]) -> Result<()>;
143    fn broadcast(&self, data: &[u8]) -> Result<()>;
144    fn close_connection(&self, event_loop: &EventLoop, conn_id: ConnectionId) -> Result<()>;
145}
146
147/// High-level TCP server
148pub struct TcpServer<H: NetworkHandler> {
149    listener: Arc<Mutex<TcpListener>>,
150    connections: Arc<LockfreeMap<ConnectionId, TcpConnection>>,
151    handler: Arc<H>,
152    config: TcpServerConfig,
153    buffer_pool: ObjectPool<Vec<u8>>,
154    next_conn_id: Arc<AtomicU64>,
155    connection_counter: Arc<AtomicUsize>,
156    context: Arc<ServerContext>,
157}
158
159impl<H: NetworkHandler> TcpServer<H> {
160    pub fn new(config: TcpServerConfig, handler: H) -> Result<Self> {
161        let listener = TcpListener::bind(config.address)?;
162
163        Ok(Self {
164            listener: Arc::new(Mutex::new(listener)),
165            connections: Arc::new(LockfreeMap::new()),
166            handler: Arc::new(handler),
167            buffer_pool: ObjectPool::new(20, move || vec![0; config.buffer_size]),
168            next_conn_id: Arc::new(AtomicU64::new(1)),
169            connection_counter: Arc::new(AtomicUsize::new(0)),
170            config,
171            context: Arc::new(ServerContext {
172                server: RwLock::new(None),
173                event_loop: RwLock::new(None),
174            }),
175        })
176    }
177
178    /// Get the local address the server is bound to
179    pub fn local_addr(&self) -> Result<SocketAddr> {
180        Ok(self.listener.lock().local_addr()?)
181    }
182
183    /// Start the server by registering with the event loop
184    pub fn start(
185        self: Arc<Self>,
186        event_loop: &Arc<EventLoop>,
187        listener_token: Token,
188    ) -> Result<()> {
189        // Update the context.
190        *self.context.server.write().unwrap() = Some(self.clone());
191        *self.context.event_loop.write().unwrap() = Some(event_loop.clone());
192
193        let listener_handler = TcpListenerHandler {
194            listener: self.listener.clone(),
195            connections: self.connections.clone(),
196            handler: self.handler.clone(),
197            config: self.config.clone(),
198            buffer_pool: self.buffer_pool.clone(),
199            next_conn_id: self.next_conn_id.clone(),
200            event_loop: Arc::downgrade(event_loop),
201            connection_counter: self.connection_counter.clone(),
202            context: self.context.clone(),
203        };
204
205        event_loop.register(
206            &mut *self.listener.lock(),
207            listener_token,
208            Interest::READABLE,
209            listener_handler,
210        )?;
211
212        Ok(())
213    }
214
215    /// Get active connection count
216    pub fn connection_count(&self) -> usize {
217        self.connection_counter.load(Ordering::SeqCst)
218    }
219}
220
221impl<H: NetworkHandler> ServerOperations for TcpServer<H> {
222    /// Send data to a specific connection
223    fn send_to(&self, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
224        if let Some(conn) = self.connections.get(&conn_id) {
225            let mut stream = conn.val().stream.lock();
226            stream.write_all(data)?;
227        } else {
228            return Err(Box::new(NetworkError::ConnectionNotFound(conn_id)));
229        }
230        Ok(())
231    }
232
233    /// Close a specific connection
234    fn close_connection(&self, event_loop: &EventLoop, conn_id: ConnectionId) -> Result<()> {
235        if let Some(conn) = self.connections.remove(&conn_id) {
236            let mut stream = conn.val().stream.lock();
237
238            let _ = event_loop.deregister(&mut *stream, conn.val().token);
239            let _ = stream.shutdown(std::net::Shutdown::Both);
240
241            if let Err(e) = self.handler.on_disconnect(&self.context, conn_id) {
242                self.handler.on_error(
243                    &self.context,
244                    Some(conn_id),
245                    NetworkError::HandlerError(format!("on_disconnect: {}", e)),
246                );
247            }
248
249            let _ = self
250                .handler
251                .on_event(&self.context, NetworkEvent::ConnectionClosed(conn_id));
252        }
253        Ok(())
254    }
255
256    /// Broadcast data to all connections
257    fn broadcast(&self, data: &[u8]) -> Result<()> {
258        for conn in self.connections.iter() {
259            let mut stream = conn.val().stream.lock();
260            if let Err(e) = stream.write_all(data) {
261                self.handler.on_error(
262                    &self.context,
263                    Some(*conn.key()),
264                    NetworkError::Io(Box::new(e)),
265                );
266            }
267        }
268        Ok(())
269    }
270}
271
272/// Internal TCP connection representation
273struct TcpConnection {
274    stream: Arc<Mutex<TcpStream>>,
275    token: Token,
276    #[allow(dead_code)]
277    peer_addr: SocketAddr,
278}
279
280/// Handler for accepting new connections
281struct TcpListenerHandler<H: NetworkHandler> {
282    listener: Arc<Mutex<TcpListener>>,
283    connections: Arc<LockfreeMap<ConnectionId, TcpConnection>>,
284    handler: Arc<H>,
285    config: TcpServerConfig,
286    buffer_pool: ObjectPool<Vec<u8>>,
287    next_conn_id: Arc<AtomicU64>,
288    event_loop: Weak<EventLoop>,
289    connection_counter: Arc<AtomicUsize>,
290    context: Arc<ServerContext>,
291}
292
293// Safety: We ensure event_loop_ref is valid for the lifetime of the handler
294unsafe impl<H: NetworkHandler> Send for TcpListenerHandler<H> {}
295unsafe impl<H: NetworkHandler> Sync for TcpListenerHandler<H> {}
296
297impl<H: NetworkHandler> EventHandler for TcpListenerHandler<H> {
298    fn handle_event(&self, event: &Event) {
299        if !event.is_readable() {
300            return;
301        }
302
303        loop {
304            let listener = self.listener.lock();
305
306            match listener.accept() {
307                Ok((stream, peer_addr)) => {
308                    // Atomically check and increment the connection count.
309                    if let Some(max) = self.config.max_connections {
310                        let mut accepted = false;
311                        loop {
312                            let current = self.connection_counter.load(Ordering::SeqCst);
313                            if current >= max {
314                                self.handler.on_error(
315                                    &self.context,
316                                    None,
317                                    NetworkError::MaxConnectionsReached(peer_addr),
318                                );
319                                break;
320                            }
321                            match self.connection_counter.compare_exchange(
322                                current,
323                                current + 1,
324                                Ordering::SeqCst,
325                                Ordering::SeqCst,
326                            ) {
327                                Ok(_) => {
328                                    accepted = true;
329                                    break;
330                                }
331                                Err(_) => continue,
332                            }
333                        }
334                        if !accepted {
335                            continue;
336                        }
337                    } else {
338                        self.connection_counter.fetch_add(1, Ordering::SeqCst);
339                    }
340
341                    if let Err(e) = stream.set_nodelay(self.config.no_delay) {
342                        self.handler.on_error(
343                            &self.context,
344                            None,
345                            NetworkError::Configuration(format!(
346                                "Failed to set TCP_NODELAY: {}",
347                                e
348                            )),
349                        );
350                    }
351
352                    let conn_id = ConnectionId(self.next_conn_id.fetch_add(1, Ordering::SeqCst));
353                    let token = Token(conn_id.as_u64() as usize);
354
355                    let stream_arc = Arc::new(Mutex::new(stream));
356
357                    let conn_handler = TcpConnectionHandler {
358                        conn_id,
359                        stream: stream_arc.clone(),
360                        connections: self.connections.clone(),
361                        handler: self.handler.clone(),
362                        buffer_pool: self.buffer_pool.clone(),
363                        event_loop: self.event_loop.clone(),
364                        connection_counter: self.connection_counter.clone(),
365                        context: self.context.clone(),
366                    };
367
368                    let event_loop = if let Some(arc) = self.event_loop.upgrade() {
369                        arc
370                    } else {
371                        self.handler
372                            .on_error(&self.context, None, NetworkError::EventLoopGone);
373                        self.connection_counter.fetch_sub(1, Ordering::SeqCst);
374                        continue;
375                    };
376
377                    if let Err(e) = event_loop.register(
378                        &mut *stream_arc.lock(),
379                        token,
380                        Interest::READABLE | Interest::WRITABLE,
381                        conn_handler,
382                    ) {
383                        self.handler
384                            .on_error(&self.context, Some(conn_id), NetworkError::Io(e));
385                        self.connection_counter.fetch_sub(1, Ordering::SeqCst);
386                        continue;
387                    }
388
389                    let conn = TcpConnection {
390                        stream: stream_arc,
391                        token,
392                        peer_addr,
393                    };
394                    self.connections.insert(conn_id, conn);
395
396                    let _ = self.handler.on_event(
397                        &self.context,
398                        NetworkEvent::ConnectionEstablished(conn_id, peer_addr),
399                    );
400
401                    if let Err(e) = self.handler.on_connect(&self.context, conn_id) {
402                        self.handler.on_error(
403                            &self.context,
404                            Some(conn_id),
405                            NetworkError::HandlerError(format!("on_connect: {}", e)),
406                        );
407
408                        // Remove the connection from the map and deregister from the event loop.
409                        if let Some(conn) = self.connections.remove(&conn_id) {
410                            let mut stream = conn.val().stream.lock();
411                            let _ = event_loop.deregister(&mut *stream, conn.val().token);
412                        }
413                        self.connection_counter.fetch_sub(1, Ordering::SeqCst);
414                        continue;
415                    }
416                }
417                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
418                    break;
419                }
420                Err(e) => {
421                    self.handler
422                        .on_error(&self.context, None, NetworkError::Accept(Box::new(e)));
423                    break;
424                }
425            }
426        }
427    }
428}
429
430/// Handler for individual TCP connections
431struct TcpConnectionHandler<H: NetworkHandler> {
432    conn_id: ConnectionId,
433    stream: Arc<Mutex<TcpStream>>,
434    connections: Arc<LockfreeMap<ConnectionId, TcpConnection>>,
435    handler: Arc<H>,
436    buffer_pool: ObjectPool<Vec<u8>>,
437    event_loop: Weak<EventLoop>,
438    connection_counter: Arc<AtomicUsize>,
439    context: Arc<ServerContext>,
440}
441
442unsafe impl<H: NetworkHandler> Send for TcpConnectionHandler<H> {}
443unsafe impl<H: NetworkHandler> Sync for TcpConnectionHandler<H> {}
444
445impl<H: NetworkHandler> EventHandler for TcpConnectionHandler<H> {
446    fn handle_event(&self, event: &Event) {
447        let is_readable = event.is_readable();
448        let is_writable = event.is_writable();
449
450        if is_readable {
451            self.handle_read();
452        }
453
454        if is_writable {
455            if let Err(e) = self.handler.on_writable(&self.context, self.conn_id) {
456                self.handler.on_error(
457                    &self.context,
458                    Some(self.conn_id),
459                    NetworkError::HandlerError(format!("on_writable: {}", e)),
460                );
461            }
462        }
463    }
464}
465
466impl<H: NetworkHandler> TcpConnectionHandler<H> {
467    fn handle_read(&self) {
468        let mut buffer: PooledObject<Vec<u8>> = self.buffer_pool.acquire();
469
470        let read_result = {
471            let mut stream = self.stream.lock();
472            stream.read(buffer.as_mut())
473        };
474
475        match read_result {
476            Ok(0) => {
477                // connection closed
478                self.disconnect();
479            }
480            Ok(n) => {
481                if let Err(e) =
482                    self.handler
483                        .on_data(&self.context, self.conn_id, &buffer.as_ref()[..n])
484                {
485                    self.handler.on_error(
486                        &self.context,
487                        Some(self.conn_id),
488                        NetworkError::HandlerError(format!("on_data: {}", e)),
489                    );
490                    self.disconnect();
491                }
492            }
493            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
494                // expected for non-blocking I/O
495            }
496            Err(e) => {
497                self.handler.on_error(
498                    &self.context,
499                    Some(self.conn_id),
500                    NetworkError::Io(Box::new(e)),
501                );
502                self.disconnect();
503            }
504        }
505    }
506
507    fn disconnect(&self) {
508        if let Some(conn) = self.connections.remove(&self.conn_id) {
509            self.connection_counter.fetch_sub(1, Ordering::SeqCst);
510            if let Some(event_loop) = self.event_loop.upgrade() {
511                let mut stream = conn.val().stream.lock();
512                let _ = event_loop.deregister(&mut *stream, conn.val().token);
513            }
514
515            if let Err(e) = self.handler.on_disconnect(&self.context, self.conn_id) {
516                self.handler.on_error(
517                    &self.context,
518                    Some(self.conn_id),
519                    NetworkError::HandlerError(format!("on_disconnect: {}", e)),
520                );
521            }
522
523            let _ = self
524                .handler
525                .on_event(&self.context, NetworkEvent::ConnectionClosed(self.conn_id));
526        }
527    }
528}
529
530/// High-level TCP client
531#[derive(Clone)]
532pub struct TcpClient<H: NetworkHandler> {
533    stream: Arc<Mutex<Option<TcpStream>>>,
534    handler: Arc<H>,
535    buffer_pool: ObjectPool<Vec<u8>>,
536    conn_id: ConnectionId,
537    context: Arc<ServerContext>,
538}
539
540impl<H: NetworkHandler> TcpClient<H> {
541    pub fn connect(addr: SocketAddr, handler: H) -> Result<Self> {
542        let stream = TcpStream::connect(addr)?;
543
544        Ok(Self {
545            stream: Arc::new(Mutex::new(Some(stream))),
546            handler: Arc::new(handler),
547            buffer_pool: ObjectPool::new(5, || vec![0; 8192]),
548            conn_id: ConnectionId::new(1),
549            context: Arc::new(ServerContext {
550                server: RwLock::new(None),
551                event_loop: RwLock::new(None),
552            }),
553        })
554    }
555
556    pub fn start(&mut self, event_loop: &Arc<EventLoop>, token: Token) -> Result<()> {
557        *self.context.event_loop.write().unwrap() = Some(event_loop.clone());
558        *self.context.server.write().unwrap() = None;
559
560        let handler = TcpClientHandler {
561            conn_id: self.conn_id,
562            stream: self.stream.clone(),
563            handler: self.handler.clone(),
564            buffer_pool: self.buffer_pool.clone(),
565            event_loop: Arc::downgrade(event_loop),
566            context: self.context.clone(),
567        };
568
569        let mut stream_guard = self.stream.lock();
570        let stream = stream_guard
571            .as_mut()
572            .ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "TCP stream is None"))?;
573
574        event_loop.register(
575            stream,
576            token,
577            Interest::READABLE | Interest::WRITABLE,
578            handler,
579        )?;
580
581        self.handler.on_connect(&self.context, self.conn_id)?;
582
583        Ok(())
584    }
585
586    pub fn send(&self, data: &[u8]) -> Result<()> {
587        let mut stream_guard = self.stream.lock();
588        if let Some(stream) = stream_guard.as_mut() {
589            stream.write_all(data)?;
590        }
591        Ok(())
592    }
593
594    pub fn disconnect(&self) -> Result<()> {
595        *self.stream.lock() = None;
596        self.handler.on_disconnect(&self.context, self.conn_id)?;
597        Ok(())
598    }
599}
600
601struct TcpClientHandler<H: NetworkHandler> {
602    conn_id: ConnectionId,
603    stream: Arc<Mutex<Option<TcpStream>>>,
604    handler: Arc<H>,
605    buffer_pool: ObjectPool<Vec<u8>>,
606    event_loop: Weak<EventLoop>,
607    context: Arc<ServerContext>,
608}
609
610impl<H: NetworkHandler> EventHandler for TcpClientHandler<H> {
611    fn handle_event(&self, event: &Event) {
612        if event.is_readable() {
613            self.handle_read();
614        }
615        if event.is_writable() {
616            if let Err(e) = self.handler.on_writable(&self.context, self.conn_id) {
617                self.handler.on_error(
618                    &self.context,
619                    Some(self.conn_id),
620                    NetworkError::HandlerError(format!("on_writable: {}", e)),
621                );
622            }
623        }
624    }
625}
626
627impl<H: NetworkHandler> TcpClientHandler<H> {
628    fn handle_read(&self) {
629        let mut buffer: PooledObject<Vec<u8>> = self.buffer_pool.acquire();
630
631        let read_result = {
632            let mut stream_guard = self.stream.lock();
633            if let Some(stream) = stream_guard.as_mut() {
634                stream.read(buffer.as_mut())
635            } else {
636                return;
637            }
638        };
639
640        match read_result {
641            Ok(0) => {
642                // connection closed by remote peer (server)
643                let _ = self
644                    .handler
645                    .on_event(&self.context, NetworkEvent::ConnectionClosed(self.conn_id));
646                self.disconnect();
647            }
648            Ok(n) => {
649                if let Err(e) =
650                    self.handler
651                        .on_data(&self.context, self.conn_id, &buffer.as_ref()[..n])
652                {
653                    self.handler.on_error(
654                        &self.context,
655                        Some(self.conn_id),
656                        NetworkError::HandlerError(format!("on_data: {}", e)),
657                    );
658                    self.disconnect();
659                }
660            }
661            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
662                // expected for non-blocking I/O
663            }
664            Err(e) => {
665                self.handler.on_error(
666                    &self.context,
667                    Some(self.conn_id),
668                    NetworkError::Io(Box::new(e)),
669                );
670                self.disconnect();
671            }
672        }
673    }
674
675    fn disconnect(&self) {
676        let mut stream_guard = self.stream.lock();
677
678        if let Some(stream) = stream_guard.as_mut() {
679            if let Some(event_loop) = self.event_loop.upgrade() {
680                let _ = event_loop.deregister(stream, Token(self.conn_id.as_u64() as usize));
681            }
682        }
683
684        *stream_guard = None;
685
686        if let Err(e) = self.handler.on_disconnect(&self.context, self.conn_id) {
687            self.handler.on_error(
688                &self.context,
689                Some(self.conn_id),
690                NetworkError::HandlerError(format!("on_disconnect: {}", e)),
691            );
692        }
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699    use mill_io::EventLoop;
700    use std::sync::{Arc, Condvar, Mutex};
701    use std::thread;
702    use std::time::Duration;
703
704    struct TestHandler {
705        on_connect_cb: Option<Box<dyn Fn() + Send + Sync>>,
706        #[allow(clippy::type_complexity)]
707        on_data_cb: Option<Box<dyn Fn(&ServerContext, ConnectionId, &[u8]) + Send + Sync>>,
708    }
709
710    impl TestHandler {
711        fn new() -> Self {
712            Self {
713                on_connect_cb: None,
714                on_data_cb: None,
715            }
716        }
717
718        fn with_on_connect<F>(mut self, f: F) -> Self
719        where
720            F: Fn() + Send + Sync + 'static,
721        {
722            self.on_connect_cb = Some(Box::new(f));
723            self
724        }
725
726        fn with_on_data<F>(mut self, f: F) -> Self
727        where
728            F: Fn(&ServerContext, ConnectionId, &[u8]) + Send + Sync + 'static,
729        {
730            self.on_data_cb = Some(Box::new(f));
731            self
732        }
733    }
734
735    impl NetworkHandler for TestHandler {
736        fn on_connect(&self, _ctx: &ServerContext, _conn_id: ConnectionId) -> Result<()> {
737            if let Some(cb) = &self.on_connect_cb {
738                cb();
739            }
740            Ok(())
741        }
742
743        fn on_data(&self, ctx: &ServerContext, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
744            if let Some(cb) = &self.on_data_cb {
745                cb(ctx, conn_id, data);
746            }
747            Ok(())
748        }
749    }
750
751    #[test]
752    fn test_tcp_server_client_echo() {
753        let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap());
754
755        // synchronization primitives
756        let server_connected = Arc::new((Mutex::new(false), Condvar::new()));
757        let client_received = Arc::new((Mutex::new(false), Condvar::new()));
758        let received_data = Arc::new(Mutex::new(Vec::new()));
759
760        // server handler
761        let sc = server_connected.clone();
762        let server_handler = TestHandler::new()
763            .with_on_connect(move || {
764                let (lock, cvar) = &*sc;
765                let mut started = lock.lock().unwrap();
766                *started = true;
767                cvar.notify_all();
768            })
769            .with_on_data(|ctx, conn_id, data| {
770                ctx.send_to(conn_id, data).unwrap();
771            });
772
773        // setup server
774        let config = TcpServerConfig::builder()
775            .address("127.0.0.1:0".parse().unwrap())
776            .build();
777        let server = Arc::new(TcpServer::new(config, server_handler).unwrap());
778        let server_addr = server.local_addr().unwrap();
779
780        server.clone().start(&event_loop, Token(1)).unwrap();
781
782        // client handler
783        let cr = client_received.clone();
784        let rd = received_data.clone();
785        let client_handler = TestHandler::new().with_on_data(move |_, _, data| {
786            let mut r_data = rd.lock().unwrap();
787            r_data.extend_from_slice(data);
788            let (lock, cvar) = &*cr;
789            let mut received = lock.lock().unwrap();
790            *received = true;
791            cvar.notify_all();
792        });
793
794        let mut client = TcpClient::connect(server_addr, client_handler).unwrap();
795        client.start(&event_loop, Token(2)).unwrap();
796
797        let el_clone = event_loop.clone();
798        thread::spawn(move || {
799            el_clone.run().unwrap();
800        });
801
802        {
803            let (lock, cvar) = &*server_connected;
804            let mut started = lock.lock().unwrap();
805            while !*started {
806                let result = cvar.wait_timeout(started, Duration::from_secs(2)).unwrap();
807                if result.1.timed_out() {
808                    panic!("Server did not accept connection in time");
809                }
810                started = result.0;
811            }
812        }
813
814        let msg = b"Hello, World!";
815        client.send(msg).unwrap();
816
817        {
818            let (lock, cvar) = &*client_received;
819            let mut received = lock.lock().unwrap();
820            while !*received {
821                let result = cvar.wait_timeout(received, Duration::from_secs(2)).unwrap();
822                if result.1.timed_out() {
823                    panic!("Client did not receive data in time");
824                }
825                received = result.0;
826            }
827        }
828
829        let data = received_data.lock().unwrap();
830        assert_eq!(*data, msg);
831
832        event_loop.stop();
833    }
834}