1use crate::{
6 config::ServerConfig,
7 connection::{Connection, ConnectionHandle},
8 handler::{BoxedHandler, Handler},
9 rate_limit::RateLimitMiddleware,
10};
11use aerosocket_core::error::ConfigError;
12use aerosocket_core::handshake::{
13 create_server_handshake, parse_client_handshake, response_to_string, validate_client_handshake,
14 HandshakeConfig,
15};
16use aerosocket_core::transport::TransportStream;
17use aerosocket_core::{Error, Result, Transport};
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex};
22use tokio::time::{timeout, Duration};
23
24pub struct Server {
26 config: ServerConfig,
27 handler: BoxedHandler,
28 rate_limiter: Option<Arc<RateLimitMiddleware>>,
29}
30
31#[derive(Debug)]
33pub struct ConnectionManager {
34 connections: Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
35 next_id: Arc<Mutex<u64>>,
36}
37
38impl ConnectionManager {
39 pub fn new() -> Self {
41 Self {
42 connections: Arc::new(Mutex::new(HashMap::new())),
43 next_id: Arc::new(Mutex::new(1)),
44 }
45 }
46
47 pub async fn add_connection(&self, connection: Connection) -> u64 {
49 let mut next_id = self.next_id.lock().await;
50 let id = *next_id;
51 *next_id += 1;
52
53 let handle = ConnectionHandle::new(id, connection);
54 let mut connections = self.connections.lock().await;
55 connections.insert(id, handle);
56 id
57 }
58
59 pub async fn remove_connection(&self, id: u64) -> Option<ConnectionHandle> {
61 let mut connections = self.connections.lock().await;
62 connections.remove(&id)
63 }
64
65 pub async fn get_connection(&self, id: u64) -> Option<ConnectionHandle> {
67 let connections = self.connections.lock().await;
68 connections.get(&id).cloned()
69 }
70
71 pub async fn get_all_connections(&self) -> Vec<ConnectionHandle> {
73 let connections = self.connections.lock().await;
74 connections.values().cloned().collect()
75 }
76
77 pub async fn connection_count(&self) -> usize {
79 let connections = self.connections.lock().await;
80 connections.len()
81 }
82}
83
84impl std::fmt::Debug for Server {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("Server")
87 .field("config", &self.config)
88 .field("handler", &"<handler>")
89 .finish()
90 }
91}
92
93impl Server {
94 pub fn new(config: ServerConfig, handler: BoxedHandler) -> Self {
96 let rate_limiter = if config.backpressure.enabled {
97 Some(Arc::new(RateLimitMiddleware::new(
98 crate::rate_limit::RateLimitConfig {
99 max_requests: config.backpressure.max_requests_per_minute,
100 window: Duration::from_secs(60),
101 max_connections: config.max_connections / 10, connection_timeout: config.idle_timeout,
103 },
104 )))
105 } else {
106 None
107 };
108
109 Self {
110 config,
111 handler,
112 rate_limiter,
113 }
114 }
115
116 pub fn builder() -> ServerBuilder {
118 ServerBuilder::new()
119 }
120
121 pub async fn serve(self) -> Result<()> {
123 let connection_manager = Arc::new(ConnectionManager::new());
124 self.serve_with_connection_manager(connection_manager).await
125 }
126
127 pub async fn serve_with_graceful_shutdown<F>(self, shutdown_signal: F) -> Result<()>
129 where
130 F: std::future::Future<Output = ()> + Send + 'static,
131 {
132 let connection_manager = Arc::new(ConnectionManager::new());
133 let shutdown_signal = Box::pin(shutdown_signal);
134 self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_signal)
135 .await
136 }
137
138 async fn serve_with_connection_manager(
140 self,
141 connection_manager: Arc<ConnectionManager>,
142 ) -> Result<()> {
143 let (_shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
144 let shutdown_rx = Box::pin(async move {
145 let _ = shutdown_rx.recv().await;
146 });
147
148 self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_rx)
149 .await
150 }
151
152 async fn serve_with_connection_manager_and_shutdown<F>(
154 self,
155 _connection_manager: Arc<ConnectionManager>,
156 _shutdown_signal: F,
157 ) -> Result<()>
158 where
159 F: std::future::Future<Output = ()> + Send + Unpin + 'static,
160 {
161 #[cfg(feature = "tcp-transport")]
163 {
164 if self.config.transport_type == crate::config::TransportType::Tcp {
165 let transport =
166 crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
167 return self
168 .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
169 .await;
170 }
171 }
172
173 #[cfg(feature = "tls-transport")]
174 {
175 if self.config.transport_type == crate::config::TransportType::Tls {
176 let _tls_config = self.config.tls.as_ref().ok_or_else(|| {
177 Error::Other("TLS configuration required for TLS transport".to_string())
178 })?;
179 return Err(Error::Other("TLS transport is not available in this release. Please enable the 'tls-transport' feature and implement proper TLS configuration.".to_string()));
181 }
182 }
183
184 if self.config.transport_type == crate::config::TransportType::Tcp {
186 #[cfg(feature = "tcp-transport")]
187 {
188 let transport =
189 crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
190 return self
191 .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
192 .await;
193 }
194 }
195 Err(Error::Config(ConfigError::Validation(
196 "No transport available".to_string(),
197 )))
198 }
199
200 #[cfg(feature = "tcp-transport")]
202 async fn serve_with_tcp_transport<F>(
203 self,
204 transport: crate::tcp_transport::TcpTransport,
205 connection_manager: Arc<ConnectionManager>,
206 _shutdown_signal: F,
207 ) -> Result<()>
208 where
209 F: std::future::Future<Output = ()> + Send + 'static,
210 {
211 let handler = self.handler;
213 let config = self.config.clone();
214 let manager = connection_manager.clone();
215 let rate_limiter = self.rate_limiter.clone();
216
217 let server_task = tokio::spawn(async move {
218 let mut connection_counter = 0u64;
219
220 loop {
221 tokio::select! {
223 result = transport.accept() => {
224 match result {
225 Ok(mut stream) => {
226 let remote_addr = match stream.remote_addr() {
228 Ok(addr) => addr.ip(),
229 Err(e) => {
230 crate::log_error!("Failed to get remote address: {:?}", e);
231 let _ = stream.close().await;
232 continue;
233 }
234 };
235
236 if let Some(ref rate_limiter) = rate_limiter {
238 if !rate_limiter.check_connection(remote_addr).await.unwrap_or(true) {
239 crate::log_warn!("Rate limit exceeded for IP: {}", remote_addr);
240 let _ = stream.close().await;
241 continue;
242 }
243 }
244
245 if manager.connection_count().await >= config.max_connections {
247 crate::log_warn!("Connection limit reached, rejecting connection from {}", remote_addr);
248 let _ = stream.close().await;
250 continue;
251 }
252
253 connection_counter += 1;
254 crate::log_debug!("Accepted connection #{} from {}", connection_counter, remote_addr);
255 let manager = manager.clone();
256 let handler = handler.clone();
257 let config = config.clone();
258 let rate_limiter = rate_limiter.clone();
259
260 tokio::spawn(async move {
262 if let Err(e) = Self::handle_connection(
263 stream,
264 handler,
265 config,
266 manager,
267 rate_limiter,
268 ).await {
269 crate::log_error!("Connection handling error: {:?}", e);
270 }
271 });
272 }
273 Err(e) => {
274 crate::log_error!("Accept error: {:?}", e);
275 }
277 }
278 }
279 _ = tokio::signal::ctrl_c() => {
280 break;
281 }
282 }
283 }
284 });
285
286 match server_task.await {
288 Ok(()) => Ok(()),
289 Err(e) => Err(Error::Other(format!("Server task panicked: {}", e))),
290 }
291 }
292
293 #[cfg(feature = "tls-transport")]
295 async fn serve_with_tls_transport<F>(
296 self,
297 _transport: crate::tls_transport::TlsTransport,
298 _connection_manager: Arc<ConnectionManager>,
299 _shutdown_signal: F,
300 ) -> Result<()>
301 where
302 F: std::future::Future<Output = ()> + Send + Unpin + 'static,
303 {
304 Err(Error::Other(
305 "TLS transport is not available in this release".to_string(),
306 ))
307 }
308
309 #[cfg(feature = "tls-transport")]
311 async fn handle_tls_connection(
312 _stream: crate::tls_transport::TlsStreamWrapper,
313 _handler: BoxedHandler,
314 _config: ServerConfig,
315 _connection_manager: Arc<ConnectionManager>,
316 _rate_limiter: Option<Arc<RateLimitMiddleware>>,
317 ) -> Result<()> {
318 Err(Error::Other(
319 "TLS connection handling is not available in this release".to_string(),
320 ))
321 }
322
323 async fn handle_connection(
325 mut stream: crate::tcp_transport::TcpStream,
326 handler: BoxedHandler,
327 config: ServerConfig,
328 connection_manager: Arc<ConnectionManager>,
329 rate_limiter: Option<Arc<RateLimitMiddleware>>,
330 ) -> Result<()> {
331 let (remote_addr, local_addr) = Self::perform_handshake(&mut stream, &config).await?;
333
334 let boxed_stream: Box<dyn TransportStream> = Box::new(stream);
336
337 let connection = Connection::with_stream(remote_addr, local_addr, boxed_stream);
339
340 let connection_id = connection_manager.add_connection(connection).await;
342
343 let connection_handle = connection_manager
345 .get_connection(connection_id)
346 .await
347 .ok_or_else(|| Error::Other("Failed to get connection handle".to_string()))?;
348
349 if let Err(e) = handler.handle(connection_handle).await {
351 crate::log_error!("Handler error: {:?}", e);
352 }
353
354 connection_manager.remove_connection(connection_id).await;
356
357 if let Some(ref rate_limiter) = rate_limiter {
359 rate_limiter.connection_closed(remote_addr.ip()).await;
360 }
361
362 Ok(())
363 }
364
365 #[cfg(feature = "tls-transport")]
367 async fn perform_tls_handshake(
368 stream: &mut crate::tls_transport::TlsStreamWrapper,
369 config: &ServerConfig,
370 ) -> Result<(SocketAddr, SocketAddr)> {
371 let request_data =
373 Self::read_tls_handshake_request(stream, config.handshake_timeout).await?;
374 let request_str = String::from_utf8_lossy(&request_data);
375
376 let request = parse_client_handshake(&request_str)?;
378
379 let handshake_config = HandshakeConfig {
381 protocols: config.supported_protocols.clone(),
382 extensions: config.supported_extensions.clone(),
383 origin: config.allowed_origin.clone(),
384 host: None,
385 extra_headers: config.extra_headers.clone(),
386 };
387
388 validate_client_handshake(&request, &handshake_config)?;
390
391 let response = create_server_handshake(&request, &handshake_config)?;
393 let response_str = response_to_string(&response);
394
395 stream.write_all(response_str.as_bytes()).await?;
397 stream.flush().await?;
398
399 let remote_addr = stream.remote_addr()?;
401 let local_addr = stream.local_addr()?;
402
403 Ok((remote_addr, local_addr))
404 }
405
406 #[cfg(feature = "tls-transport")]
408 async fn read_tls_handshake_request(
409 stream: &mut crate::tls_transport::TlsStreamWrapper,
410 timeout_duration: Duration,
411 ) -> Result<Vec<u8>> {
412 let mut buffer = Vec::new();
413 let mut temp_buffer = [0u8; 1024];
414
415 let read_result = timeout(timeout_duration, async {
416 loop {
417 let n = stream.read(&mut temp_buffer).await?;
418 if n == 0 {
419 break;
420 }
421
422 buffer.extend_from_slice(&temp_buffer[..n]);
423
424 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
426 break;
427 }
428
429 if buffer.len() > 8192 {
431 return Err(Error::Other("TLS handshake request too large".to_string()));
432 }
433 }
434
435 Ok::<(), Error>(())
436 })
437 .await;
438
439 match read_result {
440 Ok(result) => {
441 result?;
442 Ok(buffer)
443 }
444 Err(_) => Err(Error::Other("TLS handshake timeout".to_string())),
445 }
446 }
447
448 async fn perform_handshake(
450 stream: &mut crate::tcp_transport::TcpStream,
451 config: &ServerConfig,
452 ) -> Result<(SocketAddr, SocketAddr)> {
453 let request_data = Self::read_handshake_request(stream, config.handshake_timeout).await?;
455 let request_str = String::from_utf8_lossy(&request_data);
456
457 let request = parse_client_handshake(&request_str)?;
459
460 let handshake_config = HandshakeConfig {
462 protocols: config.supported_protocols.clone(),
463 extensions: config.supported_extensions.clone(),
464 origin: config.allowed_origin.clone(),
465 host: None,
466 extra_headers: config.extra_headers.clone(),
467 };
468
469 validate_client_handshake(&request, &handshake_config)?;
471
472 let response = create_server_handshake(&request, &handshake_config)?;
474 let response_str = response_to_string(&response);
475
476 stream.write_all(response_str.as_bytes()).await?;
478 stream.flush().await?;
479
480 let remote_addr = stream.remote_addr()?;
482 let local_addr = stream.local_addr()?;
483
484 Ok((remote_addr, local_addr))
485 }
486
487 async fn read_handshake_request(
489 stream: &mut crate::tcp_transport::TcpStream,
490 timeout_duration: Duration,
491 ) -> Result<Vec<u8>> {
492 let mut buffer = Vec::new();
493 let mut temp_buffer = [0u8; 1024];
494
495 let read_result = timeout(timeout_duration, async {
496 loop {
497 let n = stream.read(&mut temp_buffer).await?;
498 if n == 0 {
499 break;
500 }
501
502 buffer.extend_from_slice(&temp_buffer[..n]);
503
504 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
506 break;
507 }
508
509 if buffer.len() > 8192 {
511 return Err(Error::Other("Handshake request too large".to_string()));
512 }
513 }
514
515 Ok::<(), Error>(())
516 })
517 .await;
518
519 match read_result {
520 Ok(result) => {
521 result?;
522 Ok(buffer)
523 }
524 Err(_) => Err(Error::Other("Handshake timeout".to_string())),
525 }
526 }
527
528 async fn graceful_shutdown(&self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
530 let connections = connection_manager.get_all_connections().await;
532
533 for handle in connections {
535 if let Ok(mut connection) = handle.try_lock().await {
536 let _ = connection.close(Some(1000), Some("Server shutdown")).await;
537 }
538 }
539
540 Ok(())
541 }
542}
543
544#[derive(Debug, Clone)]
546pub struct ServerBuilder {
547 config: ServerConfig,
548}
549
550impl ServerBuilder {
551 pub fn new() -> Self {
553 Self {
554 config: ServerConfig::default(),
555 }
556 }
557
558 pub fn bind<A: std::net::ToSocketAddrs>(mut self, addr: A) -> Result<Self> {
560 self.config.bind_address = addr.to_socket_addrs()?.next().ok_or_else(|| {
561 Error::Config(ConfigError::Validation("Invalid bind address".to_string()))
562 })?;
563 Ok(self)
564 }
565
566 pub fn max_connections(mut self, max: usize) -> Self {
568 self.config.max_connections = max;
569 self
570 }
571
572 pub fn max_frame_size(mut self, size: usize) -> Self {
574 self.config.max_frame_size = size;
575 self
576 }
577
578 pub fn max_message_size(mut self, size: usize) -> Self {
580 self.config.max_message_size = size;
581 self
582 }
583
584 pub fn handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
586 self.config.handshake_timeout = timeout;
587 self
588 }
589
590 pub fn idle_timeout(mut self, timeout: std::time::Duration) -> Self {
592 self.config.idle_timeout = timeout;
593 self
594 }
595
596 pub fn compression(mut self, enabled: bool) -> Self {
598 self.config.compression.enabled = enabled;
599 self
600 }
601
602 pub fn backpressure(mut self, strategy: crate::config::BackpressureStrategy) -> Self {
604 self.config.backpressure.strategy = strategy;
605 self
606 }
607
608 pub fn build(self) -> Result<Server> {
610 self.config.validate()?;
612
613 let handler = Box::new(crate::handler::DefaultHandler::new());
615
616 Ok(Server::new(self.config, handler))
617 }
618
619 pub fn build_with_handler<H>(self, handler: H) -> Result<Server>
621 where
622 H: Handler + Send + Sync + 'static,
623 {
624 self.config.validate()?;
626
627 Ok(Server::new(self.config, Box::new(handler)))
628 }
629}
630
631impl Default for ServerBuilder {
632 fn default() -> Self {
633 Self::new()
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn test_server_builder() {
643 let builder = ServerBuilder::new()
644 .bind("127.0.0.1:8080")
645 .unwrap()
646 .max_connections(1000)
647 .max_frame_size(1024 * 1024)
648 .compression(true);
649
650 assert!(builder.build().is_ok());
651 }
652}