volga 0.9.0

Easy & Fast Web Framework for Rust
Documentation
use futures_util::{TryFutureExt, future::BoxFuture};
use std::net::SocketAddr;
use std::sync::{Arc, Weak};
use tokio_util::sync::CancellationToken;

use hyper::{
    HeaderMap, Method, Request, Response,
    body::{Incoming, SizeHint},
    header::{ALLOW, CONTENT_LENGTH, HeaderValue},
    service::Service,
};

use crate::{
    ClientIp, HttpBody, HttpRequest, HttpResult, Limit,
    app::AppEnv,
    error::{Error, handler::extract_error_args},
    headers::CACHE_CONTROL,
    http::{endpoints::FindResult, request_scope::HttpRequestScope},
    status,
};

#[cfg(feature = "tls")]
use crate::{
    headers::{HOST, STRICT_TRANSPORT_SECURITY},
    tls::HstsHeader,
};

#[cfg(feature = "tracing")]
use {
    crate::tracing::TracingConfig,
    tracing::{Id, trace_span},
};

#[cfg(feature = "middleware")]
use crate::middleware::HttpContext;

const REQUEST_HEADERS_TOO_LARGE_MESSAGE: &str = "Request headers too large.";

/// Represents the execution scope of the current connection
#[derive(Clone)]
pub(crate) struct Scope {
    pub(crate) env: Weak<AppEnv>,
    pub(crate) cancellation_token: CancellationToken,
    peer_addr: SocketAddr,
}

impl Service<Request<Incoming>> for Scope {
    type Response = Response<HttpBody>;
    type Error = Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    #[inline]
    fn call(&self, request: Request<Incoming>) -> Self::Future {
        Box::pin(
            Self::handle_request(
                request,
                self.peer_addr,
                self.env.clone(),
                self.cancellation_token.clone(),
            )
            .map_ok(Into::into),
        )
    }
}

impl Scope {
    pub(crate) fn new(env: Weak<AppEnv>, peer_addr: SocketAddr) -> Self {
        Self {
            cancellation_token: CancellationToken::new(),
            peer_addr,
            env,
        }
    }

    pub(super) async fn handle_request(
        request: Request<Incoming>,
        peer_addr: SocketAddr,
        env: Weak<AppEnv>,
        cancellation_token: CancellationToken,
    ) -> HttpResult {
        let env = match env.upgrade() {
            Some(shared) => shared,
            None => {
                #[cfg(feature = "tracing")]
                tracing::error!("app instance could not be upgraded; aborting...");

                return status!(500);
            }
        };

        let method = request.method().clone();

        #[cfg(feature = "tracing")]
        let span = env.tracing_config.as_ref().map(|_| {
            let uri = request.uri();
            trace_span!("request", %method, %uri)
        });

        #[cfg(feature = "tracing")]
        let _guard = span.as_ref().map(|s| s.enter());

        #[cfg(feature = "tls")]
        let host = request.headers().get(HOST).cloned();

        let response = handle_impl(request, peer_addr, env.clone(), cancellation_token).await;

        finalize_response(
            method,
            response,
            &env,
            #[cfg(feature = "tracing")]
            span.as_ref().and_then(|s| s.id()),
            #[cfg(feature = "tls")]
            host,
        )
    }
}

async fn handle_impl(
    request: Request<Incoming>,
    peer_addr: SocketAddr,
    env: Arc<AppEnv>,
    cancellation_token: CancellationToken,
) -> HttpResult {
    {
        let headers = request.headers();

        #[cfg(feature = "http1")]
        if let Limit::Limited(max_header_size) = env.max_header_size
            && !check_max_header_size(headers, max_header_size)
        {
            #[cfg(feature = "tracing")]
            tracing::warn!("Request rejected due to headers exceeding configured limits");

            return status!(431, text: REQUEST_HEADERS_TOO_LARGE_MESSAGE);
        }

        #[cfg(feature = "http2")]
        if let Limit::Limited(max_header_count) = env.max_header_count
            && !check_max_header_count(headers, max_header_count)
        {
            #[cfg(feature = "tracing")]
            tracing::warn!("Request rejected due to headers exceeding configured limits");

            return status!(431, text: REQUEST_HEADERS_TOO_LARGE_MESSAGE);
        }
    }

    #[cfg(feature = "static-files")]
    let request = {
        let mut request = request;
        request.extensions_mut().insert(env.host_env.clone());
        request
    };

    #[cfg(feature = "di")]
    let request = {
        let mut request = request;
        request
            .extensions_mut()
            .insert(env.container.create_scope());
        request
    };

    let pipeline = &env.pipeline;
    match pipeline.endpoints().find(
        request.method(),
        request.uri(),
        #[cfg(feature = "middleware")]
        env.cors.is_enabled,
        #[cfg(feature = "middleware")]
        request.headers(),
    ) {
        FindResult::RouteNotFound => pipeline.fallback(request).await,
        FindResult::MethodNotFound(allowed) => status!(405; [
            (ALLOW, allowed.as_ref())
        ]),
        FindResult::Ok(endpoint) => {
            #[cfg(feature = "middleware")]
            let (route_pipeline, params, cors) = endpoint.into_parts();
            #[cfg(not(feature = "middleware"))]
            let (route_pipeline, params) = endpoint.into_parts();

            let error_handler = pipeline.error_handler();
            let (mut parts, body) = request.into_parts();

            parts.extensions.insert(HttpRequestScope {
                client_ip: ClientIp(peer_addr),
                cancellation_token,
                body_limit: env.body_limit,
                params,
                #[cfg(feature = "ws")]
                error_handler: Arc::clone(error_handler),
                #[cfg(feature = "jwt-auth")]
                bearer_token_service: env.bearer_token_service.clone(),
                #[cfg(any(
                    feature = "decompression-brotli",
                    feature = "decompression-gzip",
                    feature = "decompression-zstd",
                    feature = "decompression-full"
                ))]
                decompression_limits: env.decompression_limits,
                #[cfg(feature = "rate-limiting")]
                rate_limiter: env.rate_limiter.clone(),
                #[cfg(feature = "rate-limiting")]
                trusted_proxies: env.trusted_proxies.clone(),
                #[cfg(feature = "config")]
                config: env.config.clone(),
            });

            // Pre-extract error handler args from parts before consuming them.
            let error_args = extract_error_args(error_handler, &parts);

            let request =
                HttpRequest::new(Request::from_parts(parts, body)).into_limited(env.body_limit);

            #[cfg(feature = "middleware")]
            let response = if pipeline.has_middleware_pipeline() {
                let ctx = HttpContext::new(request, Some(route_pipeline), cors);
                pipeline.execute(ctx).await
            } else {
                route_pipeline
                    .call(HttpContext::new(request, None, cors))
                    .await
            };
            #[cfg(not(feature = "middleware"))]
            let response = route_pipeline.call(request).await;

            match response {
                Ok(response) => Ok(response),
                Err(err) => error_args.call(err).await,
            }
        }
    }
}

#[inline]
fn finalize_response(
    method: Method,
    response: HttpResult,
    shared: &AppEnv,
    #[cfg(feature = "tracing")] span_id: Option<Id>,
    #[cfg(feature = "tls")] host: Option<HeaderValue>,
) -> HttpResult {
    response.map(|mut resp| {
        if method == Method::HEAD {
            keep_content_length(resp.size_hint(), resp.headers_mut());
            *resp.body_mut() = HttpBody::empty();
        }

        if let Some(hv) = &shared.cache_control {
            apply_default_cache_control(&mut resp, hv.clone());
        }

        #[cfg(feature = "tracing")]
        if let Some(tracing) = &shared.tracing_config {
            apply_tracing_headers(&mut resp, tracing, span_id);
        }

        #[cfg(feature = "tls")]
        if let Some(hsts) = &shared.hsts {
            apply_hsts_headers(&mut resp, hsts, host);
        }

        resp
    })
}

#[inline]
fn apply_default_cache_control(resp: &mut crate::HttpResponse, header: HeaderValue) {
    if !resp.headers().contains_key(CACHE_CONTROL) {
        resp.headers_mut().insert(CACHE_CONTROL, header);
    }
}

#[inline]
#[cfg(feature = "tracing")]
fn apply_tracing_headers(
    resp: &mut crate::HttpResponse,
    tracing: &TracingConfig,
    span_id: Option<Id>,
) {
    if !tracing.include_header {
        return;
    }

    let value = span_id.map_or(0, |id| id.into_u64()).to_string();

    resp.headers_mut().insert(
        tracing.span_header_name.clone(),
        value.parse().expect("valid span id"),
    );
}

#[cfg(feature = "tls")]
fn apply_hsts_headers(
    resp: &mut crate::HttpResponse,
    hsts: &HstsHeader,
    host: Option<HeaderValue>,
) {
    if is_excluded(host, &hsts.exclude_hosts) {
        return;
    }

    resp.headers_mut()
        .insert(STRICT_TRANSPORT_SECURITY, hsts.value());
}

#[inline]
#[cfg(feature = "tls")]
fn is_excluded(host: Option<HeaderValue>, exclude_hosts: &[String]) -> bool {
    host.as_ref()
        .and_then(|h| h.to_str().ok())
        .map(|h| {
            let h = normalize_host(h);
            exclude_hosts.iter().any(|e| e == h)
        })
        .unwrap_or(false)
}

#[inline]
#[cfg(feature = "tls")]
fn normalize_host(host: &str) -> &str {
    let host = host.trim();

    let host = match host.rsplit_once(':') {
        Some((h, "443")) => h,
        Some((h, "80")) => h,
        _ => host,
    };

    host.trim_end_matches('.')
}

fn keep_content_length(size_hint: SizeHint, headers: &mut HeaderMap) {
    if headers.contains_key(CONTENT_LENGTH) {
        return;
    }

    if let Some(size) = size_hint.exact() {
        let content_length = if size == 0 {
            HeaderValue::from_static("0")
        } else {
            let mut buffer = itoa::Buffer::new();
            HeaderValue::from_str(buffer.format(size)).unwrap()
        };
        headers.insert(CONTENT_LENGTH, content_length);
    }
}

#[inline(always)]
#[cfg(feature = "http2")]
fn check_max_header_count(headers: &HeaderMap, max_header_count: usize) -> bool {
    headers.len() < max_header_count
}

#[inline(always)]
#[cfg(feature = "http1")]
fn check_max_header_size(headers: &HeaderMap, max_header_size: usize) -> bool {
    let total_size: usize = headers
        .iter()
        .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
        .sum();

    total_size < max_header_size
}