hydra_sync/server.rs
1use crate::protocol::{Role, perform_server_handshake, read_join_frame, read_raw_frame_into};
2use crate::session::Sessions;
3use crate::{BUFFER_SIZE, error, info, trace, warn};
4use anyhow::Result;
5use bytes::BytesMut;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::broadcast::error::RecvError;
12// TODO; handles backpressure "properly", implement handler traits for invoking user defined fn for some events, logging can be better.
13
14/// A light-weight multi-threaded SPMC (Single Producer Multiple Consumer) E2E relay server.
15///
16/// `HydraServer` implements a zero-copy tokio::broadcast relay that:
17/// - Accepts one producer and multiple consumers per session
18/// - Routes data from producer → all connected consumers using Arc-backed `Bytes`
19/// - Handles backpressure and slow consumers with broadcast channel `RecvError::Lagged(n)`
20/// - Enforces connection limits and per-payload size constraints
21/// ```no_run
22/// use hydra_sync::server::HydraServer;
23///
24/// #[tokio::main]
25/// async fn main() {
26/// // choose an OS-assigned port
27/// let (server, addr) = HydraServer::bind_default().await.unwrap();
28///
29/// println!("Server running on: {}", addr);
30/// tokio::spawn(async move{ server.run(500).await });
31/// }
32/// ```
33/// Internals
34/// - Producer: Sends encrypted frames → broadcast channel
35/// - Consumers: Subscribe to broadcast, receive clones of `Arc<Bytes>` (zero-copy)
36/// - Sessions: Keyed by 64-byte session_id, one producer per session allowed
37/// - Errors & Logs: Error are predictable and handled gracefully by closing connections and logging without crashing the server
38///
39pub struct HydraServer {
40 /// internal tcp listener for accepting incoming connections
41 listener: TcpListener,
42 /// session management for producers and consumers
43 sessions: Arc<Sessions>,
44 /// atomic counter to track active connections for enforcing limits
45 connections: Arc<AtomicUsize>,
46 /// maximum concurrent connections allowed to prevent resource exhaustion
47 max_connections: usize,
48 /// maximum allowed payload size for incoming frames to prevent abuse
49 max_payload_length: usize,
50 /// capacity of the broadcast channel for each session to handle backpressure
51 broadcast_capacity: usize,
52}
53
54impl HydraServer {
55 /// Binds the relay server with defaults
56 /// - addr: OS-assigned port
57 /// - max_connections: 32
58 /// - max_payload_length: 64 MiB
59 /// - broadcast_capacity: 256 messages
60 pub async fn bind_default() -> Result<(Self, SocketAddr)> {
61 let addr = &"127.0.0.1:0".parse::<SocketAddr>()?;
62 let server = HydraServer::bind(addr, 64 * 1024 * 1024, 32, 256).await?;
63 let local_addr = server.listener.local_addr()?;
64 Ok((server, local_addr))
65 }
66
67 /// Binds the relay server to the specified socket address and initializes internal state
68 pub async fn bind(
69 addr: &SocketAddr,
70 max_payload_length: usize,
71 max_connections: usize,
72 broadcast_capacity: usize,
73 ) -> Result<Self> {
74 let listener = TcpListener::bind(addr).await?;
75 Ok(Self {
76 listener,
77 sessions: Arc::new(Sessions::init()),
78 connections: Arc::new(AtomicUsize::new(0)),
79 max_payload_length,
80 max_connections,
81 broadcast_capacity,
82 })
83 }
84
85 /// Main server loop to accept incoming connections, spawn thread handlers, perform handshakes & session creation
86 /// - `connections_timeout_ms` is the delay before client retries to accept new connections on server when the limit is reached
87 /// - Producer errors; If read fails from client or broadcast send fails, the connection is closed and the error is logged.
88 /// - Producer errors; If writing to client fails or broadcast lags or closed, the connection is closed and the error is logged.
89 /// - EOF check are gracefully handled by closing the connection without logging an error.
90 /// - `LOG_LEVEL` & `LOG_FILE` env vars can be set to control logging verbosity and output file (defaults to `info` level and stdout, not file).
91 pub async fn run(self, connections_timeout_ms: u64) -> Result<()> {
92 loop {
93 if self.connections.fetch_add(1, Ordering::Relaxed) >= self.max_connections {
94 self.connections.fetch_sub(1, Ordering::Relaxed);
95 warn!(
96 "Max connections reached: {}, waiting {} ms before retrying",
97 self.max_connections, connections_timeout_ms
98 );
99 tokio::time::sleep(std::time::Duration::from_millis(connections_timeout_ms)).await;
100 continue;
101 }
102
103 match self.listener.accept().await {
104 Ok((stream, peer_addr)) => {
105 stream.set_nodelay(true).ok();
106 let sessions = Arc::clone(&self.sessions);
107 let connections = Arc::clone(&self.connections);
108 // spawn handler thread
109 tokio::spawn(async move {
110 trace!("Accepted connection from: {}", peer_addr);
111 if let Err(e) = Self::handle_connection(
112 stream,
113 sessions,
114 self.max_payload_length,
115 self.broadcast_capacity,
116 )
117 .await
118 {
119 error!("Connection handling error: {} from: {}", e, peer_addr);
120 }
121 connections.fetch_sub(1, Ordering::Release);
122 });
123 }
124 Err(e) => {
125 self.connections.fetch_sub(1, Ordering::Release);
126 error!("Connection accepting error: {}", e);
127 }
128 }
129 }
130 }
131
132 /// Handles an individual client connection, performing handshake, role determination, and routing to producer/consumer handlers
133 async fn handle_connection(
134 mut stream: TcpStream,
135 sessions: Arc<Sessions>,
136 max_payload_length: usize,
137 broadcast_capacity: usize,
138 ) -> Result<()> {
139 stream.set_nodelay(true)?;
140 let mut mem_pool = BytesMut::with_capacity(max_payload_length + 4); // 4 bytes prefix space
141 let peer_addr = stream.peer_addr()?;
142 let (read_h, mut writer_raw) = stream.split();
143 let mut reader = BufReader::with_capacity(BUFFER_SIZE, read_h);
144
145 let transport_key = perform_server_handshake(&mut reader, &mut writer_raw).await?;
146 let (role, session_id) =
147 read_join_frame(&mut reader, &transport_key, &mut mem_pool).await?;
148
149 match role {
150 Role::Producer => {
151 info!(
152 "Producer addr: {} joined session: {}",
153 peer_addr,
154 hex::encode(session_id)
155 );
156 Self::run_producer(
157 &mut reader,
158 sessions,
159 session_id,
160 &peer_addr,
161 mem_pool,
162 max_payload_length,
163 broadcast_capacity,
164 )
165 .await
166 }
167 Role::Consumer => {
168 info!(
169 "Consumer addr: {} joined session: {}",
170 peer_addr,
171 hex::encode(session_id)
172 );
173 Self::run_consumer(
174 &mut reader,
175 &mut writer_raw,
176 sessions,
177 session_id,
178 &peer_addr,
179 )
180 .await
181 }
182 Role::Admin => Ok(()), // todo; implement this
183 }
184 }
185
186 /// Handles producer clients: reads encrypted frames, decrypts, and broadcasts to consumers via the session's broadcast channel
187 async fn run_producer<R: AsyncReadExt + Unpin>(
188 reader: &mut R,
189 sessions: Arc<Sessions>,
190 session_id: [u8; 64],
191 client_addr: &SocketAddr,
192 mut mem_pool: BytesMut,
193 max_payload_length: usize,
194 broadcast_capacity: usize,
195 ) -> Result<()> {
196 let tx = sessions.try_register_producer(session_id, broadcast_capacity)?;
197
198 loop {
199 // read from client read stream (just channel, no intervention)
200 let n = match read_raw_frame_into(reader, &mut mem_pool, max_payload_length).await {
201 Ok(n) => n,
202 Err(e) => {
203 tx.closed().await;
204 error!(
205 "Producer addr: {} session: {} read: {e}",
206 client_addr,
207 hex::encode(session_id)
208 );
209 break;
210 }
211 };
212
213 // write to broadcast channel
214 if let Err(e) = tx.send(mem_pool.split_to(n).freeze()) {
215 tx.closed().await; // close channel to signal consumers
216 warn!(
217 "Producer addr: {} session: {} broadcast: {e}",
218 client_addr,
219 hex::encode(session_id)
220 );
221 break;
222 }
223 }
224
225 // clean up
226 sessions.unregister_producer(session_id);
227 Ok(())
228 }
229
230 /// Handles consumer clients: subscribes to the session's broadcast channel and writes received data to the client
231 async fn run_consumer<R: AsyncReadExt + Unpin, W: AsyncWriteExt + Unpin>(
232 reader: &mut R,
233 writer: &mut W,
234 sessions: Arc<Sessions>,
235 session_id: [u8; 64],
236 client_addr: &SocketAddr,
237 ) -> Result<()> {
238 let tx = sessions
239 .get_for_consumer(session_id)
240 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
241
242 let mut rx = tx.subscribe();
243
244 let mut peek = [0u8; 1];
245 loop {
246 tokio::select! {
247 // poll from channel
248 result = rx.recv() => {
249 match result {
250 Ok(data) => {
251 // try writing to client read stream first or fail
252 if let Err(e) = writer.write_all(&data).await {
253 let _ = writer.shutdown().await;
254 error!("Consumer addr: {} session: {} write: {e}", client_addr, hex::encode(session_id));
255 break;
256 }
257 // let _ = writer.flush().await;
258 }
259 Err(RecvError::Lagged(n)) => {
260 let _ = writer.flush().await; // flush whatever remaining
261 let _ = writer.shutdown().await;
262 warn!("Consumer addr: {} session: {} lagged by {n} messages", client_addr, hex::encode(session_id));
263 break;
264 }
265 Err(RecvError::Closed) => {
266 let _ = writer.flush().await; // flush whatever b4 exiting
267 let _ = writer.shutdown().await;
268 info!("Producer for session: {} closed, consumer addr: {}", hex::encode(session_id), client_addr);
269 break;
270 },
271 }
272 }
273 result = reader.read(&mut peek) => {
274 match result {
275 Ok(0) => break, // eof check
276 Err(e) => {
277 error!("Consumer addr: {} session: {} read: {e}", client_addr, hex::encode(session_id));
278 break;
279 }
280 _ => {}
281 }
282 }
283 }
284 }
285
286 Ok(())
287 }
288}