hydra_sync/server.rs
1use crate::BUFFER_SIZE;
2use crate::protocol::{Role, perform_server_handshake, read_join_frame, read_raw_frame_into};
3use crate::session::Sessions;
4use anyhow::Result;
5use bytes::BytesMut;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::broadcast::error::RecvError;
12// TODO; handles backpressure "properly", implement handler traits for invoking user defined fn for some events
13
14/// A light-weight multi-threaded SPMC (Single Producer Multiple Consumer) E2E relay server.
15///
16/// `HydraServer` implements a zero-copy broadcast relay that:
17/// - Accepts one producer and multiple consumers per session
18/// - Routes data from producer → all connected consumers using Arc-backed `Bytes`
19/// - Handles backpressure and slow consumers with broadcast channel lagging
20/// - Enforces connection limits and per-payload size constraints
21///
22/// Internals
23/// - Producer: Sends encrypted frames → broadcast channel
24/// - Consumers: Subscribe to broadcast, receive clones of `Arc<Bytes>` (zero-copy)
25/// - Sessions: Keyed by 64-byte session_id, one producer per session allowed
26/// - Errors & Logs: Error are predictable and handled gracefully by closing connections and logging without crashing the server
27pub struct HydraServer {
28 /// internal tcp listener for accepting incoming connections
29 listener: TcpListener,
30 /// session management for producers and consumers
31 sessions: Arc<Sessions>,
32 /// atomic counter to track active connections for enforcing limits
33 connections: Arc<AtomicUsize>,
34 /// maximum concurrent connections allowed to prevent resource exhaustion
35 max_connections: usize,
36 /// maximum allowed payload size for incoming frames to prevent abuse
37 max_payload_length: usize,
38 /// capacity of the broadcast channel for each session to handle backpressure
39 broadcast_capacity: usize,
40}
41
42impl HydraServer {
43 /// Binds the relay server with defaults
44 /// - addr: OS-assigned port
45 /// - max_connections: 32
46 /// - max_payload_length: 64 MiB
47 /// - broadcast_capacity: 256 messages
48 pub async fn bind_default() -> Result<(Self, SocketAddr)> {
49 let addr = &"127.0.0.1:0".parse::<SocketAddr>()?;
50 let server = HydraServer::bind(addr, 64 * 1024 * 1024, 32, 256).await?;
51 let local_addr = server.listener.local_addr()?;
52 Ok((server, local_addr))
53 }
54
55 /// Binds the relay server to the specified socket address and initializes internal state
56 pub async fn bind(
57 addr: &SocketAddr,
58 max_payload_length: usize,
59 max_connections: usize,
60 broadcast_capacity: usize,
61 ) -> Result<Self> {
62 let listener = TcpListener::bind(addr).await?;
63 Ok(Self {
64 listener,
65 sessions: Arc::new(Sessions::init()),
66 connections: Arc::new(AtomicUsize::new(0)),
67 max_payload_length,
68 max_connections,
69 broadcast_capacity,
70 })
71 }
72
73 /// Main server loop to accept incoming connections, spawn thread handlers, perform handshakes & session creation
74 /// - `connections_timeout_ms` is the delay before client retries to accept new connections on server when the limit is reached
75 /// - Producer errors; If read fails from client or broadcast send fails, the connection is closed and the error is logged.
76 /// - Producer errors; If writing to client fails or broadcast lags or closed, the connection is closed and the error is logged.
77 /// - EOF check are gracefully handled by closing the connection without logging an error.
78 pub async fn run(self, connections_timeout_ms: u64) -> Result<()> {
79 loop {
80 if self.connections.fetch_add(1, Ordering::Acquire) >= self.max_connections {
81 self.connections.fetch_sub(1, Ordering::Release);
82 tokio::time::sleep(std::time::Duration::from_millis(connections_timeout_ms)).await;
83 continue;
84 }
85
86 match self.listener.accept().await {
87 Ok((stream, peer_addr)) => {
88 stream.set_nodelay(true).ok();
89 let sessions = Arc::clone(&self.sessions);
90 let connections = Arc::clone(&self.connections);
91 // spawn handler thread
92 tokio::spawn(async move {
93 if let Err(e) = Self::handle_connection(
94 stream,
95 sessions,
96 self.max_payload_length,
97 self.broadcast_capacity,
98 )
99 .await
100 {
101 eprintln!("Connection handling error: {} from: {}", e, peer_addr);
102 }
103 connections.fetch_sub(1, Ordering::Release);
104 });
105 }
106 Err(e) => {
107 self.connections.fetch_sub(1, Ordering::Release);
108 eprintln!("Connection accepting error: {}", e);
109 }
110 }
111 }
112 }
113
114 /// Handles an individual client connection, performing handshake, role determination, and routing to producer/consumer handlers
115 async fn handle_connection(
116 mut stream: TcpStream,
117 sessions: Arc<Sessions>,
118 max_payload_length: usize,
119 broadcast_capacity: usize,
120 ) -> Result<()> {
121 stream.set_nodelay(true)?;
122 let mut mem_pool = BytesMut::with_capacity(max_payload_length + 4); // 4 bytes prefix space
123 let (read_h, mut writer_raw) = stream.split();
124 let mut reader = BufReader::with_capacity(BUFFER_SIZE, read_h);
125
126 let transport_key = perform_server_handshake(&mut reader, &mut writer_raw).await?;
127 let (role, session_id) =
128 read_join_frame(&mut reader, &transport_key, &mut mem_pool).await?;
129
130 match role {
131 Role::Producer => {
132 Self::run_producer(
133 &mut reader,
134 sessions,
135 session_id,
136 mem_pool,
137 max_payload_length,
138 broadcast_capacity,
139 )
140 .await
141 }
142 Role::Consumer => {
143 Self::run_consumer(&mut reader, &mut writer_raw, sessions, session_id).await
144 }
145 Role::Admin => Ok(()), // todo; implement this
146 }
147 }
148
149 /// Handles producer clients: reads encrypted frames, decrypts, and broadcasts to consumers via the session's broadcast channel
150 async fn run_producer<R: AsyncReadExt + Unpin>(
151 reader: &mut R,
152 sessions: Arc<Sessions>,
153 session_id: [u8; 64],
154 mut mem_pool: BytesMut,
155 max_payload_length: usize,
156 broadcast_capacity: usize,
157 ) -> Result<()> {
158 let tx = sessions.try_register_producer(session_id, broadcast_capacity)?;
159
160 loop {
161 // read from client read stream (just channel, no intervention)
162 let n = match read_raw_frame_into(reader, &mut mem_pool, max_payload_length).await {
163 Ok(n) => n,
164 Err(e) => {
165 tx.closed().await;
166 eprintln!("Producer read: {e}");
167 break;
168 }
169 };
170
171 // write to broadcast channel
172 if let Err(e) = tx.send(mem_pool.split_to(n).freeze()) {
173 tx.closed().await; // close channel to signal consumers
174 eprintln!("Producer broadcast: {e}");
175 break;
176 }
177 }
178
179 // clean up
180 sessions.unregister_producer(session_id);
181 Ok(())
182 }
183
184 /// Handles consumer clients: subscribes to the session's broadcast channel and writes received data to the client
185 async fn run_consumer<R: AsyncReadExt + Unpin, W: AsyncWriteExt + Unpin>(
186 reader: &mut R,
187 writer: &mut W,
188 sessions: Arc<Sessions>,
189 session_id: [u8; 64],
190 ) -> Result<()> {
191 let tx = sessions
192 .get_for_consumer(session_id)
193 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
194
195 let mut rx = tx.subscribe();
196
197 let mut peek = [0u8; 1];
198 loop {
199 tokio::select! {
200 // poll from channel
201 result = rx.recv() => {
202 match result {
203 Ok(data) => {
204 // try writing to client read stream first or fail
205 if let Err(e) = writer.write_all(&data).await {
206 let _ = writer.shutdown().await;
207 eprintln!("Consumer write: {e}");
208 break;
209 }
210 // let _ = writer.flush().await;
211 }
212 Err(RecvError::Lagged(n)) => {
213 let _ = writer.flush().await; // flush whatever remaining
214 let _ = writer.shutdown().await;
215 eprintln!("Consumer lagged behind: {n}");
216 break;
217 }
218 Err(RecvError::Closed) => {
219 let _ = writer.flush().await; // flush whatever b4 exiting
220 let _ = writer.shutdown().await;
221 eprintln!("Producer closed");
222 break;
223 },
224 }
225 }
226 result = reader.read(&mut peek) => {
227 match result {
228 Ok(0) => break, // eof check
229 Err(e) => {
230 eprintln!("Consumer read: {e}");
231 break;
232 }
233 _ => {}
234 }
235 }
236 }
237 }
238
239 Ok(())
240 }
241}