1use crate::{
6 config::ServerConfig,
7 connection::{Connection, ConnectionHandle},
8 handler::{BoxedHandler, Handler},
9 rate_limit::RateLimitMiddleware,
10};
11use aerosocket_core::error::ConfigError;
12use aerosocket_core::handshake::{
13 create_server_handshake, parse_client_handshake, response_to_string, validate_client_handshake,
14 HandshakeConfig,
15};
16use aerosocket_core::transport::TransportStream;
17use aerosocket_core::{Error, Result, Transport};
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex};
22use tokio::time::{timeout, Duration};
23
24pub struct Server {
26 config: ServerConfig,
27 handler: BoxedHandler,
28 rate_limiter: Option<Arc<RateLimitMiddleware>>,
29}
30
31#[derive(Debug)]
33pub struct ConnectionManager {
34 connections: Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
35 next_id: Arc<Mutex<u64>>,
36}
37
38impl ConnectionManager {
39 pub fn new() -> Self {
41 Self {
42 connections: Arc::new(Mutex::new(HashMap::new())),
43 next_id: Arc::new(Mutex::new(1)),
44 }
45 }
46
47 pub async fn add_connection(&self, connection: Connection) -> u64 {
49 let mut next_id = self.next_id.lock().await;
50 let id = *next_id;
51 *next_id += 1;
52
53 let handle = ConnectionHandle::new(id, connection);
54 let mut connections = self.connections.lock().await;
55 connections.insert(id, handle);
56 id
57 }
58
59 pub async fn remove_connection(&self, id: u64) -> Option<ConnectionHandle> {
61 let mut connections = self.connections.lock().await;
62 connections.remove(&id)
63 }
64
65 pub async fn get_connection(&self, id: u64) -> Option<ConnectionHandle> {
67 let connections = self.connections.lock().await;
68 connections.get(&id).cloned()
69 }
70
71 pub async fn get_all_connections(&self) -> Vec<ConnectionHandle> {
73 let connections = self.connections.lock().await;
74 connections.values().cloned().collect()
75 }
76
77 pub async fn connection_count(&self) -> usize {
79 let connections = self.connections.lock().await;
80 connections.len()
81 }
82}
83
84impl std::fmt::Debug for Server {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("Server")
87 .field("config", &self.config)
88 .field("handler", &"<handler>")
89 .finish()
90 }
91}
92
93impl Server {
94 pub fn new(config: ServerConfig, handler: BoxedHandler) -> Self {
96 let rate_limiter = if config.backpressure.enabled {
97 Some(Arc::new(RateLimitMiddleware::new(
98 crate::rate_limit::RateLimitConfig {
99 max_requests: config.backpressure.max_requests_per_minute,
100 window: Duration::from_secs(60),
101 max_connections: config.max_connections / 10, connection_timeout: config.idle_timeout,
103 },
104 )))
105 } else {
106 None
107 };
108
109 Self {
110 config,
111 handler,
112 rate_limiter,
113 }
114 }
115
116 pub fn builder() -> ServerBuilder {
118 ServerBuilder::new()
119 }
120
121 pub async fn serve(self) -> Result<()> {
123 let connection_manager = Arc::new(ConnectionManager::new());
124 self.serve_with_connection_manager(connection_manager).await
125 }
126
127 pub async fn serve_with_graceful_shutdown<F>(self, shutdown_signal: F) -> Result<()>
129 where
130 F: std::future::Future<Output = ()> + Send + 'static,
131 {
132 let connection_manager = Arc::new(ConnectionManager::new());
133 let shutdown_signal = Box::pin(shutdown_signal);
134 self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_signal)
135 .await
136 }
137
138 async fn serve_with_connection_manager(
140 self,
141 connection_manager: Arc<ConnectionManager>,
142 ) -> Result<()> {
143 let (_shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
144 let shutdown_rx = Box::pin(async move {
145 let _ = shutdown_rx.recv().await;
146 });
147
148 self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_rx)
149 .await
150 }
151
152 async fn serve_with_connection_manager_and_shutdown<F>(
154 self,
155 _connection_manager: Arc<ConnectionManager>,
156 _shutdown_signal: F,
157 ) -> Result<()>
158 where
159 F: std::future::Future<Output = ()> + Send + Unpin + 'static,
160 {
161 #[cfg(feature = "tcp-transport")]
163 {
164 if self.config.transport_type == crate::config::TransportType::Tcp {
165 let transport =
166 crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
167 return self
168 .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
169 .await;
170 }
171 }
172
173 #[cfg(feature = "tls-transport")]
174 {
175 if self.config.transport_type == crate::config::TransportType::Tls {
176 let _tls_config = self.config.tls.as_ref().ok_or_else(|| {
177 Error::Other("TLS configuration required for TLS transport".to_string())
178 })?;
179 return Err(Error::Other("TLS transport is not available in this release. Please enable the 'tls-transport' feature and implement proper TLS configuration.".to_string()));
181 }
182 }
183
184 if self.config.transport_type == crate::config::TransportType::Tcp {
186 #[cfg(feature = "tcp-transport")]
187 {
188 let transport =
189 crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
190 return self
191 .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
192 .await;
193 }
194 }
195 Err(Error::Config(ConfigError::Validation(
196 "No transport available".to_string(),
197 )))
198 }
199
200 #[cfg(feature = "tcp-transport")]
202 async fn serve_with_tcp_transport<F>(
203 self,
204 transport: crate::tcp_transport::TcpTransport,
205 connection_manager: Arc<ConnectionManager>,
206 _shutdown_signal: F,
207 ) -> Result<()>
208 where
209 F: std::future::Future<Output = ()> + Send + 'static,
210 {
211 let handler = self.handler;
213 let config = self.config.clone();
214 let manager = connection_manager.clone();
215 let rate_limiter = self.rate_limiter.clone();
216
217 let server_task = tokio::spawn(async move {
218 let mut connection_counter = 0u64;
219
220 loop {
221 tokio::select! {
223 result = transport.accept() => {
224 match result {
225 Ok(mut stream) => {
226 let remote_addr = match stream.remote_addr() {
228 Ok(addr) => addr.ip(),
229 Err(e) => {
230 crate::log_error!("Failed to get remote address: {:?}", e);
231 let _ = stream.close().await;
232 continue;
233 }
234 };
235
236 if let Some(ref rate_limiter) = rate_limiter {
238 if !rate_limiter.check_connection(remote_addr).await.unwrap_or(true) {
239 crate::log_warn!("Rate limit exceeded for IP: {}", remote_addr);
240 let _ = stream.close().await;
241 continue;
242 }
243 }
244
245 if manager.connection_count().await >= config.max_connections {
247 crate::log_warn!("Connection limit reached, rejecting connection from {}", remote_addr);
248 let _ = stream.close().await;
250 continue;
251 }
252
253 connection_counter += 1;
254 let manager = manager.clone();
255 let handler = handler.clone();
256 let config = config.clone();
257 let rate_limiter = rate_limiter.clone();
258
259 tokio::spawn(async move {
261 if let Err(e) = Self::handle_connection(
262 stream,
263 handler,
264 config,
265 manager,
266 rate_limiter,
267 ).await {
268 crate::log_error!("Connection handling error: {:?}", e);
269 }
270 });
271 }
272 Err(e) => {
273 crate::log_error!("Accept error: {:?}", e);
274 }
276 }
277 }
278 _ = tokio::signal::ctrl_c() => {
279 break;
280 }
281 }
282 }
283 });
284
285 match server_task.await {
287 Ok(()) => Ok(()),
288 Err(e) => Err(Error::Other(format!("Server task panicked: {}", e))),
289 }
290 }
291
292 #[cfg(feature = "tls-transport")]
294 async fn serve_with_tls_transport<F>(
295 self,
296 _transport: crate::tls_transport::TlsTransport,
297 _connection_manager: Arc<ConnectionManager>,
298 _shutdown_signal: F,
299 ) -> Result<()>
300 where
301 F: std::future::Future<Output = ()> + Send + Unpin + 'static,
302 {
303 Err(Error::Other(
304 "TLS transport is not available in this release".to_string(),
305 ))
306 }
307
308 #[cfg(feature = "tls-transport")]
310 async fn handle_tls_connection(
311 _stream: crate::tls_transport::TlsStreamWrapper,
312 _handler: BoxedHandler,
313 _config: ServerConfig,
314 _connection_manager: Arc<ConnectionManager>,
315 _rate_limiter: Option<Arc<RateLimitMiddleware>>,
316 ) -> Result<()> {
317 Err(Error::Other(
318 "TLS connection handling is not available in this release".to_string(),
319 ))
320 }
321
322 async fn handle_connection(
324 mut stream: crate::tcp_transport::TcpStream,
325 handler: BoxedHandler,
326 config: ServerConfig,
327 connection_manager: Arc<ConnectionManager>,
328 rate_limiter: Option<Arc<RateLimitMiddleware>>,
329 ) -> Result<()> {
330 let (remote_addr, local_addr) = Self::perform_handshake(&mut stream, &config).await?;
332
333 let boxed_stream: Box<dyn TransportStream> = Box::new(stream);
335
336 let connection = Connection::with_stream(remote_addr, local_addr, boxed_stream);
338
339 let connection_id = connection_manager.add_connection(connection).await;
341
342 let connection_handle = connection_manager
344 .get_connection(connection_id)
345 .await
346 .ok_or_else(|| Error::Other("Failed to get connection handle".to_string()))?;
347
348 if let Err(e) = handler.handle(connection_handle).await {
350 crate::log_error!("Handler error: {:?}", e);
351 }
352
353 connection_manager.remove_connection(connection_id).await;
355
356 if let Some(ref rate_limiter) = rate_limiter {
358 rate_limiter.connection_closed(remote_addr.ip()).await;
359 }
360
361 Ok(())
362 }
363
364 #[cfg(feature = "tls-transport")]
366 async fn perform_tls_handshake(
367 stream: &mut crate::tls_transport::TlsStreamWrapper,
368 config: &ServerConfig,
369 ) -> Result<(SocketAddr, SocketAddr)> {
370 let request_data =
372 Self::read_tls_handshake_request(stream, config.handshake_timeout).await?;
373 let request_str = String::from_utf8_lossy(&request_data);
374
375 let request = parse_client_handshake(&request_str)?;
377
378 let handshake_config = HandshakeConfig {
380 protocols: config.supported_protocols.clone(),
381 extensions: config.supported_extensions.clone(),
382 origin: config.allowed_origin.clone(),
383 host: None,
384 extra_headers: config.extra_headers.clone(),
385 };
386
387 validate_client_handshake(&request, &handshake_config)?;
389
390 let response = create_server_handshake(&request, &handshake_config)?;
392 let response_str = response_to_string(&response);
393
394 stream.write_all(response_str.as_bytes()).await?;
396 stream.flush().await?;
397
398 let remote_addr = stream.remote_addr()?;
400 let local_addr = stream.local_addr()?;
401
402 Ok((remote_addr, local_addr))
403 }
404
405 #[cfg(feature = "tls-transport")]
407 async fn read_tls_handshake_request(
408 stream: &mut crate::tls_transport::TlsStreamWrapper,
409 timeout_duration: Duration,
410 ) -> Result<Vec<u8>> {
411 let mut buffer = Vec::new();
412 let mut temp_buffer = [0u8; 1024];
413
414 let read_result = timeout(timeout_duration, async {
415 loop {
416 let n = stream.read(&mut temp_buffer).await?;
417 if n == 0 {
418 break;
419 }
420
421 buffer.extend_from_slice(&temp_buffer[..n]);
422
423 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
425 break;
426 }
427
428 if buffer.len() > 8192 {
430 return Err(Error::Other("TLS handshake request too large".to_string()));
431 }
432 }
433
434 Ok::<(), Error>(())
435 })
436 .await;
437
438 match read_result {
439 Ok(result) => {
440 result?;
441 Ok(buffer)
442 }
443 Err(_) => Err(Error::Other("TLS handshake timeout".to_string())),
444 }
445 }
446
447 async fn perform_handshake(
449 stream: &mut crate::tcp_transport::TcpStream,
450 config: &ServerConfig,
451 ) -> Result<(SocketAddr, SocketAddr)> {
452 let request_data = Self::read_handshake_request(stream, config.handshake_timeout).await?;
454 let request_str = String::from_utf8_lossy(&request_data);
455
456 let request = parse_client_handshake(&request_str)?;
458
459 let handshake_config = HandshakeConfig {
461 protocols: config.supported_protocols.clone(),
462 extensions: config.supported_extensions.clone(),
463 origin: config.allowed_origin.clone(),
464 host: None,
465 extra_headers: config.extra_headers.clone(),
466 };
467
468 validate_client_handshake(&request, &handshake_config)?;
470
471 let response = create_server_handshake(&request, &handshake_config)?;
473 let response_str = response_to_string(&response);
474
475 stream.write_all(response_str.as_bytes()).await?;
477 stream.flush().await?;
478
479 let remote_addr = stream.remote_addr()?;
481 let local_addr = stream.local_addr()?;
482
483 Ok((remote_addr, local_addr))
484 }
485
486 async fn read_handshake_request(
488 stream: &mut crate::tcp_transport::TcpStream,
489 timeout_duration: Duration,
490 ) -> Result<Vec<u8>> {
491 let mut buffer = Vec::new();
492 let mut temp_buffer = [0u8; 1024];
493
494 let read_result = timeout(timeout_duration, async {
495 loop {
496 let n = stream.read(&mut temp_buffer).await?;
497 if n == 0 {
498 break;
499 }
500
501 buffer.extend_from_slice(&temp_buffer[..n]);
502
503 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
505 break;
506 }
507
508 if buffer.len() > 8192 {
510 return Err(Error::Other("Handshake request too large".to_string()));
511 }
512 }
513
514 Ok::<(), Error>(())
515 })
516 .await;
517
518 match read_result {
519 Ok(result) => {
520 result?;
521 Ok(buffer)
522 }
523 Err(_) => Err(Error::Other("Handshake timeout".to_string())),
524 }
525 }
526
527 async fn graceful_shutdown(&self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
529 let connections = connection_manager.get_all_connections().await;
531
532 for handle in connections {
534 if let Ok(mut connection) = handle.try_lock().await {
535 let _ = connection.close(Some(1000), Some("Server shutdown")).await;
536 }
537 }
538
539 Ok(())
540 }
541}
542
543#[derive(Debug, Clone)]
545pub struct ServerBuilder {
546 config: ServerConfig,
547}
548
549impl ServerBuilder {
550 pub fn new() -> Self {
552 Self {
553 config: ServerConfig::default(),
554 }
555 }
556
557 pub fn bind<A: std::net::ToSocketAddrs>(mut self, addr: A) -> Result<Self> {
559 self.config.bind_address = addr.to_socket_addrs()?.next().ok_or_else(|| {
560 Error::Config(ConfigError::Validation("Invalid bind address".to_string()))
561 })?;
562 Ok(self)
563 }
564
565 pub fn max_connections(mut self, max: usize) -> Self {
567 self.config.max_connections = max;
568 self
569 }
570
571 pub fn max_frame_size(mut self, size: usize) -> Self {
573 self.config.max_frame_size = size;
574 self
575 }
576
577 pub fn max_message_size(mut self, size: usize) -> Self {
579 self.config.max_message_size = size;
580 self
581 }
582
583 pub fn handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
585 self.config.handshake_timeout = timeout;
586 self
587 }
588
589 pub fn idle_timeout(mut self, timeout: std::time::Duration) -> Self {
591 self.config.idle_timeout = timeout;
592 self
593 }
594
595 pub fn compression(mut self, enabled: bool) -> Self {
597 self.config.compression.enabled = enabled;
598 self
599 }
600
601 pub fn backpressure(mut self, strategy: crate::config::BackpressureStrategy) -> Self {
603 self.config.backpressure.strategy = strategy;
604 self
605 }
606
607 pub fn build(self) -> Result<Server> {
609 self.config.validate()?;
611
612 let handler = Box::new(crate::handler::DefaultHandler::new());
614
615 Ok(Server::new(self.config, handler))
616 }
617
618 pub fn build_with_handler<H>(self, handler: H) -> Result<Server>
620 where
621 H: Handler + Send + Sync + 'static,
622 {
623 self.config.validate()?;
625
626 Ok(Server::new(self.config, Box::new(handler)))
627 }
628}
629
630impl Default for ServerBuilder {
631 fn default() -> Self {
632 Self::new()
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_server_builder() {
642 let builder = ServerBuilder::new()
643 .bind("127.0.0.1:8080")
644 .unwrap()
645 .max_connections(1000)
646 .max_frame_size(1024 * 1024)
647 .compression(true);
648
649 assert!(builder.build().is_ok());
650 }
651}