1use super::*;
16use http::header::{CONTENT_LENGTH, CONTENT_TYPE};
17use http::{Method, StatusCode};
18use pingora_cache::key::CacheHashKey;
19use pingora_cache::lock::LockStatus;
20use pingora_cache::max_file_size::ERR_RESPONSE_TOO_LARGE;
21use pingora_cache::{ForcedInvalidationKind, HitStatus, RespCacheable::*};
22use pingora_core::protocols::http::conditional_filter::to_304;
23use pingora_core::protocols::http::v1::common::header_value_content_length;
24use pingora_core::ErrorType;
25use range_filter::RangeBodyFilter;
26use std::time::SystemTime;
27
28impl<SV> HttpProxy<SV> {
29 pub(crate) async fn proxy_cache(
31 self: &Arc<Self>,
32 session: &mut Session,
33 ctx: &mut SV::CTX,
34 ) -> Option<(bool, Option<Box<Error>>)>
35 where
37 SV: ProxyHttp + Send + Sync + 'static,
38 SV::CTX: Send + Sync,
39 {
40 if let Err(e) = self.inner.request_cache_filter(session, ctx) {
42 warn!(
44 "Fail to request_cache_filter: {e}, {}",
45 self.inner.request_summary(session, ctx)
46 );
47 }
48
49 if session.cache.enabled() {
51 match self.inner.cache_key_callback(session, ctx) {
52 Ok(key) => {
53 session.cache.set_cache_key(key);
54 }
55 Err(e) => {
56 session.cache.disable(NoCacheReason::StorageError);
58 warn!(
59 "Fail to cache_key_callback: {e}, {}",
60 self.inner.request_summary(session, ctx)
61 );
62 }
63 }
64 }
65
66 if self.inner.is_purge(session, ctx) {
68 return self.proxy_purge(session, ctx).await;
69 }
70
71 if session.cache.enabled() && !session.cache.cacheable_prediction() {
73 session.cache.bypass();
74 }
75
76 if !session.cache.enabled() {
77 return None;
78 }
79
80 loop {
82 match session.cache.cache_lookup().await {
84 Ok(res) => {
85 let mut hit_status_opt = None;
86 if let Some((mut meta, mut handler)) = res {
87 let cache_key = session.cache.cache_key();
92 if let Some(variance) = cache_key.variance_bin() {
93 if Some(variance) != meta.variance() {
96 warn!("Cache variance mismatch, {variance:?}, {cache_key:?}");
97 session.cache.disable(NoCacheReason::InternalError);
98 break None;
99 }
100 } else {
101 let req_header = session.req_header();
103 let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
104 if let Some(variance) = variance {
105 if !session.cache.cache_vary_lookup(variance, &meta) {
107 continue;
110 }
111 } }
113
114 let is_fresh = meta.is_fresh(SystemTime::now());
119 let hit_status = match self
121 .inner
122 .cache_hit_filter(session, &meta, &mut handler, is_fresh, ctx)
123 .await
124 {
125 Err(e) => {
126 error!(
127 "Failed to filter cache hit: {e}, {}",
128 self.inner.request_summary(session, ctx)
129 );
130 HitStatus::FailedHitFilter
132 }
133 Ok(None) => {
134 if is_fresh {
135 HitStatus::Fresh
136 } else {
137 HitStatus::Expired
138 }
139 }
140 Ok(Some(ForcedInvalidationKind::ForceExpired)) => {
141 meta.disable_serve_stale();
144 HitStatus::ForceExpired
145 }
146 Ok(Some(ForcedInvalidationKind::ForceMiss)) => HitStatus::ForceMiss,
147 };
148
149 hit_status_opt = Some(hit_status);
150
151 session.cache.cache_found(meta, handler, hit_status);
153 }
154
155 if hit_status_opt.map_or(true, HitStatus::is_treated_as_miss) {
156 if session.cache.is_cache_locked() {
158 let lock_status = session.cache.cache_lock_wait().await;
160 if self.handle_lock_status(session, ctx, lock_status) {
161 continue;
162 } else {
163 break None;
164 }
165 } else {
166 self.inner.cache_miss(session, ctx);
167 break None;
168 }
169 }
170
171 let hit_status = hit_status_opt.expect("None case handled as miss");
174
175 if !hit_status.is_fresh() {
176 if session.cache.is_cache_locked() {
178 if let Some(write_lock) = session
180 .subrequest_ctx
181 .as_mut()
182 .and_then(|ctx| ctx.take_write_lock())
183 {
184 session.cache.set_write_lock(write_lock);
186 session.cache.tag_as_subrequest();
187 break None;
189 }
190 let will_serve_stale = session.cache.can_serve_stale_updating()
191 && self.inner.should_serve_stale(session, ctx, None);
192 if !will_serve_stale {
193 let lock_status = session.cache.cache_lock_wait().await;
194 if self.handle_lock_status(session, ctx, lock_status) {
195 continue;
196 } else {
197 break None;
198 }
199 }
200 session.cache.set_stale_updating();
202 } else if session.cache.is_cache_lock_writer() {
203 let will_serve_stale = session.cache.can_serve_stale_updating()
205 && self.inner.should_serve_stale(session, ctx, None);
206 if will_serve_stale {
207 let subrequest =
209 Box::new(crate::subrequest::create_dummy_session(session));
210 let new_app = self.clone(); let (permit, cache_lock) = session.cache.take_write_lock();
212 let sub_req_ctx = Box::new(SubReqCtx::with_write_lock(
213 cache_lock,
214 session.cache.cache_key().clone(),
215 permit,
216 ));
217 tokio::spawn(async move {
218 new_app.process_subrequest(subrequest, sub_req_ctx).await;
219 });
220 session.cache.set_stale_updating();
222 } else {
223 break None;
225 }
226 } else {
227 break None;
229 }
230 }
231
232 let (reuse, err) = self.proxy_cache_hit(session, ctx).await;
233 if let Some(e) = err.as_ref() {
234 error!(
235 "Fail to serve cache: {e}, {}",
236 self.inner.request_summary(session, ctx)
237 );
238 }
239 break Some((reuse, err));
241 }
242 Err(e) => {
243 self.inner.cache_miss(session, ctx);
248 warn!(
249 "Fail to cache lookup: {e}, {}",
250 self.inner.request_summary(session, ctx)
251 );
252 break None;
253 }
254 }
255 }
256 }
257
258 pub(crate) async fn proxy_cache_hit(
260 &self,
261 session: &mut Session,
262 ctx: &mut SV::CTX,
263 ) -> (bool, Option<Box<Error>>)
264 where
265 SV: ProxyHttp + Send + Sync,
266 SV::CTX: Send + Sync,
267 {
268 use range_filter::*;
269
270 let seekable = session.cache.hit_handler().can_seek();
271 let mut header = cache_hit_header(&session.cache);
272
273 let req = session.req_header();
274
275 let not_modified = match self.inner.cache_not_modified_filter(session, &header, ctx) {
276 Ok(not_modified) => not_modified,
277 Err(e) => {
278 warn!(
281 "Failed to run cache not modified filter: {e}, {}",
282 self.inner.request_summary(session, ctx)
283 );
284 false
285 }
286 };
287 if not_modified {
288 to_304(&mut header);
289 }
290 let header_only = not_modified || req.method == http::method::Method::HEAD;
291
292 let range_type = if seekable && !session.ignore_downstream_range {
294 self.inner.range_header_filter(session, &mut header, ctx)
295 } else {
296 RangeType::None
297 };
298
299 let header_only = header_only || matches!(range_type, RangeType::Invalid);
301
302 match self.inner.response_filter(session, &mut header, ctx).await {
304 Ok(_) => {
305 if let Err(e) = session
306 .downstream_modules_ctx
307 .response_header_filter(&mut header, header_only)
308 .await
309 {
310 error!(
311 "Failed to run downstream modules response header filter in hit: {e}, {}",
312 self.inner.request_summary(session, ctx)
313 );
314 session
315 .as_mut()
316 .respond_error(500)
317 .await
318 .unwrap_or_else(|e| {
319 error!("failed to send error response to downstream: {e}");
320 });
321 return (true, Some(e));
323 }
324
325 if let Err(e) = session
326 .as_mut()
327 .write_response_header(header)
328 .await
329 .map_err(|e| e.into_down())
330 {
331 return (false, Some(e));
333 }
334 }
335 Err(e) => {
336 error!(
337 "Failed to run response filter in hit: {e}, {}",
338 self.inner.request_summary(session, ctx)
339 );
340 session
341 .as_mut()
342 .respond_error(500)
343 .await
344 .unwrap_or_else(|e| {
345 error!("failed to send error response to downstream: {e}");
346 });
347 return (true, Some(e));
349 }
350 }
351 debug!("finished sending cached header to downstream");
352
353 if !header_only {
354 let mut maybe_range_filter = match &range_type {
355 RangeType::Single(r) => {
356 if let Err(e) = session.cache.hit_handler().seek(r.start, Some(r.end)) {
357 return (false, Some(e));
358 }
359 None
360 }
361 RangeType::Multi(_) => {
362 let mut range_filter = RangeBodyFilter::new();
364 range_filter.set(range_type.clone());
365 Some(range_filter)
366 }
367 RangeType::Invalid => unreachable!(),
368 RangeType::None => None,
369 };
370 loop {
371 match session.cache.hit_handler().read_body().await {
372 Ok(raw_body) => {
373 let end = raw_body.is_none();
374
375 let mut body = if let Some(range_filter) = maybe_range_filter.as_mut() {
376 range_filter.filter_body(raw_body)
377 } else {
378 raw_body
379 };
380
381 match self
382 .inner
383 .response_body_filter(session, &mut body, end, ctx)
384 {
385 Ok(Some(duration)) => {
386 trace!("delaying response for {duration:?}");
387 time::sleep(duration).await;
388 }
389 Ok(None) => { }
390 Err(e) => {
391 return (false, Some(e));
393 }
394 }
395
396 if let Err(e) = session
397 .downstream_modules_ctx
398 .response_body_filter(&mut body, end)
399 {
400 return (false, Some(e));
402 }
403
404 if !end && body.as_ref().map_or(true, |b| b.is_empty()) {
405 continue;
408 }
409
410 let b = body.unwrap_or_default();
412 if let Err(e) = session
413 .as_mut()
414 .write_response_body(b, end)
415 .await
416 .map_err(|e| e.into_down())
417 {
418 return (false, Some(e));
419 }
420 if end {
421 break;
422 }
423 }
424 Err(e) => return (false, Some(e)),
425 }
426 }
427 }
428
429 if let Err(e) = session.cache.finish_hit_handler().await {
430 warn!("Error during finish_hit_handler: {}", e);
431 }
432
433 match session.as_mut().finish_body().await {
434 Ok(_) => {
435 debug!("finished sending cached body to downstream");
436 (true, None)
437 }
438 Err(e) => (false, Some(e)),
439 }
440 }
441
442 pub(crate) fn downstream_response_conditional_filter(
445 &self,
446 use_cache: &mut ServeFromCache,
447 session: &Session,
448 resp: &mut ResponseHeader,
449 ctx: &mut SV::CTX,
450 ) where
451 SV: ProxyHttp,
452 {
453 let req = session.req_header();
455
456 let not_modified = match self.inner.cache_not_modified_filter(session, resp, ctx) {
457 Ok(not_modified) => not_modified,
458 Err(e) => {
459 warn!(
462 "Failed to run cache not modified filter: {e}, {}",
463 self.inner.request_summary(session, ctx)
464 );
465 false
466 }
467 };
468
469 if not_modified {
470 to_304(resp);
471 }
472 let header_only = not_modified || req.method == http::method::Method::HEAD;
473 if header_only {
474 if use_cache.is_on() {
475 use_cache.enable_header_only();
477 } else {
478 }
484 }
485 }
486
487 pub(crate) async fn cache_http_task(
490 &self,
491 session: &mut Session,
492 task: &HttpTask,
493 ctx: &mut SV::CTX,
494 serve_from_cache: &mut ServeFromCache,
495 ) -> Result<()>
496 where
497 SV: ProxyHttp + Send + Sync,
498 SV::CTX: Send + Sync,
499 {
500 if !session.cache.enabled() && !session.cache.bypassing() {
501 return Ok(());
502 }
503
504 match task {
505 HttpTask::Header(header, end_stream) => {
506 if header.status.is_informational()
510 && header.status != StatusCode::SWITCHING_PROTOCOLS
511 {
512 return Ok(());
513 }
514 match self.inner.response_cache_filter(session, header, ctx)? {
515 Cacheable(meta) => {
516 let mut fill_cache = true;
517 if session.cache.bypassing() {
518 if session.cache.max_file_size_bytes().is_some()
525 && !meta.headers().contains_key(header::CONTENT_LENGTH)
526 {
527 session
528 .cache
529 .disable(NoCacheReason::PredictedResponseTooLarge);
530 return Ok(());
531 }
532
533 session.cache.response_became_cacheable();
534
535 if session.req_header().method == Method::GET
536 && meta.response_header().status == StatusCode::OK
537 {
538 self.inner.cache_miss(session, ctx);
539 } else {
540 fill_cache = false;
546 session.cache.disable(NoCacheReason::Deferred);
547 }
548 }
549
550 if session.cache.enabled() {
553 if let Some(max_file_size) = session.cache.max_file_size_bytes() {
554 let content_length_hdr = meta.headers().get(header::CONTENT_LENGTH);
555 if let Some(content_length) =
556 header_value_content_length(content_length_hdr)
557 {
558 if content_length > max_file_size {
559 fill_cache = false;
560 session.cache.response_became_uncacheable(
561 NoCacheReason::ResponseTooLarge,
562 );
563 session.cache.disable(NoCacheReason::ResponseTooLarge);
564 session.ignore_downstream_range = true;
566 }
567 }
568 }
572 }
573 if fill_cache {
574 let req_header = session.req_header();
575 let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
580 session.cache.set_cache_meta(meta);
581 session.cache.update_variance(variance);
582 session.cache.set_miss_handler().await?;
584 if session.cache.miss_body_reader().is_some() {
585 serve_from_cache.enable_miss();
586 }
587 if *end_stream {
588 session
589 .cache
590 .miss_handler()
591 .unwrap() .write_body(Bytes::new(), true)
593 .await?;
594 session.cache.finish_miss_handler().await?;
595 }
596 }
597 }
598 Uncacheable(reason) => {
599 if !session.cache.bypassing() {
600 session.cache.response_became_uncacheable(reason);
602 }
603 session.cache.disable(reason);
604 }
605 }
606 }
607 HttpTask::Body(data, end_stream) => match data {
608 Some(d) => {
609 if session.cache.enabled() {
610 let body_size_allowed =
613 session.cache.track_body_bytes_for_max_file_size(d.len());
614 if !body_size_allowed {
615 debug!("chunked response exceeded max cache size, remembering that it is uncacheable");
616 session
617 .cache
618 .response_became_uncacheable(NoCacheReason::ResponseTooLarge);
619
620 return Error::e_explain(
621 ERR_RESPONSE_TOO_LARGE,
622 format!(
623 "writing data of size {} bytes would exceed max file size of {} bytes",
624 d.len(),
625 session.cache.max_file_size_bytes().expect("max file size bytes must be set to exceed size")
626 ),
627 );
628 }
629
630 let miss_handler = session.cache.miss_handler().unwrap();
633
634 miss_handler.write_body(d.clone(), *end_stream).await?;
635 if *end_stream {
636 session.cache.finish_miss_handler().await?;
637 }
638 }
639 }
640 None => {
641 if session.cache.enabled() && *end_stream {
642 session.cache.finish_miss_handler().await?;
643 }
644 }
645 },
646 HttpTask::Trailer(_) => {} HttpTask::Done => {
648 if session.cache.enabled() {
649 session.cache.finish_miss_handler().await?;
650 }
651 }
652 HttpTask::Failed(_) => {
653 }
655 }
656 Ok(())
657 }
658
659 pub(crate) async fn revalidate_or_stale(
664 &self,
665 session: &mut Session,
666 task: &mut HttpTask,
667 ctx: &mut SV::CTX,
668 ) -> bool
669 where
670 SV: ProxyHttp + Send + Sync,
671 SV::CTX: Send + Sync,
672 {
673 if !session.cache.enabled() {
674 return false;
675 }
676
677 match task {
678 HttpTask::Header(resp, _eos) => {
679 if resp.status == StatusCode::NOT_MODIFIED {
680 if session.cache.maybe_cache_meta().is_some() {
681 if let Err(err) = self.inner.upstream_response_filter(session, resp, ctx) {
683 error!("upstream response filter error on 304: {err:?}");
684 session.cache.revalidate_uncacheable(
685 *resp.clone(),
686 NoCacheReason::InternalError,
687 );
688 return true;
690 }
691 let merged_header = session.cache.revalidate_merge_header(resp);
694 match self
695 .inner
696 .response_cache_filter(session, &merged_header, ctx)
697 {
698 Ok(Cacheable(mut meta)) => {
699 let old_meta = session.cache.maybe_cache_meta().unwrap(); if let Some(old_variance) = old_meta.variance() {
708 meta.set_variance(old_variance);
709 }
710 if let Err(e) = session.cache.revalidate_cache_meta(meta).await {
711 warn!("revalidate_cache_meta failed {e:?}");
714 }
715 }
716 Ok(Uncacheable(reason)) => {
717 warn!("Uncacheable {reason:?} 304 received");
730 session.cache.response_became_uncacheable(reason);
731 session.cache.revalidate_uncacheable(merged_header, reason);
732 }
733 Err(e) => {
734 warn!("Error {e:?} response_cache_filter during revalidation");
738 session.cache.revalidate_uncacheable(
739 merged_header,
740 NoCacheReason::InternalError,
741 );
742 }
744 }
745 true
747 } else {
748 warn!("304 received without cached asset, disable caching");
750 let reason = NoCacheReason::Custom("304 on miss");
751 session.cache.response_became_uncacheable(reason);
752 session.cache.disable(reason);
753 false
754 }
755 } else if resp.status.is_server_error() {
756 if !session.cache.can_serve_stale_error()
760 || session.response_written().is_some()
761 {
762 return false;
763 }
764
765 let http_status_error = Error::create(
767 ErrorType::HTTPStatus(resp.status.as_u16()),
768 ErrorSource::Upstream,
769 None,
770 None,
771 );
772 if self
773 .inner
774 .should_serve_stale(session, ctx, Some(&http_status_error))
775 {
776 session
778 .cache
779 .release_write_lock(NoCacheReason::UpstreamError);
780 true
781 } else {
782 false
783 }
784 } else {
785 false }
787 }
788 _ => false, }
790 }
791
792 pub(crate) async fn handle_stale_if_error(
795 &self,
796 session: &mut Session,
797 ctx: &mut SV::CTX,
798 error: &Error,
799 ) -> Option<(bool, Option<Box<Error>>)>
800 where
801 SV: ProxyHttp + Send + Sync,
802 SV::CTX: Send + Sync,
803 {
804 if !session.cache.can_serve_stale_error() {
806 return None;
807 }
808
809 if session.response_written().is_some() {
812 return None;
813 }
814
815 if !self.inner.should_serve_stale(session, ctx, Some(error)) {
817 return None;
818 }
819
820 warn!(
822 "Fail to proxy: {}, serving stale, {}",
823 error,
824 self.inner.request_summary(session, ctx)
825 );
826
827 session
829 .cache
830 .release_write_lock(NoCacheReason::UpstreamError);
831
832 Some(self.proxy_cache_hit(session, ctx).await)
833 }
834
835 fn handle_lock_status(
837 &self,
838 session: &mut Session,
839 ctx: &SV::CTX,
840 lock_status: LockStatus,
841 ) -> bool
842 where
843 SV: ProxyHttp,
844 {
845 debug!("cache unlocked {lock_status:?}");
846 match lock_status {
847 LockStatus::Done => true,
849 LockStatus::TransientError => true,
851 LockStatus::GiveUp => {
853 session.cache.disable(NoCacheReason::CacheLockGiveUp);
855 false
857 }
858 LockStatus::Dangling => {
860 warn!(
862 "Dangling cache lock, {}",
863 self.inner.request_summary(session, ctx)
864 );
865 true
866 }
867 LockStatus::Timeout => {
875 warn!(
876 "Cache lock timeout, {}",
877 self.inner.request_summary(session, ctx)
878 );
879 session.cache.disable(NoCacheReason::CacheLockTimeout);
880 false
882 }
883 LockStatus::Waiting => panic!("impossible LockStatus::Waiting"),
885 }
886 }
887}
888
889fn cache_hit_header(cache: &HttpCache) -> Box<ResponseHeader> {
890 let mut header = Box::new(cache.cache_meta().response_header_copy());
891 let no_body = matches!(header.status.as_u16(), 204 | 304);
895
896 if !cache.upstream_used() {
900 let age = cache.cache_meta().age().as_secs();
901 header.insert_header(http::header::AGE, age).unwrap();
902 }
903
904 if !no_body
907 && !header.status.is_informational()
908 && header.headers.get(http::header::CONTENT_LENGTH).is_none()
909 {
910 header
911 .insert_header(http::header::TRANSFER_ENCODING, "chunked")
912 .unwrap();
913 }
914 header
915}
916
917pub mod range_filter {
919 use super::*;
920 use bytes::BytesMut;
921 use http::header::*;
922 use std::ops::Range;
923
924 fn parse_number(input: &[u8]) -> Option<usize> {
926 str::from_utf8(input).ok()?.parse().ok()
927 }
928
929 fn parse_range_header(range: &[u8], content_length: usize) -> RangeType {
930 use regex::Regex;
931
932 static RE_SINGLE_RANGE_PART: Lazy<Regex> =
934 Lazy::new(|| Regex::new(r"(?i)^\s*(?P<start>\d*)-(?P<end>\d*)\s*$").unwrap());
935
936 let range_str = match str::from_utf8(range) {
938 Ok(s) => s,
939 Err(_) => return RangeType::None,
940 };
941
942 let mut parts = range_str.splitn(2, "=");
944
945 let prefix = parts.next();
947 if !prefix.is_some_and(|s| s.eq_ignore_ascii_case("bytes")) {
948 return RangeType::None;
949 }
950
951 let Some(ranges_str) = parts.next() else {
952 return RangeType::None;
954 };
955
956 let mut range_count = 0;
958 for _ in ranges_str.split(',') {
959 range_count += 1;
960 const MAX_RANGES: usize = 100;
962 if range_count >= MAX_RANGES {
963 return RangeType::None;
965 }
966 }
967 let mut ranges: Vec<Range<usize>> = Vec::with_capacity(range_count);
968
969 let mut last_range_end = 0;
971 for part in ranges_str.split(',') {
972 let captured = match RE_SINGLE_RANGE_PART.captures(part) {
973 Some(c) => c,
974 None => {
975 return RangeType::None;
976 }
977 };
978
979 let maybe_start = captured
980 .name("start")
981 .and_then(|s| s.as_str().parse::<usize>().ok());
982 let end = captured
983 .name("end")
984 .and_then(|s| s.as_str().parse::<usize>().ok());
985
986 let range = if let Some(start) = maybe_start {
987 if start >= content_length {
988 continue;
990 }
991 let end = std::cmp::min(end.unwrap_or(content_length - 1), content_length - 1) + 1;
995 if end <= start {
996 continue;
998 }
999 start..end
1000 } else {
1001 if let Some(end) = end {
1004 if content_length >= end {
1005 (content_length - end)..content_length
1006 } else {
1007 0..content_length
1009 }
1010 } else {
1011 continue;
1013 }
1014 };
1015 if range.start < last_range_end {
1018 return RangeType::None;
1019 }
1020 last_range_end = range.end;
1021 ranges.push(range);
1022 }
1023
1024 if ranges.is_empty() {
1034 RangeType::Invalid
1036 } else if ranges.len() == 1 {
1037 RangeType::Single(ranges[0].clone()) } else {
1039 RangeType::Multi(MultiRangeInfo::new(ranges))
1040 }
1041 }
1042 #[test]
1043 fn test_parse_range() {
1044 assert_eq!(
1045 parse_range_header(b"bytes=0-1", 10),
1046 RangeType::new_single(0, 2)
1047 );
1048 assert_eq!(
1049 parse_range_header(b"bYTes=0-9", 10),
1050 RangeType::new_single(0, 10)
1051 );
1052 assert_eq!(
1053 parse_range_header(b"bytes=0-12", 10),
1054 RangeType::new_single(0, 10)
1055 );
1056 assert_eq!(
1057 parse_range_header(b"bytes=0-", 10),
1058 RangeType::new_single(0, 10)
1059 );
1060 assert_eq!(parse_range_header(b"bytes=2-1", 10), RangeType::Invalid);
1061 assert_eq!(parse_range_header(b"bytes=10-11", 10), RangeType::Invalid);
1062 assert_eq!(
1063 parse_range_header(b"bytes=-2", 10),
1064 RangeType::new_single(8, 10)
1065 );
1066 assert_eq!(
1067 parse_range_header(b"bytes=-12", 10),
1068 RangeType::new_single(0, 10)
1069 );
1070 assert_eq!(parse_range_header(b"bytes=-", 10), RangeType::Invalid);
1071 assert_eq!(parse_range_header(b"bytes=", 10), RangeType::None);
1072 }
1073
1074 #[test]
1076 fn test_parse_range_header_multi() {
1077 assert_eq!(
1078 parse_range_header(b"bytes=0-1,4-5", 10)
1079 .get_multirange_info()
1080 .expect("Should have multipart info for Multipart range request")
1081 .ranges,
1082 (vec![Range { start: 0, end: 2 }, Range { start: 4, end: 6 }])
1083 );
1084 assert_eq!(
1086 parse_range_header(b"bytEs=0-99,200-299,400-499", 320)
1087 .get_multirange_info()
1088 .expect("Should have multipart info for Multipart range request")
1089 .ranges,
1090 (vec![
1091 Range { start: 0, end: 100 },
1092 Range {
1093 start: 200,
1094 end: 300
1095 }
1096 ])
1097 );
1098 assert_eq!(
1100 parse_range_header(b"bytEs=0-99,200-299,400-499", 500)
1101 .get_multirange_info()
1102 .expect("Should have multipart info for Multipart range request")
1103 .ranges,
1104 vec![
1105 Range { start: 0, end: 100 },
1106 Range {
1107 start: 200,
1108 end: 300
1109 },
1110 Range {
1111 start: 400,
1112 end: 500
1113 },
1114 ]
1115 );
1116 assert_eq!(parse_range_header(b"bytes=0-,-2", 10), RangeType::None,);
1118 assert!(parse_range_header(b"bytes=0-,-2", 10)
1120 .get_multirange_info()
1121 .is_none());
1122 assert_eq!(parse_range_header(b"bytes=0-3,2-5", 10), RangeType::None,);
1124 assert!(parse_range_header(b"bytes=0-3,2-5", 10)
1125 .get_multirange_info()
1126 .is_none());
1127
1128 assert_eq!(
1130 parse_range_header(b"bytes=0-5,10-", 2),
1131 RangeType::new_single(0, 2)
1132 );
1133 assert!(parse_range_header(b"bytes=0-5,10-", 2)
1134 .get_multirange_info()
1135 .is_none());
1136
1137 assert_eq!(
1139 parse_range_header(b"bytes=0-5, 10-20, 30-18", 200)
1140 .get_multirange_info()
1141 .expect("Should have multipart info for Multipart range request")
1142 .ranges,
1143 vec![Range { start: 0, end: 6 }, Range { start: 10, end: 21 },]
1144 );
1145 assert_eq!(
1147 parse_range_header(b"bytes=5-0, 20-15, 30-25", 200),
1148 RangeType::Invalid
1149 );
1150
1151 fn generate_range_header(count: usize) -> Vec<u8> {
1153 let mut s = String::from("bytes=");
1154 for i in 0..count {
1155 let start = i * 4;
1156 let end = start + 1;
1157 if i > 0 {
1158 s.push(',');
1159 }
1160 s.push_str(&start.to_string());
1161 s.push('-');
1162 s.push_str(&end.to_string());
1163 }
1164 s.into_bytes()
1165 }
1166
1167 let ranges = generate_range_header(101);
1169 assert_eq!(parse_range_header(&ranges, 1000), RangeType::None)
1170 }
1171
1172 #[derive(Debug, Eq, PartialEq, Clone)]
1175 pub struct MultiRangeInfo {
1176 pub ranges: Vec<Range<usize>>,
1177 pub boundary: String,
1178 total_length: usize,
1179 content_type: Option<String>,
1180 }
1181
1182 impl MultiRangeInfo {
1183 pub fn new(ranges: Vec<Range<usize>>) -> Self {
1185 Self {
1186 ranges,
1187 boundary: Self::generate_boundary(),
1189 total_length: 0,
1190 content_type: None,
1191 }
1192 }
1193 pub fn set_content_type(&mut self, content_type: String) {
1194 self.content_type = Some(content_type)
1195 }
1196 pub fn set_total_length(&mut self, total_length: usize) {
1197 self.total_length = total_length;
1198 }
1199 fn generate_boundary() -> String {
1204 use rand::Rng;
1205 let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
1206 format!("{:016x}", rng.gen::<u64>())
1207 }
1208 fn calculate_multipart_length(&self) -> usize {
1209 let mut total_length = 0;
1210 let content_type = self.content_type.as_ref();
1211 for range in self.ranges.clone() {
1212 total_length += 4 + self.boundary.len() + 2;
1219 total_length += content_type.map_or(0, |ct| 14 + ct.len() + 2);
1220 total_length += format!(
1221 "Content-Range: bytes {}-{}/{}",
1222 range.start,
1223 range.end - 1,
1224 self.total_length
1225 )
1226 .len()
1227 + 2;
1228 total_length += 2;
1229 total_length += range.end - range.start;
1230 }
1231 total_length += 4 + self.boundary.len() + 4;
1233 total_length
1234 }
1235 }
1236 #[derive(Debug, Eq, PartialEq, Clone)]
1237 pub enum RangeType {
1238 None,
1239 Single(Range<usize>),
1240 Multi(MultiRangeInfo),
1241 Invalid,
1242 }
1243
1244 impl RangeType {
1245 #[allow(dead_code)]
1247 fn new_single(start: usize, end: usize) -> Self {
1248 RangeType::Single(Range { start, end })
1249 }
1250 #[allow(dead_code)]
1251 pub fn new_multi(ranges: Vec<Range<usize>>) -> Self {
1252 RangeType::Multi(MultiRangeInfo::new(ranges))
1253 }
1254 #[allow(dead_code)]
1255 fn get_multirange_info(&self) -> Option<&MultiRangeInfo> {
1256 match self {
1257 RangeType::Multi(multi_range_info) => Some(multi_range_info),
1258 _ => None,
1259 }
1260 }
1261 #[allow(dead_code)]
1262 fn update_multirange_info(&mut self, content_length: usize, content_type: Option<String>) {
1263 if let RangeType::Multi(multipart_range_info) = self {
1264 multipart_range_info.content_type = content_type;
1265 multipart_range_info.set_total_length(content_length);
1266 }
1267 }
1268 }
1269
1270 pub fn range_header_filter(req: &RequestHeader, resp: &mut ResponseHeader) -> RangeType {
1272 if resp.status != StatusCode::OK {
1276 return RangeType::None;
1277 }
1278
1279 if req.method != http::Method::GET && req.method != http::Method::HEAD {
1281 return RangeType::None;
1282 }
1283
1284 let Some(range_header) = req.headers.get(RANGE) else {
1285 return RangeType::None;
1286 };
1287
1288 let Some(content_length_bytes) = resp.headers.get(CONTENT_LENGTH) else {
1291 return RangeType::None;
1292 };
1293 let Some(content_length) = parse_number(content_length_bytes.as_bytes()) else {
1295 return RangeType::None;
1296 };
1297
1298 if let Some(if_range) = req.headers.get(IF_RANGE) {
1306 let ir = if_range.as_bytes();
1307 let matches = if ir.len() >= 2 && ir.last() == Some(&b'"') {
1308 resp.headers.get(ETAG).is_some_and(|etag| etag == if_range)
1309 } else if let Some(last_modified) = resp.headers.get(LAST_MODIFIED) {
1310 last_modified == if_range
1311 } else {
1312 false
1313 };
1314 if !matches {
1315 return RangeType::None;
1316 }
1317 }
1318
1319 let mut range_type = parse_range_header(range_header.as_bytes(), content_length);
1323
1324 match &mut range_type {
1325 RangeType::None => { }
1326 RangeType::Single(r) => {
1327 resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1329 resp.insert_header(&CONTENT_LENGTH, r.end - r.start)
1330 .unwrap();
1331 resp.insert_header(
1332 &CONTENT_RANGE,
1333 format!("bytes {}-{}/{content_length}", r.start, r.end - 1), )
1335 .unwrap()
1336 }
1337
1338 RangeType::Multi(multi_range_info) => {
1339 let content_type = resp
1340 .headers
1341 .get(CONTENT_TYPE)
1342 .and_then(|v| v.to_str().ok())
1343 .unwrap_or("application/octet-stream");
1344 multi_range_info.set_total_length(content_length);
1346 multi_range_info.set_content_type(content_type.to_string());
1347
1348 let total_length = multi_range_info.calculate_multipart_length();
1349
1350 resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1351 resp.insert_header(CONTENT_LENGTH, total_length).unwrap();
1352 resp.insert_header(
1353 CONTENT_TYPE,
1354 format!(
1355 "multipart/byteranges; boundary={}",
1356 multi_range_info.boundary
1357 ), )
1359 .unwrap();
1360 resp.remove_header(&CONTENT_RANGE);
1361 }
1362 RangeType::Invalid => {
1363 resp.set_status(StatusCode::RANGE_NOT_SATISFIABLE).unwrap();
1365 resp.insert_header(&CONTENT_LENGTH, HeaderValue::from_static("0"))
1367 .unwrap();
1368 resp.remove_header(&CONTENT_TYPE);
1370 resp.insert_header(&CONTENT_RANGE, format!("bytes */{content_length}"))
1371 .unwrap()
1372 }
1373 }
1374
1375 range_type
1376 }
1377
1378 #[test]
1379 fn test_range_filter_single() {
1380 fn gen_req() -> RequestHeader {
1381 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap()
1382 }
1383 fn gen_resp() -> ResponseHeader {
1384 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1385 resp.append_header("Content-Length", "10").unwrap();
1386 resp
1387 }
1388
1389 let req = gen_req();
1391 let mut resp = gen_resp();
1392 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp));
1393 assert_eq!(resp.status.as_u16(), 200);
1394
1395 let mut req = gen_req();
1397 req.insert_header("Range", "bytes=0-1").unwrap();
1398 let mut resp = gen_resp();
1399 assert_eq!(
1400 RangeType::new_single(0, 2),
1401 range_header_filter(&req, &mut resp)
1402 );
1403 assert_eq!(resp.status.as_u16(), 206);
1404 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1405 assert_eq!(
1406 resp.headers.get("content-range").unwrap().as_bytes(),
1407 b"bytes 0-1/10"
1408 );
1409
1410 let mut req = gen_req();
1412 req.insert_header("Range", "bytes=1-0").unwrap();
1413 let mut resp = gen_resp();
1414 assert_eq!(RangeType::Invalid, range_header_filter(&req, &mut resp));
1415 assert_eq!(resp.status.as_u16(), 416);
1416 assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"0");
1417 assert_eq!(
1418 resp.headers.get("content-range").unwrap().as_bytes(),
1419 b"bytes */10"
1420 );
1421 }
1422
1423 #[test]
1425 fn test_range_filter_multipart() {
1426 fn gen_req() -> RequestHeader {
1427 let mut req: RequestHeader =
1428 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1429 req.append_header("Range", "bytes=0-1,3-4,6-7").unwrap();
1430 req
1431 }
1432 fn gen_req_overlap_range() -> RequestHeader {
1433 let mut req: RequestHeader =
1434 RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1435 req.append_header("Range", "bytes=0-3,2-5,7-8").unwrap();
1436 req
1437 }
1438 fn gen_resp() -> ResponseHeader {
1439 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1440 resp.append_header("Content-Length", "10").unwrap();
1441 resp
1442 }
1443
1444 let req = gen_req();
1446 let mut resp = gen_resp();
1447 let result = range_header_filter(&req, &mut resp);
1448 let mut boundary_str = String::new();
1449
1450 assert!(matches!(result, RangeType::Multi(_)));
1451 if let RangeType::Multi(multi_part_info) = result {
1452 assert_eq!(multi_part_info.ranges.len(), 3);
1453 assert_eq!(multi_part_info.ranges[0], Range { start: 0, end: 2 });
1454 assert_eq!(multi_part_info.ranges[1], Range { start: 3, end: 5 });
1455 assert_eq!(multi_part_info.ranges[2], Range { start: 6, end: 8 });
1456 assert!(multi_part_info.content_type.is_some());
1458 assert_eq!(multi_part_info.total_length, 10);
1459 assert!(!multi_part_info.boundary.is_empty());
1460 boundary_str = multi_part_info.boundary;
1461 }
1462 assert_eq!(resp.status.as_u16(), 206);
1463 assert_eq!(
1465 resp.headers.get("content-type").unwrap().to_str().unwrap(),
1466 format!("multipart/byteranges; boundary={boundary_str}")
1467 );
1468 assert!(resp.headers.get("content_length").is_none());
1469
1470 let req = gen_req_overlap_range();
1472 let mut resp = gen_resp();
1473 let result = range_header_filter(&req, &mut resp);
1474
1475 assert!(matches!(result, RangeType::None));
1476 assert_eq!(resp.status.as_u16(), 200);
1477 assert!(resp.headers.get("content-type").is_none());
1478
1479 let mut req = gen_req();
1481 req.insert_header("Range", "bytes=1-0, 12-9, 50-40")
1482 .unwrap();
1483 let mut resp = gen_resp();
1484 let result = range_header_filter(&req, &mut resp);
1485 assert!(matches!(result, RangeType::Invalid));
1486 assert_eq!(resp.status.as_u16(), 416);
1487 }
1488
1489 #[test]
1490 fn test_if_range() {
1491 const DATE: &str = "Fri, 07 Jul 2023 22:03:29 GMT";
1492 const ETAG: &str = "\"1234\"";
1493
1494 fn gen_req() -> RequestHeader {
1495 let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1496 req.append_header("Range", "bytes=0-1").unwrap();
1497 req
1498 }
1499 fn get_multipart_req() -> RequestHeader {
1500 let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1501 _ = req.append_header("Range", "bytes=0-1,3-4,6-7");
1502 req
1503 }
1504 fn gen_resp() -> ResponseHeader {
1505 let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1506 resp.append_header("Content-Length", "10").unwrap();
1507 resp.append_header("Last-Modified", DATE).unwrap();
1508 resp.append_header("ETag", ETAG).unwrap();
1509 resp
1510 }
1511
1512 let mut req = gen_req();
1514 req.insert_header("If-Range", DATE).unwrap();
1515 let mut resp = gen_resp();
1516 assert_eq!(
1517 RangeType::new_single(0, 2),
1518 range_header_filter(&req, &mut resp)
1519 );
1520
1521 let mut req = gen_req();
1523 req.insert_header("If-Range", "Fri, 07 Jul 2023 22:03:25 GMT")
1524 .unwrap();
1525 let mut resp = gen_resp();
1526 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp));
1527
1528 let mut req = gen_req();
1530 req.insert_header("If-Range", ETAG).unwrap();
1531 let mut resp = gen_resp();
1532 assert_eq!(
1533 RangeType::new_single(0, 2),
1534 range_header_filter(&req, &mut resp)
1535 );
1536
1537 let mut req = gen_req();
1539 req.insert_header("If-Range", "\"4567\"").unwrap();
1540 let mut resp = gen_resp();
1541 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp));
1542
1543 let mut req = gen_req();
1544 req.insert_header("If-Range", "1234").unwrap();
1545 let mut resp = gen_resp();
1546 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp));
1547
1548 let mut req = get_multipart_req();
1550 req.insert_header("If-Range", DATE).unwrap();
1551 let mut resp = gen_resp();
1552 let result = range_header_filter(&req, &mut resp);
1553 assert!(matches!(result, RangeType::Multi(_)));
1554 assert_eq!(resp.status.as_u16(), 206);
1555
1556 let req = get_multipart_req();
1558 let mut resp = gen_resp();
1559 assert!(matches!(
1560 range_header_filter(&req, &mut resp),
1561 RangeType::Multi(_)
1562 ));
1563
1564 let mut req = get_multipart_req();
1566 req.insert_header("If-Range", "\"wrong\"").unwrap();
1567 let mut resp = gen_resp();
1568 assert_eq!(RangeType::None, range_header_filter(&req, &mut resp));
1569 assert_eq!(resp.status.as_u16(), 200);
1570 }
1571
1572 pub struct RangeBodyFilter {
1573 pub range: RangeType,
1574 current: usize,
1575 multipart_idx: Option<usize>,
1576 }
1577
1578 impl Default for RangeBodyFilter {
1579 fn default() -> Self {
1580 Self::new()
1581 }
1582 }
1583
1584 impl RangeBodyFilter {
1585 pub fn new() -> Self {
1586 RangeBodyFilter {
1587 range: RangeType::None,
1588 current: 0,
1589 multipart_idx: None,
1590 }
1591 }
1592
1593 pub fn set(&mut self, range: RangeType) {
1594 self.range = range.clone();
1595 if let RangeType::Multi(_) = self.range {
1596 self.multipart_idx = Some(0);
1597 }
1598 }
1599
1600 pub fn finalize(&self, boundary: &String) -> Option<Bytes> {
1602 if let RangeType::Multi(_) = self.range {
1603 Some(Bytes::from(format!("\r\n--{boundary}--\r\n")))
1604 } else {
1605 None
1606 }
1607 }
1608
1609 pub fn filter_body(&mut self, data: Option<Bytes>) -> Option<Bytes> {
1610 match &self.range {
1611 RangeType::None => data,
1612 RangeType::Invalid => None,
1613 RangeType::Single(r) => {
1614 let current = self.current;
1615 self.current += data.as_ref().map_or(0, |d| d.len());
1616 data.and_then(|d| Self::filter_range_data(r.start, r.end, current, d))
1617 }
1618
1619 RangeType::Multi(_) => {
1620 let data = data?;
1621 let current = self.current;
1622 let data_len = data.len();
1623 self.current += data_len;
1624 self.filter_multi_range_body(data, current, data_len)
1625 }
1626 }
1627 }
1628
1629 fn filter_range_data(
1630 start: usize,
1631 end: usize,
1632 current: usize,
1633 data: Bytes,
1634 ) -> Option<Bytes> {
1635 if current + data.len() < start || current >= end {
1636 None
1638 } else if current >= start && current + data.len() <= end {
1639 Some(data)
1641 } else {
1642 let slice_start = start.saturating_sub(current);
1645 let slice_end = std::cmp::min(data.len(), end - current);
1646 Some(data.slice(slice_start..slice_end))
1647 }
1648 }
1649
1650 fn build_multipart_header(
1652 &self,
1653 range: &Range<usize>,
1654 boundary: &str,
1655 total_length: &usize,
1656 content_type: Option<&str>,
1657 ) -> Bytes {
1658 Bytes::from(format!(
1659 "\r\n--{}\r\n{}Content-Range: bytes {}-{}/{}\r\n\r\n",
1660 boundary,
1661 content_type.map_or(String::new(), |ct| format!("Content-Type: {ct}\r\n")),
1662 range.start,
1663 range.end - 1,
1664 total_length
1665 ))
1666 }
1667
1668 fn current_chunk_includes_range_start(
1670 &self,
1671 range: &Range<usize>,
1672 current: usize,
1673 data_len: usize,
1674 ) -> bool {
1675 range.start >= current && range.start < current + data_len
1676 }
1677
1678 fn current_chunk_includes_range_end(
1680 &self,
1681 range: &Range<usize>,
1682 current: usize,
1683 data_len: usize,
1684 ) -> bool {
1685 range.end > current && range.end <= current + data_len
1686 }
1687
1688 fn filter_multi_range_body(
1689 &mut self,
1690 data: Bytes,
1691 current: usize,
1692 data_len: usize,
1693 ) -> Option<Bytes> {
1694 let mut result = BytesMut::new();
1695
1696 let RangeType::Multi(multi_part_info) = &self.range else {
1697 return None;
1698 };
1699
1700 let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1701 let final_range = multi_part_info.ranges.last()?;
1702
1703 let (_, remaining_ranges) = multi_part_info.ranges.as_slice().split_at(multipart_idx);
1704 for range in remaining_ranges {
1707 if let Some(sliced) =
1708 Self::filter_range_data(range.start, range.end, current, data.clone())
1709 {
1710 if self.current_chunk_includes_range_start(range, current, data_len) {
1711 result.extend_from_slice(&self.build_multipart_header(
1712 range,
1713 multi_part_info.boundary.as_ref(),
1714 &multi_part_info.total_length,
1715 multi_part_info.content_type.as_deref(),
1716 ));
1717 }
1718 result.extend_from_slice(&sliced);
1720 if self.current_chunk_includes_range_end(range, current, data_len) {
1721 if range == final_range {
1723 if let Some(final_chunk) = self.finalize(&multi_part_info.boundary) {
1724 result.extend_from_slice(&final_chunk);
1725 }
1726 }
1727 self.multipart_idx = Some(self.multipart_idx.expect("must be set") + 1);
1729 }
1730 } else {
1731 break;
1735 }
1736 }
1737 if result.is_empty() {
1738 None
1739 } else {
1740 Some(result.freeze())
1741 }
1742 }
1743 }
1744
1745 #[test]
1746 fn test_range_body_filter_single() {
1747 let mut body_filter = RangeBodyFilter::new();
1748 assert_eq!(body_filter.filter_body(Some("123".into())).unwrap(), "123");
1749
1750 let mut body_filter = RangeBodyFilter::new();
1751 body_filter.set(RangeType::Invalid);
1752 assert!(body_filter.filter_body(Some("123".into())).is_none());
1753
1754 let mut body_filter = RangeBodyFilter::new();
1755 body_filter.set(RangeType::new_single(0, 1));
1756 assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "0");
1757 assert!(body_filter.filter_body(Some("345".into())).is_none());
1758
1759 let mut body_filter = RangeBodyFilter::new();
1760 body_filter.set(RangeType::new_single(4, 6));
1761 assert!(body_filter.filter_body(Some("012".into())).is_none());
1762 assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "45");
1763 assert!(body_filter.filter_body(Some("678".into())).is_none());
1764
1765 let mut body_filter = RangeBodyFilter::new();
1766 body_filter.set(RangeType::new_single(1, 7));
1767 assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "12");
1768 assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "345");
1769 assert_eq!(body_filter.filter_body(Some("678".into())).unwrap(), "6");
1770 }
1771
1772 #[test]
1773 fn test_range_body_filter_multipart() {
1774 let data = Bytes::from("0123456789");
1776 let ranges = vec![0..3, 6..9];
1777 let content_length = data.len();
1778 let mut body_filter = RangeBodyFilter::new();
1779 body_filter.set(RangeType::new_multi(ranges.clone()));
1780
1781 body_filter
1782 .range
1783 .update_multirange_info(content_length, None);
1784
1785 let multi_range_info = body_filter
1786 .range
1787 .get_multirange_info()
1788 .cloned()
1789 .expect("Multipart Ranges should have MultiPartInfo struct");
1790
1791 let output = body_filter.filter_body(Some(data)).unwrap();
1793 let footer = body_filter.finalize(&multi_range_info.boundary).unwrap();
1794
1795 let output_str = str::from_utf8(&output).unwrap();
1797 let final_boundary = str::from_utf8(&footer).unwrap();
1798 let boundary = &multi_range_info.boundary;
1799
1800 for (i, range) in ranges.iter().enumerate() {
1802 let header = &format!(
1803 "--{}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
1804 boundary,
1805 range.start,
1806 range.end - 1,
1807 content_length
1808 );
1809 assert!(
1810 output_str.contains(header),
1811 "Missing part header {} in multipart body",
1812 i
1813 );
1814 let expected_body = &"0123456789"[range.clone()];
1816 assert!(
1817 output_str.contains(expected_body),
1818 "Missing body {} for range {:?}",
1819 expected_body,
1820 range
1821 )
1822 }
1823 assert_eq!(final_boundary, format!("\r\n--{}--\r\n", boundary));
1825
1826 let full_body = b"0123456789";
1828 let ranges = vec![0..2, 4..6, 8..9];
1829 let content_length = full_body.len();
1830 let content_type = "text/plain".to_string();
1831 let mut body_filter = RangeBodyFilter::new();
1832 body_filter.set(RangeType::new_multi(ranges.clone()));
1833
1834 body_filter
1835 .range
1836 .update_multirange_info(content_length, Some(content_type.clone()));
1837
1838 let multi_range_info = body_filter
1839 .range
1840 .get_multirange_info()
1841 .cloned()
1842 .expect("Multipart Ranges should have MultiPartInfo struct");
1843
1844 let chunk1 = Bytes::from_static(b"012");
1846 let chunk2 = Bytes::from_static(b"345");
1847 let chunk3 = Bytes::from_static(b"678");
1848 let chunk4 = Bytes::from_static(b"9");
1849
1850 let mut collected_bytes = BytesMut::new();
1851 for chunk in [chunk1, chunk2, chunk3, chunk4] {
1852 if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
1853 collected_bytes.extend_from_slice(&filtered);
1854 }
1855 }
1856 if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
1857 collected_bytes.extend_from_slice(&final_boundary);
1858 }
1859
1860 let output_str = str::from_utf8(&collected_bytes).unwrap();
1861 let boundary = multi_range_info.boundary;
1862
1863 for (i, range) in ranges.iter().enumerate() {
1864 let header = &format!(
1865 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
1866 boundary,
1867 content_type,
1868 range.start,
1869 range.end - 1,
1870 content_length
1871 );
1872 let expected_body = &full_body[range.clone()];
1873 let expected_output = format!("{}{}", header, str::from_utf8(expected_body).unwrap());
1874
1875 assert!(
1876 output_str.contains(&expected_output),
1877 "Missing or malformed part {} in multipart body. \n Expected: \n{}\n Got: \n{}",
1878 i,
1879 expected_output,
1880 output_str
1881 )
1882 }
1883
1884 assert!(
1885 output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
1886 "Missing final boundary"
1887 );
1888
1889 let full_body = b"abcdefghijkl";
1891 let ranges = vec![2..7, 9..11];
1892 let content_length = full_body.len();
1893 let content_type = "application/octet-stream".to_string();
1894 let mut body_filter = RangeBodyFilter::new();
1895 body_filter.set(RangeType::new_multi(ranges.clone()));
1896
1897 body_filter
1898 .range
1899 .update_multirange_info(content_length, Some(content_type.clone()));
1900
1901 let multi_range_info = body_filter
1902 .range
1903 .clone()
1904 .get_multirange_info()
1905 .cloned()
1906 .expect("Multipart Ranges should have MultiPartInfo struct");
1907
1908 let chunk1 = Bytes::from_static(b"abc");
1910 let chunk2 = Bytes::from_static(b"def");
1911 let chunk3 = Bytes::from_static(b"ghi");
1912 let chunk4 = Bytes::from_static(b"jkl");
1913
1914 let mut collected_bytes = BytesMut::new();
1915 for chunk in [chunk1, chunk2, chunk3, chunk4] {
1916 if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
1917 collected_bytes.extend_from_slice(&filtered);
1918 }
1919 }
1920 if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
1921 collected_bytes.extend_from_slice(&final_boundary);
1922 }
1923
1924 let output_str = str::from_utf8(&collected_bytes).unwrap();
1925 let boundary = &multi_range_info.boundary;
1926
1927 let header1 = &format!(
1928 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
1929 boundary,
1930 content_type,
1931 ranges[0].start,
1932 ranges[0].end - 1,
1933 content_length
1934 );
1935 let header2 = &format!(
1936 "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
1937 boundary,
1938 content_type,
1939 ranges[1].start,
1940 ranges[1].end - 1,
1941 content_length
1942 );
1943
1944 assert!(output_str.contains(header1));
1945 assert!(output_str.contains(header2));
1946
1947 let expected_body_slices = ["cdefg", "jk"];
1948
1949 assert!(
1950 output_str.contains(expected_body_slices[0]),
1951 "Missing expected sliced body {}",
1952 expected_body_slices[0]
1953 );
1954
1955 assert!(
1956 output_str.contains(expected_body_slices[1]),
1957 "Missing expected sliced body {}",
1958 expected_body_slices[1]
1959 );
1960
1961 assert!(
1962 output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
1963 "Missing final boundary"
1964 );
1965 }
1966}
1967
1968#[derive(Debug)]
1971pub(crate) enum ServeFromCache {
1972 Off, CacheHeader, CacheHeaderOnly, CacheBody(bool), CacheHeaderMiss, CacheBodyMiss(bool), Done, }
1980
1981impl ServeFromCache {
1982 pub fn new() -> Self {
1983 Self::Off
1984 }
1985
1986 pub fn is_on(&self) -> bool {
1987 !matches!(self, Self::Off)
1988 }
1989
1990 pub fn is_miss(&self) -> bool {
1991 matches!(self, Self::CacheHeaderMiss | Self::CacheBodyMiss(_))
1992 }
1993
1994 pub fn is_miss_header(&self) -> bool {
1995 matches!(self, Self::CacheHeaderMiss)
1996 }
1997
1998 pub fn is_miss_body(&self) -> bool {
1999 matches!(self, Self::CacheBodyMiss(_))
2000 }
2001
2002 pub fn should_discard_upstream(&self) -> bool {
2003 self.is_on() && !self.is_miss()
2004 }
2005
2006 pub fn should_send_to_downstream(&self) -> bool {
2007 !self.is_on()
2008 }
2009
2010 pub fn enable(&mut self) {
2011 *self = Self::CacheHeader;
2012 }
2013
2014 pub fn enable_miss(&mut self) {
2015 if !self.is_on() {
2016 *self = Self::CacheHeaderMiss;
2017 }
2018 }
2019
2020 pub fn enable_header_only(&mut self) {
2021 match self {
2022 Self::CacheBody(_) | Self::CacheBodyMiss(_) => *self = Self::Done, _ => *self = Self::CacheHeaderOnly,
2024 }
2025 }
2026
2027 pub async fn next_http_task(
2029 &mut self,
2030 cache: &mut HttpCache,
2031 range: &mut RangeBodyFilter,
2032 ) -> Result<HttpTask> {
2033 if !cache.enabled() {
2034 return Error::e_explain(InternalError, "Cache disabled");
2038 }
2039 match self {
2040 Self::Off => panic!("ProxyUseCache not enabled"),
2041 Self::CacheHeader => {
2042 *self = Self::CacheBody(true);
2043 Ok(HttpTask::Header(cache_hit_header(cache), false)) }
2045 Self::CacheHeaderMiss => {
2046 *self = Self::CacheBodyMiss(true);
2047 Ok(HttpTask::Header(cache_hit_header(cache), false)) }
2049 Self::CacheHeaderOnly => {
2050 *self = Self::Done;
2051 Ok(HttpTask::Header(cache_hit_header(cache), true))
2052 }
2053 Self::CacheBody(should_seek) => {
2054 if *should_seek {
2055 self.maybe_seek_hit_handler(cache, range)?;
2056 }
2057 if let Some(b) = cache.hit_handler().read_body().await? {
2058 Ok(HttpTask::Body(Some(b), false)) } else {
2060 *self = Self::Done;
2061 Ok(HttpTask::Done)
2062 }
2063 }
2064 Self::CacheBodyMiss(should_seek) => {
2065 if *should_seek {
2066 self.maybe_seek_miss_handler(cache, range)?;
2067 }
2068 if let Some(b) = cache.miss_body_reader().unwrap().read_body().await? {
2070 Ok(HttpTask::Body(Some(b), false)) } else {
2072 *self = Self::Done;
2073 Ok(HttpTask::Done)
2074 }
2075 }
2076 Self::Done => Ok(HttpTask::Done),
2077 }
2078 }
2079
2080 fn maybe_seek_miss_handler(
2081 &mut self,
2082 cache: &mut HttpCache,
2083 range_filter: &mut RangeBodyFilter,
2084 ) -> Result<()> {
2085 if let RangeType::Single(range) = &range_filter.range {
2086 if cache.miss_body_reader().unwrap().can_seek() {
2088 cache
2089 .miss_body_reader()
2090 .unwrap()
2092 .seek(range.start, Some(range.end))
2093 .or_err(InternalError, "cannot seek miss handler")?;
2094 range_filter.range = RangeType::None;
2097 }
2098 }
2099 *self = Self::CacheBodyMiss(false);
2100 Ok(())
2101 }
2102
2103 fn maybe_seek_hit_handler(
2104 &mut self,
2105 cache: &mut HttpCache,
2106 range_filter: &mut RangeBodyFilter,
2107 ) -> Result<()> {
2108 match &range_filter.range {
2109 RangeType::Single(range) => {
2110 if cache.hit_handler().can_seek() {
2111 cache
2112 .hit_handler()
2113 .seek(range.start, Some(range.end))
2114 .or_err(InternalError, "cannot seek hit handler")?;
2115 range_filter.range = RangeType::None;
2118 }
2119 }
2120 RangeType::Multi(_) => {
2121 }
2125 _ => {}
2126 }
2127 *self = Self::CacheBody(false);
2128 Ok(())
2129 }
2130}