hyperion_framework/network/server.rs
1// -------------------------------------------------------------------------------------------------
2// Hyperion Framework
3// https://github.com/robert-hannah/hyperion-framework
4//
5// A lightweight component-based TCP framework for building service-oriented Rust applications with
6// CLI control, async messaging, and lifecycle management.
7//
8// Copyright 2025 Robert Hannah
9//
10// Licensed under the Apache License, Version 2.0 (the "License");
11// you may not use this file except in compliance with the License.
12// You may obtain a copy of the License at
13//
14// http://www.apache.org/licenses/LICENSE-2.0
15//
16// Unless required by applicable law or agreed to in writing, software
17// distributed under the License is distributed on an "AS IS" BASIS,
18// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19// See the License for the specific language governing permissions and
20// limitations under the License.
21// -------------------------------------------------------------------------------------------------
22
23// Standard
24use std::sync::Arc as StdArc;
25use std::sync::atomic::{AtomicUsize, Ordering};
26
27// Package
28use serde::{Serialize, de::DeserializeOwned};
29use tokio::io::AsyncReadExt;
30use tokio::net::{TcpListener, TcpStream};
31use tokio::sync::{Notify, mpsc};
32use tokio::task::JoinSet;
33
34// Local
35use crate::containerisation::container_state::ContainerState;
36use crate::network::serialiser;
37use crate::utilities::tx_sender::add_to_tx_with_retry;
38
39#[derive(Debug, Clone)]
40pub struct Server<T> {
41 address: String,
42 server_tx: mpsc::Sender<T>,
43 container_state: StdArc<AtomicUsize>,
44 container_state_notify: StdArc<Notify>,
45}
46
47// ========================================================================
48// A Server will ONLY receive messages
49// ========================================================================
50
51impl<T> Server<T>
52where
53 T: Clone + Send + Serialize + DeserializeOwned + Sync + 'static,
54{
55 // ========================================================================
56 // Create new Server
57 // ========================================================================
58 pub fn new(
59 address: String,
60 server_tx: mpsc::Sender<T>,
61 container_state: StdArc<AtomicUsize>,
62 container_state_notify: StdArc<Notify>,
63 ) -> StdArc<Self> {
64 StdArc::new(Self {
65 address,
66 server_tx,
67 container_state,
68 container_state_notify,
69 })
70 }
71
72 // These functions are static as they only handle StdArc<Server<T>>. This is because self is a
73 // server and not a StdArc<Server<T>>.
74 // I think using StdArc on objects for multithreaded async tasks is overall a good idea though,
75 // maintains atomicity. It's just a bit of an ugly implementation.
76
77 // ========================================================================
78 // STATIC: Run multi-thread async Server
79 // ========================================================================
80 pub async fn run(arc_server: StdArc<Self>) -> Result<(), Box<dyn std::error::Error>> {
81 let listener = TcpListener::bind(arc_server.address.clone()).await?; // Connect to port
82 log::trace!("Server listening on {}", arc_server.address);
83
84 let mut join_set = JoinSet::new(); // Track spawned tasks
85
86 loop {
87 tokio::select! { // Loop fires on incoming messages from listener or shutdown signal
88 result = listener.accept() => {
89 match result {
90 Ok((stream, addr)) => {
91 log::info!("Accepted connection from {addr}");
92 // Clones an Arc pointer, increasing the reference count instead of doing a deep copy of the class
93 // Doing this as String and Sender are expensive to clone and StdArc::clone is far more atomic
94 let handler = StdArc::clone(&arc_server);
95 join_set.spawn(async move {
96 // Spawn new handler for each client connection
97 Server::handle_stream_notification(&handler, stream).await;
98 });
99 },
100
101 Err(e) => {
102 log::error!("Failed to accept connection: {e:?}");
103 continue; // Keep listening instead of crashing
104 }
105 }
106 },
107
108 // Listen for state changes
109 _ = arc_server.container_state_notify.notified() => {
110 // Check if state has been set to ShuttingDown
111 if ContainerState::from(arc_server.container_state.load(Ordering::SeqCst)) == ContainerState::ShuttingDown {
112 log::info!("Server {} graceful shutdown initiated", arc_server.address);
113 drop(listener); // Extra explicit way to drop the connection
114 break; // Breaks out of listening loop
115 }
116 }
117 }
118 }
119
120 log::info!("Waiting for ongoing tasks to complete...");
121 while join_set.join_next().await.is_some() {} // Ensures all client handlers die before shutdown
122 log::info!("Server {} shut down gracefully.", arc_server.address);
123
124 // Server failure will bring the container down with it
125 arc_server
126 .container_state
127 .store(ContainerState::ShuttingDown as usize, Ordering::SeqCst);
128 arc_server.container_state_notify.notify_waiters();
129
130 Ok(())
131 }
132
133 // ========================================================================
134 // STATIC: Deserialise incoming messages from clients
135 // ========================================================================
136 async fn handle_stream_notification(arc_server: &StdArc<Self>, mut stream: TcpStream) {
137 let mut buf = vec![0u8; 65_536]; // This buffer can be increased/reduced depending on use case (currently 64KB) - is a vec to put it on heap
138 let mut message_buf = Vec::new(); // Persistent buffer for accumulating message bytes - for handling partial messages
139
140 loop {
141 tokio::select! {
142 // Attempt to read from stream (All results must be wrapped in Ok(),
143 // other than inactivity err)
144 read_result = stream.read(&mut buf) => {
145 match read_result {
146 Ok(0) => {
147 log::info!("Client disconnected gracefully.");
148 break;
149 }
150 Ok(n) => {
151 message_buf.extend_from_slice(&buf[..n]);
152
153 // Attempt to process all complete messages
154 while message_buf.len() >= 4 {
155 // Read the length prefix (first 4 bytes)
156 let len_bytes = &message_buf[..4];
157 let msg_len = u32::from_be_bytes(len_bytes.try_into().unwrap()) as usize;
158
159 // Do we have the full message yet?
160 if message_buf.len() < 4 + msg_len {
161 break; // Wait for more data
162 }
163
164 // Extract message bytes and remove them from the buffer
165 let msg_bytes = message_buf[4..4 + msg_len].to_vec();
166 message_buf.drain(..4 + msg_len);
167
168 // Try to deserialize
169 match serialiser::deserialise_message(&msg_bytes) {
170 Ok(msg) => {
171 add_to_tx_with_retry(&arc_server.server_tx, &msg, "Server", "Main").await;
172 }
173 Err(e) => {
174 log::warn!("Message (raw): {}", String::from_utf8_lossy(&msg_bytes));
175 log::error!("Failed to deserialise message: {e:?}");
176 continue;
177 }
178 }
179 }
180 }
181 Err(e) => {
182 log::error!("Failed to read from socket; err = {e:?}");
183 break;
184 }
185 }
186 }
187 }
188 }
189 }
190}