use fastapi_core::{Request, Response, ResponseBody, StatusCode};
use std::sync::Arc;
pub const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
pub const EXPECT_100_CONTINUE: &str = "100-continue";
#[derive(Debug, Clone)]
pub enum ExpectResult {
NoExpectation,
ExpectsContinue,
UnknownExpectation(String),
}
#[derive(Debug, Clone, Default)]
pub struct ExpectHandler {
pub max_content_length: usize,
pub required_content_type: Option<String>,
}
impl ExpectHandler {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_content_length(mut self, max: usize) -> Self {
self.max_content_length = max;
self
}
#[must_use]
pub fn with_required_content_type(mut self, content_type: impl Into<String>) -> Self {
self.required_content_type = Some(content_type.into());
self
}
#[must_use]
pub fn check_expect(request: &Request) -> ExpectResult {
match request.headers().get("expect") {
None => ExpectResult::NoExpectation,
Some(value) => {
let value_str = match std::str::from_utf8(value) {
Ok(s) => s.trim().to_ascii_lowercase(),
Err(_) => return ExpectResult::UnknownExpectation(String::new()),
};
let mut saw_continue = false;
for token in value_str.split(',').map(str::trim) {
if token.is_empty() {
return ExpectResult::UnknownExpectation(value_str);
}
if token == EXPECT_100_CONTINUE {
saw_continue = true;
} else {
return ExpectResult::UnknownExpectation(value_str);
}
}
if saw_continue {
ExpectResult::ExpectsContinue
} else {
ExpectResult::UnknownExpectation(value_str)
}
}
}
}
#[must_use]
pub fn expects_continue(request: &Request) -> bool {
matches!(Self::check_expect(request), ExpectResult::ExpectsContinue)
}
pub fn validate_content_length(&self, request: &Request) -> Result<(), Response> {
if self.max_content_length == 0 {
return Ok(()); }
if let Some(value) = request.headers().get("content-length") {
if let Ok(len_str) = std::str::from_utf8(value) {
if let Ok(len) = len_str.trim().parse::<usize>() {
if len > self.max_content_length {
return Err(Self::payload_too_large(format!(
"Content-Length {} exceeds maximum {}",
len, self.max_content_length
)));
}
}
}
}
Ok(())
}
pub fn validate_content_type(&self, request: &Request) -> Result<(), Response> {
let required = match &self.required_content_type {
Some(ct) => ct,
None => return Ok(()),
};
match request.headers().get("content-type") {
None => Err(Self::unsupported_media_type(format!(
"Content-Type required: {required}"
))),
Some(value) => {
let content_type = std::str::from_utf8(value)
.map(|s| s.trim().to_ascii_lowercase())
.unwrap_or_default();
if content_type.starts_with(&required.to_ascii_lowercase()) {
Ok(())
} else {
Err(Self::unsupported_media_type(format!(
"Expected Content-Type: {required}, got: {content_type}"
)))
}
}
}
}
pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
self.validate_content_length(request)?;
self.validate_content_type(request)?;
Ok(())
}
#[must_use]
pub fn expectation_failed(detail: impl Into<String>) -> Response {
let detail = detail.into();
let body = format!("417 Expectation Failed: {detail}");
Response::with_status(StatusCode::from_u16(417))
.header("content-type", b"text/plain; charset=utf-8".to_vec())
.header("connection", b"close".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()))
}
#[must_use]
pub fn unauthorized(detail: impl Into<String>) -> Response {
let detail = detail.into();
let body = format!("401 Unauthorized: {detail}");
Response::with_status(StatusCode::UNAUTHORIZED)
.header("content-type", b"text/plain; charset=utf-8".to_vec())
.header("connection", b"close".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()))
}
#[must_use]
pub fn forbidden(detail: impl Into<String>) -> Response {
let detail = detail.into();
let body = format!("403 Forbidden: {detail}");
Response::with_status(StatusCode::FORBIDDEN)
.header("content-type", b"text/plain; charset=utf-8".to_vec())
.header("connection", b"close".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()))
}
#[must_use]
pub fn payload_too_large(detail: impl Into<String>) -> Response {
let detail = detail.into();
let body = format!("413 Payload Too Large: {detail}");
Response::with_status(StatusCode::PAYLOAD_TOO_LARGE)
.header("content-type", b"text/plain; charset=utf-8".to_vec())
.header("connection", b"close".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()))
}
#[must_use]
pub fn unsupported_media_type(detail: impl Into<String>) -> Response {
let detail = detail.into();
let body = format!("415 Unsupported Media Type: {detail}");
Response::with_status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.header("content-type", b"text/plain; charset=utf-8".to_vec())
.header("connection", b"close".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()))
}
}
pub trait PreBodyValidator: Send + Sync {
fn validate(&self, request: &Request) -> Result<(), Response>;
fn name(&self) -> &'static str {
"PreBodyValidator"
}
}
#[derive(Default, Clone)]
pub struct PreBodyValidators {
validators: Vec<Arc<dyn PreBodyValidator>>,
}
impl std::fmt::Debug for PreBodyValidators {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PreBodyValidators")
.field("len", &self.validators.len())
.field(
"validators",
&self.validators.iter().map(|v| v.name()).collect::<Vec<_>>(),
)
.finish()
}
}
impl PreBodyValidators {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add<V: PreBodyValidator + 'static>(&mut self, validator: V) {
self.validators.push(Arc::new(validator));
}
#[must_use]
pub fn with<V: PreBodyValidator + 'static>(mut self, validator: V) -> Self {
self.add(validator);
self
}
pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
for validator in &self.validators {
validator.validate(request)?;
}
Ok(())
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.validators.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.validators.len()
}
}
pub struct FnValidator<F> {
name: &'static str,
validate_fn: F,
}
impl<F> FnValidator<F>
where
F: Fn(&Request) -> Result<(), Response> + Send + Sync,
{
pub fn new(name: &'static str, validate_fn: F) -> Self {
Self { name, validate_fn }
}
}
impl<F> PreBodyValidator for FnValidator<F>
where
F: Fn(&Request) -> Result<(), Response> + Send + Sync,
{
fn validate(&self, request: &Request) -> Result<(), Response> {
(self.validate_fn)(request)
}
fn name(&self) -> &'static str {
self.name
}
}
#[cfg(test)]
mod tests {
use super::*;
use fastapi_core::Method;
fn request_with_expect(value: &str) -> Request {
let mut req = Request::new(Method::Post, "/upload");
req.headers_mut()
.insert("expect".to_string(), value.as_bytes().to_vec());
req
}
fn request_with_headers(headers: &[(&str, &str)]) -> Request {
let mut req = Request::new(Method::Post, "/upload");
for (name, value) in headers {
req.headers_mut()
.insert(name.to_string(), value.as_bytes().to_vec());
}
req
}
#[test]
fn check_expect_none() {
let req = Request::new(Method::Get, "/");
assert!(matches!(
ExpectHandler::check_expect(&req),
ExpectResult::NoExpectation
));
}
#[test]
fn check_expect_100_continue() {
let req = request_with_expect("100-continue");
assert!(matches!(
ExpectHandler::check_expect(&req),
ExpectResult::ExpectsContinue
));
}
#[test]
fn check_expect_100_continue_case_insensitive() {
let req = request_with_expect("100-Continue");
assert!(matches!(
ExpectHandler::check_expect(&req),
ExpectResult::ExpectsContinue
));
let req = request_with_expect("100-CONTINUE");
assert!(matches!(
ExpectHandler::check_expect(&req),
ExpectResult::ExpectsContinue
));
}
#[test]
fn check_expect_100_continue_token_list() {
let req = request_with_expect("100-continue, 100-continue");
assert!(matches!(
ExpectHandler::check_expect(&req),
ExpectResult::ExpectsContinue
));
}
#[test]
fn check_expect_unknown() {
let req = request_with_expect("something-else");
let result = ExpectHandler::check_expect(&req);
assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
if let ExpectResult::UnknownExpectation(val) = result {
assert_eq!(val, "something-else");
}
}
#[test]
fn expects_continue_helper() {
let req_yes = request_with_expect("100-continue");
assert!(ExpectHandler::expects_continue(&req_yes));
let req_no = Request::new(Method::Get, "/");
assert!(!ExpectHandler::expects_continue(&req_no));
}
#[test]
fn check_expect_mixed_token_list_is_unknown() {
let req = request_with_expect("100-continue, custom");
let result = ExpectHandler::check_expect(&req);
assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
if let ExpectResult::UnknownExpectation(val) = result {
assert_eq!(val, "100-continue, custom");
}
}
#[test]
fn check_expect_empty_token_is_unknown() {
let req = request_with_expect("100-continue,");
let result = ExpectHandler::check_expect(&req);
assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
if let ExpectResult::UnknownExpectation(val) = result {
assert_eq!(val, "100-continue,");
}
}
#[test]
fn validate_content_length_no_limit() {
let handler = ExpectHandler::new();
let req = request_with_headers(&[("content-length", "1000000")]);
assert!(handler.validate_content_length(&req).is_ok());
}
#[test]
fn validate_content_length_within_limit() {
let handler = ExpectHandler::new().with_max_content_length(1024);
let req = request_with_headers(&[("content-length", "500")]);
assert!(handler.validate_content_length(&req).is_ok());
}
#[test]
fn validate_content_length_exceeds_limit() {
let handler = ExpectHandler::new().with_max_content_length(1024);
let req = request_with_headers(&[("content-length", "2048")]);
let result = handler.validate_content_length(&req);
assert!(result.is_err());
let response = result.unwrap_err();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[test]
fn validate_content_type_no_requirement() {
let handler = ExpectHandler::new();
let req = request_with_headers(&[("content-type", "text/plain")]);
assert!(handler.validate_content_type(&req).is_ok());
}
#[test]
fn validate_content_type_matches() {
let handler = ExpectHandler::new().with_required_content_type("application/json");
let req = request_with_headers(&[("content-type", "application/json; charset=utf-8")]);
assert!(handler.validate_content_type(&req).is_ok());
}
#[test]
fn validate_content_type_missing() {
let handler = ExpectHandler::new().with_required_content_type("application/json");
let req = Request::new(Method::Post, "/upload");
let result = handler.validate_content_type(&req);
assert!(result.is_err());
let response = result.unwrap_err();
assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[test]
fn validate_content_type_mismatch() {
let handler = ExpectHandler::new().with_required_content_type("application/json");
let req = request_with_headers(&[("content-type", "text/plain")]);
let result = handler.validate_content_type(&req);
assert!(result.is_err());
let response = result.unwrap_err();
assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[test]
fn validate_all_passes() {
let handler = ExpectHandler::new()
.with_max_content_length(1024)
.with_required_content_type("application/json");
let req = request_with_headers(&[
("content-length", "100"),
("content-type", "application/json"),
]);
assert!(handler.validate_all(&req).is_ok());
}
#[test]
fn validate_all_fails_on_first_error() {
let handler = ExpectHandler::new()
.with_max_content_length(100)
.with_required_content_type("application/json");
let req = request_with_headers(&[
("content-length", "1000"), ("content-type", "text/plain"), ]);
let result = handler.validate_all(&req);
assert!(result.is_err());
let response = result.unwrap_err();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[test]
fn error_responses() {
let resp = ExpectHandler::expectation_failed("test");
assert_eq!(resp.status().as_u16(), 417);
let resp = ExpectHandler::unauthorized("test");
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let resp = ExpectHandler::forbidden("test");
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
let resp = ExpectHandler::payload_too_large("test");
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
let resp = ExpectHandler::unsupported_media_type("test");
assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[test]
fn continue_response_format() {
let expected = b"HTTP/1.1 100 Continue\r\n\r\n";
assert_eq!(CONTINUE_RESPONSE, expected);
}
#[test]
fn pre_body_validators() {
let mut validators = PreBodyValidators::new();
assert!(validators.is_empty());
assert_eq!(validators.len(), 0);
validators.add(FnValidator::new("auth_check", |req: &Request| {
if req.headers().get("authorization").is_some() {
Ok(())
} else {
Err(ExpectHandler::unauthorized("Missing Authorization header"))
}
}));
assert!(!validators.is_empty());
assert_eq!(validators.len(), 1);
let req_no_auth = Request::new(Method::Post, "/upload");
let result = validators.validate_all(&req_no_auth);
assert!(result.is_err());
assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
let req_with_auth = request_with_headers(&[("authorization", "Bearer token")]);
assert!(validators.validate_all(&req_with_auth).is_ok());
}
#[test]
fn pre_body_validators_chain() {
let validators = PreBodyValidators::new()
.with(FnValidator::new("auth", |req: &Request| {
if req.headers().get("authorization").is_some() {
Ok(())
} else {
Err(ExpectHandler::unauthorized("Missing auth"))
}
}))
.with(FnValidator::new("content_type", |req: &Request| {
if let Some(ct) = req.headers().get("content-type") {
if ct.starts_with(b"application/json") {
return Ok(());
}
}
Err(ExpectHandler::unsupported_media_type("Expected JSON"))
}));
assert_eq!(validators.len(), 2);
let req = request_with_headers(&[
("authorization", "Bearer token"),
("content-type", "application/json"),
]);
assert!(validators.validate_all(&req).is_ok());
let req = request_with_headers(&[("content-type", "application/json")]);
let result = validators.validate_all(&req);
assert!(result.is_err());
assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
}
}