Skip to main content

fastapi_http/
expect.rs

1//! HTTP `Expect: 100-continue` handling.
2//!
3//! This module provides support for the HTTP `Expect: 100-continue` mechanism
4//! as defined in [RFC 7231 Section 5.1.1](https://tools.ietf.org/html/rfc7231#section-5.1.1).
5//!
6//! # Overview
7//!
8//! When a client sends a request with `Expect: 100-continue`, it is indicating
9//! that it will wait for a `100 Continue` interim response before sending the
10//! request body. This allows the server to:
11//!
12//! - Validate request headers before receiving potentially large body data
13//! - Reject unauthorized requests without reading the body
14//! - Check Content-Type and Content-Length before accepting uploads
15//!
16//! # Example
17//!
18//! ```ignore
19//! use fastapi_http::expect::{ExpectHandler, ExpectValidation, CONTINUE_RESPONSE};
20//!
21//! // Check for Expect: 100-continue
22//! if let Some(validation) = ExpectHandler::check_expect(&request) {
23//!     // Run pre-body validation (auth, content-type, etc.)
24//!     if !validate_auth(&request) {
25//!         return validation.reject_unauthorized("Invalid credentials");
26//!     }
27//!     if !validate_content_type(&request) {
28//!         return validation.reject_unsupported_media_type("Expected application/json");
29//!     }
30//!
31//!     // Validation passed - send 100 Continue
32//!     stream.write_all(CONTINUE_RESPONSE).await?;
33//! }
34//!
35//! // Now proceed to read body and handle request
36//! ```
37//!
38//! # Error Responses
39//!
40//! When pre-body validation fails, the server should NOT send `100 Continue`.
41//! Instead, it should send an appropriate error response:
42//!
43//! - `417 Expectation Failed` - The expectation cannot be met
44//! - `401 Unauthorized` - Authentication required
45//! - `403 Forbidden` - Authorization failed
46//! - `413 Payload Too Large` - Content-Length exceeds limits
47//! - `415 Unsupported Media Type` - Content-Type not accepted
48//!
49//! # Wire Format
50//!
51//! The `100 Continue` response is a simple interim response:
52//!
53//! ```text
54//! HTTP/1.1 100 Continue\r\n
55//! \r\n
56//! ```
57//!
58//! After sending this, the server proceeds to read the request body.
59
60use fastapi_core::{Request, Response, ResponseBody, StatusCode};
61
62/// The raw bytes for an HTTP/1.1 100 Continue response.
63///
64/// This is the minimal valid 100 Continue response:
65/// ```text
66/// HTTP/1.1 100 Continue\r\n
67/// \r\n
68/// ```
69pub const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
70
71/// The Expect header value that triggers 100-continue handling.
72pub const EXPECT_100_CONTINUE: &str = "100-continue";
73
74/// Result of checking the Expect header.
75#[derive(Debug, Clone)]
76pub enum ExpectResult {
77    /// No Expect header present - proceed normally without waiting
78    NoExpectation,
79    /// Expect: 100-continue present - must validate before reading body
80    ExpectsContinue,
81    /// Unknown expectation - should return 417 Expectation Failed
82    UnknownExpectation(String),
83}
84
85/// Handler for HTTP Expect header processing.
86#[derive(Debug, Clone, Default)]
87pub struct ExpectHandler {
88    /// Maximum Content-Length to accept (0 = unlimited)
89    pub max_content_length: usize,
90    /// Required Content-Type prefix (empty = any)
91    pub required_content_type: Option<String>,
92}
93
94impl ExpectHandler {
95    /// Create a new ExpectHandler with default settings.
96    #[must_use]
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    /// Set the maximum Content-Length to accept.
102    #[must_use]
103    pub fn with_max_content_length(mut self, max: usize) -> Self {
104        self.max_content_length = max;
105        self
106    }
107
108    /// Set the required Content-Type prefix.
109    #[must_use]
110    pub fn with_required_content_type(mut self, content_type: impl Into<String>) -> Self {
111        self.required_content_type = Some(content_type.into());
112        self
113    }
114
115    /// Check if a request has an Expect header and what it contains.
116    ///
117    /// Returns:
118    /// - `ExpectResult::NoExpectation` - No Expect header, proceed normally
119    /// - `ExpectResult::ExpectsContinue` - Expect: 100-continue present
120    /// - `ExpectResult::UnknownExpectation` - Unknown expectation value
121    #[must_use]
122    pub fn check_expect(request: &Request) -> ExpectResult {
123        match request.headers().get("expect") {
124            None => ExpectResult::NoExpectation,
125            Some(value) => {
126                let value_str = std::str::from_utf8(value)
127                    .map(|s| s.trim().to_ascii_lowercase())
128                    .unwrap_or_default();
129
130                if value_str == EXPECT_100_CONTINUE {
131                    ExpectResult::ExpectsContinue
132                } else {
133                    ExpectResult::UnknownExpectation(value_str)
134                }
135            }
136        }
137    }
138
139    /// Check if the request expects 100-continue.
140    ///
141    /// This is a convenience method that returns true only for valid
142    /// `Expect: 100-continue` headers.
143    #[must_use]
144    pub fn expects_continue(request: &Request) -> bool {
145        matches!(Self::check_expect(request), ExpectResult::ExpectsContinue)
146    }
147
148    /// Validate Content-Length against maximum limit.
149    ///
150    /// Returns `Ok(())` if Content-Length is within limits or not specified,
151    /// or `Err(Response)` with 413 Payload Too Large if exceeded.
152    pub fn validate_content_length(&self, request: &Request) -> Result<(), Response> {
153        if self.max_content_length == 0 {
154            return Ok(()); // No limit
155        }
156
157        if let Some(value) = request.headers().get("content-length") {
158            if let Ok(len_str) = std::str::from_utf8(value) {
159                if let Ok(len) = len_str.trim().parse::<usize>() {
160                    if len > self.max_content_length {
161                        return Err(Self::payload_too_large(format!(
162                            "Content-Length {} exceeds maximum {}",
163                            len, self.max_content_length
164                        )));
165                    }
166                }
167            }
168        }
169
170        Ok(())
171    }
172
173    /// Validate Content-Type against required type.
174    ///
175    /// Returns `Ok(())` if Content-Type matches or no requirement is set,
176    /// or `Err(Response)` with 415 Unsupported Media Type if mismatched.
177    pub fn validate_content_type(&self, request: &Request) -> Result<(), Response> {
178        let required = match &self.required_content_type {
179            Some(ct) => ct,
180            None => return Ok(()),
181        };
182
183        match request.headers().get("content-type") {
184            None => Err(Self::unsupported_media_type(format!(
185                "Content-Type required: {required}"
186            ))),
187            Some(value) => {
188                let content_type = std::str::from_utf8(value)
189                    .map(|s| s.trim().to_ascii_lowercase())
190                    .unwrap_or_default();
191
192                if content_type.starts_with(&required.to_ascii_lowercase()) {
193                    Ok(())
194                } else {
195                    Err(Self::unsupported_media_type(format!(
196                        "Expected Content-Type: {required}, got: {content_type}"
197                    )))
198                }
199            }
200        }
201    }
202
203    /// Run all configured validations.
204    ///
205    /// Returns `Ok(())` if all validations pass, or the first error response.
206    pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
207        self.validate_content_length(request)?;
208        self.validate_content_type(request)?;
209        Ok(())
210    }
211
212    /// Create a 417 Expectation Failed response.
213    #[must_use]
214    pub fn expectation_failed(detail: impl Into<String>) -> Response {
215        let detail = detail.into();
216        let body = format!("417 Expectation Failed: {detail}");
217        // StatusCode::EXPECTATION_FAILED is 417
218        Response::with_status(StatusCode::from_u16(417))
219            .header("content-type", b"text/plain; charset=utf-8".to_vec())
220            .header("connection", b"close".to_vec())
221            .body(ResponseBody::Bytes(body.into_bytes()))
222    }
223
224    /// Create a 401 Unauthorized response.
225    #[must_use]
226    pub fn unauthorized(detail: impl Into<String>) -> Response {
227        let detail = detail.into();
228        let body = format!("401 Unauthorized: {detail}");
229        Response::with_status(StatusCode::UNAUTHORIZED)
230            .header("content-type", b"text/plain; charset=utf-8".to_vec())
231            .header("connection", b"close".to_vec())
232            .body(ResponseBody::Bytes(body.into_bytes()))
233    }
234
235    /// Create a 403 Forbidden response.
236    #[must_use]
237    pub fn forbidden(detail: impl Into<String>) -> Response {
238        let detail = detail.into();
239        let body = format!("403 Forbidden: {detail}");
240        Response::with_status(StatusCode::FORBIDDEN)
241            .header("content-type", b"text/plain; charset=utf-8".to_vec())
242            .header("connection", b"close".to_vec())
243            .body(ResponseBody::Bytes(body.into_bytes()))
244    }
245
246    /// Create a 413 Payload Too Large response.
247    #[must_use]
248    pub fn payload_too_large(detail: impl Into<String>) -> Response {
249        let detail = detail.into();
250        let body = format!("413 Payload Too Large: {detail}");
251        Response::with_status(StatusCode::PAYLOAD_TOO_LARGE)
252            .header("content-type", b"text/plain; charset=utf-8".to_vec())
253            .header("connection", b"close".to_vec())
254            .body(ResponseBody::Bytes(body.into_bytes()))
255    }
256
257    /// Create a 415 Unsupported Media Type response.
258    #[must_use]
259    pub fn unsupported_media_type(detail: impl Into<String>) -> Response {
260        let detail = detail.into();
261        let body = format!("415 Unsupported Media Type: {detail}");
262        Response::with_status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
263            .header("content-type", b"text/plain; charset=utf-8".to_vec())
264            .header("connection", b"close".to_vec())
265            .body(ResponseBody::Bytes(body.into_bytes()))
266    }
267}
268
269/// Trait for pre-body validation hooks.
270///
271/// Implement this trait to add custom validation that runs before
272/// sending 100 Continue and reading the request body.
273pub trait PreBodyValidator: Send + Sync {
274    /// Validate the request headers before body is read.
275    ///
276    /// Returns `Ok(())` if validation passes, or `Err(Response)` with
277    /// an appropriate error response if validation fails.
278    fn validate(&self, request: &Request) -> Result<(), Response>;
279
280    /// Optional name for debugging/logging.
281    fn name(&self) -> &'static str {
282        "PreBodyValidator"
283    }
284}
285
286/// A collection of pre-body validators.
287#[derive(Default)]
288pub struct PreBodyValidators {
289    validators: Vec<Box<dyn PreBodyValidator>>,
290}
291
292impl PreBodyValidators {
293    /// Create a new empty validator collection.
294    #[must_use]
295    pub fn new() -> Self {
296        Self::default()
297    }
298
299    /// Add a validator to the collection.
300    pub fn add<V: PreBodyValidator + 'static>(&mut self, validator: V) {
301        self.validators.push(Box::new(validator));
302    }
303
304    /// Add a validator and return self for chaining.
305    #[must_use]
306    pub fn with<V: PreBodyValidator + 'static>(mut self, validator: V) -> Self {
307        self.add(validator);
308        self
309    }
310
311    /// Run all validators in order.
312    ///
313    /// Returns `Ok(())` if all pass, or the first error response.
314    pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
315        for validator in &self.validators {
316            validator.validate(request)?;
317        }
318        Ok(())
319    }
320
321    /// Returns true if there are no validators.
322    #[must_use]
323    pub fn is_empty(&self) -> bool {
324        self.validators.is_empty()
325    }
326
327    /// Returns the number of validators.
328    #[must_use]
329    pub fn len(&self) -> usize {
330        self.validators.len()
331    }
332}
333
334/// A simple function-based pre-body validator.
335pub struct FnValidator<F> {
336    name: &'static str,
337    validate_fn: F,
338}
339
340impl<F> FnValidator<F>
341where
342    F: Fn(&Request) -> Result<(), Response> + Send + Sync,
343{
344    /// Create a new function validator.
345    pub fn new(name: &'static str, validate_fn: F) -> Self {
346        Self { name, validate_fn }
347    }
348}
349
350impl<F> PreBodyValidator for FnValidator<F>
351where
352    F: Fn(&Request) -> Result<(), Response> + Send + Sync,
353{
354    fn validate(&self, request: &Request) -> Result<(), Response> {
355        (self.validate_fn)(request)
356    }
357
358    fn name(&self) -> &'static str {
359        self.name
360    }
361}
362
363// ============================================================================
364// Tests
365// ============================================================================
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use fastapi_core::Method;
371
372    fn request_with_expect(value: &str) -> Request {
373        let mut req = Request::new(Method::Post, "/upload");
374        req.headers_mut()
375            .insert("expect".to_string(), value.as_bytes().to_vec());
376        req
377    }
378
379    fn request_with_headers(headers: &[(&str, &str)]) -> Request {
380        let mut req = Request::new(Method::Post, "/upload");
381        for (name, value) in headers {
382            req.headers_mut()
383                .insert(name.to_string(), value.as_bytes().to_vec());
384        }
385        req
386    }
387
388    #[test]
389    fn check_expect_none() {
390        let req = Request::new(Method::Get, "/");
391        assert!(matches!(
392            ExpectHandler::check_expect(&req),
393            ExpectResult::NoExpectation
394        ));
395    }
396
397    #[test]
398    fn check_expect_100_continue() {
399        let req = request_with_expect("100-continue");
400        assert!(matches!(
401            ExpectHandler::check_expect(&req),
402            ExpectResult::ExpectsContinue
403        ));
404    }
405
406    #[test]
407    fn check_expect_100_continue_case_insensitive() {
408        let req = request_with_expect("100-Continue");
409        assert!(matches!(
410            ExpectHandler::check_expect(&req),
411            ExpectResult::ExpectsContinue
412        ));
413
414        let req = request_with_expect("100-CONTINUE");
415        assert!(matches!(
416            ExpectHandler::check_expect(&req),
417            ExpectResult::ExpectsContinue
418        ));
419    }
420
421    #[test]
422    fn check_expect_unknown() {
423        let req = request_with_expect("something-else");
424        match ExpectHandler::check_expect(&req) {
425            ExpectResult::UnknownExpectation(val) => assert_eq!(val, "something-else"),
426            _ => panic!("Expected UnknownExpectation"),
427        }
428    }
429
430    #[test]
431    fn expects_continue_helper() {
432        let req_yes = request_with_expect("100-continue");
433        assert!(ExpectHandler::expects_continue(&req_yes));
434
435        let req_no = Request::new(Method::Get, "/");
436        assert!(!ExpectHandler::expects_continue(&req_no));
437    }
438
439    #[test]
440    fn validate_content_length_no_limit() {
441        let handler = ExpectHandler::new();
442        let req = request_with_headers(&[("content-length", "1000000")]);
443        assert!(handler.validate_content_length(&req).is_ok());
444    }
445
446    #[test]
447    fn validate_content_length_within_limit() {
448        let handler = ExpectHandler::new().with_max_content_length(1024);
449        let req = request_with_headers(&[("content-length", "500")]);
450        assert!(handler.validate_content_length(&req).is_ok());
451    }
452
453    #[test]
454    fn validate_content_length_exceeds_limit() {
455        let handler = ExpectHandler::new().with_max_content_length(1024);
456        let req = request_with_headers(&[("content-length", "2048")]);
457        let result = handler.validate_content_length(&req);
458        assert!(result.is_err());
459        let response = result.unwrap_err();
460        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
461    }
462
463    #[test]
464    fn validate_content_type_no_requirement() {
465        let handler = ExpectHandler::new();
466        let req = request_with_headers(&[("content-type", "text/plain")]);
467        assert!(handler.validate_content_type(&req).is_ok());
468    }
469
470    #[test]
471    fn validate_content_type_matches() {
472        let handler = ExpectHandler::new().with_required_content_type("application/json");
473        let req = request_with_headers(&[("content-type", "application/json; charset=utf-8")]);
474        assert!(handler.validate_content_type(&req).is_ok());
475    }
476
477    #[test]
478    fn validate_content_type_missing() {
479        let handler = ExpectHandler::new().with_required_content_type("application/json");
480        let req = Request::new(Method::Post, "/upload");
481        let result = handler.validate_content_type(&req);
482        assert!(result.is_err());
483        let response = result.unwrap_err();
484        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
485    }
486
487    #[test]
488    fn validate_content_type_mismatch() {
489        let handler = ExpectHandler::new().with_required_content_type("application/json");
490        let req = request_with_headers(&[("content-type", "text/plain")]);
491        let result = handler.validate_content_type(&req);
492        assert!(result.is_err());
493        let response = result.unwrap_err();
494        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
495    }
496
497    #[test]
498    fn validate_all_passes() {
499        let handler = ExpectHandler::new()
500            .with_max_content_length(1024)
501            .with_required_content_type("application/json");
502        let req = request_with_headers(&[
503            ("content-length", "100"),
504            ("content-type", "application/json"),
505        ]);
506        assert!(handler.validate_all(&req).is_ok());
507    }
508
509    #[test]
510    fn validate_all_fails_on_first_error() {
511        let handler = ExpectHandler::new()
512            .with_max_content_length(100)
513            .with_required_content_type("application/json");
514        let req = request_with_headers(&[
515            ("content-length", "1000"),     // Exceeds limit
516            ("content-type", "text/plain"), // Wrong type
517        ]);
518        let result = handler.validate_all(&req);
519        assert!(result.is_err());
520        // Should fail on content-length first
521        let response = result.unwrap_err();
522        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
523    }
524
525    #[test]
526    fn error_responses() {
527        let resp = ExpectHandler::expectation_failed("test");
528        assert_eq!(resp.status().as_u16(), 417);
529
530        let resp = ExpectHandler::unauthorized("test");
531        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
532
533        let resp = ExpectHandler::forbidden("test");
534        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
535
536        let resp = ExpectHandler::payload_too_large("test");
537        assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
538
539        let resp = ExpectHandler::unsupported_media_type("test");
540        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
541    }
542
543    #[test]
544    fn continue_response_format() {
545        let expected = b"HTTP/1.1 100 Continue\r\n\r\n";
546        assert_eq!(CONTINUE_RESPONSE, expected);
547    }
548
549    #[test]
550    fn pre_body_validators() {
551        let mut validators = PreBodyValidators::new();
552        assert!(validators.is_empty());
553        assert_eq!(validators.len(), 0);
554
555        // Add a validator that checks for Authorization header
556        validators.add(FnValidator::new("auth_check", |req: &Request| {
557            if req.headers().get("authorization").is_some() {
558                Ok(())
559            } else {
560                Err(ExpectHandler::unauthorized("Missing Authorization header"))
561            }
562        }));
563
564        assert!(!validators.is_empty());
565        assert_eq!(validators.len(), 1);
566
567        // Test with missing auth
568        let req_no_auth = Request::new(Method::Post, "/upload");
569        let result = validators.validate_all(&req_no_auth);
570        assert!(result.is_err());
571        assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
572
573        // Test with auth
574        let req_with_auth = request_with_headers(&[("authorization", "Bearer token")]);
575        assert!(validators.validate_all(&req_with_auth).is_ok());
576    }
577
578    #[test]
579    fn pre_body_validators_chain() {
580        let validators = PreBodyValidators::new()
581            .with(FnValidator::new("auth", |req: &Request| {
582                if req.headers().get("authorization").is_some() {
583                    Ok(())
584                } else {
585                    Err(ExpectHandler::unauthorized("Missing auth"))
586                }
587            }))
588            .with(FnValidator::new("content_type", |req: &Request| {
589                if let Some(ct) = req.headers().get("content-type") {
590                    if ct.starts_with(b"application/json") {
591                        return Ok(());
592                    }
593                }
594                Err(ExpectHandler::unsupported_media_type("Expected JSON"))
595            }));
596
597        assert_eq!(validators.len(), 2);
598
599        // Both pass
600        let req = request_with_headers(&[
601            ("authorization", "Bearer token"),
602            ("content-type", "application/json"),
603        ]);
604        assert!(validators.validate_all(&req).is_ok());
605
606        // First fails
607        let req = request_with_headers(&[("content-type", "application/json")]);
608        let result = validators.validate_all(&req);
609        assert!(result.is_err());
610        assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
611    }
612}