1#![allow(dead_code)]
10use crate::connection::should_keep_alive;
49use crate::expect::{
50 CONTINUE_RESPONSE, ExpectHandler, ExpectResult, PreBodyValidator, PreBodyValidators,
51};
52use crate::http2;
53use crate::parser::{ParseError, ParseLimits, ParseStatus, Parser, StatefulParser};
54use crate::response::{ResponseWrite, ResponseWriter};
55use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
56use asupersync::net::{TcpListener, TcpStream};
57use asupersync::runtime::{RuntimeState, SpawnError, TaskHandle};
58use asupersync::signal::{GracefulOutcome, ShutdownController, ShutdownReceiver};
59use asupersync::stream::Stream;
60use asupersync::time::timeout;
61use asupersync::{Budget, Cx, Scope, Time};
62use fastapi_core::app::App;
63use fastapi_core::{Method, Request, RequestContext, Response, StatusCode};
64use std::future::Future;
65use std::io;
66use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
67use std::pin::Pin;
68use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
69use std::sync::{Arc, Mutex, OnceLock};
70use std::task::Poll;
71use std::time::{Duration, Instant};
72
73static START_TIME: OnceLock<Instant> = OnceLock::new();
76
77fn current_time() -> Time {
82 let start = START_TIME.get_or_init(Instant::now);
83 let now = Instant::now();
84 if now < *start {
85 Time::ZERO
86 } else {
87 let elapsed = now.duration_since(*start);
88 Time::from_nanos(elapsed.as_nanos() as u64)
89 }
90}
91
92pub const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30;
94
95pub const DEFAULT_READ_BUFFER_SIZE: usize = 8192;
97
98pub const DEFAULT_MAX_CONNECTIONS: usize = 0;
100
101pub const DEFAULT_KEEP_ALIVE_TIMEOUT_SECS: u64 = 75;
103
104pub const DEFAULT_MAX_REQUESTS_PER_CONNECTION: usize = 100;
106
107pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
109
110#[derive(Debug, Clone)]
138pub struct ServerConfig {
139 pub bind_addr: String,
141 pub request_timeout: Time,
143 pub max_connections: usize,
145 pub read_buffer_size: usize,
147 pub parse_limits: ParseLimits,
149 pub allowed_hosts: Vec<String>,
151 pub trust_x_forwarded_host: bool,
153 pub tcp_nodelay: bool,
155 pub keep_alive_timeout: Duration,
158 pub max_requests_per_connection: usize,
160 pub drain_timeout: Duration,
163 pub pre_body_validators: PreBodyValidators,
168}
169
170impl ServerConfig {
171 #[must_use]
173 pub fn new(bind_addr: impl Into<String>) -> Self {
174 Self {
175 bind_addr: bind_addr.into(),
176 request_timeout: Time::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS),
177 max_connections: DEFAULT_MAX_CONNECTIONS,
178 read_buffer_size: DEFAULT_READ_BUFFER_SIZE,
179 parse_limits: ParseLimits::default(),
180 allowed_hosts: Vec::new(),
181 trust_x_forwarded_host: false,
182 tcp_nodelay: true,
183 keep_alive_timeout: Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS),
184 max_requests_per_connection: DEFAULT_MAX_REQUESTS_PER_CONNECTION,
185 drain_timeout: Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS),
186 pre_body_validators: PreBodyValidators::new(),
187 }
188 }
189
190 #[must_use]
192 pub fn with_request_timeout(mut self, timeout: Time) -> Self {
193 self.request_timeout = timeout;
194 self
195 }
196
197 #[must_use]
199 pub fn with_request_timeout_secs(mut self, secs: u64) -> Self {
200 self.request_timeout = Time::from_secs(secs);
201 self
202 }
203
204 #[must_use]
206 pub fn with_max_connections(mut self, max: usize) -> Self {
207 self.max_connections = max;
208 self
209 }
210
211 #[must_use]
213 pub fn with_read_buffer_size(mut self, size: usize) -> Self {
214 self.read_buffer_size = size;
215 self
216 }
217
218 #[must_use]
220 pub fn with_parse_limits(mut self, limits: ParseLimits) -> Self {
221 self.parse_limits = limits;
222 self
223 }
224
225 #[must_use]
230 pub fn with_allowed_hosts<I, S>(mut self, hosts: I) -> Self
231 where
232 I: IntoIterator<Item = S>,
233 S: Into<String>,
234 {
235 self.allowed_hosts = hosts
237 .into_iter()
238 .map(|s| s.into().to_ascii_lowercase())
239 .collect();
240 self
241 }
242
243 #[must_use]
247 pub fn allow_host(mut self, host: impl Into<String>) -> Self {
248 self.allowed_hosts.push(host.into().to_ascii_lowercase());
250 self
251 }
252
253 #[must_use]
255 pub fn with_trust_x_forwarded_host(mut self, trust: bool) -> Self {
256 self.trust_x_forwarded_host = trust;
257 self
258 }
259
260 #[must_use]
262 pub fn with_tcp_nodelay(mut self, enabled: bool) -> Self {
263 self.tcp_nodelay = enabled;
264 self
265 }
266
267 #[must_use]
269 pub fn with_pre_body_validators(mut self, validators: PreBodyValidators) -> Self {
270 self.pre_body_validators = validators;
271 self
272 }
273
274 #[must_use]
276 pub fn with_pre_body_validator<V: PreBodyValidator + 'static>(mut self, validator: V) -> Self {
277 self.pre_body_validators.add(validator);
278 self
279 }
280
281 #[must_use]
286 pub fn with_keep_alive_timeout(mut self, timeout: Duration) -> Self {
287 self.keep_alive_timeout = timeout;
288 self
289 }
290
291 #[must_use]
293 pub fn with_keep_alive_timeout_secs(mut self, secs: u64) -> Self {
294 self.keep_alive_timeout = Duration::from_secs(secs);
295 self
296 }
297
298 #[must_use]
302 pub fn with_max_requests_per_connection(mut self, max: usize) -> Self {
303 self.max_requests_per_connection = max;
304 self
305 }
306
307 #[must_use]
312 pub fn with_drain_timeout(mut self, timeout: Duration) -> Self {
313 self.drain_timeout = timeout;
314 self
315 }
316
317 #[must_use]
319 pub fn with_drain_timeout_secs(mut self, secs: u64) -> Self {
320 self.drain_timeout = Duration::from_secs(secs);
321 self
322 }
323}
324
325impl Default for ServerConfig {
326 fn default() -> Self {
327 Self::new("127.0.0.1:8080")
328 }
329}
330
331#[derive(Debug)]
333pub enum ServerError {
334 Io(io::Error),
336 Parse(ParseError),
338 Http2(http2::Http2Error),
340 Shutdown,
342 ConnectionLimitReached,
344 KeepAliveTimeout,
346}
347
348impl std::fmt::Display for ServerError {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 match self {
351 Self::Io(e) => write!(f, "IO error: {e}"),
352 Self::Parse(e) => write!(f, "Parse error: {e}"),
353 Self::Http2(e) => write!(f, "HTTP/2 error: {e}"),
354 Self::Shutdown => write!(f, "Server shutdown"),
355 Self::ConnectionLimitReached => write!(f, "Connection limit reached"),
356 Self::KeepAliveTimeout => write!(f, "Keep-alive timeout"),
357 }
358 }
359}
360
361impl std::error::Error for ServerError {
362 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
363 match self {
364 Self::Io(e) => Some(e),
365 Self::Parse(e) => Some(e),
366 Self::Http2(e) => Some(e),
367 _ => None,
368 }
369 }
370}
371
372#[derive(Debug, Clone, PartialEq, Eq)]
377enum HostValidationErrorKind {
378 Missing,
379 Invalid,
380 NotAllowed,
381}
382
383#[derive(Debug, Clone)]
384struct HostValidationError {
385 kind: HostValidationErrorKind,
386 detail: String,
387}
388
389impl HostValidationError {
390 fn missing() -> Self {
391 Self {
392 kind: HostValidationErrorKind::Missing,
393 detail: "missing Host header".to_string(),
394 }
395 }
396
397 fn invalid(detail: impl Into<String>) -> Self {
398 Self {
399 kind: HostValidationErrorKind::Invalid,
400 detail: detail.into(),
401 }
402 }
403
404 fn not_allowed(detail: impl Into<String>) -> Self {
405 Self {
406 kind: HostValidationErrorKind::NotAllowed,
407 detail: detail.into(),
408 }
409 }
410
411 fn response(&self) -> Response {
412 let message = match self.kind {
413 HostValidationErrorKind::Missing => "Bad Request: Host header required",
414 HostValidationErrorKind::Invalid => "Bad Request: invalid Host header",
415 HostValidationErrorKind::NotAllowed => "Bad Request: Host not allowed",
416 };
417 Response::with_status(StatusCode::BAD_REQUEST).body(fastapi_core::ResponseBody::Bytes(
418 message.as_bytes().to_vec(),
419 ))
420 }
421}
422
423#[derive(Debug, Clone, PartialEq, Eq)]
424struct HostHeader {
425 host: String,
426 port: Option<u16>,
427}
428
429fn validate_host_header(
430 request: &Request,
431 config: &ServerConfig,
432) -> Result<HostHeader, HostValidationError> {
433 let raw = extract_effective_host(request, config)?;
434 let parsed = parse_host_header(&raw)
435 .ok_or_else(|| HostValidationError::invalid(format!("invalid host value: {raw}")))?;
436
437 if !is_allowed_host(&parsed, &config.allowed_hosts) {
438 return Err(HostValidationError::not_allowed(format!(
439 "host not allowed: {}",
440 parsed.host
441 )));
442 }
443
444 Ok(parsed)
445}
446
447fn extract_effective_host(
448 request: &Request,
449 config: &ServerConfig,
450) -> Result<String, HostValidationError> {
451 if config.trust_x_forwarded_host {
452 if let Some(value) = header_value(request, "x-forwarded-host")? {
453 let forwarded = extract_first_list_value(&value)
454 .ok_or_else(|| HostValidationError::invalid("empty X-Forwarded-Host value"))?;
455 return Ok(forwarded.to_string());
456 }
457 }
458
459 match header_value(request, "host")? {
460 Some(value) => Ok(value),
461 None => Err(HostValidationError::missing()),
462 }
463}
464
465fn header_value(request: &Request, name: &str) -> Result<Option<String>, HostValidationError> {
466 request
467 .headers()
468 .get(name)
469 .map(|bytes| {
470 std::str::from_utf8(bytes)
471 .map(|s| s.trim().to_string())
472 .map_err(|_| {
473 HostValidationError::invalid(format!("invalid UTF-8 in {name} header"))
474 })
475 })
476 .transpose()
477}
478
479fn extract_first_list_value(value: &str) -> Option<&str> {
480 value.split(',').map(str::trim).find(|v| !v.is_empty())
481}
482
483fn parse_host_header(value: &str) -> Option<HostHeader> {
484 let value = value.trim();
485 if value.is_empty() {
486 return None;
487 }
488 if value.chars().any(|c| c.is_control() || c.is_whitespace()) {
489 return None;
490 }
491
492 if value.starts_with('[') {
493 let end = value.find(']')?;
494 let host = &value[1..end];
495 if host.is_empty() {
496 return None;
497 }
498 if host.parse::<Ipv6Addr>().is_err() {
499 return None;
500 }
501 let rest = &value[end + 1..];
502 let port = if rest.is_empty() {
503 None
504 } else if let Some(port_str) = rest.strip_prefix(':') {
505 parse_port(port_str)
506 } else {
507 return None;
508 };
509 return Some(HostHeader {
510 host: host.to_ascii_lowercase(),
511 port,
512 });
513 }
514
515 let mut parts = value.split(':');
516 let host = parts.next().unwrap_or("");
517 let port_part = parts.next();
518 if parts.next().is_some() {
519 return None;
521 }
522 if host.is_empty() {
523 return None;
524 }
525
526 let port = match port_part {
527 Some(p) => parse_port(p),
528 None => None,
529 };
530
531 if host.parse::<Ipv4Addr>().is_ok() || is_valid_hostname(host) {
532 Some(HostHeader {
533 host: host.to_ascii_lowercase(),
534 port,
535 })
536 } else {
537 None
538 }
539}
540
541fn parse_port(port: &str) -> Option<u16> {
542 if port.is_empty() || !port.chars().all(|c| c.is_ascii_digit()) {
543 return None;
544 }
545 let value = port.parse::<u16>().ok()?;
546 if value == 0 { None } else { Some(value) }
547}
548
549fn is_valid_hostname(host: &str) -> bool {
550 if host.len() > 253 {
552 return false;
553 }
554 for label in host.split('.') {
555 if label.is_empty() || label.len() > 63 {
556 return false;
557 }
558 let bytes = label.as_bytes();
559 if bytes.first() == Some(&b'-') || bytes.last() == Some(&b'-') {
560 return false;
561 }
562 if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
563 return false;
564 }
565 }
566 true
567}
568
569fn is_allowed_host(host: &HostHeader, allowed_hosts: &[String]) -> bool {
570 if allowed_hosts.is_empty() {
571 return true;
572 }
573
574 allowed_hosts
575 .iter()
576 .any(|pattern| host_matches_pattern(host, pattern))
577}
578
579fn host_matches_pattern(host: &HostHeader, pattern: &str) -> bool {
580 let pattern = pattern.trim();
582 if pattern.is_empty() {
583 return false;
584 }
585 if pattern == "*" {
586 return true;
587 }
588 if let Some(suffix) = pattern.strip_prefix("*.") {
589 if host.host == suffix {
591 return false;
592 }
593 return host.host.len() > suffix.len() + 1
594 && host.host.ends_with(suffix)
595 && host.host.as_bytes()[host.host.len() - suffix.len() - 1] == b'.';
596 }
597
598 if let Some(parsed) = parse_host_header(pattern) {
599 if parsed.host != host.host {
600 return false;
601 }
602 if let Some(port) = parsed.port {
603 return host.port == Some(port);
604 }
605 return true;
606 }
607
608 false
609}
610
611fn header_str<'a>(req: &'a Request, name: &str) -> Option<&'a str> {
612 req.headers()
613 .get(name)
614 .and_then(|v| std::str::from_utf8(v).ok())
615 .map(str::trim)
616}
617
618fn header_has_token(req: &Request, name: &str, token: &str) -> bool {
619 let Some(v) = header_str(req, name) else {
620 return false;
621 };
622 v.split(',')
623 .map(str::trim)
624 .any(|t| t.eq_ignore_ascii_case(token))
625}
626
627fn connection_has_token(req: &Request, token: &str) -> bool {
628 header_has_token(req, "connection", token)
629}
630
631fn is_websocket_upgrade_request(req: &Request) -> bool {
632 if req.method() != Method::Get {
633 return false;
634 }
635 if !header_has_token(req, "upgrade", "websocket") {
636 return false;
637 }
638 connection_has_token(req, "upgrade")
639}
640
641fn has_request_body_headers(req: &Request) -> bool {
642 if req.headers().contains("transfer-encoding") {
643 return true;
644 }
645 if let Some(v) = header_str(req, "content-length") {
646 if v.is_empty() {
647 return true;
648 }
649 match v.parse::<usize>() {
650 Ok(0) => false,
651 Ok(_) => true,
652 Err(_) => true,
653 }
654 } else {
655 false
656 }
657}
658
659impl From<io::Error> for ServerError {
660 fn from(e: io::Error) -> Self {
661 Self::Io(e)
662 }
663}
664
665impl From<ParseError> for ServerError {
666 fn from(e: ParseError) -> Self {
667 Self::Parse(e)
668 }
669}
670
671impl From<http2::Http2Error> for ServerError {
672 fn from(e: http2::Http2Error) -> Self {
673 Self::Http2(e)
674 }
675}
676
677async fn process_connection<H, Fut>(
681 cx: &Cx,
682 request_counter: &AtomicU64,
683 mut stream: TcpStream,
684 _peer_addr: SocketAddr,
685 config: &ServerConfig,
686 handler: H,
687) -> Result<(), ServerError>
688where
689 H: Fn(RequestContext, &mut Request) -> Fut,
690 Fut: Future<Output = Response>,
691{
692 let (proto, buffered) = sniff_protocol(&mut stream, config.keep_alive_timeout).await?;
693 if proto == SniffedProtocol::Http2PriorKnowledge {
694 return process_connection_http2(cx, request_counter, stream, config, handler).await;
695 }
696
697 let mut parser = StatefulParser::new().with_limits(config.parse_limits.clone());
698 if !buffered.is_empty() {
699 parser.feed(&buffered)?;
700 }
701 let mut read_buffer = vec![0u8; config.read_buffer_size];
702 let mut response_writer = ResponseWriter::new();
703 let mut requests_on_connection: usize = 0;
704 let max_requests = config.max_requests_per_connection;
705
706 loop {
707 if cx.is_cancel_requested() {
709 return Ok(());
710 }
711
712 let parse_result = parser.feed(&[])?;
714
715 let mut request = match parse_result {
716 ParseStatus::Complete { request, .. } => request,
717 ParseStatus::Incomplete => {
718 let keep_alive_timeout = config.keep_alive_timeout;
719
720 let bytes_read = if keep_alive_timeout.is_zero() {
721 read_into_buffer(&mut stream, &mut read_buffer).await?
722 } else {
723 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout).await
724 {
725 Ok(0) => return Ok(()),
726 Ok(n) => n,
727 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
728 cx.trace(&format!(
729 "Keep-alive timeout ({:?}) - closing idle connection",
730 keep_alive_timeout
731 ));
732 return Err(ServerError::KeepAliveTimeout);
733 }
734 Err(e) => return Err(ServerError::Io(e)),
735 }
736 };
737
738 if bytes_read == 0 {
739 return Ok(());
740 }
741
742 match parser.feed(&read_buffer[..bytes_read])? {
743 ParseStatus::Complete { request, .. } => request,
744 ParseStatus::Incomplete => continue,
745 }
746 }
747 };
748
749 requests_on_connection += 1;
750
751 let request_id = request_counter.fetch_add(1, Ordering::Relaxed);
753 let request_budget = Budget::new().with_deadline(config.request_timeout);
754 let request_cx = Cx::for_testing_with_budget(request_budget);
755 let ctx = RequestContext::new(request_cx, request_id);
756
757 if let Err(err) = validate_host_header(&request, config) {
759 ctx.trace(&format!("Rejecting request: {}", err.detail));
760 let response = err.response().header("connection", b"close".to_vec());
761 let response_write = response_writer.write(response);
762 write_response(&mut stream, response_write).await?;
763 return Ok(());
764 }
765
766 if let Err(response) = config.pre_body_validators.validate_all(&request) {
768 let response = response.header("connection", b"close".to_vec());
769 let response_write = response_writer.write(response);
770 write_response(&mut stream, response_write).await?;
771 return Ok(());
772 }
773
774 match ExpectHandler::check_expect(&request) {
778 ExpectResult::NoExpectation => {
779 }
781 ExpectResult::ExpectsContinue => {
782 ctx.trace("Sending 100 Continue for Expect: 100-continue");
785 write_raw_response(&mut stream, CONTINUE_RESPONSE).await?;
786 }
787 ExpectResult::UnknownExpectation(value) => {
788 ctx.trace(&format!("Rejecting unknown Expect value: {}", value));
790 let response =
791 ExpectHandler::expectation_failed(format!("Unsupported Expect value: {value}"));
792 let response_write = response_writer.write(response);
793 write_response(&mut stream, response_write).await?;
794 return Ok(());
795 }
796 }
797
798 let client_wants_keep_alive = should_keep_alive(&request);
799 let at_max_requests = max_requests > 0 && requests_on_connection >= max_requests;
800 let server_will_keep_alive = client_wants_keep_alive && !at_max_requests;
801
802 let request_start = Instant::now();
803 let timeout_duration = Duration::from_nanos(config.request_timeout.as_nanos());
804
805 let response = handler(ctx, &mut request).await;
807
808 let mut response = if request_start.elapsed() > timeout_duration {
809 Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
810 fastapi_core::ResponseBody::Bytes(
811 b"Gateway Timeout: request processing exceeded time limit".to_vec(),
812 ),
813 )
814 } else {
815 response
816 };
817
818 response = if server_will_keep_alive {
819 response.header("connection", b"keep-alive".to_vec())
820 } else {
821 response.header("connection", b"close".to_vec())
822 };
823
824 let response_write = response_writer.write(response);
825 write_response(&mut stream, response_write).await?;
826
827 if let Some(tasks) = App::take_background_tasks(&mut request) {
828 tasks.execute_all().await;
829 }
830
831 if !server_will_keep_alive {
832 return Ok(());
833 }
834 }
835}
836
837async fn process_connection_http2<H, Fut>(
838 cx: &Cx,
839 request_counter: &AtomicU64,
840 stream: TcpStream,
841 config: &ServerConfig,
842 handler: H,
843) -> Result<(), ServerError>
844where
845 H: Fn(RequestContext, &mut Request) -> Fut,
846 Fut: Future<Output = Response>,
847{
848 const FLAG_END_HEADERS: u8 = 0x4;
849 const FLAG_ACK: u8 = 0x1;
850
851 let mut framed = http2::FramedH2::new(stream, Vec::new());
852 let mut hpack = http2::HpackDecoder::new();
853 let recv_max_frame_size: u32 = 16 * 1024;
854 let mut peer_max_frame_size: u32 = 16 * 1024;
855 let mut flow_control = http2::H2FlowControl::new();
856
857 let first = framed.read_frame(recv_max_frame_size).await?;
858 if first.header.frame_type() != http2::FrameType::Settings
859 || first.header.stream_id != 0
860 || (first.header.flags & FLAG_ACK) != 0
861 {
862 return Err(http2::Http2Error::Protocol("expected client SETTINGS after preface").into());
863 }
864 apply_http2_settings_with_fc(
865 &mut hpack,
866 &mut peer_max_frame_size,
867 Some(&mut flow_control),
868 &first.payload,
869 )?;
870
871 framed
872 .write_frame(http2::FrameType::Settings, 0, 0, SERVER_SETTINGS_PAYLOAD)
873 .await?;
874 framed
875 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
876 .await?;
877
878 let default_body_limit = config.parse_limits.max_request_size;
879 let mut last_stream_id: u32 = 0;
880
881 loop {
882 if cx.is_cancel_requested() {
883 let _ = send_goaway(&mut framed, last_stream_id, h2_error_code::NO_ERROR).await;
884 return Ok(());
885 }
886
887 let frame = framed.read_frame(recv_max_frame_size).await?;
888 match frame.header.frame_type() {
889 http2::FrameType::Settings => {
890 let is_ack = validate_settings_frame(
891 frame.header.stream_id,
892 frame.header.flags,
893 &frame.payload,
894 )?;
895 if is_ack {
896 continue;
897 }
898 apply_http2_settings_with_fc(
899 &mut hpack,
900 &mut peer_max_frame_size,
901 Some(&mut flow_control),
902 &frame.payload,
903 )?;
904 framed
905 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
906 .await?;
907 }
908 http2::FrameType::Ping => {
909 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
910 return Err(http2::Http2Error::Protocol("invalid PING frame").into());
911 }
912 if (frame.header.flags & FLAG_ACK) == 0 {
913 framed
914 .write_frame(http2::FrameType::Ping, FLAG_ACK, 0, &frame.payload)
915 .await?;
916 }
917 }
918 http2::FrameType::Goaway => {
919 validate_goaway_payload(&frame.payload)?;
920 return Ok(());
921 }
922 http2::FrameType::PushPromise => {
923 return Err(
924 http2::Http2Error::Protocol("PUSH_PROMISE not supported by server").into(),
925 );
926 }
927 http2::FrameType::Headers => {
928 let stream_id = frame.header.stream_id;
929 if stream_id == 0 {
930 return Err(
931 http2::Http2Error::Protocol("HEADERS must not be on stream 0").into(),
932 );
933 }
934 if stream_id % 2 == 0 {
935 return Err(http2::Http2Error::Protocol(
936 "client-initiated stream ID must be odd",
937 )
938 .into());
939 }
940 if stream_id <= last_stream_id {
941 return Err(http2::Http2Error::Protocol(
942 "stream ID must be greater than previous",
943 )
944 .into());
945 }
946 last_stream_id = stream_id;
947 let (end_stream, mut header_block) =
948 extract_header_block_fragment(frame.header.flags, &frame.payload)?;
949
950 if (frame.header.flags & FLAG_END_HEADERS) == 0 {
951 loop {
952 let cont = framed.read_frame(recv_max_frame_size).await?;
953 if cont.header.frame_type() != http2::FrameType::Continuation
954 || cont.header.stream_id != stream_id
955 {
956 return Err(http2::Http2Error::Protocol(
957 "expected CONTINUATION for header block",
958 )
959 .into());
960 }
961 header_block.extend_from_slice(&cont.payload);
962 if header_block.len() > MAX_HEADER_BLOCK_SIZE {
963 return Err(http2::Http2Error::Protocol(
964 "header block exceeds maximum size",
965 )
966 .into());
967 }
968 if (cont.header.flags & FLAG_END_HEADERS) != 0 {
969 break;
970 }
971 }
972 }
973
974 let headers = hpack
975 .decode(&header_block)
976 .map_err(http2::Http2Error::from)?;
977 let mut request = request_from_h2_headers(headers)?;
978
979 if !end_stream {
980 let mut body = Vec::new();
981 let mut stream_reset = false;
982 let mut stream_received: u32 = 0;
983 loop {
984 let f = framed.read_frame(recv_max_frame_size).await?;
985 match f.header.frame_type() {
986 http2::FrameType::Data if f.header.stream_id == 0 => {
987 return Err(http2::Http2Error::Protocol(
988 "DATA must not be on stream 0",
989 )
990 .into());
991 }
992 http2::FrameType::Data if f.header.stream_id == stream_id => {
993 let (data, data_end_stream) =
994 extract_data_payload(f.header.flags, &f.payload)?;
995 if body.len().saturating_add(data.len()) > default_body_limit {
996 return Err(http2::Http2Error::Protocol(
997 "request body exceeds configured limit",
998 )
999 .into());
1000 }
1001 body.extend_from_slice(data);
1002
1003 let data_len = u32::try_from(data.len()).unwrap_or(u32::MAX);
1006 stream_received += data_len;
1007 let conn_inc = flow_control.data_received_connection(data_len);
1008 let stream_inc = flow_control.stream_window_update(stream_received);
1009 if stream_inc > 0 {
1010 stream_received = 0;
1011 }
1012 send_window_updates(&mut framed, conn_inc, stream_id, stream_inc)
1013 .await?;
1014
1015 if data_end_stream {
1016 break;
1017 }
1018 }
1019 http2::FrameType::RstStream => {
1020 validate_rst_stream_payload(f.header.stream_id, &f.payload)?;
1021 if f.header.stream_id == stream_id {
1022 stream_reset = true;
1023 break;
1024 }
1025 }
1026 http2::FrameType::PushPromise => {
1027 return Err(http2::Http2Error::Protocol(
1028 "PUSH_PROMISE not supported by server",
1029 )
1030 .into());
1031 }
1032 http2::FrameType::Settings
1033 | http2::FrameType::Ping
1034 | http2::FrameType::Goaway
1035 | http2::FrameType::WindowUpdate
1036 | http2::FrameType::Priority
1037 | http2::FrameType::Unknown => {
1038 if f.header.frame_type() == http2::FrameType::Goaway {
1039 validate_goaway_payload(&f.payload)?;
1040 return Ok(());
1041 }
1042 if f.header.frame_type() == http2::FrameType::Priority {
1043 validate_priority_payload(f.header.stream_id, &f.payload)?;
1044 }
1045 if f.header.frame_type() == http2::FrameType::WindowUpdate {
1046 validate_window_update_payload(&f.payload)?;
1047 let increment = u32::from_be_bytes([
1048 f.payload[0],
1049 f.payload[1],
1050 f.payload[2],
1051 f.payload[3],
1052 ]) & 0x7FFF_FFFF;
1053 if f.header.stream_id == 0 {
1054 apply_send_conn_window_update(
1055 &mut flow_control,
1056 increment,
1057 )?;
1058 }
1059 }
1060 if f.header.frame_type() == http2::FrameType::Ping {
1061 if f.header.stream_id != 0 || f.payload.len() != 8 {
1062 return Err(http2::Http2Error::Protocol(
1063 "invalid PING frame",
1064 )
1065 .into());
1066 }
1067 if (f.header.flags & FLAG_ACK) == 0 {
1068 framed
1069 .write_frame(
1070 http2::FrameType::Ping,
1071 FLAG_ACK,
1072 0,
1073 &f.payload,
1074 )
1075 .await?;
1076 }
1077 }
1078 if f.header.frame_type() == http2::FrameType::Settings {
1079 let is_ack = validate_settings_frame(
1080 f.header.stream_id,
1081 f.header.flags,
1082 &f.payload,
1083 )?;
1084 if !is_ack {
1085 apply_http2_settings_with_fc(
1086 &mut hpack,
1087 &mut peer_max_frame_size,
1088 Some(&mut flow_control),
1089 &f.payload,
1090 )?;
1091 framed
1092 .write_frame(
1093 http2::FrameType::Settings,
1094 FLAG_ACK,
1095 0,
1096 &[],
1097 )
1098 .await?;
1099 }
1100 }
1101 }
1102 _ => {
1103 return Err(http2::Http2Error::Protocol(
1104 "unsupported frame while reading request body",
1105 )
1106 .into());
1107 }
1108 }
1109 }
1110 if stream_reset {
1111 continue;
1112 }
1113 request.set_body(fastapi_core::Body::Bytes(body));
1114 }
1115
1116 let request_id = request_counter.fetch_add(1, Ordering::Relaxed);
1117 let request_budget = Budget::new().with_deadline(config.request_timeout);
1118 let request_cx = Cx::for_testing_with_budget(request_budget);
1119 let ctx = RequestContext::new(request_cx, request_id);
1120
1121 if let Err(err) = validate_host_header(&request, config) {
1122 let response = err.response();
1123 process_connection_http2_write_response(
1124 &mut framed,
1125 response,
1126 stream_id,
1127 peer_max_frame_size,
1128 recv_max_frame_size,
1129 Some(&mut flow_control),
1130 )
1131 .await?;
1132 continue;
1133 }
1134
1135 if let Err(response) = config.pre_body_validators.validate_all(&request) {
1136 process_connection_http2_write_response(
1137 &mut framed,
1138 response,
1139 stream_id,
1140 peer_max_frame_size,
1141 recv_max_frame_size,
1142 Some(&mut flow_control),
1143 )
1144 .await?;
1145 continue;
1146 }
1147
1148 let response = handler(ctx, &mut request).await;
1149 process_connection_http2_write_response(
1150 &mut framed,
1151 response,
1152 stream_id,
1153 peer_max_frame_size,
1154 recv_max_frame_size,
1155 Some(&mut flow_control),
1156 )
1157 .await?;
1158
1159 if let Some(tasks) = App::take_background_tasks(&mut request) {
1160 tasks.execute_all().await;
1161 }
1162 }
1163 http2::FrameType::WindowUpdate => {
1164 validate_window_update_payload(&frame.payload)?;
1165 let increment = u32::from_be_bytes([
1166 frame.payload[0],
1167 frame.payload[1],
1168 frame.payload[2],
1169 frame.payload[3],
1170 ]) & 0x7FFF_FFFF;
1171 if frame.header.stream_id == 0 {
1172 apply_send_conn_window_update(&mut flow_control, increment)?;
1173 }
1174 }
1175 _ => {
1176 handle_h2_idle_frame(&frame)?;
1177 }
1178 }
1179 }
1180}
1181
1182async fn process_connection_http2_write_response(
1183 framed: &mut http2::FramedH2,
1184 response: Response,
1185 stream_id: u32,
1186 mut peer_max_frame_size: u32,
1187 recv_max_frame_size: u32,
1188 mut flow_control: Option<&mut http2::H2FlowControl>,
1189) -> Result<(), ServerError> {
1190 use std::future::poll_fn;
1191
1192 const FLAG_END_STREAM: u8 = 0x1;
1193 const FLAG_END_HEADERS: u8 = 0x4;
1194
1195 let (status, mut headers, mut body) = response.into_parts();
1196 if !status.allows_body() {
1197 body = fastapi_core::ResponseBody::Empty;
1198 }
1199
1200 let mut add_content_length = matches!(body, fastapi_core::ResponseBody::Bytes(_));
1201 for (name, _) in &headers {
1202 if name.eq_ignore_ascii_case("content-length") {
1203 add_content_length = false;
1204 break;
1205 }
1206 }
1207 if add_content_length {
1208 headers.push((
1209 "content-length".to_string(),
1210 body.len().to_string().into_bytes(),
1211 ));
1212 }
1213
1214 let mut block: Vec<u8> = Vec::new();
1215 let status_bytes = status.as_u16().to_string().into_bytes();
1216 http2::hpack_encode_literal_without_indexing(&mut block, b":status", &status_bytes);
1217 for (name, value) in &headers {
1218 if is_h2_forbidden_header_name(name) {
1219 continue;
1220 }
1221 let n = name.to_ascii_lowercase();
1222 http2::hpack_encode_literal_without_indexing(&mut block, n.as_bytes(), value);
1223 }
1224
1225 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
1226 let mut headers_flags = FLAG_END_HEADERS;
1227 if body.is_empty() {
1228 headers_flags |= FLAG_END_STREAM;
1229 }
1230
1231 if block.len() <= max {
1232 framed
1233 .write_frame(http2::FrameType::Headers, headers_flags, stream_id, &block)
1234 .await?;
1235 } else {
1236 let mut first_flags = 0u8;
1238 if body.is_empty() {
1239 first_flags |= FLAG_END_STREAM;
1240 }
1241 let (first, rest) = block.split_at(max);
1242 framed
1243 .write_frame(http2::FrameType::Headers, first_flags, stream_id, first)
1244 .await?;
1245 let mut remaining = rest;
1246 while remaining.len() > max {
1247 let (chunk, r) = remaining.split_at(max);
1248 framed
1249 .write_frame(http2::FrameType::Continuation, 0, stream_id, chunk)
1250 .await?;
1251 remaining = r;
1252 }
1253 framed
1254 .write_frame(
1255 http2::FrameType::Continuation,
1256 FLAG_END_HEADERS,
1257 stream_id,
1258 remaining,
1259 )
1260 .await?;
1261 }
1262
1263 let mut stream_send_window: i64 = flow_control
1265 .as_ref()
1266 .map_or(i64::MAX, |fc| i64::from(fc.peer_initial_window_size()));
1267
1268 match body {
1269 fastapi_core::ResponseBody::Empty => Ok(()),
1270 fastapi_core::ResponseBody::Bytes(bytes) => {
1271 if bytes.is_empty() {
1272 return Ok(());
1273 }
1274 let mut remaining = bytes.as_slice();
1275 while !remaining.is_empty() {
1276 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
1277 let send_len = remaining.len().min(max);
1278
1279 let send_len = h2_fc_clamp_send(
1280 framed,
1281 &mut flow_control,
1282 &mut stream_send_window,
1283 stream_id,
1284 send_len,
1285 &mut peer_max_frame_size,
1286 recv_max_frame_size,
1287 )
1288 .await?;
1289
1290 let (chunk, r) = remaining.split_at(send_len);
1291 let flags = if r.is_empty() { FLAG_END_STREAM } else { 0 };
1292 framed
1293 .write_frame(http2::FrameType::Data, flags, stream_id, chunk)
1294 .await?;
1295 remaining = r;
1296 }
1297 Ok(())
1298 }
1299 fastapi_core::ResponseBody::Stream(mut s) => {
1300 loop {
1301 let next = poll_fn(|cx| Pin::new(&mut s).poll_next(cx)).await;
1302 match next {
1303 Some(chunk) => {
1304 let mut remaining = chunk.as_slice();
1305 while !remaining.is_empty() {
1306 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
1307 let send_len = remaining.len().min(max);
1308 let send_len = h2_fc_clamp_send(
1309 framed,
1310 &mut flow_control,
1311 &mut stream_send_window,
1312 stream_id,
1313 send_len,
1314 &mut peer_max_frame_size,
1315 recv_max_frame_size,
1316 )
1317 .await?;
1318
1319 let (c, r) = remaining.split_at(send_len);
1320 framed
1321 .write_frame(http2::FrameType::Data, 0, stream_id, c)
1322 .await?;
1323 remaining = r;
1324 }
1325 }
1326 None => {
1327 framed
1328 .write_frame(http2::FrameType::Data, FLAG_END_STREAM, stream_id, &[])
1329 .await?;
1330 break;
1331 }
1332 }
1333 }
1334 Ok(())
1335 }
1336 }
1337}
1338
1339async fn h2_fc_clamp_send(
1344 framed: &mut http2::FramedH2,
1345 flow_control: &mut Option<&mut http2::H2FlowControl>,
1346 stream_send_window: &mut i64,
1347 stream_id: u32,
1348 desired: usize,
1349 peer_max_frame_size: &mut u32,
1350 recv_max_frame_size: u32,
1351) -> Result<usize, ServerError> {
1352 let fc = match flow_control.as_mut() {
1353 Some(fc) => fc,
1354 None => return Ok(desired),
1355 };
1356
1357 loop {
1358 let conn_avail = usize::try_from(fc.send_conn_window().max(0)).unwrap_or(0);
1359 let stream_avail = usize::try_from((*stream_send_window).max(0)).unwrap_or(0);
1360 let peer_max = usize::try_from(*peer_max_frame_size).unwrap_or(16 * 1024);
1361 let allowed = desired.min(conn_avail).min(stream_avail).min(peer_max);
1362
1363 if allowed > 0 {
1364 let send = allowed;
1365 fc.consume_send_conn_window(u32::try_from(send).unwrap_or(u32::MAX));
1366 *stream_send_window -= i64::try_from(send).unwrap_or(i64::MAX);
1367 return Ok(send);
1368 }
1369
1370 let frame = framed.read_frame(recv_max_frame_size).await?;
1372 match frame.header.frame_type() {
1373 http2::FrameType::WindowUpdate => {
1374 apply_peer_window_update_for_send(
1375 fc,
1376 stream_send_window,
1377 stream_id,
1378 frame.header.stream_id,
1379 &frame.payload,
1380 )?;
1381 }
1382 http2::FrameType::Ping => {
1383 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
1384 return Err(ServerError::Http2(http2::Http2Error::Protocol(
1385 "invalid PING frame",
1386 )));
1387 }
1388 if frame.header.flags & 0x1 == 0 {
1389 framed
1390 .write_frame(http2::FrameType::Ping, 0x1, 0, &frame.payload)
1391 .await?;
1392 }
1393 }
1394 http2::FrameType::Settings => {
1395 let is_ack = validate_settings_frame(
1396 frame.header.stream_id,
1397 frame.header.flags,
1398 &frame.payload,
1399 )?;
1400 if !is_ack {
1401 apply_peer_settings_for_send(
1402 fc,
1403 stream_send_window,
1404 peer_max_frame_size,
1405 &frame.payload,
1406 )?;
1407 framed
1409 .write_frame(http2::FrameType::Settings, 0x1, 0, &[])
1410 .await?;
1411 }
1412 }
1413 http2::FrameType::Goaway => {
1414 validate_goaway_payload(&frame.payload)?;
1415 return Err(ServerError::Http2(http2::Http2Error::Protocol(
1416 "received GOAWAY while writing response",
1417 )));
1418 }
1419 http2::FrameType::RstStream => {
1420 validate_rst_stream_payload(frame.header.stream_id, &frame.payload)?;
1421 if frame.header.stream_id == stream_id {
1422 return Err(ServerError::Http2(http2::Http2Error::Protocol(
1423 "stream reset by peer during response",
1424 )));
1425 }
1426 }
1427 _ => { }
1428 }
1429 }
1430}
1431
1432#[derive(Debug)]
1438pub struct TcpServer {
1439 config: ServerConfig,
1440 request_counter: Arc<AtomicU64>,
1441 connection_counter: Arc<AtomicU64>,
1443 draining: Arc<AtomicBool>,
1445 connection_handles: Mutex<Vec<TaskHandle<()>>>,
1447 shutdown_controller: Arc<ShutdownController>,
1449 metrics_counters: Arc<MetricsCounters>,
1451}
1452
1453impl TcpServer {
1454 #[must_use]
1456 pub fn new(config: ServerConfig) -> Self {
1457 Self {
1458 config,
1459 request_counter: Arc::new(AtomicU64::new(0)),
1460 connection_counter: Arc::new(AtomicU64::new(0)),
1461 draining: Arc::new(AtomicBool::new(false)),
1462 connection_handles: Mutex::new(Vec::new()),
1463 shutdown_controller: Arc::new(ShutdownController::new()),
1464 metrics_counters: Arc::new(MetricsCounters::new()),
1465 }
1466 }
1467
1468 #[must_use]
1470 pub fn config(&self) -> &ServerConfig {
1471 &self.config
1472 }
1473
1474 fn next_request_id(&self) -> u64 {
1476 self.request_counter.fetch_add(1, Ordering::Relaxed)
1477 }
1478
1479 #[must_use]
1481 pub fn current_connections(&self) -> u64 {
1482 self.connection_counter.load(Ordering::Relaxed)
1483 }
1484
1485 #[must_use]
1487 pub fn metrics(&self) -> ServerMetrics {
1488 ServerMetrics {
1489 active_connections: self.connection_counter.load(Ordering::Relaxed),
1490 total_accepted: self.metrics_counters.total_accepted.load(Ordering::Relaxed),
1491 total_rejected: self.metrics_counters.total_rejected.load(Ordering::Relaxed),
1492 total_timed_out: self
1493 .metrics_counters
1494 .total_timed_out
1495 .load(Ordering::Relaxed),
1496 total_requests: self.request_counter.load(Ordering::Relaxed),
1497 bytes_in: self.metrics_counters.bytes_in.load(Ordering::Relaxed),
1498 bytes_out: self.metrics_counters.bytes_out.load(Ordering::Relaxed),
1499 }
1500 }
1501
1502 fn record_bytes_in(&self, n: u64) {
1504 self.metrics_counters
1505 .bytes_in
1506 .fetch_add(n, Ordering::Relaxed);
1507 }
1508
1509 fn record_bytes_out(&self, n: u64) {
1511 self.metrics_counters
1512 .bytes_out
1513 .fetch_add(n, Ordering::Relaxed);
1514 }
1515
1516 fn try_acquire_connection(&self) -> bool {
1521 let max = self.config.max_connections;
1522 if max == 0 {
1523 self.connection_counter.fetch_add(1, Ordering::Relaxed);
1525 self.metrics_counters
1526 .total_accepted
1527 .fetch_add(1, Ordering::Relaxed);
1528 return true;
1529 }
1530
1531 let mut current = self.connection_counter.load(Ordering::Relaxed);
1533 loop {
1534 if current >= max as u64 {
1535 self.metrics_counters
1536 .total_rejected
1537 .fetch_add(1, Ordering::Relaxed);
1538 return false;
1539 }
1540 match self.connection_counter.compare_exchange_weak(
1541 current,
1542 current + 1,
1543 Ordering::AcqRel,
1544 Ordering::Relaxed,
1545 ) {
1546 Ok(_) => {
1547 self.metrics_counters
1548 .total_accepted
1549 .fetch_add(1, Ordering::Relaxed);
1550 return true;
1551 }
1552 Err(actual) => current = actual,
1553 }
1554 }
1555 }
1556
1557 fn release_connection(&self) {
1559 self.connection_counter.fetch_sub(1, Ordering::Relaxed);
1560 }
1561
1562 #[must_use]
1564 pub fn is_draining(&self) -> bool {
1565 self.draining.load(Ordering::Acquire)
1566 }
1567
1568 pub fn start_drain(&self) {
1575 self.draining.store(true, Ordering::Release);
1576 }
1577
1578 pub async fn wait_for_drain(&self, timeout: Duration, poll_interval: Option<Duration>) -> bool {
1588 let start = Instant::now();
1589 let poll_interval = poll_interval.unwrap_or(Duration::from_millis(10));
1590
1591 while self.current_connections() > 0 {
1592 if start.elapsed() >= timeout {
1593 return false;
1594 }
1595 std::thread::sleep(poll_interval);
1605 }
1606 true
1607 }
1608
1609 pub async fn drain(&self) -> u64 {
1617 self.start_drain();
1618 let drained = self.wait_for_drain(self.config.drain_timeout, None).await;
1619 if drained {
1620 0
1621 } else {
1622 self.current_connections()
1623 }
1624 }
1625
1626 #[must_use]
1631 pub fn shutdown_controller(&self) -> &Arc<ShutdownController> {
1632 &self.shutdown_controller
1633 }
1634
1635 #[must_use]
1640 pub fn subscribe_shutdown(&self) -> ShutdownReceiver {
1641 self.shutdown_controller.subscribe()
1642 }
1643
1644 pub fn shutdown(&self) {
1653 self.start_drain();
1654 self.shutdown_controller.shutdown();
1655 }
1656
1657 #[must_use]
1659 pub fn is_shutting_down(&self) -> bool {
1660 self.shutdown_controller.is_shutting_down() || self.is_draining()
1661 }
1662
1663 pub async fn serve_with_shutdown<H, Fut>(
1693 &self,
1694 cx: &Cx,
1695 mut shutdown: ShutdownReceiver,
1696 handler: H,
1697 ) -> Result<GracefulOutcome<()>, ServerError>
1698 where
1699 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1700 Fut: Future<Output = Response> + Send + 'static,
1701 {
1702 let bind_addr = self.config.bind_addr.clone();
1703 let listener = TcpListener::bind(bind_addr).await?;
1704 let local_addr = listener.local_addr()?;
1705
1706 cx.trace(&format!(
1707 "Server listening on {local_addr} (with graceful shutdown)"
1708 ));
1709
1710 let result = self
1712 .accept_loop_with_shutdown(cx, listener, handler, &mut shutdown)
1713 .await;
1714
1715 match result {
1716 Ok(outcome) => {
1717 if outcome.is_shutdown() {
1718 cx.trace("Shutdown signal received, draining connections");
1719 self.start_drain();
1720 self.drain_connection_tasks(cx).await;
1721 }
1722 Ok(outcome)
1723 }
1724 Err(e) => Err(e),
1725 }
1726 }
1727
1728 async fn accept_loop_with_shutdown<H, Fut>(
1730 &self,
1731 cx: &Cx,
1732 listener: TcpListener,
1733 handler: H,
1734 shutdown: &mut ShutdownReceiver,
1735 ) -> Result<GracefulOutcome<()>, ServerError>
1736 where
1737 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1738 Fut: Future<Output = Response> + Send + 'static,
1739 {
1740 let handler = Arc::new(handler);
1741
1742 loop {
1743 if shutdown.is_shutting_down() {
1745 return Ok(GracefulOutcome::ShutdownSignaled);
1746 }
1747 if cx.is_cancel_requested() || self.is_draining() {
1748 return Ok(GracefulOutcome::ShutdownSignaled);
1749 }
1750
1751 let (mut stream, peer_addr) = match listener.accept().await {
1753 Ok(conn) => conn,
1754 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
1755 continue;
1756 }
1757 Err(e) => {
1758 cx.trace(&format!("Accept error: {e}"));
1759 if is_fatal_accept_error(&e) {
1760 return Err(ServerError::Io(e));
1761 }
1762 continue;
1763 }
1764 };
1765
1766 if !self.try_acquire_connection() {
1768 cx.trace(&format!(
1769 "Connection limit reached ({}), rejecting {peer_addr}",
1770 self.config.max_connections
1771 ));
1772
1773 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
1774 .header("connection", b"close".to_vec())
1775 .body(fastapi_core::ResponseBody::Bytes(
1776 b"503 Service Unavailable: connection limit reached".to_vec(),
1777 ));
1778 let mut writer = crate::response::ResponseWriter::new();
1779 let response_bytes = writer.write(response);
1780 let _ = write_response(&mut stream, response_bytes).await;
1781 continue;
1782 }
1783
1784 if self.config.tcp_nodelay {
1786 let _ = stream.set_nodelay(true);
1787 }
1788
1789 cx.trace(&format!(
1790 "Accepted connection from {peer_addr} ({}/{})",
1791 self.current_connections(),
1792 if self.config.max_connections == 0 {
1793 "∞".to_string()
1794 } else {
1795 self.config.max_connections.to_string()
1796 }
1797 ));
1798
1799 let request_id = self.next_request_id();
1800 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
1801 let request_cx = Cx::for_testing_with_budget(request_budget);
1802 let ctx = RequestContext::new(request_cx, request_id);
1803
1804 let result = self
1806 .handle_connection(&ctx, stream, peer_addr, &*handler)
1807 .await;
1808
1809 self.release_connection();
1810
1811 if let Err(e) = result {
1812 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
1813 }
1814 }
1815 }
1816
1817 pub async fn serve<H, Fut>(&self, cx: &Cx, handler: H) -> Result<(), ServerError>
1831 where
1832 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1833 Fut: Future<Output = Response> + Send + 'static,
1834 {
1835 let bind_addr = self.config.bind_addr.clone();
1836 let listener = TcpListener::bind(bind_addr).await?;
1837 let local_addr = listener.local_addr()?;
1838
1839 cx.trace(&format!("Server listening on {local_addr}"));
1840
1841 self.accept_loop(cx, listener, handler).await
1842 }
1843
1844 pub async fn serve_on<H, Fut>(
1848 &self,
1849 cx: &Cx,
1850 listener: TcpListener,
1851 handler: H,
1852 ) -> Result<(), ServerError>
1853 where
1854 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1855 Fut: Future<Output = Response> + Send + 'static,
1856 {
1857 self.accept_loop(cx, listener, handler).await
1858 }
1859
1860 pub async fn serve_handler(
1881 &self,
1882 cx: &Cx,
1883 handler: Arc<dyn fastapi_core::Handler>,
1884 ) -> Result<(), ServerError> {
1885 let bind_addr = self.config.bind_addr.clone();
1886 let listener = TcpListener::bind(bind_addr).await?;
1887 let local_addr = listener.local_addr()?;
1888
1889 cx.trace(&format!("Server listening on {local_addr}"));
1890
1891 self.accept_loop_handler(cx, listener, handler).await
1892 }
1893
1894 pub async fn serve_app(&self, cx: &Cx, app: Arc<App>) -> Result<(), ServerError> {
1899 let bind_addr = self.config.bind_addr.clone();
1900 let listener = TcpListener::bind(bind_addr).await?;
1901 let local_addr = listener.local_addr()?;
1902
1903 cx.trace(&format!("Server listening on {local_addr}"));
1904 self.accept_loop_app(cx, listener, app).await
1905 }
1906
1907 pub async fn serve_on_handler(
1909 &self,
1910 cx: &Cx,
1911 listener: TcpListener,
1912 handler: Arc<dyn fastapi_core::Handler>,
1913 ) -> Result<(), ServerError> {
1914 self.accept_loop_handler(cx, listener, handler).await
1915 }
1916
1917 pub async fn serve_on_app(
1923 &self,
1924 cx: &Cx,
1925 listener: TcpListener,
1926 app: Arc<App>,
1927 ) -> Result<(), ServerError> {
1928 self.accept_loop_app(cx, listener, app).await
1929 }
1930
1931 async fn accept_loop_app(
1932 &self,
1933 cx: &Cx,
1934 listener: TcpListener,
1935 app: Arc<App>,
1936 ) -> Result<(), ServerError> {
1937 loop {
1938 if cx.is_cancel_requested() {
1939 cx.trace("Server shutdown requested");
1940 return Ok(());
1941 }
1942 if self.is_draining() {
1943 cx.trace("Server draining, stopping accept loop");
1944 return Err(ServerError::Shutdown);
1945 }
1946
1947 let (mut stream, peer_addr) = match listener.accept().await {
1948 Ok(conn) => conn,
1949 Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
1950 Err(e) => {
1951 cx.trace(&format!("Accept error: {e}"));
1952 if is_fatal_accept_error(&e) {
1953 return Err(ServerError::Io(e));
1954 }
1955 continue;
1956 }
1957 };
1958
1959 if !self.try_acquire_connection() {
1960 cx.trace(&format!(
1961 "Connection limit reached ({}), rejecting {peer_addr}",
1962 self.config.max_connections
1963 ));
1964
1965 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
1966 .header("connection", b"close".to_vec())
1967 .body(fastapi_core::ResponseBody::Bytes(
1968 b"503 Service Unavailable: connection limit reached".to_vec(),
1969 ));
1970 let mut writer = crate::response::ResponseWriter::new();
1971 let response_bytes = writer.write(response);
1972 let _ = write_response(&mut stream, response_bytes).await;
1973 continue;
1974 }
1975
1976 if self.config.tcp_nodelay {
1977 let _ = stream.set_nodelay(true);
1978 }
1979
1980 cx.trace(&format!(
1981 "Accepted connection from {peer_addr} ({}/{})",
1982 self.current_connections(),
1983 if self.config.max_connections == 0 {
1984 "∞".to_string()
1985 } else {
1986 self.config.max_connections.to_string()
1987 }
1988 ));
1989
1990 let result = self
1991 .handle_connection_app(cx, stream, peer_addr, app.as_ref())
1992 .await;
1993
1994 self.release_connection();
1995
1996 if let Err(e) = result {
1997 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
1998 }
1999 }
2000 }
2001
2002 async fn accept_loop_handler(
2004 &self,
2005 cx: &Cx,
2006 listener: TcpListener,
2007 handler: Arc<dyn fastapi_core::Handler>,
2008 ) -> Result<(), ServerError> {
2009 loop {
2010 if cx.is_cancel_requested() {
2012 cx.trace("Server shutdown requested");
2013 return Ok(());
2014 }
2015
2016 if self.is_draining() {
2018 cx.trace("Server draining, stopping accept loop");
2019 return Err(ServerError::Shutdown);
2020 }
2021
2022 let (mut stream, peer_addr) = match listener.accept().await {
2024 Ok(conn) => conn,
2025 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
2026 continue;
2027 }
2028 Err(e) => {
2029 cx.trace(&format!("Accept error: {e}"));
2030 if is_fatal_accept_error(&e) {
2031 return Err(ServerError::Io(e));
2032 }
2033 continue;
2034 }
2035 };
2036
2037 if !self.try_acquire_connection() {
2039 cx.trace(&format!(
2040 "Connection limit reached ({}), rejecting {peer_addr}",
2041 self.config.max_connections
2042 ));
2043
2044 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
2045 .header("connection", b"close".to_vec())
2046 .body(fastapi_core::ResponseBody::Bytes(
2047 b"503 Service Unavailable: connection limit reached".to_vec(),
2048 ));
2049 let mut writer = crate::response::ResponseWriter::new();
2050 let response_bytes = writer.write(response);
2051 let _ = write_response(&mut stream, response_bytes).await;
2052 continue;
2053 }
2054
2055 if self.config.tcp_nodelay {
2057 let _ = stream.set_nodelay(true);
2058 }
2059
2060 cx.trace(&format!(
2061 "Accepted connection from {peer_addr} ({}/{})",
2062 self.current_connections(),
2063 if self.config.max_connections == 0 {
2064 "∞".to_string()
2065 } else {
2066 self.config.max_connections.to_string()
2067 }
2068 ));
2069
2070 let result = self
2072 .handle_connection_handler(cx, stream, peer_addr, &*handler)
2073 .await;
2074
2075 self.release_connection();
2076
2077 if let Err(e) = result {
2078 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
2079 }
2080 }
2081 }
2082
2083 #[allow(clippy::too_many_lines)]
2096 pub async fn serve_concurrent<H, Fut>(
2097 &self,
2098 cx: &Cx,
2099 scope: &Scope<'_>,
2100 state: &mut RuntimeState,
2101 handler: H,
2102 ) -> Result<(), ServerError>
2103 where
2104 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
2105 Fut: Future<Output = Response> + Send + 'static,
2106 {
2107 let bind_addr = self.config.bind_addr.clone();
2108 let listener = TcpListener::bind(bind_addr).await?;
2109 let local_addr = listener.local_addr()?;
2110
2111 cx.trace(&format!(
2112 "Server listening on {local_addr} (concurrent mode)"
2113 ));
2114
2115 let handler = Arc::new(handler);
2116
2117 self.accept_loop_concurrent(cx, scope, state, listener, handler)
2118 .await
2119 }
2120
2121 async fn accept_loop_concurrent<H, Fut>(
2123 &self,
2124 cx: &Cx,
2125 scope: &Scope<'_>,
2126 state: &mut RuntimeState,
2127 listener: TcpListener,
2128 handler: Arc<H>,
2129 ) -> Result<(), ServerError>
2130 where
2131 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
2132 Fut: Future<Output = Response> + Send + 'static,
2133 {
2134 loop {
2135 if cx.is_cancel_requested() || self.is_draining() {
2137 cx.trace("Server shutting down, draining connections");
2138 self.drain_connection_tasks(cx).await;
2139 return Ok(());
2140 }
2141
2142 let (mut stream, peer_addr) = match listener.accept().await {
2144 Ok(conn) => conn,
2145 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
2146 continue;
2147 }
2148 Err(e) => {
2149 cx.trace(&format!("Accept error: {e}"));
2150 if is_fatal_accept_error(&e) {
2151 return Err(ServerError::Io(e));
2152 }
2153 continue;
2154 }
2155 };
2156
2157 if !self.try_acquire_connection() {
2159 cx.trace(&format!(
2160 "Connection limit reached ({}), rejecting {peer_addr}",
2161 self.config.max_connections
2162 ));
2163
2164 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
2165 .header("connection", b"close".to_vec())
2166 .body(fastapi_core::ResponseBody::Bytes(
2167 b"503 Service Unavailable: connection limit reached".to_vec(),
2168 ));
2169 let mut writer = crate::response::ResponseWriter::new();
2170 let response_bytes = writer.write(response);
2171 let _ = write_response(&mut stream, response_bytes).await;
2172 continue;
2173 }
2174
2175 if self.config.tcp_nodelay {
2177 let _ = stream.set_nodelay(true);
2178 }
2179
2180 cx.trace(&format!(
2181 "Accepted connection from {peer_addr} ({}/{})",
2182 self.current_connections(),
2183 if self.config.max_connections == 0 {
2184 "∞".to_string()
2185 } else {
2186 self.config.max_connections.to_string()
2187 }
2188 ));
2189
2190 match self.spawn_connection_task(
2192 scope,
2193 state,
2194 cx,
2195 stream,
2196 peer_addr,
2197 Arc::clone(&handler),
2198 ) {
2199 Ok(handle) => {
2200 if let Ok(mut handles) = self.connection_handles.lock() {
2202 handles.push(handle);
2203 }
2204 self.cleanup_completed_handles();
2206 }
2207 Err(e) => {
2208 cx.trace(&format!("Failed to spawn connection task: {e:?}"));
2209 self.release_connection();
2210 }
2211 }
2212 }
2213 }
2214
2215 fn spawn_connection_task<H, Fut>(
2217 &self,
2218 scope: &Scope<'_>,
2219 state: &mut RuntimeState,
2220 cx: &Cx,
2221 stream: TcpStream,
2222 peer_addr: SocketAddr,
2223 handler: Arc<H>,
2224 ) -> Result<TaskHandle<()>, SpawnError>
2225 where
2226 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
2227 Fut: Future<Output = Response> + Send + 'static,
2228 {
2229 let config = self.config.clone();
2230 let request_counter = Arc::clone(&self.request_counter);
2231 let connection_counter = Arc::clone(&self.connection_counter);
2232
2233 scope.spawn_registered(state, cx, move |task_cx| async move {
2234 let result = process_connection(
2235 &task_cx,
2236 &request_counter,
2237 stream,
2238 peer_addr,
2239 &config,
2240 |ctx, req| handler(ctx, req),
2241 )
2242 .await;
2243
2244 connection_counter.fetch_sub(1, Ordering::Relaxed);
2246
2247 if let Err(e) = result {
2248 eprintln!("Connection error from {peer_addr}: {e}");
2250 }
2251 })
2252 }
2253
2254 fn cleanup_completed_handles(&self) {
2256 if let Ok(mut handles) = self.connection_handles.lock() {
2257 handles.retain(|handle| !handle.is_finished());
2258 }
2259 }
2260
2261 async fn drain_connection_tasks(&self, cx: &Cx) {
2263 let drain_timeout = self.config.drain_timeout;
2264 let start = Instant::now();
2265
2266 cx.trace(&format!(
2267 "Draining {} connection tasks (timeout: {:?})",
2268 self.connection_handles.lock().map_or(0, |h| h.len()),
2269 drain_timeout
2270 ));
2271
2272 while start.elapsed() < drain_timeout {
2274 let remaining = self
2275 .connection_handles
2276 .lock()
2277 .map_or(0, |h| h.iter().filter(|t| !t.is_finished()).count());
2278
2279 if remaining == 0 {
2280 cx.trace("All connection tasks drained successfully");
2281 return;
2282 }
2283
2284 asupersync::runtime::yield_now().await;
2286 }
2287
2288 cx.trace(&format!(
2289 "Drain timeout reached with {} tasks still running",
2290 self.connection_handles
2291 .lock()
2292 .map_or(0, |h| h.iter().filter(|t| !t.is_finished()).count())
2293 ));
2294 }
2295
2296 async fn handle_connection_app(
2297 &self,
2298 cx: &Cx,
2299 mut stream: TcpStream,
2300 peer_addr: SocketAddr,
2301 app: &App,
2302 ) -> Result<(), ServerError> {
2303 let (proto, buffered) = sniff_protocol(&mut stream, self.config.keep_alive_timeout).await?;
2304 if !buffered.is_empty() {
2305 self.record_bytes_in(buffered.len() as u64);
2306 }
2307
2308 if proto == SniffedProtocol::Http2PriorKnowledge {
2309 return self
2310 .handle_connection_app_http2(cx, stream, peer_addr, app)
2311 .await;
2312 }
2313
2314 let mut parser = StatefulParser::new().with_limits(self.config.parse_limits.clone());
2315 if !buffered.is_empty() {
2316 parser.feed(&buffered)?;
2317 }
2318 let mut read_buffer = vec![0u8; self.config.read_buffer_size];
2319 let mut response_writer = ResponseWriter::new();
2320 let mut requests_on_connection: usize = 0;
2321 let max_requests = self.config.max_requests_per_connection;
2322
2323 loop {
2324 if cx.is_cancel_requested() {
2325 return Ok(());
2326 }
2327
2328 let parse_result = parser.feed(&[])?;
2329 let mut request = match parse_result {
2330 ParseStatus::Complete { request, .. } => request,
2331 ParseStatus::Incomplete => {
2332 let keep_alive_timeout = self.config.keep_alive_timeout;
2333 let bytes_read = if keep_alive_timeout.is_zero() {
2334 read_into_buffer(&mut stream, &mut read_buffer).await?
2335 } else {
2336 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout)
2337 .await
2338 {
2339 Ok(0) => return Ok(()),
2340 Ok(n) => n,
2341 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
2342 self.metrics_counters
2343 .total_timed_out
2344 .fetch_add(1, Ordering::Relaxed);
2345 return Err(ServerError::KeepAliveTimeout);
2346 }
2347 Err(e) => return Err(ServerError::Io(e)),
2348 }
2349 };
2350
2351 if bytes_read == 0 {
2352 return Ok(());
2353 }
2354
2355 self.record_bytes_in(bytes_read as u64);
2356
2357 match parser.feed(&read_buffer[..bytes_read])? {
2358 ParseStatus::Complete { request, .. } => request,
2359 ParseStatus::Incomplete => continue,
2360 }
2361 }
2362 };
2363
2364 requests_on_connection += 1;
2365
2366 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
2367
2368 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
2370 let request_cx = Cx::for_testing_with_budget(request_budget);
2371 let overrides = app.dependency_overrides();
2372 let ctx = RequestContext::with_overrides_and_body_limit(
2373 request_cx,
2374 request_id,
2375 overrides,
2376 app.config().max_body_size,
2377 );
2378
2379 if let Err(err) = validate_host_header(&request, &self.config) {
2381 ctx.trace(&format!(
2382 "Rejecting request from {peer_addr}: {}",
2383 err.detail
2384 ));
2385 let response = err.response().header("connection", b"close".to_vec());
2386 let response_write = response_writer.write(response);
2387 write_response(&mut stream, response_write).await?;
2388 return Ok(());
2389 }
2390
2391 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
2393 let response = response.header("connection", b"close".to_vec());
2394 let response_write = response_writer.write(response);
2395 write_response(&mut stream, response_write).await?;
2396 return Ok(());
2397 }
2398
2399 if is_websocket_upgrade_request(&request)
2404 && app.websocket_route_count() > 0
2405 && app.has_websocket_route(request.path())
2406 {
2407 if has_request_body_headers(&request) {
2409 let response = Response::with_status(StatusCode::BAD_REQUEST)
2410 .header("connection", b"close".to_vec())
2411 .body(fastapi_core::ResponseBody::Bytes(
2412 b"Bad Request: websocket handshake must not include a body".to_vec(),
2413 ));
2414 let response_write = response_writer.write(response);
2415 write_response(&mut stream, response_write).await?;
2416 return Ok(());
2417 }
2418
2419 let Some(key) = header_str(&request, "sec-websocket-key") else {
2420 let response = Response::with_status(StatusCode::BAD_REQUEST)
2421 .header("connection", b"close".to_vec())
2422 .body(fastapi_core::ResponseBody::Bytes(
2423 b"Bad Request: missing Sec-WebSocket-Key".to_vec(),
2424 ));
2425 let response_write = response_writer.write(response);
2426 write_response(&mut stream, response_write).await?;
2427 return Ok(());
2428 };
2429 let accept = match fastapi_core::websocket_accept_from_key(key) {
2430 Ok(v) => v,
2431 Err(_) => {
2432 let response = Response::with_status(StatusCode::BAD_REQUEST)
2433 .header("connection", b"close".to_vec())
2434 .body(fastapi_core::ResponseBody::Bytes(
2435 b"Bad Request: invalid Sec-WebSocket-Key".to_vec(),
2436 ));
2437 let response_write = response_writer.write(response);
2438 write_response(&mut stream, response_write).await?;
2439 return Ok(());
2440 }
2441 };
2442
2443 if header_str(&request, "sec-websocket-version") != Some("13") {
2444 let response = Response::with_status(StatusCode::BAD_REQUEST)
2445 .header("sec-websocket-version", b"13".to_vec())
2446 .header("connection", b"close".to_vec())
2447 .body(fastapi_core::ResponseBody::Bytes(
2448 b"Bad Request: unsupported Sec-WebSocket-Version".to_vec(),
2449 ));
2450 let response_write = response_writer.write(response);
2451 write_response(&mut stream, response_write).await?;
2452 return Ok(());
2453 }
2454
2455 let response = Response::with_status(StatusCode::SWITCHING_PROTOCOLS)
2456 .header("upgrade", b"websocket".to_vec())
2457 .header("connection", b"Upgrade".to_vec())
2458 .header("sec-websocket-accept", accept.into_bytes());
2459 let response_write = response_writer.write(response);
2460 if let ResponseWrite::Full(ref bytes) = response_write {
2461 self.record_bytes_out(bytes.len() as u64);
2462 }
2463 write_response(&mut stream, response_write).await?;
2464
2465 let buffered = parser.take_buffered();
2467
2468 let ws_root_cx = Cx::for_testing_with_budget(Budget::new());
2470 let ws_ctx = RequestContext::with_overrides_and_body_limit(
2471 ws_root_cx,
2472 request_id,
2473 app.dependency_overrides(),
2474 app.config().max_body_size,
2475 );
2476
2477 let ws = fastapi_core::WebSocket::new(stream, buffered);
2478 let _ = app.handle_websocket(&ws_ctx, &mut request, ws).await;
2479 return Ok(());
2480 }
2481
2482 match ExpectHandler::check_expect(&request) {
2484 ExpectResult::NoExpectation => {}
2485 ExpectResult::ExpectsContinue => {
2486 ctx.trace("Sending 100 Continue for Expect: 100-continue");
2487 write_raw_response(&mut stream, CONTINUE_RESPONSE).await?;
2488 }
2489 ExpectResult::UnknownExpectation(value) => {
2490 ctx.trace(&format!("Rejecting unknown Expect value: {}", value));
2491 let response = ExpectHandler::expectation_failed(format!(
2492 "Unsupported Expect value: {value}"
2493 ));
2494 let response_write = response_writer.write(response);
2495 write_response(&mut stream, response_write).await?;
2496 return Ok(());
2497 }
2498 }
2499
2500 let client_wants_keep_alive = should_keep_alive(&request);
2501 let server_will_keep_alive = client_wants_keep_alive
2502 && (max_requests == 0 || requests_on_connection < max_requests);
2503
2504 let request_start = Instant::now();
2505 let timeout_duration = Duration::from_nanos(self.config.request_timeout.as_nanos());
2506
2507 let response = app.handle(&ctx, &mut request).await;
2508 let mut response = if request_start.elapsed() > timeout_duration {
2509 Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
2510 fastapi_core::ResponseBody::Bytes(
2511 b"Gateway Timeout: request processing exceeded time limit".to_vec(),
2512 ),
2513 )
2514 } else {
2515 response
2516 };
2517
2518 response = if server_will_keep_alive {
2519 response.header("connection", b"keep-alive".to_vec())
2520 } else {
2521 response.header("connection", b"close".to_vec())
2522 };
2523
2524 let response_write = response_writer.write(response);
2525 if let ResponseWrite::Full(ref bytes) = response_write {
2526 self.record_bytes_out(bytes.len() as u64);
2527 }
2528 write_response(&mut stream, response_write).await?;
2529
2530 if let Some(tasks) = App::take_background_tasks(&mut request) {
2531 tasks.execute_all().await;
2532 }
2533
2534 if !server_will_keep_alive {
2535 return Ok(());
2536 }
2537 }
2538 }
2539
2540 async fn handle_connection_app_http2(
2541 &self,
2542 cx: &Cx,
2543 stream: TcpStream,
2544 _peer_addr: SocketAddr,
2545 app: &App,
2546 ) -> Result<(), ServerError> {
2547 const FLAG_END_STREAM: u8 = 0x1;
2548 const FLAG_END_HEADERS: u8 = 0x4;
2549 const FLAG_ACK: u8 = 0x1;
2550
2551 let mut framed = http2::FramedH2::new(stream, Vec::new());
2552 let mut hpack = http2::HpackDecoder::new();
2553 let recv_max_frame_size: u32 = 16 * 1024; let mut peer_max_frame_size: u32 = 16 * 1024;
2555 let mut flow_control = http2::H2FlowControl::new();
2556
2557 let first = framed.read_frame(recv_max_frame_size).await?;
2558 self.record_bytes_in((http2::FrameHeader::LEN + first.payload.len()) as u64);
2559
2560 if first.header.frame_type() != http2::FrameType::Settings
2561 || first.header.stream_id != 0
2562 || (first.header.flags & FLAG_ACK) != 0
2563 {
2564 return Err(
2565 http2::Http2Error::Protocol("expected client SETTINGS after preface").into(),
2566 );
2567 }
2568
2569 apply_http2_settings_with_fc(
2570 &mut hpack,
2571 &mut peer_max_frame_size,
2572 Some(&mut flow_control),
2573 &first.payload,
2574 )?;
2575
2576 framed
2578 .write_frame(http2::FrameType::Settings, 0, 0, SERVER_SETTINGS_PAYLOAD)
2579 .await?;
2580 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2581
2582 framed
2583 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
2584 .await?;
2585 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2586 let mut last_stream_id: u32 = 0;
2587
2588 loop {
2589 if cx.is_cancel_requested() {
2590 let _ = send_goaway(&mut framed, last_stream_id, h2_error_code::NO_ERROR).await;
2591 return Ok(());
2592 }
2593
2594 let frame = framed.read_frame(recv_max_frame_size).await?;
2595 self.record_bytes_in((http2::FrameHeader::LEN + frame.payload.len()) as u64);
2596
2597 match frame.header.frame_type() {
2598 http2::FrameType::Settings => {
2599 let is_ack = validate_settings_frame(
2600 frame.header.stream_id,
2601 frame.header.flags,
2602 &frame.payload,
2603 )?;
2604 if is_ack {
2605 continue;
2607 }
2608 apply_http2_settings_with_fc(
2609 &mut hpack,
2610 &mut peer_max_frame_size,
2611 Some(&mut flow_control),
2612 &frame.payload,
2613 )?;
2614 framed
2616 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
2617 .await?;
2618 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2619 }
2620 http2::FrameType::Ping => {
2621 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
2623 return Err(http2::Http2Error::Protocol("invalid PING frame").into());
2624 }
2625 if (frame.header.flags & FLAG_ACK) == 0 {
2626 framed
2627 .write_frame(http2::FrameType::Ping, FLAG_ACK, 0, &frame.payload)
2628 .await?;
2629 self.record_bytes_out((http2::FrameHeader::LEN + 8) as u64);
2630 }
2631 }
2632 http2::FrameType::Goaway => {
2633 validate_goaway_payload(&frame.payload)?;
2634 return Ok(());
2635 }
2636 http2::FrameType::PushPromise => {
2637 return Err(http2::Http2Error::Protocol(
2638 "PUSH_PROMISE not supported by server",
2639 )
2640 .into());
2641 }
2642 http2::FrameType::Headers => {
2643 let stream_id = frame.header.stream_id;
2644 if stream_id == 0 {
2645 return Err(
2646 http2::Http2Error::Protocol("HEADERS must not be on stream 0").into(),
2647 );
2648 }
2649 if stream_id % 2 == 0 {
2650 return Err(http2::Http2Error::Protocol(
2651 "client-initiated stream ID must be odd",
2652 )
2653 .into());
2654 }
2655 if stream_id <= last_stream_id {
2656 return Err(http2::Http2Error::Protocol(
2657 "stream ID must be greater than previous",
2658 )
2659 .into());
2660 }
2661 last_stream_id = stream_id;
2662 let (end_stream, mut header_block) =
2663 extract_header_block_fragment(frame.header.flags, &frame.payload)?;
2664
2665 if (frame.header.flags & FLAG_END_HEADERS) == 0 {
2667 loop {
2668 let cont = framed.read_frame(recv_max_frame_size).await?;
2669 self.record_bytes_in(
2670 (http2::FrameHeader::LEN + cont.payload.len()) as u64,
2671 );
2672 if cont.header.frame_type() != http2::FrameType::Continuation
2673 || cont.header.stream_id != stream_id
2674 {
2675 return Err(http2::Http2Error::Protocol(
2676 "expected CONTINUATION for header block",
2677 )
2678 .into());
2679 }
2680 header_block.extend_from_slice(&cont.payload);
2681 if header_block.len() > MAX_HEADER_BLOCK_SIZE {
2682 return Err(http2::Http2Error::Protocol(
2683 "header block exceeds maximum size",
2684 )
2685 .into());
2686 }
2687 if (cont.header.flags & FLAG_END_HEADERS) != 0 {
2688 break;
2689 }
2690 }
2691 }
2692
2693 let headers = hpack
2694 .decode(&header_block)
2695 .map_err(http2::Http2Error::from)?;
2696 let mut request = request_from_h2_headers(headers)?;
2697 request.set_version(fastapi_core::HttpVersion::Http2);
2698
2699 if !end_stream {
2701 let max = app.config().max_body_size;
2702 let mut body = Vec::new();
2703 let mut stream_reset = false;
2704 let mut stream_received: u32 = 0;
2705 loop {
2706 let f = framed.read_frame(recv_max_frame_size).await?;
2707 self.record_bytes_in(
2708 (http2::FrameHeader::LEN + f.payload.len()) as u64,
2709 );
2710 match f.header.frame_type() {
2711 http2::FrameType::Data if f.header.stream_id == 0 => {
2712 return Err(http2::Http2Error::Protocol(
2713 "DATA must not be on stream 0",
2714 )
2715 .into());
2716 }
2717 http2::FrameType::Data if f.header.stream_id == stream_id => {
2718 let (data, data_end_stream) =
2719 extract_data_payload(f.header.flags, &f.payload)?;
2720 if body.len().saturating_add(data.len()) > max {
2721 return Err(http2::Http2Error::Protocol(
2722 "request body exceeds configured max_body_size",
2723 )
2724 .into());
2725 }
2726 body.extend_from_slice(data);
2727
2728 let data_len = u32::try_from(data.len()).unwrap_or(u32::MAX);
2731 stream_received += data_len;
2732 let conn_inc = flow_control.data_received_connection(data_len);
2733 let stream_inc =
2734 flow_control.stream_window_update(stream_received);
2735 if stream_inc > 0 {
2736 stream_received = 0;
2737 }
2738 send_window_updates(
2739 &mut framed,
2740 conn_inc,
2741 stream_id,
2742 stream_inc,
2743 )
2744 .await?;
2745
2746 if data_end_stream {
2747 break;
2748 }
2749 }
2750 http2::FrameType::RstStream => {
2751 validate_rst_stream_payload(f.header.stream_id, &f.payload)?;
2752 if f.header.stream_id == stream_id {
2753 stream_reset = true;
2754 break;
2755 }
2756 }
2757 http2::FrameType::PushPromise => {
2758 return Err(http2::Http2Error::Protocol(
2759 "PUSH_PROMISE not supported by server",
2760 )
2761 .into());
2762 }
2763 http2::FrameType::Settings
2764 | http2::FrameType::Ping
2765 | http2::FrameType::Goaway
2766 | http2::FrameType::WindowUpdate
2767 | http2::FrameType::Priority
2768 | http2::FrameType::Unknown => {
2769 if f.header.frame_type() == http2::FrameType::Goaway {
2770 validate_goaway_payload(&f.payload)?;
2771 return Ok(());
2772 }
2773 if f.header.frame_type() == http2::FrameType::Priority {
2774 validate_priority_payload(f.header.stream_id, &f.payload)?;
2775 }
2776 if f.header.frame_type() == http2::FrameType::WindowUpdate {
2777 validate_window_update_payload(&f.payload)?;
2778 let increment = u32::from_be_bytes([
2779 f.payload[0],
2780 f.payload[1],
2781 f.payload[2],
2782 f.payload[3],
2783 ]) & 0x7FFF_FFFF;
2784 if f.header.stream_id == 0 {
2785 apply_send_conn_window_update(
2786 &mut flow_control,
2787 increment,
2788 )?;
2789 }
2790 }
2791 if f.header.frame_type() == http2::FrameType::Ping {
2792 if f.header.stream_id != 0 || f.payload.len() != 8 {
2793 return Err(http2::Http2Error::Protocol(
2794 "invalid PING frame",
2795 )
2796 .into());
2797 }
2798 if (f.header.flags & FLAG_ACK) == 0 {
2799 framed
2800 .write_frame(
2801 http2::FrameType::Ping,
2802 FLAG_ACK,
2803 0,
2804 &f.payload,
2805 )
2806 .await?;
2807 self.record_bytes_out(
2808 (http2::FrameHeader::LEN + 8) as u64,
2809 );
2810 }
2811 }
2812 if f.header.frame_type() == http2::FrameType::Settings {
2813 let is_ack = validate_settings_frame(
2814 f.header.stream_id,
2815 f.header.flags,
2816 &f.payload,
2817 )?;
2818 if !is_ack {
2819 apply_http2_settings_with_fc(
2820 &mut hpack,
2821 &mut peer_max_frame_size,
2822 Some(&mut flow_control),
2823 &f.payload,
2824 )?;
2825 framed
2826 .write_frame(
2827 http2::FrameType::Settings,
2828 FLAG_ACK,
2829 0,
2830 &[],
2831 )
2832 .await?;
2833 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2834 }
2835 }
2836 }
2837 _ => {
2838 return Err(http2::Http2Error::Protocol(
2839 "unsupported frame while reading request body",
2840 )
2841 .into());
2842 }
2843 }
2844 }
2845 if stream_reset {
2846 continue;
2847 }
2848 request.set_body(fastapi_core::Body::Bytes(body));
2849 }
2850
2851 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
2852 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
2853 let request_cx = Cx::for_testing_with_budget(request_budget);
2854 let overrides = app.dependency_overrides();
2855 let ctx = RequestContext::with_overrides_and_body_limit(
2856 request_cx,
2857 request_id,
2858 overrides,
2859 app.config().max_body_size,
2860 );
2861
2862 if let Err(err) = validate_host_header(&request, &self.config) {
2863 ctx.trace(&format!("Rejecting HTTP/2 request: {}", err.detail));
2864 let response = err.response();
2865 self.write_h2_response(
2866 &mut framed,
2867 response,
2868 stream_id,
2869 peer_max_frame_size,
2870 recv_max_frame_size,
2871 Some(&mut flow_control),
2872 )
2873 .await?;
2874 continue;
2875 }
2876
2877 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
2878 self.write_h2_response(
2879 &mut framed,
2880 response,
2881 stream_id,
2882 peer_max_frame_size,
2883 recv_max_frame_size,
2884 Some(&mut flow_control),
2885 )
2886 .await?;
2887 continue;
2888 }
2889
2890 let response = app.handle(&ctx, &mut request).await;
2891
2892 self.write_h2_response(
2894 &mut framed,
2895 response,
2896 stream_id,
2897 peer_max_frame_size,
2898 recv_max_frame_size,
2899 Some(&mut flow_control),
2900 )
2901 .await?;
2902
2903 if let Some(tasks) = App::take_background_tasks(&mut request) {
2904 tasks.execute_all().await;
2905 }
2906
2907 asupersync::runtime::yield_now().await;
2909 }
2910 http2::FrameType::WindowUpdate => {
2911 validate_window_update_payload(&frame.payload)?;
2912 let increment = u32::from_be_bytes([
2913 frame.payload[0],
2914 frame.payload[1],
2915 frame.payload[2],
2916 frame.payload[3],
2917 ]) & 0x7FFF_FFFF;
2918 if frame.header.stream_id == 0 {
2919 apply_send_conn_window_update(&mut flow_control, increment)?;
2920 }
2921 }
2922 _ => {
2923 handle_h2_idle_frame(&frame)?;
2924 }
2925 }
2926 }
2927 }
2928
2929 async fn write_h2_response(
2930 &self,
2931 framed: &mut http2::FramedH2,
2932 response: Response,
2933 stream_id: u32,
2934 mut peer_max_frame_size: u32,
2935 recv_max_frame_size: u32,
2936 mut flow_control: Option<&mut http2::H2FlowControl>,
2937 ) -> Result<(), ServerError> {
2938 use std::future::poll_fn;
2939
2940 const FLAG_END_STREAM: u8 = 0x1;
2941 const FLAG_END_HEADERS: u8 = 0x4;
2942
2943 let (status, mut headers, mut body) = response.into_parts();
2944 if !status.allows_body() {
2945 body = fastapi_core::ResponseBody::Empty;
2946 }
2947
2948 let mut add_content_length = matches!(body, fastapi_core::ResponseBody::Bytes(_));
2949 for (name, _) in &headers {
2950 if name.eq_ignore_ascii_case("content-length") {
2951 add_content_length = false;
2952 break;
2953 }
2954 }
2955
2956 if add_content_length {
2957 let len = body.len();
2958 headers.push(("content-length".to_string(), len.to_string().into_bytes()));
2959 }
2960
2961 let mut block: Vec<u8> = Vec::new();
2963 let status_bytes = status.as_u16().to_string().into_bytes();
2964 http2::hpack_encode_literal_without_indexing(&mut block, b":status", &status_bytes);
2965
2966 for (name, value) in &headers {
2967 if is_h2_forbidden_header_name(name) {
2968 continue;
2969 }
2970 let n = name.to_ascii_lowercase();
2971 http2::hpack_encode_literal_without_indexing(&mut block, n.as_bytes(), value);
2972 }
2973
2974 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
2976 if block.len() <= max {
2977 let mut flags = FLAG_END_HEADERS;
2978 if body.is_empty() {
2979 flags |= FLAG_END_STREAM;
2980 }
2981 framed
2982 .write_frame(http2::FrameType::Headers, flags, stream_id, &block)
2983 .await?;
2984 self.record_bytes_out((http2::FrameHeader::LEN + block.len()) as u64);
2985 } else {
2986 let mut flags = 0u8;
2987 if body.is_empty() {
2988 flags |= FLAG_END_STREAM;
2989 }
2990 let (first, rest) = block.split_at(max);
2991 framed
2992 .write_frame(http2::FrameType::Headers, flags, stream_id, first)
2993 .await?;
2994 self.record_bytes_out((http2::FrameHeader::LEN + first.len()) as u64);
2995
2996 let mut remaining = rest;
2997 while remaining.len() > max {
2998 let (chunk, r) = remaining.split_at(max);
2999 framed
3000 .write_frame(http2::FrameType::Continuation, 0, stream_id, chunk)
3001 .await?;
3002 self.record_bytes_out((http2::FrameHeader::LEN + chunk.len()) as u64);
3003 remaining = r;
3004 }
3005 framed
3006 .write_frame(
3007 http2::FrameType::Continuation,
3008 FLAG_END_HEADERS,
3009 stream_id,
3010 remaining,
3011 )
3012 .await?;
3013 self.record_bytes_out((http2::FrameHeader::LEN + remaining.len()) as u64);
3014 }
3015
3016 let mut stream_send_window: i64 = flow_control
3018 .as_ref()
3019 .map_or(i64::MAX, |fc| i64::from(fc.peer_initial_window_size()));
3020
3021 match body {
3023 fastapi_core::ResponseBody::Empty => Ok(()),
3024 fastapi_core::ResponseBody::Bytes(bytes) => {
3025 if bytes.is_empty() {
3026 return Ok(());
3027 }
3028 let mut remaining = bytes.as_slice();
3029 while !remaining.is_empty() {
3030 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
3031 let send_len = remaining.len().min(max);
3032 let send_len = h2_fc_clamp_send(
3033 framed,
3034 &mut flow_control,
3035 &mut stream_send_window,
3036 stream_id,
3037 send_len,
3038 &mut peer_max_frame_size,
3039 recv_max_frame_size,
3040 )
3041 .await?;
3042
3043 let (chunk, r) = remaining.split_at(send_len);
3044 let flags = if r.is_empty() { FLAG_END_STREAM } else { 0 };
3045 framed
3046 .write_frame(http2::FrameType::Data, flags, stream_id, chunk)
3047 .await?;
3048 self.record_bytes_out((http2::FrameHeader::LEN + chunk.len()) as u64);
3049 remaining = r;
3050 }
3051 Ok(())
3052 }
3053 fastapi_core::ResponseBody::Stream(mut s) => {
3054 loop {
3055 let next = poll_fn(|cx| Pin::new(&mut s).poll_next(cx)).await;
3056 match next {
3057 Some(chunk) => {
3058 let mut remaining = chunk.as_slice();
3059 while !remaining.is_empty() {
3060 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
3061 let send_len = remaining.len().min(max);
3062 let send_len = h2_fc_clamp_send(
3063 framed,
3064 &mut flow_control,
3065 &mut stream_send_window,
3066 stream_id,
3067 send_len,
3068 &mut peer_max_frame_size,
3069 recv_max_frame_size,
3070 )
3071 .await?;
3072
3073 let (c, r) = remaining.split_at(send_len);
3074 framed
3075 .write_frame(http2::FrameType::Data, 0, stream_id, c)
3076 .await?;
3077 self.record_bytes_out((http2::FrameHeader::LEN + c.len()) as u64);
3078 remaining = r;
3079 }
3080 }
3081 None => {
3082 framed
3083 .write_frame(
3084 http2::FrameType::Data,
3085 FLAG_END_STREAM,
3086 stream_id,
3087 &[],
3088 )
3089 .await?;
3090 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3091 break;
3092 }
3093 }
3094 }
3095 Ok(())
3096 }
3097 }
3098 }
3099
3100 async fn handle_connection_handler_http2(
3101 &self,
3102 cx: &Cx,
3103 stream: TcpStream,
3104 handler: &dyn fastapi_core::Handler,
3105 ) -> Result<(), ServerError> {
3106 const FLAG_END_HEADERS: u8 = 0x4;
3107 const FLAG_ACK: u8 = 0x1;
3108
3109 let mut framed = http2::FramedH2::new(stream, Vec::new());
3110 let mut hpack = http2::HpackDecoder::new();
3111 let recv_max_frame_size: u32 = 16 * 1024;
3112 let mut peer_max_frame_size: u32 = 16 * 1024;
3113 let mut flow_control = http2::H2FlowControl::new();
3114
3115 let first = framed.read_frame(recv_max_frame_size).await?;
3116 self.record_bytes_in((http2::FrameHeader::LEN + first.payload.len()) as u64);
3117
3118 if first.header.frame_type() != http2::FrameType::Settings
3119 || first.header.stream_id != 0
3120 || (first.header.flags & FLAG_ACK) != 0
3121 {
3122 return Err(
3123 http2::Http2Error::Protocol("expected client SETTINGS after preface").into(),
3124 );
3125 }
3126
3127 apply_http2_settings_with_fc(
3128 &mut hpack,
3129 &mut peer_max_frame_size,
3130 Some(&mut flow_control),
3131 &first.payload,
3132 )?;
3133
3134 framed
3135 .write_frame(http2::FrameType::Settings, 0, 0, SERVER_SETTINGS_PAYLOAD)
3136 .await?;
3137 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3138
3139 framed
3140 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
3141 .await?;
3142 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3143
3144 let default_body_limit = self.config.parse_limits.max_request_size;
3145 let mut last_stream_id: u32 = 0;
3146
3147 loop {
3148 if cx.is_cancel_requested() {
3149 let _ = send_goaway(&mut framed, last_stream_id, h2_error_code::NO_ERROR).await;
3150 return Ok(());
3151 }
3152
3153 let frame = framed.read_frame(recv_max_frame_size).await?;
3154 self.record_bytes_in((http2::FrameHeader::LEN + frame.payload.len()) as u64);
3155
3156 match frame.header.frame_type() {
3157 http2::FrameType::Settings => {
3158 let is_ack = validate_settings_frame(
3159 frame.header.stream_id,
3160 frame.header.flags,
3161 &frame.payload,
3162 )?;
3163 if is_ack {
3164 continue;
3165 }
3166 apply_http2_settings_with_fc(
3167 &mut hpack,
3168 &mut peer_max_frame_size,
3169 Some(&mut flow_control),
3170 &frame.payload,
3171 )?;
3172 framed
3173 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
3174 .await?;
3175 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3176 }
3177 http2::FrameType::Ping => {
3178 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
3179 return Err(http2::Http2Error::Protocol("invalid PING frame").into());
3180 }
3181 if (frame.header.flags & FLAG_ACK) == 0 {
3182 framed
3183 .write_frame(http2::FrameType::Ping, FLAG_ACK, 0, &frame.payload)
3184 .await?;
3185 self.record_bytes_out((http2::FrameHeader::LEN + 8) as u64);
3186 }
3187 }
3188 http2::FrameType::Goaway => {
3189 validate_goaway_payload(&frame.payload)?;
3190 return Ok(());
3191 }
3192 http2::FrameType::PushPromise => {
3193 return Err(http2::Http2Error::Protocol(
3194 "PUSH_PROMISE not supported by server",
3195 )
3196 .into());
3197 }
3198 http2::FrameType::Headers => {
3199 let stream_id = frame.header.stream_id;
3200 if stream_id == 0 {
3201 return Err(
3202 http2::Http2Error::Protocol("HEADERS must not be on stream 0").into(),
3203 );
3204 }
3205 if stream_id % 2 == 0 {
3206 return Err(http2::Http2Error::Protocol(
3207 "client-initiated stream ID must be odd",
3208 )
3209 .into());
3210 }
3211 if stream_id <= last_stream_id {
3212 return Err(http2::Http2Error::Protocol(
3213 "stream ID must be greater than previous",
3214 )
3215 .into());
3216 }
3217 last_stream_id = stream_id;
3218 let (end_stream, mut header_block) =
3219 extract_header_block_fragment(frame.header.flags, &frame.payload)?;
3220
3221 if (frame.header.flags & FLAG_END_HEADERS) == 0 {
3222 loop {
3223 let cont = framed.read_frame(recv_max_frame_size).await?;
3224 self.record_bytes_in(
3225 (http2::FrameHeader::LEN + cont.payload.len()) as u64,
3226 );
3227 if cont.header.frame_type() != http2::FrameType::Continuation
3228 || cont.header.stream_id != stream_id
3229 {
3230 return Err(http2::Http2Error::Protocol(
3231 "expected CONTINUATION for header block",
3232 )
3233 .into());
3234 }
3235 header_block.extend_from_slice(&cont.payload);
3236 if header_block.len() > MAX_HEADER_BLOCK_SIZE {
3237 return Err(http2::Http2Error::Protocol(
3238 "header block exceeds maximum size",
3239 )
3240 .into());
3241 }
3242 if (cont.header.flags & FLAG_END_HEADERS) != 0 {
3243 break;
3244 }
3245 }
3246 }
3247
3248 let headers = hpack
3249 .decode(&header_block)
3250 .map_err(http2::Http2Error::from)?;
3251 let mut request = request_from_h2_headers(headers)?;
3252
3253 if !end_stream {
3254 let mut body = Vec::new();
3255 let mut stream_reset = false;
3256 let mut stream_received: u32 = 0;
3257 loop {
3258 let f = framed.read_frame(recv_max_frame_size).await?;
3259 self.record_bytes_in(
3260 (http2::FrameHeader::LEN + f.payload.len()) as u64,
3261 );
3262 match f.header.frame_type() {
3263 http2::FrameType::Data if f.header.stream_id == 0 => {
3264 return Err(http2::Http2Error::Protocol(
3265 "DATA must not be on stream 0",
3266 )
3267 .into());
3268 }
3269 http2::FrameType::Data if f.header.stream_id == stream_id => {
3270 let (data, data_end_stream) =
3271 extract_data_payload(f.header.flags, &f.payload)?;
3272 if body.len().saturating_add(data.len()) > default_body_limit {
3273 return Err(http2::Http2Error::Protocol(
3274 "request body exceeds configured limit",
3275 )
3276 .into());
3277 }
3278 body.extend_from_slice(data);
3279
3280 let data_len = u32::try_from(data.len()).unwrap_or(u32::MAX);
3283 stream_received += data_len;
3284 let conn_inc = flow_control.data_received_connection(data_len);
3285 let stream_inc =
3286 flow_control.stream_window_update(stream_received);
3287 if stream_inc > 0 {
3288 stream_received = 0;
3289 }
3290 send_window_updates(
3291 &mut framed,
3292 conn_inc,
3293 stream_id,
3294 stream_inc,
3295 )
3296 .await?;
3297
3298 if data_end_stream {
3299 break;
3300 }
3301 }
3302 http2::FrameType::RstStream => {
3303 validate_rst_stream_payload(f.header.stream_id, &f.payload)?;
3304 if f.header.stream_id == stream_id {
3305 stream_reset = true;
3306 break;
3307 }
3308 }
3309 http2::FrameType::PushPromise => {
3310 return Err(http2::Http2Error::Protocol(
3311 "PUSH_PROMISE not supported by server",
3312 )
3313 .into());
3314 }
3315 http2::FrameType::Settings
3316 | http2::FrameType::Ping
3317 | http2::FrameType::Goaway
3318 | http2::FrameType::WindowUpdate
3319 | http2::FrameType::Priority
3320 | http2::FrameType::Unknown => {
3321 if f.header.frame_type() == http2::FrameType::Goaway {
3322 validate_goaway_payload(&f.payload)?;
3323 return Ok(());
3324 }
3325 if f.header.frame_type() == http2::FrameType::Priority {
3326 validate_priority_payload(f.header.stream_id, &f.payload)?;
3327 }
3328 if f.header.frame_type() == http2::FrameType::WindowUpdate {
3329 validate_window_update_payload(&f.payload)?;
3330 let increment = u32::from_be_bytes([
3331 f.payload[0],
3332 f.payload[1],
3333 f.payload[2],
3334 f.payload[3],
3335 ]) & 0x7FFF_FFFF;
3336 if f.header.stream_id == 0 {
3337 apply_send_conn_window_update(
3338 &mut flow_control,
3339 increment,
3340 )?;
3341 }
3342 }
3343 if f.header.frame_type() == http2::FrameType::Ping {
3344 if f.header.stream_id != 0 || f.payload.len() != 8 {
3345 return Err(http2::Http2Error::Protocol(
3346 "invalid PING frame",
3347 )
3348 .into());
3349 }
3350 if (f.header.flags & FLAG_ACK) == 0 {
3351 framed
3352 .write_frame(
3353 http2::FrameType::Ping,
3354 FLAG_ACK,
3355 0,
3356 &f.payload,
3357 )
3358 .await?;
3359 self.record_bytes_out(
3360 (http2::FrameHeader::LEN + 8) as u64,
3361 );
3362 }
3363 }
3364 if f.header.frame_type() == http2::FrameType::Settings {
3365 let is_ack = validate_settings_frame(
3366 f.header.stream_id,
3367 f.header.flags,
3368 &f.payload,
3369 )?;
3370 if !is_ack {
3371 apply_http2_settings_with_fc(
3372 &mut hpack,
3373 &mut peer_max_frame_size,
3374 Some(&mut flow_control),
3375 &f.payload,
3376 )?;
3377 framed
3378 .write_frame(
3379 http2::FrameType::Settings,
3380 FLAG_ACK,
3381 0,
3382 &[],
3383 )
3384 .await?;
3385 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3386 }
3387 }
3388 }
3389 _ => {
3390 return Err(http2::Http2Error::Protocol(
3391 "unsupported frame while reading request body",
3392 )
3393 .into());
3394 }
3395 }
3396 }
3397 if stream_reset {
3398 continue;
3399 }
3400 request.set_body(fastapi_core::Body::Bytes(body));
3401 }
3402
3403 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
3404 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
3405 let request_cx = Cx::for_testing_with_budget(request_budget);
3406
3407 let overrides = handler
3408 .dependency_overrides()
3409 .unwrap_or_else(|| Arc::new(fastapi_core::DependencyOverrides::new()));
3410
3411 let ctx = RequestContext::with_overrides_and_body_limit(
3412 request_cx,
3413 request_id,
3414 overrides,
3415 default_body_limit,
3416 );
3417
3418 if let Err(err) = validate_host_header(&request, &self.config) {
3419 let response = err.response();
3420 self.write_h2_response(
3421 &mut framed,
3422 response,
3423 stream_id,
3424 peer_max_frame_size,
3425 recv_max_frame_size,
3426 Some(&mut flow_control),
3427 )
3428 .await?;
3429 continue;
3430 }
3431 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
3432 self.write_h2_response(
3433 &mut framed,
3434 response,
3435 stream_id,
3436 peer_max_frame_size,
3437 recv_max_frame_size,
3438 Some(&mut flow_control),
3439 )
3440 .await?;
3441 continue;
3442 }
3443
3444 let response = handler.call(&ctx, &mut request).await;
3445 self.write_h2_response(
3446 &mut framed,
3447 response,
3448 stream_id,
3449 peer_max_frame_size,
3450 recv_max_frame_size,
3451 Some(&mut flow_control),
3452 )
3453 .await?;
3454 }
3455 http2::FrameType::WindowUpdate => {
3456 validate_window_update_payload(&frame.payload)?;
3457 let increment = u32::from_be_bytes([
3458 frame.payload[0],
3459 frame.payload[1],
3460 frame.payload[2],
3461 frame.payload[3],
3462 ]) & 0x7FFF_FFFF;
3463 if frame.header.stream_id == 0 {
3464 apply_send_conn_window_update(&mut flow_control, increment)?;
3465 }
3466 }
3467 _ => {
3468 handle_h2_idle_frame(&frame)?;
3469 }
3470 }
3471 }
3472 }
3473
3474 async fn handle_connection_handler(
3479 &self,
3480 cx: &Cx,
3481 mut stream: TcpStream,
3482 _peer_addr: SocketAddr,
3483 handler: &dyn fastapi_core::Handler,
3484 ) -> Result<(), ServerError> {
3485 let (proto, buffered) = sniff_protocol(&mut stream, self.config.keep_alive_timeout).await?;
3486 if !buffered.is_empty() {
3487 self.record_bytes_in(buffered.len() as u64);
3488 }
3489 if proto == SniffedProtocol::Http2PriorKnowledge {
3490 return self
3491 .handle_connection_handler_http2(cx, stream, handler)
3492 .await;
3493 }
3494
3495 let mut parser = StatefulParser::new().with_limits(self.config.parse_limits.clone());
3496 if !buffered.is_empty() {
3497 parser.feed(&buffered)?;
3498 }
3499 let mut read_buffer = vec![0u8; self.config.read_buffer_size];
3500 let mut response_writer = ResponseWriter::new();
3501 let mut requests_on_connection: usize = 0;
3502 let max_requests = self.config.max_requests_per_connection;
3503
3504 loop {
3505 if cx.is_cancel_requested() {
3507 return Ok(());
3508 }
3509
3510 let parse_result = parser.feed(&[])?;
3512
3513 let mut request = match parse_result {
3514 ParseStatus::Complete { request, .. } => request,
3515 ParseStatus::Incomplete => {
3516 let keep_alive_timeout = self.config.keep_alive_timeout;
3517 let bytes_read = if keep_alive_timeout.is_zero() {
3518 read_into_buffer(&mut stream, &mut read_buffer).await?
3519 } else {
3520 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout)
3521 .await
3522 {
3523 Ok(0) => return Ok(()),
3524 Ok(n) => n,
3525 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
3526 self.metrics_counters
3527 .total_timed_out
3528 .fetch_add(1, Ordering::Relaxed);
3529 return Err(ServerError::KeepAliveTimeout);
3530 }
3531 Err(e) => return Err(ServerError::Io(e)),
3532 }
3533 };
3534
3535 if bytes_read == 0 {
3536 return Ok(());
3537 }
3538
3539 self.record_bytes_in(bytes_read as u64);
3540
3541 match parser.feed(&read_buffer[..bytes_read])? {
3542 ParseStatus::Complete { request, .. } => request,
3543 ParseStatus::Incomplete => continue,
3544 }
3545 }
3546 };
3547
3548 requests_on_connection += 1;
3549
3550 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
3552 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
3553 let request_cx = Cx::for_testing_with_budget(request_budget);
3554 let ctx = RequestContext::new(request_cx, request_id);
3555
3556 if let Err(err) = validate_host_header(&request, &self.config) {
3558 let response = err.response().header("connection", b"close".to_vec());
3559 let response_write = response_writer.write(response);
3560 write_response(&mut stream, response_write).await?;
3561 return Ok(());
3562 }
3563
3564 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
3566 let response = response.header("connection", b"close".to_vec());
3567 let response_write = response_writer.write(response);
3568 write_response(&mut stream, response_write).await?;
3569 return Ok(());
3570 }
3571
3572 match ExpectHandler::check_expect(&request) {
3574 ExpectResult::NoExpectation => {}
3575 ExpectResult::ExpectsContinue => {
3576 write_raw_response(&mut stream, CONTINUE_RESPONSE).await?;
3577 }
3578 ExpectResult::UnknownExpectation(_) => {
3579 let response =
3580 ExpectHandler::expectation_failed("Unsupported Expect value".to_string());
3581 let response_write = response_writer.write(response);
3582 write_response(&mut stream, response_write).await?;
3583 return Ok(());
3584 }
3585 }
3586
3587 let response = handler.call(&ctx, &mut request).await;
3589
3590 let client_wants_keep_alive = should_keep_alive(&request);
3592 let server_will_keep_alive = client_wants_keep_alive
3593 && (max_requests == 0 || requests_on_connection < max_requests);
3594
3595 let response = if server_will_keep_alive {
3596 response.header("connection", b"keep-alive".to_vec())
3597 } else {
3598 response.header("connection", b"close".to_vec())
3599 };
3600
3601 let response_write = response_writer.write(response);
3602 if let ResponseWrite::Full(ref bytes) = response_write {
3603 self.record_bytes_out(bytes.len() as u64);
3604 }
3605 write_response(&mut stream, response_write).await?;
3606
3607 if !server_will_keep_alive {
3608 return Ok(());
3609 }
3610 }
3611 }
3612
3613 async fn accept_loop<H, Fut>(
3615 &self,
3616 cx: &Cx,
3617 listener: TcpListener,
3618 handler: H,
3619 ) -> Result<(), ServerError>
3620 where
3621 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
3622 Fut: Future<Output = Response> + Send + 'static,
3623 {
3624 let handler = Arc::new(handler);
3625
3626 loop {
3627 if cx.is_cancel_requested() {
3629 cx.trace("Server shutdown requested");
3630 return Ok(());
3631 }
3632
3633 if self.is_draining() {
3635 cx.trace("Server draining, stopping accept loop");
3636 return Err(ServerError::Shutdown);
3637 }
3638
3639 let (mut stream, peer_addr) = match listener.accept().await {
3641 Ok(conn) => conn,
3642 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
3643 continue;
3645 }
3646 Err(e) => {
3647 cx.trace(&format!("Accept error: {e}"));
3648 if is_fatal_accept_error(&e) {
3651 return Err(ServerError::Io(e));
3652 }
3653 continue;
3654 }
3655 };
3656
3657 if !self.try_acquire_connection() {
3659 cx.trace(&format!(
3660 "Connection limit reached ({}), rejecting {peer_addr}",
3661 self.config.max_connections
3662 ));
3663
3664 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
3666 .header("connection", b"close".to_vec())
3667 .body(fastapi_core::ResponseBody::Bytes(
3668 b"503 Service Unavailable: connection limit reached".to_vec(),
3669 ));
3670 let mut writer = crate::response::ResponseWriter::new();
3671 let response_bytes = writer.write(response);
3672 let _ = write_response(&mut stream, response_bytes).await;
3673 continue;
3674 }
3675
3676 if self.config.tcp_nodelay {
3678 let _ = stream.set_nodelay(true);
3679 }
3680
3681 cx.trace(&format!(
3682 "Accepted connection from {peer_addr} ({}/{})",
3683 self.current_connections(),
3684 if self.config.max_connections == 0 {
3685 "∞".to_string()
3686 } else {
3687 self.config.max_connections.to_string()
3688 }
3689 ));
3690
3691 let request_id = self.next_request_id();
3696 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
3697
3698 let request_cx = Cx::for_testing_with_budget(request_budget);
3703 let ctx = RequestContext::new(request_cx, request_id);
3704
3705 let result = self
3707 .handle_connection(&ctx, stream, peer_addr, &*handler)
3708 .await;
3709
3710 self.release_connection();
3712
3713 if let Err(e) = result {
3714 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
3715 }
3716 }
3717 }
3718
3719 async fn handle_connection<H, Fut>(
3725 &self,
3726 ctx: &RequestContext,
3727 stream: TcpStream,
3728 peer_addr: SocketAddr,
3729 handler: &H,
3730 ) -> Result<(), ServerError>
3731 where
3732 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync,
3733 Fut: Future<Output = Response> + Send,
3734 {
3735 process_connection(
3736 ctx.cx(),
3737 &self.request_counter,
3738 stream,
3739 peer_addr,
3740 &self.config,
3741 |ctx, req| handler(ctx, req),
3742 )
3743 .await
3744 }
3745}
3746
3747#[derive(Debug, Clone, PartialEq, Eq)]
3752pub struct ServerMetrics {
3753 pub active_connections: u64,
3755 pub total_accepted: u64,
3757 pub total_rejected: u64,
3759 pub total_timed_out: u64,
3761 pub total_requests: u64,
3763 pub bytes_in: u64,
3765 pub bytes_out: u64,
3767}
3768
3769#[derive(Debug)]
3774struct MetricsCounters {
3775 total_accepted: AtomicU64,
3776 total_rejected: AtomicU64,
3777 total_timed_out: AtomicU64,
3778 bytes_in: AtomicU64,
3779 bytes_out: AtomicU64,
3780}
3781
3782impl MetricsCounters {
3783 fn new() -> Self {
3784 Self {
3785 total_accepted: AtomicU64::new(0),
3786 total_rejected: AtomicU64::new(0),
3787 total_timed_out: AtomicU64::new(0),
3788 bytes_in: AtomicU64::new(0),
3789 bytes_out: AtomicU64::new(0),
3790 }
3791 }
3792}
3793
3794impl Default for TcpServer {
3795 fn default() -> Self {
3796 Self::new(ServerConfig::default())
3797 }
3798}
3799
3800fn is_fatal_accept_error(e: &io::Error) -> bool {
3802 matches!(
3804 e.kind(),
3805 io::ErrorKind::NotConnected | io::ErrorKind::InvalidInput
3806 )
3807}
3808
3809async fn read_into_buffer(stream: &mut TcpStream, buffer: &mut [u8]) -> io::Result<usize> {
3813 use std::future::poll_fn;
3814
3815 poll_fn(|cx| {
3816 let mut read_buf = ReadBuf::new(buffer);
3817 match Pin::new(&mut *stream).poll_read(cx, &mut read_buf) {
3818 Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
3819 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
3820 Poll::Pending => Poll::Pending,
3821 }
3822 })
3823 .await
3824}
3825
3826async fn read_with_timeout(
3844 stream: &mut TcpStream,
3845 buffer: &mut [u8],
3846 timeout_duration: Duration,
3847) -> io::Result<usize> {
3848 let now = current_time();
3850
3851 let read_future = Box::pin(read_into_buffer(stream, buffer));
3853
3854 match timeout(now, timeout_duration, read_future).await {
3856 Ok(result) => result,
3857 Err(_elapsed) => Err(io::Error::new(
3858 io::ErrorKind::TimedOut,
3859 "keep-alive timeout expired",
3860 )),
3861 }
3862}
3863
3864#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3865enum SniffedProtocol {
3866 Http1,
3867 Http2PriorKnowledge,
3868}
3869
3870async fn sniff_protocol(
3874 stream: &mut TcpStream,
3875 keep_alive_timeout: Duration,
3876) -> io::Result<(SniffedProtocol, Vec<u8>)> {
3877 let mut buffered: Vec<u8> = Vec::new();
3878 let preface = http2::PREFACE;
3879
3880 while buffered.len() < preface.len() {
3881 let mut tmp = vec![0u8; preface.len() - buffered.len()];
3882 let n = if keep_alive_timeout.is_zero() {
3883 read_into_buffer(stream, &mut tmp).await?
3884 } else {
3885 read_with_timeout(stream, &mut tmp, keep_alive_timeout).await?
3886 };
3887 if n == 0 {
3888 return Ok((SniffedProtocol::Http1, buffered));
3890 }
3891
3892 buffered.extend_from_slice(&tmp[..n]);
3893 if !preface.starts_with(&buffered) {
3894 return Ok((SniffedProtocol::Http1, buffered));
3895 }
3896 }
3897
3898 Ok((SniffedProtocol::Http2PriorKnowledge, buffered))
3899}
3900
3901fn apply_http2_settings(
3902 hpack: &mut http2::HpackDecoder,
3903 max_frame_size: &mut u32,
3904 payload: &[u8],
3905) -> Result<(), http2::Http2Error> {
3906 apply_http2_settings_with_fc(hpack, max_frame_size, None, payload)
3907}
3908
3909fn apply_http2_settings_with_fc(
3910 hpack: &mut http2::HpackDecoder,
3911 max_frame_size: &mut u32,
3912 mut flow_control: Option<&mut http2::H2FlowControl>,
3913 payload: &[u8],
3914) -> Result<(), http2::Http2Error> {
3915 if payload.len() % 6 != 0 {
3917 return Err(http2::Http2Error::Protocol(
3918 "SETTINGS length must be a multiple of 6",
3919 ));
3920 }
3921
3922 for chunk in payload.chunks_exact(6) {
3923 let id = u16::from_be_bytes([chunk[0], chunk[1]]);
3924 let value = u32::from_be_bytes([chunk[2], chunk[3], chunk[4], chunk[5]]);
3925 match id {
3926 0x1 => {
3927 let capped = (value as usize).min(MAX_HPACK_TABLE_SIZE);
3929 hpack.set_dynamic_table_max_size(capped);
3930 }
3931 0x3 => {
3932 if value > 0x7FFF_FFFF {
3935 return Err(http2::Http2Error::Protocol(
3936 "SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum",
3937 ));
3938 }
3939 if let Some(ref mut fc) = flow_control {
3940 fc.set_initial_window_size(value);
3941 fc.set_peer_initial_window_size(value);
3945 }
3946 }
3947 0x5 => {
3948 if !(16_384..=16_777_215).contains(&value) {
3950 return Err(http2::Http2Error::Protocol(
3951 "invalid SETTINGS_MAX_FRAME_SIZE",
3952 ));
3953 }
3954 *max_frame_size = value;
3955 }
3956 0x2 => {
3957 if value > 1 {
3959 return Err(http2::Http2Error::Protocol(
3960 "SETTINGS_ENABLE_PUSH must be 0 or 1",
3961 ));
3962 }
3963 }
3965 0x4 => {
3966 }
3969 0x6 => {
3970 hpack.set_max_header_list_size(value as usize);
3972 }
3973 _ => {
3974 }
3976 }
3977 }
3978 Ok(())
3979}
3980
3981fn validate_settings_frame(
3982 stream_id: u32,
3983 flags: u8,
3984 payload: &[u8],
3985) -> Result<bool, http2::Http2Error> {
3986 const FLAG_ACK: u8 = 0x1;
3987 if stream_id != 0 {
3988 return Err(http2::Http2Error::Protocol("SETTINGS must be on stream 0"));
3989 }
3990
3991 let is_ack = (flags & FLAG_ACK) != 0;
3992 if is_ack && !payload.is_empty() {
3993 return Err(http2::Http2Error::Protocol(
3994 "SETTINGS ACK frame must have empty payload",
3995 ));
3996 }
3997
3998 Ok(is_ack)
3999}
4000
4001fn validate_window_update_payload(payload: &[u8]) -> Result<(), http2::Http2Error> {
4002 if payload.len() != 4 {
4003 return Err(http2::Http2Error::Protocol(
4004 "WINDOW_UPDATE payload must be 4 bytes",
4005 ));
4006 }
4007
4008 let raw = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
4009 let increment = raw & 0x7FFF_FFFF;
4010 if increment == 0 {
4011 return Err(http2::Http2Error::Protocol(
4012 "WINDOW_UPDATE increment must be non-zero",
4013 ));
4014 }
4015
4016 Ok(())
4017}
4018
4019fn handle_h2_idle_frame(frame: &http2::Frame) -> Result<(), http2::Http2Error> {
4020 match frame.header.frame_type() {
4021 http2::FrameType::RstStream => {
4022 validate_rst_stream_payload(frame.header.stream_id, &frame.payload)
4023 }
4024 http2::FrameType::Priority => {
4025 validate_priority_payload(frame.header.stream_id, &frame.payload)
4026 }
4027 http2::FrameType::Data => Err(http2::Http2Error::Protocol(
4028 "unexpected DATA frame outside active request stream",
4029 )),
4030 http2::FrameType::Continuation => Err(http2::Http2Error::Protocol(
4031 "unexpected CONTINUATION frame outside header block",
4032 )),
4033 http2::FrameType::Unknown => Ok(()),
4034 _ => Ok(()),
4035 }
4036}
4037
4038const MAX_FLOW_CONTROL_WINDOW: i64 = 0x7FFF_FFFF;
4040
4041const SERVER_SETTINGS_PAYLOAD: &[u8] = &[
4045 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, ];
4048
4049const MAX_HPACK_TABLE_SIZE: usize = 64 * 1024;
4053
4054const MAX_HEADER_BLOCK_SIZE: usize = 128 * 1024;
4061
4062fn apply_send_conn_window_update(
4065 fc: &mut http2::H2FlowControl,
4066 increment: u32,
4067) -> Result<(), http2::Http2Error> {
4068 let new_window = fc.send_conn_window() + i64::from(increment);
4069 if new_window > MAX_FLOW_CONTROL_WINDOW {
4070 return Err(http2::Http2Error::Protocol(
4071 "WINDOW_UPDATE causes flow-control window to exceed 2^31-1",
4072 ));
4073 }
4074 fc.peer_window_update_connection(increment);
4075 Ok(())
4076}
4077
4078fn apply_peer_window_update_for_send(
4079 flow_control: &mut http2::H2FlowControl,
4080 stream_send_window: &mut i64,
4081 current_stream_id: u32,
4082 frame_stream_id: u32,
4083 payload: &[u8],
4084) -> Result<(), http2::Http2Error> {
4085 validate_window_update_payload(payload)?;
4086
4087 let increment =
4088 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]) & 0x7FFF_FFFF;
4089 if frame_stream_id == 0 {
4090 apply_send_conn_window_update(flow_control, increment)?;
4091 } else if frame_stream_id == current_stream_id {
4092 let new_window = *stream_send_window + i64::from(increment);
4093 if new_window > MAX_FLOW_CONTROL_WINDOW {
4094 return Err(http2::Http2Error::Protocol(
4095 "WINDOW_UPDATE causes flow-control window to exceed 2^31-1",
4096 ));
4097 }
4098 *stream_send_window = new_window;
4099 }
4100
4101 Ok(())
4102}
4103
4104fn apply_peer_settings_for_send(
4105 flow_control: &mut http2::H2FlowControl,
4106 stream_send_window: &mut i64,
4107 peer_max_frame_size: &mut u32,
4108 payload: &[u8],
4109) -> Result<(), http2::Http2Error> {
4110 if payload.len() % 6 != 0 {
4111 return Err(http2::Http2Error::Protocol(
4112 "SETTINGS length must be a multiple of 6",
4113 ));
4114 }
4115
4116 for chunk in payload.chunks_exact(6) {
4117 let id = u16::from_be_bytes([chunk[0], chunk[1]]);
4118 let value = u32::from_be_bytes([chunk[2], chunk[3], chunk[4], chunk[5]]);
4119
4120 if id == 0x3 {
4121 if value > 0x7FFF_FFFF {
4123 return Err(http2::Http2Error::Protocol(
4124 "SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum",
4125 ));
4126 }
4127 let old = i64::from(flow_control.peer_initial_window_size());
4128 let new = i64::from(value);
4129 let delta = new - old;
4130 let updated = *stream_send_window + delta;
4131 if updated > MAX_FLOW_CONTROL_WINDOW {
4132 return Err(http2::Http2Error::Protocol(
4133 "SETTINGS_INITIAL_WINDOW_SIZE change causes stream window to exceed 2^31-1",
4134 ));
4135 }
4136 flow_control.set_peer_initial_window_size(value);
4137 *stream_send_window = updated;
4138 } else if id == 0x5 {
4139 if !(16_384..=16_777_215).contains(&value) {
4141 return Err(http2::Http2Error::Protocol(
4142 "invalid SETTINGS_MAX_FRAME_SIZE",
4143 ));
4144 }
4145 *peer_max_frame_size = value;
4146 }
4147 }
4148
4149 Ok(())
4150}
4151
4152fn window_update_payload(increment: u32) -> [u8; 4] {
4154 (increment & 0x7FFF_FFFF).to_be_bytes()
4155}
4156
4157async fn send_window_updates(
4160 framed: &mut http2::FramedH2,
4161 conn_increment: u32,
4162 stream_id: u32,
4163 stream_increment: u32,
4164) -> Result<(), http2::Http2Error> {
4165 if conn_increment > 0 {
4166 let payload = window_update_payload(conn_increment);
4167 framed
4168 .write_frame(http2::FrameType::WindowUpdate, 0, 0, &payload)
4169 .await?;
4170 }
4171 if stream_increment > 0 {
4172 let payload = window_update_payload(stream_increment);
4173 framed
4174 .write_frame(http2::FrameType::WindowUpdate, 0, stream_id, &payload)
4175 .await?;
4176 }
4177 Ok(())
4178}
4179
4180#[allow(dead_code)]
4182mod h2_error_code {
4183 pub const NO_ERROR: u32 = 0x0;
4184 pub const PROTOCOL_ERROR: u32 = 0x1;
4185 pub const FLOW_CONTROL_ERROR: u32 = 0x3;
4186 pub const SETTINGS_TIMEOUT: u32 = 0x4;
4187 pub const STREAM_CLOSED: u32 = 0x5;
4188 pub const FRAME_SIZE_ERROR: u32 = 0x6;
4189 pub const REFUSED_STREAM: u32 = 0x7;
4190 pub const CANCEL: u32 = 0x8;
4191 pub const ENHANCE_YOUR_CALM: u32 = 0xb;
4192}
4193
4194fn validate_goaway_payload(payload: &[u8]) -> Result<(), http2::Http2Error> {
4196 if payload.len() < 8 {
4197 return Err(http2::Http2Error::Protocol(
4198 "GOAWAY payload must be at least 8 bytes",
4199 ));
4200 }
4201 Ok(())
4202}
4203
4204fn goaway_payload(last_stream_id: u32, error_code: u32) -> [u8; 8] {
4206 let mut buf = [0u8; 8];
4207 buf[..4].copy_from_slice(&(last_stream_id & 0x7FFF_FFFF).to_be_bytes());
4208 buf[4..].copy_from_slice(&error_code.to_be_bytes());
4209 buf
4210}
4211
4212async fn send_goaway(
4214 framed: &mut http2::FramedH2,
4215 last_stream_id: u32,
4216 error_code: u32,
4217) -> Result<(), http2::Http2Error> {
4218 let payload = goaway_payload(last_stream_id, error_code);
4219 framed
4220 .write_frame(http2::FrameType::Goaway, 0, 0, &payload)
4221 .await
4222}
4223
4224fn validate_rst_stream_payload(stream_id: u32, payload: &[u8]) -> Result<(), http2::Http2Error> {
4225 if stream_id == 0 {
4226 return Err(http2::Http2Error::Protocol(
4227 "RST_STREAM must not be on stream 0",
4228 ));
4229 }
4230 if payload.len() != 4 {
4231 return Err(http2::Http2Error::Protocol(
4232 "RST_STREAM payload must be 4 bytes",
4233 ));
4234 }
4235 Ok(())
4236}
4237
4238fn validate_priority_payload(stream_id: u32, payload: &[u8]) -> Result<(), http2::Http2Error> {
4239 if stream_id == 0 {
4240 return Err(http2::Http2Error::Protocol(
4241 "PRIORITY must not be on stream 0",
4242 ));
4243 }
4244
4245 if payload.len() != 5 {
4246 return Err(http2::Http2Error::Protocol(
4247 "PRIORITY payload must be 5 bytes",
4248 ));
4249 }
4250
4251 let dependency_raw = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
4252 let dependency_stream_id = dependency_raw & 0x7FFF_FFFF;
4253 if dependency_stream_id == stream_id {
4254 return Err(http2::Http2Error::Protocol(
4255 "PRIORITY stream dependency must not reference itself",
4256 ));
4257 }
4258
4259 Ok(())
4260}
4261
4262fn extract_header_block_fragment(
4263 flags: u8,
4264 payload: &[u8],
4265) -> Result<(bool, Vec<u8>), http2::Http2Error> {
4266 const FLAG_END_STREAM: u8 = 0x1;
4267 const FLAG_PADDED: u8 = 0x8;
4268 const FLAG_PRIORITY: u8 = 0x20;
4269
4270 let end_stream = (flags & FLAG_END_STREAM) != 0;
4271 let mut idx = 0usize;
4272
4273 let pad_len = if (flags & FLAG_PADDED) != 0 {
4274 if payload.is_empty() {
4275 return Err(http2::Http2Error::Protocol(
4276 "HEADERS PADDED set with empty payload",
4277 ));
4278 }
4279 let v = payload[0] as usize;
4280 idx += 1;
4281 v
4282 } else {
4283 0
4284 };
4285
4286 if (flags & FLAG_PRIORITY) != 0 {
4287 if payload.len().saturating_sub(idx) < 5 {
4289 return Err(http2::Http2Error::Protocol(
4290 "HEADERS PRIORITY set but too short",
4291 ));
4292 }
4293 idx += 5;
4294 }
4295
4296 if payload.len() < idx {
4297 return Err(http2::Http2Error::Protocol("invalid HEADERS payload"));
4298 }
4299 let frag = &payload[idx..];
4300 if frag.len() < pad_len {
4301 return Err(http2::Http2Error::Protocol(
4302 "invalid HEADERS padding length",
4303 ));
4304 }
4305 let end = frag.len() - pad_len;
4306 Ok((end_stream, frag[..end].to_vec()))
4307}
4308
4309fn extract_data_payload(flags: u8, payload: &[u8]) -> Result<(&[u8], bool), http2::Http2Error> {
4310 const FLAG_END_STREAM: u8 = 0x1;
4311 const FLAG_PADDED: u8 = 0x8;
4312
4313 let end_stream = (flags & FLAG_END_STREAM) != 0;
4314 if (flags & FLAG_PADDED) == 0 {
4315 return Ok((payload, end_stream));
4316 }
4317 if payload.is_empty() {
4318 return Err(http2::Http2Error::Protocol(
4319 "DATA PADDED set with empty payload",
4320 ));
4321 }
4322 let pad_len = payload[0] as usize;
4323 let data = &payload[1..];
4324 if data.len() < pad_len {
4325 return Err(http2::Http2Error::Protocol("invalid DATA padding length"));
4326 }
4327 Ok((&data[..data.len() - pad_len], end_stream))
4328}
4329
4330fn request_from_h2_headers(headers: http2::HeaderList) -> Result<Request, http2::Http2Error> {
4331 let mut method: Option<fastapi_core::Method> = None;
4332 let mut path: Option<String> = None;
4333 let mut authority: Option<Vec<u8>> = None;
4334 let mut saw_regular_headers = false;
4335
4336 let mut req_headers: Vec<(String, Vec<u8>)> = Vec::new();
4337
4338 for (name, value) in headers {
4339 if name.starts_with(b":") {
4340 if saw_regular_headers {
4341 return Err(http2::Http2Error::Protocol(
4342 "pseudo-headers must appear before regular headers",
4343 ));
4344 }
4345 match name.as_slice() {
4346 b":method" => {
4347 if method.is_some() {
4348 return Err(http2::Http2Error::Protocol(
4349 "duplicate :method pseudo-header",
4350 ));
4351 }
4352 method = Some(
4353 fastapi_core::Method::from_bytes(&value)
4354 .ok_or(http2::Http2Error::Protocol("invalid :method"))?,
4355 );
4356 }
4357 b":path" => {
4358 if path.is_some() {
4359 return Err(http2::Http2Error::Protocol("duplicate :path pseudo-header"));
4360 }
4361 let s = std::str::from_utf8(&value)
4362 .map_err(|_| http2::Http2Error::Protocol("non-utf8 :path"))?;
4363 path = Some(s.to_string());
4364 }
4365 b":authority" => {
4366 if authority.is_some() {
4367 return Err(http2::Http2Error::Protocol(
4368 "duplicate :authority pseudo-header",
4369 ));
4370 }
4371 authority = Some(value);
4372 }
4373 b":scheme" => {}
4374 _ => return Err(http2::Http2Error::Protocol("unknown pseudo-header")),
4375 }
4376 continue;
4377 }
4378
4379 saw_regular_headers = true;
4380 let n = std::str::from_utf8(&name)
4381 .map_err(|_| http2::Http2Error::Protocol("non-utf8 header name"))?;
4382 req_headers.push((n.to_string(), value));
4383 }
4384
4385 let method = method.ok_or(http2::Http2Error::Protocol("missing :method"))?;
4386 let raw_path = path.ok_or(http2::Http2Error::Protocol("missing :path"))?;
4387 let (path_only, query) = match raw_path.split_once('?') {
4388 Some((p, q)) => (p.to_string(), Some(q.to_string())),
4389 None => (raw_path, None),
4390 };
4391
4392 let mut req = Request::with_version(method, path_only, fastapi_core::HttpVersion::Http2);
4393 req.set_query(query);
4394
4395 if let Some(auth) = authority {
4396 req.headers_mut().insert("host", auth);
4397 }
4398
4399 for (n, v) in req_headers {
4400 req.headers_mut().insert(n, v);
4401 }
4402
4403 Ok(req)
4404}
4405
4406fn is_h2_forbidden_header_name(name: &str) -> bool {
4407 name.eq_ignore_ascii_case("connection")
4410 || name.eq_ignore_ascii_case("keep-alive")
4411 || name.eq_ignore_ascii_case("proxy-connection")
4412 || name.eq_ignore_ascii_case("transfer-encoding")
4413 || name.eq_ignore_ascii_case("upgrade")
4414 || name.eq_ignore_ascii_case("te")
4415}
4416
4417async fn write_raw_response(stream: &mut TcpStream, bytes: &[u8]) -> io::Result<()> {
4421 use std::future::poll_fn;
4422 write_all(stream, bytes).await?;
4423 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await?;
4424 Ok(())
4425}
4426
4427async fn write_response(stream: &mut TcpStream, response: ResponseWrite) -> io::Result<()> {
4431 use std::future::poll_fn;
4432
4433 match response {
4434 ResponseWrite::Full(bytes) => {
4435 write_all(stream, &bytes).await?;
4436 }
4437 ResponseWrite::Stream(mut encoder) => {
4438 loop {
4440 let chunk = poll_fn(|cx| Pin::new(&mut encoder).poll_next(cx)).await;
4441 match chunk {
4442 Some(bytes) => {
4443 write_all(stream, &bytes).await?;
4444 }
4445 None => break,
4446 }
4447 }
4448 }
4449 }
4450
4451 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await?;
4453
4454 Ok(())
4455}
4456
4457async fn write_all(stream: &mut TcpStream, mut buf: &[u8]) -> io::Result<()> {
4459 use std::future::poll_fn;
4460
4461 while !buf.is_empty() {
4462 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, buf)).await?;
4463 if n == 0 {
4464 return Err(io::Error::new(
4465 io::ErrorKind::WriteZero,
4466 "failed to write whole buffer",
4467 ));
4468 }
4469 buf = &buf[n..];
4470 }
4471 Ok(())
4472}
4473
4474pub struct Server {
4486 parser: Parser,
4487}
4488
4489impl Server {
4490 #[must_use]
4492 pub fn new() -> Self {
4493 Self {
4494 parser: Parser::new(),
4495 }
4496 }
4497
4498 pub fn parse_request(&self, bytes: &[u8]) -> Result<Request, ParseError> {
4504 self.parser.parse(bytes)
4505 }
4506
4507 #[must_use]
4509 pub fn write_response(&self, response: Response) -> ResponseWrite {
4510 let mut writer = ResponseWriter::new();
4511 writer.write(response)
4512 }
4513}
4514
4515impl Default for Server {
4516 fn default() -> Self {
4517 Self::new()
4518 }
4519}
4520
4521#[cfg(test)]
4522mod tests {
4523 use super::*;
4524 use std::future::Future;
4525
4526 fn block_on<F: Future>(f: F) -> F::Output {
4527 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
4528 .build()
4529 .expect("test runtime must build");
4530 rt.block_on(f)
4531 }
4532
4533 #[test]
4534 fn server_config_builder() {
4535 let config = ServerConfig::new("0.0.0.0:3000")
4536 .with_request_timeout_secs(60)
4537 .with_max_connections(1000)
4538 .with_tcp_nodelay(false)
4539 .with_allowed_hosts(["example.com", "api.example.com"])
4540 .with_trust_x_forwarded_host(true);
4541
4542 assert_eq!(config.bind_addr, "0.0.0.0:3000");
4543 assert_eq!(config.request_timeout, Time::from_secs(60));
4544 assert_eq!(config.max_connections, 1000);
4545 assert!(!config.tcp_nodelay);
4546 assert_eq!(config.allowed_hosts.len(), 2);
4547 assert!(config.trust_x_forwarded_host);
4548 }
4549
4550 #[test]
4551 fn server_config_defaults() {
4552 let config = ServerConfig::default();
4553 assert_eq!(config.bind_addr, "127.0.0.1:8080");
4554 assert_eq!(
4555 config.request_timeout,
4556 Time::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)
4557 );
4558 assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
4559 assert!(config.tcp_nodelay);
4560 assert!(config.allowed_hosts.is_empty());
4561 assert!(!config.trust_x_forwarded_host);
4562 }
4563
4564 #[test]
4565 fn tcp_server_creates_request_ids() {
4566 let server = TcpServer::default();
4567 let id1 = server.next_request_id();
4568 let id2 = server.next_request_id();
4569 let id3 = server.next_request_id();
4570
4571 assert_eq!(id1, 0);
4572 assert_eq!(id2, 1);
4573 assert_eq!(id3, 2);
4574 }
4575
4576 #[test]
4577 fn server_error_display() {
4578 let io_err = ServerError::Io(io::Error::new(io::ErrorKind::AddrInUse, "address in use"));
4579 assert!(io_err.to_string().contains("IO error"));
4580
4581 let shutdown_err = ServerError::Shutdown;
4582 assert_eq!(shutdown_err.to_string(), "Server shutdown");
4583
4584 let limit_err = ServerError::ConnectionLimitReached;
4585 assert_eq!(limit_err.to_string(), "Connection limit reached");
4586 }
4587
4588 #[test]
4589 fn sync_server_parses_request() {
4590 let server = Server::new();
4591 let request = b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n";
4592 let result = server.parse_request(request);
4593 assert!(result.is_ok());
4594 }
4595
4596 #[test]
4597 fn window_update_payload_validation_accepts_non_zero_increment() {
4598 let payload = 1u32.to_be_bytes();
4599 assert!(validate_window_update_payload(&payload).is_ok());
4600 }
4601
4602 #[test]
4603 fn window_update_payload_validation_rejects_bad_length() {
4604 let err = validate_window_update_payload(&[0, 0, 0]).unwrap_err();
4605 assert!(
4606 err.to_string()
4607 .contains("WINDOW_UPDATE payload must be 4 bytes")
4608 );
4609 }
4610
4611 #[test]
4612 fn window_update_payload_validation_rejects_zero_increment() {
4613 let payload = 0u32.to_be_bytes();
4614 let err = validate_window_update_payload(&payload).unwrap_err();
4615 assert!(
4616 err.to_string()
4617 .contains("WINDOW_UPDATE increment must be non-zero")
4618 );
4619 }
4620
4621 #[test]
4622 fn settings_frame_validation_accepts_non_ack_payload() {
4623 let payload = [0u8; 6];
4624 let is_ack = validate_settings_frame(0, 0, &payload).unwrap();
4625 assert!(!is_ack);
4626 }
4627
4628 #[test]
4629 fn settings_frame_validation_accepts_empty_ack_payload() {
4630 let is_ack = validate_settings_frame(0, 0x1, &[]).unwrap();
4631 assert!(is_ack);
4632 }
4633
4634 #[test]
4635 fn settings_frame_validation_rejects_non_zero_stream() {
4636 let err = validate_settings_frame(1, 0, &[]).unwrap_err();
4637 assert!(err.to_string().contains("SETTINGS must be on stream 0"));
4638 }
4639
4640 #[test]
4641 fn settings_frame_validation_rejects_non_empty_ack_payload() {
4642 let err = validate_settings_frame(0, 0x1, &[0, 0, 0, 0, 0, 0]).unwrap_err();
4643 assert!(
4644 err.to_string()
4645 .contains("SETTINGS ACK frame must have empty payload")
4646 );
4647 }
4648
4649 #[test]
4650 fn settings_enable_push_accepts_zero() {
4651 let payload = [0x00, 0x02, 0x00, 0x00, 0x00, 0x00];
4653 let mut hpack = http2::HpackDecoder::new();
4654 let mut max_frame_size = 16384u32;
4655 assert!(apply_http2_settings(&mut hpack, &mut max_frame_size, &payload).is_ok());
4656 }
4657
4658 #[test]
4659 fn settings_enable_push_accepts_one() {
4660 let payload = [0x00, 0x02, 0x00, 0x00, 0x00, 0x01];
4661 let mut hpack = http2::HpackDecoder::new();
4662 let mut max_frame_size = 16384u32;
4663 assert!(apply_http2_settings(&mut hpack, &mut max_frame_size, &payload).is_ok());
4664 }
4665
4666 #[test]
4667 fn settings_enable_push_rejects_invalid_value() {
4668 let payload = [0x00, 0x02, 0x00, 0x00, 0x00, 0x02];
4669 let mut hpack = http2::HpackDecoder::new();
4670 let mut max_frame_size = 16384u32;
4671 let err = apply_http2_settings(&mut hpack, &mut max_frame_size, &payload).unwrap_err();
4672 assert!(
4673 err.to_string()
4674 .contains("SETTINGS_ENABLE_PUSH must be 0 or 1")
4675 );
4676 }
4677
4678 #[test]
4679 fn rst_stream_payload_validation_accepts_valid_payload() {
4680 let payload = 8u32.to_be_bytes();
4681 assert!(validate_rst_stream_payload(1, &payload).is_ok());
4682 }
4683
4684 #[test]
4685 fn rst_stream_payload_validation_rejects_stream_zero() {
4686 let payload = 8u32.to_be_bytes();
4687 let err = validate_rst_stream_payload(0, &payload).unwrap_err();
4688 assert!(
4689 err.to_string()
4690 .contains("RST_STREAM must not be on stream 0")
4691 );
4692 }
4693
4694 #[test]
4695 fn rst_stream_payload_validation_rejects_bad_length() {
4696 let err = validate_rst_stream_payload(1, &[0, 0, 0]).unwrap_err();
4697 assert!(
4698 err.to_string()
4699 .contains("RST_STREAM payload must be 4 bytes")
4700 );
4701 }
4702
4703 #[test]
4704 fn priority_payload_validation_accepts_valid_priority() {
4705 let payload = [0, 0, 0, 0, 16];
4706 assert!(validate_priority_payload(1, &payload).is_ok());
4707 }
4708
4709 #[test]
4710 fn priority_payload_validation_rejects_stream_zero() {
4711 let payload = [0, 0, 0, 0, 16];
4712 let err = validate_priority_payload(0, &payload).unwrap_err();
4713 assert!(err.to_string().contains("PRIORITY must not be on stream 0"));
4714 }
4715
4716 #[test]
4717 fn priority_payload_validation_rejects_bad_length() {
4718 let err = validate_priority_payload(1, &[0, 0, 0, 0]).unwrap_err();
4719 assert!(err.to_string().contains("PRIORITY payload must be 5 bytes"));
4720 }
4721
4722 #[test]
4723 fn priority_payload_validation_rejects_self_dependency() {
4724 let payload = 1u32.to_be_bytes();
4725 let mut with_weight = [0u8; 5];
4726 with_weight[..4].copy_from_slice(&payload);
4727 with_weight[4] = 16;
4728 let err = validate_priority_payload(1, &with_weight).unwrap_err();
4729 assert!(
4730 err.to_string()
4731 .contains("PRIORITY stream dependency must not reference itself")
4732 );
4733 }
4734
4735 #[test]
4736 fn goaway_payload_validation_accepts_valid_payload() {
4737 let payload = goaway_payload(0, 0);
4738 assert!(validate_goaway_payload(&payload).is_ok());
4739 }
4740
4741 #[test]
4742 fn goaway_payload_validation_accepts_payload_with_debug_data() {
4743 let mut payload = Vec::from(goaway_payload(1, 0).as_slice());
4744 payload.extend_from_slice(b"debug info");
4745 assert!(validate_goaway_payload(&payload).is_ok());
4746 }
4747
4748 #[test]
4749 fn goaway_payload_validation_rejects_short_payload() {
4750 let err = validate_goaway_payload(&[0, 0, 0]).unwrap_err();
4751 assert!(
4752 err.to_string()
4753 .contains("GOAWAY payload must be at least 8 bytes")
4754 );
4755 }
4756
4757 #[test]
4758 fn goaway_payload_validation_rejects_empty() {
4759 let err = validate_goaway_payload(&[]).unwrap_err();
4760 assert!(
4761 err.to_string()
4762 .contains("GOAWAY payload must be at least 8 bytes")
4763 );
4764 }
4765
4766 fn h2_test_frame(
4767 frame_type: http2::FrameType,
4768 stream_id: u32,
4769 payload: Vec<u8>,
4770 ) -> http2::Frame {
4771 http2::Frame {
4772 header: http2::FrameHeader {
4773 length: payload.len() as u32,
4774 frame_type: frame_type as u8,
4775 flags: 0,
4776 stream_id,
4777 },
4778 payload,
4779 }
4780 }
4781
4782 #[test]
4783 fn h2_idle_frame_rejects_data_outside_request_stream() {
4784 let frame = h2_test_frame(http2::FrameType::Data, 1, Vec::new());
4785 let err = handle_h2_idle_frame(&frame).unwrap_err();
4786 assert!(
4787 err.to_string()
4788 .contains("unexpected DATA frame outside active request stream")
4789 );
4790 }
4791
4792 #[test]
4793 fn h2_idle_frame_rejects_continuation_outside_header_block() {
4794 let frame = h2_test_frame(http2::FrameType::Continuation, 1, Vec::new());
4795 let err = handle_h2_idle_frame(&frame).unwrap_err();
4796 assert!(
4797 err.to_string()
4798 .contains("unexpected CONTINUATION frame outside header block")
4799 );
4800 }
4801
4802 #[test]
4803 fn h2_idle_frame_validates_rst_stream_payload() {
4804 let invalid = h2_test_frame(http2::FrameType::RstStream, 0, 8u32.to_be_bytes().to_vec());
4805 let err = handle_h2_idle_frame(&invalid).unwrap_err();
4806 assert!(
4807 err.to_string()
4808 .contains("RST_STREAM must not be on stream 0")
4809 );
4810
4811 let valid = h2_test_frame(http2::FrameType::RstStream, 3, 8u32.to_be_bytes().to_vec());
4812 assert!(handle_h2_idle_frame(&valid).is_ok());
4813 }
4814
4815 #[test]
4816 fn h2_idle_frame_validates_priority_payload() {
4817 let invalid = h2_test_frame(http2::FrameType::Priority, 0, vec![0, 0, 0, 0, 16]);
4818 let err = handle_h2_idle_frame(&invalid).unwrap_err();
4819 assert!(err.to_string().contains("PRIORITY must not be on stream 0"));
4820
4821 let valid = h2_test_frame(http2::FrameType::Priority, 1, vec![0, 0, 0, 0, 16]);
4822 assert!(handle_h2_idle_frame(&valid).is_ok());
4823 }
4824
4825 #[test]
4826 fn max_header_block_size_is_128k() {
4827 assert_eq!(MAX_HEADER_BLOCK_SIZE, 128 * 1024);
4828 }
4829
4830 #[test]
4831 fn server_settings_payload_advertises_max_concurrent_streams() {
4832 assert_eq!(SERVER_SETTINGS_PAYLOAD.len(), 6);
4834 assert_eq!(SERVER_SETTINGS_PAYLOAD[0..2], [0x00, 0x03]);
4835 assert_eq!(
4836 u32::from_be_bytes([
4837 SERVER_SETTINGS_PAYLOAD[2],
4838 SERVER_SETTINGS_PAYLOAD[3],
4839 SERVER_SETTINGS_PAYLOAD[4],
4840 SERVER_SETTINGS_PAYLOAD[5],
4841 ]),
4842 1
4843 );
4844 }
4845
4846 #[test]
4847 fn max_hpack_table_size_is_64k() {
4848 assert_eq!(MAX_HPACK_TABLE_SIZE, 64 * 1024);
4849 }
4850
4851 #[test]
4852 fn h2_send_window_update_ignores_other_streams() {
4853 let mut flow_control = http2::H2FlowControl::new();
4854 let mut stream_window = 123i64;
4855 let payload = 7u32.to_be_bytes();
4856
4857 apply_peer_window_update_for_send(&mut flow_control, &mut stream_window, 3, 5, &payload)
4858 .expect("window update on different stream should be ignored");
4859
4860 assert_eq!(stream_window, 123);
4861 }
4862
4863 #[test]
4864 fn h2_send_window_update_applies_connection_and_current_stream() {
4865 let mut flow_control = http2::H2FlowControl::new();
4866 let mut stream_window = 10i64;
4867
4868 let conn_before = flow_control.send_conn_window();
4869 let conn_payload = 11u32.to_be_bytes();
4870 apply_peer_window_update_for_send(
4871 &mut flow_control,
4872 &mut stream_window,
4873 9,
4874 0,
4875 &conn_payload,
4876 )
4877 .expect("connection window update should be applied");
4878 assert_eq!(flow_control.send_conn_window(), conn_before + 11);
4879 assert_eq!(stream_window, 10);
4880
4881 let stream_payload = 13u32.to_be_bytes();
4882 apply_peer_window_update_for_send(
4883 &mut flow_control,
4884 &mut stream_window,
4885 9,
4886 9,
4887 &stream_payload,
4888 )
4889 .expect("stream window update should be applied to current stream");
4890 assert_eq!(stream_window, 23);
4891 }
4892
4893 #[test]
4894 fn h2_send_settings_updates_current_stream_window_delta() {
4895 let mut flow_control = http2::H2FlowControl::new();
4896 let mut stream_window = 50i64;
4897 let mut peer_max_frame_size = 16_384u32;
4898
4899 let payload = [0x00, 0x03, 0x00, 0x01, 0x11, 0x70]; apply_peer_settings_for_send(
4901 &mut flow_control,
4902 &mut stream_window,
4903 &mut peer_max_frame_size,
4904 &payload,
4905 )
4906 .expect("valid SETTINGS_INITIAL_WINDOW_SIZE should apply");
4907
4908 assert_eq!(flow_control.peer_initial_window_size(), 70_000);
4909 assert_eq!(stream_window, 4_515); assert_eq!(peer_max_frame_size, 16_384);
4911 }
4912
4913 #[test]
4914 fn h2_send_settings_rejects_invalid_payload_len() {
4915 let mut flow_control = http2::H2FlowControl::new();
4916 let mut stream_window = 0i64;
4917 let mut peer_max_frame_size = 16_384u32;
4918 let err = apply_peer_settings_for_send(
4919 &mut flow_control,
4920 &mut stream_window,
4921 &mut peer_max_frame_size,
4922 &[0, 1, 2],
4923 )
4924 .unwrap_err();
4925 assert!(
4926 err.to_string()
4927 .contains("SETTINGS length must be a multiple of 6")
4928 );
4929 }
4930
4931 #[test]
4932 fn h2_send_settings_rejects_initial_window_too_large() {
4933 let mut flow_control = http2::H2FlowControl::new();
4934 let mut stream_window = 0i64;
4935 let mut peer_max_frame_size = 16_384u32;
4936 let payload = [0x00, 0x03, 0x80, 0x00, 0x00, 0x00]; let err = apply_peer_settings_for_send(
4938 &mut flow_control,
4939 &mut stream_window,
4940 &mut peer_max_frame_size,
4941 &payload,
4942 )
4943 .unwrap_err();
4944 assert!(
4945 err.to_string()
4946 .contains("SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum")
4947 );
4948 }
4949
4950 #[test]
4951 fn h2_send_settings_window_delta_overflow_is_flow_control_error() {
4952 let mut flow_control = http2::H2FlowControl::new();
4953 let mut peer_max_frame_size = 16_384u32;
4954 let mut stream_window: i64 = 0x7FFF_FFFF - 10;
4956 let new_initial: u32 = 0x7FFF_FFFF;
4960 let payload = [
4961 0x00,
4962 0x03,
4963 new_initial.to_be_bytes()[0],
4964 new_initial.to_be_bytes()[1],
4965 new_initial.to_be_bytes()[2],
4966 new_initial.to_be_bytes()[3],
4967 ];
4968 let err = apply_peer_settings_for_send(
4969 &mut flow_control,
4970 &mut stream_window,
4971 &mut peer_max_frame_size,
4972 &payload,
4973 )
4974 .unwrap_err();
4975 assert!(err.to_string().contains("stream window to exceed 2^31-1"));
4976 }
4977
4978 #[test]
4979 fn h2_send_settings_updates_peer_max_frame_size() {
4980 let mut flow_control = http2::H2FlowControl::new();
4981 let mut stream_window = 0i64;
4982 let mut peer_max_frame_size = 65_535u32;
4983 let payload = [0x00, 0x05, 0x00, 0x00, 0x40, 0x00]; apply_peer_settings_for_send(
4986 &mut flow_control,
4987 &mut stream_window,
4988 &mut peer_max_frame_size,
4989 &payload,
4990 )
4991 .expect("valid SETTINGS_MAX_FRAME_SIZE should apply");
4992
4993 assert_eq!(peer_max_frame_size, 16_384);
4994 }
4995
4996 #[test]
4997 fn h2_send_settings_rejects_invalid_max_frame_size() {
4998 let mut flow_control = http2::H2FlowControl::new();
4999 let mut stream_window = 0i64;
5000 let mut peer_max_frame_size = 16_384u32;
5001 let payload = [0x00, 0x05, 0x00, 0x00, 0x3F, 0xFF]; let err = apply_peer_settings_for_send(
5004 &mut flow_control,
5005 &mut stream_window,
5006 &mut peer_max_frame_size,
5007 &payload,
5008 )
5009 .unwrap_err();
5010 assert!(err.to_string().contains("invalid SETTINGS_MAX_FRAME_SIZE"));
5011 }
5012
5013 #[test]
5014 fn request_from_h2_headers_rejects_unknown_pseudo_header() {
5015 let headers: http2::HeaderList = vec![
5016 (b":method".to_vec(), b"GET".to_vec()),
5017 (b":path".to_vec(), b"/".to_vec()),
5018 (b":weird".to_vec(), b"value".to_vec()),
5019 ];
5020 let err = request_from_h2_headers(headers).unwrap_err();
5021 assert!(err.to_string().contains("unknown pseudo-header"));
5022 }
5023
5024 #[test]
5025 fn request_from_h2_headers_rejects_pseudo_after_regular_header() {
5026 let headers: http2::HeaderList = vec![
5027 (b":method".to_vec(), b"GET".to_vec()),
5028 (b":path".to_vec(), b"/".to_vec()),
5029 (b"x-test".to_vec(), b"ok".to_vec()),
5030 (b":authority".to_vec(), b"example.com".to_vec()),
5031 ];
5032 let err = request_from_h2_headers(headers).unwrap_err();
5033 assert!(
5034 err.to_string()
5035 .contains("pseudo-headers must appear before regular headers")
5036 );
5037 }
5038
5039 #[test]
5044 fn host_validation_missing_host_rejected() {
5045 let config = ServerConfig::default();
5046 let request = Request::new(fastapi_core::Method::Get, "/");
5047 let err = validate_host_header(&request, &config).unwrap_err();
5048 assert_eq!(err.kind, HostValidationErrorKind::Missing);
5049 assert_eq!(err.response().status().as_u16(), 400);
5050 }
5051
5052 #[test]
5053 fn host_validation_allows_configured_host() {
5054 let config = ServerConfig::default().with_allowed_hosts(["example.com"]);
5055 let mut request = Request::new(fastapi_core::Method::Get, "/");
5056 request
5057 .headers_mut()
5058 .insert("Host".to_string(), b"example.com".to_vec());
5059 assert!(validate_host_header(&request, &config).is_ok());
5060 }
5061
5062 #[test]
5063 fn host_validation_rejects_disallowed_host() {
5064 let config = ServerConfig::default().with_allowed_hosts(["example.com"]);
5065 let mut request = Request::new(fastapi_core::Method::Get, "/");
5066 request
5067 .headers_mut()
5068 .insert("Host".to_string(), b"evil.com".to_vec());
5069 let err = validate_host_header(&request, &config).unwrap_err();
5070 assert_eq!(err.kind, HostValidationErrorKind::NotAllowed);
5071 }
5072
5073 #[test]
5074 fn host_validation_wildcard_allows_subdomains_only() {
5075 let config = ServerConfig::default().with_allowed_hosts(["*.example.com"]);
5076 let mut request = Request::new(fastapi_core::Method::Get, "/");
5077 request
5078 .headers_mut()
5079 .insert("Host".to_string(), b"api.example.com".to_vec());
5080 assert!(validate_host_header(&request, &config).is_ok());
5081
5082 let mut request = Request::new(fastapi_core::Method::Get, "/");
5083 request
5084 .headers_mut()
5085 .insert("Host".to_string(), b"example.com".to_vec());
5086 let err = validate_host_header(&request, &config).unwrap_err();
5087 assert_eq!(err.kind, HostValidationErrorKind::NotAllowed);
5088 }
5089
5090 #[test]
5091 fn host_validation_uses_x_forwarded_host_when_trusted() {
5092 let config = ServerConfig::default()
5093 .with_allowed_hosts(["example.com"])
5094 .with_trust_x_forwarded_host(true);
5095 let mut request = Request::new(fastapi_core::Method::Get, "/");
5096 request
5097 .headers_mut()
5098 .insert("Host".to_string(), b"internal.local".to_vec());
5099 request
5100 .headers_mut()
5101 .insert("X-Forwarded-Host".to_string(), b"example.com".to_vec());
5102 assert!(validate_host_header(&request, &config).is_ok());
5103 }
5104
5105 #[test]
5106 fn host_validation_rejects_invalid_host_value() {
5107 let config = ServerConfig::default();
5108 let mut request = Request::new(fastapi_core::Method::Get, "/");
5109 request
5110 .headers_mut()
5111 .insert("Host".to_string(), b"bad host".to_vec());
5112 let err = validate_host_header(&request, &config).unwrap_err();
5113 assert_eq!(err.kind, HostValidationErrorKind::Invalid);
5114 }
5115
5116 #[test]
5121 fn websocket_upgrade_detection_accepts_token_lists_case_insensitive() {
5122 let mut request = Request::new(fastapi_core::Method::Get, "/ws");
5123 request
5124 .headers_mut()
5125 .insert("Upgrade".to_string(), b"h2c, WebSocket".to_vec());
5126 request
5127 .headers_mut()
5128 .insert("Connection".to_string(), b"keep-alive, UPGRADE".to_vec());
5129
5130 assert!(is_websocket_upgrade_request(&request));
5131 }
5132
5133 #[test]
5134 fn websocket_upgrade_detection_rejects_missing_connection_upgrade_token() {
5135 let mut request = Request::new(fastapi_core::Method::Get, "/ws");
5136 request
5137 .headers_mut()
5138 .insert("Upgrade".to_string(), b"websocket".to_vec());
5139 request
5140 .headers_mut()
5141 .insert("Connection".to_string(), b"keep-alive".to_vec());
5142
5143 assert!(!is_websocket_upgrade_request(&request));
5144 }
5145
5146 #[test]
5147 fn websocket_upgrade_detection_rejects_non_get_method() {
5148 let mut request = Request::new(fastapi_core::Method::Post, "/ws");
5149 request
5150 .headers_mut()
5151 .insert("Upgrade".to_string(), b"websocket".to_vec());
5152 request
5153 .headers_mut()
5154 .insert("Connection".to_string(), b"upgrade".to_vec());
5155
5156 assert!(!is_websocket_upgrade_request(&request));
5157 }
5158
5159 #[test]
5164 fn keep_alive_default_http11() {
5165 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5167 request
5168 .headers_mut()
5169 .insert("Host".to_string(), b"example.com".to_vec());
5170 assert!(should_keep_alive(&request));
5171 }
5172
5173 #[test]
5174 fn keep_alive_explicit_keep_alive() {
5175 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5176 request
5177 .headers_mut()
5178 .insert("Connection".to_string(), b"keep-alive".to_vec());
5179 assert!(should_keep_alive(&request));
5180 }
5181
5182 #[test]
5183 fn keep_alive_connection_close() {
5184 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5185 request
5186 .headers_mut()
5187 .insert("Connection".to_string(), b"close".to_vec());
5188 assert!(!should_keep_alive(&request));
5189 }
5190
5191 #[test]
5192 fn keep_alive_connection_close_case_insensitive() {
5193 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5194 request
5195 .headers_mut()
5196 .insert("Connection".to_string(), b"CLOSE".to_vec());
5197 assert!(!should_keep_alive(&request));
5198 }
5199
5200 #[test]
5201 fn keep_alive_multiple_values() {
5202 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5203 request
5204 .headers_mut()
5205 .insert("Connection".to_string(), b"keep-alive, upgrade".to_vec());
5206 assert!(should_keep_alive(&request));
5207 }
5208
5209 #[test]
5214 fn timeout_budget_created_with_config_deadline() {
5215 let config = ServerConfig::new("127.0.0.1:8080").with_request_timeout_secs(45);
5216 let budget = Budget::new().with_deadline(config.request_timeout);
5217 assert_eq!(budget.deadline, Some(Time::from_secs(45)));
5218 }
5219
5220 #[test]
5221 fn timeout_duration_conversion_from_time() {
5222 let timeout = Time::from_secs(30);
5223 let duration = Duration::from_nanos(timeout.as_nanos());
5224 assert_eq!(duration, Duration::from_secs(30));
5225 }
5226
5227 #[test]
5228 fn timeout_duration_conversion_from_time_millis() {
5229 let timeout = Time::from_millis(1500);
5230 let duration = Duration::from_nanos(timeout.as_nanos());
5231 assert_eq!(duration, Duration::from_millis(1500));
5232 }
5233
5234 #[test]
5235 fn gateway_timeout_response_has_correct_status() {
5236 let response = Response::with_status(StatusCode::GATEWAY_TIMEOUT);
5237 assert_eq!(response.status().as_u16(), 504);
5238 }
5239
5240 #[test]
5241 fn gateway_timeout_response_with_body() {
5242 let response = Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
5243 fastapi_core::ResponseBody::Bytes(b"Request timed out".to_vec()),
5244 );
5245 assert_eq!(response.status().as_u16(), 504);
5246 assert!(response.body_ref().len() > 0);
5248 }
5249
5250 #[test]
5251 fn elapsed_time_check_logic() {
5252 let start = Instant::now();
5254 let timeout_duration = Duration::from_millis(10);
5255
5256 assert!(start.elapsed() <= timeout_duration);
5258
5259 std::thread::sleep(Duration::from_millis(20));
5261
5262 assert!(start.elapsed() > timeout_duration);
5264 }
5265
5266 #[test]
5271 fn connection_counter_starts_at_zero() {
5272 let server = TcpServer::default();
5273 assert_eq!(server.current_connections(), 0);
5274 }
5275
5276 #[test]
5277 fn try_acquire_connection_unlimited() {
5278 let server = TcpServer::default();
5280 assert_eq!(server.config().max_connections, 0);
5281
5282 for _ in 0..100 {
5284 assert!(server.try_acquire_connection());
5285 }
5286 assert_eq!(server.current_connections(), 100);
5287
5288 for _ in 0..100 {
5290 server.release_connection();
5291 }
5292 assert_eq!(server.current_connections(), 0);
5293 }
5294
5295 #[test]
5296 fn try_acquire_connection_with_limit() {
5297 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(5);
5298 let server = TcpServer::new(config);
5299
5300 for i in 0..5 {
5302 assert!(
5303 server.try_acquire_connection(),
5304 "Should acquire connection {i}"
5305 );
5306 }
5307 assert_eq!(server.current_connections(), 5);
5308
5309 assert!(!server.try_acquire_connection());
5311 assert_eq!(server.current_connections(), 5);
5312
5313 server.release_connection();
5315 assert_eq!(server.current_connections(), 4);
5316
5317 assert!(server.try_acquire_connection());
5319 assert_eq!(server.current_connections(), 5);
5320 }
5321
5322 #[test]
5323 fn try_acquire_connection_single_connection_limit() {
5324 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(1);
5325 let server = TcpServer::new(config);
5326
5327 assert!(server.try_acquire_connection());
5329 assert_eq!(server.current_connections(), 1);
5330
5331 assert!(!server.try_acquire_connection());
5333 assert_eq!(server.current_connections(), 1);
5334
5335 server.release_connection();
5337 assert!(server.try_acquire_connection());
5338 }
5339
5340 #[test]
5341 fn service_unavailable_response_has_correct_status() {
5342 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE);
5343 assert_eq!(response.status().as_u16(), 503);
5344 }
5345
5346 #[test]
5347 fn service_unavailable_response_with_body() {
5348 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
5349 .header("connection", b"close".to_vec())
5350 .body(fastapi_core::ResponseBody::Bytes(
5351 b"503 Service Unavailable: connection limit reached".to_vec(),
5352 ));
5353 assert_eq!(response.status().as_u16(), 503);
5354 assert!(response.body_ref().len() > 0);
5355 }
5356
5357 #[test]
5358 fn config_max_connections_default_is_zero() {
5359 let config = ServerConfig::default();
5360 assert_eq!(config.max_connections, 0);
5361 }
5362
5363 #[test]
5364 fn config_max_connections_can_be_set() {
5365 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(100);
5366 assert_eq!(config.max_connections, 100);
5367 }
5368
5369 #[test]
5374 fn config_keep_alive_timeout_default() {
5375 let config = ServerConfig::default();
5376 assert_eq!(
5377 config.keep_alive_timeout,
5378 Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS)
5379 );
5380 }
5381
5382 #[test]
5383 fn config_keep_alive_timeout_can_be_set() {
5384 let config =
5385 ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout(Duration::from_secs(120));
5386 assert_eq!(config.keep_alive_timeout, Duration::from_secs(120));
5387 }
5388
5389 #[test]
5390 fn config_keep_alive_timeout_can_be_set_secs() {
5391 let config = ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout_secs(90);
5392 assert_eq!(config.keep_alive_timeout, Duration::from_secs(90));
5393 }
5394
5395 #[test]
5396 fn config_max_requests_per_connection_default() {
5397 let config = ServerConfig::default();
5398 assert_eq!(
5399 config.max_requests_per_connection,
5400 DEFAULT_MAX_REQUESTS_PER_CONNECTION
5401 );
5402 }
5403
5404 #[test]
5405 fn config_max_requests_per_connection_can_be_set() {
5406 let config = ServerConfig::new("127.0.0.1:8080").with_max_requests_per_connection(50);
5407 assert_eq!(config.max_requests_per_connection, 50);
5408 }
5409
5410 #[test]
5411 fn config_max_requests_per_connection_unlimited() {
5412 let config = ServerConfig::new("127.0.0.1:8080").with_max_requests_per_connection(0);
5413 assert_eq!(config.max_requests_per_connection, 0);
5414 }
5415
5416 #[test]
5417 fn response_with_keep_alive_header() {
5418 let response = Response::ok().header("connection", b"keep-alive".to_vec());
5419 let headers = response.headers();
5420 let connection_header = headers
5421 .iter()
5422 .find(|(name, _)| name.eq_ignore_ascii_case("connection"));
5423 assert!(connection_header.is_some());
5424 assert_eq!(connection_header.unwrap().1, b"keep-alive");
5425 }
5426
5427 #[test]
5428 fn response_with_close_header() {
5429 let response = Response::ok().header("connection", b"close".to_vec());
5430 let headers = response.headers();
5431 let connection_header = headers
5432 .iter()
5433 .find(|(name, _)| name.eq_ignore_ascii_case("connection"));
5434 assert!(connection_header.is_some());
5435 assert_eq!(connection_header.unwrap().1, b"close");
5436 }
5437
5438 #[test]
5443 fn config_drain_timeout_default() {
5444 let config = ServerConfig::default();
5445 assert_eq!(
5446 config.drain_timeout,
5447 Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS)
5448 );
5449 }
5450
5451 #[test]
5452 fn config_drain_timeout_can_be_set() {
5453 let config =
5454 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_secs(60));
5455 assert_eq!(config.drain_timeout, Duration::from_secs(60));
5456 }
5457
5458 #[test]
5459 fn config_drain_timeout_can_be_set_secs() {
5460 let config = ServerConfig::new("127.0.0.1:8080").with_drain_timeout_secs(45);
5461 assert_eq!(config.drain_timeout, Duration::from_secs(45));
5462 }
5463
5464 #[test]
5465 fn server_not_draining_initially() {
5466 let server = TcpServer::default();
5467 assert!(!server.is_draining());
5468 }
5469
5470 #[test]
5471 fn server_start_drain_sets_flag() {
5472 let server = TcpServer::default();
5473 assert!(!server.is_draining());
5474 server.start_drain();
5475 assert!(server.is_draining());
5476 }
5477
5478 #[test]
5479 fn server_start_drain_idempotent() {
5480 let server = TcpServer::default();
5481 server.start_drain();
5482 assert!(server.is_draining());
5483 server.start_drain();
5484 assert!(server.is_draining());
5485 }
5486
5487 #[test]
5488 fn wait_for_drain_returns_true_when_no_connections() {
5489 block_on(async {
5490 let server = TcpServer::default();
5491 assert_eq!(server.current_connections(), 0);
5492 let result = server
5493 .wait_for_drain(Duration::from_millis(100), Some(Duration::from_millis(1)))
5494 .await;
5495 assert!(result);
5496 });
5497 }
5498
5499 #[test]
5500 fn wait_for_drain_timeout_with_connections() {
5501 block_on(async {
5502 let server = TcpServer::default();
5503 server.try_acquire_connection();
5505 server.try_acquire_connection();
5506 assert_eq!(server.current_connections(), 2);
5507
5508 let result = server
5510 .wait_for_drain(Duration::from_millis(50), Some(Duration::from_millis(5)))
5511 .await;
5512 assert!(!result);
5513 assert_eq!(server.current_connections(), 2);
5514 });
5515 }
5516
5517 #[test]
5518 fn drain_returns_zero_when_no_connections() {
5519 block_on(async {
5520 let server = TcpServer::new(
5521 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_millis(100)),
5522 );
5523 assert_eq!(server.current_connections(), 0);
5524 let remaining = server.drain().await;
5525 assert_eq!(remaining, 0);
5526 assert!(server.is_draining());
5527 });
5528 }
5529
5530 #[test]
5531 fn drain_returns_count_when_connections_remain() {
5532 block_on(async {
5533 let server = TcpServer::new(
5534 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_millis(50)),
5535 );
5536 server.try_acquire_connection();
5538 server.try_acquire_connection();
5539 server.try_acquire_connection();
5540
5541 let remaining = server.drain().await;
5542 assert_eq!(remaining, 3);
5543 assert!(server.is_draining());
5544 });
5545 }
5546
5547 #[test]
5548 fn server_shutdown_error_display() {
5549 let err = ServerError::Shutdown;
5550 assert_eq!(err.to_string(), "Server shutdown");
5551 }
5552
5553 #[test]
5558 fn server_has_shutdown_controller() {
5559 let server = TcpServer::default();
5560 let controller = server.shutdown_controller();
5561 assert!(!controller.is_shutting_down());
5562 }
5563
5564 #[test]
5565 fn server_subscribe_shutdown_returns_receiver() {
5566 let server = TcpServer::default();
5567 let receiver = server.subscribe_shutdown();
5568 assert!(!receiver.is_shutting_down());
5569 }
5570
5571 #[test]
5572 fn server_shutdown_sets_draining_and_controller() {
5573 let server = TcpServer::default();
5574 assert!(!server.is_shutting_down());
5575 assert!(!server.is_draining());
5576 assert!(!server.shutdown_controller().is_shutting_down());
5577
5578 server.shutdown();
5579
5580 assert!(server.is_shutting_down());
5581 assert!(server.is_draining());
5582 assert!(server.shutdown_controller().is_shutting_down());
5583 }
5584
5585 #[test]
5586 fn server_shutdown_notifies_receivers() {
5587 let server = TcpServer::default();
5588 let receiver1 = server.subscribe_shutdown();
5589 let receiver2 = server.subscribe_shutdown();
5590
5591 assert!(!receiver1.is_shutting_down());
5592 assert!(!receiver2.is_shutting_down());
5593
5594 server.shutdown();
5595
5596 assert!(receiver1.is_shutting_down());
5597 assert!(receiver2.is_shutting_down());
5598 }
5599
5600 #[test]
5601 fn server_shutdown_is_idempotent() {
5602 let server = TcpServer::default();
5603 let receiver = server.subscribe_shutdown();
5604
5605 server.shutdown();
5606 server.shutdown();
5607 server.shutdown();
5608
5609 assert!(server.is_shutting_down());
5610 assert!(receiver.is_shutting_down());
5611 }
5612
5613 #[test]
5618 fn keep_alive_timeout_error_display() {
5619 let err = ServerError::KeepAliveTimeout;
5620 assert_eq!(err.to_string(), "Keep-alive timeout");
5621 }
5622
5623 #[test]
5624 fn keep_alive_timeout_zero_disables_timeout() {
5625 let config = ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout(Duration::ZERO);
5626 assert!(config.keep_alive_timeout.is_zero());
5627 }
5628
5629 #[test]
5630 fn keep_alive_timeout_default_is_non_zero() {
5631 let config = ServerConfig::default();
5632 assert!(!config.keep_alive_timeout.is_zero());
5633 assert_eq!(
5634 config.keep_alive_timeout,
5635 Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS)
5636 );
5637 }
5638
5639 #[test]
5640 fn timed_out_io_error_kind() {
5641 let err = io::Error::new(io::ErrorKind::TimedOut, "test timeout");
5642 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
5643 }
5644
5645 #[test]
5646 fn instant_deadline_calculation() {
5647 let timeout = Duration::from_millis(100);
5648 let deadline = Instant::now() + timeout;
5649
5650 assert!(deadline > Instant::now());
5652
5653 std::thread::sleep(Duration::from_millis(150));
5655 assert!(Instant::now() >= deadline);
5656 }
5657
5658 #[test]
5659 fn server_metrics_initial_state() {
5660 let server = TcpServer::default();
5661 let m = server.metrics();
5662 assert_eq!(m.active_connections, 0);
5663 assert_eq!(m.total_accepted, 0);
5664 assert_eq!(m.total_rejected, 0);
5665 assert_eq!(m.total_timed_out, 0);
5666 assert_eq!(m.total_requests, 0);
5667 assert_eq!(m.bytes_in, 0);
5668 assert_eq!(m.bytes_out, 0);
5669 }
5670
5671 #[test]
5672 fn server_metrics_after_acquire_release() {
5673 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(10));
5674 assert!(server.try_acquire_connection());
5675 assert!(server.try_acquire_connection());
5676
5677 let m = server.metrics();
5678 assert_eq!(m.active_connections, 2);
5679 assert_eq!(m.total_accepted, 2);
5680 assert_eq!(m.total_rejected, 0);
5681
5682 server.release_connection();
5683 let m = server.metrics();
5684 assert_eq!(m.active_connections, 1);
5685 assert_eq!(m.total_accepted, 2); }
5687
5688 #[test]
5689 fn server_metrics_rejection_counted() {
5690 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(1));
5691 assert!(server.try_acquire_connection());
5692 assert!(!server.try_acquire_connection()); let m = server.metrics();
5695 assert_eq!(m.total_accepted, 1);
5696 assert_eq!(m.total_rejected, 1);
5697 assert_eq!(m.active_connections, 1);
5698 }
5699
5700 #[test]
5701 fn server_metrics_bytes_tracking() {
5702 let server = TcpServer::default();
5703 server.record_bytes_in(1024);
5704 server.record_bytes_in(512);
5705 server.record_bytes_out(2048);
5706
5707 let m = server.metrics();
5708 assert_eq!(m.bytes_in, 1536);
5709 assert_eq!(m.bytes_out, 2048);
5710 }
5711
5712 #[test]
5713 fn server_metrics_unlimited_connections_accepted() {
5714 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(0));
5715 for _ in 0..100 {
5716 assert!(server.try_acquire_connection());
5717 }
5718 let m = server.metrics();
5719 assert_eq!(m.total_accepted, 100);
5720 assert_eq!(m.total_rejected, 0);
5721 assert_eq!(m.active_connections, 100);
5722 }
5723
5724 #[test]
5725 fn server_metrics_clone_eq() {
5726 let server = TcpServer::default();
5727 server.record_bytes_in(42);
5728 let m1 = server.metrics();
5729 let m2 = m1.clone();
5730 assert_eq!(m1, m2);
5731 }
5732}
5733
5734pub trait AppServeExt {
5757 fn serve(self, addr: impl Into<String>) -> impl Future<Output = Result<(), ServeError>> + Send;
5797
5798 fn serve_with_config(
5821 self,
5822 config: ServerConfig,
5823 ) -> impl Future<Output = Result<(), ServeError>> + Send;
5824}
5825
5826#[derive(Debug)]
5828pub enum ServeError {
5829 Startup(fastapi_core::StartupHookError),
5831 Server(ServerError),
5833}
5834
5835impl std::fmt::Display for ServeError {
5836 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5837 match self {
5838 Self::Startup(e) => write!(f, "startup hook failed: {}", e.message),
5839 Self::Server(e) => write!(f, "server error: {e}"),
5840 }
5841 }
5842}
5843
5844impl std::error::Error for ServeError {
5845 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
5846 match self {
5847 Self::Startup(_) => None,
5848 Self::Server(e) => Some(e),
5849 }
5850 }
5851}
5852
5853impl From<ServerError> for ServeError {
5854 fn from(e: ServerError) -> Self {
5855 Self::Server(e)
5856 }
5857}
5858
5859impl AppServeExt for App {
5860 fn serve(self, addr: impl Into<String>) -> impl Future<Output = Result<(), ServeError>> + Send {
5861 let config = ServerConfig::new(addr);
5862 self.serve_with_config(config)
5863 }
5864
5865 #[allow(clippy::manual_async_fn)] fn serve_with_config(
5867 self,
5868 config: ServerConfig,
5869 ) -> impl Future<Output = Result<(), ServeError>> + Send {
5870 async move {
5871 match self.run_startup_hooks().await {
5873 fastapi_core::StartupOutcome::Success => {}
5874 fastapi_core::StartupOutcome::PartialSuccess { warnings } => {
5875 eprintln!("Warning: {warnings} startup hook(s) had non-fatal errors");
5877 }
5878 fastapi_core::StartupOutcome::Aborted(e) => {
5879 return Err(ServeError::Startup(e));
5880 }
5881 }
5882
5883 let server = TcpServer::new(config);
5885
5886 let app = Arc::new(self);
5888
5889 let cx = Cx::for_testing();
5891
5892 let bind_addr = &server.config().bind_addr;
5894 println!("🚀 Server starting on http://{bind_addr}");
5895
5896 let result = server.serve_app(&cx, Arc::clone(&app)).await;
5898
5899 app.run_shutdown_hooks().await;
5901
5902 result.map_err(ServeError::from)
5903 }
5904 }
5905}
5906
5907pub async fn serve(app: App, addr: impl Into<String>) -> Result<(), ServeError> {
5925 app.serve(addr).await
5926}
5927
5928pub async fn serve_with_config(app: App, config: ServerConfig) -> Result<(), ServeError> {
5946 app.serve_with_config(config).await
5947}