1use crate::{
6 config::ServerConfig,
7 handler::{BoxedHandler, Handler},
8 connection::{Connection, ConnectionHandle},
9 rate_limit::RateLimitMiddleware,
10};
11use aerosocket_core::{Error, Result, Transport};
12use aerosocket_core::transport::TransportStream;
13use aerosocket_core::handshake::{HandshakeConfig, parse_client_handshake, validate_client_handshake, create_server_handshake, response_to_string};
14use aerosocket_core::error::ConfigError;
15use std::sync::Arc;
16use std::collections::HashMap;
17use std::net::SocketAddr;
18use tokio::sync::{Mutex, mpsc};
19use tokio::time::{timeout, Duration};
20
21pub struct Server {
23 config: ServerConfig,
24 handler: BoxedHandler,
25 rate_limiter: Option<Arc<RateLimitMiddleware>>,
26}
27
28#[derive(Debug)]
30pub struct ConnectionManager {
31 connections: Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
32 next_id: Arc<Mutex<u64>>,
33}
34
35impl ConnectionManager {
36 pub fn new() -> Self {
38 Self {
39 connections: Arc::new(Mutex::new(HashMap::new())),
40 next_id: Arc::new(Mutex::new(1)),
41 }
42 }
43
44 pub async fn add_connection(&self, connection: Connection) -> u64 {
46 let mut next_id = self.next_id.lock().await;
47 let id = *next_id;
48 *next_id += 1;
49
50 let handle = ConnectionHandle::new(id, connection);
51 let mut connections = self.connections.lock().await;
52 connections.insert(id, handle);
53 id
54 }
55
56 pub async fn remove_connection(&self, id: u64) -> Option<ConnectionHandle> {
58 let mut connections = self.connections.lock().await;
59 connections.remove(&id)
60 }
61
62 pub async fn get_connection(&self, id: u64) -> Option<ConnectionHandle> {
64 let connections = self.connections.lock().await;
65 connections.get(&id).cloned()
66 }
67
68 pub async fn get_all_connections(&self) -> Vec<ConnectionHandle> {
70 let connections = self.connections.lock().await;
71 connections.values().cloned().collect()
72 }
73
74 pub async fn connection_count(&self) -> usize {
76 let connections = self.connections.lock().await;
77 connections.len()
78 }
79}
80
81impl std::fmt::Debug for Server {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("Server")
84 .field("config", &self.config)
85 .field("handler", &"<handler>")
86 .finish()
87 }
88}
89
90impl Server {
91 pub fn new(config: ServerConfig, handler: BoxedHandler) -> Self {
93 let rate_limiter = if config.backpressure.enabled {
94 Some(Arc::new(RateLimitMiddleware::new(crate::rate_limit::RateLimitConfig {
95 max_requests: config.backpressure.max_requests_per_minute,
96 window: Duration::from_secs(60),
97 max_connections: config.max_connections / 10, connection_timeout: config.idle_timeout,
99 })))
100 } else {
101 None
102 };
103
104 Self {
105 config,
106 handler,
107 rate_limiter,
108 }
109 }
110
111 pub fn builder() -> ServerBuilder {
113 ServerBuilder::new()
114 }
115
116 pub async fn serve(self) -> Result<()> {
118 let connection_manager = Arc::new(ConnectionManager::new());
119 self.serve_with_connection_manager(connection_manager).await
120 }
121
122 pub async fn serve_with_graceful_shutdown<F>(self, shutdown_signal: F) -> Result<()>
124 where
125 F: std::future::Future<Output = ()> + Send + 'static,
126 {
127 let connection_manager = Arc::new(ConnectionManager::new());
128 let shutdown_signal = Box::pin(shutdown_signal);
129 self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_signal).await
130 }
131
132 async fn serve_with_connection_manager(self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
134 let (_shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
135 let shutdown_rx = Box::pin(async move {
136 let _ = shutdown_rx.recv().await;
137 });
138
139 self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_rx).await
140 }
141
142 async fn serve_with_connection_manager_and_shutdown<F>(
144 self,
145 _connection_manager: Arc<ConnectionManager>,
146 _shutdown_signal: F,
147 ) -> Result<()>
148 where
149 F: std::future::Future<Output = ()> + Send + Unpin + 'static,
150 {
151 #[cfg(feature = "tcp-transport")]
153 {
154 if self.config.transport_type == crate::config::TransportType::Tcp {
155 let transport = crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
156 return self.serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal).await;
157 }
158 }
159
160 #[cfg(feature = "tls-transport")]
161 {
162 if self.config.transport_type == crate::config::TransportType::Tls {
163 let _tls_config = self.config.tls.as_ref()
164 .ok_or_else(|| Error::Other("TLS configuration required for TLS transport".to_string()))?;
165 return Err(Error::Other("TLS transport is not available in this release. Please enable the 'tls-transport' feature and implement proper TLS configuration.".to_string()));
167 }
168 }
169
170 if self.config.transport_type == crate::config::TransportType::Tcp {
172 #[cfg(feature = "tcp-transport")]
173 {
174 let transport = crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
175 return self.serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal).await;
176 }
177 }
178 Err(Error::Config(ConfigError::Validation("No transport available".to_string())))
179 }
180
181 #[cfg(feature = "tcp-transport")]
183 async fn serve_with_tcp_transport<F>(
184 self,
185 transport: crate::tcp_transport::TcpTransport,
186 connection_manager: Arc<ConnectionManager>,
187 _shutdown_signal: F,
188 ) -> Result<()>
189 where
190 F: std::future::Future<Output = ()> + Send + 'static,
191 {
192 let handler = self.handler;
194 let config = self.config.clone();
195 let manager = connection_manager.clone();
196 let rate_limiter = self.rate_limiter.clone();
197
198 let server_task = tokio::spawn(async move {
199 let mut connection_counter = 0u64;
200
201 loop {
202 tokio::select! {
204 result = transport.accept() => {
205 match result {
206 Ok(mut stream) => {
207 let remote_addr = match stream.remote_addr() {
209 Ok(addr) => addr.ip(),
210 Err(e) => {
211 crate::log_error!("Failed to get remote address: {:?}", e);
212 let _ = stream.close().await;
213 continue;
214 }
215 };
216
217 if let Some(ref rate_limiter) = rate_limiter {
219 if !rate_limiter.check_connection(remote_addr).await.unwrap_or(true) {
220 crate::log_warn!("Rate limit exceeded for IP: {}", remote_addr);
221 let _ = stream.close().await;
222 continue;
223 }
224 }
225
226 if manager.connection_count().await >= config.max_connections {
228 crate::log_warn!("Connection limit reached, rejecting connection from {}", remote_addr);
229 let _ = stream.close().await;
231 continue;
232 }
233
234 connection_counter += 1;
235 let manager = manager.clone();
236 let handler = handler.clone();
237 let config = config.clone();
238 let rate_limiter = rate_limiter.clone();
239
240 tokio::spawn(async move {
242 if let Err(e) = Self::handle_connection(
243 stream,
244 handler,
245 config,
246 manager,
247 rate_limiter,
248 ).await {
249 crate::log_error!("Connection handling error: {:?}", e);
250 }
251 });
252 }
253 Err(e) => {
254 crate::log_error!("Accept error: {:?}", e);
255 }
257 }
258 }
259 _ = tokio::signal::ctrl_c() => {
260 break;
261 }
262 }
263 }
264 });
265
266 match server_task.await {
268 Ok(()) => Ok(()),
269 Err(e) => Err(Error::Other(format!("Server task panicked: {}", e))),
270 }
271 }
272
273 #[cfg(feature = "tls-transport")]
275 async fn serve_with_tls_transport<F>(
276 self,
277 _transport: crate::tls_transport::TlsTransport,
278 _connection_manager: Arc<ConnectionManager>,
279 _shutdown_signal: F,
280 ) -> Result<()>
281 where
282 F: std::future::Future<Output = ()> + Send + Unpin + 'static,
283 {
284 Err(Error::Other("TLS transport is not available in this release".to_string()))
285 }
286
287 #[cfg(feature = "tls-transport")]
289 async fn handle_tls_connection(
290 _stream: crate::tls_transport::TlsStreamWrapper,
291 _handler: BoxedHandler,
292 _config: ServerConfig,
293 _connection_manager: Arc<ConnectionManager>,
294 _rate_limiter: Option<Arc<RateLimitMiddleware>>,
295 ) -> Result<()> {
296 Err(Error::Other("TLS connection handling is not available in this release".to_string()))
297 }
298
299 async fn handle_connection(
301 mut stream: crate::tcp_transport::TcpStream,
302 handler: BoxedHandler,
303 config: ServerConfig,
304 connection_manager: Arc<ConnectionManager>,
305 rate_limiter: Option<Arc<RateLimitMiddleware>>,
306 ) -> Result<()> {
307 let (remote_addr, local_addr) = Self::perform_handshake(&mut stream, &config).await?;
309
310 let boxed_stream: Box<dyn TransportStream> = Box::new(stream);
312
313 let connection = Connection::with_stream(remote_addr, local_addr, boxed_stream);
315
316 let connection_id = connection_manager.add_connection(connection).await;
318
319 let connection_handle = connection_manager.get_connection(connection_id).await
321 .ok_or_else(|| Error::Other("Failed to get connection handle".to_string()))?;
322
323 if let Err(e) = handler.handle(connection_handle).await {
325 crate::log_error!("Handler error: {:?}", e);
326 }
327
328 connection_manager.remove_connection(connection_id).await;
330
331 if let Some(ref rate_limiter) = rate_limiter {
333 rate_limiter.connection_closed(remote_addr.ip()).await;
334 }
335
336 Ok(())
337 }
338
339 #[cfg(feature = "tls-transport")]
341 async fn perform_tls_handshake(
342 stream: &mut crate::tls_transport::TlsStreamWrapper,
343 config: &ServerConfig,
344 ) -> Result<(SocketAddr, SocketAddr)> {
345 let request_data = Self::read_tls_handshake_request(stream, config.handshake_timeout).await?;
347 let request_str = String::from_utf8_lossy(&request_data);
348
349 let request = parse_client_handshake(&request_str)?;
351
352 let handshake_config = HandshakeConfig {
354 protocols: config.supported_protocols.clone(),
355 extensions: config.supported_extensions.clone(),
356 origin: config.allowed_origin.clone(),
357 host: None,
358 extra_headers: config.extra_headers.clone(),
359 };
360
361 validate_client_handshake(&request, &handshake_config)?;
363
364 let response = create_server_handshake(&request, &handshake_config)?;
366 let response_str = response_to_string(&response);
367
368 stream.write_all(response_str.as_bytes()).await?;
370 stream.flush().await?;
371
372 let remote_addr = stream.remote_addr()?;
374 let local_addr = stream.local_addr()?;
375
376 Ok((remote_addr, local_addr))
377 }
378
379 #[cfg(feature = "tls-transport")]
381 async fn read_tls_handshake_request(
382 stream: &mut crate::tls_transport::TlsStreamWrapper,
383 timeout_duration: Duration,
384 ) -> Result<Vec<u8>> {
385 let mut buffer = Vec::new();
386 let mut temp_buffer = [0u8; 1024];
387
388 let read_result = timeout(timeout_duration, async {
389 loop {
390 let n = stream.read(&mut temp_buffer).await?;
391 if n == 0 {
392 break;
393 }
394
395 buffer.extend_from_slice(&temp_buffer[..n]);
396
397 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
399 break;
400 }
401
402 if buffer.len() > 8192 {
404 return Err(Error::Other("TLS handshake request too large".to_string()));
405 }
406 }
407
408 Ok::<(), Error>(())
409 }).await;
410
411 match read_result {
412 Ok(result) => {
413 result?;
414 Ok(buffer)
415 },
416 Err(_) => Err(Error::Other("TLS handshake timeout".to_string())),
417 }
418 }
419
420 async fn perform_handshake(
422 stream: &mut crate::tcp_transport::TcpStream,
423 config: &ServerConfig,
424 ) -> Result<(SocketAddr, SocketAddr)> {
425 let request_data = Self::read_handshake_request(stream, config.handshake_timeout).await?;
427 let request_str = String::from_utf8_lossy(&request_data);
428
429 let request = parse_client_handshake(&request_str)?;
431
432 let handshake_config = HandshakeConfig {
434 protocols: config.supported_protocols.clone(),
435 extensions: config.supported_extensions.clone(),
436 origin: config.allowed_origin.clone(),
437 host: None,
438 extra_headers: config.extra_headers.clone(),
439 };
440
441 validate_client_handshake(&request, &handshake_config)?;
443
444 let response = create_server_handshake(&request, &handshake_config)?;
446 let response_str = response_to_string(&response);
447
448 stream.write_all(response_str.as_bytes()).await?;
450 stream.flush().await?;
451
452 let remote_addr = stream.remote_addr()?;
454 let local_addr = stream.local_addr()?;
455
456 Ok((remote_addr, local_addr))
457 }
458
459 async fn read_handshake_request(
461 stream: &mut crate::tcp_transport::TcpStream,
462 timeout_duration: Duration,
463 ) -> Result<Vec<u8>> {
464 let mut buffer = Vec::new();
465 let mut temp_buffer = [0u8; 1024];
466
467 let read_result = timeout(timeout_duration, async {
468 loop {
469 let n = stream.read(&mut temp_buffer).await?;
470 if n == 0 {
471 break;
472 }
473
474 buffer.extend_from_slice(&temp_buffer[..n]);
475
476 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
478 break;
479 }
480
481 if buffer.len() > 8192 {
483 return Err(Error::Other("Handshake request too large".to_string()));
484 }
485 }
486
487 Ok::<(), Error>(())
488 }).await;
489
490 match read_result {
491 Ok(result) => {
492 result?;
493 Ok(buffer)
494 },
495 Err(_) => Err(Error::Other("Handshake timeout".to_string())),
496 }
497 }
498
499 async fn graceful_shutdown(&self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
501 let connections = connection_manager.get_all_connections().await;
503
504 for handle in connections {
506 if let Ok(mut connection) = handle.try_lock().await {
507 let _ = connection.close(Some(1000), Some("Server shutdown")).await;
508 }
509 }
510
511 Ok(())
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct ServerBuilder {
518 config: ServerConfig,
519}
520
521impl ServerBuilder {
522 pub fn new() -> Self {
524 Self {
525 config: ServerConfig::default(),
526 }
527 }
528
529 pub fn bind<A: std::net::ToSocketAddrs>(mut self, addr: A) -> Result<Self> {
531 self.config.bind_address = addr.to_socket_addrs()?.next().ok_or_else(|| {
532 Error::Config(ConfigError::Validation("Invalid bind address".to_string()))
533 })?;
534 Ok(self)
535 }
536
537 pub fn max_connections(mut self, max: usize) -> Self {
539 self.config.max_connections = max;
540 self
541 }
542
543 pub fn max_frame_size(mut self, size: usize) -> Self {
545 self.config.max_frame_size = size;
546 self
547 }
548
549 pub fn max_message_size(mut self, size: usize) -> Self {
551 self.config.max_message_size = size;
552 self
553 }
554
555 pub fn handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
557 self.config.handshake_timeout = timeout;
558 self
559 }
560
561 pub fn idle_timeout(mut self, timeout: std::time::Duration) -> Self {
563 self.config.idle_timeout = timeout;
564 self
565 }
566
567 pub fn compression(mut self, enabled: bool) -> Self {
569 self.config.compression.enabled = enabled;
570 self
571 }
572
573 pub fn backpressure(mut self, strategy: crate::config::BackpressureStrategy) -> Self {
575 self.config.backpressure.strategy = strategy;
576 self
577 }
578
579 pub fn build(self) -> Result<Server> {
581 self.config.validate()?;
583
584 let handler = Box::new(crate::handler::DefaultHandler::new());
586
587 Ok(Server::new(self.config, handler))
588 }
589
590 pub fn build_with_handler<H>(self, handler: H) -> Result<Server>
592 where
593 H: Handler + Send + Sync + 'static,
594 {
595 self.config.validate()?;
597
598 Ok(Server::new(self.config, Box::new(handler)))
599 }
600}
601
602impl Default for ServerBuilder {
603 fn default() -> Self {
604 Self::new()
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_server_builder() {
614 let builder = ServerBuilder::new()
615 .bind("127.0.0.1:8080")
616 .unwrap()
617 .max_connections(1000)
618 .max_frame_size(1024 * 1024)
619 .compression(true);
620
621 assert!(builder.build().is_ok());
622 }
623}