1use fastapi_core::{Request, Response, ResponseBody, StatusCode};
61use std::sync::Arc;
62
63pub const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
71
72pub const EXPECT_100_CONTINUE: &str = "100-continue";
74
75#[derive(Debug, Clone)]
77pub enum ExpectResult {
78 NoExpectation,
80 ExpectsContinue,
82 UnknownExpectation(String),
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct ExpectHandler {
89 pub max_content_length: usize,
91 pub required_content_type: Option<String>,
93}
94
95impl ExpectHandler {
96 #[must_use]
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 #[must_use]
104 pub fn with_max_content_length(mut self, max: usize) -> Self {
105 self.max_content_length = max;
106 self
107 }
108
109 #[must_use]
111 pub fn with_required_content_type(mut self, content_type: impl Into<String>) -> Self {
112 self.required_content_type = Some(content_type.into());
113 self
114 }
115
116 #[must_use]
123 pub fn check_expect(request: &Request) -> ExpectResult {
124 match request.headers().get("expect") {
125 None => ExpectResult::NoExpectation,
126 Some(value) => {
127 let value_str = match std::str::from_utf8(value) {
128 Ok(s) => s.trim().to_ascii_lowercase(),
129 Err(_) => return ExpectResult::UnknownExpectation(String::new()),
130 };
131
132 let mut saw_continue = false;
133 for token in value_str.split(',').map(str::trim) {
134 if token.is_empty() {
135 return ExpectResult::UnknownExpectation(value_str);
136 }
137 if token == EXPECT_100_CONTINUE {
138 saw_continue = true;
139 } else {
140 return ExpectResult::UnknownExpectation(value_str);
141 }
142 }
143
144 if saw_continue {
145 ExpectResult::ExpectsContinue
146 } else {
147 ExpectResult::UnknownExpectation(value_str)
148 }
149 }
150 }
151 }
152
153 #[must_use]
158 pub fn expects_continue(request: &Request) -> bool {
159 matches!(Self::check_expect(request), ExpectResult::ExpectsContinue)
160 }
161
162 pub fn validate_content_length(&self, request: &Request) -> Result<(), Response> {
167 if self.max_content_length == 0 {
168 return Ok(()); }
170
171 if let Some(value) = request.headers().get("content-length") {
172 if let Ok(len_str) = std::str::from_utf8(value) {
173 if let Ok(len) = len_str.trim().parse::<usize>() {
174 if len > self.max_content_length {
175 return Err(Self::payload_too_large(format!(
176 "Content-Length {} exceeds maximum {}",
177 len, self.max_content_length
178 )));
179 }
180 }
181 }
182 }
183
184 Ok(())
185 }
186
187 pub fn validate_content_type(&self, request: &Request) -> Result<(), Response> {
192 let required = match &self.required_content_type {
193 Some(ct) => ct,
194 None => return Ok(()),
195 };
196
197 match request.headers().get("content-type") {
198 None => Err(Self::unsupported_media_type(format!(
199 "Content-Type required: {required}"
200 ))),
201 Some(value) => {
202 let content_type = std::str::from_utf8(value)
203 .map(|s| s.trim().to_ascii_lowercase())
204 .unwrap_or_default();
205
206 if content_type.starts_with(&required.to_ascii_lowercase()) {
207 Ok(())
208 } else {
209 Err(Self::unsupported_media_type(format!(
210 "Expected Content-Type: {required}, got: {content_type}"
211 )))
212 }
213 }
214 }
215 }
216
217 pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
221 self.validate_content_length(request)?;
222 self.validate_content_type(request)?;
223 Ok(())
224 }
225
226 #[must_use]
228 pub fn expectation_failed(detail: impl Into<String>) -> Response {
229 let detail = detail.into();
230 let body = format!("417 Expectation Failed: {detail}");
231 Response::with_status(StatusCode::from_u16(417))
233 .header("content-type", b"text/plain; charset=utf-8".to_vec())
234 .header("connection", b"close".to_vec())
235 .body(ResponseBody::Bytes(body.into_bytes()))
236 }
237
238 #[must_use]
240 pub fn unauthorized(detail: impl Into<String>) -> Response {
241 let detail = detail.into();
242 let body = format!("401 Unauthorized: {detail}");
243 Response::with_status(StatusCode::UNAUTHORIZED)
244 .header("content-type", b"text/plain; charset=utf-8".to_vec())
245 .header("connection", b"close".to_vec())
246 .body(ResponseBody::Bytes(body.into_bytes()))
247 }
248
249 #[must_use]
251 pub fn forbidden(detail: impl Into<String>) -> Response {
252 let detail = detail.into();
253 let body = format!("403 Forbidden: {detail}");
254 Response::with_status(StatusCode::FORBIDDEN)
255 .header("content-type", b"text/plain; charset=utf-8".to_vec())
256 .header("connection", b"close".to_vec())
257 .body(ResponseBody::Bytes(body.into_bytes()))
258 }
259
260 #[must_use]
262 pub fn payload_too_large(detail: impl Into<String>) -> Response {
263 let detail = detail.into();
264 let body = format!("413 Payload Too Large: {detail}");
265 Response::with_status(StatusCode::PAYLOAD_TOO_LARGE)
266 .header("content-type", b"text/plain; charset=utf-8".to_vec())
267 .header("connection", b"close".to_vec())
268 .body(ResponseBody::Bytes(body.into_bytes()))
269 }
270
271 #[must_use]
273 pub fn unsupported_media_type(detail: impl Into<String>) -> Response {
274 let detail = detail.into();
275 let body = format!("415 Unsupported Media Type: {detail}");
276 Response::with_status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
277 .header("content-type", b"text/plain; charset=utf-8".to_vec())
278 .header("connection", b"close".to_vec())
279 .body(ResponseBody::Bytes(body.into_bytes()))
280 }
281}
282
283pub trait PreBodyValidator: Send + Sync {
288 fn validate(&self, request: &Request) -> Result<(), Response>;
293
294 fn name(&self) -> &'static str {
296 "PreBodyValidator"
297 }
298}
299
300#[derive(Default, Clone)]
302pub struct PreBodyValidators {
303 validators: Vec<Arc<dyn PreBodyValidator>>,
304}
305
306impl std::fmt::Debug for PreBodyValidators {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 f.debug_struct("PreBodyValidators")
309 .field("len", &self.validators.len())
310 .field(
311 "validators",
312 &self.validators.iter().map(|v| v.name()).collect::<Vec<_>>(),
313 )
314 .finish()
315 }
316}
317
318impl PreBodyValidators {
319 #[must_use]
321 pub fn new() -> Self {
322 Self::default()
323 }
324
325 pub fn add<V: PreBodyValidator + 'static>(&mut self, validator: V) {
327 self.validators.push(Arc::new(validator));
328 }
329
330 #[must_use]
332 pub fn with<V: PreBodyValidator + 'static>(mut self, validator: V) -> Self {
333 self.add(validator);
334 self
335 }
336
337 pub fn validate_all(&self, request: &Request) -> Result<(), Response> {
341 for validator in &self.validators {
342 validator.validate(request)?;
343 }
344 Ok(())
345 }
346
347 #[must_use]
349 pub fn is_empty(&self) -> bool {
350 self.validators.is_empty()
351 }
352
353 #[must_use]
355 pub fn len(&self) -> usize {
356 self.validators.len()
357 }
358}
359
360pub struct FnValidator<F> {
362 name: &'static str,
363 validate_fn: F,
364}
365
366impl<F> FnValidator<F>
367where
368 F: Fn(&Request) -> Result<(), Response> + Send + Sync,
369{
370 pub fn new(name: &'static str, validate_fn: F) -> Self {
372 Self { name, validate_fn }
373 }
374}
375
376impl<F> PreBodyValidator for FnValidator<F>
377where
378 F: Fn(&Request) -> Result<(), Response> + Send + Sync,
379{
380 fn validate(&self, request: &Request) -> Result<(), Response> {
381 (self.validate_fn)(request)
382 }
383
384 fn name(&self) -> &'static str {
385 self.name
386 }
387}
388
389#[cfg(test)]
394mod tests {
395 use super::*;
396 use fastapi_core::Method;
397
398 fn request_with_expect(value: &str) -> Request {
399 let mut req = Request::new(Method::Post, "/upload");
400 req.headers_mut()
401 .insert("expect".to_string(), value.as_bytes().to_vec());
402 req
403 }
404
405 fn request_with_headers(headers: &[(&str, &str)]) -> Request {
406 let mut req = Request::new(Method::Post, "/upload");
407 for (name, value) in headers {
408 req.headers_mut()
409 .insert(name.to_string(), value.as_bytes().to_vec());
410 }
411 req
412 }
413
414 #[test]
415 fn check_expect_none() {
416 let req = Request::new(Method::Get, "/");
417 assert!(matches!(
418 ExpectHandler::check_expect(&req),
419 ExpectResult::NoExpectation
420 ));
421 }
422
423 #[test]
424 fn check_expect_100_continue() {
425 let req = request_with_expect("100-continue");
426 assert!(matches!(
427 ExpectHandler::check_expect(&req),
428 ExpectResult::ExpectsContinue
429 ));
430 }
431
432 #[test]
433 fn check_expect_100_continue_case_insensitive() {
434 let req = request_with_expect("100-Continue");
435 assert!(matches!(
436 ExpectHandler::check_expect(&req),
437 ExpectResult::ExpectsContinue
438 ));
439
440 let req = request_with_expect("100-CONTINUE");
441 assert!(matches!(
442 ExpectHandler::check_expect(&req),
443 ExpectResult::ExpectsContinue
444 ));
445 }
446
447 #[test]
448 fn check_expect_100_continue_token_list() {
449 let req = request_with_expect("100-continue, 100-continue");
450 assert!(matches!(
451 ExpectHandler::check_expect(&req),
452 ExpectResult::ExpectsContinue
453 ));
454 }
455
456 #[test]
457 fn check_expect_unknown() {
458 let req = request_with_expect("something-else");
459 let result = ExpectHandler::check_expect(&req);
460 assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
461 if let ExpectResult::UnknownExpectation(val) = result {
462 assert_eq!(val, "something-else");
463 }
464 }
465
466 #[test]
467 fn expects_continue_helper() {
468 let req_yes = request_with_expect("100-continue");
469 assert!(ExpectHandler::expects_continue(&req_yes));
470
471 let req_no = Request::new(Method::Get, "/");
472 assert!(!ExpectHandler::expects_continue(&req_no));
473 }
474
475 #[test]
476 fn check_expect_mixed_token_list_is_unknown() {
477 let req = request_with_expect("100-continue, custom");
478 let result = ExpectHandler::check_expect(&req);
479 assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
480 if let ExpectResult::UnknownExpectation(val) = result {
481 assert_eq!(val, "100-continue, custom");
482 }
483 }
484
485 #[test]
486 fn check_expect_empty_token_is_unknown() {
487 let req = request_with_expect("100-continue,");
488 let result = ExpectHandler::check_expect(&req);
489 assert!(matches!(result, ExpectResult::UnknownExpectation(_)));
490 if let ExpectResult::UnknownExpectation(val) = result {
491 assert_eq!(val, "100-continue,");
492 }
493 }
494
495 #[test]
496 fn validate_content_length_no_limit() {
497 let handler = ExpectHandler::new();
498 let req = request_with_headers(&[("content-length", "1000000")]);
499 assert!(handler.validate_content_length(&req).is_ok());
500 }
501
502 #[test]
503 fn validate_content_length_within_limit() {
504 let handler = ExpectHandler::new().with_max_content_length(1024);
505 let req = request_with_headers(&[("content-length", "500")]);
506 assert!(handler.validate_content_length(&req).is_ok());
507 }
508
509 #[test]
510 fn validate_content_length_exceeds_limit() {
511 let handler = ExpectHandler::new().with_max_content_length(1024);
512 let req = request_with_headers(&[("content-length", "2048")]);
513 let result = handler.validate_content_length(&req);
514 assert!(result.is_err());
515 let response = result.unwrap_err();
516 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
517 }
518
519 #[test]
520 fn validate_content_type_no_requirement() {
521 let handler = ExpectHandler::new();
522 let req = request_with_headers(&[("content-type", "text/plain")]);
523 assert!(handler.validate_content_type(&req).is_ok());
524 }
525
526 #[test]
527 fn validate_content_type_matches() {
528 let handler = ExpectHandler::new().with_required_content_type("application/json");
529 let req = request_with_headers(&[("content-type", "application/json; charset=utf-8")]);
530 assert!(handler.validate_content_type(&req).is_ok());
531 }
532
533 #[test]
534 fn validate_content_type_missing() {
535 let handler = ExpectHandler::new().with_required_content_type("application/json");
536 let req = Request::new(Method::Post, "/upload");
537 let result = handler.validate_content_type(&req);
538 assert!(result.is_err());
539 let response = result.unwrap_err();
540 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
541 }
542
543 #[test]
544 fn validate_content_type_mismatch() {
545 let handler = ExpectHandler::new().with_required_content_type("application/json");
546 let req = request_with_headers(&[("content-type", "text/plain")]);
547 let result = handler.validate_content_type(&req);
548 assert!(result.is_err());
549 let response = result.unwrap_err();
550 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
551 }
552
553 #[test]
554 fn validate_all_passes() {
555 let handler = ExpectHandler::new()
556 .with_max_content_length(1024)
557 .with_required_content_type("application/json");
558 let req = request_with_headers(&[
559 ("content-length", "100"),
560 ("content-type", "application/json"),
561 ]);
562 assert!(handler.validate_all(&req).is_ok());
563 }
564
565 #[test]
566 fn validate_all_fails_on_first_error() {
567 let handler = ExpectHandler::new()
568 .with_max_content_length(100)
569 .with_required_content_type("application/json");
570 let req = request_with_headers(&[
571 ("content-length", "1000"), ("content-type", "text/plain"), ]);
574 let result = handler.validate_all(&req);
575 assert!(result.is_err());
576 let response = result.unwrap_err();
578 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
579 }
580
581 #[test]
582 fn error_responses() {
583 let resp = ExpectHandler::expectation_failed("test");
584 assert_eq!(resp.status().as_u16(), 417);
585
586 let resp = ExpectHandler::unauthorized("test");
587 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
588
589 let resp = ExpectHandler::forbidden("test");
590 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
591
592 let resp = ExpectHandler::payload_too_large("test");
593 assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
594
595 let resp = ExpectHandler::unsupported_media_type("test");
596 assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
597 }
598
599 #[test]
600 fn continue_response_format() {
601 let expected = b"HTTP/1.1 100 Continue\r\n\r\n";
602 assert_eq!(CONTINUE_RESPONSE, expected);
603 }
604
605 #[test]
606 fn pre_body_validators() {
607 let mut validators = PreBodyValidators::new();
608 assert!(validators.is_empty());
609 assert_eq!(validators.len(), 0);
610
611 validators.add(FnValidator::new("auth_check", |req: &Request| {
613 if req.headers().get("authorization").is_some() {
614 Ok(())
615 } else {
616 Err(ExpectHandler::unauthorized("Missing Authorization header"))
617 }
618 }));
619
620 assert!(!validators.is_empty());
621 assert_eq!(validators.len(), 1);
622
623 let req_no_auth = Request::new(Method::Post, "/upload");
625 let result = validators.validate_all(&req_no_auth);
626 assert!(result.is_err());
627 assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
628
629 let req_with_auth = request_with_headers(&[("authorization", "Bearer token")]);
631 assert!(validators.validate_all(&req_with_auth).is_ok());
632 }
633
634 #[test]
635 fn pre_body_validators_chain() {
636 let validators = PreBodyValidators::new()
637 .with(FnValidator::new("auth", |req: &Request| {
638 if req.headers().get("authorization").is_some() {
639 Ok(())
640 } else {
641 Err(ExpectHandler::unauthorized("Missing auth"))
642 }
643 }))
644 .with(FnValidator::new("content_type", |req: &Request| {
645 if let Some(ct) = req.headers().get("content-type") {
646 if ct.starts_with(b"application/json") {
647 return Ok(());
648 }
649 }
650 Err(ExpectHandler::unsupported_media_type("Expected JSON"))
651 }));
652
653 assert_eq!(validators.len(), 2);
654
655 let req = request_with_headers(&[
657 ("authorization", "Bearer token"),
658 ("content-type", "application/json"),
659 ]);
660 assert!(validators.validate_all(&req).is_ok());
661
662 let req = request_with_headers(&[("content-type", "application/json")]);
664 let result = validators.validate_all(&req);
665 assert!(result.is_err());
666 assert_eq!(result.unwrap_err().status(), StatusCode::UNAUTHORIZED);
667 }
668}