Skip to main content

fastapi_http/
expect.rs

1//! HTTP `Expect: 100-continue` handling.
2//!
3//! This module provides support for the HTTP `Expect: 100-continue` mechanism
4//! as defined in [RFC 7231 Section 5.1.1](https://tools.ietf.org/html/rfc7231#section-5.1.1).
5//!
6//! # Overview
7//!
8//! When a client sends a request with `Expect: 100-continue`, it is indicating
9//! that it will wait for a `100 Continue` interim response before sending the
10//! request body. This allows the server to:
11//!
12//! - Validate request headers before receiving potentially large body data
13//! - Reject unauthorized requests without reading the body
14//! - Check Content-Type and Content-Length before accepting uploads
15//!
16//! # Example
17//!
18//! ```ignore
19//! use fastapi_http::expect::{ExpectHandler, ExpectValidation, CONTINUE_RESPONSE};
20//!
21//! // Check for Expect: 100-continue
22//! if let Some(validation) = ExpectHandler::check_expect(&request) {
23//!     // Run pre-body validation (auth, content-type, etc.)
24//!     if !validate_auth(&request) {
25//!         return validation.reject_unauthorized("Invalid credentials");
26//!     }
27//!     if !validate_content_type(&request) {
28//!         return validation.reject_unsupported_media_type("Expected application/json");
29//!     }
30//!
31//!     // Validation passed - send 100 Continue
32//!     stream.write_all(CONTINUE_RESPONSE).await?;
33//! }
34//!
35//! // Now proceed to read body and handle request
36//! ```
37//!
38//! # Error Responses
39//!
40//! When pre-body validation fails, the server should NOT send `100 Continue`.
41//! Instead, it should send an appropriate error response:
42//!
43//! - `417 Expectation Failed` - The expectation cannot be met
44//! - `401 Unauthorized` - Authentication required
45//! - `403 Forbidden` - Authorization failed
46//! - `413 Payload Too Large` - Content-Length exceeds limits
47//! - `415 Unsupported Media Type` - Content-Type not accepted
48//!
49//! # Wire Format
50//!
51//! The `100 Continue` response is a simple interim response:
52//!
53//! ```text
54//! HTTP/1.1 100 Continue\r\n
55//! \r\n
56//! ```
57//!
58//! After sending this, the server proceeds to read the request body.
59
60use fastapi_core::{Request, Response, ResponseBody, StatusCode};
61use std::sync::Arc;
62
63/// The raw bytes for an HTTP/1.1 100 Continue response.
64///
65/// This is the minimal valid 100 Continue response:
66/// ```text
67/// HTTP/1.1 100 Continue\r\n
68/// \r\n
69/// ```
70pub const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
71
72/// The Expect header value that triggers 100-continue handling.
73pub const EXPECT_100_CONTINUE: &str = "100-continue";
74
75/// Result of checking the Expect header.
76#[derive(Debug, Clone)]
77pub enum ExpectResult {
78    /// No Expect header present - proceed normally without waiting
79    NoExpectation,
80    /// Expect: 100-continue present - must validate before reading body
81    ExpectsContinue,
82    /// Unknown expectation - should return 417 Expectation Failed
83    UnknownExpectation(String),
84}
85
86/// Handler for HTTP Expect header processing.
87#[derive(Debug, Clone, Default)]
88pub struct ExpectHandler {
89    /// Maximum Content-Length to accept (0 = unlimited)
90    pub max_content_length: usize,
91    /// Required Content-Type prefix (empty = any)
92    pub required_content_type: Option<String>,
93}
94
95impl ExpectHandler {
96    /// Create a new ExpectHandler with default settings.
97    #[must_use]
98    pub fn new() -> Self {
99        Self::default()
100    }
101
102    /// Set the maximum Content-Length to accept.
103    #[must_use]
104    pub fn with_max_content_length(mut self, max: usize) -> Self {
105        self.max_content_length = max;
106        self
107    }
108
109    /// Set the required Content-Type prefix.
110    #[must_use]
111    pub fn with_required_content_type(mut self, content_type: impl Into<String>) -> Self {
112        self.required_content_type = Some(content_type.into());
113        self
114    }
115
116    /// Check if a request has an Expect header and what it contains.
117    ///
118    /// Returns:
119    /// - `ExpectResult::NoExpectation` - No Expect header, proceed normally
120    /// - `ExpectResult::ExpectsContinue` - Expect: 100-continue present
121    /// - `ExpectResult::UnknownExpectation` - Unknown expectation value
122    #[must_use]
123    pub fn check_expect(request: &Request) -> ExpectResult {
124        match request.headers().get("expect") {
125            None => ExpectResult::NoExpectation,
126            Some(value) => {
127                let value_str = match std::str::from_utf8(value) {
128                    Ok(s) => s.trim().to_ascii_lowercase(),
129                    Err(_) => return ExpectResult::UnknownExpectation(String::new()),
130                };
131
132                let mut saw_continue = false;
133                for token in value_str.split(',').map(str::trim) {
134                    if token.is_empty() {
135                        return ExpectResult::UnknownExpectation(value_str);
136                    }
137                    if token == EXPECT_100_CONTINUE {
138                        saw_continue = true;
139                    } else {
140                        return ExpectResult::UnknownExpectation(value_str);
141                    }
142                }
143
144                if saw_continue {
145                    ExpectResult::ExpectsContinue
146                } else {
147                    ExpectResult::UnknownExpectation(value_str)
148                }
149            }
150        }
151    }
152
153    /// Check if the request expects 100-continue.
154    ///
155    /// This is a convenience method that returns true only for valid
156    /// `Expect: 100-continue` headers.
157    #[must_use]
158    pub fn expects_continue(request: &Request) -> bool {
159        matches!(Self::check_expect(request), ExpectResult::ExpectsContinue)
160    }
161
162    /// Validate Content-Length against maximum limit.
163    ///
164    /// Returns `Ok(())` if Content-Length is within limits or not specified,
165    /// or `Err(Response)` with 413 Payload Too Large if exceeded.
166    pub fn validate_content_length(&self, request: &Request) -> Result<(), Response> {
167        if self.max_content_length == 0 {
168            return Ok(()); // No limit
169        }
170
171        if let Some(value) = request.headers().get("content-length") {
172            if let Ok(len_str) = std::str::from_utf8(value) {
173                if let Ok(len) = len_str.trim().parse::<usize>() {
174                    if len > self.max_content_length {
175                        return Err(Self::payload_too_large(format!(
176                            "Content-Length {} exceeds maximum {}",
177                            len, self.max_content_length
178                        )));
179                    }
180                }
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Validate Content-Type against required type.
188    ///
189    /// Returns `Ok(())` if Content-Type matches or no requirement is set,
190    /// or `Err(Response)` with 415 Unsupported Media Type if mismatched.
191    pub fn validate_content_type(&self, request: &Request) -> Result<(), Response> {
192        let required = match &self.required_content_type {
193            Some(ct) => ct,
194            None => return Ok(()),
195        };
196
197        match request.headers().get("content-type") {
198            None => Err(Self::unsupported_media_type(format!(
199                "Content-Type required: {required}"
200            ))),
201            Some(value) => {
202                let content_type = std::str::from_utf8(value)
203                    .map(|s| s.trim().to_ascii_lowercase())
204                    .unwrap_or_default();
205
206                if content_type.starts_with(&required.to_ascii_lowercase()) {
207                    Ok(())
208                } else {
209                    Err(Self::unsupported_media_type(format!(
210                        "Expected Content-Type: {required}, got: {content_type}"
211                    )))
212                }
213            }
214        }
215    }
216
217    /// Run all configured validations.
218    ///
219    /// Returns `Ok(())` if all validations pass, or the first error response.
220    pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
221        self.validate_content_length(request)?;
222        self.validate_content_type(request)?;
223        Ok(())
224    }
225
226    /// Create a 417 Expectation Failed response.
227    #[must_use]
228    pub fn expectation_failed(detail: impl Into<String>) -> Response {
229        let detail = detail.into();
230        let body = format!("417 Expectation Failed: {detail}");
231        // StatusCode::EXPECTATION_FAILED is 417
232        Response::with_status(StatusCode::from_u16(417))
233            .header("content-type", b"text/plain; charset=utf-8".to_vec())
234            .header("connection", b"close".to_vec())
235            .body(ResponseBody::Bytes(body.into_bytes()))
236    }
237
238    /// Create a 401 Unauthorized response.
239    #[must_use]
240    pub fn unauthorized(detail: impl Into<String>) -> Response {
241        let detail = detail.into();
242        let body = format!("401 Unauthorized: {detail}");
243        Response::with_status(StatusCode::UNAUTHORIZED)
244            .header("content-type", b"text/plain; charset=utf-8".to_vec())
245            .header("connection", b"close".to_vec())
246            .body(ResponseBody::Bytes(body.into_bytes()))
247    }
248
249    /// Create a 403 Forbidden response.
250    #[must_use]
251    pub fn forbidden(detail: impl Into<String>) -> Response {
252        let detail = detail.into();
253        let body = format!("403 Forbidden: {detail}");
254        Response::with_status(StatusCode::FORBIDDEN)
255            .header("content-type", b"text/plain; charset=utf-8".to_vec())
256            .header("connection", b"close".to_vec())
257            .body(ResponseBody::Bytes(body.into_bytes()))
258    }
259
260    /// Create a 413 Payload Too Large response.
261    #[must_use]
262    pub fn payload_too_large(detail: impl Into<String>) -> Response {
263        let detail = detail.into();
264        let body = format!("413 Payload Too Large: {detail}");
265        Response::with_status(StatusCode::PAYLOAD_TOO_LARGE)
266            .header("content-type", b"text/plain; charset=utf-8".to_vec())
267            .header("connection", b"close".to_vec())
268            .body(ResponseBody::Bytes(body.into_bytes()))
269    }
270
271    /// Create a 415 Unsupported Media Type response.
272    #[must_use]
273    pub fn unsupported_media_type(detail: impl Into<String>) -> Response {
274        let detail = detail.into();
275        let body = format!("415 Unsupported Media Type: {detail}");
276        Response::with_status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
277            .header("content-type", b"text/plain; charset=utf-8".to_vec())
278            .header("connection", b"close".to_vec())
279            .body(ResponseBody::Bytes(body.into_bytes()))
280    }
281}
282
283/// Trait for pre-body validation hooks.
284///
285/// Implement this trait to add custom validation that runs before
286/// sending 100 Continue and reading the request body.
287pub trait PreBodyValidator: Send + Sync {
288    /// Validate the request headers before body is read.
289    ///
290    /// Returns `Ok(())` if validation passes, or `Err(Response)` with
291    /// an appropriate error response if validation fails.
292    fn validate(&self, request: &Request) -> Result<(), Response>;
293
294    /// Optional name for debugging/logging.
295    fn name(&self) -> &'static str {
296        "PreBodyValidator"
297    }
298}
299
300/// A collection of pre-body validators.
301#[derive(Default, Clone)]
302pub struct PreBodyValidators {
303    validators: Vec<Arc<dyn PreBodyValidator>>,
304}
305
306impl std::fmt::Debug for PreBodyValidators {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        f.debug_struct("PreBodyValidators")
309            .field("len", &self.validators.len())
310            .field(
311                "validators",
312                &self.validators.iter().map(|v| v.name()).collect::<Vec<_>>(),
313            )
314            .finish()
315    }
316}
317
318impl PreBodyValidators {
319    /// Create a new empty validator collection.
320    #[must_use]
321    pub fn new() -> Self {
322        Self::default()
323    }
324
325    /// Add a validator to the collection.
326    pub fn add<V: PreBodyValidator + 'static>(&mut self, validator: V) {
327        self.validators.push(Arc::new(validator));
328    }
329
330    /// Add a validator and return self for chaining.
331    #[must_use]
332    pub fn with<V: PreBodyValidator + 'static>(mut self, validator: V) -> Self {
333        self.add(validator);
334        self
335    }
336
337    /// Run all validators in order.
338    ///
339    /// Returns `Ok(())` if all pass, or the first error response.
340    pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
341        for validator in &self.validators {
342            validator.validate(request)?;
343        }
344        Ok(())
345    }
346
347    /// Returns true if there are no validators.
348    #[must_use]
349    pub fn is_empty(&self) -> bool {
350        self.validators.is_empty()
351    }
352
353    /// Returns the number of validators.
354    #[must_use]
355    pub fn len(&self) -> usize {
356        self.validators.len()
357    }
358}
359
360/// A simple function-based pre-body validator.
361pub struct FnValidator<F> {
362    name: &'static str,
363    validate_fn: F,
364}
365
366impl<F> FnValidator<F>
367where
368    F: Fn(&Request) -> Result<(), Response> + Send + Sync,
369{
370    /// Create a new function validator.
371    pub fn new(name: &'static str, validate_fn: F) -> Self {
372        Self { name, validate_fn }
373    }
374}
375
376impl<F> PreBodyValidator for FnValidator<F>
377where
378    F: Fn(&Request) -> Result<(), Response> + Send + Sync,
379{
380    fn validate(&self, request: &Request) -> Result<(), Response> {
381        (self.validate_fn)(request)
382    }
383
384    fn name(&self) -> &'static str {
385        self.name
386    }
387}
388
389// ============================================================================
390// Tests
391// ============================================================================
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use fastapi_core::Method;
397
398    fn request_with_expect(value: &str) -> Request {
399        let mut req = Request::new(Method::Post, "/upload");
400        req.headers_mut()
401            .insert("expect".to_string(), value.as_bytes().to_vec());
402        req
403    }
404
405    fn request_with_headers(headers: &[(&str, &str)]) -> Request {
406        let mut req = Request::new(Method::Post, "/upload");
407        for (name, value) in headers {
408            req.headers_mut()
409                .insert(name.to_string(), value.as_bytes().to_vec());
410        }
411        req
412    }
413
414    #[test]
415    fn check_expect_none() {
416        let req = Request::new(Method::Get, "/");
417        assert!(matches!(
418            ExpectHandler::check_expect(&req),
419            ExpectResult::NoExpectation
420        ));
421    }
422
423    #[test]
424    fn check_expect_100_continue() {
425        let req = request_with_expect("100-continue");
426        assert!(matches!(
427            ExpectHandler::check_expect(&req),
428            ExpectResult::ExpectsContinue
429        ));
430    }
431
432    #[test]
433    fn check_expect_100_continue_case_insensitive() {
434        let req = request_with_expect("100-Continue");
435        assert!(matches!(
436            ExpectHandler::check_expect(&req),
437            ExpectResult::ExpectsContinue
438        ));
439
440        let req = request_with_expect("100-CONTINUE");
441        assert!(matches!(
442            ExpectHandler::check_expect(&req),
443            ExpectResult::ExpectsContinue
444        ));
445    }
446
447    #[test]
448    fn check_expect_100_continue_token_list() {
449        let req = request_with_expect("100-continue, 100-continue");
450        assert!(matches!(
451            ExpectHandler::check_expect(&req),
452            ExpectResult::ExpectsContinue
453        ));
454    }
455
456    #[test]
457    fn check_expect_unknown() {
458        let req = request_with_expect("something-else");
459        let result = ExpectHandler::check_expect(&req);
460        assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
461        if let ExpectResult::UnknownExpectation(val) = result {
462            assert_eq!(val, "something-else");
463        }
464    }
465
466    #[test]
467    fn expects_continue_helper() {
468        let req_yes = request_with_expect("100-continue");
469        assert!(ExpectHandler::expects_continue(&req_yes));
470
471        let req_no = Request::new(Method::Get, "/");
472        assert!(!ExpectHandler::expects_continue(&req_no));
473    }
474
475    #[test]
476    fn check_expect_mixed_token_list_is_unknown() {
477        let req = request_with_expect("100-continue, custom");
478        let result = ExpectHandler::check_expect(&req);
479        assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
480        if let ExpectResult::UnknownExpectation(val) = result {
481            assert_eq!(val, "100-continue, custom");
482        }
483    }
484
485    #[test]
486    fn check_expect_empty_token_is_unknown() {
487        let req = request_with_expect("100-continue,");
488        let result = ExpectHandler::check_expect(&req);
489        assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
490        if let ExpectResult::UnknownExpectation(val) = result {
491            assert_eq!(val, "100-continue,");
492        }
493    }
494
495    #[test]
496    fn validate_content_length_no_limit() {
497        let handler = ExpectHandler::new();
498        let req = request_with_headers(&[("content-length", "1000000")]);
499        assert!(handler.validate_content_length(&req).is_ok());
500    }
501
502    #[test]
503    fn validate_content_length_within_limit() {
504        let handler = ExpectHandler::new().with_max_content_length(1024);
505        let req = request_with_headers(&[("content-length", "500")]);
506        assert!(handler.validate_content_length(&req).is_ok());
507    }
508
509    #[test]
510    fn validate_content_length_exceeds_limit() {
511        let handler = ExpectHandler::new().with_max_content_length(1024);
512        let req = request_with_headers(&[("content-length", "2048")]);
513        let result = handler.validate_content_length(&req);
514        assert!(result.is_err());
515        let response = result.unwrap_err();
516        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
517    }
518
519    #[test]
520    fn validate_content_type_no_requirement() {
521        let handler = ExpectHandler::new();
522        let req = request_with_headers(&[("content-type", "text/plain")]);
523        assert!(handler.validate_content_type(&req).is_ok());
524    }
525
526    #[test]
527    fn validate_content_type_matches() {
528        let handler = ExpectHandler::new().with_required_content_type("application/json");
529        let req = request_with_headers(&[("content-type", "application/json; charset=utf-8")]);
530        assert!(handler.validate_content_type(&req).is_ok());
531    }
532
533    #[test]
534    fn validate_content_type_missing() {
535        let handler = ExpectHandler::new().with_required_content_type("application/json");
536        let req = Request::new(Method::Post, "/upload");
537        let result = handler.validate_content_type(&req);
538        assert!(result.is_err());
539        let response = result.unwrap_err();
540        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
541    }
542
543    #[test]
544    fn validate_content_type_mismatch() {
545        let handler = ExpectHandler::new().with_required_content_type("application/json");
546        let req = request_with_headers(&[("content-type", "text/plain")]);
547        let result = handler.validate_content_type(&req);
548        assert!(result.is_err());
549        let response = result.unwrap_err();
550        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
551    }
552
553    #[test]
554    fn validate_all_passes() {
555        let handler = ExpectHandler::new()
556            .with_max_content_length(1024)
557            .with_required_content_type("application/json");
558        let req = request_with_headers(&[
559            ("content-length", "100"),
560            ("content-type", "application/json"),
561        ]);
562        assert!(handler.validate_all(&req).is_ok());
563    }
564
565    #[test]
566    fn validate_all_fails_on_first_error() {
567        let handler = ExpectHandler::new()
568            .with_max_content_length(100)
569            .with_required_content_type("application/json");
570        let req = request_with_headers(&[
571            ("content-length", "1000"),     // Exceeds limit
572            ("content-type", "text/plain"), // Wrong type
573        ]);
574        let result = handler.validate_all(&req);
575        assert!(result.is_err());
576        // Should fail on content-length first
577        let response = result.unwrap_err();
578        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
579    }
580
581    #[test]
582    fn error_responses() {
583        let resp = ExpectHandler::expectation_failed("test");
584        assert_eq!(resp.status().as_u16(), 417);
585
586        let resp = ExpectHandler::unauthorized("test");
587        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
588
589        let resp = ExpectHandler::forbidden("test");
590        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
591
592        let resp = ExpectHandler::payload_too_large("test");
593        assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
594
595        let resp = ExpectHandler::unsupported_media_type("test");
596        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
597    }
598
599    #[test]
600    fn continue_response_format() {
601        let expected = b"HTTP/1.1 100 Continue\r\n\r\n";
602        assert_eq!(CONTINUE_RESPONSE, expected);
603    }
604
605    #[test]
606    fn pre_body_validators() {
607        let mut validators = PreBodyValidators::new();
608        assert!(validators.is_empty());
609        assert_eq!(validators.len(), 0);
610
611        // Add a validator that checks for Authorization header
612        validators.add(FnValidator::new("auth_check", |req: &Request| {
613            if req.headers().get("authorization").is_some() {
614                Ok(())
615            } else {
616                Err(ExpectHandler::unauthorized("Missing Authorization header"))
617            }
618        }));
619
620        assert!(!validators.is_empty());
621        assert_eq!(validators.len(), 1);
622
623        // Test with missing auth
624        let req_no_auth = Request::new(Method::Post, "/upload");
625        let result = validators.validate_all(&req_no_auth);
626        assert!(result.is_err());
627        assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
628
629        // Test with auth
630        let req_with_auth = request_with_headers(&[("authorization", "Bearer token")]);
631        assert!(validators.validate_all(&req_with_auth).is_ok());
632    }
633
634    #[test]
635    fn pre_body_validators_chain() {
636        let validators = PreBodyValidators::new()
637            .with(FnValidator::new("auth", |req: &Request| {
638                if req.headers().get("authorization").is_some() {
639                    Ok(())
640                } else {
641                    Err(ExpectHandler::unauthorized("Missing auth"))
642                }
643            }))
644            .with(FnValidator::new("content_type", |req: &Request| {
645                if let Some(ct) = req.headers().get("content-type") {
646                    if ct.starts_with(b"application/json") {
647                        return Ok(());
648                    }
649                }
650                Err(ExpectHandler::unsupported_media_type("Expected JSON"))
651            }));
652
653        assert_eq!(validators.len(), 2);
654
655        // Both pass
656        let req = request_with_headers(&[
657            ("authorization", "Bearer token"),
658            ("content-type", "application/json"),
659        ]);
660        assert!(validators.validate_all(&req).is_ok());
661
662        // First fails
663        let req = request_with_headers(&[("content-type", "application/json")]);
664        let result = validators.validate_all(&req);
665        assert!(result.is_err());
666        assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
667    }
668}