1pub mod config;
79pub mod traits;
80
81pub use config::TcpServerConfig;
82pub use traits::*;
83
84use crate::errors::Result;
85use crate::errors::{NetworkError, NetworkEvent};
86use lock_freedom::map::Map as LockfreeMap;
87use mill_io::{EventHandler, EventLoop, ObjectPool, PooledObject};
88use mio::event::Event;
89use mio::net::{TcpListener, TcpStream};
90use mio::{Interest, Token};
91use parking_lot::Mutex;
92use std::io;
93use std::io::{Read, Write};
94use std::net::SocketAddr;
95use std::sync::{
96 atomic::{AtomicU64, AtomicUsize, Ordering},
97 Arc, RwLock, Weak,
98};
99
100pub struct ServerContext {
102 server: RwLock<Option<Arc<dyn ServerOperations>>>,
103 event_loop: RwLock<Option<Arc<EventLoop>>>,
104}
105
106impl ServerContext {
107 pub fn send_to(&self, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
109 if let Some(server) = self.server.read().unwrap().as_ref() {
110 server.send_to(conn_id, data)
111 } else {
112 Ok(())
113 }
114 }
115
116 pub fn broadcast(&self, data: &[u8]) -> Result<()> {
118 if let Some(server) = self.server.read().unwrap().as_ref() {
119 server.broadcast(data)
120 } else {
121 Ok(())
122 }
123 }
124
125 pub fn close_connection(&self, conn_id: ConnectionId) -> Result<()> {
127 let server_guard = self.server.read().unwrap();
128 if let Some(server) = server_guard.as_ref() {
129 if let Some(event_loop) = self.event_loop.read().unwrap().as_ref() {
130 server.close_connection(event_loop, conn_id)
131 } else {
132 Ok(())
133 }
134 } else {
135 Ok(())
136 }
137 }
138}
139
140trait ServerOperations: Send + Sync {
142 fn send_to(&self, conn_id: ConnectionId, data: &[u8]) -> Result<()>;
143 fn broadcast(&self, data: &[u8]) -> Result<()>;
144 fn close_connection(&self, event_loop: &EventLoop, conn_id: ConnectionId) -> Result<()>;
145}
146
147pub struct TcpServer<H: NetworkHandler> {
149 listener: Arc<Mutex<TcpListener>>,
150 connections: Arc<LockfreeMap<ConnectionId, TcpConnection>>,
151 handler: Arc<H>,
152 config: TcpServerConfig,
153 buffer_pool: ObjectPool<Vec<u8>>,
154 next_conn_id: Arc<AtomicU64>,
155 connection_counter: Arc<AtomicUsize>,
156 context: Arc<ServerContext>,
157}
158
159impl<H: NetworkHandler> TcpServer<H> {
160 pub fn new(config: TcpServerConfig, handler: H) -> Result<Self> {
161 let listener = TcpListener::bind(config.address)?;
162
163 Ok(Self {
164 listener: Arc::new(Mutex::new(listener)),
165 connections: Arc::new(LockfreeMap::new()),
166 handler: Arc::new(handler),
167 buffer_pool: ObjectPool::new(20, move || vec![0; config.buffer_size]),
168 next_conn_id: Arc::new(AtomicU64::new(1)),
169 connection_counter: Arc::new(AtomicUsize::new(0)),
170 config,
171 context: Arc::new(ServerContext {
172 server: RwLock::new(None),
173 event_loop: RwLock::new(None),
174 }),
175 })
176 }
177
178 pub fn local_addr(&self) -> Result<SocketAddr> {
180 Ok(self.listener.lock().local_addr()?)
181 }
182
183 pub fn start(
185 self: Arc<Self>,
186 event_loop: &Arc<EventLoop>,
187 listener_token: Token,
188 ) -> Result<()> {
189 *self.context.server.write().unwrap() = Some(self.clone());
191 *self.context.event_loop.write().unwrap() = Some(event_loop.clone());
192
193 let listener_handler = TcpListenerHandler {
194 listener: self.listener.clone(),
195 connections: self.connections.clone(),
196 handler: self.handler.clone(),
197 config: self.config.clone(),
198 buffer_pool: self.buffer_pool.clone(),
199 next_conn_id: self.next_conn_id.clone(),
200 event_loop: Arc::downgrade(event_loop),
201 connection_counter: self.connection_counter.clone(),
202 context: self.context.clone(),
203 };
204
205 event_loop.register(
206 &mut *self.listener.lock(),
207 listener_token,
208 Interest::READABLE,
209 listener_handler,
210 )?;
211
212 Ok(())
213 }
214
215 pub fn connection_count(&self) -> usize {
217 self.connection_counter.load(Ordering::SeqCst)
218 }
219}
220
221impl<H: NetworkHandler> ServerOperations for TcpServer<H> {
222 fn send_to(&self, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
224 if let Some(conn) = self.connections.get(&conn_id) {
225 let mut stream = conn.val().stream.lock();
226 stream.write_all(data)?;
227 } else {
228 return Err(Box::new(NetworkError::ConnectionNotFound(conn_id)));
229 }
230 Ok(())
231 }
232
233 fn close_connection(&self, event_loop: &EventLoop, conn_id: ConnectionId) -> Result<()> {
235 if let Some(conn) = self.connections.remove(&conn_id) {
236 let mut stream = conn.val().stream.lock();
237
238 let _ = event_loop.deregister(&mut *stream, conn.val().token);
239 let _ = stream.shutdown(std::net::Shutdown::Both);
240
241 if let Err(e) = self.handler.on_disconnect(&self.context, conn_id) {
242 self.handler.on_error(
243 &self.context,
244 Some(conn_id),
245 NetworkError::HandlerError(format!("on_disconnect: {}", e)),
246 );
247 }
248
249 let _ = self
250 .handler
251 .on_event(&self.context, NetworkEvent::ConnectionClosed(conn_id));
252 }
253 Ok(())
254 }
255
256 fn broadcast(&self, data: &[u8]) -> Result<()> {
258 for conn in self.connections.iter() {
259 let mut stream = conn.val().stream.lock();
260 if let Err(e) = stream.write_all(data) {
261 self.handler.on_error(
262 &self.context,
263 Some(*conn.key()),
264 NetworkError::Io(Box::new(e)),
265 );
266 }
267 }
268 Ok(())
269 }
270}
271
272struct TcpConnection {
274 stream: Arc<Mutex<TcpStream>>,
275 token: Token,
276 #[allow(dead_code)]
277 peer_addr: SocketAddr,
278}
279
280struct TcpListenerHandler<H: NetworkHandler> {
282 listener: Arc<Mutex<TcpListener>>,
283 connections: Arc<LockfreeMap<ConnectionId, TcpConnection>>,
284 handler: Arc<H>,
285 config: TcpServerConfig,
286 buffer_pool: ObjectPool<Vec<u8>>,
287 next_conn_id: Arc<AtomicU64>,
288 event_loop: Weak<EventLoop>,
289 connection_counter: Arc<AtomicUsize>,
290 context: Arc<ServerContext>,
291}
292
293unsafe impl<H: NetworkHandler> Send for TcpListenerHandler<H> {}
295unsafe impl<H: NetworkHandler> Sync for TcpListenerHandler<H> {}
296
297impl<H: NetworkHandler> EventHandler for TcpListenerHandler<H> {
298 fn handle_event(&self, event: &Event) {
299 if !event.is_readable() {
300 return;
301 }
302
303 loop {
304 let listener = self.listener.lock();
305
306 match listener.accept() {
307 Ok((stream, peer_addr)) => {
308 if let Some(max) = self.config.max_connections {
310 let mut accepted = false;
311 loop {
312 let current = self.connection_counter.load(Ordering::SeqCst);
313 if current >= max {
314 self.handler.on_error(
315 &self.context,
316 None,
317 NetworkError::MaxConnectionsReached(peer_addr),
318 );
319 break;
320 }
321 match self.connection_counter.compare_exchange(
322 current,
323 current + 1,
324 Ordering::SeqCst,
325 Ordering::SeqCst,
326 ) {
327 Ok(_) => {
328 accepted = true;
329 break;
330 }
331 Err(_) => continue,
332 }
333 }
334 if !accepted {
335 continue;
336 }
337 } else {
338 self.connection_counter.fetch_add(1, Ordering::SeqCst);
339 }
340
341 if let Err(e) = stream.set_nodelay(self.config.no_delay) {
342 self.handler.on_error(
343 &self.context,
344 None,
345 NetworkError::Configuration(format!(
346 "Failed to set TCP_NODELAY: {}",
347 e
348 )),
349 );
350 }
351
352 let conn_id = ConnectionId(self.next_conn_id.fetch_add(1, Ordering::SeqCst));
353 let token = Token(conn_id.as_u64() as usize);
354
355 let stream_arc = Arc::new(Mutex::new(stream));
356
357 let conn_handler = TcpConnectionHandler {
358 conn_id,
359 stream: stream_arc.clone(),
360 connections: self.connections.clone(),
361 handler: self.handler.clone(),
362 buffer_pool: self.buffer_pool.clone(),
363 event_loop: self.event_loop.clone(),
364 connection_counter: self.connection_counter.clone(),
365 context: self.context.clone(),
366 };
367
368 let event_loop = if let Some(arc) = self.event_loop.upgrade() {
369 arc
370 } else {
371 self.handler
372 .on_error(&self.context, None, NetworkError::EventLoopGone);
373 self.connection_counter.fetch_sub(1, Ordering::SeqCst);
374 continue;
375 };
376
377 if let Err(e) = event_loop.register(
378 &mut *stream_arc.lock(),
379 token,
380 Interest::READABLE | Interest::WRITABLE,
381 conn_handler,
382 ) {
383 self.handler
384 .on_error(&self.context, Some(conn_id), NetworkError::Io(e));
385 self.connection_counter.fetch_sub(1, Ordering::SeqCst);
386 continue;
387 }
388
389 let conn = TcpConnection {
390 stream: stream_arc,
391 token,
392 peer_addr,
393 };
394 self.connections.insert(conn_id, conn);
395
396 let _ = self.handler.on_event(
397 &self.context,
398 NetworkEvent::ConnectionEstablished(conn_id, peer_addr),
399 );
400
401 if let Err(e) = self.handler.on_connect(&self.context, conn_id) {
402 self.handler.on_error(
403 &self.context,
404 Some(conn_id),
405 NetworkError::HandlerError(format!("on_connect: {}", e)),
406 );
407
408 if let Some(conn) = self.connections.remove(&conn_id) {
410 let mut stream = conn.val().stream.lock();
411 let _ = event_loop.deregister(&mut *stream, conn.val().token);
412 }
413 self.connection_counter.fetch_sub(1, Ordering::SeqCst);
414 continue;
415 }
416 }
417 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
418 break;
419 }
420 Err(e) => {
421 self.handler
422 .on_error(&self.context, None, NetworkError::Accept(Box::new(e)));
423 break;
424 }
425 }
426 }
427 }
428}
429
430struct TcpConnectionHandler<H: NetworkHandler> {
432 conn_id: ConnectionId,
433 stream: Arc<Mutex<TcpStream>>,
434 connections: Arc<LockfreeMap<ConnectionId, TcpConnection>>,
435 handler: Arc<H>,
436 buffer_pool: ObjectPool<Vec<u8>>,
437 event_loop: Weak<EventLoop>,
438 connection_counter: Arc<AtomicUsize>,
439 context: Arc<ServerContext>,
440}
441
442unsafe impl<H: NetworkHandler> Send for TcpConnectionHandler<H> {}
443unsafe impl<H: NetworkHandler> Sync for TcpConnectionHandler<H> {}
444
445impl<H: NetworkHandler> EventHandler for TcpConnectionHandler<H> {
446 fn handle_event(&self, event: &Event) {
447 let is_readable = event.is_readable();
448 let is_writable = event.is_writable();
449
450 if is_readable {
451 self.handle_read();
452 }
453
454 if is_writable {
455 if let Err(e) = self.handler.on_writable(&self.context, self.conn_id) {
456 self.handler.on_error(
457 &self.context,
458 Some(self.conn_id),
459 NetworkError::HandlerError(format!("on_writable: {}", e)),
460 );
461 }
462 }
463 }
464}
465
466impl<H: NetworkHandler> TcpConnectionHandler<H> {
467 fn handle_read(&self) {
468 let mut buffer: PooledObject<Vec<u8>> = self.buffer_pool.acquire();
469
470 let read_result = {
471 let mut stream = self.stream.lock();
472 stream.read(buffer.as_mut())
473 };
474
475 match read_result {
476 Ok(0) => {
477 self.disconnect();
479 }
480 Ok(n) => {
481 if let Err(e) =
482 self.handler
483 .on_data(&self.context, self.conn_id, &buffer.as_ref()[..n])
484 {
485 self.handler.on_error(
486 &self.context,
487 Some(self.conn_id),
488 NetworkError::HandlerError(format!("on_data: {}", e)),
489 );
490 self.disconnect();
491 }
492 }
493 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
494 }
496 Err(e) => {
497 self.handler.on_error(
498 &self.context,
499 Some(self.conn_id),
500 NetworkError::Io(Box::new(e)),
501 );
502 self.disconnect();
503 }
504 }
505 }
506
507 fn disconnect(&self) {
508 if let Some(conn) = self.connections.remove(&self.conn_id) {
509 self.connection_counter.fetch_sub(1, Ordering::SeqCst);
510 if let Some(event_loop) = self.event_loop.upgrade() {
511 let mut stream = conn.val().stream.lock();
512 let _ = event_loop.deregister(&mut *stream, conn.val().token);
513 }
514
515 if let Err(e) = self.handler.on_disconnect(&self.context, self.conn_id) {
516 self.handler.on_error(
517 &self.context,
518 Some(self.conn_id),
519 NetworkError::HandlerError(format!("on_disconnect: {}", e)),
520 );
521 }
522
523 let _ = self
524 .handler
525 .on_event(&self.context, NetworkEvent::ConnectionClosed(self.conn_id));
526 }
527 }
528}
529
530#[derive(Clone)]
532pub struct TcpClient<H: NetworkHandler> {
533 stream: Arc<Mutex<Option<TcpStream>>>,
534 handler: Arc<H>,
535 buffer_pool: ObjectPool<Vec<u8>>,
536 conn_id: ConnectionId,
537 context: Arc<ServerContext>,
538}
539
540impl<H: NetworkHandler> TcpClient<H> {
541 pub fn connect(addr: SocketAddr, handler: H) -> Result<Self> {
542 let stream = TcpStream::connect(addr)?;
543
544 Ok(Self {
545 stream: Arc::new(Mutex::new(Some(stream))),
546 handler: Arc::new(handler),
547 buffer_pool: ObjectPool::new(5, || vec![0; 8192]),
548 conn_id: ConnectionId::new(1),
549 context: Arc::new(ServerContext {
550 server: RwLock::new(None),
551 event_loop: RwLock::new(None),
552 }),
553 })
554 }
555
556 pub fn start(&mut self, event_loop: &Arc<EventLoop>, token: Token) -> Result<()> {
557 *self.context.event_loop.write().unwrap() = Some(event_loop.clone());
558 *self.context.server.write().unwrap() = None;
559
560 let handler = TcpClientHandler {
561 conn_id: self.conn_id,
562 stream: self.stream.clone(),
563 handler: self.handler.clone(),
564 buffer_pool: self.buffer_pool.clone(),
565 event_loop: Arc::downgrade(event_loop),
566 context: self.context.clone(),
567 };
568
569 let mut stream_guard = self.stream.lock();
570 let stream = stream_guard
571 .as_mut()
572 .ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "TCP stream is None"))?;
573
574 event_loop.register(
575 stream,
576 token,
577 Interest::READABLE | Interest::WRITABLE,
578 handler,
579 )?;
580
581 self.handler.on_connect(&self.context, self.conn_id)?;
582
583 Ok(())
584 }
585
586 pub fn send(&self, data: &[u8]) -> Result<()> {
587 let mut stream_guard = self.stream.lock();
588 if let Some(stream) = stream_guard.as_mut() {
589 stream.write_all(data)?;
590 }
591 Ok(())
592 }
593
594 pub fn disconnect(&self) -> Result<()> {
595 *self.stream.lock() = None;
596 self.handler.on_disconnect(&self.context, self.conn_id)?;
597 Ok(())
598 }
599}
600
601struct TcpClientHandler<H: NetworkHandler> {
602 conn_id: ConnectionId,
603 stream: Arc<Mutex<Option<TcpStream>>>,
604 handler: Arc<H>,
605 buffer_pool: ObjectPool<Vec<u8>>,
606 event_loop: Weak<EventLoop>,
607 context: Arc<ServerContext>,
608}
609
610impl<H: NetworkHandler> EventHandler for TcpClientHandler<H> {
611 fn handle_event(&self, event: &Event) {
612 if event.is_readable() {
613 self.handle_read();
614 }
615 if event.is_writable() {
616 if let Err(e) = self.handler.on_writable(&self.context, self.conn_id) {
617 self.handler.on_error(
618 &self.context,
619 Some(self.conn_id),
620 NetworkError::HandlerError(format!("on_writable: {}", e)),
621 );
622 }
623 }
624 }
625}
626
627impl<H: NetworkHandler> TcpClientHandler<H> {
628 fn handle_read(&self) {
629 let mut buffer: PooledObject<Vec<u8>> = self.buffer_pool.acquire();
630
631 let read_result = {
632 let mut stream_guard = self.stream.lock();
633 if let Some(stream) = stream_guard.as_mut() {
634 stream.read(buffer.as_mut())
635 } else {
636 return;
637 }
638 };
639
640 match read_result {
641 Ok(0) => {
642 let _ = self
644 .handler
645 .on_event(&self.context, NetworkEvent::ConnectionClosed(self.conn_id));
646 self.disconnect();
647 }
648 Ok(n) => {
649 if let Err(e) =
650 self.handler
651 .on_data(&self.context, self.conn_id, &buffer.as_ref()[..n])
652 {
653 self.handler.on_error(
654 &self.context,
655 Some(self.conn_id),
656 NetworkError::HandlerError(format!("on_data: {}", e)),
657 );
658 self.disconnect();
659 }
660 }
661 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
662 }
664 Err(e) => {
665 self.handler.on_error(
666 &self.context,
667 Some(self.conn_id),
668 NetworkError::Io(Box::new(e)),
669 );
670 self.disconnect();
671 }
672 }
673 }
674
675 fn disconnect(&self) {
676 let mut stream_guard = self.stream.lock();
677
678 if let Some(stream) = stream_guard.as_mut() {
679 if let Some(event_loop) = self.event_loop.upgrade() {
680 let _ = event_loop.deregister(stream, Token(self.conn_id.as_u64() as usize));
681 }
682 }
683
684 *stream_guard = None;
685
686 if let Err(e) = self.handler.on_disconnect(&self.context, self.conn_id) {
687 self.handler.on_error(
688 &self.context,
689 Some(self.conn_id),
690 NetworkError::HandlerError(format!("on_disconnect: {}", e)),
691 );
692 }
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699 use mill_io::EventLoop;
700 use std::sync::{Arc, Condvar, Mutex};
701 use std::thread;
702 use std::time::Duration;
703
704 struct TestHandler {
705 on_connect_cb: Option<Box<dyn Fn() + Send + Sync>>,
706 #[allow(clippy::type_complexity)]
707 on_data_cb: Option<Box<dyn Fn(&ServerContext, ConnectionId, &[u8]) + Send + Sync>>,
708 }
709
710 impl TestHandler {
711 fn new() -> Self {
712 Self {
713 on_connect_cb: None,
714 on_data_cb: None,
715 }
716 }
717
718 fn with_on_connect<F>(mut self, f: F) -> Self
719 where
720 F: Fn() + Send + Sync + 'static,
721 {
722 self.on_connect_cb = Some(Box::new(f));
723 self
724 }
725
726 fn with_on_data<F>(mut self, f: F) -> Self
727 where
728 F: Fn(&ServerContext, ConnectionId, &[u8]) + Send + Sync + 'static,
729 {
730 self.on_data_cb = Some(Box::new(f));
731 self
732 }
733 }
734
735 impl NetworkHandler for TestHandler {
736 fn on_connect(&self, _ctx: &ServerContext, _conn_id: ConnectionId) -> Result<()> {
737 if let Some(cb) = &self.on_connect_cb {
738 cb();
739 }
740 Ok(())
741 }
742
743 fn on_data(&self, ctx: &ServerContext, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
744 if let Some(cb) = &self.on_data_cb {
745 cb(ctx, conn_id, data);
746 }
747 Ok(())
748 }
749 }
750
751 #[test]
752 fn test_tcp_server_client_echo() {
753 let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap());
754
755 let server_connected = Arc::new((Mutex::new(false), Condvar::new()));
757 let client_received = Arc::new((Mutex::new(false), Condvar::new()));
758 let received_data = Arc::new(Mutex::new(Vec::new()));
759
760 let sc = server_connected.clone();
762 let server_handler = TestHandler::new()
763 .with_on_connect(move || {
764 let (lock, cvar) = &*sc;
765 let mut started = lock.lock().unwrap();
766 *started = true;
767 cvar.notify_all();
768 })
769 .with_on_data(|ctx, conn_id, data| {
770 ctx.send_to(conn_id, data).unwrap();
771 });
772
773 let config = TcpServerConfig::builder()
775 .address("127.0.0.1:0".parse().unwrap())
776 .build();
777 let server = Arc::new(TcpServer::new(config, server_handler).unwrap());
778 let server_addr = server.local_addr().unwrap();
779
780 server.clone().start(&event_loop, Token(1)).unwrap();
781
782 let cr = client_received.clone();
784 let rd = received_data.clone();
785 let client_handler = TestHandler::new().with_on_data(move |_, _, data| {
786 let mut r_data = rd.lock().unwrap();
787 r_data.extend_from_slice(data);
788 let (lock, cvar) = &*cr;
789 let mut received = lock.lock().unwrap();
790 *received = true;
791 cvar.notify_all();
792 });
793
794 let mut client = TcpClient::connect(server_addr, client_handler).unwrap();
795 client.start(&event_loop, Token(2)).unwrap();
796
797 let el_clone = event_loop.clone();
798 thread::spawn(move || {
799 el_clone.run().unwrap();
800 });
801
802 {
803 let (lock, cvar) = &*server_connected;
804 let mut started = lock.lock().unwrap();
805 while !*started {
806 let result = cvar.wait_timeout(started, Duration::from_secs(2)).unwrap();
807 if result.1.timed_out() {
808 panic!("Server did not accept connection in time");
809 }
810 started = result.0;
811 }
812 }
813
814 let msg = b"Hello, World!";
815 client.send(msg).unwrap();
816
817 {
818 let (lock, cvar) = &*client_received;
819 let mut received = lock.lock().unwrap();
820 while !*received {
821 let result = cvar.wait_timeout(received, Duration::from_secs(2)).unwrap();
822 if result.1.timed_out() {
823 panic!("Client did not receive data in time");
824 }
825 received = result.0;
826 }
827 }
828
829 let data = received_data.lock().unwrap();
830 assert_eq!(*data, msg);
831
832 event_loop.stop();
833 }
834}