1use fastapi_core::{Request, Response, ResponseBody, StatusCode};
61
62pub const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
70
71pub const EXPECT_100_CONTINUE: &str = "100-continue";
73
74#[derive(Debug, Clone)]
76pub enum ExpectResult {
77 NoExpectation,
79 ExpectsContinue,
81 UnknownExpectation(String),
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct ExpectHandler {
88 pub max_content_length: usize,
90 pub required_content_type: Option<String>,
92}
93
94impl ExpectHandler {
95 #[must_use]
97 pub fn new() -> Self {
98 Self::default()
99 }
100
101 #[must_use]
103 pub fn with_max_content_length(mut self, max: usize) -> Self {
104 self.max_content_length = max;
105 self
106 }
107
108 #[must_use]
110 pub fn with_required_content_type(mut self, content_type: impl Into<String>) -> Self {
111 self.required_content_type = Some(content_type.into());
112 self
113 }
114
115 #[must_use]
122 pub fn check_expect(request: &Request) -> ExpectResult {
123 match request.headers().get("expect") {
124 None => ExpectResult::NoExpectation,
125 Some(value) => {
126 let value_str = std::str::from_utf8(value)
127 .map(|s| s.trim().to_ascii_lowercase())
128 .unwrap_or_default();
129
130 if value_str == EXPECT_100_CONTINUE {
131 ExpectResult::ExpectsContinue
132 } else {
133 ExpectResult::UnknownExpectation(value_str)
134 }
135 }
136 }
137 }
138
139 #[must_use]
144 pub fn expects_continue(request: &Request) -> bool {
145 matches!(Self::check_expect(request), ExpectResult::ExpectsContinue)
146 }
147
148 pub fn validate_content_length(&self, request: &Request) -> Result<(), Response> {
153 if self.max_content_length == 0 {
154 return Ok(()); }
156
157 if let Some(value) = request.headers().get("content-length") {
158 if let Ok(len_str) = std::str::from_utf8(value) {
159 if let Ok(len) = len_str.trim().parse::<usize>() {
160 if len > self.max_content_length {
161 return Err(Self::payload_too_large(format!(
162 "Content-Length {} exceeds maximum {}",
163 len, self.max_content_length
164 )));
165 }
166 }
167 }
168 }
169
170 Ok(())
171 }
172
173 pub fn validate_content_type(&self, request: &Request) -> Result<(), Response> {
178 let required = match &self.required_content_type {
179 Some(ct) => ct,
180 None => return Ok(()),
181 };
182
183 match request.headers().get("content-type") {
184 None => Err(Self::unsupported_media_type(format!(
185 "Content-Type required: {required}"
186 ))),
187 Some(value) => {
188 let content_type = std::str::from_utf8(value)
189 .map(|s| s.trim().to_ascii_lowercase())
190 .unwrap_or_default();
191
192 if content_type.starts_with(&required.to_ascii_lowercase()) {
193 Ok(())
194 } else {
195 Err(Self::unsupported_media_type(format!(
196 "Expected Content-Type: {required}, got: {content_type}"
197 )))
198 }
199 }
200 }
201 }
202
203 pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
207 self.validate_content_length(request)?;
208 self.validate_content_type(request)?;
209 Ok(())
210 }
211
212 #[must_use]
214 pub fn expectation_failed(detail: impl Into<String>) -> Response {
215 let detail = detail.into();
216 let body = format!("417 Expectation Failed: {detail}");
217 Response::with_status(StatusCode::from_u16(417))
219 .header("content-type", b"text/plain; charset=utf-8".to_vec())
220 .header("connection", b"close".to_vec())
221 .body(ResponseBody::Bytes(body.into_bytes()))
222 }
223
224 #[must_use]
226 pub fn unauthorized(detail: impl Into<String>) -> Response {
227 let detail = detail.into();
228 let body = format!("401 Unauthorized: {detail}");
229 Response::with_status(StatusCode::UNAUTHORIZED)
230 .header("content-type", b"text/plain; charset=utf-8".to_vec())
231 .header("connection", b"close".to_vec())
232 .body(ResponseBody::Bytes(body.into_bytes()))
233 }
234
235 #[must_use]
237 pub fn forbidden(detail: impl Into<String>) -> Response {
238 let detail = detail.into();
239 let body = format!("403 Forbidden: {detail}");
240 Response::with_status(StatusCode::FORBIDDEN)
241 .header("content-type", b"text/plain; charset=utf-8".to_vec())
242 .header("connection", b"close".to_vec())
243 .body(ResponseBody::Bytes(body.into_bytes()))
244 }
245
246 #[must_use]
248 pub fn payload_too_large(detail: impl Into<String>) -> Response {
249 let detail = detail.into();
250 let body = format!("413 Payload Too Large: {detail}");
251 Response::with_status(StatusCode::PAYLOAD_TOO_LARGE)
252 .header("content-type", b"text/plain; charset=utf-8".to_vec())
253 .header("connection", b"close".to_vec())
254 .body(ResponseBody::Bytes(body.into_bytes()))
255 }
256
257 #[must_use]
259 pub fn unsupported_media_type(detail: impl Into<String>) -> Response {
260 let detail = detail.into();
261 let body = format!("415 Unsupported Media Type: {detail}");
262 Response::with_status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
263 .header("content-type", b"text/plain; charset=utf-8".to_vec())
264 .header("connection", b"close".to_vec())
265 .body(ResponseBody::Bytes(body.into_bytes()))
266 }
267}
268
269pub trait PreBodyValidator: Send + Sync {
274 fn validate(&self, request: &Request) -> Result<(), Response>;
279
280 fn name(&self) -> &'static str {
282 "PreBodyValidator"
283 }
284}
285
286#[derive(Default)]
288pub struct PreBodyValidators {
289 validators: Vec<Box<dyn PreBodyValidator>>,
290}
291
292impl PreBodyValidators {
293 #[must_use]
295 pub fn new() -> Self {
296 Self::default()
297 }
298
299 pub fn add<V: PreBodyValidator + 'static>(&mut self, validator: V) {
301 self.validators.push(Box::new(validator));
302 }
303
304 #[must_use]
306 pub fn with<V: PreBodyValidator + 'static>(mut self, validator: V) -> Self {
307 self.add(validator);
308 self
309 }
310
311 pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
315 for validator in &self.validators {
316 validator.validate(request)?;
317 }
318 Ok(())
319 }
320
321 #[must_use]
323 pub fn is_empty(&self) -> bool {
324 self.validators.is_empty()
325 }
326
327 #[must_use]
329 pub fn len(&self) -> usize {
330 self.validators.len()
331 }
332}
333
334pub struct FnValidator<F> {
336 name: &'static str,
337 validate_fn: F,
338}
339
340impl<F> FnValidator<F>
341where
342 F: Fn(&Request) -> Result<(), Response> + Send + Sync,
343{
344 pub fn new(name: &'static str, validate_fn: F) -> Self {
346 Self { name, validate_fn }
347 }
348}
349
350impl<F> PreBodyValidator for FnValidator<F>
351where
352 F: Fn(&Request) -> Result<(), Response> + Send + Sync,
353{
354 fn validate(&self, request: &Request) -> Result<(), Response> {
355 (self.validate_fn)(request)
356 }
357
358 fn name(&self) -> &'static str {
359 self.name
360 }
361}
362
363#[cfg(test)]
368mod tests {
369 use super::*;
370 use fastapi_core::Method;
371
372 fn request_with_expect(value: &str) -> Request {
373 let mut req = Request::new(Method::Post, "/upload");
374 req.headers_mut()
375 .insert("expect".to_string(), value.as_bytes().to_vec());
376 req
377 }
378
379 fn request_with_headers(headers: &[(&str, &str)]) -> Request {
380 let mut req = Request::new(Method::Post, "/upload");
381 for (name, value) in headers {
382 req.headers_mut()
383 .insert(name.to_string(), value.as_bytes().to_vec());
384 }
385 req
386 }
387
388 #[test]
389 fn check_expect_none() {
390 let req = Request::new(Method::Get, "/");
391 assert!(matches!(
392 ExpectHandler::check_expect(&req),
393 ExpectResult::NoExpectation
394 ));
395 }
396
397 #[test]
398 fn check_expect_100_continue() {
399 let req = request_with_expect("100-continue");
400 assert!(matches!(
401 ExpectHandler::check_expect(&req),
402 ExpectResult::ExpectsContinue
403 ));
404 }
405
406 #[test]
407 fn check_expect_100_continue_case_insensitive() {
408 let req = request_with_expect("100-Continue");
409 assert!(matches!(
410 ExpectHandler::check_expect(&req),
411 ExpectResult::ExpectsContinue
412 ));
413
414 let req = request_with_expect("100-CONTINUE");
415 assert!(matches!(
416 ExpectHandler::check_expect(&req),
417 ExpectResult::ExpectsContinue
418 ));
419 }
420
421 #[test]
422 fn check_expect_unknown() {
423 let req = request_with_expect("something-else");
424 match ExpectHandler::check_expect(&req) {
425 ExpectResult::UnknownExpectation(val) => assert_eq!(val, "something-else"),
426 _ => panic!("Expected UnknownExpectation"),
427 }
428 }
429
430 #[test]
431 fn expects_continue_helper() {
432 let req_yes = request_with_expect("100-continue");
433 assert!(ExpectHandler::expects_continue(&req_yes));
434
435 let req_no = Request::new(Method::Get, "/");
436 assert!(!ExpectHandler::expects_continue(&req_no));
437 }
438
439 #[test]
440 fn validate_content_length_no_limit() {
441 let handler = ExpectHandler::new();
442 let req = request_with_headers(&[("content-length", "1000000")]);
443 assert!(handler.validate_content_length(&req).is_ok());
444 }
445
446 #[test]
447 fn validate_content_length_within_limit() {
448 let handler = ExpectHandler::new().with_max_content_length(1024);
449 let req = request_with_headers(&[("content-length", "500")]);
450 assert!(handler.validate_content_length(&req).is_ok());
451 }
452
453 #[test]
454 fn validate_content_length_exceeds_limit() {
455 let handler = ExpectHandler::new().with_max_content_length(1024);
456 let req = request_with_headers(&[("content-length", "2048")]);
457 let result = handler.validate_content_length(&req);
458 assert!(result.is_err());
459 let response = result.unwrap_err();
460 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
461 }
462
463 #[test]
464 fn validate_content_type_no_requirement() {
465 let handler = ExpectHandler::new();
466 let req = request_with_headers(&[("content-type", "text/plain")]);
467 assert!(handler.validate_content_type(&req).is_ok());
468 }
469
470 #[test]
471 fn validate_content_type_matches() {
472 let handler = ExpectHandler::new().with_required_content_type("application/json");
473 let req = request_with_headers(&[("content-type", "application/json; charset=utf-8")]);
474 assert!(handler.validate_content_type(&req).is_ok());
475 }
476
477 #[test]
478 fn validate_content_type_missing() {
479 let handler = ExpectHandler::new().with_required_content_type("application/json");
480 let req = Request::new(Method::Post, "/upload");
481 let result = handler.validate_content_type(&req);
482 assert!(result.is_err());
483 let response = result.unwrap_err();
484 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
485 }
486
487 #[test]
488 fn validate_content_type_mismatch() {
489 let handler = ExpectHandler::new().with_required_content_type("application/json");
490 let req = request_with_headers(&[("content-type", "text/plain")]);
491 let result = handler.validate_content_type(&req);
492 assert!(result.is_err());
493 let response = result.unwrap_err();
494 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
495 }
496
497 #[test]
498 fn validate_all_passes() {
499 let handler = ExpectHandler::new()
500 .with_max_content_length(1024)
501 .with_required_content_type("application/json");
502 let req = request_with_headers(&[
503 ("content-length", "100"),
504 ("content-type", "application/json"),
505 ]);
506 assert!(handler.validate_all(&req).is_ok());
507 }
508
509 #[test]
510 fn validate_all_fails_on_first_error() {
511 let handler = ExpectHandler::new()
512 .with_max_content_length(100)
513 .with_required_content_type("application/json");
514 let req = request_with_headers(&[
515 ("content-length", "1000"), ("content-type", "text/plain"), ]);
518 let result = handler.validate_all(&req);
519 assert!(result.is_err());
520 let response = result.unwrap_err();
522 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
523 }
524
525 #[test]
526 fn error_responses() {
527 let resp = ExpectHandler::expectation_failed("test");
528 assert_eq!(resp.status().as_u16(), 417);
529
530 let resp = ExpectHandler::unauthorized("test");
531 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
532
533 let resp = ExpectHandler::forbidden("test");
534 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
535
536 let resp = ExpectHandler::payload_too_large("test");
537 assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
538
539 let resp = ExpectHandler::unsupported_media_type("test");
540 assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
541 }
542
543 #[test]
544 fn continue_response_format() {
545 let expected = b"HTTP/1.1 100 Continue\r\n\r\n";
546 assert_eq!(CONTINUE_RESPONSE, expected);
547 }
548
549 #[test]
550 fn pre_body_validators() {
551 let mut validators = PreBodyValidators::new();
552 assert!(validators.is_empty());
553 assert_eq!(validators.len(), 0);
554
555 validators.add(FnValidator::new("auth_check", |req: &Request| {
557 if req.headers().get("authorization").is_some() {
558 Ok(())
559 } else {
560 Err(ExpectHandler::unauthorized("Missing Authorization header"))
561 }
562 }));
563
564 assert!(!validators.is_empty());
565 assert_eq!(validators.len(), 1);
566
567 let req_no_auth = Request::new(Method::Post, "/upload");
569 let result = validators.validate_all(&req_no_auth);
570 assert!(result.is_err());
571 assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
572
573 let req_with_auth = request_with_headers(&[("authorization", "Bearer token")]);
575 assert!(validators.validate_all(&req_with_auth).is_ok());
576 }
577
578 #[test]
579 fn pre_body_validators_chain() {
580 let validators = PreBodyValidators::new()
581 .with(FnValidator::new("auth", |req: &Request| {
582 if req.headers().get("authorization").is_some() {
583 Ok(())
584 } else {
585 Err(ExpectHandler::unauthorized("Missing auth"))
586 }
587 }))
588 .with(FnValidator::new("content_type", |req: &Request| {
589 if let Some(ct) = req.headers().get("content-type") {
590 if ct.starts_with(b"application/json") {
591 return Ok(());
592 }
593 }
594 Err(ExpectHandler::unsupported_media_type("Expected JSON"))
595 }));
596
597 assert_eq!(validators.len(), 2);
598
599 let req = request_with_headers(&[
601 ("authorization", "Bearer token"),
602 ("content-type", "application/json"),
603 ]);
604 assert!(validators.validate_all(&req).is_ok());
605
606 let req = request_with_headers(&[("content-type", "application/json")]);
608 let result = validators.validate_all(&req);
609 assert!(result.is_err());
610 assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
611 }
612}