1#![allow(dead_code)]
9use crate::connection::should_keep_alive;
48use crate::expect::{CONTINUE_RESPONSE, ExpectHandler, ExpectResult};
49use crate::parser::{ParseError, ParseLimits, ParseStatus, Parser, StatefulParser};
50use crate::response::{ResponseWrite, ResponseWriter};
51use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
52use asupersync::net::{TcpListener, TcpStream};
53use asupersync::runtime::{RuntimeState, SpawnError, TaskHandle};
54use asupersync::signal::{GracefulOutcome, ShutdownController, ShutdownReceiver};
55use asupersync::stream::Stream;
56use asupersync::time::timeout;
57use asupersync::{Budget, Cx, Scope, Time};
58use fastapi_core::app::App;
59use fastapi_core::{Request, RequestContext, Response, StatusCode};
60use std::future::Future;
61use std::io;
62use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
63use std::pin::Pin;
64use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
65use std::sync::{Arc, Mutex, OnceLock};
66use std::task::Poll;
67use std::time::{Duration, Instant};
68
69static START_TIME: OnceLock<Instant> = OnceLock::new();
72
73fn current_time() -> Time {
78 let start = START_TIME.get_or_init(Instant::now);
79 let now = Instant::now();
80 if now < *start {
81 Time::ZERO
82 } else {
83 let elapsed = now.duration_since(*start);
84 Time::from_nanos(elapsed.as_nanos() as u64)
85 }
86}
87
88pub const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30;
90
91pub const DEFAULT_READ_BUFFER_SIZE: usize = 8192;
93
94pub const DEFAULT_MAX_CONNECTIONS: usize = 0;
96
97pub const DEFAULT_KEEP_ALIVE_TIMEOUT_SECS: u64 = 75;
99
100pub const DEFAULT_MAX_REQUESTS_PER_CONNECTION: usize = 100;
102
103pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
105
106#[derive(Debug, Clone)]
134pub struct ServerConfig {
135 pub bind_addr: String,
137 pub request_timeout: Time,
139 pub max_connections: usize,
141 pub read_buffer_size: usize,
143 pub parse_limits: ParseLimits,
145 pub allowed_hosts: Vec<String>,
147 pub trust_x_forwarded_host: bool,
149 pub tcp_nodelay: bool,
151 pub keep_alive_timeout: Duration,
154 pub max_requests_per_connection: usize,
156 pub drain_timeout: Duration,
159}
160
161impl ServerConfig {
162 #[must_use]
164 pub fn new(bind_addr: impl Into<String>) -> Self {
165 Self {
166 bind_addr: bind_addr.into(),
167 request_timeout: Time::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS),
168 max_connections: DEFAULT_MAX_CONNECTIONS,
169 read_buffer_size: DEFAULT_READ_BUFFER_SIZE,
170 parse_limits: ParseLimits::default(),
171 allowed_hosts: Vec::new(),
172 trust_x_forwarded_host: false,
173 tcp_nodelay: true,
174 keep_alive_timeout: Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS),
175 max_requests_per_connection: DEFAULT_MAX_REQUESTS_PER_CONNECTION,
176 drain_timeout: Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS),
177 }
178 }
179
180 #[must_use]
182 pub fn with_request_timeout(mut self, timeout: Time) -> Self {
183 self.request_timeout = timeout;
184 self
185 }
186
187 #[must_use]
189 pub fn with_request_timeout_secs(mut self, secs: u64) -> Self {
190 self.request_timeout = Time::from_secs(secs);
191 self
192 }
193
194 #[must_use]
196 pub fn with_max_connections(mut self, max: usize) -> Self {
197 self.max_connections = max;
198 self
199 }
200
201 #[must_use]
203 pub fn with_read_buffer_size(mut self, size: usize) -> Self {
204 self.read_buffer_size = size;
205 self
206 }
207
208 #[must_use]
210 pub fn with_parse_limits(mut self, limits: ParseLimits) -> Self {
211 self.parse_limits = limits;
212 self
213 }
214
215 #[must_use]
220 pub fn with_allowed_hosts<I, S>(mut self, hosts: I) -> Self
221 where
222 I: IntoIterator<Item = S>,
223 S: Into<String>,
224 {
225 self.allowed_hosts = hosts
227 .into_iter()
228 .map(|s| s.into().to_ascii_lowercase())
229 .collect();
230 self
231 }
232
233 #[must_use]
237 pub fn allow_host(mut self, host: impl Into<String>) -> Self {
238 self.allowed_hosts.push(host.into().to_ascii_lowercase());
240 self
241 }
242
243 #[must_use]
245 pub fn with_trust_x_forwarded_host(mut self, trust: bool) -> Self {
246 self.trust_x_forwarded_host = trust;
247 self
248 }
249
250 #[must_use]
252 pub fn with_tcp_nodelay(mut self, enabled: bool) -> Self {
253 self.tcp_nodelay = enabled;
254 self
255 }
256
257 #[must_use]
262 pub fn with_keep_alive_timeout(mut self, timeout: Duration) -> Self {
263 self.keep_alive_timeout = timeout;
264 self
265 }
266
267 #[must_use]
269 pub fn with_keep_alive_timeout_secs(mut self, secs: u64) -> Self {
270 self.keep_alive_timeout = Duration::from_secs(secs);
271 self
272 }
273
274 #[must_use]
278 pub fn with_max_requests_per_connection(mut self, max: usize) -> Self {
279 self.max_requests_per_connection = max;
280 self
281 }
282
283 #[must_use]
288 pub fn with_drain_timeout(mut self, timeout: Duration) -> Self {
289 self.drain_timeout = timeout;
290 self
291 }
292
293 #[must_use]
295 pub fn with_drain_timeout_secs(mut self, secs: u64) -> Self {
296 self.drain_timeout = Duration::from_secs(secs);
297 self
298 }
299}
300
301impl Default for ServerConfig {
302 fn default() -> Self {
303 Self::new("127.0.0.1:8080")
304 }
305}
306
307#[derive(Debug)]
309pub enum ServerError {
310 Io(io::Error),
312 Parse(ParseError),
314 Shutdown,
316 ConnectionLimitReached,
318 KeepAliveTimeout,
320}
321
322impl std::fmt::Display for ServerError {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 match self {
325 Self::Io(e) => write!(f, "IO error: {e}"),
326 Self::Parse(e) => write!(f, "Parse error: {e}"),
327 Self::Shutdown => write!(f, "Server shutdown"),
328 Self::ConnectionLimitReached => write!(f, "Connection limit reached"),
329 Self::KeepAliveTimeout => write!(f, "Keep-alive timeout"),
330 }
331 }
332}
333
334impl std::error::Error for ServerError {
335 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
336 match self {
337 Self::Io(e) => Some(e),
338 Self::Parse(e) => Some(e),
339 _ => None,
340 }
341 }
342}
343
344#[derive(Debug, Clone, PartialEq, Eq)]
349enum HostValidationErrorKind {
350 Missing,
351 Invalid,
352 NotAllowed,
353}
354
355#[derive(Debug, Clone)]
356struct HostValidationError {
357 kind: HostValidationErrorKind,
358 detail: String,
359}
360
361impl HostValidationError {
362 fn missing() -> Self {
363 Self {
364 kind: HostValidationErrorKind::Missing,
365 detail: "missing Host header".to_string(),
366 }
367 }
368
369 fn invalid(detail: impl Into<String>) -> Self {
370 Self {
371 kind: HostValidationErrorKind::Invalid,
372 detail: detail.into(),
373 }
374 }
375
376 fn not_allowed(detail: impl Into<String>) -> Self {
377 Self {
378 kind: HostValidationErrorKind::NotAllowed,
379 detail: detail.into(),
380 }
381 }
382
383 fn response(&self) -> Response {
384 let message = match self.kind {
385 HostValidationErrorKind::Missing => "Bad Request: Host header required",
386 HostValidationErrorKind::Invalid => "Bad Request: invalid Host header",
387 HostValidationErrorKind::NotAllowed => "Bad Request: Host not allowed",
388 };
389 Response::with_status(StatusCode::BAD_REQUEST).body(fastapi_core::ResponseBody::Bytes(
390 message.as_bytes().to_vec(),
391 ))
392 }
393}
394
395#[derive(Debug, Clone, PartialEq, Eq)]
396struct HostHeader {
397 host: String,
398 port: Option<u16>,
399}
400
401fn validate_host_header(
402 request: &Request,
403 config: &ServerConfig,
404) -> Result<HostHeader, HostValidationError> {
405 let raw = extract_effective_host(request, config)?;
406 let parsed = parse_host_header(&raw)
407 .ok_or_else(|| HostValidationError::invalid(format!("invalid host value: {raw}")))?;
408
409 if !is_allowed_host(&parsed, &config.allowed_hosts) {
410 return Err(HostValidationError::not_allowed(format!(
411 "host not allowed: {}",
412 parsed.host
413 )));
414 }
415
416 Ok(parsed)
417}
418
419fn extract_effective_host(
420 request: &Request,
421 config: &ServerConfig,
422) -> Result<String, HostValidationError> {
423 if config.trust_x_forwarded_host {
424 if let Some(value) = header_value(request, "x-forwarded-host")? {
425 let forwarded = extract_first_list_value(&value)
426 .ok_or_else(|| HostValidationError::invalid("empty X-Forwarded-Host value"))?;
427 return Ok(forwarded.to_string());
428 }
429 }
430
431 match header_value(request, "host")? {
432 Some(value) => Ok(value),
433 None => Err(HostValidationError::missing()),
434 }
435}
436
437fn header_value(request: &Request, name: &str) -> Result<Option<String>, HostValidationError> {
438 request
439 .headers()
440 .get(name)
441 .map(|bytes| {
442 std::str::from_utf8(bytes)
443 .map(|s| s.trim().to_string())
444 .map_err(|_| {
445 HostValidationError::invalid(format!("invalid UTF-8 in {name} header"))
446 })
447 })
448 .transpose()
449}
450
451fn extract_first_list_value(value: &str) -> Option<&str> {
452 value.split(',').map(str::trim).find(|v| !v.is_empty())
453}
454
455fn parse_host_header(value: &str) -> Option<HostHeader> {
456 let value = value.trim();
457 if value.is_empty() {
458 return None;
459 }
460 if value.chars().any(|c| c.is_control() || c.is_whitespace()) {
461 return None;
462 }
463
464 if value.starts_with('[') {
465 let end = value.find(']')?;
466 let host = &value[1..end];
467 if host.is_empty() {
468 return None;
469 }
470 if host.parse::<Ipv6Addr>().is_err() {
471 return None;
472 }
473 let rest = &value[end + 1..];
474 let port = if rest.is_empty() {
475 None
476 } else if let Some(port_str) = rest.strip_prefix(':') {
477 parse_port(port_str)
478 } else {
479 return None;
480 };
481 return Some(HostHeader {
482 host: host.to_ascii_lowercase(),
483 port,
484 });
485 }
486
487 let mut parts = value.split(':');
488 let host = parts.next().unwrap_or("");
489 let port_part = parts.next();
490 if parts.next().is_some() {
491 return None;
493 }
494 if host.is_empty() {
495 return None;
496 }
497
498 let port = match port_part {
499 Some(p) => parse_port(p),
500 None => None,
501 };
502
503 if host.parse::<Ipv4Addr>().is_ok() || is_valid_hostname(host) {
504 Some(HostHeader {
505 host: host.to_ascii_lowercase(),
506 port,
507 })
508 } else {
509 None
510 }
511}
512
513fn parse_port(port: &str) -> Option<u16> {
514 if port.is_empty() || !port.chars().all(|c| c.is_ascii_digit()) {
515 return None;
516 }
517 let value = port.parse::<u16>().ok()?;
518 if value == 0 { None } else { Some(value) }
519}
520
521fn is_valid_hostname(host: &str) -> bool {
522 if host.len() > 253 {
524 return false;
525 }
526 for label in host.split('.') {
527 if label.is_empty() || label.len() > 63 {
528 return false;
529 }
530 let bytes = label.as_bytes();
531 if bytes.first() == Some(&b'-') || bytes.last() == Some(&b'-') {
532 return false;
533 }
534 if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
535 return false;
536 }
537 }
538 true
539}
540
541fn is_allowed_host(host: &HostHeader, allowed_hosts: &[String]) -> bool {
542 if allowed_hosts.is_empty() {
543 return true;
544 }
545
546 allowed_hosts
547 .iter()
548 .any(|pattern| host_matches_pattern(host, pattern))
549}
550
551fn host_matches_pattern(host: &HostHeader, pattern: &str) -> bool {
552 let pattern = pattern.trim();
554 if pattern.is_empty() {
555 return false;
556 }
557 if pattern == "*" {
558 return true;
559 }
560 if let Some(suffix) = pattern.strip_prefix("*.") {
561 if host.host == suffix {
563 return false;
564 }
565 return host.host.len() > suffix.len() + 1
566 && host.host.ends_with(suffix)
567 && host.host.as_bytes()[host.host.len() - suffix.len() - 1] == b'.';
568 }
569
570 if let Some(parsed) = parse_host_header(pattern) {
571 if parsed.host != host.host {
572 return false;
573 }
574 if let Some(port) = parsed.port {
575 return host.port == Some(port);
576 }
577 return true;
578 }
579
580 false
581}
582
583impl From<io::Error> for ServerError {
584 fn from(e: io::Error) -> Self {
585 Self::Io(e)
586 }
587}
588
589impl From<ParseError> for ServerError {
590 fn from(e: ParseError) -> Self {
591 Self::Parse(e)
592 }
593}
594
595async fn process_connection<H, Fut>(
599 cx: &Cx,
600 request_counter: &AtomicU64,
601 mut stream: TcpStream,
602 _peer_addr: SocketAddr,
603 config: &ServerConfig,
604 handler: H,
605) -> Result<(), ServerError>
606where
607 H: Fn(RequestContext, &mut Request) -> Fut,
608 Fut: Future<Output = Response>,
609{
610 let mut parser = StatefulParser::new().with_limits(config.parse_limits.clone());
611 let mut read_buffer = vec![0u8; config.read_buffer_size];
612 let mut response_writer = ResponseWriter::new();
613 let mut requests_on_connection: usize = 0;
614 let max_requests = config.max_requests_per_connection;
615
616 loop {
617 if cx.is_cancel_requested() {
619 return Ok(());
620 }
621
622 let parse_result = parser.feed(&[])?;
624
625 let mut request = match parse_result {
626 ParseStatus::Complete { request, .. } => request,
627 ParseStatus::Incomplete => {
628 let keep_alive_timeout = config.keep_alive_timeout;
629
630 let bytes_read = if keep_alive_timeout.is_zero() {
631 read_into_buffer(&mut stream, &mut read_buffer).await?
632 } else {
633 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout).await
634 {
635 Ok(0) => return Ok(()),
636 Ok(n) => n,
637 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
638 cx.trace(&format!(
639 "Keep-alive timeout ({:?}) - closing idle connection",
640 keep_alive_timeout
641 ));
642 return Err(ServerError::KeepAliveTimeout);
643 }
644 Err(e) => return Err(ServerError::Io(e)),
645 }
646 };
647
648 if bytes_read == 0 {
649 return Ok(());
650 }
651
652 match parser.feed(&read_buffer[..bytes_read])? {
653 ParseStatus::Complete { request, .. } => request,
654 ParseStatus::Incomplete => continue,
655 }
656 }
657 };
658
659 requests_on_connection += 1;
660
661 let request_id = request_counter.fetch_add(1, Ordering::Relaxed);
663 let request_budget = Budget::new().with_deadline(config.request_timeout);
664 let request_cx = Cx::for_testing_with_budget(request_budget);
665 let ctx = RequestContext::new(request_cx, request_id);
666
667 if let Err(err) = validate_host_header(&request, config) {
669 ctx.trace(&format!("Rejecting request: {}", err.detail));
670 let response = err.response().header("connection", b"close".to_vec());
671 let response_write = response_writer.write(response);
672 write_response(&mut stream, response_write).await?;
673 return Ok(());
674 }
675
676 match ExpectHandler::check_expect(&request) {
680 ExpectResult::NoExpectation => {
681 }
683 ExpectResult::ExpectsContinue => {
684 ctx.trace("Sending 100 Continue for Expect: 100-continue");
689 write_raw_response(&mut stream, CONTINUE_RESPONSE).await?;
690 }
691 ExpectResult::UnknownExpectation(value) => {
692 ctx.trace(&format!("Rejecting unknown Expect value: {}", value));
694 let response =
695 ExpectHandler::expectation_failed(format!("Unsupported Expect value: {value}"));
696 let response_write = response_writer.write(response);
697 write_response(&mut stream, response_write).await?;
698 return Ok(());
699 }
700 }
701
702 let client_wants_keep_alive = should_keep_alive(&request);
703 let at_max_requests = max_requests > 0 && requests_on_connection >= max_requests;
704 let server_will_keep_alive = client_wants_keep_alive && !at_max_requests;
705
706 let request_start = Instant::now();
707 let timeout_duration = Duration::from_nanos(config.request_timeout.as_nanos());
708
709 let response = handler(ctx, &mut request).await;
711
712 let mut response = if request_start.elapsed() > timeout_duration {
713 Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
714 fastapi_core::ResponseBody::Bytes(
715 b"Gateway Timeout: request processing exceeded time limit".to_vec(),
716 ),
717 )
718 } else {
719 response
720 };
721
722 response = if server_will_keep_alive {
723 response.header("connection", b"keep-alive".to_vec())
724 } else {
725 response.header("connection", b"close".to_vec())
726 };
727
728 let response_write = response_writer.write(response);
729 write_response(&mut stream, response_write).await?;
730
731 if let Some(tasks) = App::take_background_tasks(&mut request) {
732 tasks.execute_all().await;
733 }
734
735 if !server_will_keep_alive {
736 return Ok(());
737 }
738 }
739}
740
741#[derive(Debug)]
747pub struct TcpServer {
748 config: ServerConfig,
749 request_counter: Arc<AtomicU64>,
750 connection_counter: Arc<AtomicU64>,
752 draining: Arc<AtomicBool>,
754 connection_handles: Mutex<Vec<TaskHandle<()>>>,
756 shutdown_controller: Arc<ShutdownController>,
758 metrics_counters: Arc<MetricsCounters>,
760}
761
762impl TcpServer {
763 #[must_use]
765 pub fn new(config: ServerConfig) -> Self {
766 Self {
767 config,
768 request_counter: Arc::new(AtomicU64::new(0)),
769 connection_counter: Arc::new(AtomicU64::new(0)),
770 draining: Arc::new(AtomicBool::new(false)),
771 connection_handles: Mutex::new(Vec::new()),
772 shutdown_controller: Arc::new(ShutdownController::new()),
773 metrics_counters: Arc::new(MetricsCounters::new()),
774 }
775 }
776
777 #[must_use]
779 pub fn config(&self) -> &ServerConfig {
780 &self.config
781 }
782
783 fn next_request_id(&self) -> u64 {
785 self.request_counter.fetch_add(1, Ordering::Relaxed)
786 }
787
788 #[must_use]
790 pub fn current_connections(&self) -> u64 {
791 self.connection_counter.load(Ordering::Relaxed)
792 }
793
794 #[must_use]
796 pub fn metrics(&self) -> ServerMetrics {
797 ServerMetrics {
798 active_connections: self.connection_counter.load(Ordering::Relaxed),
799 total_accepted: self.metrics_counters.total_accepted.load(Ordering::Relaxed),
800 total_rejected: self.metrics_counters.total_rejected.load(Ordering::Relaxed),
801 total_timed_out: self
802 .metrics_counters
803 .total_timed_out
804 .load(Ordering::Relaxed),
805 total_requests: self.request_counter.load(Ordering::Relaxed),
806 bytes_in: self.metrics_counters.bytes_in.load(Ordering::Relaxed),
807 bytes_out: self.metrics_counters.bytes_out.load(Ordering::Relaxed),
808 }
809 }
810
811 fn record_bytes_in(&self, n: u64) {
813 self.metrics_counters
814 .bytes_in
815 .fetch_add(n, Ordering::Relaxed);
816 }
817
818 fn record_bytes_out(&self, n: u64) {
820 self.metrics_counters
821 .bytes_out
822 .fetch_add(n, Ordering::Relaxed);
823 }
824
825 fn try_acquire_connection(&self) -> bool {
830 let max = self.config.max_connections;
831 if max == 0 {
832 self.connection_counter.fetch_add(1, Ordering::Relaxed);
834 self.metrics_counters
835 .total_accepted
836 .fetch_add(1, Ordering::Relaxed);
837 return true;
838 }
839
840 let mut current = self.connection_counter.load(Ordering::Relaxed);
842 loop {
843 if current >= max as u64 {
844 self.metrics_counters
845 .total_rejected
846 .fetch_add(1, Ordering::Relaxed);
847 return false;
848 }
849 match self.connection_counter.compare_exchange_weak(
850 current,
851 current + 1,
852 Ordering::AcqRel,
853 Ordering::Relaxed,
854 ) {
855 Ok(_) => {
856 self.metrics_counters
857 .total_accepted
858 .fetch_add(1, Ordering::Relaxed);
859 return true;
860 }
861 Err(actual) => current = actual,
862 }
863 }
864 }
865
866 fn release_connection(&self) {
868 self.connection_counter.fetch_sub(1, Ordering::Relaxed);
869 }
870
871 #[must_use]
873 pub fn is_draining(&self) -> bool {
874 self.draining.load(Ordering::Acquire)
875 }
876
877 pub fn start_drain(&self) {
884 self.draining.store(true, Ordering::Release);
885 }
886
887 pub async fn wait_for_drain(&self, timeout: Duration, poll_interval: Option<Duration>) -> bool {
897 let start = Instant::now();
898 let poll_interval = poll_interval.unwrap_or(Duration::from_millis(10));
899
900 while self.current_connections() > 0 {
901 if start.elapsed() >= timeout {
902 return false;
903 }
904 std::thread::sleep(poll_interval);
914 }
915 true
916 }
917
918 pub async fn drain(&self) -> u64 {
926 self.start_drain();
927 let drained = self.wait_for_drain(self.config.drain_timeout, None).await;
928 if drained {
929 0
930 } else {
931 self.current_connections()
932 }
933 }
934
935 #[must_use]
940 pub fn shutdown_controller(&self) -> &Arc<ShutdownController> {
941 &self.shutdown_controller
942 }
943
944 #[must_use]
949 pub fn subscribe_shutdown(&self) -> ShutdownReceiver {
950 self.shutdown_controller.subscribe()
951 }
952
953 pub fn shutdown(&self) {
962 self.start_drain();
963 self.shutdown_controller.shutdown();
964 }
965
966 #[must_use]
968 pub fn is_shutting_down(&self) -> bool {
969 self.shutdown_controller.is_shutting_down() || self.is_draining()
970 }
971
972 pub async fn serve_with_shutdown<H, Fut>(
1002 &self,
1003 cx: &Cx,
1004 mut shutdown: ShutdownReceiver,
1005 handler: H,
1006 ) -> Result<GracefulOutcome<()>, ServerError>
1007 where
1008 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1009 Fut: Future<Output = Response> + Send + 'static,
1010 {
1011 let bind_addr = self.config.bind_addr.clone();
1012 let listener = TcpListener::bind(bind_addr).await?;
1013 let local_addr = listener.local_addr()?;
1014
1015 cx.trace(&format!(
1016 "Server listening on {local_addr} (with graceful shutdown)"
1017 ));
1018
1019 let result = self
1021 .accept_loop_with_shutdown(cx, listener, handler, &mut shutdown)
1022 .await;
1023
1024 match result {
1025 Ok(outcome) => {
1026 if outcome.is_shutdown() {
1027 cx.trace("Shutdown signal received, draining connections");
1028 self.start_drain();
1029 self.drain_connection_tasks(cx).await;
1030 }
1031 Ok(outcome)
1032 }
1033 Err(e) => Err(e),
1034 }
1035 }
1036
1037 async fn accept_loop_with_shutdown<H, Fut>(
1039 &self,
1040 cx: &Cx,
1041 listener: TcpListener,
1042 handler: H,
1043 shutdown: &mut ShutdownReceiver,
1044 ) -> Result<GracefulOutcome<()>, ServerError>
1045 where
1046 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1047 Fut: Future<Output = Response> + Send + 'static,
1048 {
1049 let handler = Arc::new(handler);
1050
1051 loop {
1052 if shutdown.is_shutting_down() {
1054 return Ok(GracefulOutcome::ShutdownSignaled);
1055 }
1056 if cx.is_cancel_requested() || self.is_draining() {
1057 return Ok(GracefulOutcome::ShutdownSignaled);
1058 }
1059
1060 let (mut stream, peer_addr) = match listener.accept().await {
1062 Ok(conn) => conn,
1063 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
1064 continue;
1065 }
1066 Err(e) => {
1067 cx.trace(&format!("Accept error: {e}"));
1068 if is_fatal_accept_error(&e) {
1069 return Err(ServerError::Io(e));
1070 }
1071 continue;
1072 }
1073 };
1074
1075 if !self.try_acquire_connection() {
1077 cx.trace(&format!(
1078 "Connection limit reached ({}), rejecting {peer_addr}",
1079 self.config.max_connections
1080 ));
1081
1082 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
1083 .header("connection", b"close".to_vec())
1084 .body(fastapi_core::ResponseBody::Bytes(
1085 b"503 Service Unavailable: connection limit reached".to_vec(),
1086 ));
1087 let mut writer = crate::response::ResponseWriter::new();
1088 let response_bytes = writer.write(response);
1089 let _ = write_response(&mut stream, response_bytes).await;
1090 continue;
1091 }
1092
1093 if self.config.tcp_nodelay {
1095 let _ = stream.set_nodelay(true);
1096 }
1097
1098 cx.trace(&format!(
1099 "Accepted connection from {peer_addr} ({}/{})",
1100 self.current_connections(),
1101 if self.config.max_connections == 0 {
1102 "∞".to_string()
1103 } else {
1104 self.config.max_connections.to_string()
1105 }
1106 ));
1107
1108 let request_id = self.next_request_id();
1109 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
1110 let request_cx = Cx::for_testing_with_budget(request_budget);
1111 let ctx = RequestContext::new(request_cx, request_id);
1112
1113 let result = self
1115 .handle_connection(&ctx, stream, peer_addr, &*handler)
1116 .await;
1117
1118 self.release_connection();
1119
1120 if let Err(e) = result {
1121 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
1122 }
1123 }
1124 }
1125
1126 pub async fn serve<H, Fut>(&self, cx: &Cx, handler: H) -> Result<(), ServerError>
1140 where
1141 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1142 Fut: Future<Output = Response> + Send + 'static,
1143 {
1144 let bind_addr = self.config.bind_addr.clone();
1145 let listener = TcpListener::bind(bind_addr).await?;
1146 let local_addr = listener.local_addr()?;
1147
1148 cx.trace(&format!("Server listening on {local_addr}"));
1149
1150 self.accept_loop(cx, listener, handler).await
1151 }
1152
1153 pub async fn serve_on<H, Fut>(
1157 &self,
1158 cx: &Cx,
1159 listener: TcpListener,
1160 handler: H,
1161 ) -> Result<(), ServerError>
1162 where
1163 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1164 Fut: Future<Output = Response> + Send + 'static,
1165 {
1166 self.accept_loop(cx, listener, handler).await
1167 }
1168
1169 pub async fn serve_handler(
1190 &self,
1191 cx: &Cx,
1192 handler: Arc<dyn fastapi_core::Handler>,
1193 ) -> Result<(), ServerError> {
1194 let bind_addr = self.config.bind_addr.clone();
1195 let listener = TcpListener::bind(bind_addr).await?;
1196 let local_addr = listener.local_addr()?;
1197
1198 cx.trace(&format!("Server listening on {local_addr}"));
1199
1200 self.accept_loop_handler(cx, listener, handler).await
1201 }
1202
1203 pub async fn serve_on_handler(
1205 &self,
1206 cx: &Cx,
1207 listener: TcpListener,
1208 handler: Arc<dyn fastapi_core::Handler>,
1209 ) -> Result<(), ServerError> {
1210 self.accept_loop_handler(cx, listener, handler).await
1211 }
1212
1213 async fn accept_loop_handler(
1215 &self,
1216 cx: &Cx,
1217 listener: TcpListener,
1218 handler: Arc<dyn fastapi_core::Handler>,
1219 ) -> Result<(), ServerError> {
1220 loop {
1221 if cx.is_cancel_requested() {
1223 cx.trace("Server shutdown requested");
1224 return Ok(());
1225 }
1226
1227 if self.is_draining() {
1229 cx.trace("Server draining, stopping accept loop");
1230 return Err(ServerError::Shutdown);
1231 }
1232
1233 let (mut stream, peer_addr) = match listener.accept().await {
1235 Ok(conn) => conn,
1236 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
1237 continue;
1238 }
1239 Err(e) => {
1240 cx.trace(&format!("Accept error: {e}"));
1241 if is_fatal_accept_error(&e) {
1242 return Err(ServerError::Io(e));
1243 }
1244 continue;
1245 }
1246 };
1247
1248 if !self.try_acquire_connection() {
1250 cx.trace(&format!(
1251 "Connection limit reached ({}), rejecting {peer_addr}",
1252 self.config.max_connections
1253 ));
1254
1255 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
1256 .header("connection", b"close".to_vec())
1257 .body(fastapi_core::ResponseBody::Bytes(
1258 b"503 Service Unavailable: connection limit reached".to_vec(),
1259 ));
1260 let mut writer = crate::response::ResponseWriter::new();
1261 let response_bytes = writer.write(response);
1262 let _ = write_response(&mut stream, response_bytes).await;
1263 continue;
1264 }
1265
1266 if self.config.tcp_nodelay {
1268 let _ = stream.set_nodelay(true);
1269 }
1270
1271 cx.trace(&format!(
1272 "Accepted connection from {peer_addr} ({}/{})",
1273 self.current_connections(),
1274 if self.config.max_connections == 0 {
1275 "∞".to_string()
1276 } else {
1277 self.config.max_connections.to_string()
1278 }
1279 ));
1280
1281 let result = self
1283 .handle_connection_handler(cx, stream, peer_addr, &*handler)
1284 .await;
1285
1286 self.release_connection();
1287
1288 if let Err(e) = result {
1289 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
1290 }
1291 }
1292 }
1293
1294 #[allow(clippy::too_many_lines)]
1307 pub async fn serve_concurrent<H, Fut>(
1308 &self,
1309 cx: &Cx,
1310 scope: &Scope<'_>,
1311 state: &mut RuntimeState,
1312 handler: H,
1313 ) -> Result<(), ServerError>
1314 where
1315 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1316 Fut: Future<Output = Response> + Send + 'static,
1317 {
1318 let bind_addr = self.config.bind_addr.clone();
1319 let listener = TcpListener::bind(bind_addr).await?;
1320 let local_addr = listener.local_addr()?;
1321
1322 cx.trace(&format!(
1323 "Server listening on {local_addr} (concurrent mode)"
1324 ));
1325
1326 let handler = Arc::new(handler);
1327
1328 self.accept_loop_concurrent(cx, scope, state, listener, handler)
1329 .await
1330 }
1331
1332 async fn accept_loop_concurrent<H, Fut>(
1334 &self,
1335 cx: &Cx,
1336 scope: &Scope<'_>,
1337 state: &mut RuntimeState,
1338 listener: TcpListener,
1339 handler: Arc<H>,
1340 ) -> Result<(), ServerError>
1341 where
1342 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1343 Fut: Future<Output = Response> + Send + 'static,
1344 {
1345 loop {
1346 if cx.is_cancel_requested() || self.is_draining() {
1348 cx.trace("Server shutting down, draining connections");
1349 self.drain_connection_tasks(cx).await;
1350 return Ok(());
1351 }
1352
1353 let (mut stream, peer_addr) = match listener.accept().await {
1355 Ok(conn) => conn,
1356 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
1357 continue;
1358 }
1359 Err(e) => {
1360 cx.trace(&format!("Accept error: {e}"));
1361 if is_fatal_accept_error(&e) {
1362 return Err(ServerError::Io(e));
1363 }
1364 continue;
1365 }
1366 };
1367
1368 if !self.try_acquire_connection() {
1370 cx.trace(&format!(
1371 "Connection limit reached ({}), rejecting {peer_addr}",
1372 self.config.max_connections
1373 ));
1374
1375 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
1376 .header("connection", b"close".to_vec())
1377 .body(fastapi_core::ResponseBody::Bytes(
1378 b"503 Service Unavailable: connection limit reached".to_vec(),
1379 ));
1380 let mut writer = crate::response::ResponseWriter::new();
1381 let response_bytes = writer.write(response);
1382 let _ = write_response(&mut stream, response_bytes).await;
1383 continue;
1384 }
1385
1386 if self.config.tcp_nodelay {
1388 let _ = stream.set_nodelay(true);
1389 }
1390
1391 cx.trace(&format!(
1392 "Accepted connection from {peer_addr} ({}/{})",
1393 self.current_connections(),
1394 if self.config.max_connections == 0 {
1395 "∞".to_string()
1396 } else {
1397 self.config.max_connections.to_string()
1398 }
1399 ));
1400
1401 match self.spawn_connection_task(
1403 scope,
1404 state,
1405 cx,
1406 stream,
1407 peer_addr,
1408 Arc::clone(&handler),
1409 ) {
1410 Ok(handle) => {
1411 if let Ok(mut handles) = self.connection_handles.lock() {
1413 handles.push(handle);
1414 }
1415 self.cleanup_completed_handles();
1417 }
1418 Err(e) => {
1419 cx.trace(&format!("Failed to spawn connection task: {e:?}"));
1420 self.release_connection();
1421 }
1422 }
1423 }
1424 }
1425
1426 fn spawn_connection_task<H, Fut>(
1428 &self,
1429 scope: &Scope<'_>,
1430 state: &mut RuntimeState,
1431 cx: &Cx,
1432 stream: TcpStream,
1433 peer_addr: SocketAddr,
1434 handler: Arc<H>,
1435 ) -> Result<TaskHandle<()>, SpawnError>
1436 where
1437 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1438 Fut: Future<Output = Response> + Send + 'static,
1439 {
1440 let config = self.config.clone();
1441 let request_counter = Arc::clone(&self.request_counter);
1442 let connection_counter = Arc::clone(&self.connection_counter);
1443
1444 scope.spawn_registered(state, cx, move |task_cx| async move {
1445 let result = process_connection(
1446 &task_cx,
1447 &request_counter,
1448 stream,
1449 peer_addr,
1450 &config,
1451 |ctx, req| handler(ctx, req),
1452 )
1453 .await;
1454
1455 connection_counter.fetch_sub(1, Ordering::Relaxed);
1457
1458 if let Err(e) = result {
1459 eprintln!("Connection error from {peer_addr}: {e}");
1461 }
1462 })
1463 }
1464
1465 fn cleanup_completed_handles(&self) {
1467 if let Ok(mut handles) = self.connection_handles.lock() {
1468 handles.retain(|handle| !handle.is_finished());
1469 }
1470 }
1471
1472 async fn drain_connection_tasks(&self, cx: &Cx) {
1474 let drain_timeout = self.config.drain_timeout;
1475 let start = Instant::now();
1476
1477 cx.trace(&format!(
1478 "Draining {} connection tasks (timeout: {:?})",
1479 self.connection_handles.lock().map_or(0, |h| h.len()),
1480 drain_timeout
1481 ));
1482
1483 while start.elapsed() < drain_timeout {
1485 let remaining = self
1486 .connection_handles
1487 .lock()
1488 .map_or(0, |h| h.iter().filter(|t| !t.is_finished()).count());
1489
1490 if remaining == 0 {
1491 cx.trace("All connection tasks drained successfully");
1492 return;
1493 }
1494
1495 asupersync::runtime::yield_now().await;
1497 }
1498
1499 cx.trace(&format!(
1500 "Drain timeout reached with {} tasks still running",
1501 self.connection_handles
1502 .lock()
1503 .map_or(0, |h| h.iter().filter(|t| !t.is_finished()).count())
1504 ));
1505 }
1506
1507 async fn handle_connection_handler(
1512 &self,
1513 cx: &Cx,
1514 mut stream: TcpStream,
1515 _peer_addr: SocketAddr,
1516 handler: &dyn fastapi_core::Handler,
1517 ) -> Result<(), ServerError> {
1518 let mut parser = StatefulParser::new().with_limits(self.config.parse_limits.clone());
1519 let mut read_buffer = vec![0u8; self.config.read_buffer_size];
1520 let mut response_writer = ResponseWriter::new();
1521 let mut requests_on_connection: usize = 0;
1522 let max_requests = self.config.max_requests_per_connection;
1523
1524 loop {
1525 if cx.is_cancel_requested() {
1527 return Ok(());
1528 }
1529
1530 let parse_result = parser.feed(&[])?;
1532
1533 let mut request = match parse_result {
1534 ParseStatus::Complete { request, .. } => request,
1535 ParseStatus::Incomplete => {
1536 let keep_alive_timeout = self.config.keep_alive_timeout;
1537 let bytes_read = if keep_alive_timeout.is_zero() {
1538 read_into_buffer(&mut stream, &mut read_buffer).await?
1539 } else {
1540 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout)
1541 .await
1542 {
1543 Ok(0) => return Ok(()),
1544 Ok(n) => n,
1545 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
1546 self.metrics_counters
1547 .total_timed_out
1548 .fetch_add(1, Ordering::Relaxed);
1549 return Err(ServerError::KeepAliveTimeout);
1550 }
1551 Err(e) => return Err(ServerError::Io(e)),
1552 }
1553 };
1554
1555 if bytes_read == 0 {
1556 return Ok(());
1557 }
1558
1559 self.record_bytes_in(bytes_read as u64);
1560
1561 match parser.feed(&read_buffer[..bytes_read])? {
1562 ParseStatus::Complete { request, .. } => request,
1563 ParseStatus::Incomplete => continue,
1564 }
1565 }
1566 };
1567
1568 requests_on_connection += 1;
1569
1570 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
1572 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
1573 let request_cx = Cx::for_testing_with_budget(request_budget);
1574 let ctx = RequestContext::new(request_cx, request_id);
1575
1576 let response = handler.call(&ctx, &mut request).await;
1578
1579 let client_wants_keep_alive = should_keep_alive(&request);
1581 let server_will_keep_alive = client_wants_keep_alive
1582 && (max_requests == 0 || requests_on_connection < max_requests);
1583
1584 let response = if server_will_keep_alive {
1585 response.header("connection", b"keep-alive".to_vec())
1586 } else {
1587 response.header("connection", b"close".to_vec())
1588 };
1589
1590 let response_write = response_writer.write(response);
1591 if let ResponseWrite::Full(ref bytes) = response_write {
1592 self.record_bytes_out(bytes.len() as u64);
1593 }
1594 write_response(&mut stream, response_write).await?;
1595
1596 if !server_will_keep_alive {
1597 return Ok(());
1598 }
1599 }
1600 }
1601
1602 async fn accept_loop<H, Fut>(
1604 &self,
1605 cx: &Cx,
1606 listener: TcpListener,
1607 handler: H,
1608 ) -> Result<(), ServerError>
1609 where
1610 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1611 Fut: Future<Output = Response> + Send + 'static,
1612 {
1613 let handler = Arc::new(handler);
1614
1615 loop {
1616 if cx.is_cancel_requested() {
1618 cx.trace("Server shutdown requested");
1619 return Ok(());
1620 }
1621
1622 if self.is_draining() {
1624 cx.trace("Server draining, stopping accept loop");
1625 return Err(ServerError::Shutdown);
1626 }
1627
1628 let (mut stream, peer_addr) = match listener.accept().await {
1630 Ok(conn) => conn,
1631 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
1632 continue;
1634 }
1635 Err(e) => {
1636 cx.trace(&format!("Accept error: {e}"));
1637 if is_fatal_accept_error(&e) {
1640 return Err(ServerError::Io(e));
1641 }
1642 continue;
1643 }
1644 };
1645
1646 if !self.try_acquire_connection() {
1648 cx.trace(&format!(
1649 "Connection limit reached ({}), rejecting {peer_addr}",
1650 self.config.max_connections
1651 ));
1652
1653 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
1655 .header("connection", b"close".to_vec())
1656 .body(fastapi_core::ResponseBody::Bytes(
1657 b"503 Service Unavailable: connection limit reached".to_vec(),
1658 ));
1659 let mut writer = crate::response::ResponseWriter::new();
1660 let response_bytes = writer.write(response);
1661 let _ = write_response(&mut stream, response_bytes).await;
1662 continue;
1663 }
1664
1665 if self.config.tcp_nodelay {
1667 let _ = stream.set_nodelay(true);
1668 }
1669
1670 cx.trace(&format!(
1671 "Accepted connection from {peer_addr} ({}/{})",
1672 self.current_connections(),
1673 if self.config.max_connections == 0 {
1674 "∞".to_string()
1675 } else {
1676 self.config.max_connections.to_string()
1677 }
1678 ));
1679
1680 #[cfg(feature = "concurrent")]
1684 {
1685 self.spawn_connection_handler(cx.clone(), stream, peer_addr, Arc::clone(&handler));
1686 }
1687
1688 #[cfg(not(feature = "concurrent"))]
1691 {
1692 let request_id = self.next_request_id();
1693 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
1694
1695 let request_cx = Cx::for_testing_with_budget(request_budget);
1699 let ctx = RequestContext::new(request_cx, request_id);
1700
1701 let result = self
1703 .handle_connection(&ctx, stream, peer_addr, &*handler)
1704 .await;
1705
1706 self.release_connection();
1708
1709 if let Err(e) = result {
1710 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
1711 }
1712 }
1713 }
1714 }
1715
1716 #[cfg(feature = "concurrent")]
1724 fn spawn_connection_handler<H, Fut>(
1725 &self,
1726 server_cx: Cx,
1727 stream: TcpStream,
1728 peer_addr: SocketAddr,
1729 handler: Arc<H>,
1730 ) where
1731 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1732 Fut: Future<Output = Response> + Send + 'static,
1733 {
1734 let config = self.config.clone();
1736 let request_counter = Arc::clone(&self.request_counter);
1737 let connection_counter = Arc::clone(&self.connection_counter);
1738
1739 tokio::spawn(async move {
1743 let result = process_connection(
1744 &server_cx,
1745 &request_counter,
1746 stream,
1747 peer_addr,
1748 &config,
1749 |ctx, req| handler(ctx, req),
1750 )
1751 .await;
1752
1753 connection_counter.fetch_sub(1, Ordering::Relaxed);
1755
1756 if let Err(e) = result {
1757 server_cx.trace(&format!("Connection error from {peer_addr}: {e}"));
1758 }
1759 });
1760 }
1761
1762 async fn handle_connection<H, Fut>(
1768 &self,
1769 ctx: &RequestContext,
1770 stream: TcpStream,
1771 peer_addr: SocketAddr,
1772 handler: &H,
1773 ) -> Result<(), ServerError>
1774 where
1775 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync,
1776 Fut: Future<Output = Response> + Send,
1777 {
1778 process_connection(
1779 ctx.cx(),
1780 &self.request_counter,
1781 stream,
1782 peer_addr,
1783 &self.config,
1784 |ctx, req| handler(ctx, req),
1785 )
1786 .await
1787 }
1788}
1789
1790#[derive(Debug, Clone, PartialEq, Eq)]
1795pub struct ServerMetrics {
1796 pub active_connections: u64,
1798 pub total_accepted: u64,
1800 pub total_rejected: u64,
1802 pub total_timed_out: u64,
1804 pub total_requests: u64,
1806 pub bytes_in: u64,
1808 pub bytes_out: u64,
1810}
1811
1812#[derive(Debug)]
1817struct MetricsCounters {
1818 total_accepted: AtomicU64,
1819 total_rejected: AtomicU64,
1820 total_timed_out: AtomicU64,
1821 bytes_in: AtomicU64,
1822 bytes_out: AtomicU64,
1823}
1824
1825impl MetricsCounters {
1826 fn new() -> Self {
1827 Self {
1828 total_accepted: AtomicU64::new(0),
1829 total_rejected: AtomicU64::new(0),
1830 total_timed_out: AtomicU64::new(0),
1831 bytes_in: AtomicU64::new(0),
1832 bytes_out: AtomicU64::new(0),
1833 }
1834 }
1835}
1836
1837impl Default for TcpServer {
1838 fn default() -> Self {
1839 Self::new(ServerConfig::default())
1840 }
1841}
1842
1843fn is_fatal_accept_error(e: &io::Error) -> bool {
1845 matches!(
1847 e.kind(),
1848 io::ErrorKind::NotConnected | io::ErrorKind::InvalidInput
1849 )
1850}
1851
1852async fn read_into_buffer(stream: &mut TcpStream, buffer: &mut [u8]) -> io::Result<usize> {
1856 use std::future::poll_fn;
1857
1858 poll_fn(|cx| {
1859 let mut read_buf = ReadBuf::new(buffer);
1860 match Pin::new(&mut *stream).poll_read(cx, &mut read_buf) {
1861 Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
1862 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1863 Poll::Pending => Poll::Pending,
1864 }
1865 })
1866 .await
1867}
1868
1869async fn read_with_timeout(
1887 stream: &mut TcpStream,
1888 buffer: &mut [u8],
1889 timeout_duration: Duration,
1890) -> io::Result<usize> {
1891 let now = current_time();
1893
1894 let read_future = Box::pin(read_into_buffer(stream, buffer));
1896
1897 match timeout(now, timeout_duration, read_future).await {
1899 Ok(result) => result,
1900 Err(_elapsed) => Err(io::Error::new(
1901 io::ErrorKind::TimedOut,
1902 "keep-alive timeout expired",
1903 )),
1904 }
1905}
1906
1907async fn write_raw_response(stream: &mut TcpStream, bytes: &[u8]) -> io::Result<()> {
1911 use std::future::poll_fn;
1912 write_all(stream, bytes).await?;
1913 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await?;
1914 Ok(())
1915}
1916
1917async fn write_response(stream: &mut TcpStream, response: ResponseWrite) -> io::Result<()> {
1921 use std::future::poll_fn;
1922
1923 match response {
1924 ResponseWrite::Full(bytes) => {
1925 write_all(stream, &bytes).await?;
1926 }
1927 ResponseWrite::Stream(mut encoder) => {
1928 loop {
1930 let chunk = poll_fn(|cx| Pin::new(&mut encoder).poll_next(cx)).await;
1931 match chunk {
1932 Some(bytes) => {
1933 write_all(stream, &bytes).await?;
1934 }
1935 None => break,
1936 }
1937 }
1938 }
1939 }
1940
1941 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await?;
1943
1944 Ok(())
1945}
1946
1947async fn write_all(stream: &mut TcpStream, mut buf: &[u8]) -> io::Result<()> {
1949 use std::future::poll_fn;
1950
1951 while !buf.is_empty() {
1952 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, buf)).await?;
1953 if n == 0 {
1954 return Err(io::Error::new(
1955 io::ErrorKind::WriteZero,
1956 "failed to write whole buffer",
1957 ));
1958 }
1959 buf = &buf[n..];
1960 }
1961 Ok(())
1962}
1963
1964pub struct Server {
1976 parser: Parser,
1977}
1978
1979impl Server {
1980 #[must_use]
1982 pub fn new() -> Self {
1983 Self {
1984 parser: Parser::new(),
1985 }
1986 }
1987
1988 pub fn parse_request(&self, bytes: &[u8]) -> Result<Request, ParseError> {
1994 self.parser.parse(bytes)
1995 }
1996
1997 #[must_use]
1999 pub fn write_response(&self, response: Response) -> ResponseWrite {
2000 let mut writer = ResponseWriter::new();
2001 writer.write(response)
2002 }
2003}
2004
2005impl Default for Server {
2006 fn default() -> Self {
2007 Self::new()
2008 }
2009}
2010
2011#[cfg(test)]
2012mod tests {
2013 use super::*;
2014
2015 #[test]
2016 fn server_config_builder() {
2017 let config = ServerConfig::new("0.0.0.0:3000")
2018 .with_request_timeout_secs(60)
2019 .with_max_connections(1000)
2020 .with_tcp_nodelay(false)
2021 .with_allowed_hosts(["example.com", "api.example.com"])
2022 .with_trust_x_forwarded_host(true);
2023
2024 assert_eq!(config.bind_addr, "0.0.0.0:3000");
2025 assert_eq!(config.request_timeout, Time::from_secs(60));
2026 assert_eq!(config.max_connections, 1000);
2027 assert!(!config.tcp_nodelay);
2028 assert_eq!(config.allowed_hosts.len(), 2);
2029 assert!(config.trust_x_forwarded_host);
2030 }
2031
2032 #[test]
2033 fn server_config_defaults() {
2034 let config = ServerConfig::default();
2035 assert_eq!(config.bind_addr, "127.0.0.1:8080");
2036 assert_eq!(
2037 config.request_timeout,
2038 Time::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)
2039 );
2040 assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
2041 assert!(config.tcp_nodelay);
2042 assert!(config.allowed_hosts.is_empty());
2043 assert!(!config.trust_x_forwarded_host);
2044 }
2045
2046 #[test]
2047 fn tcp_server_creates_request_ids() {
2048 let server = TcpServer::default();
2049 let id1 = server.next_request_id();
2050 let id2 = server.next_request_id();
2051 let id3 = server.next_request_id();
2052
2053 assert_eq!(id1, 0);
2054 assert_eq!(id2, 1);
2055 assert_eq!(id3, 2);
2056 }
2057
2058 #[test]
2059 fn server_error_display() {
2060 let io_err = ServerError::Io(io::Error::new(io::ErrorKind::AddrInUse, "address in use"));
2061 assert!(io_err.to_string().contains("IO error"));
2062
2063 let shutdown_err = ServerError::Shutdown;
2064 assert_eq!(shutdown_err.to_string(), "Server shutdown");
2065
2066 let limit_err = ServerError::ConnectionLimitReached;
2067 assert_eq!(limit_err.to_string(), "Connection limit reached");
2068 }
2069
2070 #[test]
2071 fn sync_server_parses_request() {
2072 let server = Server::new();
2073 let request = b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n";
2074 let result = server.parse_request(request);
2075 assert!(result.is_ok());
2076 }
2077
2078 #[test]
2083 fn host_validation_missing_host_rejected() {
2084 let config = ServerConfig::default();
2085 let request = Request::new(fastapi_core::Method::Get, "/");
2086 let err = validate_host_header(&request, &config).unwrap_err();
2087 assert_eq!(err.kind, HostValidationErrorKind::Missing);
2088 assert_eq!(err.response().status().as_u16(), 400);
2089 }
2090
2091 #[test]
2092 fn host_validation_allows_configured_host() {
2093 let config = ServerConfig::default().with_allowed_hosts(["example.com"]);
2094 let mut request = Request::new(fastapi_core::Method::Get, "/");
2095 request
2096 .headers_mut()
2097 .insert("Host".to_string(), b"example.com".to_vec());
2098 assert!(validate_host_header(&request, &config).is_ok());
2099 }
2100
2101 #[test]
2102 fn host_validation_rejects_disallowed_host() {
2103 let config = ServerConfig::default().with_allowed_hosts(["example.com"]);
2104 let mut request = Request::new(fastapi_core::Method::Get, "/");
2105 request
2106 .headers_mut()
2107 .insert("Host".to_string(), b"evil.com".to_vec());
2108 let err = validate_host_header(&request, &config).unwrap_err();
2109 assert_eq!(err.kind, HostValidationErrorKind::NotAllowed);
2110 }
2111
2112 #[test]
2113 fn host_validation_wildcard_allows_subdomains_only() {
2114 let config = ServerConfig::default().with_allowed_hosts(["*.example.com"]);
2115 let mut request = Request::new(fastapi_core::Method::Get, "/");
2116 request
2117 .headers_mut()
2118 .insert("Host".to_string(), b"api.example.com".to_vec());
2119 assert!(validate_host_header(&request, &config).is_ok());
2120
2121 let mut request = Request::new(fastapi_core::Method::Get, "/");
2122 request
2123 .headers_mut()
2124 .insert("Host".to_string(), b"example.com".to_vec());
2125 let err = validate_host_header(&request, &config).unwrap_err();
2126 assert_eq!(err.kind, HostValidationErrorKind::NotAllowed);
2127 }
2128
2129 #[test]
2130 fn host_validation_uses_x_forwarded_host_when_trusted() {
2131 let config = ServerConfig::default()
2132 .with_allowed_hosts(["example.com"])
2133 .with_trust_x_forwarded_host(true);
2134 let mut request = Request::new(fastapi_core::Method::Get, "/");
2135 request
2136 .headers_mut()
2137 .insert("Host".to_string(), b"internal.local".to_vec());
2138 request
2139 .headers_mut()
2140 .insert("X-Forwarded-Host".to_string(), b"example.com".to_vec());
2141 assert!(validate_host_header(&request, &config).is_ok());
2142 }
2143
2144 #[test]
2145 fn host_validation_rejects_invalid_host_value() {
2146 let config = ServerConfig::default();
2147 let mut request = Request::new(fastapi_core::Method::Get, "/");
2148 request
2149 .headers_mut()
2150 .insert("Host".to_string(), b"bad host".to_vec());
2151 let err = validate_host_header(&request, &config).unwrap_err();
2152 assert_eq!(err.kind, HostValidationErrorKind::Invalid);
2153 }
2154
2155 #[test]
2160 fn keep_alive_default_http11() {
2161 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
2163 request
2164 .headers_mut()
2165 .insert("Host".to_string(), b"example.com".to_vec());
2166 assert!(should_keep_alive(&request));
2167 }
2168
2169 #[test]
2170 fn keep_alive_explicit_keep_alive() {
2171 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
2172 request
2173 .headers_mut()
2174 .insert("Connection".to_string(), b"keep-alive".to_vec());
2175 assert!(should_keep_alive(&request));
2176 }
2177
2178 #[test]
2179 fn keep_alive_connection_close() {
2180 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
2181 request
2182 .headers_mut()
2183 .insert("Connection".to_string(), b"close".to_vec());
2184 assert!(!should_keep_alive(&request));
2185 }
2186
2187 #[test]
2188 fn keep_alive_connection_close_case_insensitive() {
2189 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
2190 request
2191 .headers_mut()
2192 .insert("Connection".to_string(), b"CLOSE".to_vec());
2193 assert!(!should_keep_alive(&request));
2194 }
2195
2196 #[test]
2197 fn keep_alive_multiple_values() {
2198 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
2199 request
2200 .headers_mut()
2201 .insert("Connection".to_string(), b"keep-alive, upgrade".to_vec());
2202 assert!(should_keep_alive(&request));
2203 }
2204
2205 #[test]
2210 fn timeout_budget_created_with_config_deadline() {
2211 let config = ServerConfig::new("127.0.0.1:8080").with_request_timeout_secs(45);
2212 let budget = Budget::new().with_deadline(config.request_timeout);
2213 assert_eq!(budget.deadline, Some(Time::from_secs(45)));
2214 }
2215
2216 #[test]
2217 fn timeout_duration_conversion_from_time() {
2218 let timeout = Time::from_secs(30);
2219 let duration = Duration::from_nanos(timeout.as_nanos());
2220 assert_eq!(duration, Duration::from_secs(30));
2221 }
2222
2223 #[test]
2224 fn timeout_duration_conversion_from_time_millis() {
2225 let timeout = Time::from_millis(1500);
2226 let duration = Duration::from_nanos(timeout.as_nanos());
2227 assert_eq!(duration, Duration::from_millis(1500));
2228 }
2229
2230 #[test]
2231 fn gateway_timeout_response_has_correct_status() {
2232 let response = Response::with_status(StatusCode::GATEWAY_TIMEOUT);
2233 assert_eq!(response.status().as_u16(), 504);
2234 }
2235
2236 #[test]
2237 fn gateway_timeout_response_with_body() {
2238 let response = Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
2239 fastapi_core::ResponseBody::Bytes(b"Request timed out".to_vec()),
2240 );
2241 assert_eq!(response.status().as_u16(), 504);
2242 assert!(response.body_ref().len() > 0);
2244 }
2245
2246 #[test]
2247 fn elapsed_time_check_logic() {
2248 let start = Instant::now();
2250 let timeout_duration = Duration::from_millis(10);
2251
2252 assert!(start.elapsed() <= timeout_duration);
2254
2255 std::thread::sleep(Duration::from_millis(20));
2257
2258 assert!(start.elapsed() > timeout_duration);
2260 }
2261
2262 #[test]
2267 fn connection_counter_starts_at_zero() {
2268 let server = TcpServer::default();
2269 assert_eq!(server.current_connections(), 0);
2270 }
2271
2272 #[test]
2273 fn try_acquire_connection_unlimited() {
2274 let server = TcpServer::default();
2276 assert_eq!(server.config().max_connections, 0);
2277
2278 for _ in 0..100 {
2280 assert!(server.try_acquire_connection());
2281 }
2282 assert_eq!(server.current_connections(), 100);
2283
2284 for _ in 0..100 {
2286 server.release_connection();
2287 }
2288 assert_eq!(server.current_connections(), 0);
2289 }
2290
2291 #[test]
2292 fn try_acquire_connection_with_limit() {
2293 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(5);
2294 let server = TcpServer::new(config);
2295
2296 for i in 0..5 {
2298 assert!(
2299 server.try_acquire_connection(),
2300 "Should acquire connection {i}"
2301 );
2302 }
2303 assert_eq!(server.current_connections(), 5);
2304
2305 assert!(!server.try_acquire_connection());
2307 assert_eq!(server.current_connections(), 5);
2308
2309 server.release_connection();
2311 assert_eq!(server.current_connections(), 4);
2312
2313 assert!(server.try_acquire_connection());
2315 assert_eq!(server.current_connections(), 5);
2316 }
2317
2318 #[test]
2319 fn try_acquire_connection_single_connection_limit() {
2320 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(1);
2321 let server = TcpServer::new(config);
2322
2323 assert!(server.try_acquire_connection());
2325 assert_eq!(server.current_connections(), 1);
2326
2327 assert!(!server.try_acquire_connection());
2329 assert_eq!(server.current_connections(), 1);
2330
2331 server.release_connection();
2333 assert!(server.try_acquire_connection());
2334 }
2335
2336 #[test]
2337 fn service_unavailable_response_has_correct_status() {
2338 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE);
2339 assert_eq!(response.status().as_u16(), 503);
2340 }
2341
2342 #[test]
2343 fn service_unavailable_response_with_body() {
2344 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
2345 .header("connection", b"close".to_vec())
2346 .body(fastapi_core::ResponseBody::Bytes(
2347 b"503 Service Unavailable: connection limit reached".to_vec(),
2348 ));
2349 assert_eq!(response.status().as_u16(), 503);
2350 assert!(response.body_ref().len() > 0);
2351 }
2352
2353 #[test]
2354 fn config_max_connections_default_is_zero() {
2355 let config = ServerConfig::default();
2356 assert_eq!(config.max_connections, 0);
2357 }
2358
2359 #[test]
2360 fn config_max_connections_can_be_set() {
2361 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(100);
2362 assert_eq!(config.max_connections, 100);
2363 }
2364
2365 #[test]
2370 fn config_keep_alive_timeout_default() {
2371 let config = ServerConfig::default();
2372 assert_eq!(
2373 config.keep_alive_timeout,
2374 Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS)
2375 );
2376 }
2377
2378 #[test]
2379 fn config_keep_alive_timeout_can_be_set() {
2380 let config =
2381 ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout(Duration::from_secs(120));
2382 assert_eq!(config.keep_alive_timeout, Duration::from_secs(120));
2383 }
2384
2385 #[test]
2386 fn config_keep_alive_timeout_can_be_set_secs() {
2387 let config = ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout_secs(90);
2388 assert_eq!(config.keep_alive_timeout, Duration::from_secs(90));
2389 }
2390
2391 #[test]
2392 fn config_max_requests_per_connection_default() {
2393 let config = ServerConfig::default();
2394 assert_eq!(
2395 config.max_requests_per_connection,
2396 DEFAULT_MAX_REQUESTS_PER_CONNECTION
2397 );
2398 }
2399
2400 #[test]
2401 fn config_max_requests_per_connection_can_be_set() {
2402 let config = ServerConfig::new("127.0.0.1:8080").with_max_requests_per_connection(50);
2403 assert_eq!(config.max_requests_per_connection, 50);
2404 }
2405
2406 #[test]
2407 fn config_max_requests_per_connection_unlimited() {
2408 let config = ServerConfig::new("127.0.0.1:8080").with_max_requests_per_connection(0);
2409 assert_eq!(config.max_requests_per_connection, 0);
2410 }
2411
2412 #[test]
2413 fn response_with_keep_alive_header() {
2414 let response = Response::ok().header("connection", b"keep-alive".to_vec());
2415 let headers = response.headers();
2416 let connection_header = headers
2417 .iter()
2418 .find(|(name, _)| name.eq_ignore_ascii_case("connection"));
2419 assert!(connection_header.is_some());
2420 assert_eq!(connection_header.unwrap().1, b"keep-alive");
2421 }
2422
2423 #[test]
2424 fn response_with_close_header() {
2425 let response = Response::ok().header("connection", b"close".to_vec());
2426 let headers = response.headers();
2427 let connection_header = headers
2428 .iter()
2429 .find(|(name, _)| name.eq_ignore_ascii_case("connection"));
2430 assert!(connection_header.is_some());
2431 assert_eq!(connection_header.unwrap().1, b"close");
2432 }
2433
2434 #[test]
2439 fn config_drain_timeout_default() {
2440 let config = ServerConfig::default();
2441 assert_eq!(
2442 config.drain_timeout,
2443 Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS)
2444 );
2445 }
2446
2447 #[test]
2448 fn config_drain_timeout_can_be_set() {
2449 let config =
2450 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_secs(60));
2451 assert_eq!(config.drain_timeout, Duration::from_secs(60));
2452 }
2453
2454 #[test]
2455 fn config_drain_timeout_can_be_set_secs() {
2456 let config = ServerConfig::new("127.0.0.1:8080").with_drain_timeout_secs(45);
2457 assert_eq!(config.drain_timeout, Duration::from_secs(45));
2458 }
2459
2460 #[test]
2461 fn server_not_draining_initially() {
2462 let server = TcpServer::default();
2463 assert!(!server.is_draining());
2464 }
2465
2466 #[test]
2467 fn server_start_drain_sets_flag() {
2468 let server = TcpServer::default();
2469 assert!(!server.is_draining());
2470 server.start_drain();
2471 assert!(server.is_draining());
2472 }
2473
2474 #[test]
2475 fn server_start_drain_idempotent() {
2476 let server = TcpServer::default();
2477 server.start_drain();
2478 assert!(server.is_draining());
2479 server.start_drain();
2480 assert!(server.is_draining());
2481 }
2482
2483 #[tokio::test]
2484 async fn wait_for_drain_returns_true_when_no_connections() {
2485 let server = TcpServer::default();
2486 assert_eq!(server.current_connections(), 0);
2487 let result = server
2488 .wait_for_drain(Duration::from_millis(100), Some(Duration::from_millis(1)))
2489 .await;
2490 assert!(result);
2491 }
2492
2493 #[tokio::test]
2494 async fn wait_for_drain_timeout_with_connections() {
2495 let server = TcpServer::default();
2496 server.try_acquire_connection();
2498 server.try_acquire_connection();
2499 assert_eq!(server.current_connections(), 2);
2500
2501 let result = server
2503 .wait_for_drain(Duration::from_millis(50), Some(Duration::from_millis(5)))
2504 .await;
2505 assert!(!result);
2506 assert_eq!(server.current_connections(), 2);
2507 }
2508
2509 #[tokio::test]
2510 async fn drain_returns_zero_when_no_connections() {
2511 let server = TcpServer::new(
2512 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_millis(100)),
2513 );
2514 assert_eq!(server.current_connections(), 0);
2515 let remaining = server.drain().await;
2516 assert_eq!(remaining, 0);
2517 assert!(server.is_draining());
2518 }
2519
2520 #[tokio::test]
2521 async fn drain_returns_count_when_connections_remain() {
2522 let server = TcpServer::new(
2523 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_millis(50)),
2524 );
2525 server.try_acquire_connection();
2527 server.try_acquire_connection();
2528 server.try_acquire_connection();
2529
2530 let remaining = server.drain().await;
2531 assert_eq!(remaining, 3);
2532 assert!(server.is_draining());
2533 }
2534
2535 #[test]
2536 fn server_shutdown_error_display() {
2537 let err = ServerError::Shutdown;
2538 assert_eq!(err.to_string(), "Server shutdown");
2539 }
2540
2541 #[test]
2546 fn server_has_shutdown_controller() {
2547 let server = TcpServer::default();
2548 let controller = server.shutdown_controller();
2549 assert!(!controller.is_shutting_down());
2550 }
2551
2552 #[test]
2553 fn server_subscribe_shutdown_returns_receiver() {
2554 let server = TcpServer::default();
2555 let receiver = server.subscribe_shutdown();
2556 assert!(!receiver.is_shutting_down());
2557 }
2558
2559 #[test]
2560 fn server_shutdown_sets_draining_and_controller() {
2561 let server = TcpServer::default();
2562 assert!(!server.is_shutting_down());
2563 assert!(!server.is_draining());
2564 assert!(!server.shutdown_controller().is_shutting_down());
2565
2566 server.shutdown();
2567
2568 assert!(server.is_shutting_down());
2569 assert!(server.is_draining());
2570 assert!(server.shutdown_controller().is_shutting_down());
2571 }
2572
2573 #[test]
2574 fn server_shutdown_notifies_receivers() {
2575 let server = TcpServer::default();
2576 let receiver1 = server.subscribe_shutdown();
2577 let receiver2 = server.subscribe_shutdown();
2578
2579 assert!(!receiver1.is_shutting_down());
2580 assert!(!receiver2.is_shutting_down());
2581
2582 server.shutdown();
2583
2584 assert!(receiver1.is_shutting_down());
2585 assert!(receiver2.is_shutting_down());
2586 }
2587
2588 #[test]
2589 fn server_shutdown_is_idempotent() {
2590 let server = TcpServer::default();
2591 let receiver = server.subscribe_shutdown();
2592
2593 server.shutdown();
2594 server.shutdown();
2595 server.shutdown();
2596
2597 assert!(server.is_shutting_down());
2598 assert!(receiver.is_shutting_down());
2599 }
2600
2601 #[test]
2606 fn keep_alive_timeout_error_display() {
2607 let err = ServerError::KeepAliveTimeout;
2608 assert_eq!(err.to_string(), "Keep-alive timeout");
2609 }
2610
2611 #[test]
2612 fn keep_alive_timeout_zero_disables_timeout() {
2613 let config = ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout(Duration::ZERO);
2614 assert!(config.keep_alive_timeout.is_zero());
2615 }
2616
2617 #[test]
2618 fn keep_alive_timeout_default_is_non_zero() {
2619 let config = ServerConfig::default();
2620 assert!(!config.keep_alive_timeout.is_zero());
2621 assert_eq!(
2622 config.keep_alive_timeout,
2623 Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS)
2624 );
2625 }
2626
2627 #[test]
2628 fn timed_out_io_error_kind() {
2629 let err = io::Error::new(io::ErrorKind::TimedOut, "test timeout");
2630 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
2631 }
2632
2633 #[test]
2634 fn instant_deadline_calculation() {
2635 let timeout = Duration::from_millis(100);
2636 let deadline = Instant::now() + timeout;
2637
2638 assert!(deadline > Instant::now());
2640
2641 std::thread::sleep(Duration::from_millis(150));
2643 assert!(Instant::now() >= deadline);
2644 }
2645
2646 #[test]
2647 fn server_metrics_initial_state() {
2648 let server = TcpServer::default();
2649 let m = server.metrics();
2650 assert_eq!(m.active_connections, 0);
2651 assert_eq!(m.total_accepted, 0);
2652 assert_eq!(m.total_rejected, 0);
2653 assert_eq!(m.total_timed_out, 0);
2654 assert_eq!(m.total_requests, 0);
2655 assert_eq!(m.bytes_in, 0);
2656 assert_eq!(m.bytes_out, 0);
2657 }
2658
2659 #[test]
2660 fn server_metrics_after_acquire_release() {
2661 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(10));
2662 assert!(server.try_acquire_connection());
2663 assert!(server.try_acquire_connection());
2664
2665 let m = server.metrics();
2666 assert_eq!(m.active_connections, 2);
2667 assert_eq!(m.total_accepted, 2);
2668 assert_eq!(m.total_rejected, 0);
2669
2670 server.release_connection();
2671 let m = server.metrics();
2672 assert_eq!(m.active_connections, 1);
2673 assert_eq!(m.total_accepted, 2); }
2675
2676 #[test]
2677 fn server_metrics_rejection_counted() {
2678 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(1));
2679 assert!(server.try_acquire_connection());
2680 assert!(!server.try_acquire_connection()); let m = server.metrics();
2683 assert_eq!(m.total_accepted, 1);
2684 assert_eq!(m.total_rejected, 1);
2685 assert_eq!(m.active_connections, 1);
2686 }
2687
2688 #[test]
2689 fn server_metrics_bytes_tracking() {
2690 let server = TcpServer::default();
2691 server.record_bytes_in(1024);
2692 server.record_bytes_in(512);
2693 server.record_bytes_out(2048);
2694
2695 let m = server.metrics();
2696 assert_eq!(m.bytes_in, 1536);
2697 assert_eq!(m.bytes_out, 2048);
2698 }
2699
2700 #[test]
2701 fn server_metrics_unlimited_connections_accepted() {
2702 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(0));
2703 for _ in 0..100 {
2704 assert!(server.try_acquire_connection());
2705 }
2706 let m = server.metrics();
2707 assert_eq!(m.total_accepted, 100);
2708 assert_eq!(m.total_rejected, 0);
2709 assert_eq!(m.active_connections, 100);
2710 }
2711
2712 #[test]
2713 fn server_metrics_clone_eq() {
2714 let server = TcpServer::default();
2715 server.record_bytes_in(42);
2716 let m1 = server.metrics();
2717 let m2 = m1.clone();
2718 assert_eq!(m1, m2);
2719 }
2720}
2721
2722pub trait AppServeExt {
2745 fn serve(self, addr: impl Into<String>) -> impl Future<Output = Result<(), ServeError>> + Send;
2781
2782 fn serve_with_config(
2805 self,
2806 config: ServerConfig,
2807 ) -> impl Future<Output = Result<(), ServeError>> + Send;
2808}
2809
2810#[derive(Debug)]
2812pub enum ServeError {
2813 Startup(fastapi_core::StartupHookError),
2815 Server(ServerError),
2817}
2818
2819impl std::fmt::Display for ServeError {
2820 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2821 match self {
2822 Self::Startup(e) => write!(f, "startup hook failed: {}", e.message),
2823 Self::Server(e) => write!(f, "server error: {e}"),
2824 }
2825 }
2826}
2827
2828impl std::error::Error for ServeError {
2829 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
2830 match self {
2831 Self::Startup(_) => None,
2832 Self::Server(e) => Some(e),
2833 }
2834 }
2835}
2836
2837impl From<ServerError> for ServeError {
2838 fn from(e: ServerError) -> Self {
2839 Self::Server(e)
2840 }
2841}
2842
2843impl AppServeExt for App {
2844 fn serve(self, addr: impl Into<String>) -> impl Future<Output = Result<(), ServeError>> + Send {
2845 let config = ServerConfig::new(addr);
2846 self.serve_with_config(config)
2847 }
2848
2849 #[allow(clippy::manual_async_fn)] fn serve_with_config(
2851 self,
2852 config: ServerConfig,
2853 ) -> impl Future<Output = Result<(), ServeError>> + Send {
2854 async move {
2855 match self.run_startup_hooks().await {
2857 fastapi_core::StartupOutcome::Success => {}
2858 fastapi_core::StartupOutcome::PartialSuccess { warnings } => {
2859 eprintln!("Warning: {warnings} startup hook(s) had non-fatal errors");
2861 }
2862 fastapi_core::StartupOutcome::Aborted(e) => {
2863 return Err(ServeError::Startup(e));
2864 }
2865 }
2866
2867 let server = TcpServer::new(config);
2869
2870 let app = Arc::new(self);
2873 let handler: Arc<dyn fastapi_core::Handler> =
2874 Arc::clone(&app) as Arc<dyn fastapi_core::Handler>;
2875
2876 let cx = Cx::for_testing();
2878
2879 let bind_addr = &server.config().bind_addr;
2881 println!("🚀 Server starting on http://{bind_addr}");
2882
2883 let result = server.serve_handler(&cx, handler).await;
2885
2886 app.run_shutdown_hooks().await;
2888
2889 result.map_err(ServeError::from)
2890 }
2891 }
2892}
2893
2894pub async fn serve(app: App, addr: impl Into<String>) -> Result<(), ServeError> {
2912 app.serve(addr).await
2913}
2914
2915pub async fn serve_with_config(app: App, config: ServerConfig) -> Result<(), ServeError> {
2933 app.serve_with_config(config).await
2934}