1#![allow(dead_code)]
10use crate::connection::should_keep_alive;
49use crate::expect::{
50 CONTINUE_RESPONSE, ExpectHandler, ExpectResult, PreBodyValidator, PreBodyValidators,
51};
52use crate::http2;
53use crate::parser::{ParseError, ParseLimits, ParseStatus, Parser, StatefulParser};
54use crate::response::{ResponseWrite, ResponseWriter};
55use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
56use asupersync::net::{TcpListener, TcpStream};
57use asupersync::runtime::{JoinHandle, Runtime, RuntimeHandle, SpawnError};
58use asupersync::signal::{GracefulOutcome, ShutdownController, ShutdownReceiver};
59use asupersync::stream::Stream;
60use asupersync::time::timeout;
61use asupersync::{Budget, Cx, Time};
62use fastapi_core::app::App;
63use fastapi_core::{Method, Request, RequestContext, Response, StatusCode};
64use std::future::Future;
65use std::io;
66use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
67use std::pin::Pin;
68use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
69use std::sync::{Arc, Mutex, OnceLock};
70use std::task::Poll;
71use std::time::{Duration, Instant};
72
73static START_TIME: OnceLock<Instant> = OnceLock::new();
76
77fn current_time() -> Time {
82 let start = START_TIME.get_or_init(Instant::now);
83 let now = Instant::now();
84 if now < *start {
85 Time::ZERO
86 } else {
87 let elapsed = now.duration_since(*start);
88 Time::from_nanos(elapsed.as_nanos() as u64)
89 }
90}
91
92pub const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30;
94
95pub const DEFAULT_READ_BUFFER_SIZE: usize = 8192;
97
98pub const DEFAULT_MAX_CONNECTIONS: usize = 0;
100
101pub const DEFAULT_KEEP_ALIVE_TIMEOUT_SECS: u64 = 75;
103
104pub const DEFAULT_MAX_REQUESTS_PER_CONNECTION: usize = 100;
106
107pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
109
110struct CatchUnwind<F>(Pin<Box<F>>);
111
112impl<F: Future> CatchUnwind<F> {
113 fn new(future: F) -> Self {
114 Self(Box::pin(future))
115 }
116}
117
118impl<F: Future> Future for CatchUnwind<F> {
119 type Output = std::thread::Result<F::Output>;
120
121 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
122 let inner = self.0.as_mut();
123 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| inner.poll(cx)));
124 match result {
125 Ok(Poll::Pending) => Poll::Pending,
126 Ok(Poll::Ready(output)) => Poll::Ready(Ok(output)),
127 Err(payload) => Poll::Ready(Err(payload)),
128 }
129 }
130}
131
132fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
133 if let Some(message) = payload.downcast_ref::<&'static str>() {
134 (*message).to_string()
135 } else if let Some(message) = payload.downcast_ref::<String>() {
136 message.clone()
137 } else {
138 "non-string panic payload".to_string()
139 }
140}
141
142#[derive(Debug, Clone)]
170pub struct ServerConfig {
171 pub bind_addr: String,
173 pub request_timeout: Time,
175 pub max_connections: usize,
177 pub read_buffer_size: usize,
179 pub parse_limits: ParseLimits,
181 pub allowed_hosts: Vec<String>,
183 pub trust_x_forwarded_host: bool,
185 pub tcp_nodelay: bool,
187 pub keep_alive_timeout: Duration,
190 pub max_requests_per_connection: usize,
192 pub drain_timeout: Duration,
196 pub pre_body_validators: PreBodyValidators,
201}
202
203impl ServerConfig {
204 #[must_use]
206 pub fn new(bind_addr: impl Into<String>) -> Self {
207 Self {
208 bind_addr: bind_addr.into(),
209 request_timeout: Time::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS),
210 max_connections: DEFAULT_MAX_CONNECTIONS,
211 read_buffer_size: DEFAULT_READ_BUFFER_SIZE,
212 parse_limits: ParseLimits::default(),
213 allowed_hosts: Vec::new(),
214 trust_x_forwarded_host: false,
215 tcp_nodelay: true,
216 keep_alive_timeout: Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS),
217 max_requests_per_connection: DEFAULT_MAX_REQUESTS_PER_CONNECTION,
218 drain_timeout: Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS),
219 pre_body_validators: PreBodyValidators::new(),
220 }
221 }
222
223 #[must_use]
225 pub fn with_request_timeout(mut self, timeout: Time) -> Self {
226 self.request_timeout = timeout;
227 self
228 }
229
230 #[must_use]
232 pub fn with_request_timeout_secs(mut self, secs: u64) -> Self {
233 self.request_timeout = Time::from_secs(secs);
234 self
235 }
236
237 #[must_use]
239 pub fn with_max_connections(mut self, max: usize) -> Self {
240 self.max_connections = max;
241 self
242 }
243
244 #[must_use]
246 pub fn with_read_buffer_size(mut self, size: usize) -> Self {
247 self.read_buffer_size = size;
248 self
249 }
250
251 #[must_use]
253 pub fn with_parse_limits(mut self, limits: ParseLimits) -> Self {
254 self.parse_limits = limits;
255 self
256 }
257
258 #[must_use]
263 pub fn with_allowed_hosts<I, S>(mut self, hosts: I) -> Self
264 where
265 I: IntoIterator<Item = S>,
266 S: Into<String>,
267 {
268 self.allowed_hosts = hosts
270 .into_iter()
271 .map(|s| s.into().to_ascii_lowercase())
272 .collect();
273 self
274 }
275
276 #[must_use]
280 pub fn allow_host(mut self, host: impl Into<String>) -> Self {
281 self.allowed_hosts.push(host.into().to_ascii_lowercase());
283 self
284 }
285
286 #[must_use]
288 pub fn with_trust_x_forwarded_host(mut self, trust: bool) -> Self {
289 self.trust_x_forwarded_host = trust;
290 self
291 }
292
293 #[must_use]
295 pub fn with_tcp_nodelay(mut self, enabled: bool) -> Self {
296 self.tcp_nodelay = enabled;
297 self
298 }
299
300 #[must_use]
302 pub fn with_pre_body_validators(mut self, validators: PreBodyValidators) -> Self {
303 self.pre_body_validators = validators;
304 self
305 }
306
307 #[must_use]
309 pub fn with_pre_body_validator<V: PreBodyValidator + 'static>(mut self, validator: V) -> Self {
310 self.pre_body_validators.add(validator);
311 self
312 }
313
314 #[must_use]
319 pub fn with_keep_alive_timeout(mut self, timeout: Duration) -> Self {
320 self.keep_alive_timeout = timeout;
321 self
322 }
323
324 #[must_use]
326 pub fn with_keep_alive_timeout_secs(mut self, secs: u64) -> Self {
327 self.keep_alive_timeout = Duration::from_secs(secs);
328 self
329 }
330
331 #[must_use]
335 pub fn with_max_requests_per_connection(mut self, max: usize) -> Self {
336 self.max_requests_per_connection = max;
337 self
338 }
339
340 #[must_use]
346 pub fn with_drain_timeout(mut self, timeout: Duration) -> Self {
347 self.drain_timeout = timeout;
348 self
349 }
350
351 #[must_use]
353 pub fn with_drain_timeout_secs(mut self, secs: u64) -> Self {
354 self.drain_timeout = Duration::from_secs(secs);
355 self
356 }
357}
358
359impl Default for ServerConfig {
360 fn default() -> Self {
361 Self::new("127.0.0.1:8080")
362 }
363}
364
365#[derive(Debug)]
367pub enum ServerError {
368 Io(io::Error),
370 Parse(ParseError),
372 Http2(http2::Http2Error),
374 Shutdown,
376 ConnectionLimitReached,
378 KeepAliveTimeout,
380}
381
382impl std::fmt::Display for ServerError {
383 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 match self {
385 Self::Io(e) => write!(f, "IO error: {e}"),
386 Self::Parse(e) => write!(f, "Parse error: {e}"),
387 Self::Http2(e) => write!(f, "HTTP/2 error: {e}"),
388 Self::Shutdown => write!(f, "Server shutdown"),
389 Self::ConnectionLimitReached => write!(f, "Connection limit reached"),
390 Self::KeepAliveTimeout => write!(f, "Keep-alive timeout"),
391 }
392 }
393}
394
395impl std::error::Error for ServerError {
396 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
397 match self {
398 Self::Io(e) => Some(e),
399 Self::Parse(e) => Some(e),
400 Self::Http2(e) => Some(e),
401 _ => None,
402 }
403 }
404}
405
406#[derive(Debug, Clone, PartialEq, Eq)]
411enum HostValidationErrorKind {
412 Missing,
413 Invalid,
414 NotAllowed,
415}
416
417#[derive(Debug, Clone)]
418struct HostValidationError {
419 kind: HostValidationErrorKind,
420 detail: String,
421}
422
423impl HostValidationError {
424 fn missing() -> Self {
425 Self {
426 kind: HostValidationErrorKind::Missing,
427 detail: "missing Host header".to_string(),
428 }
429 }
430
431 fn invalid(detail: impl Into<String>) -> Self {
432 Self {
433 kind: HostValidationErrorKind::Invalid,
434 detail: detail.into(),
435 }
436 }
437
438 fn not_allowed(detail: impl Into<String>) -> Self {
439 Self {
440 kind: HostValidationErrorKind::NotAllowed,
441 detail: detail.into(),
442 }
443 }
444
445 fn response(&self) -> Response {
446 let message = match self.kind {
447 HostValidationErrorKind::Missing => "Bad Request: Host header required",
448 HostValidationErrorKind::Invalid => "Bad Request: invalid Host header",
449 HostValidationErrorKind::NotAllowed => "Bad Request: Host not allowed",
450 };
451 Response::with_status(StatusCode::BAD_REQUEST).body(fastapi_core::ResponseBody::Bytes(
452 message.as_bytes().to_vec(),
453 ))
454 }
455}
456
457#[derive(Debug, Clone, PartialEq, Eq)]
458struct HostHeader {
459 host: String,
460 port: Option<u16>,
461}
462
463fn validate_host_header(
464 request: &Request,
465 config: &ServerConfig,
466) -> Result<HostHeader, HostValidationError> {
467 let raw = extract_effective_host(request, config)?;
468 let parsed = parse_host_header(&raw)
469 .ok_or_else(|| HostValidationError::invalid(format!("invalid host value: {raw}")))?;
470
471 if !is_allowed_host(&parsed, &config.allowed_hosts) {
472 return Err(HostValidationError::not_allowed(format!(
473 "host not allowed: {}",
474 parsed.host
475 )));
476 }
477
478 Ok(parsed)
479}
480
481fn extract_effective_host(
482 request: &Request,
483 config: &ServerConfig,
484) -> Result<String, HostValidationError> {
485 if config.trust_x_forwarded_host {
486 if let Some(value) = header_value(request, "x-forwarded-host")? {
487 let forwarded = extract_first_list_value(&value)
488 .ok_or_else(|| HostValidationError::invalid("empty X-Forwarded-Host value"))?;
489 return Ok(forwarded.to_string());
490 }
491 }
492
493 match header_value(request, "host")? {
494 Some(value) => Ok(value),
495 None => Err(HostValidationError::missing()),
496 }
497}
498
499fn header_value(request: &Request, name: &str) -> Result<Option<String>, HostValidationError> {
500 request
501 .headers()
502 .get(name)
503 .map(|bytes| {
504 std::str::from_utf8(bytes)
505 .map(|s| s.trim().to_string())
506 .map_err(|_| {
507 HostValidationError::invalid(format!("invalid UTF-8 in {name} header"))
508 })
509 })
510 .transpose()
511}
512
513fn extract_first_list_value(value: &str) -> Option<&str> {
514 value.split(',').map(str::trim).find(|v| !v.is_empty())
515}
516
517fn parse_host_header(value: &str) -> Option<HostHeader> {
518 let value = value.trim();
519 if value.is_empty() {
520 return None;
521 }
522 if value.chars().any(|c| c.is_control() || c.is_whitespace()) {
523 return None;
524 }
525
526 if value.starts_with('[') {
527 let end = value.find(']')?;
528 let host = &value[1..end];
529 if host.is_empty() {
530 return None;
531 }
532 if host.parse::<Ipv6Addr>().is_err() {
533 return None;
534 }
535 let rest = &value[end + 1..];
536 let port = if rest.is_empty() {
537 None
538 } else if let Some(port_str) = rest.strip_prefix(':') {
539 parse_port(port_str)
540 } else {
541 return None;
542 };
543 return Some(HostHeader {
544 host: host.to_ascii_lowercase(),
545 port,
546 });
547 }
548
549 let mut parts = value.split(':');
550 let host = parts.next().unwrap_or("");
551 let port_part = parts.next();
552 if parts.next().is_some() {
553 return None;
555 }
556 if host.is_empty() {
557 return None;
558 }
559
560 let port = match port_part {
561 Some(p) => parse_port(p),
562 None => None,
563 };
564
565 if host.parse::<Ipv4Addr>().is_ok() || is_valid_hostname(host) {
566 Some(HostHeader {
567 host: host.to_ascii_lowercase(),
568 port,
569 })
570 } else {
571 None
572 }
573}
574
575fn parse_port(port: &str) -> Option<u16> {
576 if port.is_empty() || !port.chars().all(|c| c.is_ascii_digit()) {
577 return None;
578 }
579 let value = port.parse::<u16>().ok()?;
580 if value == 0 { None } else { Some(value) }
581}
582
583fn is_valid_hostname(host: &str) -> bool {
584 if host.len() > 253 {
586 return false;
587 }
588 for label in host.split('.') {
589 if label.is_empty() || label.len() > 63 {
590 return false;
591 }
592 let bytes = label.as_bytes();
593 if bytes.first() == Some(&b'-') || bytes.last() == Some(&b'-') {
594 return false;
595 }
596 if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
597 return false;
598 }
599 }
600 true
601}
602
603fn is_allowed_host(host: &HostHeader, allowed_hosts: &[String]) -> bool {
604 if allowed_hosts.is_empty() {
605 return true;
606 }
607
608 allowed_hosts
609 .iter()
610 .any(|pattern| host_matches_pattern(host, pattern))
611}
612
613fn host_matches_pattern(host: &HostHeader, pattern: &str) -> bool {
614 let pattern = pattern.trim();
616 if pattern.is_empty() {
617 return false;
618 }
619 if pattern == "*" {
620 return true;
621 }
622 if let Some(suffix) = pattern.strip_prefix("*.") {
623 if host.host == suffix {
625 return false;
626 }
627 return host.host.len() > suffix.len() + 1
628 && host.host.ends_with(suffix)
629 && host.host.as_bytes()[host.host.len() - suffix.len() - 1] == b'.';
630 }
631
632 if let Some(parsed) = parse_host_header(pattern) {
633 if parsed.host != host.host {
634 return false;
635 }
636 if let Some(port) = parsed.port {
637 return host.port == Some(port);
638 }
639 return true;
640 }
641
642 false
643}
644
645fn header_str<'a>(req: &'a Request, name: &str) -> Option<&'a str> {
646 req.headers()
647 .get(name)
648 .and_then(|v| std::str::from_utf8(v).ok())
649 .map(str::trim)
650}
651
652fn header_has_token(req: &Request, name: &str, token: &str) -> bool {
653 let Some(v) = header_str(req, name) else {
654 return false;
655 };
656 v.split(',')
657 .map(str::trim)
658 .any(|t| t.eq_ignore_ascii_case(token))
659}
660
661fn connection_has_token(req: &Request, token: &str) -> bool {
662 header_has_token(req, "connection", token)
663}
664
665fn is_websocket_upgrade_request(req: &Request) -> bool {
666 if req.method() != Method::Get {
667 return false;
668 }
669 if !header_has_token(req, "upgrade", "websocket") {
670 return false;
671 }
672 connection_has_token(req, "upgrade")
673}
674
675fn has_request_body_headers(req: &Request) -> bool {
676 if req.headers().contains("transfer-encoding") {
677 return true;
678 }
679 if let Some(v) = header_str(req, "content-length") {
680 if v.is_empty() {
681 return true;
682 }
683 match v.parse::<usize>() {
684 Ok(0) => false,
685 Ok(_) => true,
686 Err(_) => true,
687 }
688 } else {
689 false
690 }
691}
692
693impl From<io::Error> for ServerError {
694 fn from(e: io::Error) -> Self {
695 Self::Io(e)
696 }
697}
698
699impl From<ParseError> for ServerError {
700 fn from(e: ParseError) -> Self {
701 Self::Parse(e)
702 }
703}
704
705impl From<http2::Http2Error> for ServerError {
706 fn from(e: http2::Http2Error) -> Self {
707 Self::Http2(e)
708 }
709}
710
711pub async fn process_connection<H, Fut>(
720 cx: &Cx,
721 request_counter: &AtomicU64,
722 mut stream: TcpStream,
723 _peer_addr: SocketAddr,
724 config: &ServerConfig,
725 handler: H,
726) -> Result<(), ServerError>
727where
728 H: Fn(RequestContext, &mut Request) -> Fut,
729 Fut: Future<Output = Response>,
730{
731 let (proto, buffered) = sniff_protocol(&mut stream, config.keep_alive_timeout).await?;
732 if proto == SniffedProtocol::Http2PriorKnowledge {
733 return process_connection_http2(cx, request_counter, stream, config, handler).await;
734 }
735
736 let mut parser = StatefulParser::new().with_limits(config.parse_limits.clone());
737 if !buffered.is_empty() {
738 parser.feed(&buffered)?;
739 }
740 let mut read_buffer = vec![0u8; config.read_buffer_size];
741 let mut response_writer = ResponseWriter::new();
742 let mut requests_on_connection: usize = 0;
743 let max_requests = config.max_requests_per_connection;
744
745 loop {
746 if cx.is_cancel_requested() {
748 return Ok(());
749 }
750
751 let parse_result = parser.feed(&[])?;
753
754 let mut request = match parse_result {
755 ParseStatus::Complete { request, .. } => request,
756 ParseStatus::Incomplete => {
757 let keep_alive_timeout = config.keep_alive_timeout;
758
759 let bytes_read = if keep_alive_timeout.is_zero() {
760 read_into_buffer(&mut stream, &mut read_buffer).await?
761 } else {
762 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout).await
763 {
764 Ok(0) => return Ok(()),
765 Ok(n) => n,
766 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
767 cx.trace(&format!(
768 "Keep-alive timeout ({:?}) - closing idle connection",
769 keep_alive_timeout
770 ));
771 return Err(ServerError::KeepAliveTimeout);
772 }
773 Err(e) => return Err(ServerError::Io(e)),
774 }
775 };
776
777 if bytes_read == 0 {
778 return Ok(());
779 }
780
781 match parser.feed(&read_buffer[..bytes_read])? {
782 ParseStatus::Complete { request, .. } => request,
783 ParseStatus::Incomplete => continue,
784 }
785 }
786 };
787
788 requests_on_connection += 1;
789
790 let request_id = request_counter.fetch_add(1, Ordering::Relaxed);
792 let request_budget = Budget::new().with_deadline(config.request_timeout);
793 let request_cx = Cx::for_testing_with_budget(request_budget);
794 let ctx = RequestContext::new(request_cx, request_id);
795
796 if let Err(err) = validate_host_header(&request, config) {
798 ctx.trace(&format!("Rejecting request: {}", err.detail));
799 let response = err.response().header("connection", b"close".to_vec());
800 let response_write = response_writer.write(response);
801 write_response(&mut stream, response_write).await?;
802 return Ok(());
803 }
804
805 if let Err(response) = config.pre_body_validators.validate_all(&request) {
807 let response = response.header("connection", b"close".to_vec());
808 let response_write = response_writer.write(response);
809 write_response(&mut stream, response_write).await?;
810 return Ok(());
811 }
812
813 match ExpectHandler::check_expect(&request) {
817 ExpectResult::NoExpectation => {
818 }
820 ExpectResult::ExpectsContinue => {
821 ctx.trace("Sending 100 Continue for Expect: 100-continue");
824 write_raw_response(&mut stream, CONTINUE_RESPONSE).await?;
825 }
826 ExpectResult::UnknownExpectation(value) => {
827 ctx.trace(&format!("Rejecting unknown Expect value: {}", value));
829 let response =
830 ExpectHandler::expectation_failed(format!("Unsupported Expect value: {value}"));
831 let response_write = response_writer.write(response);
832 write_response(&mut stream, response_write).await?;
833 return Ok(());
834 }
835 }
836
837 let client_wants_keep_alive = should_keep_alive(&request);
838 let at_max_requests = max_requests > 0 && requests_on_connection >= max_requests;
839 let server_will_keep_alive = client_wants_keep_alive && !at_max_requests;
840
841 let request_start = Instant::now();
842 let timeout_duration = Duration::from_nanos(config.request_timeout.as_nanos());
843
844 let response = handler(ctx, &mut request).await;
846
847 let mut response = if request_start.elapsed() > timeout_duration {
848 Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
849 fastapi_core::ResponseBody::Bytes(
850 b"Gateway Timeout: request processing exceeded time limit".to_vec(),
851 ),
852 )
853 } else {
854 response
855 };
856
857 response = if server_will_keep_alive {
858 response.header("connection", b"keep-alive".to_vec())
859 } else {
860 response.header("connection", b"close".to_vec())
861 };
862
863 let response_write = response_writer.write(response);
864 write_response(&mut stream, response_write).await?;
865
866 if let Some(tasks) = App::take_background_tasks(&mut request) {
867 tasks.execute_all().await;
868 }
869
870 if !server_will_keep_alive {
871 return Ok(());
872 }
873 }
874}
875
876async fn process_connection_http2<H, Fut>(
877 cx: &Cx,
878 request_counter: &AtomicU64,
879 stream: TcpStream,
880 config: &ServerConfig,
881 handler: H,
882) -> Result<(), ServerError>
883where
884 H: Fn(RequestContext, &mut Request) -> Fut,
885 Fut: Future<Output = Response>,
886{
887 const FLAG_END_HEADERS: u8 = 0x4;
888 const FLAG_ACK: u8 = 0x1;
889
890 let mut framed = http2::FramedH2::new(stream, Vec::new());
891 let mut hpack = http2::HpackDecoder::new();
892 let recv_max_frame_size: u32 = 16 * 1024;
893 let mut peer_max_frame_size: u32 = 16 * 1024;
894 let mut flow_control = http2::H2FlowControl::new();
895
896 let first = framed.read_frame(recv_max_frame_size).await?;
897 if first.header.frame_type() != http2::FrameType::Settings
898 || first.header.stream_id != 0
899 || (first.header.flags & FLAG_ACK) != 0
900 {
901 return Err(http2::Http2Error::Protocol("expected client SETTINGS after preface").into());
902 }
903 apply_http2_settings_with_fc(
904 &mut hpack,
905 &mut peer_max_frame_size,
906 Some(&mut flow_control),
907 &first.payload,
908 )?;
909
910 framed
911 .write_frame(http2::FrameType::Settings, 0, 0, SERVER_SETTINGS_PAYLOAD)
912 .await?;
913 framed
914 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
915 .await?;
916
917 let default_body_limit = config.parse_limits.max_request_size;
918 let mut last_stream_id: u32 = 0;
919
920 loop {
921 if cx.is_cancel_requested() {
922 let _ = send_goaway(&mut framed, last_stream_id, h2_error_code::NO_ERROR).await;
923 return Ok(());
924 }
925
926 let frame = framed.read_frame(recv_max_frame_size).await?;
927 match frame.header.frame_type() {
928 http2::FrameType::Settings => {
929 let is_ack = validate_settings_frame(
930 frame.header.stream_id,
931 frame.header.flags,
932 &frame.payload,
933 )?;
934 if is_ack {
935 continue;
936 }
937 apply_http2_settings_with_fc(
938 &mut hpack,
939 &mut peer_max_frame_size,
940 Some(&mut flow_control),
941 &frame.payload,
942 )?;
943 framed
944 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
945 .await?;
946 }
947 http2::FrameType::Ping => {
948 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
949 return Err(http2::Http2Error::Protocol("invalid PING frame").into());
950 }
951 if (frame.header.flags & FLAG_ACK) == 0 {
952 framed
953 .write_frame(http2::FrameType::Ping, FLAG_ACK, 0, &frame.payload)
954 .await?;
955 }
956 }
957 http2::FrameType::Goaway => {
958 validate_goaway_payload(&frame.payload)?;
959 return Ok(());
960 }
961 http2::FrameType::PushPromise => {
962 return Err(
963 http2::Http2Error::Protocol("PUSH_PROMISE not supported by server").into(),
964 );
965 }
966 http2::FrameType::Headers => {
967 let stream_id = frame.header.stream_id;
968 if stream_id == 0 {
969 return Err(
970 http2::Http2Error::Protocol("HEADERS must not be on stream 0").into(),
971 );
972 }
973 if stream_id % 2 == 0 {
974 return Err(http2::Http2Error::Protocol(
975 "client-initiated stream ID must be odd",
976 )
977 .into());
978 }
979 if stream_id <= last_stream_id {
980 return Err(http2::Http2Error::Protocol(
981 "stream ID must be greater than previous",
982 )
983 .into());
984 }
985 last_stream_id = stream_id;
986 let (end_stream, mut header_block) =
987 extract_header_block_fragment(frame.header.flags, &frame.payload)?;
988
989 if (frame.header.flags & FLAG_END_HEADERS) == 0 {
990 loop {
991 let cont = framed.read_frame(recv_max_frame_size).await?;
992 if cont.header.frame_type() != http2::FrameType::Continuation
993 || cont.header.stream_id != stream_id
994 {
995 return Err(http2::Http2Error::Protocol(
996 "expected CONTINUATION for header block",
997 )
998 .into());
999 }
1000 header_block.extend_from_slice(&cont.payload);
1001 if header_block.len() > MAX_HEADER_BLOCK_SIZE {
1002 return Err(http2::Http2Error::Protocol(
1003 "header block exceeds maximum size",
1004 )
1005 .into());
1006 }
1007 if (cont.header.flags & FLAG_END_HEADERS) != 0 {
1008 break;
1009 }
1010 }
1011 }
1012
1013 let headers = hpack
1014 .decode(&header_block)
1015 .map_err(http2::Http2Error::from)?;
1016 let mut request = request_from_h2_headers(headers)?;
1017
1018 if !end_stream {
1019 let mut body = Vec::new();
1020 let mut stream_reset = false;
1021 let mut stream_received: u32 = 0;
1022 loop {
1023 let f = framed.read_frame(recv_max_frame_size).await?;
1024 match f.header.frame_type() {
1025 http2::FrameType::Data if f.header.stream_id == 0 => {
1026 return Err(http2::Http2Error::Protocol(
1027 "DATA must not be on stream 0",
1028 )
1029 .into());
1030 }
1031 http2::FrameType::Data if f.header.stream_id == stream_id => {
1032 let (data, data_end_stream) =
1033 extract_data_payload(f.header.flags, &f.payload)?;
1034 if body.len().saturating_add(data.len()) > default_body_limit {
1035 return Err(http2::Http2Error::Protocol(
1036 "request body exceeds configured limit",
1037 )
1038 .into());
1039 }
1040 body.extend_from_slice(data);
1041
1042 let data_len = u32::try_from(data.len()).unwrap_or(u32::MAX);
1045 stream_received += data_len;
1046 let conn_inc = flow_control.data_received_connection(data_len);
1047 let stream_inc = flow_control.stream_window_update(stream_received);
1048 if stream_inc > 0 {
1049 stream_received = 0;
1050 }
1051 send_window_updates(&mut framed, conn_inc, stream_id, stream_inc)
1052 .await?;
1053
1054 if data_end_stream {
1055 break;
1056 }
1057 }
1058 http2::FrameType::RstStream => {
1059 validate_rst_stream_payload(f.header.stream_id, &f.payload)?;
1060 if f.header.stream_id == stream_id {
1061 stream_reset = true;
1062 break;
1063 }
1064 }
1065 http2::FrameType::PushPromise => {
1066 return Err(http2::Http2Error::Protocol(
1067 "PUSH_PROMISE not supported by server",
1068 )
1069 .into());
1070 }
1071 http2::FrameType::Settings
1072 | http2::FrameType::Ping
1073 | http2::FrameType::Goaway
1074 | http2::FrameType::WindowUpdate
1075 | http2::FrameType::Priority
1076 | http2::FrameType::Unknown => {
1077 if f.header.frame_type() == http2::FrameType::Goaway {
1078 validate_goaway_payload(&f.payload)?;
1079 return Ok(());
1080 }
1081 if f.header.frame_type() == http2::FrameType::Priority {
1082 validate_priority_payload(f.header.stream_id, &f.payload)?;
1083 }
1084 if f.header.frame_type() == http2::FrameType::WindowUpdate {
1085 validate_window_update_payload(&f.payload)?;
1086 let increment = u32::from_be_bytes([
1087 f.payload[0],
1088 f.payload[1],
1089 f.payload[2],
1090 f.payload[3],
1091 ]) & 0x7FFF_FFFF;
1092 if f.header.stream_id == 0 {
1093 apply_send_conn_window_update(
1094 &mut flow_control,
1095 increment,
1096 )?;
1097 }
1098 }
1099 if f.header.frame_type() == http2::FrameType::Ping {
1100 if f.header.stream_id != 0 || f.payload.len() != 8 {
1101 return Err(http2::Http2Error::Protocol(
1102 "invalid PING frame",
1103 )
1104 .into());
1105 }
1106 if (f.header.flags & FLAG_ACK) == 0 {
1107 framed
1108 .write_frame(
1109 http2::FrameType::Ping,
1110 FLAG_ACK,
1111 0,
1112 &f.payload,
1113 )
1114 .await?;
1115 }
1116 }
1117 if f.header.frame_type() == http2::FrameType::Settings {
1118 let is_ack = validate_settings_frame(
1119 f.header.stream_id,
1120 f.header.flags,
1121 &f.payload,
1122 )?;
1123 if !is_ack {
1124 apply_http2_settings_with_fc(
1125 &mut hpack,
1126 &mut peer_max_frame_size,
1127 Some(&mut flow_control),
1128 &f.payload,
1129 )?;
1130 framed
1131 .write_frame(
1132 http2::FrameType::Settings,
1133 FLAG_ACK,
1134 0,
1135 &[],
1136 )
1137 .await?;
1138 }
1139 }
1140 }
1141 _ => {
1142 return Err(http2::Http2Error::Protocol(
1143 "unsupported frame while reading request body",
1144 )
1145 .into());
1146 }
1147 }
1148 }
1149 if stream_reset {
1150 continue;
1151 }
1152 request.set_body(fastapi_core::Body::Bytes(body));
1153 }
1154
1155 let request_id = request_counter.fetch_add(1, Ordering::Relaxed);
1156 let request_budget = Budget::new().with_deadline(config.request_timeout);
1157 let request_cx = Cx::for_testing_with_budget(request_budget);
1158 let ctx = RequestContext::new(request_cx, request_id);
1159
1160 if let Err(err) = validate_host_header(&request, config) {
1161 let response = err.response();
1162 process_connection_http2_write_response(
1163 &mut framed,
1164 response,
1165 stream_id,
1166 peer_max_frame_size,
1167 recv_max_frame_size,
1168 Some(&mut flow_control),
1169 )
1170 .await?;
1171 continue;
1172 }
1173
1174 if let Err(response) = config.pre_body_validators.validate_all(&request) {
1175 process_connection_http2_write_response(
1176 &mut framed,
1177 response,
1178 stream_id,
1179 peer_max_frame_size,
1180 recv_max_frame_size,
1181 Some(&mut flow_control),
1182 )
1183 .await?;
1184 continue;
1185 }
1186
1187 let response = handler(ctx, &mut request).await;
1188 process_connection_http2_write_response(
1189 &mut framed,
1190 response,
1191 stream_id,
1192 peer_max_frame_size,
1193 recv_max_frame_size,
1194 Some(&mut flow_control),
1195 )
1196 .await?;
1197
1198 if let Some(tasks) = App::take_background_tasks(&mut request) {
1199 tasks.execute_all().await;
1200 }
1201 }
1202 http2::FrameType::WindowUpdate => {
1203 validate_window_update_payload(&frame.payload)?;
1204 let increment = u32::from_be_bytes([
1205 frame.payload[0],
1206 frame.payload[1],
1207 frame.payload[2],
1208 frame.payload[3],
1209 ]) & 0x7FFF_FFFF;
1210 if frame.header.stream_id == 0 {
1211 apply_send_conn_window_update(&mut flow_control, increment)?;
1212 }
1213 }
1214 _ => {
1215 handle_h2_idle_frame(&frame)?;
1216 }
1217 }
1218 }
1219}
1220
1221async fn process_connection_http2_write_response(
1222 framed: &mut http2::FramedH2,
1223 response: Response,
1224 stream_id: u32,
1225 mut peer_max_frame_size: u32,
1226 recv_max_frame_size: u32,
1227 mut flow_control: Option<&mut http2::H2FlowControl>,
1228) -> Result<(), ServerError> {
1229 use std::future::poll_fn;
1230
1231 const FLAG_END_STREAM: u8 = 0x1;
1232 const FLAG_END_HEADERS: u8 = 0x4;
1233
1234 let (status, mut headers, mut body) = response.into_parts();
1235 if !status.allows_body() {
1236 body = fastapi_core::ResponseBody::Empty;
1237 }
1238
1239 let mut add_content_length = matches!(body, fastapi_core::ResponseBody::Bytes(_));
1240 for (name, _) in &headers {
1241 if name.eq_ignore_ascii_case("content-length") {
1242 add_content_length = false;
1243 break;
1244 }
1245 }
1246 if add_content_length {
1247 headers.push((
1248 "content-length".to_string(),
1249 body.len().to_string().into_bytes(),
1250 ));
1251 }
1252
1253 let mut block: Vec<u8> = Vec::new();
1254 let status_bytes = status.as_u16().to_string().into_bytes();
1255 http2::hpack_encode_literal_without_indexing(&mut block, b":status", &status_bytes);
1256 for (name, value) in &headers {
1257 if is_h2_forbidden_header_name(name) {
1258 continue;
1259 }
1260 let n = name.to_ascii_lowercase();
1261 http2::hpack_encode_literal_without_indexing(&mut block, n.as_bytes(), value);
1262 }
1263
1264 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
1265 let mut headers_flags = FLAG_END_HEADERS;
1266 if body.is_empty() {
1267 headers_flags |= FLAG_END_STREAM;
1268 }
1269
1270 if block.len() <= max {
1271 framed
1272 .write_frame(http2::FrameType::Headers, headers_flags, stream_id, &block)
1273 .await?;
1274 } else {
1275 let mut first_flags = 0u8;
1277 if body.is_empty() {
1278 first_flags |= FLAG_END_STREAM;
1279 }
1280 let (first, rest) = block.split_at(max);
1281 framed
1282 .write_frame(http2::FrameType::Headers, first_flags, stream_id, first)
1283 .await?;
1284 let mut remaining = rest;
1285 while remaining.len() > max {
1286 let (chunk, r) = remaining.split_at(max);
1287 framed
1288 .write_frame(http2::FrameType::Continuation, 0, stream_id, chunk)
1289 .await?;
1290 remaining = r;
1291 }
1292 framed
1293 .write_frame(
1294 http2::FrameType::Continuation,
1295 FLAG_END_HEADERS,
1296 stream_id,
1297 remaining,
1298 )
1299 .await?;
1300 }
1301
1302 let mut stream_send_window: i64 = flow_control
1304 .as_ref()
1305 .map_or(i64::MAX, |fc| i64::from(fc.peer_initial_window_size()));
1306
1307 match body {
1308 fastapi_core::ResponseBody::Empty => Ok(()),
1309 fastapi_core::ResponseBody::Bytes(bytes) => {
1310 if bytes.is_empty() {
1311 return Ok(());
1312 }
1313 let mut remaining = bytes.as_slice();
1314 while !remaining.is_empty() {
1315 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
1316 let send_len = remaining.len().min(max);
1317
1318 let send_len = h2_fc_clamp_send(
1319 framed,
1320 &mut flow_control,
1321 &mut stream_send_window,
1322 stream_id,
1323 send_len,
1324 &mut peer_max_frame_size,
1325 recv_max_frame_size,
1326 )
1327 .await?;
1328
1329 let (chunk, r) = remaining.split_at(send_len);
1330 let flags = if r.is_empty() { FLAG_END_STREAM } else { 0 };
1331 framed
1332 .write_frame(http2::FrameType::Data, flags, stream_id, chunk)
1333 .await?;
1334 remaining = r;
1335 }
1336 Ok(())
1337 }
1338 fastapi_core::ResponseBody::Stream(mut s) => {
1339 loop {
1340 let next = poll_fn(|cx| Pin::new(&mut s).poll_next(cx)).await;
1341 match next {
1342 Some(chunk) => {
1343 let mut remaining = chunk.as_slice();
1344 while !remaining.is_empty() {
1345 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
1346 let send_len = remaining.len().min(max);
1347 let send_len = h2_fc_clamp_send(
1348 framed,
1349 &mut flow_control,
1350 &mut stream_send_window,
1351 stream_id,
1352 send_len,
1353 &mut peer_max_frame_size,
1354 recv_max_frame_size,
1355 )
1356 .await?;
1357
1358 let (c, r) = remaining.split_at(send_len);
1359 framed
1360 .write_frame(http2::FrameType::Data, 0, stream_id, c)
1361 .await?;
1362 remaining = r;
1363 }
1364 }
1365 None => {
1366 framed
1367 .write_frame(http2::FrameType::Data, FLAG_END_STREAM, stream_id, &[])
1368 .await?;
1369 break;
1370 }
1371 }
1372 }
1373 Ok(())
1374 }
1375 }
1376}
1377
1378async fn h2_fc_clamp_send(
1383 framed: &mut http2::FramedH2,
1384 flow_control: &mut Option<&mut http2::H2FlowControl>,
1385 stream_send_window: &mut i64,
1386 stream_id: u32,
1387 desired: usize,
1388 peer_max_frame_size: &mut u32,
1389 recv_max_frame_size: u32,
1390) -> Result<usize, ServerError> {
1391 let fc = match flow_control.as_mut() {
1392 Some(fc) => fc,
1393 None => return Ok(desired),
1394 };
1395
1396 loop {
1397 let conn_avail = usize::try_from(fc.send_conn_window().max(0)).unwrap_or(0);
1398 let stream_avail = usize::try_from((*stream_send_window).max(0)).unwrap_or(0);
1399 let peer_max = usize::try_from(*peer_max_frame_size).unwrap_or(16 * 1024);
1400 let allowed = desired.min(conn_avail).min(stream_avail).min(peer_max);
1401
1402 if allowed > 0 {
1403 let send = allowed;
1404 fc.consume_send_conn_window(u32::try_from(send).unwrap_or(u32::MAX));
1405 *stream_send_window -= i64::try_from(send).unwrap_or(i64::MAX);
1406 return Ok(send);
1407 }
1408
1409 let frame = framed.read_frame(recv_max_frame_size).await?;
1411 match frame.header.frame_type() {
1412 http2::FrameType::WindowUpdate => {
1413 apply_peer_window_update_for_send(
1414 fc,
1415 stream_send_window,
1416 stream_id,
1417 frame.header.stream_id,
1418 &frame.payload,
1419 )?;
1420 }
1421 http2::FrameType::Ping => {
1422 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
1423 return Err(ServerError::Http2(http2::Http2Error::Protocol(
1424 "invalid PING frame",
1425 )));
1426 }
1427 if frame.header.flags & 0x1 == 0 {
1428 framed
1429 .write_frame(http2::FrameType::Ping, 0x1, 0, &frame.payload)
1430 .await?;
1431 }
1432 }
1433 http2::FrameType::Settings => {
1434 let is_ack = validate_settings_frame(
1435 frame.header.stream_id,
1436 frame.header.flags,
1437 &frame.payload,
1438 )?;
1439 if !is_ack {
1440 apply_peer_settings_for_send(
1441 fc,
1442 stream_send_window,
1443 peer_max_frame_size,
1444 &frame.payload,
1445 )?;
1446 framed
1448 .write_frame(http2::FrameType::Settings, 0x1, 0, &[])
1449 .await?;
1450 }
1451 }
1452 http2::FrameType::Goaway => {
1453 validate_goaway_payload(&frame.payload)?;
1454 return Err(ServerError::Http2(http2::Http2Error::Protocol(
1455 "received GOAWAY while writing response",
1456 )));
1457 }
1458 http2::FrameType::RstStream => {
1459 validate_rst_stream_payload(frame.header.stream_id, &frame.payload)?;
1460 if frame.header.stream_id == stream_id {
1461 return Err(ServerError::Http2(http2::Http2Error::Protocol(
1462 "stream reset by peer during response",
1463 )));
1464 }
1465 }
1466 _ => { }
1467 }
1468 }
1469}
1470
1471pub struct TcpServer {
1477 config: ServerConfig,
1478 request_counter: Arc<AtomicU64>,
1479 connection_counter: Arc<AtomicU64>,
1481 draining: Arc<AtomicBool>,
1483 connection_handles: Mutex<Vec<JoinHandle<()>>>,
1485 shutdown_controller: Arc<ShutdownController>,
1487 metrics_counters: Arc<MetricsCounters>,
1489}
1490
1491struct ConnectionSlotGuard {
1492 counter: Arc<AtomicU64>,
1493}
1494
1495impl ConnectionSlotGuard {
1496 fn new(counter: Arc<AtomicU64>) -> Self {
1497 Self { counter }
1498 }
1499}
1500
1501impl Drop for ConnectionSlotGuard {
1502 fn drop(&mut self) {
1503 self.counter.fetch_sub(1, Ordering::Relaxed);
1504 }
1505}
1506
1507impl std::fmt::Debug for TcpServer {
1508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1509 f.debug_struct("TcpServer")
1510 .field("config", &self.config)
1511 .field("request_counter", &self.request_counter)
1512 .field("connection_counter", &self.connection_counter)
1513 .field("draining", &self.draining)
1514 .field(
1515 "connection_handles",
1516 &self.connection_handles.lock().map_or(0, |h| h.len()),
1517 )
1518 .field("shutdown_controller", &self.shutdown_controller)
1519 .field("metrics_counters", &self.metrics_counters)
1520 .finish()
1521 }
1522}
1523
1524impl TcpServer {
1525 #[must_use]
1527 pub fn new(config: ServerConfig) -> Self {
1528 Self {
1529 config,
1530 request_counter: Arc::new(AtomicU64::new(0)),
1531 connection_counter: Arc::new(AtomicU64::new(0)),
1532 draining: Arc::new(AtomicBool::new(false)),
1533 connection_handles: Mutex::new(Vec::new()),
1534 shutdown_controller: Arc::new(ShutdownController::new()),
1535 metrics_counters: Arc::new(MetricsCounters::new()),
1536 }
1537 }
1538
1539 #[must_use]
1541 pub fn config(&self) -> &ServerConfig {
1542 &self.config
1543 }
1544
1545 fn next_request_id(&self) -> u64 {
1547 self.request_counter.fetch_add(1, Ordering::Relaxed)
1548 }
1549
1550 #[must_use]
1552 pub fn current_connections(&self) -> u64 {
1553 self.connection_counter.load(Ordering::Relaxed)
1554 }
1555
1556 #[must_use]
1558 pub fn metrics(&self) -> ServerMetrics {
1559 ServerMetrics {
1560 active_connections: self.connection_counter.load(Ordering::Relaxed),
1561 total_accepted: self.metrics_counters.total_accepted.load(Ordering::Relaxed),
1562 total_rejected: self.metrics_counters.total_rejected.load(Ordering::Relaxed),
1563 total_timed_out: self
1564 .metrics_counters
1565 .total_timed_out
1566 .load(Ordering::Relaxed),
1567 total_requests: self.request_counter.load(Ordering::Relaxed),
1568 bytes_in: self.metrics_counters.bytes_in.load(Ordering::Relaxed),
1569 bytes_out: self.metrics_counters.bytes_out.load(Ordering::Relaxed),
1570 }
1571 }
1572
1573 fn record_bytes_in(&self, n: u64) {
1575 self.metrics_counters
1576 .bytes_in
1577 .fetch_add(n, Ordering::Relaxed);
1578 }
1579
1580 fn record_bytes_out(&self, n: u64) {
1582 self.metrics_counters
1583 .bytes_out
1584 .fetch_add(n, Ordering::Relaxed);
1585 }
1586
1587 fn try_acquire_connection(&self) -> bool {
1592 let max = self.config.max_connections;
1593 if max == 0 {
1594 self.connection_counter.fetch_add(1, Ordering::Relaxed);
1596 self.metrics_counters
1597 .total_accepted
1598 .fetch_add(1, Ordering::Relaxed);
1599 return true;
1600 }
1601
1602 let mut current = self.connection_counter.load(Ordering::Relaxed);
1604 loop {
1605 if current >= max as u64 {
1606 self.metrics_counters
1607 .total_rejected
1608 .fetch_add(1, Ordering::Relaxed);
1609 return false;
1610 }
1611 match self.connection_counter.compare_exchange_weak(
1612 current,
1613 current + 1,
1614 Ordering::AcqRel,
1615 Ordering::Relaxed,
1616 ) {
1617 Ok(_) => {
1618 self.metrics_counters
1619 .total_accepted
1620 .fetch_add(1, Ordering::Relaxed);
1621 return true;
1622 }
1623 Err(actual) => current = actual,
1624 }
1625 }
1626 }
1627
1628 fn release_connection(&self) {
1630 self.connection_counter.fetch_sub(1, Ordering::Relaxed);
1631 }
1632
1633 #[must_use]
1635 pub fn is_draining(&self) -> bool {
1636 self.draining.load(Ordering::Acquire)
1637 }
1638
1639 pub fn start_drain(&self) {
1646 self.draining.store(true, Ordering::Release);
1647 }
1648
1649 pub async fn wait_for_drain(&self, timeout: Duration, poll_interval: Option<Duration>) -> bool {
1659 let start = Instant::now();
1660 let poll_interval = poll_interval.unwrap_or(Duration::from_millis(10));
1661
1662 while self.current_connections() > 0 {
1663 if start.elapsed() >= timeout {
1664 return false;
1665 }
1666 std::thread::sleep(poll_interval);
1676 }
1677 true
1678 }
1679
1680 pub async fn drain(&self) -> u64 {
1688 self.start_drain();
1689 let drained = self.wait_for_drain(self.config.drain_timeout, None).await;
1690 if drained {
1691 0
1692 } else {
1693 self.current_connections()
1694 }
1695 }
1696
1697 #[must_use]
1702 pub fn shutdown_controller(&self) -> &Arc<ShutdownController> {
1703 &self.shutdown_controller
1704 }
1705
1706 #[must_use]
1711 pub fn subscribe_shutdown(&self) -> ShutdownReceiver {
1712 self.shutdown_controller.subscribe()
1713 }
1714
1715 pub fn shutdown(&self) {
1724 self.start_drain();
1725 self.shutdown_controller.shutdown();
1726 }
1727
1728 #[must_use]
1730 pub fn is_shutting_down(&self) -> bool {
1731 self.shutdown_controller.is_shutting_down() || self.is_draining()
1732 }
1733
1734 pub async fn serve_with_shutdown<H, Fut>(
1764 &self,
1765 cx: &Cx,
1766 mut shutdown: ShutdownReceiver,
1767 handler: H,
1768 ) -> Result<GracefulOutcome<()>, ServerError>
1769 where
1770 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1771 Fut: Future<Output = Response> + Send + 'static,
1772 {
1773 let bind_addr = self.config.bind_addr.clone();
1774 let listener = TcpListener::bind(bind_addr).await?;
1775 let local_addr = listener.local_addr()?;
1776
1777 cx.trace(&format!(
1778 "Server listening on {local_addr} (with graceful shutdown)"
1779 ));
1780
1781 let result = self
1783 .accept_loop_with_shutdown(cx, listener, handler, &mut shutdown)
1784 .await;
1785
1786 match result {
1787 Ok(outcome) => {
1788 if outcome.is_shutdown() {
1789 cx.trace("Shutdown signal received, draining connections");
1790 self.start_drain();
1791 self.drain_connection_tasks(cx).await;
1792 }
1793 Ok(outcome)
1794 }
1795 Err(e) => Err(e),
1796 }
1797 }
1798
1799 async fn accept_loop_with_shutdown<H, Fut>(
1801 &self,
1802 cx: &Cx,
1803 listener: TcpListener,
1804 handler: H,
1805 shutdown: &mut ShutdownReceiver,
1806 ) -> Result<GracefulOutcome<()>, ServerError>
1807 where
1808 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1809 Fut: Future<Output = Response> + Send + 'static,
1810 {
1811 let handler = Arc::new(handler);
1812
1813 loop {
1814 if shutdown.is_shutting_down() {
1816 return Ok(GracefulOutcome::ShutdownSignaled);
1817 }
1818 if cx.is_cancel_requested() || self.is_draining() {
1819 return Ok(GracefulOutcome::ShutdownSignaled);
1820 }
1821
1822 let (mut stream, peer_addr) = match listener.accept().await {
1824 Ok(conn) => conn,
1825 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
1826 continue;
1827 }
1828 Err(e) => {
1829 cx.trace(&format!("Accept error: {e}"));
1830 if is_fatal_accept_error(&e) {
1831 self.drain_connection_tasks(cx).await;
1832 return Err(ServerError::Io(e));
1833 }
1834 continue;
1835 }
1836 };
1837
1838 if !self.try_acquire_connection() {
1840 cx.trace(&format!(
1841 "Connection limit reached ({}), rejecting {peer_addr}",
1842 self.config.max_connections
1843 ));
1844
1845 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
1846 .header("connection", b"close".to_vec())
1847 .body(fastapi_core::ResponseBody::Bytes(
1848 b"503 Service Unavailable: connection limit reached".to_vec(),
1849 ));
1850 let mut writer = crate::response::ResponseWriter::new();
1851 let response_bytes = writer.write(response);
1852 let _ = write_response(&mut stream, response_bytes).await;
1853 continue;
1854 }
1855
1856 if self.config.tcp_nodelay {
1858 let _ = stream.set_nodelay(true);
1859 }
1860
1861 cx.trace(&format!(
1862 "Accepted connection from {peer_addr} ({}/{})",
1863 self.current_connections(),
1864 if self.config.max_connections == 0 {
1865 "∞".to_string()
1866 } else {
1867 self.config.max_connections.to_string()
1868 }
1869 ));
1870
1871 let request_id = self.next_request_id();
1872 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
1873 let request_cx = Cx::for_testing_with_budget(request_budget);
1874 let ctx = RequestContext::new(request_cx, request_id);
1875
1876 let result = self
1878 .handle_connection(&ctx, stream, peer_addr, &*handler)
1879 .await;
1880
1881 self.release_connection();
1882
1883 if let Err(e) = result {
1884 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
1885 }
1886 }
1887 }
1888
1889 pub async fn serve<H, Fut>(&self, cx: &Cx, handler: H) -> Result<(), ServerError>
1903 where
1904 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1905 Fut: Future<Output = Response> + Send + 'static,
1906 {
1907 let bind_addr = self.config.bind_addr.clone();
1908 let listener = TcpListener::bind(bind_addr).await?;
1909 let local_addr = listener.local_addr()?;
1910
1911 cx.trace(&format!("Server listening on {local_addr}"));
1912
1913 self.accept_loop(cx, listener, handler).await
1914 }
1915
1916 pub async fn serve_on<H, Fut>(
1920 &self,
1921 cx: &Cx,
1922 listener: TcpListener,
1923 handler: H,
1924 ) -> Result<(), ServerError>
1925 where
1926 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
1927 Fut: Future<Output = Response> + Send + 'static,
1928 {
1929 self.accept_loop(cx, listener, handler).await
1930 }
1931
1932 pub async fn serve_handler(
1953 &self,
1954 cx: &Cx,
1955 handler: Arc<dyn fastapi_core::Handler>,
1956 ) -> Result<(), ServerError> {
1957 let bind_addr = self.config.bind_addr.clone();
1958 let listener = TcpListener::bind(bind_addr).await?;
1959 let local_addr = listener.local_addr()?;
1960
1961 cx.trace(&format!("Server listening on {local_addr}"));
1962
1963 self.accept_loop_handler(cx, listener, handler).await
1964 }
1965
1966 pub async fn serve_app(&self, cx: &Cx, app: Arc<App>) -> Result<(), ServerError> {
1971 let bind_addr = self.config.bind_addr.clone();
1972 let listener = TcpListener::bind(bind_addr).await?;
1973 let local_addr = listener.local_addr()?;
1974
1975 cx.trace(&format!("Server listening on {local_addr}"));
1976 self.accept_loop_app(cx, listener, app).await
1977 }
1978
1979 pub async fn serve_on_handler(
1981 &self,
1982 cx: &Cx,
1983 listener: TcpListener,
1984 handler: Arc<dyn fastapi_core::Handler>,
1985 ) -> Result<(), ServerError> {
1986 self.accept_loop_handler(cx, listener, handler).await
1987 }
1988
1989 pub async fn serve_on_app(
1995 &self,
1996 cx: &Cx,
1997 listener: TcpListener,
1998 app: Arc<App>,
1999 ) -> Result<(), ServerError> {
2000 self.accept_loop_app(cx, listener, app).await
2001 }
2002
2003 async fn accept_loop_app(
2004 &self,
2005 cx: &Cx,
2006 listener: TcpListener,
2007 app: Arc<App>,
2008 ) -> Result<(), ServerError> {
2009 loop {
2010 if cx.is_cancel_requested() {
2011 cx.trace("Server shutdown requested");
2012 return Ok(());
2013 }
2014 if self.is_draining() {
2015 cx.trace("Server draining, stopping accept loop");
2016 return Err(ServerError::Shutdown);
2017 }
2018
2019 let (mut stream, peer_addr) = match listener.accept().await {
2020 Ok(conn) => conn,
2021 Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
2022 Err(e) => {
2023 cx.trace(&format!("Accept error: {e}"));
2024 if is_fatal_accept_error(&e) {
2025 return Err(ServerError::Io(e));
2026 }
2027 continue;
2028 }
2029 };
2030
2031 if !self.try_acquire_connection() {
2032 cx.trace(&format!(
2033 "Connection limit reached ({}), rejecting {peer_addr}",
2034 self.config.max_connections
2035 ));
2036
2037 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
2038 .header("connection", b"close".to_vec())
2039 .body(fastapi_core::ResponseBody::Bytes(
2040 b"503 Service Unavailable: connection limit reached".to_vec(),
2041 ));
2042 let mut writer = crate::response::ResponseWriter::new();
2043 let response_bytes = writer.write(response);
2044 let _ = write_response(&mut stream, response_bytes).await;
2045 continue;
2046 }
2047
2048 if self.config.tcp_nodelay {
2049 let _ = stream.set_nodelay(true);
2050 }
2051
2052 cx.trace(&format!(
2053 "Accepted connection from {peer_addr} ({}/{})",
2054 self.current_connections(),
2055 if self.config.max_connections == 0 {
2056 "∞".to_string()
2057 } else {
2058 self.config.max_connections.to_string()
2059 }
2060 ));
2061
2062 let result = self
2063 .handle_connection_app(cx, stream, peer_addr, app.as_ref())
2064 .await;
2065
2066 self.release_connection();
2067
2068 if let Err(e) = result {
2069 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
2070 }
2071 }
2072 }
2073
2074 async fn accept_loop_handler(
2076 &self,
2077 cx: &Cx,
2078 listener: TcpListener,
2079 handler: Arc<dyn fastapi_core::Handler>,
2080 ) -> Result<(), ServerError> {
2081 loop {
2082 if cx.is_cancel_requested() {
2084 cx.trace("Server shutdown requested");
2085 return Ok(());
2086 }
2087
2088 if self.is_draining() {
2090 cx.trace("Server draining, stopping accept loop");
2091 return Err(ServerError::Shutdown);
2092 }
2093
2094 let (mut stream, peer_addr) = match listener.accept().await {
2096 Ok(conn) => conn,
2097 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
2098 continue;
2099 }
2100 Err(e) => {
2101 cx.trace(&format!("Accept error: {e}"));
2102 if is_fatal_accept_error(&e) {
2103 return Err(ServerError::Io(e));
2104 }
2105 continue;
2106 }
2107 };
2108
2109 if !self.try_acquire_connection() {
2111 cx.trace(&format!(
2112 "Connection limit reached ({}), rejecting {peer_addr}",
2113 self.config.max_connections
2114 ));
2115
2116 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
2117 .header("connection", b"close".to_vec())
2118 .body(fastapi_core::ResponseBody::Bytes(
2119 b"503 Service Unavailable: connection limit reached".to_vec(),
2120 ));
2121 let mut writer = crate::response::ResponseWriter::new();
2122 let response_bytes = writer.write(response);
2123 let _ = write_response(&mut stream, response_bytes).await;
2124 continue;
2125 }
2126
2127 if self.config.tcp_nodelay {
2129 let _ = stream.set_nodelay(true);
2130 }
2131
2132 cx.trace(&format!(
2133 "Accepted connection from {peer_addr} ({}/{})",
2134 self.current_connections(),
2135 if self.config.max_connections == 0 {
2136 "∞".to_string()
2137 } else {
2138 self.config.max_connections.to_string()
2139 }
2140 ));
2141
2142 let result = self
2144 .handle_connection_handler(cx, stream, peer_addr, &*handler)
2145 .await;
2146
2147 self.release_connection();
2148
2149 if let Err(e) = result {
2150 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
2151 }
2152 }
2153 }
2154
2155 #[allow(clippy::too_many_lines)]
2171 pub async fn serve_concurrent<H, Fut>(&self, cx: &Cx, handler: H) -> Result<(), ServerError>
2172 where
2173 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
2174 Fut: Future<Output = Response> + Send + 'static,
2175 {
2176 let bind_addr = self.config.bind_addr.clone();
2177 let listener = TcpListener::bind(bind_addr).await?;
2178 let local_addr = listener.local_addr()?;
2179
2180 cx.trace(&format!(
2181 "Server listening on {local_addr} (concurrent mode)"
2182 ));
2183
2184 let handler = Arc::new(handler);
2185
2186 self.accept_loop_concurrent(cx, listener, handler).await
2187 }
2188
2189 async fn accept_loop_concurrent<H, Fut>(
2191 &self,
2192 cx: &Cx,
2193 listener: TcpListener,
2194 handler: Arc<H>,
2195 ) -> Result<(), ServerError>
2196 where
2197 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
2198 Fut: Future<Output = Response> + Send + 'static,
2199 {
2200 let runtime_handle = Runtime::current_handle()
2201 .expect("serve_concurrent must be called inside an asupersync runtime");
2202 let accept_poll_interval = Duration::from_millis(50);
2203
2204 loop {
2205 self.cleanup_completed_handles(cx).await;
2206
2207 if cx.is_cancel_requested() || self.is_draining() {
2209 cx.trace("Server shutting down, draining connections");
2210 self.drain_connection_tasks(cx).await;
2211 return Ok(());
2212 }
2213
2214 let accept_future = Box::pin(listener.accept());
2217 let (mut stream, peer_addr) =
2218 match timeout(current_time(), accept_poll_interval, accept_future).await {
2219 Ok(Ok(conn)) => conn,
2220 Ok(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => {
2221 continue;
2222 }
2223 Ok(Err(e)) => {
2224 cx.trace(&format!("Accept error: {e}"));
2225 if is_fatal_accept_error(&e) {
2226 return Err(ServerError::Io(e));
2227 }
2228 continue;
2229 }
2230 Err(_elapsed) => continue,
2231 };
2232
2233 if !self.try_acquire_connection() {
2235 cx.trace(&format!(
2236 "Connection limit reached ({}), rejecting {peer_addr}",
2237 self.config.max_connections
2238 ));
2239
2240 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
2241 .header("connection", b"close".to_vec())
2242 .body(fastapi_core::ResponseBody::Bytes(
2243 b"503 Service Unavailable: connection limit reached".to_vec(),
2244 ));
2245 let mut writer = crate::response::ResponseWriter::new();
2246 let response_bytes = writer.write(response);
2247 let _ = write_response(&mut stream, response_bytes).await;
2248 continue;
2249 }
2250
2251 if self.config.tcp_nodelay {
2253 let _ = stream.set_nodelay(true);
2254 }
2255
2256 cx.trace(&format!(
2257 "Accepted connection from {peer_addr} ({}/{})",
2258 self.current_connections(),
2259 if self.config.max_connections == 0 {
2260 "∞".to_string()
2261 } else {
2262 self.config.max_connections.to_string()
2263 }
2264 ));
2265
2266 match self.spawn_connection_task(
2268 &runtime_handle,
2269 cx,
2270 stream,
2271 peer_addr,
2272 Arc::clone(&handler),
2273 ) {
2274 Ok(handle) => {
2275 if let Ok(mut handles) = self.connection_handles.lock() {
2277 handles.push(handle);
2278 }
2279 self.cleanup_completed_handles(cx).await;
2281 }
2282 Err(e) => {
2283 cx.trace(&format!("Failed to spawn connection task: {e:?}"));
2284 }
2285 }
2286 }
2287 }
2288
2289 fn spawn_connection_task<H, Fut>(
2294 &self,
2295 handle: &RuntimeHandle,
2296 cx: &Cx,
2297 stream: TcpStream,
2298 peer_addr: SocketAddr,
2299 handler: Arc<H>,
2300 ) -> Result<JoinHandle<()>, SpawnError>
2301 where
2302 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
2303 Fut: Future<Output = Response> + Send + 'static,
2304 {
2305 let config = self.config.clone();
2306 let connection_cx = cx.clone();
2307 let request_counter = Arc::clone(&self.request_counter);
2308 let connection_counter = Arc::clone(&self.connection_counter);
2309 let connection_slot = ConnectionSlotGuard::new(connection_counter);
2310
2311 handle.try_spawn(async move {
2312 let _connection_slot = connection_slot;
2313 let result = process_connection(
2314 &connection_cx,
2315 &request_counter,
2316 stream,
2317 peer_addr,
2318 &config,
2319 |ctx, req| handler(ctx, req),
2320 )
2321 .await;
2322
2323 if let Err(e) = result {
2324 eprintln!("Connection error from {peer_addr}: {e}");
2326 }
2327 })
2328 }
2329
2330 fn take_finished_connection_handles(&self) -> Vec<JoinHandle<()>> {
2331 if let Ok(mut handles) = self.connection_handles.lock() {
2332 let mut finished = Vec::new();
2333 let mut idx = 0;
2334 while idx < handles.len() {
2335 if handles[idx].is_finished() {
2336 finished.push(handles.swap_remove(idx));
2337 } else {
2338 idx += 1;
2339 }
2340 }
2341 finished
2342 } else {
2343 Vec::new()
2344 }
2345 }
2346
2347 async fn cleanup_completed_handles(&self, cx: &Cx) {
2349 for handle in self.take_finished_connection_handles() {
2350 if let Err(payload) = CatchUnwind::new(handle).await {
2351 let message = panic_payload_message(payload.as_ref());
2352 cx.trace(&format!("Connection task panicked: {message}"));
2353 eprintln!("Connection task panicked: {message}");
2354 }
2355 }
2356 }
2357
2358 async fn drain_connection_tasks(&self, cx: &Cx) {
2360 let drain_timeout = self.config.drain_timeout;
2361 let start = Instant::now();
2362
2363 cx.trace(&format!(
2364 "Draining {} connection tasks (timeout: {:?})",
2365 self.connection_handles.lock().map_or(0, |h| h.len()),
2366 drain_timeout
2367 ));
2368
2369 while start.elapsed() < drain_timeout {
2371 self.cleanup_completed_handles(cx).await;
2372
2373 let remaining = self
2374 .connection_handles
2375 .lock()
2376 .map_or(0, |h| h.iter().filter(|t| !t.is_finished()).count());
2377
2378 if remaining == 0 {
2379 self.cleanup_completed_handles(cx).await;
2380 cx.trace("All connection tasks drained successfully");
2381 return;
2382 }
2383
2384 asupersync::runtime::yield_now().await;
2386 }
2387
2388 self.cleanup_completed_handles(cx).await;
2389 cx.trace(&format!(
2390 "Drain timeout reached with {} tasks still running; lingering connection tasks will continue in the background",
2391 self.connection_handles
2392 .lock()
2393 .map_or(0, |h| h.iter().filter(|t| !t.is_finished()).count())
2394 ));
2395 }
2396
2397 async fn handle_connection_app(
2398 &self,
2399 cx: &Cx,
2400 mut stream: TcpStream,
2401 peer_addr: SocketAddr,
2402 app: &App,
2403 ) -> Result<(), ServerError> {
2404 let (proto, buffered) = sniff_protocol(&mut stream, self.config.keep_alive_timeout).await?;
2405 if !buffered.is_empty() {
2406 self.record_bytes_in(buffered.len() as u64);
2407 }
2408
2409 if proto == SniffedProtocol::Http2PriorKnowledge {
2410 return self
2411 .handle_connection_app_http2(cx, stream, peer_addr, app)
2412 .await;
2413 }
2414
2415 let mut parser = StatefulParser::new().with_limits(self.config.parse_limits.clone());
2416 if !buffered.is_empty() {
2417 parser.feed(&buffered)?;
2418 }
2419 let mut read_buffer = vec![0u8; self.config.read_buffer_size];
2420 let mut response_writer = ResponseWriter::new();
2421 let mut requests_on_connection: usize = 0;
2422 let max_requests = self.config.max_requests_per_connection;
2423
2424 loop {
2425 if cx.is_cancel_requested() {
2426 return Ok(());
2427 }
2428
2429 let parse_result = parser.feed(&[])?;
2430 let mut request = match parse_result {
2431 ParseStatus::Complete { request, .. } => request,
2432 ParseStatus::Incomplete => {
2433 let keep_alive_timeout = self.config.keep_alive_timeout;
2434 let bytes_read = if keep_alive_timeout.is_zero() {
2435 read_into_buffer(&mut stream, &mut read_buffer).await?
2436 } else {
2437 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout)
2438 .await
2439 {
2440 Ok(0) => return Ok(()),
2441 Ok(n) => n,
2442 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
2443 self.metrics_counters
2444 .total_timed_out
2445 .fetch_add(1, Ordering::Relaxed);
2446 return Err(ServerError::KeepAliveTimeout);
2447 }
2448 Err(e) => return Err(ServerError::Io(e)),
2449 }
2450 };
2451
2452 if bytes_read == 0 {
2453 return Ok(());
2454 }
2455
2456 self.record_bytes_in(bytes_read as u64);
2457
2458 match parser.feed(&read_buffer[..bytes_read])? {
2459 ParseStatus::Complete { request, .. } => request,
2460 ParseStatus::Incomplete => continue,
2461 }
2462 }
2463 };
2464
2465 requests_on_connection += 1;
2466
2467 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
2468
2469 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
2471 let request_cx = Cx::for_testing_with_budget(request_budget);
2472 let overrides = app.dependency_overrides();
2473 let ctx = RequestContext::with_overrides_and_body_limit(
2474 request_cx,
2475 request_id,
2476 overrides,
2477 app.config().max_body_size,
2478 );
2479
2480 if let Err(err) = validate_host_header(&request, &self.config) {
2482 ctx.trace(&format!(
2483 "Rejecting request from {peer_addr}: {}",
2484 err.detail
2485 ));
2486 let response = err.response().header("connection", b"close".to_vec());
2487 let response_write = response_writer.write(response);
2488 write_response(&mut stream, response_write).await?;
2489 return Ok(());
2490 }
2491
2492 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
2494 let response = response.header("connection", b"close".to_vec());
2495 let response_write = response_writer.write(response);
2496 write_response(&mut stream, response_write).await?;
2497 return Ok(());
2498 }
2499
2500 if is_websocket_upgrade_request(&request)
2505 && app.websocket_route_count() > 0
2506 && app.has_websocket_route(request.path())
2507 {
2508 if has_request_body_headers(&request) {
2510 let response = Response::with_status(StatusCode::BAD_REQUEST)
2511 .header("connection", b"close".to_vec())
2512 .body(fastapi_core::ResponseBody::Bytes(
2513 b"Bad Request: websocket handshake must not include a body".to_vec(),
2514 ));
2515 let response_write = response_writer.write(response);
2516 write_response(&mut stream, response_write).await?;
2517 return Ok(());
2518 }
2519
2520 let Some(key) = header_str(&request, "sec-websocket-key") else {
2521 let response = Response::with_status(StatusCode::BAD_REQUEST)
2522 .header("connection", b"close".to_vec())
2523 .body(fastapi_core::ResponseBody::Bytes(
2524 b"Bad Request: missing Sec-WebSocket-Key".to_vec(),
2525 ));
2526 let response_write = response_writer.write(response);
2527 write_response(&mut stream, response_write).await?;
2528 return Ok(());
2529 };
2530 let accept = match fastapi_core::websocket_accept_from_key(key) {
2531 Ok(v) => v,
2532 Err(_) => {
2533 let response = Response::with_status(StatusCode::BAD_REQUEST)
2534 .header("connection", b"close".to_vec())
2535 .body(fastapi_core::ResponseBody::Bytes(
2536 b"Bad Request: invalid Sec-WebSocket-Key".to_vec(),
2537 ));
2538 let response_write = response_writer.write(response);
2539 write_response(&mut stream, response_write).await?;
2540 return Ok(());
2541 }
2542 };
2543
2544 if header_str(&request, "sec-websocket-version") != Some("13") {
2545 let response = Response::with_status(StatusCode::BAD_REQUEST)
2546 .header("sec-websocket-version", b"13".to_vec())
2547 .header("connection", b"close".to_vec())
2548 .body(fastapi_core::ResponseBody::Bytes(
2549 b"Bad Request: unsupported Sec-WebSocket-Version".to_vec(),
2550 ));
2551 let response_write = response_writer.write(response);
2552 write_response(&mut stream, response_write).await?;
2553 return Ok(());
2554 }
2555
2556 let response = Response::with_status(StatusCode::SWITCHING_PROTOCOLS)
2557 .header("upgrade", b"websocket".to_vec())
2558 .header("connection", b"Upgrade".to_vec())
2559 .header("sec-websocket-accept", accept.into_bytes());
2560 let response_write = response_writer.write(response);
2561 if let ResponseWrite::Full(ref bytes) = response_write {
2562 self.record_bytes_out(bytes.len() as u64);
2563 }
2564 write_response(&mut stream, response_write).await?;
2565
2566 let buffered = parser.take_buffered();
2568
2569 let ws_root_cx = Cx::for_testing_with_budget(Budget::new());
2571 let ws_ctx = RequestContext::with_overrides_and_body_limit(
2572 ws_root_cx,
2573 request_id,
2574 app.dependency_overrides(),
2575 app.config().max_body_size,
2576 );
2577
2578 let ws = fastapi_core::WebSocket::new(stream, buffered);
2579 let _ = app.handle_websocket(&ws_ctx, &mut request, ws).await;
2580 return Ok(());
2581 }
2582
2583 match ExpectHandler::check_expect(&request) {
2585 ExpectResult::NoExpectation => {}
2586 ExpectResult::ExpectsContinue => {
2587 ctx.trace("Sending 100 Continue for Expect: 100-continue");
2588 write_raw_response(&mut stream, CONTINUE_RESPONSE).await?;
2589 }
2590 ExpectResult::UnknownExpectation(value) => {
2591 ctx.trace(&format!("Rejecting unknown Expect value: {}", value));
2592 let response = ExpectHandler::expectation_failed(format!(
2593 "Unsupported Expect value: {value}"
2594 ));
2595 let response_write = response_writer.write(response);
2596 write_response(&mut stream, response_write).await?;
2597 return Ok(());
2598 }
2599 }
2600
2601 let client_wants_keep_alive = should_keep_alive(&request);
2602 let server_will_keep_alive = client_wants_keep_alive
2603 && (max_requests == 0 || requests_on_connection < max_requests);
2604
2605 let request_start = Instant::now();
2606 let timeout_duration = Duration::from_nanos(self.config.request_timeout.as_nanos());
2607
2608 let response = app.handle(&ctx, &mut request).await;
2609 let mut response = if request_start.elapsed() > timeout_duration {
2610 Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
2611 fastapi_core::ResponseBody::Bytes(
2612 b"Gateway Timeout: request processing exceeded time limit".to_vec(),
2613 ),
2614 )
2615 } else {
2616 response
2617 };
2618
2619 response = if server_will_keep_alive {
2620 response.header("connection", b"keep-alive".to_vec())
2621 } else {
2622 response.header("connection", b"close".to_vec())
2623 };
2624
2625 let response_write = response_writer.write(response);
2626 if let ResponseWrite::Full(ref bytes) = response_write {
2627 self.record_bytes_out(bytes.len() as u64);
2628 }
2629 write_response(&mut stream, response_write).await?;
2630
2631 if let Some(tasks) = App::take_background_tasks(&mut request) {
2632 tasks.execute_all().await;
2633 }
2634
2635 if !server_will_keep_alive {
2636 return Ok(());
2637 }
2638 }
2639 }
2640
2641 async fn handle_connection_app_http2(
2642 &self,
2643 cx: &Cx,
2644 stream: TcpStream,
2645 _peer_addr: SocketAddr,
2646 app: &App,
2647 ) -> Result<(), ServerError> {
2648 const FLAG_END_STREAM: u8 = 0x1;
2649 const FLAG_END_HEADERS: u8 = 0x4;
2650 const FLAG_ACK: u8 = 0x1;
2651
2652 let mut framed = http2::FramedH2::new(stream, Vec::new());
2653 let mut hpack = http2::HpackDecoder::new();
2654 let recv_max_frame_size: u32 = 16 * 1024; let mut peer_max_frame_size: u32 = 16 * 1024;
2656 let mut flow_control = http2::H2FlowControl::new();
2657
2658 let first = framed.read_frame(recv_max_frame_size).await?;
2659 self.record_bytes_in((http2::FrameHeader::LEN + first.payload.len()) as u64);
2660
2661 if first.header.frame_type() != http2::FrameType::Settings
2662 || first.header.stream_id != 0
2663 || (first.header.flags & FLAG_ACK) != 0
2664 {
2665 return Err(
2666 http2::Http2Error::Protocol("expected client SETTINGS after preface").into(),
2667 );
2668 }
2669
2670 apply_http2_settings_with_fc(
2671 &mut hpack,
2672 &mut peer_max_frame_size,
2673 Some(&mut flow_control),
2674 &first.payload,
2675 )?;
2676
2677 framed
2679 .write_frame(http2::FrameType::Settings, 0, 0, SERVER_SETTINGS_PAYLOAD)
2680 .await?;
2681 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2682
2683 framed
2684 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
2685 .await?;
2686 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2687 let mut last_stream_id: u32 = 0;
2688
2689 loop {
2690 if cx.is_cancel_requested() {
2691 let _ = send_goaway(&mut framed, last_stream_id, h2_error_code::NO_ERROR).await;
2692 return Ok(());
2693 }
2694
2695 let frame = framed.read_frame(recv_max_frame_size).await?;
2696 self.record_bytes_in((http2::FrameHeader::LEN + frame.payload.len()) as u64);
2697
2698 match frame.header.frame_type() {
2699 http2::FrameType::Settings => {
2700 let is_ack = validate_settings_frame(
2701 frame.header.stream_id,
2702 frame.header.flags,
2703 &frame.payload,
2704 )?;
2705 if is_ack {
2706 continue;
2708 }
2709 apply_http2_settings_with_fc(
2710 &mut hpack,
2711 &mut peer_max_frame_size,
2712 Some(&mut flow_control),
2713 &frame.payload,
2714 )?;
2715 framed
2717 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
2718 .await?;
2719 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2720 }
2721 http2::FrameType::Ping => {
2722 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
2724 return Err(http2::Http2Error::Protocol("invalid PING frame").into());
2725 }
2726 if (frame.header.flags & FLAG_ACK) == 0 {
2727 framed
2728 .write_frame(http2::FrameType::Ping, FLAG_ACK, 0, &frame.payload)
2729 .await?;
2730 self.record_bytes_out((http2::FrameHeader::LEN + 8) as u64);
2731 }
2732 }
2733 http2::FrameType::Goaway => {
2734 validate_goaway_payload(&frame.payload)?;
2735 return Ok(());
2736 }
2737 http2::FrameType::PushPromise => {
2738 return Err(http2::Http2Error::Protocol(
2739 "PUSH_PROMISE not supported by server",
2740 )
2741 .into());
2742 }
2743 http2::FrameType::Headers => {
2744 let stream_id = frame.header.stream_id;
2745 if stream_id == 0 {
2746 return Err(
2747 http2::Http2Error::Protocol("HEADERS must not be on stream 0").into(),
2748 );
2749 }
2750 if stream_id % 2 == 0 {
2751 return Err(http2::Http2Error::Protocol(
2752 "client-initiated stream ID must be odd",
2753 )
2754 .into());
2755 }
2756 if stream_id <= last_stream_id {
2757 return Err(http2::Http2Error::Protocol(
2758 "stream ID must be greater than previous",
2759 )
2760 .into());
2761 }
2762 last_stream_id = stream_id;
2763 let (end_stream, mut header_block) =
2764 extract_header_block_fragment(frame.header.flags, &frame.payload)?;
2765
2766 if (frame.header.flags & FLAG_END_HEADERS) == 0 {
2768 loop {
2769 let cont = framed.read_frame(recv_max_frame_size).await?;
2770 self.record_bytes_in(
2771 (http2::FrameHeader::LEN + cont.payload.len()) as u64,
2772 );
2773 if cont.header.frame_type() != http2::FrameType::Continuation
2774 || cont.header.stream_id != stream_id
2775 {
2776 return Err(http2::Http2Error::Protocol(
2777 "expected CONTINUATION for header block",
2778 )
2779 .into());
2780 }
2781 header_block.extend_from_slice(&cont.payload);
2782 if header_block.len() > MAX_HEADER_BLOCK_SIZE {
2783 return Err(http2::Http2Error::Protocol(
2784 "header block exceeds maximum size",
2785 )
2786 .into());
2787 }
2788 if (cont.header.flags & FLAG_END_HEADERS) != 0 {
2789 break;
2790 }
2791 }
2792 }
2793
2794 let headers = hpack
2795 .decode(&header_block)
2796 .map_err(http2::Http2Error::from)?;
2797 let mut request = request_from_h2_headers(headers)?;
2798 request.set_version(fastapi_core::HttpVersion::Http2);
2799
2800 if !end_stream {
2802 let max = app.config().max_body_size;
2803 let mut body = Vec::new();
2804 let mut stream_reset = false;
2805 let mut stream_received: u32 = 0;
2806 loop {
2807 let f = framed.read_frame(recv_max_frame_size).await?;
2808 self.record_bytes_in(
2809 (http2::FrameHeader::LEN + f.payload.len()) as u64,
2810 );
2811 match f.header.frame_type() {
2812 http2::FrameType::Data if f.header.stream_id == 0 => {
2813 return Err(http2::Http2Error::Protocol(
2814 "DATA must not be on stream 0",
2815 )
2816 .into());
2817 }
2818 http2::FrameType::Data if f.header.stream_id == stream_id => {
2819 let (data, data_end_stream) =
2820 extract_data_payload(f.header.flags, &f.payload)?;
2821 if body.len().saturating_add(data.len()) > max {
2822 return Err(http2::Http2Error::Protocol(
2823 "request body exceeds configured max_body_size",
2824 )
2825 .into());
2826 }
2827 body.extend_from_slice(data);
2828
2829 let data_len = u32::try_from(data.len()).unwrap_or(u32::MAX);
2832 stream_received += data_len;
2833 let conn_inc = flow_control.data_received_connection(data_len);
2834 let stream_inc =
2835 flow_control.stream_window_update(stream_received);
2836 if stream_inc > 0 {
2837 stream_received = 0;
2838 }
2839 send_window_updates(
2840 &mut framed,
2841 conn_inc,
2842 stream_id,
2843 stream_inc,
2844 )
2845 .await?;
2846
2847 if data_end_stream {
2848 break;
2849 }
2850 }
2851 http2::FrameType::RstStream => {
2852 validate_rst_stream_payload(f.header.stream_id, &f.payload)?;
2853 if f.header.stream_id == stream_id {
2854 stream_reset = true;
2855 break;
2856 }
2857 }
2858 http2::FrameType::PushPromise => {
2859 return Err(http2::Http2Error::Protocol(
2860 "PUSH_PROMISE not supported by server",
2861 )
2862 .into());
2863 }
2864 http2::FrameType::Settings
2865 | http2::FrameType::Ping
2866 | http2::FrameType::Goaway
2867 | http2::FrameType::WindowUpdate
2868 | http2::FrameType::Priority
2869 | http2::FrameType::Unknown => {
2870 if f.header.frame_type() == http2::FrameType::Goaway {
2871 validate_goaway_payload(&f.payload)?;
2872 return Ok(());
2873 }
2874 if f.header.frame_type() == http2::FrameType::Priority {
2875 validate_priority_payload(f.header.stream_id, &f.payload)?;
2876 }
2877 if f.header.frame_type() == http2::FrameType::WindowUpdate {
2878 validate_window_update_payload(&f.payload)?;
2879 let increment = u32::from_be_bytes([
2880 f.payload[0],
2881 f.payload[1],
2882 f.payload[2],
2883 f.payload[3],
2884 ]) & 0x7FFF_FFFF;
2885 if f.header.stream_id == 0 {
2886 apply_send_conn_window_update(
2887 &mut flow_control,
2888 increment,
2889 )?;
2890 }
2891 }
2892 if f.header.frame_type() == http2::FrameType::Ping {
2893 if f.header.stream_id != 0 || f.payload.len() != 8 {
2894 return Err(http2::Http2Error::Protocol(
2895 "invalid PING frame",
2896 )
2897 .into());
2898 }
2899 if (f.header.flags & FLAG_ACK) == 0 {
2900 framed
2901 .write_frame(
2902 http2::FrameType::Ping,
2903 FLAG_ACK,
2904 0,
2905 &f.payload,
2906 )
2907 .await?;
2908 self.record_bytes_out(
2909 (http2::FrameHeader::LEN + 8) as u64,
2910 );
2911 }
2912 }
2913 if f.header.frame_type() == http2::FrameType::Settings {
2914 let is_ack = validate_settings_frame(
2915 f.header.stream_id,
2916 f.header.flags,
2917 &f.payload,
2918 )?;
2919 if !is_ack {
2920 apply_http2_settings_with_fc(
2921 &mut hpack,
2922 &mut peer_max_frame_size,
2923 Some(&mut flow_control),
2924 &f.payload,
2925 )?;
2926 framed
2927 .write_frame(
2928 http2::FrameType::Settings,
2929 FLAG_ACK,
2930 0,
2931 &[],
2932 )
2933 .await?;
2934 self.record_bytes_out(http2::FrameHeader::LEN as u64);
2935 }
2936 }
2937 }
2938 _ => {
2939 return Err(http2::Http2Error::Protocol(
2940 "unsupported frame while reading request body",
2941 )
2942 .into());
2943 }
2944 }
2945 }
2946 if stream_reset {
2947 continue;
2948 }
2949 request.set_body(fastapi_core::Body::Bytes(body));
2950 }
2951
2952 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
2953 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
2954 let request_cx = Cx::for_testing_with_budget(request_budget);
2955 let overrides = app.dependency_overrides();
2956 let ctx = RequestContext::with_overrides_and_body_limit(
2957 request_cx,
2958 request_id,
2959 overrides,
2960 app.config().max_body_size,
2961 );
2962
2963 if let Err(err) = validate_host_header(&request, &self.config) {
2964 ctx.trace(&format!("Rejecting HTTP/2 request: {}", err.detail));
2965 let response = err.response();
2966 self.write_h2_response(
2967 &mut framed,
2968 response,
2969 stream_id,
2970 peer_max_frame_size,
2971 recv_max_frame_size,
2972 Some(&mut flow_control),
2973 )
2974 .await?;
2975 continue;
2976 }
2977
2978 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
2979 self.write_h2_response(
2980 &mut framed,
2981 response,
2982 stream_id,
2983 peer_max_frame_size,
2984 recv_max_frame_size,
2985 Some(&mut flow_control),
2986 )
2987 .await?;
2988 continue;
2989 }
2990
2991 let response = app.handle(&ctx, &mut request).await;
2992
2993 self.write_h2_response(
2995 &mut framed,
2996 response,
2997 stream_id,
2998 peer_max_frame_size,
2999 recv_max_frame_size,
3000 Some(&mut flow_control),
3001 )
3002 .await?;
3003
3004 if let Some(tasks) = App::take_background_tasks(&mut request) {
3005 tasks.execute_all().await;
3006 }
3007
3008 asupersync::runtime::yield_now().await;
3010 }
3011 http2::FrameType::WindowUpdate => {
3012 validate_window_update_payload(&frame.payload)?;
3013 let increment = u32::from_be_bytes([
3014 frame.payload[0],
3015 frame.payload[1],
3016 frame.payload[2],
3017 frame.payload[3],
3018 ]) & 0x7FFF_FFFF;
3019 if frame.header.stream_id == 0 {
3020 apply_send_conn_window_update(&mut flow_control, increment)?;
3021 }
3022 }
3023 _ => {
3024 handle_h2_idle_frame(&frame)?;
3025 }
3026 }
3027 }
3028 }
3029
3030 async fn write_h2_response(
3031 &self,
3032 framed: &mut http2::FramedH2,
3033 response: Response,
3034 stream_id: u32,
3035 mut peer_max_frame_size: u32,
3036 recv_max_frame_size: u32,
3037 mut flow_control: Option<&mut http2::H2FlowControl>,
3038 ) -> Result<(), ServerError> {
3039 use std::future::poll_fn;
3040
3041 const FLAG_END_STREAM: u8 = 0x1;
3042 const FLAG_END_HEADERS: u8 = 0x4;
3043
3044 let (status, mut headers, mut body) = response.into_parts();
3045 if !status.allows_body() {
3046 body = fastapi_core::ResponseBody::Empty;
3047 }
3048
3049 let mut add_content_length = matches!(body, fastapi_core::ResponseBody::Bytes(_));
3050 for (name, _) in &headers {
3051 if name.eq_ignore_ascii_case("content-length") {
3052 add_content_length = false;
3053 break;
3054 }
3055 }
3056
3057 if add_content_length {
3058 let len = body.len();
3059 headers.push(("content-length".to_string(), len.to_string().into_bytes()));
3060 }
3061
3062 let mut block: Vec<u8> = Vec::new();
3064 let status_bytes = status.as_u16().to_string().into_bytes();
3065 http2::hpack_encode_literal_without_indexing(&mut block, b":status", &status_bytes);
3066
3067 for (name, value) in &headers {
3068 if is_h2_forbidden_header_name(name) {
3069 continue;
3070 }
3071 let n = name.to_ascii_lowercase();
3072 http2::hpack_encode_literal_without_indexing(&mut block, n.as_bytes(), value);
3073 }
3074
3075 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
3077 if block.len() <= max {
3078 let mut flags = FLAG_END_HEADERS;
3079 if body.is_empty() {
3080 flags |= FLAG_END_STREAM;
3081 }
3082 framed
3083 .write_frame(http2::FrameType::Headers, flags, stream_id, &block)
3084 .await?;
3085 self.record_bytes_out((http2::FrameHeader::LEN + block.len()) as u64);
3086 } else {
3087 let mut flags = 0u8;
3088 if body.is_empty() {
3089 flags |= FLAG_END_STREAM;
3090 }
3091 let (first, rest) = block.split_at(max);
3092 framed
3093 .write_frame(http2::FrameType::Headers, flags, stream_id, first)
3094 .await?;
3095 self.record_bytes_out((http2::FrameHeader::LEN + first.len()) as u64);
3096
3097 let mut remaining = rest;
3098 while remaining.len() > max {
3099 let (chunk, r) = remaining.split_at(max);
3100 framed
3101 .write_frame(http2::FrameType::Continuation, 0, stream_id, chunk)
3102 .await?;
3103 self.record_bytes_out((http2::FrameHeader::LEN + chunk.len()) as u64);
3104 remaining = r;
3105 }
3106 framed
3107 .write_frame(
3108 http2::FrameType::Continuation,
3109 FLAG_END_HEADERS,
3110 stream_id,
3111 remaining,
3112 )
3113 .await?;
3114 self.record_bytes_out((http2::FrameHeader::LEN + remaining.len()) as u64);
3115 }
3116
3117 let mut stream_send_window: i64 = flow_control
3119 .as_ref()
3120 .map_or(i64::MAX, |fc| i64::from(fc.peer_initial_window_size()));
3121
3122 match body {
3124 fastapi_core::ResponseBody::Empty => Ok(()),
3125 fastapi_core::ResponseBody::Bytes(bytes) => {
3126 if bytes.is_empty() {
3127 return Ok(());
3128 }
3129 let mut remaining = bytes.as_slice();
3130 while !remaining.is_empty() {
3131 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
3132 let send_len = remaining.len().min(max);
3133 let send_len = h2_fc_clamp_send(
3134 framed,
3135 &mut flow_control,
3136 &mut stream_send_window,
3137 stream_id,
3138 send_len,
3139 &mut peer_max_frame_size,
3140 recv_max_frame_size,
3141 )
3142 .await?;
3143
3144 let (chunk, r) = remaining.split_at(send_len);
3145 let flags = if r.is_empty() { FLAG_END_STREAM } else { 0 };
3146 framed
3147 .write_frame(http2::FrameType::Data, flags, stream_id, chunk)
3148 .await?;
3149 self.record_bytes_out((http2::FrameHeader::LEN + chunk.len()) as u64);
3150 remaining = r;
3151 }
3152 Ok(())
3153 }
3154 fastapi_core::ResponseBody::Stream(mut s) => {
3155 loop {
3156 let next = poll_fn(|cx| Pin::new(&mut s).poll_next(cx)).await;
3157 match next {
3158 Some(chunk) => {
3159 let mut remaining = chunk.as_slice();
3160 while !remaining.is_empty() {
3161 let max = usize::try_from(peer_max_frame_size).unwrap_or(16 * 1024);
3162 let send_len = remaining.len().min(max);
3163 let send_len = h2_fc_clamp_send(
3164 framed,
3165 &mut flow_control,
3166 &mut stream_send_window,
3167 stream_id,
3168 send_len,
3169 &mut peer_max_frame_size,
3170 recv_max_frame_size,
3171 )
3172 .await?;
3173
3174 let (c, r) = remaining.split_at(send_len);
3175 framed
3176 .write_frame(http2::FrameType::Data, 0, stream_id, c)
3177 .await?;
3178 self.record_bytes_out((http2::FrameHeader::LEN + c.len()) as u64);
3179 remaining = r;
3180 }
3181 }
3182 None => {
3183 framed
3184 .write_frame(
3185 http2::FrameType::Data,
3186 FLAG_END_STREAM,
3187 stream_id,
3188 &[],
3189 )
3190 .await?;
3191 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3192 break;
3193 }
3194 }
3195 }
3196 Ok(())
3197 }
3198 }
3199 }
3200
3201 async fn handle_connection_handler_http2(
3202 &self,
3203 cx: &Cx,
3204 stream: TcpStream,
3205 handler: &dyn fastapi_core::Handler,
3206 ) -> Result<(), ServerError> {
3207 const FLAG_END_HEADERS: u8 = 0x4;
3208 const FLAG_ACK: u8 = 0x1;
3209
3210 let mut framed = http2::FramedH2::new(stream, Vec::new());
3211 let mut hpack = http2::HpackDecoder::new();
3212 let recv_max_frame_size: u32 = 16 * 1024;
3213 let mut peer_max_frame_size: u32 = 16 * 1024;
3214 let mut flow_control = http2::H2FlowControl::new();
3215
3216 let first = framed.read_frame(recv_max_frame_size).await?;
3217 self.record_bytes_in((http2::FrameHeader::LEN + first.payload.len()) as u64);
3218
3219 if first.header.frame_type() != http2::FrameType::Settings
3220 || first.header.stream_id != 0
3221 || (first.header.flags & FLAG_ACK) != 0
3222 {
3223 return Err(
3224 http2::Http2Error::Protocol("expected client SETTINGS after preface").into(),
3225 );
3226 }
3227
3228 apply_http2_settings_with_fc(
3229 &mut hpack,
3230 &mut peer_max_frame_size,
3231 Some(&mut flow_control),
3232 &first.payload,
3233 )?;
3234
3235 framed
3236 .write_frame(http2::FrameType::Settings, 0, 0, SERVER_SETTINGS_PAYLOAD)
3237 .await?;
3238 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3239
3240 framed
3241 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
3242 .await?;
3243 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3244
3245 let default_body_limit = self.config.parse_limits.max_request_size;
3246 let mut last_stream_id: u32 = 0;
3247
3248 loop {
3249 if cx.is_cancel_requested() {
3250 let _ = send_goaway(&mut framed, last_stream_id, h2_error_code::NO_ERROR).await;
3251 return Ok(());
3252 }
3253
3254 let frame = framed.read_frame(recv_max_frame_size).await?;
3255 self.record_bytes_in((http2::FrameHeader::LEN + frame.payload.len()) as u64);
3256
3257 match frame.header.frame_type() {
3258 http2::FrameType::Settings => {
3259 let is_ack = validate_settings_frame(
3260 frame.header.stream_id,
3261 frame.header.flags,
3262 &frame.payload,
3263 )?;
3264 if is_ack {
3265 continue;
3266 }
3267 apply_http2_settings_with_fc(
3268 &mut hpack,
3269 &mut peer_max_frame_size,
3270 Some(&mut flow_control),
3271 &frame.payload,
3272 )?;
3273 framed
3274 .write_frame(http2::FrameType::Settings, FLAG_ACK, 0, &[])
3275 .await?;
3276 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3277 }
3278 http2::FrameType::Ping => {
3279 if frame.header.stream_id != 0 || frame.payload.len() != 8 {
3280 return Err(http2::Http2Error::Protocol("invalid PING frame").into());
3281 }
3282 if (frame.header.flags & FLAG_ACK) == 0 {
3283 framed
3284 .write_frame(http2::FrameType::Ping, FLAG_ACK, 0, &frame.payload)
3285 .await?;
3286 self.record_bytes_out((http2::FrameHeader::LEN + 8) as u64);
3287 }
3288 }
3289 http2::FrameType::Goaway => {
3290 validate_goaway_payload(&frame.payload)?;
3291 return Ok(());
3292 }
3293 http2::FrameType::PushPromise => {
3294 return Err(http2::Http2Error::Protocol(
3295 "PUSH_PROMISE not supported by server",
3296 )
3297 .into());
3298 }
3299 http2::FrameType::Headers => {
3300 let stream_id = frame.header.stream_id;
3301 if stream_id == 0 {
3302 return Err(
3303 http2::Http2Error::Protocol("HEADERS must not be on stream 0").into(),
3304 );
3305 }
3306 if stream_id % 2 == 0 {
3307 return Err(http2::Http2Error::Protocol(
3308 "client-initiated stream ID must be odd",
3309 )
3310 .into());
3311 }
3312 if stream_id <= last_stream_id {
3313 return Err(http2::Http2Error::Protocol(
3314 "stream ID must be greater than previous",
3315 )
3316 .into());
3317 }
3318 last_stream_id = stream_id;
3319 let (end_stream, mut header_block) =
3320 extract_header_block_fragment(frame.header.flags, &frame.payload)?;
3321
3322 if (frame.header.flags & FLAG_END_HEADERS) == 0 {
3323 loop {
3324 let cont = framed.read_frame(recv_max_frame_size).await?;
3325 self.record_bytes_in(
3326 (http2::FrameHeader::LEN + cont.payload.len()) as u64,
3327 );
3328 if cont.header.frame_type() != http2::FrameType::Continuation
3329 || cont.header.stream_id != stream_id
3330 {
3331 return Err(http2::Http2Error::Protocol(
3332 "expected CONTINUATION for header block",
3333 )
3334 .into());
3335 }
3336 header_block.extend_from_slice(&cont.payload);
3337 if header_block.len() > MAX_HEADER_BLOCK_SIZE {
3338 return Err(http2::Http2Error::Protocol(
3339 "header block exceeds maximum size",
3340 )
3341 .into());
3342 }
3343 if (cont.header.flags & FLAG_END_HEADERS) != 0 {
3344 break;
3345 }
3346 }
3347 }
3348
3349 let headers = hpack
3350 .decode(&header_block)
3351 .map_err(http2::Http2Error::from)?;
3352 let mut request = request_from_h2_headers(headers)?;
3353
3354 if !end_stream {
3355 let mut body = Vec::new();
3356 let mut stream_reset = false;
3357 let mut stream_received: u32 = 0;
3358 loop {
3359 let f = framed.read_frame(recv_max_frame_size).await?;
3360 self.record_bytes_in(
3361 (http2::FrameHeader::LEN + f.payload.len()) as u64,
3362 );
3363 match f.header.frame_type() {
3364 http2::FrameType::Data if f.header.stream_id == 0 => {
3365 return Err(http2::Http2Error::Protocol(
3366 "DATA must not be on stream 0",
3367 )
3368 .into());
3369 }
3370 http2::FrameType::Data if f.header.stream_id == stream_id => {
3371 let (data, data_end_stream) =
3372 extract_data_payload(f.header.flags, &f.payload)?;
3373 if body.len().saturating_add(data.len()) > default_body_limit {
3374 return Err(http2::Http2Error::Protocol(
3375 "request body exceeds configured limit",
3376 )
3377 .into());
3378 }
3379 body.extend_from_slice(data);
3380
3381 let data_len = u32::try_from(data.len()).unwrap_or(u32::MAX);
3384 stream_received += data_len;
3385 let conn_inc = flow_control.data_received_connection(data_len);
3386 let stream_inc =
3387 flow_control.stream_window_update(stream_received);
3388 if stream_inc > 0 {
3389 stream_received = 0;
3390 }
3391 send_window_updates(
3392 &mut framed,
3393 conn_inc,
3394 stream_id,
3395 stream_inc,
3396 )
3397 .await?;
3398
3399 if data_end_stream {
3400 break;
3401 }
3402 }
3403 http2::FrameType::RstStream => {
3404 validate_rst_stream_payload(f.header.stream_id, &f.payload)?;
3405 if f.header.stream_id == stream_id {
3406 stream_reset = true;
3407 break;
3408 }
3409 }
3410 http2::FrameType::PushPromise => {
3411 return Err(http2::Http2Error::Protocol(
3412 "PUSH_PROMISE not supported by server",
3413 )
3414 .into());
3415 }
3416 http2::FrameType::Settings
3417 | http2::FrameType::Ping
3418 | http2::FrameType::Goaway
3419 | http2::FrameType::WindowUpdate
3420 | http2::FrameType::Priority
3421 | http2::FrameType::Unknown => {
3422 if f.header.frame_type() == http2::FrameType::Goaway {
3423 validate_goaway_payload(&f.payload)?;
3424 return Ok(());
3425 }
3426 if f.header.frame_type() == http2::FrameType::Priority {
3427 validate_priority_payload(f.header.stream_id, &f.payload)?;
3428 }
3429 if f.header.frame_type() == http2::FrameType::WindowUpdate {
3430 validate_window_update_payload(&f.payload)?;
3431 let increment = u32::from_be_bytes([
3432 f.payload[0],
3433 f.payload[1],
3434 f.payload[2],
3435 f.payload[3],
3436 ]) & 0x7FFF_FFFF;
3437 if f.header.stream_id == 0 {
3438 apply_send_conn_window_update(
3439 &mut flow_control,
3440 increment,
3441 )?;
3442 }
3443 }
3444 if f.header.frame_type() == http2::FrameType::Ping {
3445 if f.header.stream_id != 0 || f.payload.len() != 8 {
3446 return Err(http2::Http2Error::Protocol(
3447 "invalid PING frame",
3448 )
3449 .into());
3450 }
3451 if (f.header.flags & FLAG_ACK) == 0 {
3452 framed
3453 .write_frame(
3454 http2::FrameType::Ping,
3455 FLAG_ACK,
3456 0,
3457 &f.payload,
3458 )
3459 .await?;
3460 self.record_bytes_out(
3461 (http2::FrameHeader::LEN + 8) as u64,
3462 );
3463 }
3464 }
3465 if f.header.frame_type() == http2::FrameType::Settings {
3466 let is_ack = validate_settings_frame(
3467 f.header.stream_id,
3468 f.header.flags,
3469 &f.payload,
3470 )?;
3471 if !is_ack {
3472 apply_http2_settings_with_fc(
3473 &mut hpack,
3474 &mut peer_max_frame_size,
3475 Some(&mut flow_control),
3476 &f.payload,
3477 )?;
3478 framed
3479 .write_frame(
3480 http2::FrameType::Settings,
3481 FLAG_ACK,
3482 0,
3483 &[],
3484 )
3485 .await?;
3486 self.record_bytes_out(http2::FrameHeader::LEN as u64);
3487 }
3488 }
3489 }
3490 _ => {
3491 return Err(http2::Http2Error::Protocol(
3492 "unsupported frame while reading request body",
3493 )
3494 .into());
3495 }
3496 }
3497 }
3498 if stream_reset {
3499 continue;
3500 }
3501 request.set_body(fastapi_core::Body::Bytes(body));
3502 }
3503
3504 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
3505 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
3506 let request_cx = Cx::for_testing_with_budget(request_budget);
3507
3508 let overrides = handler
3509 .dependency_overrides()
3510 .unwrap_or_else(|| Arc::new(fastapi_core::DependencyOverrides::new()));
3511
3512 let ctx = RequestContext::with_overrides_and_body_limit(
3513 request_cx,
3514 request_id,
3515 overrides,
3516 default_body_limit,
3517 );
3518
3519 if let Err(err) = validate_host_header(&request, &self.config) {
3520 let response = err.response();
3521 self.write_h2_response(
3522 &mut framed,
3523 response,
3524 stream_id,
3525 peer_max_frame_size,
3526 recv_max_frame_size,
3527 Some(&mut flow_control),
3528 )
3529 .await?;
3530 continue;
3531 }
3532 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
3533 self.write_h2_response(
3534 &mut framed,
3535 response,
3536 stream_id,
3537 peer_max_frame_size,
3538 recv_max_frame_size,
3539 Some(&mut flow_control),
3540 )
3541 .await?;
3542 continue;
3543 }
3544
3545 let response = handler.call(&ctx, &mut request).await;
3546 self.write_h2_response(
3547 &mut framed,
3548 response,
3549 stream_id,
3550 peer_max_frame_size,
3551 recv_max_frame_size,
3552 Some(&mut flow_control),
3553 )
3554 .await?;
3555 }
3556 http2::FrameType::WindowUpdate => {
3557 validate_window_update_payload(&frame.payload)?;
3558 let increment = u32::from_be_bytes([
3559 frame.payload[0],
3560 frame.payload[1],
3561 frame.payload[2],
3562 frame.payload[3],
3563 ]) & 0x7FFF_FFFF;
3564 if frame.header.stream_id == 0 {
3565 apply_send_conn_window_update(&mut flow_control, increment)?;
3566 }
3567 }
3568 _ => {
3569 handle_h2_idle_frame(&frame)?;
3570 }
3571 }
3572 }
3573 }
3574
3575 async fn handle_connection_handler(
3580 &self,
3581 cx: &Cx,
3582 mut stream: TcpStream,
3583 _peer_addr: SocketAddr,
3584 handler: &dyn fastapi_core::Handler,
3585 ) -> Result<(), ServerError> {
3586 let (proto, buffered) = sniff_protocol(&mut stream, self.config.keep_alive_timeout).await?;
3587 if !buffered.is_empty() {
3588 self.record_bytes_in(buffered.len() as u64);
3589 }
3590 if proto == SniffedProtocol::Http2PriorKnowledge {
3591 return self
3592 .handle_connection_handler_http2(cx, stream, handler)
3593 .await;
3594 }
3595
3596 let mut parser = StatefulParser::new().with_limits(self.config.parse_limits.clone());
3597 if !buffered.is_empty() {
3598 parser.feed(&buffered)?;
3599 }
3600 let mut read_buffer = vec![0u8; self.config.read_buffer_size];
3601 let mut response_writer = ResponseWriter::new();
3602 let mut requests_on_connection: usize = 0;
3603 let max_requests = self.config.max_requests_per_connection;
3604
3605 loop {
3606 if cx.is_cancel_requested() {
3608 return Ok(());
3609 }
3610
3611 let parse_result = parser.feed(&[])?;
3613
3614 let mut request = match parse_result {
3615 ParseStatus::Complete { request, .. } => request,
3616 ParseStatus::Incomplete => {
3617 let keep_alive_timeout = self.config.keep_alive_timeout;
3618 let bytes_read = if keep_alive_timeout.is_zero() {
3619 read_into_buffer(&mut stream, &mut read_buffer).await?
3620 } else {
3621 match read_with_timeout(&mut stream, &mut read_buffer, keep_alive_timeout)
3622 .await
3623 {
3624 Ok(0) => return Ok(()),
3625 Ok(n) => n,
3626 Err(e) if e.kind() == io::ErrorKind::TimedOut => {
3627 self.metrics_counters
3628 .total_timed_out
3629 .fetch_add(1, Ordering::Relaxed);
3630 return Err(ServerError::KeepAliveTimeout);
3631 }
3632 Err(e) => return Err(ServerError::Io(e)),
3633 }
3634 };
3635
3636 if bytes_read == 0 {
3637 return Ok(());
3638 }
3639
3640 self.record_bytes_in(bytes_read as u64);
3641
3642 match parser.feed(&read_buffer[..bytes_read])? {
3643 ParseStatus::Complete { request, .. } => request,
3644 ParseStatus::Incomplete => continue,
3645 }
3646 }
3647 };
3648
3649 requests_on_connection += 1;
3650
3651 let request_id = self.request_counter.fetch_add(1, Ordering::Relaxed);
3653 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
3654 let request_cx = Cx::for_testing_with_budget(request_budget);
3655 let ctx = RequestContext::new(request_cx, request_id);
3656
3657 if let Err(err) = validate_host_header(&request, &self.config) {
3659 let response = err.response().header("connection", b"close".to_vec());
3660 let response_write = response_writer.write(response);
3661 write_response(&mut stream, response_write).await?;
3662 return Ok(());
3663 }
3664
3665 if let Err(response) = self.config.pre_body_validators.validate_all(&request) {
3667 let response = response.header("connection", b"close".to_vec());
3668 let response_write = response_writer.write(response);
3669 write_response(&mut stream, response_write).await?;
3670 return Ok(());
3671 }
3672
3673 match ExpectHandler::check_expect(&request) {
3675 ExpectResult::NoExpectation => {}
3676 ExpectResult::ExpectsContinue => {
3677 write_raw_response(&mut stream, CONTINUE_RESPONSE).await?;
3678 }
3679 ExpectResult::UnknownExpectation(_) => {
3680 let response =
3681 ExpectHandler::expectation_failed("Unsupported Expect value".to_string());
3682 let response_write = response_writer.write(response);
3683 write_response(&mut stream, response_write).await?;
3684 return Ok(());
3685 }
3686 }
3687
3688 let response = handler.call(&ctx, &mut request).await;
3690
3691 let client_wants_keep_alive = should_keep_alive(&request);
3693 let server_will_keep_alive = client_wants_keep_alive
3694 && (max_requests == 0 || requests_on_connection < max_requests);
3695
3696 let response = if server_will_keep_alive {
3697 response.header("connection", b"keep-alive".to_vec())
3698 } else {
3699 response.header("connection", b"close".to_vec())
3700 };
3701
3702 let response_write = response_writer.write(response);
3703 if let ResponseWrite::Full(ref bytes) = response_write {
3704 self.record_bytes_out(bytes.len() as u64);
3705 }
3706 write_response(&mut stream, response_write).await?;
3707
3708 if !server_will_keep_alive {
3709 return Ok(());
3710 }
3711 }
3712 }
3713
3714 async fn accept_loop<H, Fut>(
3716 &self,
3717 cx: &Cx,
3718 listener: TcpListener,
3719 handler: H,
3720 ) -> Result<(), ServerError>
3721 where
3722 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
3723 Fut: Future<Output = Response> + Send + 'static,
3724 {
3725 let handler = Arc::new(handler);
3726
3727 loop {
3728 if cx.is_cancel_requested() {
3730 cx.trace("Server shutdown requested");
3731 return Ok(());
3732 }
3733
3734 if self.is_draining() {
3736 cx.trace("Server draining, stopping accept loop");
3737 return Err(ServerError::Shutdown);
3738 }
3739
3740 let (mut stream, peer_addr) = match listener.accept().await {
3742 Ok(conn) => conn,
3743 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
3744 continue;
3746 }
3747 Err(e) => {
3748 cx.trace(&format!("Accept error: {e}"));
3749 if is_fatal_accept_error(&e) {
3752 return Err(ServerError::Io(e));
3753 }
3754 continue;
3755 }
3756 };
3757
3758 if !self.try_acquire_connection() {
3760 cx.trace(&format!(
3761 "Connection limit reached ({}), rejecting {peer_addr}",
3762 self.config.max_connections
3763 ));
3764
3765 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
3767 .header("connection", b"close".to_vec())
3768 .body(fastapi_core::ResponseBody::Bytes(
3769 b"503 Service Unavailable: connection limit reached".to_vec(),
3770 ));
3771 let mut writer = crate::response::ResponseWriter::new();
3772 let response_bytes = writer.write(response);
3773 let _ = write_response(&mut stream, response_bytes).await;
3774 continue;
3775 }
3776
3777 if self.config.tcp_nodelay {
3779 let _ = stream.set_nodelay(true);
3780 }
3781
3782 cx.trace(&format!(
3783 "Accepted connection from {peer_addr} ({}/{})",
3784 self.current_connections(),
3785 if self.config.max_connections == 0 {
3786 "∞".to_string()
3787 } else {
3788 self.config.max_connections.to_string()
3789 }
3790 ));
3791
3792 let request_id = self.next_request_id();
3797 let request_budget = Budget::new().with_deadline(self.config.request_timeout);
3798
3799 let request_cx = Cx::for_testing_with_budget(request_budget);
3804 let ctx = RequestContext::new(request_cx, request_id);
3805
3806 let result = self
3808 .handle_connection(&ctx, stream, peer_addr, &*handler)
3809 .await;
3810
3811 self.release_connection();
3813
3814 if let Err(e) = result {
3815 cx.trace(&format!("Connection error from {peer_addr}: {e}"));
3816 }
3817 }
3818 }
3819
3820 async fn handle_connection<H, Fut>(
3826 &self,
3827 ctx: &RequestContext,
3828 stream: TcpStream,
3829 peer_addr: SocketAddr,
3830 handler: &H,
3831 ) -> Result<(), ServerError>
3832 where
3833 H: Fn(RequestContext, &mut Request) -> Fut + Send + Sync,
3834 Fut: Future<Output = Response> + Send,
3835 {
3836 process_connection(
3837 ctx.cx(),
3838 &self.request_counter,
3839 stream,
3840 peer_addr,
3841 &self.config,
3842 |ctx, req| handler(ctx, req),
3843 )
3844 .await
3845 }
3846}
3847
3848#[derive(Debug, Clone, PartialEq, Eq)]
3853pub struct ServerMetrics {
3854 pub active_connections: u64,
3856 pub total_accepted: u64,
3858 pub total_rejected: u64,
3860 pub total_timed_out: u64,
3862 pub total_requests: u64,
3864 pub bytes_in: u64,
3866 pub bytes_out: u64,
3868}
3869
3870#[derive(Debug)]
3875struct MetricsCounters {
3876 total_accepted: AtomicU64,
3877 total_rejected: AtomicU64,
3878 total_timed_out: AtomicU64,
3879 bytes_in: AtomicU64,
3880 bytes_out: AtomicU64,
3881}
3882
3883impl MetricsCounters {
3884 fn new() -> Self {
3885 Self {
3886 total_accepted: AtomicU64::new(0),
3887 total_rejected: AtomicU64::new(0),
3888 total_timed_out: AtomicU64::new(0),
3889 bytes_in: AtomicU64::new(0),
3890 bytes_out: AtomicU64::new(0),
3891 }
3892 }
3893}
3894
3895impl Default for TcpServer {
3896 fn default() -> Self {
3897 Self::new(ServerConfig::default())
3898 }
3899}
3900
3901fn is_fatal_accept_error(e: &io::Error) -> bool {
3903 matches!(
3905 e.kind(),
3906 io::ErrorKind::NotConnected | io::ErrorKind::InvalidInput
3907 )
3908}
3909
3910pub async fn read_into_buffer(stream: &mut TcpStream, buffer: &mut [u8]) -> io::Result<usize> {
3917 use std::future::poll_fn;
3918
3919 poll_fn(|cx| {
3920 let mut read_buf = ReadBuf::new(buffer);
3921 match Pin::new(&mut *stream).poll_read(cx, &mut read_buf) {
3922 Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
3923 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
3924 Poll::Pending => Poll::Pending,
3925 }
3926 })
3927 .await
3928}
3929
3930async fn read_with_timeout(
3948 stream: &mut TcpStream,
3949 buffer: &mut [u8],
3950 timeout_duration: Duration,
3951) -> io::Result<usize> {
3952 let now = current_time();
3954
3955 let read_future = Box::pin(read_into_buffer(stream, buffer));
3957
3958 match timeout(now, timeout_duration, read_future).await {
3960 Ok(result) => result,
3961 Err(_elapsed) => Err(io::Error::new(
3962 io::ErrorKind::TimedOut,
3963 "keep-alive timeout expired",
3964 )),
3965 }
3966}
3967
3968#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3969enum SniffedProtocol {
3970 Http1,
3971 Http2PriorKnowledge,
3972}
3973
3974async fn sniff_protocol(
3978 stream: &mut TcpStream,
3979 keep_alive_timeout: Duration,
3980) -> io::Result<(SniffedProtocol, Vec<u8>)> {
3981 let mut buffered: Vec<u8> = Vec::new();
3982 let preface = http2::PREFACE;
3983
3984 while buffered.len() < preface.len() {
3985 let mut tmp = vec![0u8; preface.len() - buffered.len()];
3986 let n = if keep_alive_timeout.is_zero() {
3987 read_into_buffer(stream, &mut tmp).await?
3988 } else {
3989 read_with_timeout(stream, &mut tmp, keep_alive_timeout).await?
3990 };
3991 if n == 0 {
3992 return Ok((SniffedProtocol::Http1, buffered));
3994 }
3995
3996 buffered.extend_from_slice(&tmp[..n]);
3997 if !preface.starts_with(&buffered) {
3998 return Ok((SniffedProtocol::Http1, buffered));
3999 }
4000 }
4001
4002 Ok((SniffedProtocol::Http2PriorKnowledge, buffered))
4003}
4004
4005fn apply_http2_settings(
4006 hpack: &mut http2::HpackDecoder,
4007 max_frame_size: &mut u32,
4008 payload: &[u8],
4009) -> Result<(), http2::Http2Error> {
4010 apply_http2_settings_with_fc(hpack, max_frame_size, None, payload)
4011}
4012
4013fn apply_http2_settings_with_fc(
4014 hpack: &mut http2::HpackDecoder,
4015 max_frame_size: &mut u32,
4016 mut flow_control: Option<&mut http2::H2FlowControl>,
4017 payload: &[u8],
4018) -> Result<(), http2::Http2Error> {
4019 if payload.len() % 6 != 0 {
4021 return Err(http2::Http2Error::Protocol(
4022 "SETTINGS length must be a multiple of 6",
4023 ));
4024 }
4025
4026 for chunk in payload.chunks_exact(6) {
4027 let id = u16::from_be_bytes([chunk[0], chunk[1]]);
4028 let value = u32::from_be_bytes([chunk[2], chunk[3], chunk[4], chunk[5]]);
4029 match id {
4030 0x1 => {
4031 let capped = (value as usize).min(MAX_HPACK_TABLE_SIZE);
4033 hpack.set_dynamic_table_max_size(capped);
4034 }
4035 0x3 => {
4036 if value > 0x7FFF_FFFF {
4039 return Err(http2::Http2Error::Protocol(
4040 "SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum",
4041 ));
4042 }
4043 if let Some(ref mut fc) = flow_control {
4044 fc.set_initial_window_size(value);
4045 fc.set_peer_initial_window_size(value);
4049 }
4050 }
4051 0x5 => {
4052 if !(16_384..=16_777_215).contains(&value) {
4054 return Err(http2::Http2Error::Protocol(
4055 "invalid SETTINGS_MAX_FRAME_SIZE",
4056 ));
4057 }
4058 *max_frame_size = value;
4059 }
4060 0x2 => {
4061 if value > 1 {
4063 return Err(http2::Http2Error::Protocol(
4064 "SETTINGS_ENABLE_PUSH must be 0 or 1",
4065 ));
4066 }
4067 }
4069 0x4 => {
4070 }
4073 0x6 => {
4074 hpack.set_max_header_list_size(value as usize);
4076 }
4077 _ => {
4078 }
4080 }
4081 }
4082 Ok(())
4083}
4084
4085fn validate_settings_frame(
4086 stream_id: u32,
4087 flags: u8,
4088 payload: &[u8],
4089) -> Result<bool, http2::Http2Error> {
4090 const FLAG_ACK: u8 = 0x1;
4091 if stream_id != 0 {
4092 return Err(http2::Http2Error::Protocol("SETTINGS must be on stream 0"));
4093 }
4094
4095 let is_ack = (flags & FLAG_ACK) != 0;
4096 if is_ack && !payload.is_empty() {
4097 return Err(http2::Http2Error::Protocol(
4098 "SETTINGS ACK frame must have empty payload",
4099 ));
4100 }
4101
4102 Ok(is_ack)
4103}
4104
4105fn validate_window_update_payload(payload: &[u8]) -> Result<(), http2::Http2Error> {
4106 if payload.len() != 4 {
4107 return Err(http2::Http2Error::Protocol(
4108 "WINDOW_UPDATE payload must be 4 bytes",
4109 ));
4110 }
4111
4112 let raw = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
4113 let increment = raw & 0x7FFF_FFFF;
4114 if increment == 0 {
4115 return Err(http2::Http2Error::Protocol(
4116 "WINDOW_UPDATE increment must be non-zero",
4117 ));
4118 }
4119
4120 Ok(())
4121}
4122
4123fn handle_h2_idle_frame(frame: &http2::Frame) -> Result<(), http2::Http2Error> {
4124 match frame.header.frame_type() {
4125 http2::FrameType::RstStream => {
4126 validate_rst_stream_payload(frame.header.stream_id, &frame.payload)
4127 }
4128 http2::FrameType::Priority => {
4129 validate_priority_payload(frame.header.stream_id, &frame.payload)
4130 }
4131 http2::FrameType::Data => Err(http2::Http2Error::Protocol(
4132 "unexpected DATA frame outside active request stream",
4133 )),
4134 http2::FrameType::Continuation => Err(http2::Http2Error::Protocol(
4135 "unexpected CONTINUATION frame outside header block",
4136 )),
4137 http2::FrameType::Unknown => Ok(()),
4138 _ => Ok(()),
4139 }
4140}
4141
4142const MAX_FLOW_CONTROL_WINDOW: i64 = 0x7FFF_FFFF;
4144
4145const SERVER_SETTINGS_PAYLOAD: &[u8] = &[
4149 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, ];
4152
4153const MAX_HPACK_TABLE_SIZE: usize = 64 * 1024;
4157
4158const MAX_HEADER_BLOCK_SIZE: usize = 128 * 1024;
4165
4166fn apply_send_conn_window_update(
4169 fc: &mut http2::H2FlowControl,
4170 increment: u32,
4171) -> Result<(), http2::Http2Error> {
4172 let new_window = fc.send_conn_window() + i64::from(increment);
4173 if new_window > MAX_FLOW_CONTROL_WINDOW {
4174 return Err(http2::Http2Error::Protocol(
4175 "WINDOW_UPDATE causes flow-control window to exceed 2^31-1",
4176 ));
4177 }
4178 fc.peer_window_update_connection(increment);
4179 Ok(())
4180}
4181
4182fn apply_peer_window_update_for_send(
4183 flow_control: &mut http2::H2FlowControl,
4184 stream_send_window: &mut i64,
4185 current_stream_id: u32,
4186 frame_stream_id: u32,
4187 payload: &[u8],
4188) -> Result<(), http2::Http2Error> {
4189 validate_window_update_payload(payload)?;
4190
4191 let increment =
4192 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]) & 0x7FFF_FFFF;
4193 if frame_stream_id == 0 {
4194 apply_send_conn_window_update(flow_control, increment)?;
4195 } else if frame_stream_id == current_stream_id {
4196 let new_window = *stream_send_window + i64::from(increment);
4197 if new_window > MAX_FLOW_CONTROL_WINDOW {
4198 return Err(http2::Http2Error::Protocol(
4199 "WINDOW_UPDATE causes flow-control window to exceed 2^31-1",
4200 ));
4201 }
4202 *stream_send_window = new_window;
4203 }
4204
4205 Ok(())
4206}
4207
4208fn apply_peer_settings_for_send(
4209 flow_control: &mut http2::H2FlowControl,
4210 stream_send_window: &mut i64,
4211 peer_max_frame_size: &mut u32,
4212 payload: &[u8],
4213) -> Result<(), http2::Http2Error> {
4214 if payload.len() % 6 != 0 {
4215 return Err(http2::Http2Error::Protocol(
4216 "SETTINGS length must be a multiple of 6",
4217 ));
4218 }
4219
4220 for chunk in payload.chunks_exact(6) {
4221 let id = u16::from_be_bytes([chunk[0], chunk[1]]);
4222 let value = u32::from_be_bytes([chunk[2], chunk[3], chunk[4], chunk[5]]);
4223
4224 if id == 0x3 {
4225 if value > 0x7FFF_FFFF {
4227 return Err(http2::Http2Error::Protocol(
4228 "SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum",
4229 ));
4230 }
4231 let old = i64::from(flow_control.peer_initial_window_size());
4232 let new = i64::from(value);
4233 let delta = new - old;
4234 let updated = *stream_send_window + delta;
4235 if updated > MAX_FLOW_CONTROL_WINDOW {
4236 return Err(http2::Http2Error::Protocol(
4237 "SETTINGS_INITIAL_WINDOW_SIZE change causes stream window to exceed 2^31-1",
4238 ));
4239 }
4240 flow_control.set_peer_initial_window_size(value);
4241 *stream_send_window = updated;
4242 } else if id == 0x5 {
4243 if !(16_384..=16_777_215).contains(&value) {
4245 return Err(http2::Http2Error::Protocol(
4246 "invalid SETTINGS_MAX_FRAME_SIZE",
4247 ));
4248 }
4249 *peer_max_frame_size = value;
4250 }
4251 }
4252
4253 Ok(())
4254}
4255
4256fn window_update_payload(increment: u32) -> [u8; 4] {
4258 (increment & 0x7FFF_FFFF).to_be_bytes()
4259}
4260
4261async fn send_window_updates(
4264 framed: &mut http2::FramedH2,
4265 conn_increment: u32,
4266 stream_id: u32,
4267 stream_increment: u32,
4268) -> Result<(), http2::Http2Error> {
4269 if conn_increment > 0 {
4270 let payload = window_update_payload(conn_increment);
4271 framed
4272 .write_frame(http2::FrameType::WindowUpdate, 0, 0, &payload)
4273 .await?;
4274 }
4275 if stream_increment > 0 {
4276 let payload = window_update_payload(stream_increment);
4277 framed
4278 .write_frame(http2::FrameType::WindowUpdate, 0, stream_id, &payload)
4279 .await?;
4280 }
4281 Ok(())
4282}
4283
4284#[allow(dead_code)]
4286mod h2_error_code {
4287 pub const NO_ERROR: u32 = 0x0;
4288 pub const PROTOCOL_ERROR: u32 = 0x1;
4289 pub const FLOW_CONTROL_ERROR: u32 = 0x3;
4290 pub const SETTINGS_TIMEOUT: u32 = 0x4;
4291 pub const STREAM_CLOSED: u32 = 0x5;
4292 pub const FRAME_SIZE_ERROR: u32 = 0x6;
4293 pub const REFUSED_STREAM: u32 = 0x7;
4294 pub const CANCEL: u32 = 0x8;
4295 pub const ENHANCE_YOUR_CALM: u32 = 0xb;
4296}
4297
4298fn validate_goaway_payload(payload: &[u8]) -> Result<(), http2::Http2Error> {
4300 if payload.len() < 8 {
4301 return Err(http2::Http2Error::Protocol(
4302 "GOAWAY payload must be at least 8 bytes",
4303 ));
4304 }
4305 Ok(())
4306}
4307
4308fn goaway_payload(last_stream_id: u32, error_code: u32) -> [u8; 8] {
4310 let mut buf = [0u8; 8];
4311 buf[..4].copy_from_slice(&(last_stream_id & 0x7FFF_FFFF).to_be_bytes());
4312 buf[4..].copy_from_slice(&error_code.to_be_bytes());
4313 buf
4314}
4315
4316async fn send_goaway(
4318 framed: &mut http2::FramedH2,
4319 last_stream_id: u32,
4320 error_code: u32,
4321) -> Result<(), http2::Http2Error> {
4322 let payload = goaway_payload(last_stream_id, error_code);
4323 framed
4324 .write_frame(http2::FrameType::Goaway, 0, 0, &payload)
4325 .await
4326}
4327
4328fn validate_rst_stream_payload(stream_id: u32, payload: &[u8]) -> Result<(), http2::Http2Error> {
4329 if stream_id == 0 {
4330 return Err(http2::Http2Error::Protocol(
4331 "RST_STREAM must not be on stream 0",
4332 ));
4333 }
4334 if payload.len() != 4 {
4335 return Err(http2::Http2Error::Protocol(
4336 "RST_STREAM payload must be 4 bytes",
4337 ));
4338 }
4339 Ok(())
4340}
4341
4342fn validate_priority_payload(stream_id: u32, payload: &[u8]) -> Result<(), http2::Http2Error> {
4343 if stream_id == 0 {
4344 return Err(http2::Http2Error::Protocol(
4345 "PRIORITY must not be on stream 0",
4346 ));
4347 }
4348
4349 if payload.len() != 5 {
4350 return Err(http2::Http2Error::Protocol(
4351 "PRIORITY payload must be 5 bytes",
4352 ));
4353 }
4354
4355 let dependency_raw = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
4356 let dependency_stream_id = dependency_raw & 0x7FFF_FFFF;
4357 if dependency_stream_id == stream_id {
4358 return Err(http2::Http2Error::Protocol(
4359 "PRIORITY stream dependency must not reference itself",
4360 ));
4361 }
4362
4363 Ok(())
4364}
4365
4366fn extract_header_block_fragment(
4367 flags: u8,
4368 payload: &[u8],
4369) -> Result<(bool, Vec<u8>), http2::Http2Error> {
4370 const FLAG_END_STREAM: u8 = 0x1;
4371 const FLAG_PADDED: u8 = 0x8;
4372 const FLAG_PRIORITY: u8 = 0x20;
4373
4374 let end_stream = (flags & FLAG_END_STREAM) != 0;
4375 let mut idx = 0usize;
4376
4377 let pad_len = if (flags & FLAG_PADDED) != 0 {
4378 if payload.is_empty() {
4379 return Err(http2::Http2Error::Protocol(
4380 "HEADERS PADDED set with empty payload",
4381 ));
4382 }
4383 let v = payload[0] as usize;
4384 idx += 1;
4385 v
4386 } else {
4387 0
4388 };
4389
4390 if (flags & FLAG_PRIORITY) != 0 {
4391 if payload.len().saturating_sub(idx) < 5 {
4393 return Err(http2::Http2Error::Protocol(
4394 "HEADERS PRIORITY set but too short",
4395 ));
4396 }
4397 idx += 5;
4398 }
4399
4400 if payload.len() < idx {
4401 return Err(http2::Http2Error::Protocol("invalid HEADERS payload"));
4402 }
4403 let frag = &payload[idx..];
4404 if frag.len() < pad_len {
4405 return Err(http2::Http2Error::Protocol(
4406 "invalid HEADERS padding length",
4407 ));
4408 }
4409 let end = frag.len() - pad_len;
4410 Ok((end_stream, frag[..end].to_vec()))
4411}
4412
4413fn extract_data_payload(flags: u8, payload: &[u8]) -> Result<(&[u8], bool), http2::Http2Error> {
4414 const FLAG_END_STREAM: u8 = 0x1;
4415 const FLAG_PADDED: u8 = 0x8;
4416
4417 let end_stream = (flags & FLAG_END_STREAM) != 0;
4418 if (flags & FLAG_PADDED) == 0 {
4419 return Ok((payload, end_stream));
4420 }
4421 if payload.is_empty() {
4422 return Err(http2::Http2Error::Protocol(
4423 "DATA PADDED set with empty payload",
4424 ));
4425 }
4426 let pad_len = payload[0] as usize;
4427 let data = &payload[1..];
4428 if data.len() < pad_len {
4429 return Err(http2::Http2Error::Protocol("invalid DATA padding length"));
4430 }
4431 Ok((&data[..data.len() - pad_len], end_stream))
4432}
4433
4434fn request_from_h2_headers(headers: http2::HeaderList) -> Result<Request, http2::Http2Error> {
4435 let mut method: Option<fastapi_core::Method> = None;
4436 let mut path: Option<String> = None;
4437 let mut authority: Option<Vec<u8>> = None;
4438 let mut saw_regular_headers = false;
4439
4440 let mut req_headers: Vec<(String, Vec<u8>)> = Vec::new();
4441
4442 for (name, value) in headers {
4443 if name.starts_with(b":") {
4444 if saw_regular_headers {
4445 return Err(http2::Http2Error::Protocol(
4446 "pseudo-headers must appear before regular headers",
4447 ));
4448 }
4449 match name.as_slice() {
4450 b":method" => {
4451 if method.is_some() {
4452 return Err(http2::Http2Error::Protocol(
4453 "duplicate :method pseudo-header",
4454 ));
4455 }
4456 method = Some(
4457 fastapi_core::Method::from_bytes(&value)
4458 .ok_or(http2::Http2Error::Protocol("invalid :method"))?,
4459 );
4460 }
4461 b":path" => {
4462 if path.is_some() {
4463 return Err(http2::Http2Error::Protocol("duplicate :path pseudo-header"));
4464 }
4465 let s = std::str::from_utf8(&value)
4466 .map_err(|_| http2::Http2Error::Protocol("non-utf8 :path"))?;
4467 path = Some(s.to_string());
4468 }
4469 b":authority" => {
4470 if authority.is_some() {
4471 return Err(http2::Http2Error::Protocol(
4472 "duplicate :authority pseudo-header",
4473 ));
4474 }
4475 authority = Some(value);
4476 }
4477 b":scheme" => {}
4478 _ => return Err(http2::Http2Error::Protocol("unknown pseudo-header")),
4479 }
4480 continue;
4481 }
4482
4483 saw_regular_headers = true;
4484 let n = std::str::from_utf8(&name)
4485 .map_err(|_| http2::Http2Error::Protocol("non-utf8 header name"))?;
4486 req_headers.push((n.to_string(), value));
4487 }
4488
4489 let method = method.ok_or(http2::Http2Error::Protocol("missing :method"))?;
4490 let raw_path = path.ok_or(http2::Http2Error::Protocol("missing :path"))?;
4491 let (path_only, query) = match raw_path.split_once('?') {
4492 Some((p, q)) => (p.to_string(), Some(q.to_string())),
4493 None => (raw_path, None),
4494 };
4495
4496 let mut req = Request::with_version(method, path_only, fastapi_core::HttpVersion::Http2);
4497 req.set_query(query);
4498
4499 if let Some(auth) = authority {
4500 req.headers_mut().insert("host", auth);
4501 }
4502
4503 for (n, v) in req_headers {
4504 req.headers_mut().insert(n, v);
4505 }
4506
4507 Ok(req)
4508}
4509
4510fn is_h2_forbidden_header_name(name: &str) -> bool {
4511 name.eq_ignore_ascii_case("connection")
4514 || name.eq_ignore_ascii_case("keep-alive")
4515 || name.eq_ignore_ascii_case("proxy-connection")
4516 || name.eq_ignore_ascii_case("transfer-encoding")
4517 || name.eq_ignore_ascii_case("upgrade")
4518 || name.eq_ignore_ascii_case("te")
4519}
4520
4521async fn write_raw_response(stream: &mut TcpStream, bytes: &[u8]) -> io::Result<()> {
4525 use std::future::poll_fn;
4526 write_all(stream, bytes).await?;
4527 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await?;
4528 Ok(())
4529}
4530
4531pub async fn write_response(stream: &mut TcpStream, response: ResponseWrite) -> io::Result<()> {
4536 use std::future::poll_fn;
4537
4538 match response {
4539 ResponseWrite::Full(bytes) => {
4540 write_all(stream, &bytes).await?;
4541 }
4542 ResponseWrite::Stream(mut encoder) => {
4543 loop {
4545 let chunk = poll_fn(|cx| Pin::new(&mut encoder).poll_next(cx)).await;
4546 match chunk {
4547 Some(bytes) => {
4548 write_all(stream, &bytes).await?;
4549 }
4550 None => break,
4551 }
4552 }
4553 }
4554 }
4555
4556 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await?;
4558
4559 Ok(())
4560}
4561
4562pub async fn write_all(stream: &mut TcpStream, mut buf: &[u8]) -> io::Result<()> {
4564 use std::future::poll_fn;
4565
4566 while !buf.is_empty() {
4567 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, buf)).await?;
4568 if n == 0 {
4569 return Err(io::Error::new(
4570 io::ErrorKind::WriteZero,
4571 "failed to write whole buffer",
4572 ));
4573 }
4574 buf = &buf[n..];
4575 }
4576 Ok(())
4577}
4578
4579pub struct Server {
4591 parser: Parser,
4592}
4593
4594impl Server {
4595 #[must_use]
4597 pub fn new() -> Self {
4598 Self {
4599 parser: Parser::new(),
4600 }
4601 }
4602
4603 pub fn parse_request(&self, bytes: &[u8]) -> Result<Request, ParseError> {
4609 self.parser.parse(bytes)
4610 }
4611
4612 #[must_use]
4614 pub fn write_response(&self, response: Response) -> ResponseWrite {
4615 let mut writer = ResponseWriter::new();
4616 writer.write(response)
4617 }
4618}
4619
4620impl Default for Server {
4621 fn default() -> Self {
4622 Self::new()
4623 }
4624}
4625
4626#[cfg(test)]
4627mod tests {
4628 use super::*;
4629 use std::future::Future;
4630
4631 fn block_on<F: Future>(f: F) -> F::Output {
4632 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
4633 .build()
4634 .expect("test runtime must build");
4635 rt.block_on(f)
4636 }
4637
4638 #[test]
4639 fn server_config_builder() {
4640 let config = ServerConfig::new("0.0.0.0:3000")
4641 .with_request_timeout_secs(60)
4642 .with_max_connections(1000)
4643 .with_tcp_nodelay(false)
4644 .with_allowed_hosts(["example.com", "api.example.com"])
4645 .with_trust_x_forwarded_host(true);
4646
4647 assert_eq!(config.bind_addr, "0.0.0.0:3000");
4648 assert_eq!(config.request_timeout, Time::from_secs(60));
4649 assert_eq!(config.max_connections, 1000);
4650 assert!(!config.tcp_nodelay);
4651 assert_eq!(config.allowed_hosts.len(), 2);
4652 assert!(config.trust_x_forwarded_host);
4653 }
4654
4655 #[test]
4656 fn server_config_defaults() {
4657 let config = ServerConfig::default();
4658 assert_eq!(config.bind_addr, "127.0.0.1:8080");
4659 assert_eq!(
4660 config.request_timeout,
4661 Time::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)
4662 );
4663 assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
4664 assert!(config.tcp_nodelay);
4665 assert!(config.allowed_hosts.is_empty());
4666 assert!(!config.trust_x_forwarded_host);
4667 }
4668
4669 #[test]
4670 fn tcp_server_creates_request_ids() {
4671 let server = TcpServer::default();
4672 let id1 = server.next_request_id();
4673 let id2 = server.next_request_id();
4674 let id3 = server.next_request_id();
4675
4676 assert_eq!(id1, 0);
4677 assert_eq!(id2, 1);
4678 assert_eq!(id3, 2);
4679 }
4680
4681 #[test]
4682 fn server_error_display() {
4683 let io_err = ServerError::Io(io::Error::new(io::ErrorKind::AddrInUse, "address in use"));
4684 assert!(io_err.to_string().contains("IO error"));
4685
4686 let shutdown_err = ServerError::Shutdown;
4687 assert_eq!(shutdown_err.to_string(), "Server shutdown");
4688
4689 let limit_err = ServerError::ConnectionLimitReached;
4690 assert_eq!(limit_err.to_string(), "Connection limit reached");
4691 }
4692
4693 #[test]
4694 fn sync_server_parses_request() {
4695 let server = Server::new();
4696 let request = b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n";
4697 let result = server.parse_request(request);
4698 assert!(result.is_ok());
4699 }
4700
4701 #[test]
4702 fn window_update_payload_validation_accepts_non_zero_increment() {
4703 let payload = 1u32.to_be_bytes();
4704 assert!(validate_window_update_payload(&payload).is_ok());
4705 }
4706
4707 #[test]
4708 fn window_update_payload_validation_rejects_bad_length() {
4709 let err = validate_window_update_payload(&[0, 0, 0]).unwrap_err();
4710 assert!(
4711 err.to_string()
4712 .contains("WINDOW_UPDATE payload must be 4 bytes")
4713 );
4714 }
4715
4716 #[test]
4717 fn window_update_payload_validation_rejects_zero_increment() {
4718 let payload = 0u32.to_be_bytes();
4719 let err = validate_window_update_payload(&payload).unwrap_err();
4720 assert!(
4721 err.to_string()
4722 .contains("WINDOW_UPDATE increment must be non-zero")
4723 );
4724 }
4725
4726 #[test]
4727 fn settings_frame_validation_accepts_non_ack_payload() {
4728 let payload = [0u8; 6];
4729 let is_ack = validate_settings_frame(0, 0, &payload).unwrap();
4730 assert!(!is_ack);
4731 }
4732
4733 #[test]
4734 fn settings_frame_validation_accepts_empty_ack_payload() {
4735 let is_ack = validate_settings_frame(0, 0x1, &[]).unwrap();
4736 assert!(is_ack);
4737 }
4738
4739 #[test]
4740 fn settings_frame_validation_rejects_non_zero_stream() {
4741 let err = validate_settings_frame(1, 0, &[]).unwrap_err();
4742 assert!(err.to_string().contains("SETTINGS must be on stream 0"));
4743 }
4744
4745 #[test]
4746 fn settings_frame_validation_rejects_non_empty_ack_payload() {
4747 let err = validate_settings_frame(0, 0x1, &[0, 0, 0, 0, 0, 0]).unwrap_err();
4748 assert!(
4749 err.to_string()
4750 .contains("SETTINGS ACK frame must have empty payload")
4751 );
4752 }
4753
4754 #[test]
4755 fn settings_enable_push_accepts_zero() {
4756 let payload = [0x00, 0x02, 0x00, 0x00, 0x00, 0x00];
4758 let mut hpack = http2::HpackDecoder::new();
4759 let mut max_frame_size = 16384u32;
4760 assert!(apply_http2_settings(&mut hpack, &mut max_frame_size, &payload).is_ok());
4761 }
4762
4763 #[test]
4764 fn settings_enable_push_accepts_one() {
4765 let payload = [0x00, 0x02, 0x00, 0x00, 0x00, 0x01];
4766 let mut hpack = http2::HpackDecoder::new();
4767 let mut max_frame_size = 16384u32;
4768 assert!(apply_http2_settings(&mut hpack, &mut max_frame_size, &payload).is_ok());
4769 }
4770
4771 #[test]
4772 fn settings_enable_push_rejects_invalid_value() {
4773 let payload = [0x00, 0x02, 0x00, 0x00, 0x00, 0x02];
4774 let mut hpack = http2::HpackDecoder::new();
4775 let mut max_frame_size = 16384u32;
4776 let err = apply_http2_settings(&mut hpack, &mut max_frame_size, &payload).unwrap_err();
4777 assert!(
4778 err.to_string()
4779 .contains("SETTINGS_ENABLE_PUSH must be 0 or 1")
4780 );
4781 }
4782
4783 #[test]
4784 fn rst_stream_payload_validation_accepts_valid_payload() {
4785 let payload = 8u32.to_be_bytes();
4786 assert!(validate_rst_stream_payload(1, &payload).is_ok());
4787 }
4788
4789 #[test]
4790 fn rst_stream_payload_validation_rejects_stream_zero() {
4791 let payload = 8u32.to_be_bytes();
4792 let err = validate_rst_stream_payload(0, &payload).unwrap_err();
4793 assert!(
4794 err.to_string()
4795 .contains("RST_STREAM must not be on stream 0")
4796 );
4797 }
4798
4799 #[test]
4800 fn rst_stream_payload_validation_rejects_bad_length() {
4801 let err = validate_rst_stream_payload(1, &[0, 0, 0]).unwrap_err();
4802 assert!(
4803 err.to_string()
4804 .contains("RST_STREAM payload must be 4 bytes")
4805 );
4806 }
4807
4808 #[test]
4809 fn priority_payload_validation_accepts_valid_priority() {
4810 let payload = [0, 0, 0, 0, 16];
4811 assert!(validate_priority_payload(1, &payload).is_ok());
4812 }
4813
4814 #[test]
4815 fn priority_payload_validation_rejects_stream_zero() {
4816 let payload = [0, 0, 0, 0, 16];
4817 let err = validate_priority_payload(0, &payload).unwrap_err();
4818 assert!(err.to_string().contains("PRIORITY must not be on stream 0"));
4819 }
4820
4821 #[test]
4822 fn priority_payload_validation_rejects_bad_length() {
4823 let err = validate_priority_payload(1, &[0, 0, 0, 0]).unwrap_err();
4824 assert!(err.to_string().contains("PRIORITY payload must be 5 bytes"));
4825 }
4826
4827 #[test]
4828 fn priority_payload_validation_rejects_self_dependency() {
4829 let payload = 1u32.to_be_bytes();
4830 let mut with_weight = [0u8; 5];
4831 with_weight[..4].copy_from_slice(&payload);
4832 with_weight[4] = 16;
4833 let err = validate_priority_payload(1, &with_weight).unwrap_err();
4834 assert!(
4835 err.to_string()
4836 .contains("PRIORITY stream dependency must not reference itself")
4837 );
4838 }
4839
4840 #[test]
4841 fn goaway_payload_validation_accepts_valid_payload() {
4842 let payload = goaway_payload(0, 0);
4843 assert!(validate_goaway_payload(&payload).is_ok());
4844 }
4845
4846 #[test]
4847 fn goaway_payload_validation_accepts_payload_with_debug_data() {
4848 let mut payload = Vec::from(goaway_payload(1, 0).as_slice());
4849 payload.extend_from_slice(b"debug info");
4850 assert!(validate_goaway_payload(&payload).is_ok());
4851 }
4852
4853 #[test]
4854 fn goaway_payload_validation_rejects_short_payload() {
4855 let err = validate_goaway_payload(&[0, 0, 0]).unwrap_err();
4856 assert!(
4857 err.to_string()
4858 .contains("GOAWAY payload must be at least 8 bytes")
4859 );
4860 }
4861
4862 #[test]
4863 fn goaway_payload_validation_rejects_empty() {
4864 let err = validate_goaway_payload(&[]).unwrap_err();
4865 assert!(
4866 err.to_string()
4867 .contains("GOAWAY payload must be at least 8 bytes")
4868 );
4869 }
4870
4871 fn h2_test_frame(
4872 frame_type: http2::FrameType,
4873 stream_id: u32,
4874 payload: Vec<u8>,
4875 ) -> http2::Frame {
4876 http2::Frame {
4877 header: http2::FrameHeader {
4878 length: payload.len() as u32,
4879 frame_type: frame_type as u8,
4880 flags: 0,
4881 stream_id,
4882 },
4883 payload,
4884 }
4885 }
4886
4887 #[test]
4888 fn h2_idle_frame_rejects_data_outside_request_stream() {
4889 let frame = h2_test_frame(http2::FrameType::Data, 1, Vec::new());
4890 let err = handle_h2_idle_frame(&frame).unwrap_err();
4891 assert!(
4892 err.to_string()
4893 .contains("unexpected DATA frame outside active request stream")
4894 );
4895 }
4896
4897 #[test]
4898 fn h2_idle_frame_rejects_continuation_outside_header_block() {
4899 let frame = h2_test_frame(http2::FrameType::Continuation, 1, Vec::new());
4900 let err = handle_h2_idle_frame(&frame).unwrap_err();
4901 assert!(
4902 err.to_string()
4903 .contains("unexpected CONTINUATION frame outside header block")
4904 );
4905 }
4906
4907 #[test]
4908 fn h2_idle_frame_validates_rst_stream_payload() {
4909 let invalid = h2_test_frame(http2::FrameType::RstStream, 0, 8u32.to_be_bytes().to_vec());
4910 let err = handle_h2_idle_frame(&invalid).unwrap_err();
4911 assert!(
4912 err.to_string()
4913 .contains("RST_STREAM must not be on stream 0")
4914 );
4915
4916 let valid = h2_test_frame(http2::FrameType::RstStream, 3, 8u32.to_be_bytes().to_vec());
4917 assert!(handle_h2_idle_frame(&valid).is_ok());
4918 }
4919
4920 #[test]
4921 fn h2_idle_frame_validates_priority_payload() {
4922 let invalid = h2_test_frame(http2::FrameType::Priority, 0, vec![0, 0, 0, 0, 16]);
4923 let err = handle_h2_idle_frame(&invalid).unwrap_err();
4924 assert!(err.to_string().contains("PRIORITY must not be on stream 0"));
4925
4926 let valid = h2_test_frame(http2::FrameType::Priority, 1, vec![0, 0, 0, 0, 16]);
4927 assert!(handle_h2_idle_frame(&valid).is_ok());
4928 }
4929
4930 #[test]
4931 fn max_header_block_size_is_128k() {
4932 assert_eq!(MAX_HEADER_BLOCK_SIZE, 128 * 1024);
4933 }
4934
4935 #[test]
4936 fn server_settings_payload_advertises_max_concurrent_streams() {
4937 assert_eq!(SERVER_SETTINGS_PAYLOAD.len(), 6);
4939 assert_eq!(SERVER_SETTINGS_PAYLOAD[0..2], [0x00, 0x03]);
4940 assert_eq!(
4941 u32::from_be_bytes([
4942 SERVER_SETTINGS_PAYLOAD[2],
4943 SERVER_SETTINGS_PAYLOAD[3],
4944 SERVER_SETTINGS_PAYLOAD[4],
4945 SERVER_SETTINGS_PAYLOAD[5],
4946 ]),
4947 1
4948 );
4949 }
4950
4951 #[test]
4952 fn max_hpack_table_size_is_64k() {
4953 assert_eq!(MAX_HPACK_TABLE_SIZE, 64 * 1024);
4954 }
4955
4956 #[test]
4957 fn h2_send_window_update_ignores_other_streams() {
4958 let mut flow_control = http2::H2FlowControl::new();
4959 let mut stream_window = 123i64;
4960 let payload = 7u32.to_be_bytes();
4961
4962 apply_peer_window_update_for_send(&mut flow_control, &mut stream_window, 3, 5, &payload)
4963 .expect("window update on different stream should be ignored");
4964
4965 assert_eq!(stream_window, 123);
4966 }
4967
4968 #[test]
4969 fn h2_send_window_update_applies_connection_and_current_stream() {
4970 let mut flow_control = http2::H2FlowControl::new();
4971 let mut stream_window = 10i64;
4972
4973 let conn_before = flow_control.send_conn_window();
4974 let conn_payload = 11u32.to_be_bytes();
4975 apply_peer_window_update_for_send(
4976 &mut flow_control,
4977 &mut stream_window,
4978 9,
4979 0,
4980 &conn_payload,
4981 )
4982 .expect("connection window update should be applied");
4983 assert_eq!(flow_control.send_conn_window(), conn_before + 11);
4984 assert_eq!(stream_window, 10);
4985
4986 let stream_payload = 13u32.to_be_bytes();
4987 apply_peer_window_update_for_send(
4988 &mut flow_control,
4989 &mut stream_window,
4990 9,
4991 9,
4992 &stream_payload,
4993 )
4994 .expect("stream window update should be applied to current stream");
4995 assert_eq!(stream_window, 23);
4996 }
4997
4998 #[test]
4999 fn h2_send_settings_updates_current_stream_window_delta() {
5000 let mut flow_control = http2::H2FlowControl::new();
5001 let mut stream_window = 50i64;
5002 let mut peer_max_frame_size = 16_384u32;
5003
5004 let payload = [0x00, 0x03, 0x00, 0x01, 0x11, 0x70]; apply_peer_settings_for_send(
5006 &mut flow_control,
5007 &mut stream_window,
5008 &mut peer_max_frame_size,
5009 &payload,
5010 )
5011 .expect("valid SETTINGS_INITIAL_WINDOW_SIZE should apply");
5012
5013 assert_eq!(flow_control.peer_initial_window_size(), 70_000);
5014 assert_eq!(stream_window, 4_515); assert_eq!(peer_max_frame_size, 16_384);
5016 }
5017
5018 #[test]
5019 fn h2_send_settings_rejects_invalid_payload_len() {
5020 let mut flow_control = http2::H2FlowControl::new();
5021 let mut stream_window = 0i64;
5022 let mut peer_max_frame_size = 16_384u32;
5023 let err = apply_peer_settings_for_send(
5024 &mut flow_control,
5025 &mut stream_window,
5026 &mut peer_max_frame_size,
5027 &[0, 1, 2],
5028 )
5029 .unwrap_err();
5030 assert!(
5031 err.to_string()
5032 .contains("SETTINGS length must be a multiple of 6")
5033 );
5034 }
5035
5036 #[test]
5037 fn h2_send_settings_rejects_initial_window_too_large() {
5038 let mut flow_control = http2::H2FlowControl::new();
5039 let mut stream_window = 0i64;
5040 let mut peer_max_frame_size = 16_384u32;
5041 let payload = [0x00, 0x03, 0x80, 0x00, 0x00, 0x00]; let err = apply_peer_settings_for_send(
5043 &mut flow_control,
5044 &mut stream_window,
5045 &mut peer_max_frame_size,
5046 &payload,
5047 )
5048 .unwrap_err();
5049 assert!(
5050 err.to_string()
5051 .contains("SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum")
5052 );
5053 }
5054
5055 #[test]
5056 fn h2_send_settings_window_delta_overflow_is_flow_control_error() {
5057 let mut flow_control = http2::H2FlowControl::new();
5058 let mut peer_max_frame_size = 16_384u32;
5059 let mut stream_window: i64 = 0x7FFF_FFFF - 10;
5061 let new_initial: u32 = 0x7FFF_FFFF;
5065 let payload = [
5066 0x00,
5067 0x03,
5068 new_initial.to_be_bytes()[0],
5069 new_initial.to_be_bytes()[1],
5070 new_initial.to_be_bytes()[2],
5071 new_initial.to_be_bytes()[3],
5072 ];
5073 let err = apply_peer_settings_for_send(
5074 &mut flow_control,
5075 &mut stream_window,
5076 &mut peer_max_frame_size,
5077 &payload,
5078 )
5079 .unwrap_err();
5080 assert!(err.to_string().contains("stream window to exceed 2^31-1"));
5081 }
5082
5083 #[test]
5084 fn h2_send_settings_updates_peer_max_frame_size() {
5085 let mut flow_control = http2::H2FlowControl::new();
5086 let mut stream_window = 0i64;
5087 let mut peer_max_frame_size = 65_535u32;
5088 let payload = [0x00, 0x05, 0x00, 0x00, 0x40, 0x00]; apply_peer_settings_for_send(
5091 &mut flow_control,
5092 &mut stream_window,
5093 &mut peer_max_frame_size,
5094 &payload,
5095 )
5096 .expect("valid SETTINGS_MAX_FRAME_SIZE should apply");
5097
5098 assert_eq!(peer_max_frame_size, 16_384);
5099 }
5100
5101 #[test]
5102 fn h2_send_settings_rejects_invalid_max_frame_size() {
5103 let mut flow_control = http2::H2FlowControl::new();
5104 let mut stream_window = 0i64;
5105 let mut peer_max_frame_size = 16_384u32;
5106 let payload = [0x00, 0x05, 0x00, 0x00, 0x3F, 0xFF]; let err = apply_peer_settings_for_send(
5109 &mut flow_control,
5110 &mut stream_window,
5111 &mut peer_max_frame_size,
5112 &payload,
5113 )
5114 .unwrap_err();
5115 assert!(err.to_string().contains("invalid SETTINGS_MAX_FRAME_SIZE"));
5116 }
5117
5118 #[test]
5119 fn request_from_h2_headers_rejects_unknown_pseudo_header() {
5120 let headers: http2::HeaderList = vec![
5121 (b":method".to_vec(), b"GET".to_vec()),
5122 (b":path".to_vec(), b"/".to_vec()),
5123 (b":weird".to_vec(), b"value".to_vec()),
5124 ];
5125 let err = request_from_h2_headers(headers).unwrap_err();
5126 assert!(err.to_string().contains("unknown pseudo-header"));
5127 }
5128
5129 #[test]
5130 fn request_from_h2_headers_rejects_pseudo_after_regular_header() {
5131 let headers: http2::HeaderList = vec![
5132 (b":method".to_vec(), b"GET".to_vec()),
5133 (b":path".to_vec(), b"/".to_vec()),
5134 (b"x-test".to_vec(), b"ok".to_vec()),
5135 (b":authority".to_vec(), b"example.com".to_vec()),
5136 ];
5137 let err = request_from_h2_headers(headers).unwrap_err();
5138 assert!(
5139 err.to_string()
5140 .contains("pseudo-headers must appear before regular headers")
5141 );
5142 }
5143
5144 #[test]
5149 fn host_validation_missing_host_rejected() {
5150 let config = ServerConfig::default();
5151 let request = Request::new(fastapi_core::Method::Get, "/");
5152 let err = validate_host_header(&request, &config).unwrap_err();
5153 assert_eq!(err.kind, HostValidationErrorKind::Missing);
5154 assert_eq!(err.response().status().as_u16(), 400);
5155 }
5156
5157 #[test]
5158 fn host_validation_allows_configured_host() {
5159 let config = ServerConfig::default().with_allowed_hosts(["example.com"]);
5160 let mut request = Request::new(fastapi_core::Method::Get, "/");
5161 request
5162 .headers_mut()
5163 .insert("Host".to_string(), b"example.com".to_vec());
5164 assert!(validate_host_header(&request, &config).is_ok());
5165 }
5166
5167 #[test]
5168 fn host_validation_rejects_disallowed_host() {
5169 let config = ServerConfig::default().with_allowed_hosts(["example.com"]);
5170 let mut request = Request::new(fastapi_core::Method::Get, "/");
5171 request
5172 .headers_mut()
5173 .insert("Host".to_string(), b"evil.com".to_vec());
5174 let err = validate_host_header(&request, &config).unwrap_err();
5175 assert_eq!(err.kind, HostValidationErrorKind::NotAllowed);
5176 }
5177
5178 #[test]
5179 fn host_validation_wildcard_allows_subdomains_only() {
5180 let config = ServerConfig::default().with_allowed_hosts(["*.example.com"]);
5181 let mut request = Request::new(fastapi_core::Method::Get, "/");
5182 request
5183 .headers_mut()
5184 .insert("Host".to_string(), b"api.example.com".to_vec());
5185 assert!(validate_host_header(&request, &config).is_ok());
5186
5187 let mut request = Request::new(fastapi_core::Method::Get, "/");
5188 request
5189 .headers_mut()
5190 .insert("Host".to_string(), b"example.com".to_vec());
5191 let err = validate_host_header(&request, &config).unwrap_err();
5192 assert_eq!(err.kind, HostValidationErrorKind::NotAllowed);
5193 }
5194
5195 #[test]
5196 fn host_validation_uses_x_forwarded_host_when_trusted() {
5197 let config = ServerConfig::default()
5198 .with_allowed_hosts(["example.com"])
5199 .with_trust_x_forwarded_host(true);
5200 let mut request = Request::new(fastapi_core::Method::Get, "/");
5201 request
5202 .headers_mut()
5203 .insert("Host".to_string(), b"internal.local".to_vec());
5204 request
5205 .headers_mut()
5206 .insert("X-Forwarded-Host".to_string(), b"example.com".to_vec());
5207 assert!(validate_host_header(&request, &config).is_ok());
5208 }
5209
5210 #[test]
5211 fn host_validation_rejects_invalid_host_value() {
5212 let config = ServerConfig::default();
5213 let mut request = Request::new(fastapi_core::Method::Get, "/");
5214 request
5215 .headers_mut()
5216 .insert("Host".to_string(), b"bad host".to_vec());
5217 let err = validate_host_header(&request, &config).unwrap_err();
5218 assert_eq!(err.kind, HostValidationErrorKind::Invalid);
5219 }
5220
5221 #[test]
5226 fn websocket_upgrade_detection_accepts_token_lists_case_insensitive() {
5227 let mut request = Request::new(fastapi_core::Method::Get, "/ws");
5228 request
5229 .headers_mut()
5230 .insert("Upgrade".to_string(), b"h2c, WebSocket".to_vec());
5231 request
5232 .headers_mut()
5233 .insert("Connection".to_string(), b"keep-alive, UPGRADE".to_vec());
5234
5235 assert!(is_websocket_upgrade_request(&request));
5236 }
5237
5238 #[test]
5239 fn websocket_upgrade_detection_rejects_missing_connection_upgrade_token() {
5240 let mut request = Request::new(fastapi_core::Method::Get, "/ws");
5241 request
5242 .headers_mut()
5243 .insert("Upgrade".to_string(), b"websocket".to_vec());
5244 request
5245 .headers_mut()
5246 .insert("Connection".to_string(), b"keep-alive".to_vec());
5247
5248 assert!(!is_websocket_upgrade_request(&request));
5249 }
5250
5251 #[test]
5252 fn websocket_upgrade_detection_rejects_non_get_method() {
5253 let mut request = Request::new(fastapi_core::Method::Post, "/ws");
5254 request
5255 .headers_mut()
5256 .insert("Upgrade".to_string(), b"websocket".to_vec());
5257 request
5258 .headers_mut()
5259 .insert("Connection".to_string(), b"upgrade".to_vec());
5260
5261 assert!(!is_websocket_upgrade_request(&request));
5262 }
5263
5264 #[test]
5269 fn keep_alive_default_http11() {
5270 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5272 request
5273 .headers_mut()
5274 .insert("Host".to_string(), b"example.com".to_vec());
5275 assert!(should_keep_alive(&request));
5276 }
5277
5278 #[test]
5279 fn keep_alive_explicit_keep_alive() {
5280 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5281 request
5282 .headers_mut()
5283 .insert("Connection".to_string(), b"keep-alive".to_vec());
5284 assert!(should_keep_alive(&request));
5285 }
5286
5287 #[test]
5288 fn keep_alive_connection_close() {
5289 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5290 request
5291 .headers_mut()
5292 .insert("Connection".to_string(), b"close".to_vec());
5293 assert!(!should_keep_alive(&request));
5294 }
5295
5296 #[test]
5297 fn keep_alive_connection_close_case_insensitive() {
5298 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5299 request
5300 .headers_mut()
5301 .insert("Connection".to_string(), b"CLOSE".to_vec());
5302 assert!(!should_keep_alive(&request));
5303 }
5304
5305 #[test]
5306 fn keep_alive_multiple_values() {
5307 let mut request = Request::new(fastapi_core::Method::Get, "/path".to_string());
5308 request
5309 .headers_mut()
5310 .insert("Connection".to_string(), b"keep-alive, upgrade".to_vec());
5311 assert!(should_keep_alive(&request));
5312 }
5313
5314 #[test]
5319 fn timeout_budget_created_with_config_deadline() {
5320 let config = ServerConfig::new("127.0.0.1:8080").with_request_timeout_secs(45);
5321 let budget = Budget::new().with_deadline(config.request_timeout);
5322 assert_eq!(budget.deadline, Some(Time::from_secs(45)));
5323 }
5324
5325 #[test]
5326 fn timeout_duration_conversion_from_time() {
5327 let timeout = Time::from_secs(30);
5328 let duration = Duration::from_nanos(timeout.as_nanos());
5329 assert_eq!(duration, Duration::from_secs(30));
5330 }
5331
5332 #[test]
5333 fn timeout_duration_conversion_from_time_millis() {
5334 let timeout = Time::from_millis(1500);
5335 let duration = Duration::from_nanos(timeout.as_nanos());
5336 assert_eq!(duration, Duration::from_millis(1500));
5337 }
5338
5339 #[test]
5340 fn gateway_timeout_response_has_correct_status() {
5341 let response = Response::with_status(StatusCode::GATEWAY_TIMEOUT);
5342 assert_eq!(response.status().as_u16(), 504);
5343 }
5344
5345 #[test]
5346 fn gateway_timeout_response_with_body() {
5347 let response = Response::with_status(StatusCode::GATEWAY_TIMEOUT).body(
5348 fastapi_core::ResponseBody::Bytes(b"Request timed out".to_vec()),
5349 );
5350 assert_eq!(response.status().as_u16(), 504);
5351 assert!(response.body_ref().len() > 0);
5353 }
5354
5355 #[test]
5356 fn elapsed_time_check_logic() {
5357 let start = Instant::now();
5359 let timeout_duration = Duration::from_millis(10);
5360
5361 assert!(start.elapsed() <= timeout_duration);
5363
5364 std::thread::sleep(Duration::from_millis(20));
5366
5367 assert!(start.elapsed() > timeout_duration);
5369 }
5370
5371 #[test]
5376 fn connection_counter_starts_at_zero() {
5377 let server = TcpServer::default();
5378 assert_eq!(server.current_connections(), 0);
5379 }
5380
5381 #[test]
5382 fn try_acquire_connection_unlimited() {
5383 let server = TcpServer::default();
5385 assert_eq!(server.config().max_connections, 0);
5386
5387 for _ in 0..100 {
5389 assert!(server.try_acquire_connection());
5390 }
5391 assert_eq!(server.current_connections(), 100);
5392
5393 for _ in 0..100 {
5395 server.release_connection();
5396 }
5397 assert_eq!(server.current_connections(), 0);
5398 }
5399
5400 #[test]
5401 fn try_acquire_connection_with_limit() {
5402 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(5);
5403 let server = TcpServer::new(config);
5404
5405 for i in 0..5 {
5407 assert!(
5408 server.try_acquire_connection(),
5409 "Should acquire connection {i}"
5410 );
5411 }
5412 assert_eq!(server.current_connections(), 5);
5413
5414 assert!(!server.try_acquire_connection());
5416 assert_eq!(server.current_connections(), 5);
5417
5418 server.release_connection();
5420 assert_eq!(server.current_connections(), 4);
5421
5422 assert!(server.try_acquire_connection());
5424 assert_eq!(server.current_connections(), 5);
5425 }
5426
5427 #[test]
5428 fn try_acquire_connection_single_connection_limit() {
5429 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(1);
5430 let server = TcpServer::new(config);
5431
5432 assert!(server.try_acquire_connection());
5434 assert_eq!(server.current_connections(), 1);
5435
5436 assert!(!server.try_acquire_connection());
5438 assert_eq!(server.current_connections(), 1);
5439
5440 server.release_connection();
5442 assert!(server.try_acquire_connection());
5443 }
5444
5445 #[test]
5446 fn service_unavailable_response_has_correct_status() {
5447 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE);
5448 assert_eq!(response.status().as_u16(), 503);
5449 }
5450
5451 #[test]
5452 fn service_unavailable_response_with_body() {
5453 let response = Response::with_status(StatusCode::SERVICE_UNAVAILABLE)
5454 .header("connection", b"close".to_vec())
5455 .body(fastapi_core::ResponseBody::Bytes(
5456 b"503 Service Unavailable: connection limit reached".to_vec(),
5457 ));
5458 assert_eq!(response.status().as_u16(), 503);
5459 assert!(response.body_ref().len() > 0);
5460 }
5461
5462 #[test]
5463 fn config_max_connections_default_is_zero() {
5464 let config = ServerConfig::default();
5465 assert_eq!(config.max_connections, 0);
5466 }
5467
5468 #[test]
5469 fn config_max_connections_can_be_set() {
5470 let config = ServerConfig::new("127.0.0.1:8080").with_max_connections(100);
5471 assert_eq!(config.max_connections, 100);
5472 }
5473
5474 #[test]
5479 fn config_keep_alive_timeout_default() {
5480 let config = ServerConfig::default();
5481 assert_eq!(
5482 config.keep_alive_timeout,
5483 Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS)
5484 );
5485 }
5486
5487 #[test]
5488 fn config_keep_alive_timeout_can_be_set() {
5489 let config =
5490 ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout(Duration::from_secs(120));
5491 assert_eq!(config.keep_alive_timeout, Duration::from_secs(120));
5492 }
5493
5494 #[test]
5495 fn config_keep_alive_timeout_can_be_set_secs() {
5496 let config = ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout_secs(90);
5497 assert_eq!(config.keep_alive_timeout, Duration::from_secs(90));
5498 }
5499
5500 #[test]
5501 fn config_max_requests_per_connection_default() {
5502 let config = ServerConfig::default();
5503 assert_eq!(
5504 config.max_requests_per_connection,
5505 DEFAULT_MAX_REQUESTS_PER_CONNECTION
5506 );
5507 }
5508
5509 #[test]
5510 fn config_max_requests_per_connection_can_be_set() {
5511 let config = ServerConfig::new("127.0.0.1:8080").with_max_requests_per_connection(50);
5512 assert_eq!(config.max_requests_per_connection, 50);
5513 }
5514
5515 #[test]
5516 fn config_max_requests_per_connection_unlimited() {
5517 let config = ServerConfig::new("127.0.0.1:8080").with_max_requests_per_connection(0);
5518 assert_eq!(config.max_requests_per_connection, 0);
5519 }
5520
5521 #[test]
5522 fn response_with_keep_alive_header() {
5523 let response = Response::ok().header("connection", b"keep-alive".to_vec());
5524 let headers = response.headers();
5525 let connection_header = headers
5526 .iter()
5527 .find(|(name, _)| name.eq_ignore_ascii_case("connection"));
5528 assert!(connection_header.is_some());
5529 assert_eq!(connection_header.unwrap().1, b"keep-alive");
5530 }
5531
5532 #[test]
5533 fn response_with_close_header() {
5534 let response = Response::ok().header("connection", b"close".to_vec());
5535 let headers = response.headers();
5536 let connection_header = headers
5537 .iter()
5538 .find(|(name, _)| name.eq_ignore_ascii_case("connection"));
5539 assert!(connection_header.is_some());
5540 assert_eq!(connection_header.unwrap().1, b"close");
5541 }
5542
5543 #[test]
5548 fn config_drain_timeout_default() {
5549 let config = ServerConfig::default();
5550 assert_eq!(
5551 config.drain_timeout,
5552 Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS)
5553 );
5554 }
5555
5556 #[test]
5557 fn config_drain_timeout_can_be_set() {
5558 let config =
5559 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_secs(60));
5560 assert_eq!(config.drain_timeout, Duration::from_secs(60));
5561 }
5562
5563 #[test]
5564 fn config_drain_timeout_can_be_set_secs() {
5565 let config = ServerConfig::new("127.0.0.1:8080").with_drain_timeout_secs(45);
5566 assert_eq!(config.drain_timeout, Duration::from_secs(45));
5567 }
5568
5569 #[test]
5570 fn server_not_draining_initially() {
5571 let server = TcpServer::default();
5572 assert!(!server.is_draining());
5573 }
5574
5575 #[test]
5576 fn server_start_drain_sets_flag() {
5577 let server = TcpServer::default();
5578 assert!(!server.is_draining());
5579 server.start_drain();
5580 assert!(server.is_draining());
5581 }
5582
5583 #[test]
5584 fn server_start_drain_idempotent() {
5585 let server = TcpServer::default();
5586 server.start_drain();
5587 assert!(server.is_draining());
5588 server.start_drain();
5589 assert!(server.is_draining());
5590 }
5591
5592 #[test]
5593 fn wait_for_drain_returns_true_when_no_connections() {
5594 block_on(async {
5595 let server = TcpServer::default();
5596 assert_eq!(server.current_connections(), 0);
5597 let result = server
5598 .wait_for_drain(Duration::from_millis(100), Some(Duration::from_millis(1)))
5599 .await;
5600 assert!(result);
5601 });
5602 }
5603
5604 #[test]
5605 fn wait_for_drain_timeout_with_connections() {
5606 block_on(async {
5607 let server = TcpServer::default();
5608 server.try_acquire_connection();
5610 server.try_acquire_connection();
5611 assert_eq!(server.current_connections(), 2);
5612
5613 let result = server
5615 .wait_for_drain(Duration::from_millis(50), Some(Duration::from_millis(5)))
5616 .await;
5617 assert!(!result);
5618 assert_eq!(server.current_connections(), 2);
5619 });
5620 }
5621
5622 #[test]
5623 fn drain_returns_zero_when_no_connections() {
5624 block_on(async {
5625 let server = TcpServer::new(
5626 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_millis(100)),
5627 );
5628 assert_eq!(server.current_connections(), 0);
5629 let remaining = server.drain().await;
5630 assert_eq!(remaining, 0);
5631 assert!(server.is_draining());
5632 });
5633 }
5634
5635 #[test]
5636 fn drain_returns_count_when_connections_remain() {
5637 block_on(async {
5638 let server = TcpServer::new(
5639 ServerConfig::new("127.0.0.1:8080").with_drain_timeout(Duration::from_millis(50)),
5640 );
5641 server.try_acquire_connection();
5643 server.try_acquire_connection();
5644 server.try_acquire_connection();
5645
5646 let remaining = server.drain().await;
5647 assert_eq!(remaining, 3);
5648 assert!(server.is_draining());
5649 });
5650 }
5651
5652 #[test]
5653 fn cleanup_completed_handles_prunes_finished_runtime_tasks() {
5654 use std::sync::mpsc;
5655 use std::time::{Duration, Instant as StdInstant};
5656
5657 let runtime = asupersync::runtime::RuntimeBuilder::new()
5658 .worker_threads(2)
5659 .build()
5660 .expect("runtime build");
5661 let handle = runtime.handle();
5662 let server = TcpServer::default();
5663 let (tx, rx) = mpsc::sync_channel(1);
5664
5665 let join = handle.spawn(async move {
5666 tx.send(()).expect("completion signal should send");
5667 });
5668
5669 server
5670 .connection_handles
5671 .lock()
5672 .expect("connection handle mutex should not be poisoned")
5673 .push(join);
5674
5675 rx.recv_timeout(Duration::from_secs(1))
5676 .expect("spawned runtime task should complete");
5677
5678 let deadline = StdInstant::now() + Duration::from_secs(1);
5679 loop {
5680 if server
5681 .connection_handles
5682 .lock()
5683 .expect("connection handle mutex should not be poisoned")[0]
5684 .is_finished()
5685 {
5686 break;
5687 }
5688 assert!(
5689 StdInstant::now() < deadline,
5690 "JoinHandle should report completion after task exit"
5691 );
5692 std::thread::sleep(Duration::from_millis(10));
5693 }
5694
5695 block_on(async {
5696 server.cleanup_completed_handles(&Cx::for_testing()).await;
5697 });
5698
5699 let remaining = server
5700 .connection_handles
5701 .lock()
5702 .expect("connection handle mutex should not be poisoned")
5703 .len();
5704 assert_eq!(remaining, 0);
5705 }
5706
5707 #[test]
5708 fn cleanup_completed_handles_reaps_panicked_runtime_tasks_without_panicking() {
5709 use std::time::{Duration, Instant as StdInstant};
5710
5711 let runtime = asupersync::runtime::RuntimeBuilder::new()
5712 .worker_threads(2)
5713 .build()
5714 .expect("runtime build");
5715 let handle = runtime.handle();
5716 let server = TcpServer::default();
5717
5718 let join = handle.spawn(async move {
5719 panic!("intentional panic to verify cleanup panics are observed");
5720 });
5721
5722 server
5723 .connection_handles
5724 .lock()
5725 .expect("connection handle mutex should not be poisoned")
5726 .push(join);
5727
5728 let deadline = StdInstant::now() + Duration::from_secs(1);
5729 loop {
5730 let finished = server
5731 .connection_handles
5732 .lock()
5733 .expect("connection handle mutex should not be poisoned")[0]
5734 .is_finished();
5735 if finished {
5736 break;
5737 }
5738 assert!(
5739 StdInstant::now() < deadline,
5740 "panicking task should finish promptly"
5741 );
5742 std::thread::sleep(Duration::from_millis(10));
5743 }
5744
5745 block_on(async {
5746 server.cleanup_completed_handles(&Cx::for_testing()).await;
5747 });
5748
5749 let remaining = server
5750 .connection_handles
5751 .lock()
5752 .expect("connection handle mutex should not be poisoned")
5753 .len();
5754 assert_eq!(remaining, 0);
5755 }
5756
5757 #[test]
5758 fn connection_slot_guard_releases_counter_when_task_panics() {
5759 use std::time::{Duration, Instant as StdInstant};
5760
5761 let runtime = asupersync::runtime::RuntimeBuilder::new()
5762 .worker_threads(2)
5763 .build()
5764 .expect("runtime build");
5765 let handle = runtime.handle();
5766 let counter = Arc::new(AtomicU64::new(1));
5767
5768 let panic_task = handle.spawn({
5769 let counter = Arc::clone(&counter);
5770 async move {
5771 let _connection_slot = ConnectionSlotGuard::new(counter);
5772 panic!("intentional panic to verify connection slot cleanup");
5773 }
5774 });
5775
5776 let deadline = StdInstant::now() + Duration::from_secs(1);
5777 while !panic_task.is_finished() {
5778 assert!(
5779 StdInstant::now() < deadline,
5780 "panicing task should finish promptly"
5781 );
5782 std::thread::sleep(Duration::from_millis(10));
5783 }
5784
5785 assert_eq!(
5786 counter.load(Ordering::Relaxed),
5787 0,
5788 "connection slot must be released even when the task unwinds"
5789 );
5790 }
5791
5792 #[test]
5793 fn connection_slot_guard_releases_counter_when_future_drops_before_poll() {
5794 let counter = Arc::new(AtomicU64::new(1));
5795 let connection_slot = ConnectionSlotGuard::new(Arc::clone(&counter));
5796
5797 let future = async move {
5798 let _connection_slot = connection_slot;
5799 };
5800
5801 drop(future);
5802
5803 assert_eq!(
5804 counter.load(Ordering::Relaxed),
5805 0,
5806 "connection slot must be released even if the spawned future is dropped before polling"
5807 );
5808 }
5809
5810 #[test]
5811 fn serve_concurrent_shutdown_wakes_idle_accept_loop() {
5812 use std::time::{Duration, Instant as StdInstant};
5813
5814 let server = Arc::new(TcpServer::new(ServerConfig::new("127.0.0.1:0")));
5815 let server_for_thread = Arc::clone(&server);
5816
5817 let serve_thread = std::thread::spawn(move || {
5818 block_on(async {
5819 let cx = Cx::for_testing();
5820 server_for_thread
5821 .serve_concurrent(&cx, |_ctx, _req| async {
5822 Response::ok().body(fastapi_core::ResponseBody::Bytes(b"ok".to_vec()))
5823 })
5824 .await
5825 })
5826 });
5827
5828 std::thread::sleep(Duration::from_millis(100));
5829 server.shutdown();
5830
5831 let deadline = StdInstant::now() + Duration::from_secs(2);
5832 while !serve_thread.is_finished() {
5833 assert!(
5834 StdInstant::now() < deadline,
5835 "serve_concurrent should exit promptly after shutdown without a new connection"
5836 );
5837 std::thread::sleep(Duration::from_millis(20));
5838 }
5839
5840 let result = serve_thread
5841 .join()
5842 .expect("serve_concurrent regression thread should not panic");
5843 assert!(
5844 result.is_ok(),
5845 "serve_concurrent should stop cleanly on shutdown"
5846 );
5847 }
5848
5849 #[test]
5850 fn server_shutdown_error_display() {
5851 let err = ServerError::Shutdown;
5852 assert_eq!(err.to_string(), "Server shutdown");
5853 }
5854
5855 #[test]
5860 fn server_has_shutdown_controller() {
5861 let server = TcpServer::default();
5862 let controller = server.shutdown_controller();
5863 assert!(!controller.is_shutting_down());
5864 }
5865
5866 #[test]
5867 fn server_subscribe_shutdown_returns_receiver() {
5868 let server = TcpServer::default();
5869 let receiver = server.subscribe_shutdown();
5870 assert!(!receiver.is_shutting_down());
5871 }
5872
5873 #[test]
5874 fn server_shutdown_sets_draining_and_controller() {
5875 let server = TcpServer::default();
5876 assert!(!server.is_shutting_down());
5877 assert!(!server.is_draining());
5878 assert!(!server.shutdown_controller().is_shutting_down());
5879
5880 server.shutdown();
5881
5882 assert!(server.is_shutting_down());
5883 assert!(server.is_draining());
5884 assert!(server.shutdown_controller().is_shutting_down());
5885 }
5886
5887 #[test]
5888 fn server_shutdown_notifies_receivers() {
5889 let server = TcpServer::default();
5890 let receiver1 = server.subscribe_shutdown();
5891 let receiver2 = server.subscribe_shutdown();
5892
5893 assert!(!receiver1.is_shutting_down());
5894 assert!(!receiver2.is_shutting_down());
5895
5896 server.shutdown();
5897
5898 assert!(receiver1.is_shutting_down());
5899 assert!(receiver2.is_shutting_down());
5900 }
5901
5902 #[test]
5903 fn server_shutdown_is_idempotent() {
5904 let server = TcpServer::default();
5905 let receiver = server.subscribe_shutdown();
5906
5907 server.shutdown();
5908 server.shutdown();
5909 server.shutdown();
5910
5911 assert!(server.is_shutting_down());
5912 assert!(receiver.is_shutting_down());
5913 }
5914
5915 #[test]
5920 fn keep_alive_timeout_error_display() {
5921 let err = ServerError::KeepAliveTimeout;
5922 assert_eq!(err.to_string(), "Keep-alive timeout");
5923 }
5924
5925 #[test]
5926 fn keep_alive_timeout_zero_disables_timeout() {
5927 let config = ServerConfig::new("127.0.0.1:8080").with_keep_alive_timeout(Duration::ZERO);
5928 assert!(config.keep_alive_timeout.is_zero());
5929 }
5930
5931 #[test]
5932 fn keep_alive_timeout_default_is_non_zero() {
5933 let config = ServerConfig::default();
5934 assert!(!config.keep_alive_timeout.is_zero());
5935 assert_eq!(
5936 config.keep_alive_timeout,
5937 Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS)
5938 );
5939 }
5940
5941 #[test]
5942 fn timed_out_io_error_kind() {
5943 let err = io::Error::new(io::ErrorKind::TimedOut, "test timeout");
5944 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
5945 }
5946
5947 #[test]
5948 fn instant_deadline_calculation() {
5949 let timeout = Duration::from_millis(100);
5950 let deadline = Instant::now() + timeout;
5951
5952 assert!(deadline > Instant::now());
5954
5955 std::thread::sleep(Duration::from_millis(150));
5957 assert!(Instant::now() >= deadline);
5958 }
5959
5960 #[test]
5961 fn server_metrics_initial_state() {
5962 let server = TcpServer::default();
5963 let m = server.metrics();
5964 assert_eq!(m.active_connections, 0);
5965 assert_eq!(m.total_accepted, 0);
5966 assert_eq!(m.total_rejected, 0);
5967 assert_eq!(m.total_timed_out, 0);
5968 assert_eq!(m.total_requests, 0);
5969 assert_eq!(m.bytes_in, 0);
5970 assert_eq!(m.bytes_out, 0);
5971 }
5972
5973 #[test]
5974 fn server_metrics_after_acquire_release() {
5975 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(10));
5976 assert!(server.try_acquire_connection());
5977 assert!(server.try_acquire_connection());
5978
5979 let m = server.metrics();
5980 assert_eq!(m.active_connections, 2);
5981 assert_eq!(m.total_accepted, 2);
5982 assert_eq!(m.total_rejected, 0);
5983
5984 server.release_connection();
5985 let m = server.metrics();
5986 assert_eq!(m.active_connections, 1);
5987 assert_eq!(m.total_accepted, 2); }
5989
5990 #[test]
5991 fn server_metrics_rejection_counted() {
5992 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(1));
5993 assert!(server.try_acquire_connection());
5994 assert!(!server.try_acquire_connection()); let m = server.metrics();
5997 assert_eq!(m.total_accepted, 1);
5998 assert_eq!(m.total_rejected, 1);
5999 assert_eq!(m.active_connections, 1);
6000 }
6001
6002 #[test]
6003 fn server_metrics_bytes_tracking() {
6004 let server = TcpServer::default();
6005 server.record_bytes_in(1024);
6006 server.record_bytes_in(512);
6007 server.record_bytes_out(2048);
6008
6009 let m = server.metrics();
6010 assert_eq!(m.bytes_in, 1536);
6011 assert_eq!(m.bytes_out, 2048);
6012 }
6013
6014 #[test]
6015 fn server_metrics_unlimited_connections_accepted() {
6016 let server = TcpServer::new(ServerConfig::new("127.0.0.1:0").with_max_connections(0));
6017 for _ in 0..100 {
6018 assert!(server.try_acquire_connection());
6019 }
6020 let m = server.metrics();
6021 assert_eq!(m.total_accepted, 100);
6022 assert_eq!(m.total_rejected, 0);
6023 assert_eq!(m.active_connections, 100);
6024 }
6025
6026 #[test]
6027 fn server_metrics_clone_eq() {
6028 let server = TcpServer::default();
6029 server.record_bytes_in(42);
6030 let m1 = server.metrics();
6031 let m2 = m1.clone();
6032 assert_eq!(m1, m2);
6033 }
6034}
6035
6036pub trait AppServeExt {
6059 fn serve(self, addr: impl Into<String>) -> impl Future<Output = Result<(), ServeError>> + Send;
6099
6100 fn serve_with_config(
6123 self,
6124 config: ServerConfig,
6125 ) -> impl Future<Output = Result<(), ServeError>> + Send;
6126}
6127
6128#[derive(Debug)]
6130pub enum ServeError {
6131 Startup(fastapi_core::StartupHookError),
6133 Server(ServerError),
6135}
6136
6137impl std::fmt::Display for ServeError {
6138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6139 match self {
6140 Self::Startup(e) => write!(f, "startup hook failed: {}", e.message),
6141 Self::Server(e) => write!(f, "server error: {e}"),
6142 }
6143 }
6144}
6145
6146impl std::error::Error for ServeError {
6147 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
6148 match self {
6149 Self::Startup(_) => None,
6150 Self::Server(e) => Some(e),
6151 }
6152 }
6153}
6154
6155impl From<ServerError> for ServeError {
6156 fn from(e: ServerError) -> Self {
6157 Self::Server(e)
6158 }
6159}
6160
6161impl AppServeExt for App {
6162 fn serve(self, addr: impl Into<String>) -> impl Future<Output = Result<(), ServeError>> + Send {
6163 let config = ServerConfig::new(addr);
6164 self.serve_with_config(config)
6165 }
6166
6167 #[allow(clippy::manual_async_fn)] fn serve_with_config(
6169 self,
6170 config: ServerConfig,
6171 ) -> impl Future<Output = Result<(), ServeError>> + Send {
6172 async move {
6173 match self.run_startup_hooks().await {
6175 fastapi_core::StartupOutcome::Success => {}
6176 fastapi_core::StartupOutcome::PartialSuccess { warnings } => {
6177 eprintln!("Warning: {warnings} startup hook(s) had non-fatal errors");
6179 }
6180 fastapi_core::StartupOutcome::Aborted(e) => {
6181 return Err(ServeError::Startup(e));
6182 }
6183 }
6184
6185 let server = TcpServer::new(config);
6187
6188 let app = Arc::new(self);
6190
6191 let cx = Cx::for_testing();
6193
6194 let bind_addr = &server.config().bind_addr;
6196 println!("🚀 Server starting on http://{bind_addr}");
6197
6198 let result = server.serve_app(&cx, Arc::clone(&app)).await;
6200
6201 app.run_shutdown_hooks().await;
6203
6204 result.map_err(ServeError::from)
6205 }
6206 }
6207}
6208
6209pub async fn serve(app: App, addr: impl Into<String>) -> Result<(), ServeError> {
6227 app.serve(addr).await
6228}
6229
6230pub async fn serve_with_config(app: App, config: ServerConfig) -> Result<(), ServeError> {
6248 app.serve_with_config(config).await
6249}