Skip to main content

hydra_sync/
server.rs

1use crate::protocol::{Role, perform_server_handshake, read_join_frame, read_raw_frame_into};
2use crate::session::Sessions;
3use crate::{BUFFER_SIZE, error, info, trace, warn};
4use anyhow::Result;
5use bytes::BytesMut;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::broadcast::error::RecvError;
12// TODO; handles backpressure "properly", implement handler traits for invoking user defined fn for some events, logging can be better.
13
14/// A light-weight multi-threaded SPMC (Single Producer Multiple Consumer) E2E relay server.
15///
16/// `HydraServer` implements a zero-copy tokio::broadcast relay that:
17/// - Accepts one producer and multiple consumers per session
18/// - Routes data from producer → all connected consumers using Arc-backed `Bytes`
19/// - Handles backpressure and slow consumers with broadcast channel `RecvError::Lagged(n)`
20/// - Enforces connection limits and per-payload size constraints
21/// ```no_run
22/// use hydra_sync::server::HydraServer;
23///
24/// #[tokio::main]
25/// async fn main() {
26/// // choose an OS-assigned port
27///     let (server, addr) = HydraServer::bind_default().await.unwrap();
28///
29///     println!("Server running on: {}", addr);
30///     tokio::spawn(async move{ server.run(500).await });
31/// }
32/// ```
33/// Internals
34/// - Producer: Sends encrypted frames → broadcast channel
35/// - Consumers: Subscribe to broadcast, receive clones of `Arc<Bytes>` (zero-copy)
36/// - Sessions: Keyed by 64-byte session_id, one producer per session allowed
37/// - Errors & Logs: Error are predictable and handled gracefully by closing connections and logging without crashing the server
38///
39pub struct HydraServer {
40    /// internal tcp listener for accepting incoming connections
41    listener: TcpListener,
42    /// session management for producers and consumers
43    sessions: Arc<Sessions>,
44    /// atomic counter to track active connections for enforcing limits
45    connections: Arc<AtomicUsize>,
46    /// maximum concurrent connections allowed to prevent resource exhaustion
47    max_connections: usize,
48    /// maximum allowed payload size for incoming frames to prevent abuse
49    max_payload_length: usize,
50    /// capacity of the broadcast channel for each session to handle backpressure
51    broadcast_capacity: usize,
52}
53
54impl HydraServer {
55    /// Binds the relay server with defaults
56    /// - addr: OS-assigned port
57    /// - max_connections: 32
58    /// - max_payload_length: 64 MiB
59    /// - broadcast_capacity: 256 messages
60    pub async fn bind_default() -> Result<(Self, SocketAddr)> {
61        let addr = &"127.0.0.1:0".parse::<SocketAddr>()?;
62        let server = HydraServer::bind(addr, 64 * 1024 * 1024, 32, 256).await?;
63        let local_addr = server.listener.local_addr()?;
64        Ok((server, local_addr))
65    }
66
67    /// Binds the relay server to the specified socket address and initializes internal state
68    pub async fn bind(
69        addr: &SocketAddr,
70        max_payload_length: usize,
71        max_connections: usize,
72        broadcast_capacity: usize,
73    ) -> Result<Self> {
74        let listener = TcpListener::bind(addr).await?;
75        Ok(Self {
76            listener,
77            sessions: Arc::new(Sessions::init()),
78            connections: Arc::new(AtomicUsize::new(0)),
79            max_payload_length,
80            max_connections,
81            broadcast_capacity,
82        })
83    }
84
85    /// Main server loop to accept incoming connections, spawn thread handlers, perform handshakes & session creation
86    /// - `connections_timeout_ms` is the delay before client retries to accept new connections on server when the limit is reached
87    /// - Producer errors; If read fails from client or broadcast send fails, the connection is closed and the error is logged.
88    /// - Producer errors; If writing to client fails or broadcast lags or closed, the connection is closed and the error is logged.
89    /// - EOF check are gracefully handled by closing the connection without logging an error.
90    /// - `LOG_LEVEL` & `LOG_FILE` env vars can be set to control logging verbosity and output file (defaults to `info` level and stdout, not file).
91    pub async fn run(self, connections_timeout_ms: u64) -> Result<()> {
92        loop {
93            if self.connections.fetch_add(1, Ordering::Relaxed) >= self.max_connections {
94                self.connections.fetch_sub(1, Ordering::Relaxed);
95                warn!(
96                    "Max connections reached: {}, waiting {} ms before retrying",
97                    self.max_connections, connections_timeout_ms
98                );
99                tokio::time::sleep(std::time::Duration::from_millis(connections_timeout_ms)).await;
100                continue;
101            }
102
103            match self.listener.accept().await {
104                Ok((stream, peer_addr)) => {
105                    stream.set_nodelay(true).ok();
106                    let sessions = Arc::clone(&self.sessions);
107                    let connections = Arc::clone(&self.connections);
108                    // spawn handler thread
109                    tokio::spawn(async move {
110                        trace!("Accepted connection from: {}", peer_addr);
111                        if let Err(e) = Self::handle_connection(
112                            stream,
113                            sessions,
114                            self.max_payload_length,
115                            self.broadcast_capacity,
116                        )
117                        .await
118                        {
119                            error!("Connection handling error: {} from: {}", e, peer_addr);
120                        }
121                        connections.fetch_sub(1, Ordering::Release);
122                    });
123                }
124                Err(e) => {
125                    self.connections.fetch_sub(1, Ordering::Release);
126                    error!("Connection accepting error: {}", e);
127                }
128            }
129        }
130    }
131
132    /// Handles an individual client connection, performing handshake, role determination, and routing to producer/consumer handlers
133    async fn handle_connection(
134        mut stream: TcpStream,
135        sessions: Arc<Sessions>,
136        max_payload_length: usize,
137        broadcast_capacity: usize,
138    ) -> Result<()> {
139        stream.set_nodelay(true)?;
140        let mut mem_pool = BytesMut::with_capacity(max_payload_length + 4); // 4 bytes prefix space 
141        let peer_addr = stream.peer_addr()?;
142        let (read_h, mut writer_raw) = stream.split();
143        let mut reader = BufReader::with_capacity(BUFFER_SIZE, read_h);
144
145        let transport_key = perform_server_handshake(&mut reader, &mut writer_raw).await?;
146        let (role, session_id) =
147            read_join_frame(&mut reader, &transport_key, &mut mem_pool).await?;
148
149        match role {
150            Role::Producer => {
151                info!(
152                    "Producer addr: {} joined session: {}",
153                    peer_addr,
154                    hex::encode(session_id)
155                );
156                Self::run_producer(
157                    &mut reader,
158                    sessions,
159                    session_id,
160                    &peer_addr,
161                    mem_pool,
162                    max_payload_length,
163                    broadcast_capacity,
164                )
165                .await
166            }
167            Role::Consumer => {
168                info!(
169                    "Consumer addr: {} joined session: {}",
170                    peer_addr,
171                    hex::encode(session_id)
172                );
173                Self::run_consumer(
174                    &mut reader,
175                    &mut writer_raw,
176                    sessions,
177                    session_id,
178                    &peer_addr,
179                )
180                .await
181            }
182            Role::Admin => Ok(()), // todo; implement this
183        }
184    }
185
186    /// Handles producer clients: reads encrypted frames, decrypts, and broadcasts to consumers via the session's broadcast channel
187    async fn run_producer<R: AsyncReadExt + Unpin>(
188        reader: &mut R,
189        sessions: Arc<Sessions>,
190        session_id: [u8; 64],
191        client_addr: &SocketAddr,
192        mut mem_pool: BytesMut,
193        max_payload_length: usize,
194        broadcast_capacity: usize,
195    ) -> Result<()> {
196        let tx = sessions.try_register_producer(session_id, broadcast_capacity)?;
197
198        loop {
199            // read from client read stream (just channel, no intervention)
200            let n = match read_raw_frame_into(reader, &mut mem_pool, max_payload_length).await {
201                Ok(n) => n,
202                Err(e) => {
203                    tx.closed().await;
204                    error!(
205                        "Producer addr: {} session: {} read: {e}",
206                        client_addr,
207                        hex::encode(session_id)
208                    );
209                    break;
210                }
211            };
212
213            // write to broadcast channel
214            if let Err(e) = tx.send(mem_pool.split_to(n).freeze()) {
215                tx.closed().await; // close channel to signal consumers
216                warn!(
217                    "Producer addr: {} session: {} broadcast: {e}",
218                    client_addr,
219                    hex::encode(session_id)
220                );
221                break;
222            }
223        }
224
225        // clean up
226        sessions.unregister_producer(session_id);
227        Ok(())
228    }
229
230    /// Handles consumer clients: subscribes to the session's broadcast channel and writes received data to the client
231    async fn run_consumer<R: AsyncReadExt + Unpin, W: AsyncWriteExt + Unpin>(
232        reader: &mut R,
233        writer: &mut W,
234        sessions: Arc<Sessions>,
235        session_id: [u8; 64],
236        client_addr: &SocketAddr,
237    ) -> Result<()> {
238        let tx = sessions
239            .get_for_consumer(session_id)
240            .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
241
242        let mut rx = tx.subscribe();
243
244        let mut peek = [0u8; 1];
245        loop {
246            tokio::select! {
247                // poll from channel
248                result = rx.recv() => {
249                    match result {
250                        Ok(data) => {
251                            // try writing to client read stream first or fail
252                            if let Err(e) = writer.write_all(&data).await {
253                                let _ = writer.shutdown().await;
254                                error!("Consumer addr: {} session: {} write: {e}", client_addr, hex::encode(session_id));
255                                break;
256                            }
257                            // let _ = writer.flush().await;
258                        }
259                        Err(RecvError::Lagged(n)) => {
260                            let _ = writer.flush().await; // flush whatever remaining
261                            let _ = writer.shutdown().await;
262                            warn!("Consumer addr: {} session: {} lagged by {n} messages", client_addr, hex::encode(session_id));
263                            break;
264                        }
265                        Err(RecvError::Closed) => {
266                            let _ = writer.flush().await; // flush whatever b4 exiting
267                            let _ = writer.shutdown().await;
268                            info!("Producer for session: {} closed, consumer addr: {}", hex::encode(session_id), client_addr);
269                            break;
270                        },
271                    }
272                }
273                result = reader.read(&mut peek) => {
274                    match result {
275                        Ok(0) => break, // eof check
276                        Err(e) => {
277                            error!("Consumer addr: {} session: {} read: {e}", client_addr, hex::encode(session_id));
278                            break;
279                        }
280                        _ => {}
281                    }
282                }
283            }
284        }
285
286        Ok(())
287    }
288}