Skip to main content

hydra_sync/
server.rs

1use crate::protocol::{Role, perform_server_handshake, read_join_frame, read_raw_frame_into};
2use crate::session::Sessions;
3use crate::{BUFFER_SIZE, error, info, trace, warn};
4use anyhow::Result;
5use bytes::BytesMut;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::broadcast::error::RecvError;
12// TODO; handles backpressure "properly", implement handler traits for invoking user defined fn for some events, logging can be better.
13
14/// A light-weight multi-threaded SPMC (Single Producer Multiple Consumer) E2E relay server.
15///
16/// `HydraServer` implements a zero-copy broadcast relay that:
17/// - Accepts one producer and multiple consumers per session
18/// - Routes data from producer → all connected consumers using Arc-backed `Bytes`
19/// - Handles backpressure and slow consumers with broadcast channel lagging
20/// - Enforces connection limits and per-payload size constraints
21///
22/// Internals
23/// - Producer: Sends encrypted frames → broadcast channel
24/// - Consumers: Subscribe to broadcast, receive clones of `Arc<Bytes>` (zero-copy)
25/// - Sessions: Keyed by 64-byte session_id, one producer per session allowed
26/// - Errors & Logs: Error are predictable and handled gracefully by closing connections and logging without crashing the server
27pub struct HydraServer {
28    /// internal tcp listener for accepting incoming connections
29    listener: TcpListener,
30    /// session management for producers and consumers
31    sessions: Arc<Sessions>,
32    /// atomic counter to track active connections for enforcing limits
33    connections: Arc<AtomicUsize>,
34    /// maximum concurrent connections allowed to prevent resource exhaustion
35    max_connections: usize,
36    /// maximum allowed payload size for incoming frames to prevent abuse
37    max_payload_length: usize,
38    /// capacity of the broadcast channel for each session to handle backpressure
39    broadcast_capacity: usize,
40}
41
42impl HydraServer {
43    /// Binds the relay server with defaults
44    /// - addr: OS-assigned port
45    /// - max_connections: 32
46    /// - max_payload_length: 64 MiB
47    /// - broadcast_capacity: 256 messages
48    pub async fn bind_default() -> Result<(Self, SocketAddr)> {
49        let addr = &"127.0.0.1:0".parse::<SocketAddr>()?;
50        let server = HydraServer::bind(addr, 64 * 1024 * 1024, 32, 256).await?;
51        let local_addr = server.listener.local_addr()?;
52        Ok((server, local_addr))
53    }
54
55    /// Binds the relay server to the specified socket address and initializes internal state
56    pub async fn bind(
57        addr: &SocketAddr,
58        max_payload_length: usize,
59        max_connections: usize,
60        broadcast_capacity: usize,
61    ) -> Result<Self> {
62        let listener = TcpListener::bind(addr).await?;
63        Ok(Self {
64            listener,
65            sessions: Arc::new(Sessions::init()),
66            connections: Arc::new(AtomicUsize::new(0)),
67            max_payload_length,
68            max_connections,
69            broadcast_capacity,
70        })
71    }
72
73    /// Main server loop to accept incoming connections, spawn thread handlers, perform handshakes & session creation
74    /// - `connections_timeout_ms` is the delay before client retries to accept new connections on server when the limit is reached
75    /// - Producer errors; If read fails from client or broadcast send fails, the connection is closed and the error is logged.
76    /// - Producer errors; If writing to client fails or broadcast lags or closed, the connection is closed and the error is logged.
77    /// - EOF check are gracefully handled by closing the connection without logging an error.
78    /// - `LOG_LEVEL` & `LOG_FILE` env vars can be set to control logging verbosity and output file (defaults to `info` level and stdout, not file).
79    pub async fn run(self, connections_timeout_ms: u64) -> Result<()> {
80        loop {
81            if self.connections.fetch_add(1, Ordering::Relaxed) >= self.max_connections {
82                self.connections.fetch_sub(1, Ordering::Relaxed);
83                warn!(
84                    "Max connections reached: {}, waiting {} ms before retrying",
85                    self.max_connections, connections_timeout_ms
86                );
87                tokio::time::sleep(std::time::Duration::from_millis(connections_timeout_ms)).await;
88                continue;
89            }
90
91            match self.listener.accept().await {
92                Ok((stream, peer_addr)) => {
93                    stream.set_nodelay(true).ok();
94                    let sessions = Arc::clone(&self.sessions);
95                    let connections = Arc::clone(&self.connections);
96                    // spawn handler thread
97                    tokio::spawn(async move {
98                        trace!("Accepted connection from: {}", peer_addr);
99                        if let Err(e) = Self::handle_connection(
100                            stream,
101                            sessions,
102                            self.max_payload_length,
103                            self.broadcast_capacity,
104                        )
105                        .await
106                        {
107                            error!("Connection handling error: {} from: {}", e, peer_addr);
108                        }
109                        connections.fetch_sub(1, Ordering::Release);
110                    });
111                }
112                Err(e) => {
113                    self.connections.fetch_sub(1, Ordering::Release);
114                    error!("Connection accepting error: {}", e);
115                }
116            }
117        }
118    }
119
120    /// Handles an individual client connection, performing handshake, role determination, and routing to producer/consumer handlers
121    async fn handle_connection(
122        mut stream: TcpStream,
123        sessions: Arc<Sessions>,
124        max_payload_length: usize,
125        broadcast_capacity: usize,
126    ) -> Result<()> {
127        stream.set_nodelay(true)?;
128        let mut mem_pool = BytesMut::with_capacity(max_payload_length + 4); // 4 bytes prefix space 
129        let peer_addr = stream.peer_addr()?;
130        let (read_h, mut writer_raw) = stream.split();
131        let mut reader = BufReader::with_capacity(BUFFER_SIZE, read_h);
132
133        let transport_key = perform_server_handshake(&mut reader, &mut writer_raw).await?;
134        let (role, session_id) =
135            read_join_frame(&mut reader, &transport_key, &mut mem_pool).await?;
136
137        match role {
138            Role::Producer => {
139                info!(
140                    "Producer addr: {} joined session: {}",
141                    peer_addr,
142                    hex::encode(session_id)
143                );
144                Self::run_producer(
145                    &mut reader,
146                    sessions,
147                    session_id,
148                    &peer_addr,
149                    mem_pool,
150                    max_payload_length,
151                    broadcast_capacity,
152                )
153                .await
154            }
155            Role::Consumer => {
156                info!(
157                    "Consumer addr: {} joined session: {}",
158                    peer_addr,
159                    hex::encode(session_id)
160                );
161                Self::run_consumer(
162                    &mut reader,
163                    &mut writer_raw,
164                    sessions,
165                    session_id,
166                    &peer_addr,
167                )
168                .await
169            }
170            Role::Admin => Ok(()), // todo; implement this
171        }
172    }
173
174    /// Handles producer clients: reads encrypted frames, decrypts, and broadcasts to consumers via the session's broadcast channel
175    async fn run_producer<R: AsyncReadExt + Unpin>(
176        reader: &mut R,
177        sessions: Arc<Sessions>,
178        session_id: [u8; 64],
179        client_addr: &SocketAddr,
180        mut mem_pool: BytesMut,
181        max_payload_length: usize,
182        broadcast_capacity: usize,
183    ) -> Result<()> {
184        let tx = sessions.try_register_producer(session_id, broadcast_capacity)?;
185
186        loop {
187            // read from client read stream (just channel, no intervention)
188            let n = match read_raw_frame_into(reader, &mut mem_pool, max_payload_length).await {
189                Ok(n) => n,
190                Err(e) => {
191                    tx.closed().await;
192                    error!(
193                        "Producer addr: {} session: {} read: {e}",
194                        client_addr,
195                        hex::encode(session_id)
196                    );
197                    break;
198                }
199            };
200
201            // write to broadcast channel
202            if let Err(e) = tx.send(mem_pool.split_to(n).freeze()) {
203                tx.closed().await; // close channel to signal consumers
204                warn!(
205                    "Producer addr: {} session: {} broadcast: {e}",
206                    client_addr,
207                    hex::encode(session_id)
208                );
209                break;
210            }
211        }
212
213        // clean up
214        sessions.unregister_producer(session_id);
215        Ok(())
216    }
217
218    /// Handles consumer clients: subscribes to the session's broadcast channel and writes received data to the client
219    async fn run_consumer<R: AsyncReadExt + Unpin, W: AsyncWriteExt + Unpin>(
220        reader: &mut R,
221        writer: &mut W,
222        sessions: Arc<Sessions>,
223        session_id: [u8; 64],
224        client_addr: &SocketAddr,
225    ) -> Result<()> {
226        let tx = sessions
227            .get_for_consumer(session_id)
228            .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
229
230        let mut rx = tx.subscribe();
231
232        let mut peek = [0u8; 1];
233        loop {
234            tokio::select! {
235                // poll from channel
236                result = rx.recv() => {
237                    match result {
238                        Ok(data) => {
239                            // try writing to client read stream first or fail
240                            if let Err(e) = writer.write_all(&data).await {
241                                let _ = writer.shutdown().await;
242                                error!("Consumer addr: {} session: {} write: {e}", client_addr, hex::encode(session_id));
243                                break;
244                            }
245                            // let _ = writer.flush().await;
246                        }
247                        Err(RecvError::Lagged(n)) => {
248                            let _ = writer.flush().await; // flush whatever remaining
249                            let _ = writer.shutdown().await;
250                            warn!("Consumer addr: {} session: {} lagged by {n} messages", client_addr, hex::encode(session_id));
251                            break;
252                        }
253                        Err(RecvError::Closed) => {
254                            let _ = writer.flush().await; // flush whatever b4 exiting
255                            let _ = writer.shutdown().await;
256                            info!("Producer for session: {} closed, consumer addr: {}", hex::encode(session_id), client_addr);
257                            break;
258                        },
259                    }
260                }
261                result = reader.read(&mut peek) => {
262                    match result {
263                        Ok(0) => break, // eof check
264                        Err(e) => {
265                            error!("Consumer addr: {} session: {} read: {e}", client_addr, hex::encode(session_id));
266                            break;
267                        }
268                        _ => {}
269                    }
270                }
271            }
272        }
273
274        Ok(())
275    }
276}