1use super::*;
16use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING};
17use http::{Method, StatusCode};
18use pingora_cache::key::CacheHashKey;
19use pingora_cache::lock::LockStatus;
20use pingora_cache::max_file_size::ERR_RESPONSE_TOO_LARGE;
21use pingora_cache::{ForcedFreshness, HitHandler, HitStatus, RespCacheable::*};
22use pingora_core::protocols::http::conditional_filter::to_304;
23use pingora_core::protocols::http::v1::common::header_value_content_length;
24use pingora_core::ErrorType;
25use range_filter::RangeBodyFilter;
26use std::time::SystemTime;
27
28impl<SV, C> HttpProxy<SV, C>
29where
30 C: custom::Connector,
31{
32 pub(crate) async fn proxy_cache(
34 self: &Arc<Self>,
35 session: &mut Session,
36 ctx: &mut SV::CTX,
37 ) -> Option<(bool, Option<Box<Error>>)>
38 where
40 SV: ProxyHttp + Send + Sync + 'static,
41 SV::CTX: Send + Sync,
42 {
43 if let Err(e) = self.inner.request_cache_filter(session, ctx) {
45 warn!(
47 "Fail to request_cache_filter: {e}, {}",
48 self.inner.request_summary(session, ctx)
49 );
50 }
51
52 if session.cache.enabled() {
54 match self.inner.cache_key_callback(session, ctx) {
55 Ok(key) => {
56 session.cache.set_cache_key(key);
57 }
58 Err(e) => {
59 session.cache.disable(NoCacheReason::StorageError);
61 warn!(
62 "Fail to cache_key_callback: {e}, {}",
63 self.inner.request_summary(session, ctx)
64 );
65 }
66 }
67 }
68
69 if self.inner.is_purge(session, ctx) {
71 return self.proxy_purge(session, ctx).await;
72 }
73
74 if session.cache.enabled() && !session.cache.cacheable_prediction() {
76 session.cache.bypass();
77 }
78
79 if !session.cache.enabled() {
80 return None;
81 }
82
83 loop {
85 match session.cache.cache_lookup().await {
87 Ok(res) => {
88 let mut hit_status_opt = None;
89 if let Some((mut meta, mut handler)) = res {
90 let cache_key = session.cache.cache_key();
95 if let Some(variance) = cache_key.variance_bin() {
96 if Some(variance) != meta.variance() {
99 warn!("Cache variance mismatch, {variance:?}, {cache_key:?}");
100 session.cache.disable(NoCacheReason::InternalError);
101 break None;
102 }
103 } else {
104 let req_header = session.req_header();
106 let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
107 if let Some(variance) = variance {
108 if !session.cache.cache_vary_lookup(variance, &meta) {
110 continue;
113 }
114 } }
116
117 let is_fresh = meta.is_fresh(SystemTime::now());
122 let hit_status = match self
124 .inner
125 .cache_hit_filter(session, &meta, &mut handler, is_fresh, ctx)
126 .await
127 {
128 Err(e) => {
129 error!(
130 "Failed to filter cache hit: {e}, {}",
131 self.inner.request_summary(session, ctx)
132 );
133 HitStatus::FailedHitFilter
135 }
136 Ok(None) => {
137 if is_fresh {
138 HitStatus::Fresh
139 } else {
140 HitStatus::Expired
141 }
142 }
143 Ok(Some(ForcedFreshness::ForceExpired)) => {
144 meta.disable_serve_stale();
147 HitStatus::ForceExpired
148 }
149 Ok(Some(ForcedFreshness::ForceMiss)) => HitStatus::ForceMiss,
150 Ok(Some(ForcedFreshness::ForceFresh)) => HitStatus::Fresh,
151 };
152
153 hit_status_opt = Some(hit_status);
154
155 session.cache.cache_found(meta, handler, hit_status);
157 }
158
159 if hit_status_opt.is_none_or(HitStatus::is_treated_as_miss) {
160 if session.cache.is_cache_locked() {
162 let lock_status = session.cache.cache_lock_wait().await;
164 if self.handle_lock_status(session, ctx, lock_status) {
165 continue;
166 } else {
167 break None;
168 }
169 } else {
170 self.inner.cache_miss(session, ctx);
171 break None;
172 }
173 }
174
175 let hit_status = hit_status_opt.expect("None case handled as miss");
178
179 if !hit_status.is_fresh() {
180 if session.cache.is_cache_locked() {
182 if let Some(write_lock) = session
184 .subrequest_ctx
185 .as_mut()
186 .and_then(|ctx| ctx.take_write_lock())
187 {
188 session.cache.set_write_lock(write_lock);
190 session.cache.tag_as_subrequest();
191 break None;
193 }
194 let will_serve_stale = session.cache.can_serve_stale_updating()
195 && self.inner.should_serve_stale(session, ctx, None);
196 if !will_serve_stale {
197 let lock_status = session.cache.cache_lock_wait().await;
198 if self.handle_lock_status(session, ctx, lock_status) {
199 continue;
200 } else {
201 break None;
202 }
203 }
204 session.cache.set_stale_updating();
206 } else if session.cache.is_cache_lock_writer() {
207 let will_serve_stale = session.cache.can_serve_stale_updating()
209 && self.inner.should_serve_stale(session, ctx, None);
210 if will_serve_stale {
211 let (permit, cache_lock) = session.cache.take_write_lock();
215 SubrequestSpawner::new(self.clone()).spawn_background_subrequest(
216 session.as_ref(),
217 subrequest::Ctx::builder()
218 .cache_write_lock(
219 cache_lock,
220 session.cache.cache_key().clone(),
221 permit,
222 )
223 .build(),
224 );
225 session.cache.set_stale_updating();
227 } else {
228 break None;
230 }
231 } else {
232 break None;
234 }
235 }
236
237 let (reuse, err) = self.proxy_cache_hit(session, ctx).await;
238 if let Some(e) = err.as_ref() {
239 error!(
240 "Fail to serve cache: {e}, {}",
241 self.inner.request_summary(session, ctx)
242 );
243 }
244 break Some((reuse, err));
246 }
247 Err(e) => {
248 self.inner.cache_miss(session, ctx);
253 warn!(
254 "Fail to cache lookup: {e}, {}",
255 self.inner.request_summary(session, ctx)
256 );
257 break None;
258 }
259 }
260 }
261 }
262
263 pub(crate) async fn proxy_cache_hit(
265 &self,
266 session: &mut Session,
267 ctx: &mut SV::CTX,
268 ) -> (bool, Option<Box<Error>>)
269 where
270 SV: ProxyHttp + Send + Sync,
271 SV::CTX: Send + Sync,
272 {
273 use range_filter::*;
274
275 let seekable = session.cache.hit_handler().can_seek();
276 let mut header = cache_hit_header(&session.cache);
277
278 let req = session.req_header();
279
280 let not_modified = match self.inner.cache_not_modified_filter(session, &header, ctx) {
281 Ok(not_modified) => not_modified,
282 Err(e) => {
283 warn!(
286 "Failed to run cache not modified filter: {e}, {}",
287 self.inner.request_summary(session, ctx)
288 );
289 false
290 }
291 };
292 if not_modified {
293 to_304(&mut header);
294 }
295 let header_only = not_modified || req.method == http::method::Method::HEAD;
296
297 let range_type = if seekable && !session.ignore_downstream_range {
299 self.inner.range_header_filter(session, &mut header, ctx)
300 } else {
301 RangeType::None
302 };
303
304 let header_only = header_only || matches!(range_type, RangeType::Invalid);
306 debug!("header: {header:?}");
307
308 match self.inner.response_filter(session, &mut header, ctx).await {
310 Ok(_) => {
311 if let Err(e) = session
312 .downstream_modules_ctx
313 .response_header_filter(&mut header, header_only)
314 .await
315 {
316 error!(
317 "Failed to run downstream modules response header filter in hit: {e}, {}",
318 self.inner.request_summary(session, ctx)
319 );
320 session
321 .as_mut()
322 .respond_error(500)
323 .await
324 .unwrap_or_else(|e| {
325 error!("failed to send error response to downstream: {e}");
326 });
327 return (true, Some(e));
329 }
330
331 if let Err(e) = session
332 .as_mut()
333 .write_response_header(header)
334 .await
335 .map_err(|e| e.into_down())
336 {
337 return (false, Some(e));
339 }
340 }
341 Err(e) => {
342 error!(
343 "Failed to run response filter in hit: {e}, {}",
344 self.inner.request_summary(session, ctx)
345 );
346 session
347 .as_mut()
348 .respond_error(500)
349 .await
350 .unwrap_or_else(|e| {
351 error!("failed to send error response to downstream: {e}");
352 });
353 return (true, Some(e));
355 }
356 }
357 debug!("finished sending cached header to downstream");
358
359 fn seek_multipart(
364 hit_handler: &mut HitHandler,
365 range_filter: &mut RangeBodyFilter,
366 ) -> Result<bool> {
367 if !range_filter.is_multipart_range() || !hit_handler.can_seek_multipart() {
368 return Ok(false);
369 }
370 let r = range_filter.next_cache_multipart_range();
371 hit_handler.seek_multipart(r.start, Some(r.end))?;
372 range_filter.set_current_cursor(r.start);
375 Ok(true)
376 }
377
378 if !header_only {
379 let mut maybe_range_filter = match &range_type {
380 RangeType::Single(r) => {
381 if session.cache.hit_handler().can_seek() {
382 if let Err(e) = session.cache.hit_handler().seek(r.start, Some(r.end)) {
383 return (false, Some(e));
384 }
385 None
386 } else {
387 Some(RangeBodyFilter::new_range(range_type.clone()))
388 }
389 }
390 RangeType::Multi(_) => {
391 let mut range_filter = RangeBodyFilter::new_range(range_type.clone());
392 if let Err(e) = seek_multipart(session.cache.hit_handler(), &mut range_filter) {
393 return (false, Some(e));
394 }
395 Some(range_filter)
396 }
397 RangeType::Invalid => unreachable!(),
398 RangeType::None => None,
399 };
400 loop {
401 match session.cache.hit_handler().read_body().await {
402 Ok(raw_body) => {
403 let end = raw_body.is_none();
404
405 if end {
406 if let Some(range_filter) = maybe_range_filter.as_mut() {
407 if range_filter.should_cache_seek_again() {
408 let e = match seek_multipart(
409 session.cache.hit_handler(),
410 range_filter,
411 ) {
412 Ok(true) => {
413 continue;
415 }
416 Ok(false) => {
417 Error::explain(
424 InternalError,
425 "hit handler cannot seek for multipart again",
426 )
427 }
429 Err(e) => e,
430 };
431 return (false, Some(e));
432 }
433 }
434 }
435
436 let mut body = if let Some(range_filter) = maybe_range_filter.as_mut() {
437 range_filter.filter_body(raw_body)
438 } else {
439 raw_body
440 };
441
442 match self
443 .inner
444 .response_body_filter(session, &mut body, end, ctx)
445 {
446 Ok(Some(duration)) => {
447 trace!("delaying response for {duration:?}");
448 time::sleep(duration).await;
449 }
450 Ok(None) => { }
451 Err(e) => {
452 return (false, Some(e));
454 }
455 }
456
457 if let Err(e) = session
458 .downstream_modules_ctx
459 .response_body_filter(&mut body, end)
460 {
461 return (false, Some(e));
463 }
464
465 if !end && body.as_ref().is_none_or(|b| b.is_empty()) {
466 continue;
469 }
470
471 let b = body.unwrap_or_default();
473 if let Err(e) = session
474 .as_mut()
475 .write_response_body(b, end)
476 .await
477 .map_err(|e| e.into_down())
478 {
479 return (false, Some(e));
480 }
481 if end {
482 break;
483 }
484 }
485 Err(e) => return (false, Some(e)),
486 }
487 }
488 }
489
490 if let Err(e) = session.cache.finish_hit_handler().await {
491 warn!("Error during finish_hit_handler: {}", e);
492 }
493
494 match session.as_mut().finish_body().await {
495 Ok(_) => {
496 debug!("finished sending cached body to downstream");
497 (true, None)
498 }
499 Err(e) => (false, Some(e)),
500 }
501 }
502
503 pub(crate) fn downstream_response_conditional_filter(
506 &self,
507 use_cache: &mut ServeFromCache,
508 session: &Session,
509 resp: &mut ResponseHeader,
510 ctx: &mut SV::CTX,
511 ) where
512 SV: ProxyHttp,
513 {
514 let req = session.req_header();
516
517 let not_modified = match self.inner.cache_not_modified_filter(session, resp, ctx) {
518 Ok(not_modified) => not_modified,
519 Err(e) => {
520 warn!(
523 "Failed to run cache not modified filter: {e}, {}",
524 self.inner.request_summary(session, ctx)
525 );
526 false
527 }
528 };
529
530 if not_modified {
531 to_304(resp);
532 }
533 let header_only = not_modified || req.method == http::method::Method::HEAD;
534 if header_only && use_cache.is_on() {
535 use_cache.enable_header_only();
538 }
539 }
540
541 pub(crate) async fn cache_http_task(
544 &self,
545 session: &mut Session,
546 task: &HttpTask,
547 ctx: &mut SV::CTX,
548 serve_from_cache: &mut ServeFromCache,
549 ) -> Result<()>
550 where
551 SV: ProxyHttp + Send + Sync,
552 SV::CTX: Send + Sync,
553 {
554 if !session.cache.enabled() && !session.cache.bypassing() {
555 return Ok(());
556 }
557
558 match task {
559 HttpTask::Header(header, end_stream) => {
560 if header.status.is_informational()
564 && header.status != StatusCode::SWITCHING_PROTOCOLS
565 {
566 return Ok(());
567 }
568 match self.inner.response_cache_filter(session, header, ctx)? {
569 Cacheable(meta) => {
570 let mut fill_cache = true;
571 if session.cache.bypassing() {
572 if session.cache.max_file_size_bytes().is_some()
579 && !meta.headers().contains_key(header::CONTENT_LENGTH)
580 {
581 session
582 .cache
583 .disable(NoCacheReason::PredictedResponseTooLarge);
584 return Ok(());
585 }
586
587 session.cache.response_became_cacheable();
588
589 if session.req_header().method == Method::GET
590 && meta.response_header().status == StatusCode::OK
591 {
592 self.inner.cache_miss(session, ctx);
593 if !session.cache.enabled() {
594 fill_cache = false;
595 }
596 } else {
597 fill_cache = false;
603 session.cache.disable(NoCacheReason::Deferred);
604 }
605 }
606
607 if session.cache.enabled() {
610 if let Some(max_file_size) = session.cache.max_file_size_bytes() {
611 let content_length_hdr = meta.headers().get(header::CONTENT_LENGTH);
612 if let Some(content_length) =
613 header_value_content_length(content_length_hdr)
614 {
615 if content_length > max_file_size {
616 fill_cache = false;
617 session.cache.response_became_uncacheable(
618 NoCacheReason::ResponseTooLarge,
619 );
620 session.cache.disable(NoCacheReason::ResponseTooLarge);
621 session.ignore_downstream_range = true;
623 }
624 }
625 }
629 }
630 if fill_cache {
631 let req_header = session.req_header();
632 let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
637 session.cache.set_cache_meta(meta);
638 session.cache.update_variance(variance);
639 session.cache.set_miss_handler().await?;
641 if session.cache.miss_body_reader().is_some() {
642 serve_from_cache.enable_miss();
643 }
644 if *end_stream {
645 session
646 .cache
647 .miss_handler()
648 .unwrap() .write_body(Bytes::new(), true)
650 .await?;
651 session.cache.finish_miss_handler().await?;
652 }
653 }
654 }
655 Uncacheable(reason) => {
656 if !session.cache.bypassing() {
657 session.cache.response_became_uncacheable(reason);
659 }
660 session.cache.disable(reason);
661 }
662 }
663 }
664 HttpTask::Body(data, end_stream) | HttpTask::UpgradedBody(data, end_stream) => {
665 match data {
669 Some(d) => {
670 if session.cache.enabled() {
671 let body_size_allowed =
674 session.cache.track_body_bytes_for_max_file_size(d.len());
675 if !body_size_allowed {
676 debug!("chunked response exceeded max cache size, remembering that it is uncacheable");
677 session
678 .cache
679 .response_became_uncacheable(NoCacheReason::ResponseTooLarge);
680
681 return Error::e_explain(
682 ERR_RESPONSE_TOO_LARGE,
683 format!(
684 "writing data of size {} bytes would exceed max file size of {} bytes",
685 d.len(),
686 session.cache.max_file_size_bytes().expect("max file size bytes must be set to exceed size")
687 ),
688 );
689 }
690
691 let miss_handler = session.cache.miss_handler().unwrap();
694
695 miss_handler.write_body(d.clone(), *end_stream).await?;
696 if *end_stream {
697 session.cache.finish_miss_handler().await?;
698 }
699 }
700 }
701 None => {
702 if session.cache.enabled() && *end_stream {
703 session.cache.finish_miss_handler().await?;
704 }
705 }
706 }
707 }
708 HttpTask::Trailer(_) => {} HttpTask::Done => {
710 if session.cache.enabled() {
711 session.cache.finish_miss_handler().await?;
712 }
713 }
714 HttpTask::Failed(_) => {
715 }
717 }
718 Ok(())
719 }
720
721 pub(crate) async fn revalidate_or_stale(
726 &self,
727 session: &mut Session,
728 task: &mut HttpTask,
729 ctx: &mut SV::CTX,
730 ) -> bool
731 where
732 SV: ProxyHttp + Send + Sync,
733 SV::CTX: Send + Sync,
734 {
735 if !session.cache.enabled() {
736 return false;
737 }
738
739 match task {
740 HttpTask::Header(resp, _eos) => {
741 if resp.status == StatusCode::NOT_MODIFIED {
742 if session.cache.maybe_cache_meta().is_some() {
743 if let Err(err) = self
745 .inner
746 .upstream_response_filter(session, resp, ctx)
747 .await
748 {
749 error!("upstream response filter error on 304: {err:?}");
750 session.cache.revalidate_uncacheable(
751 *resp.clone(),
752 NoCacheReason::InternalError,
753 );
754 return true;
756 }
757 let merged_header = session.cache.revalidate_merge_header(resp);
760 match self
761 .inner
762 .response_cache_filter(session, &merged_header, ctx)
763 {
764 Ok(Cacheable(mut meta)) => {
765 let old_meta = session.cache.maybe_cache_meta().unwrap(); if let Some(old_variance) = old_meta.variance() {
774 meta.set_variance(old_variance);
775 }
776 if let Err(e) = session.cache.revalidate_cache_meta(meta).await {
777 warn!("revalidate_cache_meta failed {e:?}");
780 }
781 }
782 Ok(Uncacheable(reason)) => {
783 debug!("Uncacheable {reason:?} 304 received");
796 session.cache.response_became_uncacheable(reason);
797 session.cache.revalidate_uncacheable(merged_header, reason);
798 }
799 Err(e) => {
800 warn!("Error {e:?} response_cache_filter during revalidation");
804 session.cache.revalidate_uncacheable(
805 merged_header,
806 NoCacheReason::InternalError,
807 );
808 }
810 }
811 true
813 } else {
814 warn!("304 received without cached asset, disable caching");
816 let reason = NoCacheReason::Custom("304 on miss");
817 session.cache.response_became_uncacheable(reason);
818 session.cache.disable(reason);
819 false
820 }
821 } else if resp.status.is_server_error() {
822 if !session.cache.can_serve_stale_error()
826 || session.response_written().is_some()
827 {
828 return false;
829 }
830
831 let http_status_error = Error::create(
833 ErrorType::HTTPStatus(resp.status.as_u16()),
834 ErrorSource::Upstream,
835 None,
836 None,
837 );
838 if self
839 .inner
840 .should_serve_stale(session, ctx, Some(&http_status_error))
841 {
842 session
844 .cache
845 .release_write_lock(NoCacheReason::UpstreamError);
846 true
847 } else {
848 false
849 }
850 } else {
851 false }
853 }
854 _ => false, }
856 }
857
858 pub(crate) async fn handle_stale_if_error(
861 &self,
862 session: &mut Session,
863 ctx: &mut SV::CTX,
864 error: &Error,
865 ) -> Option<(bool, Option<Box<Error>>)>
866 where
867 SV: ProxyHttp + Send + Sync,
868 SV::CTX: Send + Sync,
869 {
870 if !session.cache.can_serve_stale_error() {
872 return None;
873 }
874
875 if session.response_written().is_some() {
878 return None;
879 }
880
881 if !self.inner.should_serve_stale(session, ctx, Some(error)) {
883 return None;
884 }
885
886 warn!(
888 "Fail to proxy: {}, serving stale, {}",
889 error,
890 self.inner.request_summary(session, ctx)
891 );
892
893 session
895 .cache
896 .release_write_lock(NoCacheReason::UpstreamError);
897
898 Some(self.proxy_cache_hit(session, ctx).await)
899 }
900
901 fn handle_lock_status(
903 &self,
904 session: &mut Session,
905 ctx: &SV::CTX,
906 lock_status: LockStatus,
907 ) -> bool
908 where
909 SV: ProxyHttp,
910 {
911 debug!("cache unlocked {lock_status:?}");
912 match lock_status {
913 LockStatus::Done => true,
915 LockStatus::TransientError => true,
917 LockStatus::GiveUp => {
919 session.cache.disable(NoCacheReason::CacheLockGiveUp);
921 false
923 }
924 LockStatus::Dangling => {
926 warn!(
928 "Dangling cache lock, {}",
929 self.inner.request_summary(session, ctx)
930 );
931 true
932 }
933 LockStatus::WaitTimeout => {
936 warn!(
937 "Cache lock timeout, {}",
938 self.inner.request_summary(session, ctx)
939 );
940 session.cache.disable(NoCacheReason::CacheLockTimeout);
941 false
943 }
944 LockStatus::AgeTimeout => true,
948 LockStatus::Waiting => panic!("impossible LockStatus::Waiting"),
950 }
951 }
952}
953
954fn cache_hit_header(cache: &HttpCache) -> Box<ResponseHeader> {
955 let mut header = Box::new(cache.cache_meta().response_header_copy());
956 let no_body = matches!(header.status.as_u16(), 204 | 304);
960
961 if !cache.upstream_used() {
965 let age = cache.cache_meta().age().as_secs();
966 header.insert_header(http::header::AGE, age).unwrap();
967 }
968 log::debug!("cache header: {header:?} {:?}", cache.phase());
969
970 header.set_version(Version::HTTP_11);
974
975 if !no_body
978 && !header.status.is_informational()
979 && header.headers.get(http::header::CONTENT_LENGTH).is_none()
980 {
981 header
982 .insert_header(http::header::TRANSFER_ENCODING, "chunked")
983 .unwrap();
984 }
985 header
986}
987
988pub mod range_filter {
990 use super::*;
991 use bytes::BytesMut;
992 use http::header::*;
993 use std::ops::Range;
994
995 fn parse_number(input: &[u8]) -> Option<usize> {
997 str::from_utf8(input).ok()?.parse().ok()
998 }
999
1000 fn parse_range_header(
1001 range: &[u8],
1002 content_length: usize,
1003 max_multipart_ranges: Option<usize>,
1004 ) -> RangeType {
1005 use regex::Regex;
1006
1007 static RE_SINGLE_RANGE_PART: Lazy<Regex> =
1009 Lazy::new(|| Regex::new(r"(?i)^\s*(?P<start>\d*)-(?P<end>\d*)\s*$").unwrap());
1010
1011 let range_str = match str::from_utf8(range) {
1013 Ok(s) => s,
1014 Err(_) => return RangeType::None,
1015 };
1016
1017 let mut parts = range_str.splitn(2, "=");
1019
1020 let prefix = parts.next();
1022 if !prefix.is_some_and(|s| s.eq_ignore_ascii_case("bytes")) {
1023 return RangeType::None;
1024 }
1025
1026 let Some(ranges_str) = parts.next() else {
1027 return RangeType::None;
1029 };
1030
1031 if ranges_str.trim().is_empty() {
1034 return RangeType::Invalid;
1035 }
1036
1037 let mut range_count = 0;
1039 for _ in ranges_str.split(',') {
1040 range_count += 1;
1041 if let Some(max_ranges) = max_multipart_ranges {
1042 if range_count >= max_ranges {
1043 return RangeType::None;
1045 }
1046 }
1047 }
1048 let mut ranges: Vec<Range<usize>> = Vec::with_capacity(range_count);
1049
1050 let mut last_range_end = 0;
1052 for part in ranges_str.split(',') {
1053 let captured = match RE_SINGLE_RANGE_PART.captures(part) {
1054 Some(c) => c,
1055 None => {
1056 return RangeType::None;
1057 }
1058 };
1059
1060 let maybe_start = captured
1061 .name("start")
1062 .and_then(|s| s.as_str().parse::<usize>().ok());
1063 let end = captured
1064 .name("end")
1065 .and_then(|s| s.as_str().parse::<usize>().ok());
1066
1067 let range = if let Some(start) = maybe_start {
1068 if start >= content_length {
1069 continue;
1071 }
1072 let end = std::cmp::min(end.unwrap_or(content_length - 1), content_length - 1) + 1;
1076 if end <= start {
1077 continue;
1079 }
1080 start..end
1081 } else {
1082 if let Some(end) = end {
1085 if content_length >= end {
1086 (content_length - end)..content_length
1087 } else {
1088 0..content_length
1090 }
1091 } else {
1092 continue;
1094 }
1095 };
1096 if range.start < last_range_end {
1099 return RangeType::None;
1100 }
1101 last_range_end = range.end;
1102 ranges.push(range);
1103 }
1104
1105 if ranges.is_empty() {
1115 RangeType::Invalid
1117 } else if ranges.len() == 1 {
1118 RangeType::Single(ranges[0].clone()) } else {
1120 RangeType::Multi(MultiRangeInfo::new(ranges))
1121 }
1122 }
1123 #[test]
1124 fn test_parse_range() {
1125 assert_eq!(
1126 parse_range_header(b"bytes=0-1", 10, None),
1127 RangeType::new_single(0, 2)
1128 );
1129 assert_eq!(
1130 parse_range_header(b"bYTes=0-9", 10, None),
1131 RangeType::new_single(0, 10)
1132 );
1133 assert_eq!(
1134 parse_range_header(b"bytes=0-12", 10, None),
1135 RangeType::new_single(0, 10)
1136 );
1137 assert_eq!(
1138 parse_range_header(b"bytes=0-", 10, None),
1139 RangeType::new_single(0, 10)
1140 );
1141 assert_eq!(
1142 parse_range_header(b"bytes=2-1", 10, None),
1143 RangeType::Invalid
1144 );
1145 assert_eq!(
1146 parse_range_header(b"bytes=10-11", 10, None),
1147 RangeType::Invalid
1148 );
1149 assert_eq!(
1150 parse_range_header(b"bytes=-2", 10, None),
1151 RangeType::new_single(8, 10)
1152 );
1153 assert_eq!(
1154 parse_range_header(b"bytes=-12", 10, None),
1155 RangeType::new_single(0, 10)
1156 );
1157 assert_eq!(parse_range_header(b"bytes=-", 10, None), RangeType::Invalid);
1158 assert_eq!(parse_range_header(b"bytes=", 10, None), RangeType::Invalid);
1159 assert_eq!(
1160 parse_range_header(b"bytes= ", 10, None),
1161 RangeType::Invalid
1162 );
1163 }
1164
1165 #[test]
1167 fn test_parse_range_header_multi() {
1168 assert_eq!(
1169 parse_range_header(b"bytes=0-1,4-5", 10, None)
1170 .get_multirange_info()
1171 .expect("Should have multipart info for Multipart range request")
1172 .ranges,
1173 (vec![Range { start: 0, end: 2 }, Range { start: 4, end: 6 }])
1174 );
1175 assert_eq!(
1177 parse_range_header(b"bytEs=0-99,200-299,400-499", 320, None)
1178 .get_multirange_info()
1179 .expect("Should have multipart info for Multipart range request")
1180 .ranges,
1181 (vec![
1182 Range { start: 0, end: 100 },
1183 Range {
1184 start: 200,
1185 end: 300
1186 }
1187 ])
1188 );
1189 assert_eq!(
1191 parse_range_header(b"bytEs=0-99,200-299,400-499", 500, None)
1192 .get_multirange_info()
1193 .expect("Should have multipart info for Multipart range request")
1194 .ranges,
1195 vec![
1196 Range { start: 0, end: 100 },
1197 Range {
1198 start: 200,
1199 end: 300
1200 },
1201 Range {
1202 start: 400,
1203 end: 500
1204 },
1205 ]
1206 );
1207 assert_eq!(
1209 parse_range_header(b"bytes=0-,-2", 10, None),
1210 RangeType::None,
1211 );
1212 assert!(parse_range_header(b"bytes=0-,-2", 10, None)
1214 .get_multirange_info()
1215 .is_none());
1216 assert_eq!(
1218 parse_range_header(b"bytes=0-3,2-5", 10, None),
1219 RangeType::None,
1220 );
1221 assert!(parse_range_header(b"bytes=0-3,2-5", 10, None)
1222 .get_multirange_info()
1223 .is_none());
1224
1225 assert_eq!(
1227 parse_range_header(b"bytes=0-5,10-", 2, None),
1228 RangeType::new_single(0, 2)
1229 );
1230 assert!(parse_range_header(b"bytes=0-5,10-", 2, None)
1231 .get_multirange_info()
1232 .is_none());
1233
1234 assert_eq!(
1236 parse_range_header(b"bytes=0-5, 10-20, 30-18", 200, None)
1237 .get_multirange_info()
1238 .expect("Should have multipart info for Multipart range request")
1239 .ranges,
1240 vec![Range { start: 0, end: 6 }, Range { start: 10, end: 21 },]
1241 );
1242 assert_eq!(
1244 parse_range_header(b"bytes=5-0, 20-15, 30-25", 200, None),
1245 RangeType::Invalid
1246 );
1247
1248 fn generate_range_header(count: usize) -> Vec<u8> {
1250 let mut s = String::from("bytes=");
1251 for i in 0..count {
1252 let start = i * 4;
1253 let end = start + 1;
1254 if i > 0 {
1255 s.push(',');
1256 }
1257 s.push_str(&start.to_string());
1258 s.push('-');
1259 s.push_str(&end.to_string());
1260 }
1261 s.into_bytes()
1262 }
1263
1264 let ranges = generate_range_header(201);
1266 assert_eq!(
1267 parse_range_header(&ranges, 1000, Some(200)),
1268 RangeType::None
1269 )
1270 }
1271
1272 #[derive(Debug, Eq, PartialEq, Clone)]
1275 pub struct MultiRangeInfo {
1276 pub ranges: Vec<Range<usize>>,
1277 pub boundary: String,
1278 total_length: usize,
1279 content_type: Option<String>,
1280 }
1281
1282 impl MultiRangeInfo {
1283 pub fn new(ranges: Vec<Range<usize>>) -> Self {
1285 Self {
1286 ranges,
1287 boundary: Self::generate_boundary(),
1289 total_length: 0,
1290 content_type: None,
1291 }
1292 }
1293 pub fn set_content_type(&mut self, content_type: String) {
1294 self.content_type = Some(content_type)
1295 }
1296 pub fn set_total_length(&mut self, total_length: usize) {
1297 self.total_length = total_length;
1298 }
1299 fn generate_boundary() -> String {
1304 use rand::Rng;
1305 let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
1306 format!("{:016x}", rng.gen::<u64>())
1307 }
1308 pub fn calculate_multipart_length(&self) -> usize {
1309 let mut total_length = 0;
1310 let content_type = self.content_type.as_ref();
1311 for range in self.ranges.clone() {
1312 total_length += 4 + self.boundary.len() + 2;
1319 total_length += content_type.map_or(0, |ct| 14 + ct.len() + 2);
1320 total_length += format!(
1321 "Content-Range: bytes {}-{}/{}",
1322 range.start,
1323 range.end - 1,
1324 self.total_length
1325 )
1326 .len()
1327 + 2;
1328 total_length += 2;
1329 total_length += range.end - range.start;
1330 }
1331 total_length += 4 + self.boundary.len() + 4;
1333 total_length
1334 }
1335 }
1336 #[derive(Debug, Eq, PartialEq, Clone)]
1337 pub enum RangeType {
1338 None,
1339 Single(Range<usize>),
1340 Multi(MultiRangeInfo),
1341 Invalid,
1342 }
1343
1344 impl RangeType {
1345 #[allow(dead_code)]
1347 fn new_single(start: usize, end: usize) -> Self {
1348 RangeType::Single(Range { start, end })
1349 }
1350 #[allow(dead_code)]
1351 pub fn new_multi(ranges: Vec<Range<usize>>) -> Self {
1352 RangeType::Multi(MultiRangeInfo::new(ranges))
1353 }
1354 #[allow(dead_code)]
1355 fn get_multirange_info(&self) -> Option<&MultiRangeInfo> {
1356 match self {
1357 RangeType::Multi(multi_range_info) => Some(multi_range_info),
1358 _ => None,
1359 }
1360 }
1361 #[allow(dead_code)]
1362 fn update_multirange_info(&mut self, content_length: usize, content_type: Option<String>) {
1363 if let RangeType::Multi(multipart_range_info) = self {
1364 multipart_range_info.content_type = content_type;
1365 multipart_range_info.set_total_length(content_length);
1366 }
1367 }
1368 }
1369
1370 pub fn range_header_filter(
1372 req: &RequestHeader,
1373 resp: &mut ResponseHeader,
1374 max_multipart_ranges: Option<usize>,
1375 ) -> RangeType {
1376 if resp.status != StatusCode::OK {
1380 return RangeType::None;
1381 }
1382
1383 let Some(content_length_bytes) = resp.headers.get(CONTENT_LENGTH) else {
1386 return RangeType::None;
1387 };
1388 let Some(content_length) = parse_number(content_length_bytes.as_bytes()) else {
1390 return RangeType::None;
1391 };
1392
1393 fn request_range_type(
1398 req: &RequestHeader,
1399 resp: &ResponseHeader,
1400 content_length: usize,
1401 max_multipart_ranges: Option<usize>,
1402 ) -> RangeType {
1403 if req.method != http::Method::GET && req.method != http::Method::HEAD {
1405 return RangeType::None;
1406 }
1407
1408 let Some(range_header) = req.headers.get(RANGE) else {
1409 return RangeType::None;
1410 };
1411
1412 if let Some(if_range) = req.headers.get(IF_RANGE) {
1420 let ir = if_range.as_bytes();
1421 let matches = if ir.len() >= 2 && ir.last() == Some(&b'"') {
1422 resp.headers.get(ETAG).is_some_and(|etag| etag == if_range)
1423 } else if let Some(last_modified) = resp.headers.get(LAST_MODIFIED) {
1424 last_modified == if_range
1425 } else {
1426 false
1427 };
1428 if !matches {
1429 return RangeType::None;
1430 }
1431 }
1432
1433 parse_range_header(
1434 range_header.as_bytes(),
1435 content_length,
1436 max_multipart_ranges,
1437 )
1438 }
1439
1440 let mut range_type = request_range_type(req, resp, content_length, max_multipart_ranges);
1441
1442 match &mut range_type {
1443 RangeType::None => {
1444 resp.insert_header(&ACCEPT_RANGES, "bytes").unwrap();
1447 }
1448 RangeType::Single(r) => {
1449 resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1451 resp.remove_header(&ACCEPT_RANGES);
1452 resp.insert_header(&CONTENT_LENGTH, r.end - r.start)
1453 .unwrap();
1454 resp.insert_header(
1455 &CONTENT_RANGE,
1456 format!("bytes {}-{}/{content_length}", r.start, r.end - 1), )
1458 .unwrap()
1459 }
1460
1461 RangeType::Multi(multi_range_info) => {
1462 let content_type = resp
1463 .headers
1464 .get(CONTENT_TYPE)
1465 .and_then(|v| v.to_str().ok())
1466 .unwrap_or("application/octet-stream");
1467 multi_range_info.set_total_length(content_length);
1469 multi_range_info.set_content_type(content_type.to_string());
1470
1471 let total_length = multi_range_info.calculate_multipart_length();
1472
1473 resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1474 resp.remove_header(&ACCEPT_RANGES);
1475 resp.insert_header(CONTENT_LENGTH, total_length).unwrap();
1476 resp.insert_header(
1477 CONTENT_TYPE,
1478 format!(
1479 "multipart/byteranges; boundary={}",
1480 multi_range_info.boundary
1481 ), )
1483 .unwrap();
1484 resp.remove_header(&CONTENT_RANGE);
1485 }
1486 RangeType::Invalid => {
1487 resp.set_status(StatusCode::RANGE_NOT_SATISFIABLE).unwrap();
1489 resp.insert_header(&CONTENT_LENGTH, HeaderValue::from_static("0"))
1491 .unwrap();
1492 resp.remove_header(&ACCEPT_RANGES);
1493 resp.remove_header(&CONTENT_TYPE);
1494 resp.remove_header(&CONTENT_ENCODING);
1495 resp.remove_header(&TRANSFER_ENCODING);
1496 resp.insert_header(&CONTENT_RANGE, format!("bytes */{content_length}"))
1497 .unwrap()
1498 }
1499 }
1500
1501 range_type
1502 }
1503
1504 #[test]
1505 fn test_range_filter_single() {
1506 fn gen_req() -> RequestHeader {
1507 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap()
1508 }
1509 fn gen_resp() -> ResponseHeader {
1510 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1511 resp.append_header("Content-Length", "10").unwrap();
1512 resp
1513 }
1514
1515 let req = gen_req();
1517 let mut resp = gen_resp();
1518 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1519 assert_eq!(resp.status.as_u16(), 200);
1520 assert_eq!(
1521 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1522 b"bytes"
1523 );
1524
1525 let mut req = gen_req();
1527 req.method = Method::HEAD;
1528 let mut resp = gen_resp();
1529 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1530 assert_eq!(resp.status.as_u16(), 200);
1531 assert_eq!(
1532 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1533 b"bytes"
1534 );
1535
1536 let mut req = gen_req();
1538 req.insert_header("Range", "bytes=0-1").unwrap();
1539 let mut resp = gen_resp();
1540 assert_eq!(
1541 RangeType::new_single(0, 2),
1542 range_header_filter(&req, &mut resp, None)
1543 );
1544 assert_eq!(resp.status.as_u16(), 206);
1545 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1546 assert_eq!(
1547 resp.headers.get("content-range").unwrap().as_bytes(),
1548 b"bytes 0-1/10"
1549 );
1550 assert!(resp.headers.get("accept-ranges").is_none());
1551
1552 let mut req = gen_req();
1554 req.insert_header("Range", "bytes=0-1").unwrap();
1555 let mut resp = gen_resp();
1556 resp.insert_header("Accept-Ranges", "bytes").unwrap();
1557 assert_eq!(
1558 RangeType::new_single(0, 2),
1559 range_header_filter(&req, &mut resp, None)
1560 );
1561 assert_eq!(resp.status.as_u16(), 206);
1562 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1563 assert_eq!(
1564 resp.headers.get("content-range").unwrap().as_bytes(),
1565 b"bytes 0-1/10"
1566 );
1567 assert!(resp.headers.get("accept-ranges").is_none());
1569
1570 let mut req = gen_req();
1572 req.insert_header("Range", "bytes=1-0").unwrap();
1573 let mut resp = gen_resp();
1574 resp.insert_header("Accept-Ranges", "bytes").unwrap();
1575 resp.insert_header("Content-Encoding", "gzip").unwrap();
1576 resp.insert_header("Transfer-Encoding", "chunked").unwrap();
1577 assert_eq!(
1578 RangeType::Invalid,
1579 range_header_filter(&req, &mut resp, None)
1580 );
1581 assert_eq!(resp.status.as_u16(), 416);
1582 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"0");
1583 assert_eq!(
1584 resp.headers.get("content-range").unwrap().as_bytes(),
1585 b"bytes */10"
1586 );
1587 assert!(resp.headers.get("accept-ranges").is_none());
1588 assert!(resp.headers.get("content-encoding").is_none());
1589 assert!(resp.headers.get("transfer-encoding").is_none());
1590 }
1591
1592 #[test]
1594 fn test_range_filter_multipart() {
1595 fn gen_req() -> RequestHeader {
1596 let mut req: RequestHeader =
1597 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1598 req.append_header("Range", "bytes=0-1,3-4,6-7").unwrap();
1599 req
1600 }
1601 fn gen_req_overlap_range() -> RequestHeader {
1602 let mut req: RequestHeader =
1603 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1604 req.append_header("Range", "bytes=0-3,2-5,7-8").unwrap();
1605 req
1606 }
1607 fn gen_resp() -> ResponseHeader {
1608 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1609 resp.append_header("Content-Length", "10").unwrap();
1610 resp
1611 }
1612
1613 let req = gen_req();
1615 let mut resp = gen_resp();
1616 let result = range_header_filter(&req, &mut resp, None);
1617 let mut boundary_str = String::new();
1618
1619 assert!(matches!(result, RangeType::Multi(_)));
1620 if let RangeType::Multi(multi_part_info) = result {
1621 assert_eq!(multi_part_info.ranges.len(), 3);
1622 assert_eq!(multi_part_info.ranges[0], Range { start: 0, end: 2 });
1623 assert_eq!(multi_part_info.ranges[1], Range { start: 3, end: 5 });
1624 assert_eq!(multi_part_info.ranges[2], Range { start: 6, end: 8 });
1625 assert!(multi_part_info.content_type.is_some());
1627 assert_eq!(multi_part_info.total_length, 10);
1628 assert!(!multi_part_info.boundary.is_empty());
1629 boundary_str = multi_part_info.boundary;
1630 }
1631 assert_eq!(resp.status.as_u16(), 206);
1632 assert_eq!(
1634 resp.headers.get("content-type").unwrap().to_str().unwrap(),
1635 format!("multipart/byteranges; boundary={boundary_str}")
1636 );
1637 assert!(resp.headers.get("content_length").is_none());
1638 assert!(resp.headers.get("accept-ranges").is_none());
1639
1640 let req = gen_req_overlap_range();
1642 let mut resp = gen_resp();
1643 let result = range_header_filter(&req, &mut resp, None);
1644
1645 assert!(matches!(result, RangeType::None));
1646 assert_eq!(resp.status.as_u16(), 200);
1647 assert!(resp.headers.get("content-type").is_none());
1648 assert_eq!(
1649 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1650 b"bytes"
1651 );
1652
1653 let mut req = gen_req();
1655 req.insert_header("Range", "bytes=1-0, 12-9, 50-40")
1656 .unwrap();
1657 let mut resp = gen_resp();
1658 resp.insert_header("Content-Encoding", "br").unwrap();
1659 resp.insert_header("Transfer-Encoding", "chunked").unwrap();
1660 let result = range_header_filter(&req, &mut resp, None);
1661 assert!(matches!(result, RangeType::Invalid));
1662 assert_eq!(resp.status.as_u16(), 416);
1663 assert!(resp.headers.get("accept-ranges").is_none());
1664 assert!(resp.headers.get("content-encoding").is_none());
1665 assert!(resp.headers.get("transfer-encoding").is_none());
1666 }
1667
1668 #[test]
1669 fn test_if_range() {
1670 const DATE: &str = "Fri, 07 Jul 2023 22:03:29 GMT";
1671 const ETAG: &str = "\"1234\"";
1672
1673 fn gen_req() -> RequestHeader {
1674 let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1675 req.append_header("Range", "bytes=0-1").unwrap();
1676 req
1677 }
1678 fn get_multipart_req() -> RequestHeader {
1679 let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1680 _ = req.append_header("Range", "bytes=0-1,3-4,6-7");
1681 req
1682 }
1683 fn gen_resp() -> ResponseHeader {
1684 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1685 resp.append_header("Content-Length", "10").unwrap();
1686 resp.append_header("Last-Modified", DATE).unwrap();
1687 resp.append_header("ETag", ETAG).unwrap();
1688 resp
1689 }
1690
1691 let mut req = gen_req();
1693 req.insert_header("If-Range", DATE).unwrap();
1694 let mut resp = gen_resp();
1695 assert_eq!(
1696 RangeType::new_single(0, 2),
1697 range_header_filter(&req, &mut resp, None)
1698 );
1699
1700 let mut req = gen_req();
1702 req.insert_header("If-Range", "Fri, 07 Jul 2023 22:03:25 GMT")
1703 .unwrap();
1704 let mut resp = gen_resp();
1705 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1706 assert_eq!(resp.status.as_u16(), 200);
1707 assert_eq!(
1708 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1709 b"bytes"
1710 );
1711
1712 let mut req = gen_req();
1714 req.insert_header("If-Range", ETAG).unwrap();
1715 let mut resp = gen_resp();
1716 assert_eq!(
1717 RangeType::new_single(0, 2),
1718 range_header_filter(&req, &mut resp, None)
1719 );
1720 assert_eq!(resp.status.as_u16(), 206);
1721 assert!(resp.headers.get("accept-ranges").is_none());
1722
1723 let mut req = gen_req();
1725 req.insert_header("If-Range", "\"4567\"").unwrap();
1726 let mut resp = gen_resp();
1727 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1728 assert_eq!(resp.status.as_u16(), 200);
1729 assert_eq!(
1730 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1731 b"bytes"
1732 );
1733
1734 let mut req = gen_req();
1735 req.insert_header("If-Range", "1234").unwrap();
1736 let mut resp = gen_resp();
1737 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1738 assert_eq!(resp.status.as_u16(), 200);
1739 assert_eq!(
1740 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1741 b"bytes"
1742 );
1743
1744 let mut req = get_multipart_req();
1746 req.insert_header("If-Range", DATE).unwrap();
1747 let mut resp = gen_resp();
1748 let result = range_header_filter(&req, &mut resp, None);
1749 assert!(matches!(result, RangeType::Multi(_)));
1750 assert_eq!(resp.status.as_u16(), 206);
1751 assert!(resp.headers.get("accept-ranges").is_none());
1752
1753 let req = get_multipart_req();
1755 let mut resp = gen_resp();
1756 assert!(matches!(
1757 range_header_filter(&req, &mut resp, None),
1758 RangeType::Multi(_)
1759 ));
1760
1761 let mut req = get_multipart_req();
1763 req.insert_header("If-Range", "\"wrong\"").unwrap();
1764 let mut resp = gen_resp();
1765 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1766 assert_eq!(resp.status.as_u16(), 200);
1767 assert_eq!(
1768 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1769 b"bytes"
1770 );
1771 }
1772
1773 pub struct RangeBodyFilter {
1774 pub range: RangeType,
1775 current: usize,
1776 multipart_idx: Option<usize>,
1777 cache_multipart_idx: Option<usize>,
1778 }
1779
1780 impl Default for RangeBodyFilter {
1781 fn default() -> Self {
1782 Self::new()
1783 }
1784 }
1785
1786 impl RangeBodyFilter {
1787 pub fn new() -> Self {
1788 RangeBodyFilter {
1789 range: RangeType::None,
1790 current: 0,
1791 multipart_idx: None,
1792 cache_multipart_idx: None,
1793 }
1794 }
1795
1796 pub fn new_range(range: RangeType) -> Self {
1797 RangeBodyFilter {
1798 multipart_idx: matches!(range, RangeType::Multi(_)).then_some(0),
1799 range,
1800 ..Default::default()
1801 }
1802 }
1803
1804 pub fn is_multipart_range(&self) -> bool {
1805 matches!(self.range, RangeType::Multi(_))
1806 }
1807
1808 pub fn should_cache_seek_again(&self) -> bool {
1811 match &self.range {
1812 RangeType::Multi(multipart_info) => self
1813 .cache_multipart_idx
1814 .is_some_and(|idx| idx != multipart_info.ranges.len() - 1),
1815 _ => false,
1816 }
1817 }
1818
1819 pub fn next_cache_multipart_range(&mut self) -> Range<usize> {
1821 match &self.range {
1822 RangeType::Multi(multipart_info) => {
1823 match self.cache_multipart_idx.as_mut() {
1824 Some(v) => *v += 1,
1825 None => self.cache_multipart_idx = Some(0),
1826 }
1827 let cache_multipart_idx = self.cache_multipart_idx.expect("set above");
1828 let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1829 assert_eq!(multipart_idx, cache_multipart_idx,
1832 "cache multipart idx should match multipart idx, or there is a hit handler bug");
1833 multipart_info.ranges[cache_multipart_idx].clone()
1834 }
1835 _ => panic!("tried to advance multipart idx on non-multipart range"),
1836 }
1837 }
1838
1839 pub fn set_current_cursor(&mut self, current: usize) {
1840 self.current = current;
1841 }
1842
1843 pub fn set(&mut self, range: RangeType) {
1844 self.multipart_idx = matches!(range, RangeType::Multi(_)).then_some(0);
1845 self.range = range;
1846 }
1847
1848 pub fn finalize(&self, boundary: &String) -> Option<Bytes> {
1850 if let RangeType::Multi(_) = self.range {
1851 Some(Bytes::from(format!("\r\n--{boundary}--\r\n")))
1852 } else {
1853 None
1854 }
1855 }
1856
1857 pub fn filter_body(&mut self, data: Option<Bytes>) -> Option<Bytes> {
1858 match &self.range {
1859 RangeType::None => data,
1860 RangeType::Invalid => None,
1861 RangeType::Single(r) => {
1862 let current = self.current;
1863 self.current += data.as_ref().map_or(0, |d| d.len());
1864 data.and_then(|d| Self::filter_range_data(r.start, r.end, current, d))
1865 }
1866
1867 RangeType::Multi(_) => {
1868 let data = data?;
1869 let current = self.current;
1870 let data_len = data.len();
1871 self.current += data_len;
1872 self.filter_multi_range_body(data, current, data_len)
1873 }
1874 }
1875 }
1876
1877 fn filter_range_data(
1878 start: usize,
1879 end: usize,
1880 current: usize,
1881 data: Bytes,
1882 ) -> Option<Bytes> {
1883 if current + data.len() < start || current >= end {
1884 None
1886 } else if current >= start && current + data.len() <= end {
1887 Some(data)
1889 } else {
1890 let slice_start = start.saturating_sub(current);
1893 let slice_end = std::cmp::min(data.len(), end - current);
1894 Some(data.slice(slice_start..slice_end))
1895 }
1896 }
1897
1898 fn build_multipart_header(
1900 &self,
1901 range: &Range<usize>,
1902 boundary: &str,
1903 total_length: &usize,
1904 content_type: Option<&str>,
1905 ) -> Bytes {
1906 Bytes::from(format!(
1907 "\r\n--{}\r\n{}Content-Range: bytes {}-{}/{}\r\n\r\n",
1908 boundary,
1909 content_type.map_or(String::new(), |ct| format!("Content-Type: {ct}\r\n")),
1910 range.start,
1911 range.end - 1,
1912 total_length
1913 ))
1914 }
1915
1916 fn current_chunk_includes_range_start(
1918 &self,
1919 range: &Range<usize>,
1920 current: usize,
1921 data_len: usize,
1922 ) -> bool {
1923 range.start >= current && range.start < current + data_len
1924 }
1925
1926 fn current_chunk_includes_range_end(
1928 &self,
1929 range: &Range<usize>,
1930 current: usize,
1931 data_len: usize,
1932 ) -> bool {
1933 range.end > current && range.end <= current + data_len
1934 }
1935
1936 fn filter_multi_range_body(
1937 &mut self,
1938 data: Bytes,
1939 current: usize,
1940 data_len: usize,
1941 ) -> Option<Bytes> {
1942 let mut result = BytesMut::new();
1943
1944 let RangeType::Multi(multi_part_info) = &self.range else {
1945 return None;
1946 };
1947
1948 let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1949 let final_range = multi_part_info.ranges.last()?;
1950
1951 let (_, remaining_ranges) = multi_part_info.ranges.as_slice().split_at(multipart_idx);
1952 for range in remaining_ranges {
1955 if let Some(sliced) =
1956 Self::filter_range_data(range.start, range.end, current, data.clone())
1957 {
1958 if self.current_chunk_includes_range_start(range, current, data_len) {
1959 result.extend_from_slice(&self.build_multipart_header(
1960 range,
1961 multi_part_info.boundary.as_ref(),
1962 &multi_part_info.total_length,
1963 multi_part_info.content_type.as_deref(),
1964 ));
1965 }
1966 result.extend_from_slice(&sliced);
1968 if self.current_chunk_includes_range_end(range, current, data_len) {
1969 if range == final_range {
1971 if let Some(final_chunk) = self.finalize(&multi_part_info.boundary) {
1972 result.extend_from_slice(&final_chunk);
1973 }
1974 }
1975 self.multipart_idx = Some(self.multipart_idx.expect("must be set") + 1);
1977 }
1978 } else {
1979 break;
1983 }
1984 }
1985 if result.is_empty() {
1986 None
1987 } else {
1988 Some(result.freeze())
1989 }
1990 }
1991 }
1992
1993 #[test]
1994 fn test_range_body_filter_single() {
1995 let mut body_filter = RangeBodyFilter::new_range(RangeType::None);
1996 assert_eq!(body_filter.filter_body(Some("123".into())).unwrap(), "123");
1997
1998 let mut body_filter = RangeBodyFilter::new_range(RangeType::Invalid);
1999 assert!(body_filter.filter_body(Some("123".into())).is_none());
2000
2001 let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(0, 1));
2002 assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "0");
2003 assert!(body_filter.filter_body(Some("345".into())).is_none());
2004
2005 let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(4, 6));
2006 assert!(body_filter.filter_body(Some("012".into())).is_none());
2007 assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "45");
2008 assert!(body_filter.filter_body(Some("678".into())).is_none());
2009
2010 let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(1, 7));
2011 assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "12");
2012 assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "345");
2013 assert_eq!(body_filter.filter_body(Some("678".into())).unwrap(), "6");
2014 }
2015
2016 #[test]
2017 fn test_range_body_filter_multipart() {
2018 let data = Bytes::from("0123456789");
2020 let ranges = vec![0..3, 6..9];
2021 let content_length = data.len();
2022 let mut body_filter = RangeBodyFilter::new();
2023 body_filter.set(RangeType::new_multi(ranges.clone()));
2024
2025 body_filter
2026 .range
2027 .update_multirange_info(content_length, None);
2028
2029 let multi_range_info = body_filter
2030 .range
2031 .get_multirange_info()
2032 .cloned()
2033 .expect("Multipart Ranges should have MultiPartInfo struct");
2034
2035 let output = body_filter.filter_body(Some(data)).unwrap();
2037 let footer = body_filter.finalize(&multi_range_info.boundary).unwrap();
2038
2039 let output_str = str::from_utf8(&output).unwrap();
2041 let final_boundary = str::from_utf8(&footer).unwrap();
2042 let boundary = &multi_range_info.boundary;
2043
2044 for (i, range) in ranges.iter().enumerate() {
2046 let header = &format!(
2047 "--{}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2048 boundary,
2049 range.start,
2050 range.end - 1,
2051 content_length
2052 );
2053 assert!(
2054 output_str.contains(header),
2055 "Missing part header {} in multipart body",
2056 i
2057 );
2058 let expected_body = &"0123456789"[range.clone()];
2060 assert!(
2061 output_str.contains(expected_body),
2062 "Missing body {} for range {:?}",
2063 expected_body,
2064 range
2065 )
2066 }
2067 assert_eq!(final_boundary, format!("\r\n--{}--\r\n", boundary));
2069
2070 let full_body = b"0123456789";
2072 let ranges = vec![0..2, 4..6, 8..9];
2073 let content_length = full_body.len();
2074 let content_type = "text/plain".to_string();
2075 let mut body_filter = RangeBodyFilter::new();
2076 body_filter.set(RangeType::new_multi(ranges.clone()));
2077
2078 body_filter
2079 .range
2080 .update_multirange_info(content_length, Some(content_type.clone()));
2081
2082 let multi_range_info = body_filter
2083 .range
2084 .get_multirange_info()
2085 .cloned()
2086 .expect("Multipart Ranges should have MultiPartInfo struct");
2087
2088 let chunk1 = Bytes::from_static(b"012");
2090 let chunk2 = Bytes::from_static(b"345");
2091 let chunk3 = Bytes::from_static(b"678");
2092 let chunk4 = Bytes::from_static(b"9");
2093
2094 let mut collected_bytes = BytesMut::new();
2095 for chunk in [chunk1, chunk2, chunk3, chunk4] {
2096 if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2097 collected_bytes.extend_from_slice(&filtered);
2098 }
2099 }
2100 if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2101 collected_bytes.extend_from_slice(&final_boundary);
2102 }
2103
2104 let output_str = str::from_utf8(&collected_bytes).unwrap();
2105 let boundary = multi_range_info.boundary;
2106
2107 for (i, range) in ranges.iter().enumerate() {
2108 let header = &format!(
2109 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2110 boundary,
2111 content_type,
2112 range.start,
2113 range.end - 1,
2114 content_length
2115 );
2116 let expected_body = &full_body[range.clone()];
2117 let expected_output = format!("{}{}", header, str::from_utf8(expected_body).unwrap());
2118
2119 assert!(
2120 output_str.contains(&expected_output),
2121 "Missing or malformed part {} in multipart body. \n Expected: \n{}\n Got: \n{}",
2122 i,
2123 expected_output,
2124 output_str
2125 )
2126 }
2127
2128 assert!(
2129 output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2130 "Missing final boundary"
2131 );
2132
2133 let full_body = b"abcdefghijkl";
2135 let ranges = vec![2..7, 9..11];
2136 let content_length = full_body.len();
2137 let content_type = "application/octet-stream".to_string();
2138 let mut body_filter = RangeBodyFilter::new();
2139 body_filter.set(RangeType::new_multi(ranges.clone()));
2140
2141 body_filter
2142 .range
2143 .update_multirange_info(content_length, Some(content_type.clone()));
2144
2145 let multi_range_info = body_filter
2146 .range
2147 .clone()
2148 .get_multirange_info()
2149 .cloned()
2150 .expect("Multipart Ranges should have MultiPartInfo struct");
2151
2152 let chunk1 = Bytes::from_static(b"abc");
2154 let chunk2 = Bytes::from_static(b"def");
2155 let chunk3 = Bytes::from_static(b"ghi");
2156 let chunk4 = Bytes::from_static(b"jkl");
2157
2158 let mut collected_bytes = BytesMut::new();
2159 for chunk in [chunk1, chunk2, chunk3, chunk4] {
2160 if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2161 collected_bytes.extend_from_slice(&filtered);
2162 }
2163 }
2164 if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2165 collected_bytes.extend_from_slice(&final_boundary);
2166 }
2167
2168 let output_str = str::from_utf8(&collected_bytes).unwrap();
2169 let boundary = &multi_range_info.boundary;
2170
2171 let header1 = &format!(
2172 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2173 boundary,
2174 content_type,
2175 ranges[0].start,
2176 ranges[0].end - 1,
2177 content_length
2178 );
2179 let header2 = &format!(
2180 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2181 boundary,
2182 content_type,
2183 ranges[1].start,
2184 ranges[1].end - 1,
2185 content_length
2186 );
2187
2188 assert!(output_str.contains(header1));
2189 assert!(output_str.contains(header2));
2190
2191 let expected_body_slices = ["cdefg", "jk"];
2192
2193 assert!(
2194 output_str.contains(expected_body_slices[0]),
2195 "Missing expected sliced body {}",
2196 expected_body_slices[0]
2197 );
2198
2199 assert!(
2200 output_str.contains(expected_body_slices[1]),
2201 "Missing expected sliced body {}",
2202 expected_body_slices[1]
2203 );
2204
2205 assert!(
2206 output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2207 "Missing final boundary"
2208 );
2209 }
2210}
2211
2212#[derive(Debug)]
2215pub(crate) enum ServeFromCache {
2216 Off,
2218 CacheHeader,
2220 CacheHeaderOnly,
2222 CacheHeaderOnlyMiss,
2224 CacheBody(bool),
2226 CacheHeaderMiss,
2230 CacheBodyMiss(bool),
2232 Done,
2234 DoneMiss,
2236}
2237
2238impl ServeFromCache {
2239 pub fn new() -> Self {
2240 Self::Off
2241 }
2242
2243 pub fn is_on(&self) -> bool {
2244 !matches!(self, Self::Off)
2245 }
2246
2247 pub fn is_miss(&self) -> bool {
2248 matches!(
2249 self,
2250 Self::CacheHeaderMiss
2251 | Self::CacheHeaderOnlyMiss
2252 | Self::CacheBodyMiss(_)
2253 | Self::DoneMiss
2254 )
2255 }
2256
2257 pub fn is_miss_header(&self) -> bool {
2258 matches!(self, Self::CacheHeaderMiss)
2261 }
2262
2263 pub fn is_miss_body(&self) -> bool {
2264 matches!(self, Self::CacheBodyMiss(_))
2265 }
2266
2267 pub fn should_discard_upstream(&self) -> bool {
2268 self.is_on() && !self.is_miss()
2269 }
2270
2271 pub fn should_send_to_downstream(&self) -> bool {
2272 !self.is_on()
2273 }
2274
2275 pub fn enable(&mut self) {
2276 *self = Self::CacheHeader;
2277 }
2278
2279 pub fn enable_miss(&mut self) {
2280 if !self.is_on() {
2281 *self = Self::CacheHeaderMiss;
2282 }
2283 }
2284
2285 pub fn enable_header_only(&mut self) {
2286 match self {
2287 Self::CacheBody(_) => *self = Self::Done, Self::CacheBodyMiss(_) => *self = Self::DoneMiss,
2289 _ => {
2290 if self.is_miss() {
2291 *self = Self::CacheHeaderOnlyMiss;
2292 } else {
2293 *self = Self::CacheHeaderOnly;
2294 }
2295 }
2296 }
2297 }
2298
2299 pub async fn next_http_task(
2301 &mut self,
2302 cache: &mut HttpCache,
2303 range: &mut RangeBodyFilter,
2304 upgraded: bool,
2305 ) -> Result<HttpTask> {
2306 fn body_task(data: Bytes, upgraded: bool) -> HttpTask {
2307 if upgraded {
2308 HttpTask::UpgradedBody(Some(data), false)
2309 } else {
2310 HttpTask::Body(Some(data), false)
2311 }
2312 }
2313
2314 if !cache.enabled() {
2315 return Error::e_explain(InternalError, "Cache disabled");
2319 }
2320 match self {
2321 Self::Off => panic!("ProxyUseCache not enabled"),
2322 Self::CacheHeader => {
2323 *self = Self::CacheBody(true);
2324 Ok(HttpTask::Header(cache_hit_header(cache), false)) }
2326 Self::CacheHeaderMiss => {
2327 *self = Self::CacheBodyMiss(true);
2328 Ok(HttpTask::Header(cache_hit_header(cache), false)) }
2330 Self::CacheHeaderOnly => {
2331 *self = Self::Done;
2332 Ok(HttpTask::Header(cache_hit_header(cache), true))
2333 }
2334 Self::CacheHeaderOnlyMiss => {
2335 *self = Self::DoneMiss;
2336 Ok(HttpTask::Header(cache_hit_header(cache), true))
2337 }
2338 Self::CacheBody(should_seek) => {
2339 log::trace!("cache body should seek: {should_seek}");
2340 if *should_seek {
2341 self.maybe_seek_hit_handler(cache, range)?;
2342 }
2343 loop {
2344 if let Some(b) = cache.hit_handler().read_body().await? {
2345 return Ok(body_task(b, upgraded));
2346 }
2347 if range.should_cache_seek_again() {
2350 self.maybe_seek_hit_handler(cache, range)?;
2351 } else {
2352 *self = Self::Done;
2353 return Ok(HttpTask::Done);
2354 }
2355 }
2356 }
2357 Self::CacheBodyMiss(should_seek) => {
2358 if *should_seek {
2359 self.maybe_seek_miss_handler(cache, range)?;
2360 }
2361 loop {
2363 if let Some(b) = cache.miss_body_reader().unwrap().read_body().await? {
2364 return Ok(body_task(b, upgraded));
2365 } else {
2366 if range.should_cache_seek_again() {
2369 self.maybe_seek_miss_handler(cache, range)?;
2370 } else {
2371 *self = Self::DoneMiss;
2372 return Ok(HttpTask::Done);
2373 }
2374 }
2375 }
2376 }
2377 Self::Done => Ok(HttpTask::Done),
2378 Self::DoneMiss => Ok(HttpTask::Done),
2379 }
2380 }
2381
2382 fn maybe_seek_miss_handler(
2383 &mut self,
2384 cache: &mut HttpCache,
2385 range_filter: &mut RangeBodyFilter,
2386 ) -> Result<()> {
2387 match &range_filter.range {
2388 RangeType::Single(range) => {
2389 if cache.miss_body_reader().unwrap().can_seek() {
2391 cache
2392 .miss_body_reader()
2393 .unwrap()
2395 .seek(range.start, Some(range.end))
2396 .or_err(InternalError, "cannot seek miss handler")?;
2397 range_filter.range = RangeType::None;
2400 }
2401 }
2402 RangeType::Multi(_info) => {
2403 if cache.miss_body_reader().unwrap().can_seek_multipart() {
2405 let range = range_filter.next_cache_multipart_range();
2406 cache
2407 .miss_body_reader()
2408 .unwrap()
2409 .seek_multipart(range.start, Some(range.end))
2410 .or_err(InternalError, "cannot seek hit handler for multirange")?;
2411 range_filter.set_current_cursor(range.start);
2414 }
2415 }
2416 _ => {}
2417 }
2418
2419 *self = Self::CacheBodyMiss(false);
2420 Ok(())
2421 }
2422
2423 fn maybe_seek_hit_handler(
2424 &mut self,
2425 cache: &mut HttpCache,
2426 range_filter: &mut RangeBodyFilter,
2427 ) -> Result<()> {
2428 match &range_filter.range {
2429 RangeType::Single(range) => {
2430 if cache.hit_handler().can_seek() {
2431 cache
2432 .hit_handler()
2433 .seek(range.start, Some(range.end))
2434 .or_err(InternalError, "cannot seek hit handler")?;
2435 range_filter.range = RangeType::None;
2438 }
2439 }
2440 RangeType::Multi(_info) => {
2441 if cache.hit_handler().can_seek_multipart() {
2442 let range = range_filter.next_cache_multipart_range();
2443 cache
2444 .hit_handler()
2445 .seek_multipart(range.start, Some(range.end))
2446 .or_err(InternalError, "cannot seek hit handler for multirange")?;
2447 range_filter.set_current_cursor(range.start);
2450 }
2451 }
2452 _ => {}
2453 }
2454 *self = Self::CacheBody(false);
2455 Ok(())
2456 }
2457}