Skip to main content

hyperion_framework/network/
server.rs

1// -------------------------------------------------------------------------------------------------
2// Hyperion Framework
3// https://github.com/robert-hannah/hyperion-framework
4//
5// A lightweight component-based TCP framework for building service-oriented Rust applications with
6// CLI control, async messaging, and lifecycle management.
7//
8// Copyright 2025 Robert Hannah
9//
10// Licensed under the Apache License, Version 2.0 (the "License");
11// you may not use this file except in compliance with the License.
12// You may obtain a copy of the License at
13//
14//     http://www.apache.org/licenses/LICENSE-2.0
15//
16// Unless required by applicable law or agreed to in writing, software
17// distributed under the License is distributed on an "AS IS" BASIS,
18// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19// See the License for the specific language governing permissions and
20// limitations under the License.
21// -------------------------------------------------------------------------------------------------
22
23// Standard
24use std::sync::Arc as StdArc;
25use std::sync::atomic::{AtomicUsize, Ordering};
26
27// Package
28use serde::{Serialize, de::DeserializeOwned};
29use tokio::io::AsyncReadExt;
30use tokio::net::{TcpListener, TcpStream};
31use tokio::sync::{Notify, mpsc};
32use tokio::task::JoinSet;
33
34// Local
35use crate::containerisation::container_state::ContainerState;
36use crate::network::serialiser;
37use crate::utilities::tx_sender::add_to_tx_with_retry;
38
39#[derive(Debug, Clone)]
40pub struct Server<T> {
41    address: String,
42    server_tx: mpsc::Sender<T>,
43    container_state: StdArc<AtomicUsize>,
44    container_state_notify: StdArc<Notify>,
45}
46
47// ========================================================================
48//    A Server will ONLY receive messages
49// ========================================================================
50
51impl<T> Server<T>
52where
53    T: Clone + Send + Serialize + DeserializeOwned + Sync + 'static,
54{
55    // ========================================================================
56    //    Create new Server
57    // ========================================================================
58    pub fn new(
59        address: String,
60        server_tx: mpsc::Sender<T>,
61        container_state: StdArc<AtomicUsize>,
62        container_state_notify: StdArc<Notify>,
63    ) -> StdArc<Self> {
64        StdArc::new(Self {
65            address,
66            server_tx,
67            container_state,
68            container_state_notify,
69        })
70    }
71
72    // These functions are static as they only handle StdArc<Server<T>>. This is because self is a
73    // server and not a StdArc<Server<T>>.
74    // I think using StdArc on objects for multithreaded async tasks is overall a good idea though,
75    // maintains atomicity. It's just a bit of an ugly implementation.
76
77    // ========================================================================
78    //    STATIC: Run multi-thread async Server
79    // ========================================================================
80    pub async fn run(arc_server: StdArc<Self>) -> Result<(), Box<dyn std::error::Error>> {
81        let listener = TcpListener::bind(arc_server.address.clone()).await?; // Connect to port
82        log::trace!("Server listening on {}", arc_server.address);
83
84        let mut join_set = JoinSet::new(); // Track spawned tasks
85
86        loop {
87            tokio::select! {     // Loop fires on incoming messages from listener or shutdown signal
88                result = listener.accept() => {
89                    match result {
90                        Ok((stream, addr)) => {
91                            log::info!("Accepted connection from {addr}");
92                            // Clones an Arc pointer, increasing the reference count instead of doing a deep copy of the class
93                            // Doing this as String and Sender are expensive to clone and StdArc::clone is far more atomic
94                            let handler = StdArc::clone(&arc_server);
95                            join_set.spawn(async move {
96                                // Spawn new handler for each client connection
97                                Server::handle_stream_notification(&handler, stream).await;
98                            });
99                        },
100
101                        Err(e) => {
102                            log::error!("Failed to accept connection: {e:?}");
103                            continue;  // Keep listening instead of crashing
104                        }
105                    }
106                },
107
108                // Listen for state changes
109                _ = arc_server.container_state_notify.notified() => {
110                    // Check if state has been set to ShuttingDown
111                    if ContainerState::from(arc_server.container_state.load(Ordering::SeqCst)) == ContainerState::ShuttingDown {
112                        log::info!("Server {} graceful shutdown initiated", arc_server.address);
113                        drop(listener);     // Extra explicit way to drop the connection
114                        break;              // Breaks out of listening loop
115                    }
116                }
117            }
118        }
119
120        log::info!("Waiting for ongoing tasks to complete...");
121        while join_set.join_next().await.is_some() {} // Ensures all client handlers die before shutdown
122        log::info!("Server {} shut down gracefully.", arc_server.address);
123
124        // Server failure will bring the container down with it
125        arc_server
126            .container_state
127            .store(ContainerState::ShuttingDown as usize, Ordering::SeqCst);
128        arc_server.container_state_notify.notify_waiters();
129
130        Ok(())
131    }
132
133    // ========================================================================
134    //    STATIC: Deserialise incoming messages from clients
135    // ========================================================================
136    async fn handle_stream_notification(arc_server: &StdArc<Self>, mut stream: TcpStream) {
137        let mut buf = vec![0u8; 65_536]; // This buffer can be increased/reduced depending on use case (currently 64KB) - is a vec to put it on heap
138        let mut message_buf = Vec::new(); // Persistent buffer for accumulating message bytes - for handling partial messages
139
140        loop {
141            tokio::select! {
142                // Attempt to read from stream (All results must be wrapped in Ok(),
143                // other than inactivity err)
144                read_result = stream.read(&mut buf) => {
145                    match read_result {
146                        Ok(0) => {
147                            log::info!("Client disconnected gracefully.");
148                            break;
149                        }
150                        Ok(n) => {
151                            message_buf.extend_from_slice(&buf[..n]);
152
153                            // Attempt to process all complete messages
154                            while message_buf.len() >= 4 {
155                                // Read the length prefix (first 4 bytes)
156                                let len_bytes = &message_buf[..4];
157                                let msg_len = u32::from_be_bytes(len_bytes.try_into().unwrap()) as usize;
158
159                                // Do we have the full message yet?
160                                if message_buf.len() < 4 + msg_len {
161                                    break; // Wait for more data
162                                }
163
164                                // Extract message bytes and remove them from the buffer
165                                let msg_bytes = message_buf[4..4 + msg_len].to_vec();
166                                message_buf.drain(..4 + msg_len);
167
168                                // Try to deserialize
169                                match serialiser::deserialise_message(&msg_bytes) {
170                                    Ok(msg) => {
171                                        add_to_tx_with_retry(&arc_server.server_tx, &msg, "Server", "Main").await;
172                                    }
173                                    Err(e) => {
174                                        log::warn!("Message (raw): {}", String::from_utf8_lossy(&msg_bytes));
175                                        log::error!("Failed to deserialise message: {e:?}");
176                                        continue;
177                                    }
178                                }
179                            }
180                        }
181                        Err(e) => {
182                            log::error!("Failed to read from socket; err = {e:?}");
183                            break;
184                        }
185                    }
186                }
187            }
188        }
189    }
190}