1use super::*;
16use http::header::{CONTENT_LENGTH, CONTENT_TYPE};
17use http::{Method, StatusCode};
18use pingora_cache::key::CacheHashKey;
19use pingora_cache::lock::LockStatus;
20use pingora_cache::max_file_size::ERR_RESPONSE_TOO_LARGE;
21use pingora_cache::{ForcedFreshness, HitHandler, HitStatus, RespCacheable::*};
22use pingora_core::protocols::http::conditional_filter::to_304;
23use pingora_core::protocols::http::v1::common::header_value_content_length;
24use pingora_core::ErrorType;
25use range_filter::RangeBodyFilter;
26use std::time::SystemTime;
27
28impl<SV, C> HttpProxy<SV, C>
29where
30 C: custom::Connector,
31{
32 pub(crate) async fn proxy_cache(
34 self: &Arc<Self>,
35 session: &mut Session,
36 ctx: &mut SV::CTX,
37 ) -> Option<(bool, Option<Box<Error>>)>
38 where
40 SV: ProxyHttp + Send + Sync + 'static,
41 SV::CTX: Send + Sync,
42 {
43 if let Err(e) = self.inner.request_cache_filter(session, ctx) {
45 warn!(
47 "Fail to request_cache_filter: {e}, {}",
48 self.inner.request_summary(session, ctx)
49 );
50 }
51
52 if session.cache.enabled() {
54 match self.inner.cache_key_callback(session, ctx) {
55 Ok(key) => {
56 session.cache.set_cache_key(key);
57 }
58 Err(e) => {
59 session.cache.disable(NoCacheReason::StorageError);
61 warn!(
62 "Fail to cache_key_callback: {e}, {}",
63 self.inner.request_summary(session, ctx)
64 );
65 }
66 }
67 }
68
69 if self.inner.is_purge(session, ctx) {
71 return self.proxy_purge(session, ctx).await;
72 }
73
74 if session.cache.enabled() && !session.cache.cacheable_prediction() {
76 session.cache.bypass();
77 }
78
79 if !session.cache.enabled() {
80 return None;
81 }
82
83 loop {
85 match session.cache.cache_lookup().await {
87 Ok(res) => {
88 let mut hit_status_opt = None;
89 if let Some((mut meta, mut handler)) = res {
90 let cache_key = session.cache.cache_key();
95 if let Some(variance) = cache_key.variance_bin() {
96 if Some(variance) != meta.variance() {
99 warn!("Cache variance mismatch, {variance:?}, {cache_key:?}");
100 session.cache.disable(NoCacheReason::InternalError);
101 break None;
102 }
103 } else {
104 let req_header = session.req_header();
106 let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
107 if let Some(variance) = variance {
108 if !session.cache.cache_vary_lookup(variance, &meta) {
110 continue;
113 }
114 } }
116
117 let is_fresh = meta.is_fresh(SystemTime::now());
122 let hit_status = match self
124 .inner
125 .cache_hit_filter(session, &meta, &mut handler, is_fresh, ctx)
126 .await
127 {
128 Err(e) => {
129 error!(
130 "Failed to filter cache hit: {e}, {}",
131 self.inner.request_summary(session, ctx)
132 );
133 HitStatus::FailedHitFilter
135 }
136 Ok(None) => {
137 if is_fresh {
138 HitStatus::Fresh
139 } else {
140 HitStatus::Expired
141 }
142 }
143 Ok(Some(ForcedFreshness::ForceExpired)) => {
144 meta.disable_serve_stale();
147 HitStatus::ForceExpired
148 }
149 Ok(Some(ForcedFreshness::ForceMiss)) => HitStatus::ForceMiss,
150 Ok(Some(ForcedFreshness::ForceFresh)) => HitStatus::Fresh,
151 };
152
153 hit_status_opt = Some(hit_status);
154
155 session.cache.cache_found(meta, handler, hit_status);
157 }
158
159 if hit_status_opt.is_none_or(HitStatus::is_treated_as_miss) {
160 if session.cache.is_cache_locked() {
162 let lock_status = session.cache.cache_lock_wait().await;
164 if self.handle_lock_status(session, ctx, lock_status) {
165 continue;
166 } else {
167 break None;
168 }
169 } else {
170 self.inner.cache_miss(session, ctx);
171 break None;
172 }
173 }
174
175 let hit_status = hit_status_opt.expect("None case handled as miss");
178
179 if !hit_status.is_fresh() {
180 if session.cache.is_cache_locked() {
182 if let Some(write_lock) = session
184 .subrequest_ctx
185 .as_mut()
186 .and_then(|ctx| ctx.take_write_lock())
187 {
188 session.cache.set_write_lock(write_lock);
190 session.cache.tag_as_subrequest();
191 break None;
193 }
194 let will_serve_stale = session.cache.can_serve_stale_updating()
195 && self.inner.should_serve_stale(session, ctx, None);
196 if !will_serve_stale {
197 let lock_status = session.cache.cache_lock_wait().await;
198 if self.handle_lock_status(session, ctx, lock_status) {
199 continue;
200 } else {
201 break None;
202 }
203 }
204 session.cache.set_stale_updating();
206 } else if session.cache.is_cache_lock_writer() {
207 let will_serve_stale = session.cache.can_serve_stale_updating()
209 && self.inner.should_serve_stale(session, ctx, None);
210 if will_serve_stale {
211 let (permit, cache_lock) = session.cache.take_write_lock();
215 SubrequestSpawner::new(self.clone()).spawn_background_subrequest(
216 session.as_ref(),
217 subrequest::Ctx::builder()
218 .cache_write_lock(
219 cache_lock,
220 session.cache.cache_key().clone(),
221 permit,
222 )
223 .build(),
224 );
225 session.cache.set_stale_updating();
227 } else {
228 break None;
230 }
231 } else {
232 break None;
234 }
235 }
236
237 let (reuse, err) = self.proxy_cache_hit(session, ctx).await;
238 if let Some(e) = err.as_ref() {
239 error!(
240 "Fail to serve cache: {e}, {}",
241 self.inner.request_summary(session, ctx)
242 );
243 }
244 break Some((reuse, err));
246 }
247 Err(e) => {
248 self.inner.cache_miss(session, ctx);
253 warn!(
254 "Fail to cache lookup: {e}, {}",
255 self.inner.request_summary(session, ctx)
256 );
257 break None;
258 }
259 }
260 }
261 }
262
263 pub(crate) async fn proxy_cache_hit(
265 &self,
266 session: &mut Session,
267 ctx: &mut SV::CTX,
268 ) -> (bool, Option<Box<Error>>)
269 where
270 SV: ProxyHttp + Send + Sync,
271 SV::CTX: Send + Sync,
272 {
273 use range_filter::*;
274
275 let seekable = session.cache.hit_handler().can_seek();
276 let mut header = cache_hit_header(&session.cache);
277
278 let req = session.req_header();
279
280 let not_modified = match self.inner.cache_not_modified_filter(session, &header, ctx) {
281 Ok(not_modified) => not_modified,
282 Err(e) => {
283 warn!(
286 "Failed to run cache not modified filter: {e}, {}",
287 self.inner.request_summary(session, ctx)
288 );
289 false
290 }
291 };
292 if not_modified {
293 to_304(&mut header);
294 }
295 let header_only = not_modified || req.method == http::method::Method::HEAD;
296
297 let range_type = if seekable && !session.ignore_downstream_range {
299 self.inner.range_header_filter(session, &mut header, ctx)
300 } else {
301 RangeType::None
302 };
303
304 let header_only = header_only || matches!(range_type, RangeType::Invalid);
306 debug!("header: {header:?}");
307
308 match self.inner.response_filter(session, &mut header, ctx).await {
310 Ok(_) => {
311 if let Err(e) = session
312 .downstream_modules_ctx
313 .response_header_filter(&mut header, header_only)
314 .await
315 {
316 error!(
317 "Failed to run downstream modules response header filter in hit: {e}, {}",
318 self.inner.request_summary(session, ctx)
319 );
320 session
321 .as_mut()
322 .respond_error(500)
323 .await
324 .unwrap_or_else(|e| {
325 error!("failed to send error response to downstream: {e}");
326 });
327 return (true, Some(e));
329 }
330
331 if let Err(e) = session
332 .as_mut()
333 .write_response_header(header)
334 .await
335 .map_err(|e| e.into_down())
336 {
337 return (false, Some(e));
339 }
340 }
341 Err(e) => {
342 error!(
343 "Failed to run response filter in hit: {e}, {}",
344 self.inner.request_summary(session, ctx)
345 );
346 session
347 .as_mut()
348 .respond_error(500)
349 .await
350 .unwrap_or_else(|e| {
351 error!("failed to send error response to downstream: {e}");
352 });
353 return (true, Some(e));
355 }
356 }
357 debug!("finished sending cached header to downstream");
358
359 fn seek_multipart(
364 hit_handler: &mut HitHandler,
365 range_filter: &mut RangeBodyFilter,
366 ) -> Result<bool> {
367 if !range_filter.is_multipart_range() || !hit_handler.can_seek_multipart() {
368 return Ok(false);
369 }
370 let r = range_filter.next_cache_multipart_range();
371 hit_handler.seek_multipart(r.start, Some(r.end))?;
372 range_filter.set_current_cursor(r.start);
375 Ok(true)
376 }
377
378 if !header_only {
379 let mut maybe_range_filter = match &range_type {
380 RangeType::Single(r) => {
381 if session.cache.hit_handler().can_seek() {
382 if let Err(e) = session.cache.hit_handler().seek(r.start, Some(r.end)) {
383 return (false, Some(e));
384 }
385 None
386 } else {
387 Some(RangeBodyFilter::new_range(range_type.clone()))
388 }
389 }
390 RangeType::Multi(_) => {
391 let mut range_filter = RangeBodyFilter::new_range(range_type.clone());
392 if let Err(e) = seek_multipart(session.cache.hit_handler(), &mut range_filter) {
393 return (false, Some(e));
394 }
395 Some(range_filter)
396 }
397 RangeType::Invalid => unreachable!(),
398 RangeType::None => None,
399 };
400 loop {
401 match session.cache.hit_handler().read_body().await {
402 Ok(raw_body) => {
403 let end = raw_body.is_none();
404
405 if end {
406 if let Some(range_filter) = maybe_range_filter.as_mut() {
407 if range_filter.should_cache_seek_again() {
408 let e = match seek_multipart(
409 session.cache.hit_handler(),
410 range_filter,
411 ) {
412 Ok(true) => {
413 continue;
415 }
416 Ok(false) => {
417 Error::explain(
424 InternalError,
425 "hit handler cannot seek for multipart again",
426 )
427 }
429 Err(e) => e,
430 };
431 return (false, Some(e));
432 }
433 }
434 }
435
436 let mut body = if let Some(range_filter) = maybe_range_filter.as_mut() {
437 range_filter.filter_body(raw_body)
438 } else {
439 raw_body
440 };
441
442 match self
443 .inner
444 .response_body_filter(session, &mut body, end, ctx)
445 {
446 Ok(Some(duration)) => {
447 trace!("delaying response for {duration:?}");
448 time::sleep(duration).await;
449 }
450 Ok(None) => { }
451 Err(e) => {
452 return (false, Some(e));
454 }
455 }
456
457 if let Err(e) = session
458 .downstream_modules_ctx
459 .response_body_filter(&mut body, end)
460 {
461 return (false, Some(e));
463 }
464
465 if !end && body.as_ref().is_none_or(|b| b.is_empty()) {
466 continue;
469 }
470
471 let b = body.unwrap_or_default();
473 if let Err(e) = session
474 .as_mut()
475 .write_response_body(b, end)
476 .await
477 .map_err(|e| e.into_down())
478 {
479 return (false, Some(e));
480 }
481 if end {
482 break;
483 }
484 }
485 Err(e) => return (false, Some(e)),
486 }
487 }
488 }
489
490 if let Err(e) = session.cache.finish_hit_handler().await {
491 warn!("Error during finish_hit_handler: {}", e);
492 }
493
494 match session.as_mut().finish_body().await {
495 Ok(_) => {
496 debug!("finished sending cached body to downstream");
497 (true, None)
498 }
499 Err(e) => (false, Some(e)),
500 }
501 }
502
503 pub(crate) fn downstream_response_conditional_filter(
506 &self,
507 use_cache: &mut ServeFromCache,
508 session: &Session,
509 resp: &mut ResponseHeader,
510 ctx: &mut SV::CTX,
511 ) where
512 SV: ProxyHttp,
513 {
514 let req = session.req_header();
516
517 let not_modified = match self.inner.cache_not_modified_filter(session, resp, ctx) {
518 Ok(not_modified) => not_modified,
519 Err(e) => {
520 warn!(
523 "Failed to run cache not modified filter: {e}, {}",
524 self.inner.request_summary(session, ctx)
525 );
526 false
527 }
528 };
529
530 if not_modified {
531 to_304(resp);
532 }
533 let header_only = not_modified || req.method == http::method::Method::HEAD;
534 if header_only && use_cache.is_on() {
535 use_cache.enable_header_only();
538 }
539 }
540
541 pub(crate) async fn cache_http_task(
544 &self,
545 session: &mut Session,
546 task: &HttpTask,
547 ctx: &mut SV::CTX,
548 serve_from_cache: &mut ServeFromCache,
549 ) -> Result<()>
550 where
551 SV: ProxyHttp + Send + Sync,
552 SV::CTX: Send + Sync,
553 {
554 if !session.cache.enabled() && !session.cache.bypassing() {
555 return Ok(());
556 }
557
558 match task {
559 HttpTask::Header(header, end_stream) => {
560 if header.status.is_informational()
564 && header.status != StatusCode::SWITCHING_PROTOCOLS
565 {
566 return Ok(());
567 }
568 match self.inner.response_cache_filter(session, header, ctx)? {
569 Cacheable(meta) => {
570 let mut fill_cache = true;
571 if session.cache.bypassing() {
572 if session.cache.max_file_size_bytes().is_some()
579 && !meta.headers().contains_key(header::CONTENT_LENGTH)
580 {
581 session
582 .cache
583 .disable(NoCacheReason::PredictedResponseTooLarge);
584 return Ok(());
585 }
586
587 session.cache.response_became_cacheable();
588
589 if session.req_header().method == Method::GET
590 && meta.response_header().status == StatusCode::OK
591 {
592 self.inner.cache_miss(session, ctx);
593 if !session.cache.enabled() {
594 fill_cache = false;
595 }
596 } else {
597 fill_cache = false;
603 session.cache.disable(NoCacheReason::Deferred);
604 }
605 }
606
607 if session.cache.enabled() {
610 if let Some(max_file_size) = session.cache.max_file_size_bytes() {
611 let content_length_hdr = meta.headers().get(header::CONTENT_LENGTH);
612 if let Some(content_length) =
613 header_value_content_length(content_length_hdr)
614 {
615 if content_length > max_file_size {
616 fill_cache = false;
617 session.cache.response_became_uncacheable(
618 NoCacheReason::ResponseTooLarge,
619 );
620 session.cache.disable(NoCacheReason::ResponseTooLarge);
621 session.ignore_downstream_range = true;
623 }
624 }
625 }
629 }
630 if fill_cache {
631 let req_header = session.req_header();
632 let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
637 session.cache.set_cache_meta(meta);
638 session.cache.update_variance(variance);
639 session.cache.set_miss_handler().await?;
641 if session.cache.miss_body_reader().is_some() {
642 serve_from_cache.enable_miss();
643 }
644 if *end_stream {
645 session
646 .cache
647 .miss_handler()
648 .unwrap() .write_body(Bytes::new(), true)
650 .await?;
651 session.cache.finish_miss_handler().await?;
652 }
653 }
654 }
655 Uncacheable(reason) => {
656 if !session.cache.bypassing() {
657 session.cache.response_became_uncacheable(reason);
659 }
660 session.cache.disable(reason);
661 }
662 }
663 }
664 HttpTask::Body(data, end_stream) => match data {
665 Some(d) => {
666 if session.cache.enabled() {
667 let body_size_allowed =
670 session.cache.track_body_bytes_for_max_file_size(d.len());
671 if !body_size_allowed {
672 debug!("chunked response exceeded max cache size, remembering that it is uncacheable");
673 session
674 .cache
675 .response_became_uncacheable(NoCacheReason::ResponseTooLarge);
676
677 return Error::e_explain(
678 ERR_RESPONSE_TOO_LARGE,
679 format!(
680 "writing data of size {} bytes would exceed max file size of {} bytes",
681 d.len(),
682 session.cache.max_file_size_bytes().expect("max file size bytes must be set to exceed size")
683 ),
684 );
685 }
686
687 let miss_handler = session.cache.miss_handler().unwrap();
690
691 miss_handler.write_body(d.clone(), *end_stream).await?;
692 if *end_stream {
693 session.cache.finish_miss_handler().await?;
694 }
695 }
696 }
697 None => {
698 if session.cache.enabled() && *end_stream {
699 session.cache.finish_miss_handler().await?;
700 }
701 }
702 },
703 HttpTask::Trailer(_) => {} HttpTask::Done => {
705 if session.cache.enabled() {
706 session.cache.finish_miss_handler().await?;
707 }
708 }
709 HttpTask::Failed(_) => {
710 }
712 }
713 Ok(())
714 }
715
716 pub(crate) async fn revalidate_or_stale(
721 &self,
722 session: &mut Session,
723 task: &mut HttpTask,
724 ctx: &mut SV::CTX,
725 ) -> bool
726 where
727 SV: ProxyHttp + Send + Sync,
728 SV::CTX: Send + Sync,
729 {
730 if !session.cache.enabled() {
731 return false;
732 }
733
734 match task {
735 HttpTask::Header(resp, _eos) => {
736 if resp.status == StatusCode::NOT_MODIFIED {
737 if session.cache.maybe_cache_meta().is_some() {
738 if let Err(err) = self
740 .inner
741 .upstream_response_filter(session, resp, ctx)
742 .await
743 {
744 error!("upstream response filter error on 304: {err:?}");
745 session.cache.revalidate_uncacheable(
746 *resp.clone(),
747 NoCacheReason::InternalError,
748 );
749 return true;
751 }
752 let merged_header = session.cache.revalidate_merge_header(resp);
755 match self
756 .inner
757 .response_cache_filter(session, &merged_header, ctx)
758 {
759 Ok(Cacheable(mut meta)) => {
760 let old_meta = session.cache.maybe_cache_meta().unwrap(); if let Some(old_variance) = old_meta.variance() {
769 meta.set_variance(old_variance);
770 }
771 if let Err(e) = session.cache.revalidate_cache_meta(meta).await {
772 warn!("revalidate_cache_meta failed {e:?}");
775 }
776 }
777 Ok(Uncacheable(reason)) => {
778 debug!("Uncacheable {reason:?} 304 received");
791 session.cache.response_became_uncacheable(reason);
792 session.cache.revalidate_uncacheable(merged_header, reason);
793 }
794 Err(e) => {
795 warn!("Error {e:?} response_cache_filter during revalidation");
799 session.cache.revalidate_uncacheable(
800 merged_header,
801 NoCacheReason::InternalError,
802 );
803 }
805 }
806 true
808 } else {
809 warn!("304 received without cached asset, disable caching");
811 let reason = NoCacheReason::Custom("304 on miss");
812 session.cache.response_became_uncacheable(reason);
813 session.cache.disable(reason);
814 false
815 }
816 } else if resp.status.is_server_error() {
817 if !session.cache.can_serve_stale_error()
821 || session.response_written().is_some()
822 {
823 return false;
824 }
825
826 let http_status_error = Error::create(
828 ErrorType::HTTPStatus(resp.status.as_u16()),
829 ErrorSource::Upstream,
830 None,
831 None,
832 );
833 if self
834 .inner
835 .should_serve_stale(session, ctx, Some(&http_status_error))
836 {
837 session
839 .cache
840 .release_write_lock(NoCacheReason::UpstreamError);
841 true
842 } else {
843 false
844 }
845 } else {
846 false }
848 }
849 _ => false, }
851 }
852
853 pub(crate) async fn handle_stale_if_error(
856 &self,
857 session: &mut Session,
858 ctx: &mut SV::CTX,
859 error: &Error,
860 ) -> Option<(bool, Option<Box<Error>>)>
861 where
862 SV: ProxyHttp + Send + Sync,
863 SV::CTX: Send + Sync,
864 {
865 if !session.cache.can_serve_stale_error() {
867 return None;
868 }
869
870 if session.response_written().is_some() {
873 return None;
874 }
875
876 if !self.inner.should_serve_stale(session, ctx, Some(error)) {
878 return None;
879 }
880
881 warn!(
883 "Fail to proxy: {}, serving stale, {}",
884 error,
885 self.inner.request_summary(session, ctx)
886 );
887
888 session
890 .cache
891 .release_write_lock(NoCacheReason::UpstreamError);
892
893 Some(self.proxy_cache_hit(session, ctx).await)
894 }
895
896 fn handle_lock_status(
898 &self,
899 session: &mut Session,
900 ctx: &SV::CTX,
901 lock_status: LockStatus,
902 ) -> bool
903 where
904 SV: ProxyHttp,
905 {
906 debug!("cache unlocked {lock_status:?}");
907 match lock_status {
908 LockStatus::Done => true,
910 LockStatus::TransientError => true,
912 LockStatus::GiveUp => {
914 session.cache.disable(NoCacheReason::CacheLockGiveUp);
916 false
918 }
919 LockStatus::Dangling => {
921 warn!(
923 "Dangling cache lock, {}",
924 self.inner.request_summary(session, ctx)
925 );
926 true
927 }
928 LockStatus::WaitTimeout => {
931 warn!(
932 "Cache lock timeout, {}",
933 self.inner.request_summary(session, ctx)
934 );
935 session.cache.disable(NoCacheReason::CacheLockTimeout);
936 false
938 }
939 LockStatus::AgeTimeout => true,
943 LockStatus::Waiting => panic!("impossible LockStatus::Waiting"),
945 }
946 }
947}
948
949fn cache_hit_header(cache: &HttpCache) -> Box<ResponseHeader> {
950 let mut header = Box::new(cache.cache_meta().response_header_copy());
951 let no_body = matches!(header.status.as_u16(), 204 | 304);
955
956 if !cache.upstream_used() {
960 let age = cache.cache_meta().age().as_secs();
961 header.insert_header(http::header::AGE, age).unwrap();
962 }
963 log::debug!("cache header: {header:?} {:?}", cache.phase());
964
965 header.set_version(Version::HTTP_11);
969
970 if !no_body
973 && !header.status.is_informational()
974 && header.headers.get(http::header::CONTENT_LENGTH).is_none()
975 {
976 header
977 .insert_header(http::header::TRANSFER_ENCODING, "chunked")
978 .unwrap();
979 }
980 header
981}
982
983pub mod range_filter {
985 use super::*;
986 use bytes::BytesMut;
987 use http::header::*;
988 use std::ops::Range;
989
990 fn parse_number(input: &[u8]) -> Option<usize> {
992 str::from_utf8(input).ok()?.parse().ok()
993 }
994
995 fn parse_range_header(
996 range: &[u8],
997 content_length: usize,
998 max_multipart_ranges: Option<usize>,
999 ) -> RangeType {
1000 use regex::Regex;
1001
1002 static RE_SINGLE_RANGE_PART: Lazy<Regex> =
1004 Lazy::new(|| Regex::new(r"(?i)^\s*(?P<start>\d*)-(?P<end>\d*)\s*$").unwrap());
1005
1006 let range_str = match str::from_utf8(range) {
1008 Ok(s) => s,
1009 Err(_) => return RangeType::None,
1010 };
1011
1012 let mut parts = range_str.splitn(2, "=");
1014
1015 let prefix = parts.next();
1017 if !prefix.is_some_and(|s| s.eq_ignore_ascii_case("bytes")) {
1018 return RangeType::None;
1019 }
1020
1021 let Some(ranges_str) = parts.next() else {
1022 return RangeType::None;
1024 };
1025
1026 let mut range_count = 0;
1028 for _ in ranges_str.split(',') {
1029 range_count += 1;
1030 if let Some(max_ranges) = max_multipart_ranges {
1031 if range_count >= max_ranges {
1032 return RangeType::None;
1034 }
1035 }
1036 }
1037 let mut ranges: Vec<Range<usize>> = Vec::with_capacity(range_count);
1038
1039 let mut last_range_end = 0;
1041 for part in ranges_str.split(',') {
1042 let captured = match RE_SINGLE_RANGE_PART.captures(part) {
1043 Some(c) => c,
1044 None => {
1045 return RangeType::None;
1046 }
1047 };
1048
1049 let maybe_start = captured
1050 .name("start")
1051 .and_then(|s| s.as_str().parse::<usize>().ok());
1052 let end = captured
1053 .name("end")
1054 .and_then(|s| s.as_str().parse::<usize>().ok());
1055
1056 let range = if let Some(start) = maybe_start {
1057 if start >= content_length {
1058 continue;
1060 }
1061 let end = std::cmp::min(end.unwrap_or(content_length - 1), content_length - 1) + 1;
1065 if end <= start {
1066 continue;
1068 }
1069 start..end
1070 } else {
1071 if let Some(end) = end {
1074 if content_length >= end {
1075 (content_length - end)..content_length
1076 } else {
1077 0..content_length
1079 }
1080 } else {
1081 continue;
1083 }
1084 };
1085 if range.start < last_range_end {
1088 return RangeType::None;
1089 }
1090 last_range_end = range.end;
1091 ranges.push(range);
1092 }
1093
1094 if ranges.is_empty() {
1104 RangeType::Invalid
1106 } else if ranges.len() == 1 {
1107 RangeType::Single(ranges[0].clone()) } else {
1109 RangeType::Multi(MultiRangeInfo::new(ranges))
1110 }
1111 }
1112 #[test]
1113 fn test_parse_range() {
1114 assert_eq!(
1115 parse_range_header(b"bytes=0-1", 10, None),
1116 RangeType::new_single(0, 2)
1117 );
1118 assert_eq!(
1119 parse_range_header(b"bYTes=0-9", 10, None),
1120 RangeType::new_single(0, 10)
1121 );
1122 assert_eq!(
1123 parse_range_header(b"bytes=0-12", 10, None),
1124 RangeType::new_single(0, 10)
1125 );
1126 assert_eq!(
1127 parse_range_header(b"bytes=0-", 10, None),
1128 RangeType::new_single(0, 10)
1129 );
1130 assert_eq!(
1131 parse_range_header(b"bytes=2-1", 10, None),
1132 RangeType::Invalid
1133 );
1134 assert_eq!(
1135 parse_range_header(b"bytes=10-11", 10, None),
1136 RangeType::Invalid
1137 );
1138 assert_eq!(
1139 parse_range_header(b"bytes=-2", 10, None),
1140 RangeType::new_single(8, 10)
1141 );
1142 assert_eq!(
1143 parse_range_header(b"bytes=-12", 10, None),
1144 RangeType::new_single(0, 10)
1145 );
1146 assert_eq!(parse_range_header(b"bytes=-", 10, None), RangeType::Invalid);
1147 assert_eq!(parse_range_header(b"bytes=", 10, None), RangeType::None);
1148 }
1149
1150 #[test]
1152 fn test_parse_range_header_multi() {
1153 assert_eq!(
1154 parse_range_header(b"bytes=0-1,4-5", 10, None)
1155 .get_multirange_info()
1156 .expect("Should have multipart info for Multipart range request")
1157 .ranges,
1158 (vec![Range { start: 0, end: 2 }, Range { start: 4, end: 6 }])
1159 );
1160 assert_eq!(
1162 parse_range_header(b"bytEs=0-99,200-299,400-499", 320, None)
1163 .get_multirange_info()
1164 .expect("Should have multipart info for Multipart range request")
1165 .ranges,
1166 (vec![
1167 Range { start: 0, end: 100 },
1168 Range {
1169 start: 200,
1170 end: 300
1171 }
1172 ])
1173 );
1174 assert_eq!(
1176 parse_range_header(b"bytEs=0-99,200-299,400-499", 500, None)
1177 .get_multirange_info()
1178 .expect("Should have multipart info for Multipart range request")
1179 .ranges,
1180 vec![
1181 Range { start: 0, end: 100 },
1182 Range {
1183 start: 200,
1184 end: 300
1185 },
1186 Range {
1187 start: 400,
1188 end: 500
1189 },
1190 ]
1191 );
1192 assert_eq!(
1194 parse_range_header(b"bytes=0-,-2", 10, None),
1195 RangeType::None,
1196 );
1197 assert!(parse_range_header(b"bytes=0-,-2", 10, None)
1199 .get_multirange_info()
1200 .is_none());
1201 assert_eq!(
1203 parse_range_header(b"bytes=0-3,2-5", 10, None),
1204 RangeType::None,
1205 );
1206 assert!(parse_range_header(b"bytes=0-3,2-5", 10, None)
1207 .get_multirange_info()
1208 .is_none());
1209
1210 assert_eq!(
1212 parse_range_header(b"bytes=0-5,10-", 2, None),
1213 RangeType::new_single(0, 2)
1214 );
1215 assert!(parse_range_header(b"bytes=0-5,10-", 2, None)
1216 .get_multirange_info()
1217 .is_none());
1218
1219 assert_eq!(
1221 parse_range_header(b"bytes=0-5, 10-20, 30-18", 200, None)
1222 .get_multirange_info()
1223 .expect("Should have multipart info for Multipart range request")
1224 .ranges,
1225 vec![Range { start: 0, end: 6 }, Range { start: 10, end: 21 },]
1226 );
1227 assert_eq!(
1229 parse_range_header(b"bytes=5-0, 20-15, 30-25", 200, None),
1230 RangeType::Invalid
1231 );
1232
1233 fn generate_range_header(count: usize) -> Vec<u8> {
1235 let mut s = String::from("bytes=");
1236 for i in 0..count {
1237 let start = i * 4;
1238 let end = start + 1;
1239 if i > 0 {
1240 s.push(',');
1241 }
1242 s.push_str(&start.to_string());
1243 s.push('-');
1244 s.push_str(&end.to_string());
1245 }
1246 s.into_bytes()
1247 }
1248
1249 let ranges = generate_range_header(201);
1251 assert_eq!(
1252 parse_range_header(&ranges, 1000, Some(200)),
1253 RangeType::None
1254 )
1255 }
1256
1257 #[derive(Debug, Eq, PartialEq, Clone)]
1260 pub struct MultiRangeInfo {
1261 pub ranges: Vec<Range<usize>>,
1262 pub boundary: String,
1263 total_length: usize,
1264 content_type: Option<String>,
1265 }
1266
1267 impl MultiRangeInfo {
1268 pub fn new(ranges: Vec<Range<usize>>) -> Self {
1270 Self {
1271 ranges,
1272 boundary: Self::generate_boundary(),
1274 total_length: 0,
1275 content_type: None,
1276 }
1277 }
1278 pub fn set_content_type(&mut self, content_type: String) {
1279 self.content_type = Some(content_type)
1280 }
1281 pub fn set_total_length(&mut self, total_length: usize) {
1282 self.total_length = total_length;
1283 }
1284 fn generate_boundary() -> String {
1289 use rand::Rng;
1290 let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
1291 format!("{:016x}", rng.gen::<u64>())
1292 }
1293 pub fn calculate_multipart_length(&self) -> usize {
1294 let mut total_length = 0;
1295 let content_type = self.content_type.as_ref();
1296 for range in self.ranges.clone() {
1297 total_length += 4 + self.boundary.len() + 2;
1304 total_length += content_type.map_or(0, |ct| 14 + ct.len() + 2);
1305 total_length += format!(
1306 "Content-Range: bytes {}-{}/{}",
1307 range.start,
1308 range.end - 1,
1309 self.total_length
1310 )
1311 .len()
1312 + 2;
1313 total_length += 2;
1314 total_length += range.end - range.start;
1315 }
1316 total_length += 4 + self.boundary.len() + 4;
1318 total_length
1319 }
1320 }
1321 #[derive(Debug, Eq, PartialEq, Clone)]
1322 pub enum RangeType {
1323 None,
1324 Single(Range<usize>),
1325 Multi(MultiRangeInfo),
1326 Invalid,
1327 }
1328
1329 impl RangeType {
1330 #[allow(dead_code)]
1332 fn new_single(start: usize, end: usize) -> Self {
1333 RangeType::Single(Range { start, end })
1334 }
1335 #[allow(dead_code)]
1336 pub fn new_multi(ranges: Vec<Range<usize>>) -> Self {
1337 RangeType::Multi(MultiRangeInfo::new(ranges))
1338 }
1339 #[allow(dead_code)]
1340 fn get_multirange_info(&self) -> Option<&MultiRangeInfo> {
1341 match self {
1342 RangeType::Multi(multi_range_info) => Some(multi_range_info),
1343 _ => None,
1344 }
1345 }
1346 #[allow(dead_code)]
1347 fn update_multirange_info(&mut self, content_length: usize, content_type: Option<String>) {
1348 if let RangeType::Multi(multipart_range_info) = self {
1349 multipart_range_info.content_type = content_type;
1350 multipart_range_info.set_total_length(content_length);
1351 }
1352 }
1353 }
1354
1355 pub fn range_header_filter(
1357 req: &RequestHeader,
1358 resp: &mut ResponseHeader,
1359 max_multipart_ranges: Option<usize>,
1360 ) -> RangeType {
1361 if resp.status != StatusCode::OK {
1365 return RangeType::None;
1366 }
1367
1368 let Some(content_length_bytes) = resp.headers.get(CONTENT_LENGTH) else {
1371 return RangeType::None;
1372 };
1373 let Some(content_length) = parse_number(content_length_bytes.as_bytes()) else {
1375 return RangeType::None;
1376 };
1377
1378 fn request_range_type(
1383 req: &RequestHeader,
1384 resp: &ResponseHeader,
1385 content_length: usize,
1386 max_multipart_ranges: Option<usize>,
1387 ) -> RangeType {
1388 if req.method != http::Method::GET && req.method != http::Method::HEAD {
1390 return RangeType::None;
1391 }
1392
1393 let Some(range_header) = req.headers.get(RANGE) else {
1394 return RangeType::None;
1395 };
1396
1397 if let Some(if_range) = req.headers.get(IF_RANGE) {
1405 let ir = if_range.as_bytes();
1406 let matches = if ir.len() >= 2 && ir.last() == Some(&b'"') {
1407 resp.headers.get(ETAG).is_some_and(|etag| etag == if_range)
1408 } else if let Some(last_modified) = resp.headers.get(LAST_MODIFIED) {
1409 last_modified == if_range
1410 } else {
1411 false
1412 };
1413 if !matches {
1414 return RangeType::None;
1415 }
1416 }
1417
1418 parse_range_header(
1419 range_header.as_bytes(),
1420 content_length,
1421 max_multipart_ranges,
1422 )
1423 }
1424
1425 let mut range_type = request_range_type(req, resp, content_length, max_multipart_ranges);
1426
1427 match &mut range_type {
1428 RangeType::None => {
1429 resp.insert_header(&ACCEPT_RANGES, "bytes").unwrap();
1432 }
1433 RangeType::Single(r) => {
1434 resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1436 resp.remove_header(&ACCEPT_RANGES);
1437 resp.insert_header(&CONTENT_LENGTH, r.end - r.start)
1438 .unwrap();
1439 resp.insert_header(
1440 &CONTENT_RANGE,
1441 format!("bytes {}-{}/{content_length}", r.start, r.end - 1), )
1443 .unwrap()
1444 }
1445
1446 RangeType::Multi(multi_range_info) => {
1447 let content_type = resp
1448 .headers
1449 .get(CONTENT_TYPE)
1450 .and_then(|v| v.to_str().ok())
1451 .unwrap_or("application/octet-stream");
1452 multi_range_info.set_total_length(content_length);
1454 multi_range_info.set_content_type(content_type.to_string());
1455
1456 let total_length = multi_range_info.calculate_multipart_length();
1457
1458 resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1459 resp.remove_header(&ACCEPT_RANGES);
1460 resp.insert_header(CONTENT_LENGTH, total_length).unwrap();
1461 resp.insert_header(
1462 CONTENT_TYPE,
1463 format!(
1464 "multipart/byteranges; boundary={}",
1465 multi_range_info.boundary
1466 ), )
1468 .unwrap();
1469 resp.remove_header(&CONTENT_RANGE);
1470 }
1471 RangeType::Invalid => {
1472 resp.set_status(StatusCode::RANGE_NOT_SATISFIABLE).unwrap();
1474 resp.insert_header(&CONTENT_LENGTH, HeaderValue::from_static("0"))
1476 .unwrap();
1477 resp.remove_header(&ACCEPT_RANGES);
1478 resp.remove_header(&CONTENT_TYPE);
1480 resp.insert_header(&CONTENT_RANGE, format!("bytes */{content_length}"))
1481 .unwrap()
1482 }
1483 }
1484
1485 range_type
1486 }
1487
1488 #[test]
1489 fn test_range_filter_single() {
1490 fn gen_req() -> RequestHeader {
1491 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap()
1492 }
1493 fn gen_resp() -> ResponseHeader {
1494 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1495 resp.append_header("Content-Length", "10").unwrap();
1496 resp
1497 }
1498
1499 let req = gen_req();
1501 let mut resp = gen_resp();
1502 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1503 assert_eq!(resp.status.as_u16(), 200);
1504 assert_eq!(
1505 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1506 b"bytes"
1507 );
1508
1509 let mut req = gen_req();
1511 req.method = Method::HEAD;
1512 let mut resp = gen_resp();
1513 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1514 assert_eq!(resp.status.as_u16(), 200);
1515 assert_eq!(
1516 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1517 b"bytes"
1518 );
1519
1520 let mut req = gen_req();
1522 req.insert_header("Range", "bytes=0-1").unwrap();
1523 let mut resp = gen_resp();
1524 assert_eq!(
1525 RangeType::new_single(0, 2),
1526 range_header_filter(&req, &mut resp, None)
1527 );
1528 assert_eq!(resp.status.as_u16(), 206);
1529 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1530 assert_eq!(
1531 resp.headers.get("content-range").unwrap().as_bytes(),
1532 b"bytes 0-1/10"
1533 );
1534 assert!(resp.headers.get("accept-ranges").is_none());
1535
1536 let mut req = gen_req();
1538 req.insert_header("Range", "bytes=0-1").unwrap();
1539 let mut resp = gen_resp();
1540 resp.insert_header("Accept-Ranges", "bytes").unwrap();
1541 assert_eq!(
1542 RangeType::new_single(0, 2),
1543 range_header_filter(&req, &mut resp, None)
1544 );
1545 assert_eq!(resp.status.as_u16(), 206);
1546 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1547 assert_eq!(
1548 resp.headers.get("content-range").unwrap().as_bytes(),
1549 b"bytes 0-1/10"
1550 );
1551 assert!(resp.headers.get("accept-ranges").is_none());
1553
1554 let mut req = gen_req();
1556 req.insert_header("Range", "bytes=1-0").unwrap();
1557 let mut resp = gen_resp();
1558 resp.insert_header("Accept-Ranges", "bytes").unwrap();
1559 assert_eq!(
1560 RangeType::Invalid,
1561 range_header_filter(&req, &mut resp, None)
1562 );
1563 assert_eq!(resp.status.as_u16(), 416);
1564 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"0");
1565 assert_eq!(
1566 resp.headers.get("content-range").unwrap().as_bytes(),
1567 b"bytes */10"
1568 );
1569 assert!(resp.headers.get("accept-ranges").is_none());
1570 }
1571
1572 #[test]
1574 fn test_range_filter_multipart() {
1575 fn gen_req() -> RequestHeader {
1576 let mut req: RequestHeader =
1577 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1578 req.append_header("Range", "bytes=0-1,3-4,6-7").unwrap();
1579 req
1580 }
1581 fn gen_req_overlap_range() -> RequestHeader {
1582 let mut req: RequestHeader =
1583 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1584 req.append_header("Range", "bytes=0-3,2-5,7-8").unwrap();
1585 req
1586 }
1587 fn gen_resp() -> ResponseHeader {
1588 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1589 resp.append_header("Content-Length", "10").unwrap();
1590 resp
1591 }
1592
1593 let req = gen_req();
1595 let mut resp = gen_resp();
1596 let result = range_header_filter(&req, &mut resp, None);
1597 let mut boundary_str = String::new();
1598
1599 assert!(matches!(result, RangeType::Multi(_)));
1600 if let RangeType::Multi(multi_part_info) = result {
1601 assert_eq!(multi_part_info.ranges.len(), 3);
1602 assert_eq!(multi_part_info.ranges[0], Range { start: 0, end: 2 });
1603 assert_eq!(multi_part_info.ranges[1], Range { start: 3, end: 5 });
1604 assert_eq!(multi_part_info.ranges[2], Range { start: 6, end: 8 });
1605 assert!(multi_part_info.content_type.is_some());
1607 assert_eq!(multi_part_info.total_length, 10);
1608 assert!(!multi_part_info.boundary.is_empty());
1609 boundary_str = multi_part_info.boundary;
1610 }
1611 assert_eq!(resp.status.as_u16(), 206);
1612 assert_eq!(
1614 resp.headers.get("content-type").unwrap().to_str().unwrap(),
1615 format!("multipart/byteranges; boundary={boundary_str}")
1616 );
1617 assert!(resp.headers.get("content_length").is_none());
1618 assert!(resp.headers.get("accept-ranges").is_none());
1619
1620 let req = gen_req_overlap_range();
1622 let mut resp = gen_resp();
1623 let result = range_header_filter(&req, &mut resp, None);
1624
1625 assert!(matches!(result, RangeType::None));
1626 assert_eq!(resp.status.as_u16(), 200);
1627 assert!(resp.headers.get("content-type").is_none());
1628 assert_eq!(
1629 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1630 b"bytes"
1631 );
1632
1633 let mut req = gen_req();
1635 req.insert_header("Range", "bytes=1-0, 12-9, 50-40")
1636 .unwrap();
1637 let mut resp = gen_resp();
1638 let result = range_header_filter(&req, &mut resp, None);
1639 assert!(matches!(result, RangeType::Invalid));
1640 assert_eq!(resp.status.as_u16(), 416);
1641 assert!(resp.headers.get("accept-ranges").is_none());
1642 }
1643
1644 #[test]
1645 fn test_if_range() {
1646 const DATE: &str = "Fri, 07 Jul 2023 22:03:29 GMT";
1647 const ETAG: &str = "\"1234\"";
1648
1649 fn gen_req() -> RequestHeader {
1650 let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1651 req.append_header("Range", "bytes=0-1").unwrap();
1652 req
1653 }
1654 fn get_multipart_req() -> RequestHeader {
1655 let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1656 _ = req.append_header("Range", "bytes=0-1,3-4,6-7");
1657 req
1658 }
1659 fn gen_resp() -> ResponseHeader {
1660 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1661 resp.append_header("Content-Length", "10").unwrap();
1662 resp.append_header("Last-Modified", DATE).unwrap();
1663 resp.append_header("ETag", ETAG).unwrap();
1664 resp
1665 }
1666
1667 let mut req = gen_req();
1669 req.insert_header("If-Range", DATE).unwrap();
1670 let mut resp = gen_resp();
1671 assert_eq!(
1672 RangeType::new_single(0, 2),
1673 range_header_filter(&req, &mut resp, None)
1674 );
1675
1676 let mut req = gen_req();
1678 req.insert_header("If-Range", "Fri, 07 Jul 2023 22:03:25 GMT")
1679 .unwrap();
1680 let mut resp = gen_resp();
1681 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1682 assert_eq!(resp.status.as_u16(), 200);
1683 assert_eq!(
1684 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1685 b"bytes"
1686 );
1687
1688 let mut req = gen_req();
1690 req.insert_header("If-Range", ETAG).unwrap();
1691 let mut resp = gen_resp();
1692 assert_eq!(
1693 RangeType::new_single(0, 2),
1694 range_header_filter(&req, &mut resp, None)
1695 );
1696 assert_eq!(resp.status.as_u16(), 206);
1697 assert!(resp.headers.get("accept-ranges").is_none());
1698
1699 let mut req = gen_req();
1701 req.insert_header("If-Range", "\"4567\"").unwrap();
1702 let mut resp = gen_resp();
1703 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1704 assert_eq!(resp.status.as_u16(), 200);
1705 assert_eq!(
1706 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1707 b"bytes"
1708 );
1709
1710 let mut req = gen_req();
1711 req.insert_header("If-Range", "1234").unwrap();
1712 let mut resp = gen_resp();
1713 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1714 assert_eq!(resp.status.as_u16(), 200);
1715 assert_eq!(
1716 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1717 b"bytes"
1718 );
1719
1720 let mut req = get_multipart_req();
1722 req.insert_header("If-Range", DATE).unwrap();
1723 let mut resp = gen_resp();
1724 let result = range_header_filter(&req, &mut resp, None);
1725 assert!(matches!(result, RangeType::Multi(_)));
1726 assert_eq!(resp.status.as_u16(), 206);
1727 assert!(resp.headers.get("accept-ranges").is_none());
1728
1729 let req = get_multipart_req();
1731 let mut resp = gen_resp();
1732 assert!(matches!(
1733 range_header_filter(&req, &mut resp, None),
1734 RangeType::Multi(_)
1735 ));
1736
1737 let mut req = get_multipart_req();
1739 req.insert_header("If-Range", "\"wrong\"").unwrap();
1740 let mut resp = gen_resp();
1741 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1742 assert_eq!(resp.status.as_u16(), 200);
1743 assert_eq!(
1744 resp.headers.get("accept-ranges").unwrap().as_bytes(),
1745 b"bytes"
1746 );
1747 }
1748
1749 pub struct RangeBodyFilter {
1750 pub range: RangeType,
1751 current: usize,
1752 multipart_idx: Option<usize>,
1753 cache_multipart_idx: Option<usize>,
1754 }
1755
1756 impl Default for RangeBodyFilter {
1757 fn default() -> Self {
1758 Self::new()
1759 }
1760 }
1761
1762 impl RangeBodyFilter {
1763 pub fn new() -> Self {
1764 RangeBodyFilter {
1765 range: RangeType::None,
1766 current: 0,
1767 multipart_idx: None,
1768 cache_multipart_idx: None,
1769 }
1770 }
1771
1772 pub fn new_range(range: RangeType) -> Self {
1773 RangeBodyFilter {
1774 multipart_idx: matches!(range, RangeType::Multi(_)).then_some(0),
1775 range,
1776 ..Default::default()
1777 }
1778 }
1779
1780 pub fn is_multipart_range(&self) -> bool {
1781 matches!(self.range, RangeType::Multi(_))
1782 }
1783
1784 pub fn should_cache_seek_again(&self) -> bool {
1787 match &self.range {
1788 RangeType::Multi(multipart_info) => self
1789 .cache_multipart_idx
1790 .is_some_and(|idx| idx != multipart_info.ranges.len() - 1),
1791 _ => false,
1792 }
1793 }
1794
1795 pub fn next_cache_multipart_range(&mut self) -> Range<usize> {
1797 match &self.range {
1798 RangeType::Multi(multipart_info) => {
1799 match self.cache_multipart_idx.as_mut() {
1800 Some(v) => *v += 1,
1801 None => self.cache_multipart_idx = Some(0),
1802 }
1803 let cache_multipart_idx = self.cache_multipart_idx.expect("set above");
1804 let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1805 assert_eq!(multipart_idx, cache_multipart_idx,
1808 "cache multipart idx should match multipart idx, or there is a hit handler bug");
1809 multipart_info.ranges[cache_multipart_idx].clone()
1810 }
1811 _ => panic!("tried to advance multipart idx on non-multipart range"),
1812 }
1813 }
1814
1815 pub fn set_current_cursor(&mut self, current: usize) {
1816 self.current = current;
1817 }
1818
1819 pub fn set(&mut self, range: RangeType) {
1820 self.multipart_idx = matches!(range, RangeType::Multi(_)).then_some(0);
1821 self.range = range;
1822 }
1823
1824 pub fn finalize(&self, boundary: &String) -> Option<Bytes> {
1826 if let RangeType::Multi(_) = self.range {
1827 Some(Bytes::from(format!("\r\n--{boundary}--\r\n")))
1828 } else {
1829 None
1830 }
1831 }
1832
1833 pub fn filter_body(&mut self, data: Option<Bytes>) -> Option<Bytes> {
1834 match &self.range {
1835 RangeType::None => data,
1836 RangeType::Invalid => None,
1837 RangeType::Single(r) => {
1838 let current = self.current;
1839 self.current += data.as_ref().map_or(0, |d| d.len());
1840 data.and_then(|d| Self::filter_range_data(r.start, r.end, current, d))
1841 }
1842
1843 RangeType::Multi(_) => {
1844 let data = data?;
1845 let current = self.current;
1846 let data_len = data.len();
1847 self.current += data_len;
1848 self.filter_multi_range_body(data, current, data_len)
1849 }
1850 }
1851 }
1852
1853 fn filter_range_data(
1854 start: usize,
1855 end: usize,
1856 current: usize,
1857 data: Bytes,
1858 ) -> Option<Bytes> {
1859 if current + data.len() < start || current >= end {
1860 None
1862 } else if current >= start && current + data.len() <= end {
1863 Some(data)
1865 } else {
1866 let slice_start = start.saturating_sub(current);
1869 let slice_end = std::cmp::min(data.len(), end - current);
1870 Some(data.slice(slice_start..slice_end))
1871 }
1872 }
1873
1874 fn build_multipart_header(
1876 &self,
1877 range: &Range<usize>,
1878 boundary: &str,
1879 total_length: &usize,
1880 content_type: Option<&str>,
1881 ) -> Bytes {
1882 Bytes::from(format!(
1883 "\r\n--{}\r\n{}Content-Range: bytes {}-{}/{}\r\n\r\n",
1884 boundary,
1885 content_type.map_or(String::new(), |ct| format!("Content-Type: {ct}\r\n")),
1886 range.start,
1887 range.end - 1,
1888 total_length
1889 ))
1890 }
1891
1892 fn current_chunk_includes_range_start(
1894 &self,
1895 range: &Range<usize>,
1896 current: usize,
1897 data_len: usize,
1898 ) -> bool {
1899 range.start >= current && range.start < current + data_len
1900 }
1901
1902 fn current_chunk_includes_range_end(
1904 &self,
1905 range: &Range<usize>,
1906 current: usize,
1907 data_len: usize,
1908 ) -> bool {
1909 range.end > current && range.end <= current + data_len
1910 }
1911
1912 fn filter_multi_range_body(
1913 &mut self,
1914 data: Bytes,
1915 current: usize,
1916 data_len: usize,
1917 ) -> Option<Bytes> {
1918 let mut result = BytesMut::new();
1919
1920 let RangeType::Multi(multi_part_info) = &self.range else {
1921 return None;
1922 };
1923
1924 let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1925 let final_range = multi_part_info.ranges.last()?;
1926
1927 let (_, remaining_ranges) = multi_part_info.ranges.as_slice().split_at(multipart_idx);
1928 for range in remaining_ranges {
1931 if let Some(sliced) =
1932 Self::filter_range_data(range.start, range.end, current, data.clone())
1933 {
1934 if self.current_chunk_includes_range_start(range, current, data_len) {
1935 result.extend_from_slice(&self.build_multipart_header(
1936 range,
1937 multi_part_info.boundary.as_ref(),
1938 &multi_part_info.total_length,
1939 multi_part_info.content_type.as_deref(),
1940 ));
1941 }
1942 result.extend_from_slice(&sliced);
1944 if self.current_chunk_includes_range_end(range, current, data_len) {
1945 if range == final_range {
1947 if let Some(final_chunk) = self.finalize(&multi_part_info.boundary) {
1948 result.extend_from_slice(&final_chunk);
1949 }
1950 }
1951 self.multipart_idx = Some(self.multipart_idx.expect("must be set") + 1);
1953 }
1954 } else {
1955 break;
1959 }
1960 }
1961 if result.is_empty() {
1962 None
1963 } else {
1964 Some(result.freeze())
1965 }
1966 }
1967 }
1968
1969 #[test]
1970 fn test_range_body_filter_single() {
1971 let mut body_filter = RangeBodyFilter::new_range(RangeType::None);
1972 assert_eq!(body_filter.filter_body(Some("123".into())).unwrap(), "123");
1973
1974 let mut body_filter = RangeBodyFilter::new_range(RangeType::Invalid);
1975 assert!(body_filter.filter_body(Some("123".into())).is_none());
1976
1977 let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(0, 1));
1978 assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "0");
1979 assert!(body_filter.filter_body(Some("345".into())).is_none());
1980
1981 let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(4, 6));
1982 assert!(body_filter.filter_body(Some("012".into())).is_none());
1983 assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "45");
1984 assert!(body_filter.filter_body(Some("678".into())).is_none());
1985
1986 let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(1, 7));
1987 assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "12");
1988 assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "345");
1989 assert_eq!(body_filter.filter_body(Some("678".into())).unwrap(), "6");
1990 }
1991
1992 #[test]
1993 fn test_range_body_filter_multipart() {
1994 let data = Bytes::from("0123456789");
1996 let ranges = vec![0..3, 6..9];
1997 let content_length = data.len();
1998 let mut body_filter = RangeBodyFilter::new();
1999 body_filter.set(RangeType::new_multi(ranges.clone()));
2000
2001 body_filter
2002 .range
2003 .update_multirange_info(content_length, None);
2004
2005 let multi_range_info = body_filter
2006 .range
2007 .get_multirange_info()
2008 .cloned()
2009 .expect("Multipart Ranges should have MultiPartInfo struct");
2010
2011 let output = body_filter.filter_body(Some(data)).unwrap();
2013 let footer = body_filter.finalize(&multi_range_info.boundary).unwrap();
2014
2015 let output_str = str::from_utf8(&output).unwrap();
2017 let final_boundary = str::from_utf8(&footer).unwrap();
2018 let boundary = &multi_range_info.boundary;
2019
2020 for (i, range) in ranges.iter().enumerate() {
2022 let header = &format!(
2023 "--{}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2024 boundary,
2025 range.start,
2026 range.end - 1,
2027 content_length
2028 );
2029 assert!(
2030 output_str.contains(header),
2031 "Missing part header {} in multipart body",
2032 i
2033 );
2034 let expected_body = &"0123456789"[range.clone()];
2036 assert!(
2037 output_str.contains(expected_body),
2038 "Missing body {} for range {:?}",
2039 expected_body,
2040 range
2041 )
2042 }
2043 assert_eq!(final_boundary, format!("\r\n--{}--\r\n", boundary));
2045
2046 let full_body = b"0123456789";
2048 let ranges = vec![0..2, 4..6, 8..9];
2049 let content_length = full_body.len();
2050 let content_type = "text/plain".to_string();
2051 let mut body_filter = RangeBodyFilter::new();
2052 body_filter.set(RangeType::new_multi(ranges.clone()));
2053
2054 body_filter
2055 .range
2056 .update_multirange_info(content_length, Some(content_type.clone()));
2057
2058 let multi_range_info = body_filter
2059 .range
2060 .get_multirange_info()
2061 .cloned()
2062 .expect("Multipart Ranges should have MultiPartInfo struct");
2063
2064 let chunk1 = Bytes::from_static(b"012");
2066 let chunk2 = Bytes::from_static(b"345");
2067 let chunk3 = Bytes::from_static(b"678");
2068 let chunk4 = Bytes::from_static(b"9");
2069
2070 let mut collected_bytes = BytesMut::new();
2071 for chunk in [chunk1, chunk2, chunk3, chunk4] {
2072 if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2073 collected_bytes.extend_from_slice(&filtered);
2074 }
2075 }
2076 if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2077 collected_bytes.extend_from_slice(&final_boundary);
2078 }
2079
2080 let output_str = str::from_utf8(&collected_bytes).unwrap();
2081 let boundary = multi_range_info.boundary;
2082
2083 for (i, range) in ranges.iter().enumerate() {
2084 let header = &format!(
2085 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2086 boundary,
2087 content_type,
2088 range.start,
2089 range.end - 1,
2090 content_length
2091 );
2092 let expected_body = &full_body[range.clone()];
2093 let expected_output = format!("{}{}", header, str::from_utf8(expected_body).unwrap());
2094
2095 assert!(
2096 output_str.contains(&expected_output),
2097 "Missing or malformed part {} in multipart body. \n Expected: \n{}\n Got: \n{}",
2098 i,
2099 expected_output,
2100 output_str
2101 )
2102 }
2103
2104 assert!(
2105 output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2106 "Missing final boundary"
2107 );
2108
2109 let full_body = b"abcdefghijkl";
2111 let ranges = vec![2..7, 9..11];
2112 let content_length = full_body.len();
2113 let content_type = "application/octet-stream".to_string();
2114 let mut body_filter = RangeBodyFilter::new();
2115 body_filter.set(RangeType::new_multi(ranges.clone()));
2116
2117 body_filter
2118 .range
2119 .update_multirange_info(content_length, Some(content_type.clone()));
2120
2121 let multi_range_info = body_filter
2122 .range
2123 .clone()
2124 .get_multirange_info()
2125 .cloned()
2126 .expect("Multipart Ranges should have MultiPartInfo struct");
2127
2128 let chunk1 = Bytes::from_static(b"abc");
2130 let chunk2 = Bytes::from_static(b"def");
2131 let chunk3 = Bytes::from_static(b"ghi");
2132 let chunk4 = Bytes::from_static(b"jkl");
2133
2134 let mut collected_bytes = BytesMut::new();
2135 for chunk in [chunk1, chunk2, chunk3, chunk4] {
2136 if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2137 collected_bytes.extend_from_slice(&filtered);
2138 }
2139 }
2140 if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2141 collected_bytes.extend_from_slice(&final_boundary);
2142 }
2143
2144 let output_str = str::from_utf8(&collected_bytes).unwrap();
2145 let boundary = &multi_range_info.boundary;
2146
2147 let header1 = &format!(
2148 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2149 boundary,
2150 content_type,
2151 ranges[0].start,
2152 ranges[0].end - 1,
2153 content_length
2154 );
2155 let header2 = &format!(
2156 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2157 boundary,
2158 content_type,
2159 ranges[1].start,
2160 ranges[1].end - 1,
2161 content_length
2162 );
2163
2164 assert!(output_str.contains(header1));
2165 assert!(output_str.contains(header2));
2166
2167 let expected_body_slices = ["cdefg", "jk"];
2168
2169 assert!(
2170 output_str.contains(expected_body_slices[0]),
2171 "Missing expected sliced body {}",
2172 expected_body_slices[0]
2173 );
2174
2175 assert!(
2176 output_str.contains(expected_body_slices[1]),
2177 "Missing expected sliced body {}",
2178 expected_body_slices[1]
2179 );
2180
2181 assert!(
2182 output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2183 "Missing final boundary"
2184 );
2185 }
2186}
2187
2188#[derive(Debug)]
2191pub(crate) enum ServeFromCache {
2192 Off,
2194 CacheHeader,
2196 CacheHeaderOnly,
2198 CacheHeaderOnlyMiss,
2200 CacheBody(bool),
2202 CacheHeaderMiss,
2206 CacheBodyMiss(bool),
2208 Done,
2210 DoneMiss,
2212}
2213
2214impl ServeFromCache {
2215 pub fn new() -> Self {
2216 Self::Off
2217 }
2218
2219 pub fn is_on(&self) -> bool {
2220 !matches!(self, Self::Off)
2221 }
2222
2223 pub fn is_miss(&self) -> bool {
2224 matches!(
2225 self,
2226 Self::CacheHeaderMiss
2227 | Self::CacheHeaderOnlyMiss
2228 | Self::CacheBodyMiss(_)
2229 | Self::DoneMiss
2230 )
2231 }
2232
2233 pub fn is_miss_header(&self) -> bool {
2234 matches!(self, Self::CacheHeaderMiss)
2237 }
2238
2239 pub fn is_miss_body(&self) -> bool {
2240 matches!(self, Self::CacheBodyMiss(_))
2241 }
2242
2243 pub fn should_discard_upstream(&self) -> bool {
2244 self.is_on() && !self.is_miss()
2245 }
2246
2247 pub fn should_send_to_downstream(&self) -> bool {
2248 !self.is_on()
2249 }
2250
2251 pub fn enable(&mut self) {
2252 *self = Self::CacheHeader;
2253 }
2254
2255 pub fn enable_miss(&mut self) {
2256 if !self.is_on() {
2257 *self = Self::CacheHeaderMiss;
2258 }
2259 }
2260
2261 pub fn enable_header_only(&mut self) {
2262 match self {
2263 Self::CacheBody(_) => *self = Self::Done, Self::CacheBodyMiss(_) => *self = Self::DoneMiss,
2265 _ => {
2266 if self.is_miss() {
2267 *self = Self::CacheHeaderOnlyMiss;
2268 } else {
2269 *self = Self::CacheHeaderOnly;
2270 }
2271 }
2272 }
2273 }
2274
2275 pub async fn next_http_task(
2277 &mut self,
2278 cache: &mut HttpCache,
2279 range: &mut RangeBodyFilter,
2280 ) -> Result<HttpTask> {
2281 if !cache.enabled() {
2282 return Error::e_explain(InternalError, "Cache disabled");
2286 }
2287 match self {
2288 Self::Off => panic!("ProxyUseCache not enabled"),
2289 Self::CacheHeader => {
2290 *self = Self::CacheBody(true);
2291 Ok(HttpTask::Header(cache_hit_header(cache), false)) }
2293 Self::CacheHeaderMiss => {
2294 *self = Self::CacheBodyMiss(true);
2295 Ok(HttpTask::Header(cache_hit_header(cache), false)) }
2297 Self::CacheHeaderOnly => {
2298 *self = Self::Done;
2299 Ok(HttpTask::Header(cache_hit_header(cache), true))
2300 }
2301 Self::CacheHeaderOnlyMiss => {
2302 *self = Self::DoneMiss;
2303 Ok(HttpTask::Header(cache_hit_header(cache), true))
2304 }
2305 Self::CacheBody(should_seek) => {
2306 log::trace!("cache body should seek: {should_seek}");
2307 if *should_seek {
2308 self.maybe_seek_hit_handler(cache, range)?;
2309 }
2310 loop {
2311 if let Some(b) = cache.hit_handler().read_body().await? {
2312 return Ok(HttpTask::Body(Some(b), false)); }
2314 if range.should_cache_seek_again() {
2317 self.maybe_seek_hit_handler(cache, range)?;
2318 } else {
2319 *self = Self::Done;
2320 return Ok(HttpTask::Done);
2321 }
2322 }
2323 }
2324 Self::CacheBodyMiss(should_seek) => {
2325 if *should_seek {
2326 self.maybe_seek_miss_handler(cache, range)?;
2327 }
2328 loop {
2330 if let Some(b) = cache.miss_body_reader().unwrap().read_body().await? {
2331 return Ok(HttpTask::Body(Some(b), false)); } else {
2333 if range.should_cache_seek_again() {
2336 self.maybe_seek_miss_handler(cache, range)?;
2337 } else {
2338 *self = Self::DoneMiss;
2339 return Ok(HttpTask::Done);
2340 }
2341 }
2342 }
2343 }
2344 Self::Done => Ok(HttpTask::Done),
2345 Self::DoneMiss => Ok(HttpTask::Done),
2346 }
2347 }
2348
2349 fn maybe_seek_miss_handler(
2350 &mut self,
2351 cache: &mut HttpCache,
2352 range_filter: &mut RangeBodyFilter,
2353 ) -> Result<()> {
2354 match &range_filter.range {
2355 RangeType::Single(range) => {
2356 if cache.miss_body_reader().unwrap().can_seek() {
2358 cache
2359 .miss_body_reader()
2360 .unwrap()
2362 .seek(range.start, Some(range.end))
2363 .or_err(InternalError, "cannot seek miss handler")?;
2364 range_filter.range = RangeType::None;
2367 }
2368 }
2369 RangeType::Multi(_info) => {
2370 if cache.miss_body_reader().unwrap().can_seek_multipart() {
2372 let range = range_filter.next_cache_multipart_range();
2373 cache
2374 .miss_body_reader()
2375 .unwrap()
2376 .seek_multipart(range.start, Some(range.end))
2377 .or_err(InternalError, "cannot seek hit handler for multirange")?;
2378 range_filter.set_current_cursor(range.start);
2381 }
2382 }
2383 _ => {}
2384 }
2385
2386 *self = Self::CacheBodyMiss(false);
2387 Ok(())
2388 }
2389
2390 fn maybe_seek_hit_handler(
2391 &mut self,
2392 cache: &mut HttpCache,
2393 range_filter: &mut RangeBodyFilter,
2394 ) -> Result<()> {
2395 match &range_filter.range {
2396 RangeType::Single(range) => {
2397 if cache.hit_handler().can_seek() {
2398 cache
2399 .hit_handler()
2400 .seek(range.start, Some(range.end))
2401 .or_err(InternalError, "cannot seek hit handler")?;
2402 range_filter.range = RangeType::None;
2405 }
2406 }
2407 RangeType::Multi(_info) => {
2408 if cache.hit_handler().can_seek_multipart() {
2409 let range = range_filter.next_cache_multipart_range();
2410 cache
2411 .hit_handler()
2412 .seek_multipart(range.start, Some(range.end))
2413 .or_err(InternalError, "cannot seek hit handler for multirange")?;
2414 range_filter.set_current_cursor(range.start);
2417 }
2418 }
2419 _ => {}
2420 }
2421 *self = Self::CacheBody(false);
2422 Ok(())
2423 }
2424}