Skip to main content

pingora_proxy/
proxy_cache.rs

1// Copyright 2026 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::*;
16use http::header::{CONTENT_LENGTH, CONTENT_TYPE};
17use http::{Method, StatusCode};
18use pingora_cache::key::CacheHashKey;
19use pingora_cache::lock::LockStatus;
20use pingora_cache::max_file_size::ERR_RESPONSE_TOO_LARGE;
21use pingora_cache::{ForcedFreshness, HitHandler, HitStatus, RespCacheable::*};
22use pingora_core::protocols::http::conditional_filter::to_304;
23use pingora_core::protocols::http::v1::common::header_value_content_length;
24use pingora_core::ErrorType;
25use range_filter::RangeBodyFilter;
26use std::time::SystemTime;
27
28impl<SV, C> HttpProxy<SV, C>
29where
30    C: custom::Connector,
31{
32    // return bool: server_session can be reused, and error if any
33    pub(crate) async fn proxy_cache(
34        self: &Arc<Self>,
35        session: &mut Session,
36        ctx: &mut SV::CTX,
37    ) -> Option<(bool, Option<Box<Error>>)>
38    // None: continue to proxy, Some: return
39    where
40        SV: ProxyHttp + Send + Sync + 'static,
41        SV::CTX: Send + Sync,
42    {
43        // Cache logic request phase
44        if let Err(e) = self.inner.request_cache_filter(session, ctx) {
45            // TODO: handle this error
46            warn!(
47                "Fail to request_cache_filter: {e}, {}",
48                self.inner.request_summary(session, ctx)
49            );
50        }
51
52        // cache key logic, should this be part of request_cache_filter?
53        if session.cache.enabled() {
54            match self.inner.cache_key_callback(session, ctx) {
55                Ok(key) => {
56                    session.cache.set_cache_key(key);
57                }
58                Err(e) => {
59                    // TODO: handle this error
60                    session.cache.disable(NoCacheReason::StorageError);
61                    warn!(
62                        "Fail to cache_key_callback: {e}, {}",
63                        self.inner.request_summary(session, ctx)
64                    );
65                }
66            }
67        }
68
69        // cache purge logic: PURGE short-circuits rest of request
70        if self.inner.is_purge(session, ctx) {
71            return self.proxy_purge(session, ctx).await;
72        }
73
74        // bypass cache lookup if we predict to be uncacheable
75        if session.cache.enabled() && !session.cache.cacheable_prediction() {
76            session.cache.bypass();
77        }
78
79        if !session.cache.enabled() {
80            return None;
81        }
82
83        // cache lookup logic
84        loop {
85            // for cache lock, TODO: cap the max number of loops
86            match session.cache.cache_lookup().await {
87                Ok(res) => {
88                    let mut hit_status_opt = None;
89                    if let Some((mut meta, mut handler)) = res {
90                        // Vary logic
91                        // Because this branch can be called multiple times in a loop, and we only
92                        // need to update the vary once, check if variance is already set to
93                        // prevent unnecessary vary lookups.
94                        let cache_key = session.cache.cache_key();
95                        if let Some(variance) = cache_key.variance_bin() {
96                            // We've looked up a secondary slot.
97                            // Adhoc double check that the variance found is the variance we want.
98                            if Some(variance) != meta.variance() {
99                                warn!("Cache variance mismatch, {variance:?}, {cache_key:?}");
100                                session.cache.disable(NoCacheReason::InternalError);
101                                break None;
102                            }
103                        } else {
104                            // Basic cache key; either variance is off, or this is the primary slot.
105                            let req_header = session.req_header();
106                            let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
107                            if let Some(variance) = variance {
108                                // Variance is on. This is the primary slot.
109                                if !session.cache.cache_vary_lookup(variance, &meta) {
110                                    // This wasn't the desired variant. Updated cache key variance, cause another
111                                    // lookup to get the desired variant, which would be in a secondary slot.
112                                    continue;
113                                }
114                            } // else: vary is not in use
115                        }
116
117                        // Either no variance, or the current handler targets the correct variant.
118
119                        // hit
120                        // TODO: maybe round and/or cache now()
121                        let is_fresh = meta.is_fresh(SystemTime::now());
122                        // check if we should force expire or force miss
123                        let hit_status = match self
124                            .inner
125                            .cache_hit_filter(session, &meta, &mut handler, is_fresh, ctx)
126                            .await
127                        {
128                            Err(e) => {
129                                error!(
130                                    "Failed to filter cache hit: {e}, {}",
131                                    self.inner.request_summary(session, ctx)
132                                );
133                                // this return value will cause us to fetch from upstream
134                                HitStatus::FailedHitFilter
135                            }
136                            Ok(None) => {
137                                if is_fresh {
138                                    HitStatus::Fresh
139                                } else {
140                                    HitStatus::Expired
141                                }
142                            }
143                            Ok(Some(ForcedFreshness::ForceExpired)) => {
144                                // force expired asset should not be serve as stale
145                                // because force expire is usually to remove data
146                                meta.disable_serve_stale();
147                                HitStatus::ForceExpired
148                            }
149                            Ok(Some(ForcedFreshness::ForceMiss)) => HitStatus::ForceMiss,
150                            Ok(Some(ForcedFreshness::ForceFresh)) => HitStatus::Fresh,
151                        };
152
153                        hit_status_opt = Some(hit_status);
154
155                        // init cache for hit / stale
156                        session.cache.cache_found(meta, handler, hit_status);
157                    }
158
159                    if hit_status_opt.is_none_or(HitStatus::is_treated_as_miss) {
160                        // cache miss
161                        if session.cache.is_cache_locked() {
162                            // Another request is filling the cache; try waiting til that's done and retry.
163                            let lock_status = session.cache.cache_lock_wait().await;
164                            if self.handle_lock_status(session, ctx, lock_status) {
165                                continue;
166                            } else {
167                                break None;
168                            }
169                        } else {
170                            self.inner.cache_miss(session, ctx);
171                            break None;
172                        }
173                    }
174
175                    // Safe because an empty hit status would have broken out
176                    // in the block above
177                    let hit_status = hit_status_opt.expect("None case handled as miss");
178
179                    if !hit_status.is_fresh() {
180                        // expired or force expired asset
181                        if session.cache.is_cache_locked() {
182                            // first if this is the sub request for the background cache update
183                            if let Some(write_lock) = session
184                                .subrequest_ctx
185                                .as_mut()
186                                .and_then(|ctx| ctx.take_write_lock())
187                            {
188                                // Put the write lock in the request
189                                session.cache.set_write_lock(write_lock);
190                                session.cache.tag_as_subrequest();
191                                // and then let it go to upstream
192                                break None;
193                            }
194                            let will_serve_stale = session.cache.can_serve_stale_updating()
195                                && self.inner.should_serve_stale(session, ctx, None);
196                            if !will_serve_stale {
197                                let lock_status = session.cache.cache_lock_wait().await;
198                                if self.handle_lock_status(session, ctx, lock_status) {
199                                    continue;
200                                } else {
201                                    break None;
202                                }
203                            }
204                            // else continue to serve stale
205                            session.cache.set_stale_updating();
206                        } else if session.cache.is_cache_lock_writer() {
207                            // stale while revalidate logic for the writer
208                            let will_serve_stale = session.cache.can_serve_stale_updating()
209                                && self.inner.should_serve_stale(session, ctx, None);
210                            if will_serve_stale {
211                                // create a background thread to do the actual update
212                                // the subrequest handle is only None by this phase in unit tests
213                                // that don't go through process_new_http
214                                let (permit, cache_lock) = session.cache.take_write_lock();
215                                SubrequestSpawner::new(self.clone()).spawn_background_subrequest(
216                                    session.as_ref(),
217                                    subrequest::Ctx::builder()
218                                        .cache_write_lock(
219                                            cache_lock,
220                                            session.cache.cache_key().clone(),
221                                            permit,
222                                        )
223                                        .build(),
224                                );
225                                // continue to serve stale for this request
226                                session.cache.set_stale_updating();
227                            } else {
228                                // return to fetch from upstream
229                                break None;
230                            }
231                        } else {
232                            // return to fetch from upstream
233                            break None;
234                        }
235                    }
236
237                    let (reuse, err) = self.proxy_cache_hit(session, ctx).await;
238                    if let Some(e) = err.as_ref() {
239                        error!(
240                            "Fail to serve cache: {e}, {}",
241                            self.inner.request_summary(session, ctx)
242                        );
243                    }
244                    // responses is served from cache, exit
245                    break Some((reuse, err));
246                }
247                Err(e) => {
248                    // Allow cache miss to fill cache even if cache lookup errors
249                    // this is mostly to support backward incompatible metadata update
250                    // TODO: check error types
251                    // session.cache.disable();
252                    self.inner.cache_miss(session, ctx);
253                    warn!(
254                        "Fail to cache lookup: {e}, {}",
255                        self.inner.request_summary(session, ctx)
256                    );
257                    break None;
258                }
259            }
260        }
261    }
262
263    // return bool: server_session can be reused, and error if any
264    pub(crate) async fn proxy_cache_hit(
265        &self,
266        session: &mut Session,
267        ctx: &mut SV::CTX,
268    ) -> (bool, Option<Box<Error>>)
269    where
270        SV: ProxyHttp + Send + Sync,
271        SV::CTX: Send + Sync,
272    {
273        use range_filter::*;
274
275        let seekable = session.cache.hit_handler().can_seek();
276        let mut header = cache_hit_header(&session.cache);
277
278        let req = session.req_header();
279
280        let not_modified = match self.inner.cache_not_modified_filter(session, &header, ctx) {
281            Ok(not_modified) => not_modified,
282            Err(e) => {
283                // fail open if cache_not_modified_filter errors,
284                // just return the whole original response
285                warn!(
286                    "Failed to run cache not modified filter: {e}, {}",
287                    self.inner.request_summary(session, ctx)
288                );
289                false
290            }
291        };
292        if not_modified {
293            to_304(&mut header);
294        }
295        let header_only = not_modified || req.method == http::method::Method::HEAD;
296
297        // process range header if the cache storage supports seek
298        let range_type = if seekable && !session.ignore_downstream_range {
299            self.inner.range_header_filter(session, &mut header, ctx)
300        } else {
301            RangeType::None
302        };
303
304        // return a 416 with an empty body for simplicity
305        let header_only = header_only || matches!(range_type, RangeType::Invalid);
306        debug!("header: {header:?}");
307
308        // TODO: use ProxyUseCache to replace the logic below
309        match self.inner.response_filter(session, &mut header, ctx).await {
310            Ok(_) => {
311                if let Err(e) = session
312                    .downstream_modules_ctx
313                    .response_header_filter(&mut header, header_only)
314                    .await
315                {
316                    error!(
317                        "Failed to run downstream modules response header filter in hit: {e}, {}",
318                        self.inner.request_summary(session, ctx)
319                    );
320                    session
321                        .as_mut()
322                        .respond_error(500)
323                        .await
324                        .unwrap_or_else(|e| {
325                            error!("failed to send error response to downstream: {e}");
326                        });
327                    // we have not write anything dirty to downstream, it is still reusable
328                    return (true, Some(e));
329                }
330
331                if let Err(e) = session
332                    .as_mut()
333                    .write_response_header(header)
334                    .await
335                    .map_err(|e| e.into_down())
336                {
337                    // downstream connection is bad already
338                    return (false, Some(e));
339                }
340            }
341            Err(e) => {
342                error!(
343                    "Failed to run response filter in hit: {e}, {}",
344                    self.inner.request_summary(session, ctx)
345                );
346                session
347                    .as_mut()
348                    .respond_error(500)
349                    .await
350                    .unwrap_or_else(|e| {
351                        error!("failed to send error response to downstream: {e}");
352                    });
353                // we have not write anything dirty to downstream, it is still reusable
354                return (true, Some(e));
355            }
356        }
357        debug!("finished sending cached header to downstream");
358
359        // If the function returns an Err, there was an issue seeking from the hit handler.
360        //
361        // Returning false means that no seeking or state change was done, either because the
362        // hit handler doesn't support the seek or because multipart doesn't apply.
363        fn seek_multipart(
364            hit_handler: &mut HitHandler,
365            range_filter: &mut RangeBodyFilter,
366        ) -> Result<bool> {
367            if !range_filter.is_multipart_range() || !hit_handler.can_seek_multipart() {
368                return Ok(false);
369            }
370            let r = range_filter.next_cache_multipart_range();
371            hit_handler.seek_multipart(r.start, Some(r.end))?;
372            // we still need RangeBodyFilter's help to transform the byte
373            // range into a multipart response.
374            range_filter.set_current_cursor(r.start);
375            Ok(true)
376        }
377
378        if !header_only {
379            let mut maybe_range_filter = match &range_type {
380                RangeType::Single(r) => {
381                    if session.cache.hit_handler().can_seek() {
382                        if let Err(e) = session.cache.hit_handler().seek(r.start, Some(r.end)) {
383                            return (false, Some(e));
384                        }
385                        None
386                    } else {
387                        Some(RangeBodyFilter::new_range(range_type.clone()))
388                    }
389                }
390                RangeType::Multi(_) => {
391                    let mut range_filter = RangeBodyFilter::new_range(range_type.clone());
392                    if let Err(e) = seek_multipart(session.cache.hit_handler(), &mut range_filter) {
393                        return (false, Some(e));
394                    }
395                    Some(range_filter)
396                }
397                RangeType::Invalid => unreachable!(),
398                RangeType::None => None,
399            };
400            loop {
401                match session.cache.hit_handler().read_body().await {
402                    Ok(raw_body) => {
403                        let end = raw_body.is_none();
404
405                        if end {
406                            if let Some(range_filter) = maybe_range_filter.as_mut() {
407                                if range_filter.should_cache_seek_again() {
408                                    let e = match seek_multipart(
409                                        session.cache.hit_handler(),
410                                        range_filter,
411                                    ) {
412                                        Ok(true) => {
413                                            // called seek(), read again
414                                            continue;
415                                        }
416                                        Ok(false) => {
417                                            // body reader can no longer seek multipart,
418                                            // but cache wants to continue seeking
419                                            // the body will just end in this case if we pass the
420                                            // None through
421                                            // (TODO: how might hit handlers want to recover from
422                                            // this situation)?
423                                            Error::explain(
424                                                InternalError,
425                                                "hit handler cannot seek for multipart again",
426                                            )
427                                            // the body will just end in this case.
428                                        }
429                                        Err(e) => e,
430                                    };
431                                    return (false, Some(e));
432                                }
433                            }
434                        }
435
436                        let mut body = if let Some(range_filter) = maybe_range_filter.as_mut() {
437                            range_filter.filter_body(raw_body)
438                        } else {
439                            raw_body
440                        };
441
442                        match self
443                            .inner
444                            .response_body_filter(session, &mut body, end, ctx)
445                        {
446                            Ok(Some(duration)) => {
447                                trace!("delaying response for {duration:?}");
448                                time::sleep(duration).await;
449                            }
450                            Ok(None) => { /* continue */ }
451                            Err(e) => {
452                                // body is being sent, don't treat downstream as reusable
453                                return (false, Some(e));
454                            }
455                        }
456
457                        if let Err(e) = session
458                            .downstream_modules_ctx
459                            .response_body_filter(&mut body, end)
460                        {
461                            // body is being sent, don't treat downstream as reusable
462                            return (false, Some(e));
463                        }
464
465                        if !end && body.as_ref().is_none_or(|b| b.is_empty()) {
466                            // Don't write empty body which will end session,
467                            // still more hit handler bytes to read
468                            continue;
469                        }
470
471                        // write to downstream
472                        let b = body.unwrap_or_default();
473                        if let Err(e) = session
474                            .as_mut()
475                            .write_response_body(b, end)
476                            .await
477                            .map_err(|e| e.into_down())
478                        {
479                            return (false, Some(e));
480                        }
481                        if end {
482                            break;
483                        }
484                    }
485                    Err(e) => return (false, Some(e)),
486                }
487            }
488        }
489
490        if let Err(e) = session.cache.finish_hit_handler().await {
491            warn!("Error during finish_hit_handler: {}", e);
492        }
493
494        match session.as_mut().finish_body().await {
495            Ok(_) => {
496                debug!("finished sending cached body to downstream");
497                (true, None)
498            }
499            Err(e) => (false, Some(e)),
500        }
501    }
502
503    /* Downstream revalidation, only needed when cache is on because otherwise origin
504     * will handle it */
505    pub(crate) fn downstream_response_conditional_filter(
506        &self,
507        use_cache: &mut ServeFromCache,
508        session: &Session,
509        resp: &mut ResponseHeader,
510        ctx: &mut SV::CTX,
511    ) where
512        SV: ProxyHttp,
513    {
514        // TODO: range
515        let req = session.req_header();
516
517        let not_modified = match self.inner.cache_not_modified_filter(session, resp, ctx) {
518            Ok(not_modified) => not_modified,
519            Err(e) => {
520                // fail open if cache_not_modified_filter errors,
521                // just return the whole original response
522                warn!(
523                    "Failed to run cache not modified filter: {e}, {}",
524                    self.inner.request_summary(session, ctx)
525                );
526                false
527            }
528        };
529
530        if not_modified {
531            to_304(resp);
532        }
533        let header_only = not_modified || req.method == http::method::Method::HEAD;
534        if header_only && use_cache.is_on() {
535            // tell cache to stop serving downstream after yielding header
536            // (misses will continue to allow admitting upstream into cache)
537            use_cache.enable_header_only();
538        }
539    }
540
541    // TODO: cache upstream header filter to add/remove headers
542
543    pub(crate) async fn cache_http_task(
544        &self,
545        session: &mut Session,
546        task: &HttpTask,
547        ctx: &mut SV::CTX,
548        serve_from_cache: &mut ServeFromCache,
549    ) -> Result<()>
550    where
551        SV: ProxyHttp + Send + Sync,
552        SV::CTX: Send + Sync,
553    {
554        if !session.cache.enabled() && !session.cache.bypassing() {
555            return Ok(());
556        }
557
558        match task {
559            HttpTask::Header(header, end_stream) => {
560                // decide if cacheable and create cache meta
561                // for now, skip 1xxs (should not affect response cache decisions)
562                // However 101 is an exception because it is the final response header
563                if header.status.is_informational()
564                    && header.status != StatusCode::SWITCHING_PROTOCOLS
565                {
566                    return Ok(());
567                }
568                match self.inner.response_cache_filter(session, header, ctx)? {
569                    Cacheable(meta) => {
570                        let mut fill_cache = true;
571                        if session.cache.bypassing() {
572                            // The cache might have been bypassed because the response exceeded the
573                            // maximum cacheable asset size. If that looks like the case (there
574                            // is a maximum file size configured and we don't know the content
575                            // length up front), attempting to re-enable the cache now would cause
576                            // the request to fail when the chunked response exceeds the maximum
577                            // file size again.
578                            if session.cache.max_file_size_bytes().is_some()
579                                && !meta.headers().contains_key(header::CONTENT_LENGTH)
580                            {
581                                session
582                                    .cache
583                                    .disable(NoCacheReason::PredictedResponseTooLarge);
584                                return Ok(());
585                            }
586
587                            session.cache.response_became_cacheable();
588
589                            if session.req_header().method == Method::GET
590                                && meta.response_header().status == StatusCode::OK
591                            {
592                                self.inner.cache_miss(session, ctx);
593                                if !session.cache.enabled() {
594                                    fill_cache = false;
595                                }
596                            } else {
597                                // we've allowed caching on the next request,
598                                // but do not cache _this_ request if bypassed and not 200
599                                // (We didn't run upstream request cache filters to strip range or condition headers,
600                                // so this could be an uncacheable response e.g. 206 or 304 or HEAD.
601                                // Exclude all non-200/GET for simplicity, may expand allowable codes in the future.)
602                                fill_cache = false;
603                                session.cache.disable(NoCacheReason::Deferred);
604                            }
605                        }
606
607                        // If the Content-Length is known, and a maximum asset size has been configured
608                        // on the cache, validate that the response does not exceed the maximum asset size.
609                        if session.cache.enabled() {
610                            if let Some(max_file_size) = session.cache.max_file_size_bytes() {
611                                let content_length_hdr = meta.headers().get(header::CONTENT_LENGTH);
612                                if let Some(content_length) =
613                                    header_value_content_length(content_length_hdr)
614                                {
615                                    if content_length > max_file_size {
616                                        fill_cache = false;
617                                        session.cache.response_became_uncacheable(
618                                            NoCacheReason::ResponseTooLarge,
619                                        );
620                                        session.cache.disable(NoCacheReason::ResponseTooLarge);
621                                        // too large to cache, disable ranging
622                                        session.ignore_downstream_range = true;
623                                    }
624                                }
625                                // if the content-length header is not specified, the miss handler
626                                // will count the response size on the fly, aborting the request
627                                // mid-transfer if the max file size is exceeded
628                            }
629                        }
630                        if fill_cache {
631                            let req_header = session.req_header();
632                            // Update the variance in the meta via the same callback,
633                            // cache_vary_filter(), used in cache lookup for consistency.
634                            // Future cache lookups need a matching variance in the meta
635                            // with the cache key to pick up the correct variance
636                            let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
637                            session.cache.set_cache_meta(meta);
638                            session.cache.update_variance(variance);
639                            // this sends the meta and header
640                            session.cache.set_miss_handler().await?;
641                            if session.cache.miss_body_reader().is_some() {
642                                serve_from_cache.enable_miss();
643                            }
644                            if *end_stream {
645                                session
646                                    .cache
647                                    .miss_handler()
648                                    .unwrap() // safe, it is set above
649                                    .write_body(Bytes::new(), true)
650                                    .await?;
651                                session.cache.finish_miss_handler().await?;
652                            }
653                        }
654                    }
655                    Uncacheable(reason) => {
656                        if !session.cache.bypassing() {
657                            // mark as uncacheable, so we bypass cache next time
658                            session.cache.response_became_uncacheable(reason);
659                        }
660                        session.cache.disable(reason);
661                    }
662                }
663            }
664            HttpTask::Body(data, end_stream) => match data {
665                Some(d) => {
666                    if session.cache.enabled() {
667                        // TODO: do this async
668                        // fail if writing the body would exceed the max_file_size_bytes
669                        let body_size_allowed =
670                            session.cache.track_body_bytes_for_max_file_size(d.len());
671                        if !body_size_allowed {
672                            debug!("chunked response exceeded max cache size, remembering that it is uncacheable");
673                            session
674                                .cache
675                                .response_became_uncacheable(NoCacheReason::ResponseTooLarge);
676
677                            return Error::e_explain(
678                                ERR_RESPONSE_TOO_LARGE,
679                                format!(
680                                    "writing data of size {} bytes would exceed max file size of {} bytes",
681                                    d.len(),
682                                    session.cache.max_file_size_bytes().expect("max file size bytes must be set to exceed size")
683                                ),
684                            );
685                        }
686
687                        // this will panic if more data is sent after we see end_stream
688                        // but should be impossible in real world
689                        let miss_handler = session.cache.miss_handler().unwrap();
690
691                        miss_handler.write_body(d.clone(), *end_stream).await?;
692                        if *end_stream {
693                            session.cache.finish_miss_handler().await?;
694                        }
695                    }
696                }
697                None => {
698                    if session.cache.enabled() && *end_stream {
699                        session.cache.finish_miss_handler().await?;
700                    }
701                }
702            },
703            HttpTask::Trailer(_) => {} // h1 trailer is not supported yet
704            HttpTask::Done => {
705                if session.cache.enabled() {
706                    session.cache.finish_miss_handler().await?;
707                }
708            }
709            HttpTask::Failed(_) => {
710                // TODO: handle this failure: delete the temp files?
711            }
712        }
713        Ok(())
714    }
715
716    // Decide if local cache can be used according to upstream http header
717    // 1. when upstream returns 304, the local cache is refreshed and served fresh
718    // 2. when upstream returns certain HTTP error status, the local cache is served stale
719    // Return true if local cache should be used, false otherwise
720    pub(crate) async fn revalidate_or_stale(
721        &self,
722        session: &mut Session,
723        task: &mut HttpTask,
724        ctx: &mut SV::CTX,
725    ) -> bool
726    where
727        SV: ProxyHttp + Send + Sync,
728        SV::CTX: Send + Sync,
729    {
730        if !session.cache.enabled() {
731            return false;
732        }
733
734        match task {
735            HttpTask::Header(resp, _eos) => {
736                if resp.status == StatusCode::NOT_MODIFIED {
737                    if session.cache.maybe_cache_meta().is_some() {
738                        // run upstream response filters on upstream 304 first
739                        if let Err(err) = self
740                            .inner
741                            .upstream_response_filter(session, resp, ctx)
742                            .await
743                        {
744                            error!("upstream response filter error on 304: {err:?}");
745                            session.cache.revalidate_uncacheable(
746                                *resp.clone(),
747                                NoCacheReason::InternalError,
748                            );
749                            // always serve from cache after receiving the 304
750                            return true;
751                        }
752                        // 304 doesn't contain all the headers, merge 304 into cached 200 header
753                        // in order for response_cache_filter to run correctly
754                        let merged_header = session.cache.revalidate_merge_header(resp);
755                        match self
756                            .inner
757                            .response_cache_filter(session, &merged_header, ctx)
758                        {
759                            Ok(Cacheable(mut meta)) => {
760                                // For simplicity, ignore changes to variance over 304 for now.
761                                // Note this means upstream can only update variance via 2xx
762                                // (expired response).
763                                //
764                                // TODO: if we choose to respect changing Vary / variance over 304,
765                                // then there are a few cases to consider. See `update_variance` in
766                                // the `pingora-cache` module.
767                                let old_meta = session.cache.maybe_cache_meta().unwrap(); // safe, checked above
768                                if let Some(old_variance) = old_meta.variance() {
769                                    meta.set_variance(old_variance);
770                                }
771                                if let Err(e) = session.cache.revalidate_cache_meta(meta).await {
772                                    // Fail open: we can continue use the revalidated response even
773                                    // if the meta failed to write to storage
774                                    warn!("revalidate_cache_meta failed {e:?}");
775                                }
776                            }
777                            Ok(Uncacheable(reason)) => {
778                                // This response was once cacheable, and upstream tells us it has not changed
779                                // but now we decided it is uncacheable!
780                                // RFC 9111: still allowed to reuse stored response this time because
781                                // it was "successfully validated"
782                                // https://www.rfc-editor.org/rfc/rfc9111#constructing.responses.from.caches
783                                // Serve the response, but do not update cache
784
785                                // We also want to avoid poisoning downstream's cache with an unsolicited 304
786                                // if we did not receive a conditional request from downstream
787                                // (downstream may have a different cacheability assessment and could cache the 304)
788
789                                //TODO: log more
790                                debug!("Uncacheable {reason:?} 304 received");
791                                session.cache.response_became_uncacheable(reason);
792                                session.cache.revalidate_uncacheable(merged_header, reason);
793                            }
794                            Err(e) => {
795                                // Error during revalidation, similarly to the reasons above
796                                // (avoid poisoning downstream cache with passthrough 304),
797                                // allow serving the stored response without updating cache
798                                warn!("Error {e:?} response_cache_filter during revalidation");
799                                session.cache.revalidate_uncacheable(
800                                    merged_header,
801                                    NoCacheReason::InternalError,
802                                );
803                                // Assume the next 304 may succeed, so don't mark uncacheable
804                            }
805                        }
806                        // always serve from cache after receiving the 304
807                        true
808                    } else {
809                        //TODO: log more
810                        warn!("304 received without cached asset, disable caching");
811                        let reason = NoCacheReason::Custom("304 on miss");
812                        session.cache.response_became_uncacheable(reason);
813                        session.cache.disable(reason);
814                        false
815                    }
816                } else if resp.status.is_server_error() {
817                    // stale if error logic, 5xx only for now
818
819                    // this is response header filter, response_written should always be None?
820                    if !session.cache.can_serve_stale_error()
821                        || session.response_written().is_some()
822                    {
823                        return false;
824                    }
825
826                    // create an error to encode the http status code
827                    let http_status_error = Error::create(
828                        ErrorType::HTTPStatus(resp.status.as_u16()),
829                        ErrorSource::Upstream,
830                        None,
831                        None,
832                    );
833                    if self
834                        .inner
835                        .should_serve_stale(session, ctx, Some(&http_status_error))
836                    {
837                        // no more need to keep the write lock
838                        session
839                            .cache
840                            .release_write_lock(NoCacheReason::UpstreamError);
841                        true
842                    } else {
843                        false
844                    }
845                } else {
846                    false // not 304, not stale if error status code
847                }
848            }
849            _ => false, // not header
850        }
851    }
852
853    // None: no staled asset is used, Some(_): staled asset is sent to downstream
854    // bool: can the downstream connection be reused
855    pub(crate) async fn handle_stale_if_error(
856        &self,
857        session: &mut Session,
858        ctx: &mut SV::CTX,
859        error: &Error,
860    ) -> Option<(bool, Option<Box<Error>>)>
861    where
862        SV: ProxyHttp + Send + Sync,
863        SV::CTX: Send + Sync,
864    {
865        // the caller might already checked this as an optimization
866        if !session.cache.can_serve_stale_error() {
867            return None;
868        }
869
870        // the error happen halfway through a regular response to downstream
871        // can't resend the response
872        if session.response_written().is_some() {
873            return None;
874        }
875
876        // check error types
877        if !self.inner.should_serve_stale(session, ctx, Some(error)) {
878            return None;
879        }
880
881        // log the original error
882        warn!(
883            "Fail to proxy: {}, serving stale, {}",
884            error,
885            self.inner.request_summary(session, ctx)
886        );
887
888        // no more need to hang onto the cache lock
889        session
890            .cache
891            .release_write_lock(NoCacheReason::UpstreamError);
892
893        Some(self.proxy_cache_hit(session, ctx).await)
894    }
895
896    // helper function to check when to continue to retry lock (true) or give up (false)
897    fn handle_lock_status(
898        &self,
899        session: &mut Session,
900        ctx: &SV::CTX,
901        lock_status: LockStatus,
902    ) -> bool
903    where
904        SV: ProxyHttp,
905    {
906        debug!("cache unlocked {lock_status:?}");
907        match lock_status {
908            // should lookup the cached asset again
909            LockStatus::Done => true,
910            // should compete to be a new writer
911            LockStatus::TransientError => true,
912            // the request is uncacheable, go ahead to fetch from the origin
913            LockStatus::GiveUp => {
914                // TODO: It will be nice for the writer to propagate the real reason
915                session.cache.disable(NoCacheReason::CacheLockGiveUp);
916                // not cacheable, just go to the origin.
917                false
918            }
919            // treat this the same as TransientError
920            LockStatus::Dangling => {
921                // software bug, but request can recover from this
922                warn!(
923                    "Dangling cache lock, {}",
924                    self.inner.request_summary(session, ctx)
925                );
926                true
927            }
928            // If this reader has spent too long waiting on locks, let the request
929            // through while disabling cache (to avoid amplifying disk writes).
930            LockStatus::WaitTimeout => {
931                warn!(
932                    "Cache lock timeout, {}",
933                    self.inner.request_summary(session, ctx)
934                );
935                session.cache.disable(NoCacheReason::CacheLockTimeout);
936                // not cacheable, just go to the origin.
937                false
938            }
939            // When a singular cache lock has been held for too long,
940            // we should allow requests to recompete for the lock
941            // to protect upstreams from load.
942            LockStatus::AgeTimeout => true,
943            // software bug, this status should be impossible to reach
944            LockStatus::Waiting => panic!("impossible LockStatus::Waiting"),
945        }
946    }
947}
948
949fn cache_hit_header(cache: &HttpCache) -> Box<ResponseHeader> {
950    let mut header = Box::new(cache.cache_meta().response_header_copy());
951    // convert cache response
952
953    // these status codes / method cannot have body, so no need to add chunked encoding
954    let no_body = matches!(header.status.as_u16(), 204 | 304);
955
956    // https://www.rfc-editor.org/rfc/rfc9111#section-4:
957    // When a stored response is used to satisfy a request without validation, a cache
958    // MUST generate an Age header field
959    if !cache.upstream_used() {
960        let age = cache.cache_meta().age().as_secs();
961        header.insert_header(http::header::AGE, age).unwrap();
962    }
963    log::debug!("cache header: {header:?} {:?}", cache.phase());
964
965    // currently storage cache is always considered an h1 upstream
966    // (header-serde serializes as h1.0 or h1.1)
967    // set this header to be h1.1
968    header.set_version(Version::HTTP_11);
969
970    /* Add chunked header to tell downstream to use chunked encoding
971     * during the absent of content-length in h2 */
972    if !no_body
973        && !header.status.is_informational()
974        && header.headers.get(http::header::CONTENT_LENGTH).is_none()
975    {
976        header
977            .insert_header(http::header::TRANSFER_ENCODING, "chunked")
978            .unwrap();
979    }
980    header
981}
982
983// https://datatracker.ietf.org/doc/html/rfc7233#section-3
984pub mod range_filter {
985    use super::*;
986    use bytes::BytesMut;
987    use http::header::*;
988    use std::ops::Range;
989
990    // parse bytes into usize, ignores specific error
991    fn parse_number(input: &[u8]) -> Option<usize> {
992        str::from_utf8(input).ok()?.parse().ok()
993    }
994
995    fn parse_range_header(
996        range: &[u8],
997        content_length: usize,
998        max_multipart_ranges: Option<usize>,
999    ) -> RangeType {
1000        use regex::Regex;
1001
1002        // Match individual range parts, (e.g. "0-100", "-5", "1-")
1003        static RE_SINGLE_RANGE_PART: Lazy<Regex> =
1004            Lazy::new(|| Regex::new(r"(?i)^\s*(?P<start>\d*)-(?P<end>\d*)\s*$").unwrap());
1005
1006        // Convert bytes to UTF-8 string
1007        let range_str = match str::from_utf8(range) {
1008            Ok(s) => s,
1009            Err(_) => return RangeType::None,
1010        };
1011
1012        // Split into "bytes=" and the actual range(s)
1013        let mut parts = range_str.splitn(2, "=");
1014
1015        // Check if it starts with "bytes="
1016        let prefix = parts.next();
1017        if !prefix.is_some_and(|s| s.eq_ignore_ascii_case("bytes")) {
1018            return RangeType::None;
1019        }
1020
1021        let Some(ranges_str) = parts.next() else {
1022            // No ranges provided
1023            return RangeType::None;
1024        };
1025
1026        // Get the actual range string (e.g."100-200,300-400")
1027        let mut range_count = 0;
1028        for _ in ranges_str.split(',') {
1029            range_count += 1;
1030            if let Some(max_ranges) = max_multipart_ranges {
1031                if range_count >= max_ranges {
1032                    // If we get more than max configured ranges, return None for now to save parsing time
1033                    return RangeType::None;
1034                }
1035            }
1036        }
1037        let mut ranges: Vec<Range<usize>> = Vec::with_capacity(range_count);
1038
1039        // Process each range
1040        let mut last_range_end = 0;
1041        for part in ranges_str.split(',') {
1042            let captured = match RE_SINGLE_RANGE_PART.captures(part) {
1043                Some(c) => c,
1044                None => {
1045                    return RangeType::None;
1046                }
1047            };
1048
1049            let maybe_start = captured
1050                .name("start")
1051                .and_then(|s| s.as_str().parse::<usize>().ok());
1052            let end = captured
1053                .name("end")
1054                .and_then(|s| s.as_str().parse::<usize>().ok());
1055
1056            let range = if let Some(start) = maybe_start {
1057                if start >= content_length {
1058                    // Skip the invalid range
1059                    continue;
1060                }
1061                // open-ended range should end at the last byte
1062                // over sized end is allowed but ignored
1063                // range end is inclusive
1064                let end = std::cmp::min(end.unwrap_or(content_length - 1), content_length - 1) + 1;
1065                if end <= start {
1066                    // Skip the invalid range
1067                    continue;
1068                }
1069                start..end
1070            } else {
1071                // start is empty, this changes the meaning of the value of `end`
1072                // Now it means to read the last `end` bytes
1073                if let Some(end) = end {
1074                    if content_length >= end {
1075                        (content_length - end)..content_length
1076                    } else {
1077                        // over sized end is allowed but ignored
1078                        0..content_length
1079                    }
1080                } else {
1081                    // No start or end, skip the invalid range
1082                    continue;
1083                }
1084            };
1085            // For now we stick to non-overlapping, ascending ranges for simplicity
1086            // and parity with nginx
1087            if range.start < last_range_end {
1088                return RangeType::None;
1089            }
1090            last_range_end = range.end;
1091            ranges.push(range);
1092        }
1093
1094        // Note for future: we can technically coalesce multiple ranges for multipart
1095        //
1096        // https://www.rfc-editor.org/rfc/rfc9110#section-17.15
1097        // "Servers ought to ignore, coalesce, or reject egregious range
1098        // requests, such as requests for more than two overlapping ranges or
1099        // for many small ranges in a single set, particularly when the ranges
1100        // are requested out of order for no apparent reason. Multipart range
1101        // requests are not designed to support random access."
1102
1103        if ranges.is_empty() {
1104            // We got some ranges, processed them but none were valid
1105            RangeType::Invalid
1106        } else if ranges.len() == 1 {
1107            RangeType::Single(ranges[0].clone()) // Only 1 index
1108        } else {
1109            RangeType::Multi(MultiRangeInfo::new(ranges))
1110        }
1111    }
1112    #[test]
1113    fn test_parse_range() {
1114        assert_eq!(
1115            parse_range_header(b"bytes=0-1", 10, None),
1116            RangeType::new_single(0, 2)
1117        );
1118        assert_eq!(
1119            parse_range_header(b"bYTes=0-9", 10, None),
1120            RangeType::new_single(0, 10)
1121        );
1122        assert_eq!(
1123            parse_range_header(b"bytes=0-12", 10, None),
1124            RangeType::new_single(0, 10)
1125        );
1126        assert_eq!(
1127            parse_range_header(b"bytes=0-", 10, None),
1128            RangeType::new_single(0, 10)
1129        );
1130        assert_eq!(
1131            parse_range_header(b"bytes=2-1", 10, None),
1132            RangeType::Invalid
1133        );
1134        assert_eq!(
1135            parse_range_header(b"bytes=10-11", 10, None),
1136            RangeType::Invalid
1137        );
1138        assert_eq!(
1139            parse_range_header(b"bytes=-2", 10, None),
1140            RangeType::new_single(8, 10)
1141        );
1142        assert_eq!(
1143            parse_range_header(b"bytes=-12", 10, None),
1144            RangeType::new_single(0, 10)
1145        );
1146        assert_eq!(parse_range_header(b"bytes=-", 10, None), RangeType::Invalid);
1147        assert_eq!(parse_range_header(b"bytes=", 10, None), RangeType::None);
1148    }
1149
1150    // Add some tests for multi-range too
1151    #[test]
1152    fn test_parse_range_header_multi() {
1153        assert_eq!(
1154            parse_range_header(b"bytes=0-1,4-5", 10, None)
1155                .get_multirange_info()
1156                .expect("Should have multipart info for Multipart range request")
1157                .ranges,
1158            (vec![Range { start: 0, end: 2 }, Range { start: 4, end: 6 }])
1159        );
1160        // Last range is invalid because the content-length is too small
1161        assert_eq!(
1162            parse_range_header(b"bytEs=0-99,200-299,400-499", 320, None)
1163                .get_multirange_info()
1164                .expect("Should have multipart info for Multipart range request")
1165                .ranges,
1166            (vec![
1167                Range { start: 0, end: 100 },
1168                Range {
1169                    start: 200,
1170                    end: 300
1171                }
1172            ])
1173        );
1174        // Same as above but appropriate content length
1175        assert_eq!(
1176            parse_range_header(b"bytEs=0-99,200-299,400-499", 500, None)
1177                .get_multirange_info()
1178                .expect("Should have multipart info for Multipart range request")
1179                .ranges,
1180            vec![
1181                Range { start: 0, end: 100 },
1182                Range {
1183                    start: 200,
1184                    end: 300
1185                },
1186                Range {
1187                    start: 400,
1188                    end: 500
1189                },
1190            ]
1191        );
1192        // Looks like a range request but it is continuous, we decline to range
1193        assert_eq!(
1194            parse_range_header(b"bytes=0-,-2", 10, None),
1195            RangeType::None,
1196        );
1197        // Should not have multirange info set
1198        assert!(parse_range_header(b"bytes=0-,-2", 10, None)
1199            .get_multirange_info()
1200            .is_none());
1201        // Overlapping ranges, these ranges are currently declined
1202        assert_eq!(
1203            parse_range_header(b"bytes=0-3,2-5", 10, None),
1204            RangeType::None,
1205        );
1206        assert!(parse_range_header(b"bytes=0-3,2-5", 10, None)
1207            .get_multirange_info()
1208            .is_none());
1209
1210        // Content length is 2, so only range is 0-2.
1211        assert_eq!(
1212            parse_range_header(b"bytes=0-5,10-", 2, None),
1213            RangeType::new_single(0, 2)
1214        );
1215        assert!(parse_range_header(b"bytes=0-5,10-", 2, None)
1216            .get_multirange_info()
1217            .is_none());
1218
1219        // We should ignore the last incorrect range and return the other acceptable ranges
1220        assert_eq!(
1221            parse_range_header(b"bytes=0-5, 10-20, 30-18", 200, None)
1222                .get_multirange_info()
1223                .expect("Should have multipart info for Multipart range request")
1224                .ranges,
1225            vec![Range { start: 0, end: 6 }, Range { start: 10, end: 21 },]
1226        );
1227        // All invalid ranges
1228        assert_eq!(
1229            parse_range_header(b"bytes=5-0, 20-15, 30-25", 200, None),
1230            RangeType::Invalid
1231        );
1232
1233        // Helper function to generate a large number of ranges for the next test
1234        fn generate_range_header(count: usize) -> Vec<u8> {
1235            let mut s = String::from("bytes=");
1236            for i in 0..count {
1237                let start = i * 4;
1238                let end = start + 1;
1239                if i > 0 {
1240                    s.push(',');
1241                }
1242                s.push_str(&start.to_string());
1243                s.push('-');
1244                s.push_str(&end.to_string());
1245            }
1246            s.into_bytes()
1247        }
1248
1249        // Test 200 range limit for parsing.
1250        let ranges = generate_range_header(201);
1251        assert_eq!(
1252            parse_range_header(&ranges, 1000, Some(200)),
1253            RangeType::None
1254        )
1255    }
1256
1257    // For Multipart Requests, we need to know the boundary, content length and type across
1258    // the headers and the body. So let us store this information as part of the range
1259    #[derive(Debug, Eq, PartialEq, Clone)]
1260    pub struct MultiRangeInfo {
1261        pub ranges: Vec<Range<usize>>,
1262        pub boundary: String,
1263        total_length: usize,
1264        content_type: Option<String>,
1265    }
1266
1267    impl MultiRangeInfo {
1268        // Create a new MultiRangeInfo, when we just have the ranges
1269        pub fn new(ranges: Vec<Range<usize>>) -> Self {
1270            Self {
1271                ranges,
1272                // Directly create boundary string on initialization
1273                boundary: Self::generate_boundary(),
1274                total_length: 0,
1275                content_type: None,
1276            }
1277        }
1278        pub fn set_content_type(&mut self, content_type: String) {
1279            self.content_type = Some(content_type)
1280        }
1281        pub fn set_total_length(&mut self, total_length: usize) {
1282            self.total_length = total_length;
1283        }
1284        // Per [RFC 9110](https://www.rfc-editor.org/rfc/rfc9110.html#multipart.byteranges),
1285        // we need generate a boundary string for each body part.
1286        // Per [RFC 2046](https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1), the boundary should be no longer than 70 characters
1287        // and it must not match the body content.
1288        fn generate_boundary() -> String {
1289            use rand::Rng;
1290            let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
1291            format!("{:016x}", rng.gen::<u64>())
1292        }
1293        pub fn calculate_multipart_length(&self) -> usize {
1294            let mut total_length = 0;
1295            let content_type = self.content_type.as_ref();
1296            for range in self.ranges.clone() {
1297                // Each part should have
1298                // \r\n--boundary\r\n                         --> 4 + boundary.len() (16) + 2 = 20
1299                // Content-Type: original-content-type\r\n    --> 14 + content_type.len() + 2
1300                // Content-Range: bytes start-end/total\r\n   --> Variable +2
1301                // \r\n                                       --> 2
1302                // [data]                                     --> data.len()
1303                total_length += 4 + self.boundary.len() + 2;
1304                total_length += content_type.map_or(0, |ct| 14 + ct.len() + 2);
1305                total_length += format!(
1306                    "Content-Range: bytes {}-{}/{}",
1307                    range.start,
1308                    range.end - 1,
1309                    self.total_length
1310                )
1311                .len()
1312                    + 2;
1313                total_length += 2;
1314                total_length += range.end - range.start;
1315            }
1316            // Final boundary: "\r\n--<boundary>--\r\n"
1317            total_length += 4 + self.boundary.len() + 4;
1318            total_length
1319        }
1320    }
1321    #[derive(Debug, Eq, PartialEq, Clone)]
1322    pub enum RangeType {
1323        None,
1324        Single(Range<usize>),
1325        Multi(MultiRangeInfo),
1326        Invalid,
1327    }
1328
1329    impl RangeType {
1330        // Helper functions for tests
1331        #[allow(dead_code)]
1332        fn new_single(start: usize, end: usize) -> Self {
1333            RangeType::Single(Range { start, end })
1334        }
1335        #[allow(dead_code)]
1336        pub fn new_multi(ranges: Vec<Range<usize>>) -> Self {
1337            RangeType::Multi(MultiRangeInfo::new(ranges))
1338        }
1339        #[allow(dead_code)]
1340        fn get_multirange_info(&self) -> Option<&MultiRangeInfo> {
1341            match self {
1342                RangeType::Multi(multi_range_info) => Some(multi_range_info),
1343                _ => None,
1344            }
1345        }
1346        #[allow(dead_code)]
1347        fn update_multirange_info(&mut self, content_length: usize, content_type: Option<String>) {
1348            if let RangeType::Multi(multipart_range_info) = self {
1349                multipart_range_info.content_type = content_type;
1350                multipart_range_info.set_total_length(content_length);
1351            }
1352        }
1353    }
1354
1355    // Handles both single-range and multipart-range requests
1356    pub fn range_header_filter(
1357        req: &RequestHeader,
1358        resp: &mut ResponseHeader,
1359        max_multipart_ranges: Option<usize>,
1360    ) -> RangeType {
1361        // The Range header field is evaluated after evaluating the precondition
1362        // header fields defined in [RFC7232], and only if the result in absence
1363        // of the Range header field would be a 200 (OK) response
1364        if resp.status != StatusCode::OK {
1365            return RangeType::None;
1366        }
1367
1368        // Content-Length is not required by RFC but it is what nginx does and easier to implement
1369        // with this header present.
1370        let Some(content_length_bytes) = resp.headers.get(CONTENT_LENGTH) else {
1371            return RangeType::None;
1372        };
1373        // bail on invalid content length
1374        let Some(content_length) = parse_number(content_length_bytes.as_bytes()) else {
1375            return RangeType::None;
1376        };
1377
1378        // At this point the response is allowed to be served as ranges
1379        // TODO: we can also check Accept-Range header from resp. Nginx gives uses the option
1380        // see proxy_force_ranges
1381
1382        fn request_range_type(
1383            req: &RequestHeader,
1384            resp: &ResponseHeader,
1385            content_length: usize,
1386            max_multipart_ranges: Option<usize>,
1387        ) -> RangeType {
1388            // "A server MUST ignore a Range header field received with a request method other than GET."
1389            if req.method != http::Method::GET && req.method != http::Method::HEAD {
1390                return RangeType::None;
1391            }
1392
1393            let Some(range_header) = req.headers.get(RANGE) else {
1394                return RangeType::None;
1395            };
1396
1397            // if-range wants to understand if the Last-Modified / ETag value matches exactly for use
1398            // with resumable downloads.
1399            // https://datatracker.ietf.org/doc/html/rfc9110#name-if-range
1400            // Note that the RFC wants strong validation, and suggests that
1401            // "A valid entity-tag can be distinguished from a valid HTTP-date
1402            // by examining the first three characters for a DQUOTE,"
1403            // but this current etag matching behavior most closely mirrors nginx.
1404            if let Some(if_range) = req.headers.get(IF_RANGE) {
1405                let ir = if_range.as_bytes();
1406                let matches = if ir.len() >= 2 && ir.last() == Some(&b'"') {
1407                    resp.headers.get(ETAG).is_some_and(|etag| etag == if_range)
1408                } else if let Some(last_modified) = resp.headers.get(LAST_MODIFIED) {
1409                    last_modified == if_range
1410                } else {
1411                    false
1412                };
1413                if !matches {
1414                    return RangeType::None;
1415                }
1416            }
1417
1418            parse_range_header(
1419                range_header.as_bytes(),
1420                content_length,
1421                max_multipart_ranges,
1422            )
1423        }
1424
1425        let mut range_type = request_range_type(req, resp, content_length, max_multipart_ranges);
1426
1427        match &mut range_type {
1428            RangeType::None => {
1429                // At this point, the response is _eligible_ to be served in ranges
1430                // in the future, so add Accept-Ranges, mirroring nginx behavior
1431                resp.insert_header(&ACCEPT_RANGES, "bytes").unwrap();
1432            }
1433            RangeType::Single(r) => {
1434                // 206 response
1435                resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1436                resp.remove_header(&ACCEPT_RANGES);
1437                resp.insert_header(&CONTENT_LENGTH, r.end - r.start)
1438                    .unwrap();
1439                resp.insert_header(
1440                    &CONTENT_RANGE,
1441                    format!("bytes {}-{}/{content_length}", r.start, r.end - 1), // range end is inclusive
1442                )
1443                .unwrap()
1444            }
1445
1446            RangeType::Multi(multi_range_info) => {
1447                let content_type = resp
1448                    .headers
1449                    .get(CONTENT_TYPE)
1450                    .and_then(|v| v.to_str().ok())
1451                    .unwrap_or("application/octet-stream");
1452                // Update multipart info
1453                multi_range_info.set_total_length(content_length);
1454                multi_range_info.set_content_type(content_type.to_string());
1455
1456                let total_length = multi_range_info.calculate_multipart_length();
1457
1458                resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1459                resp.remove_header(&ACCEPT_RANGES);
1460                resp.insert_header(CONTENT_LENGTH, total_length).unwrap();
1461                resp.insert_header(
1462                    CONTENT_TYPE,
1463                    format!(
1464                        "multipart/byteranges; boundary={}",
1465                        multi_range_info.boundary
1466                    ), // RFC 2046
1467                )
1468                .unwrap();
1469                resp.remove_header(&CONTENT_RANGE);
1470            }
1471            RangeType::Invalid => {
1472                // 416 response
1473                resp.set_status(StatusCode::RANGE_NOT_SATISFIABLE).unwrap();
1474                // empty body for simplicity
1475                resp.insert_header(&CONTENT_LENGTH, HeaderValue::from_static("0"))
1476                    .unwrap();
1477                resp.remove_header(&ACCEPT_RANGES);
1478                // TODO: remove other headers like content-encoding
1479                resp.remove_header(&CONTENT_TYPE);
1480                resp.insert_header(&CONTENT_RANGE, format!("bytes */{content_length}"))
1481                    .unwrap()
1482            }
1483        }
1484
1485        range_type
1486    }
1487
1488    #[test]
1489    fn test_range_filter_single() {
1490        fn gen_req() -> RequestHeader {
1491            RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap()
1492        }
1493        fn gen_resp() -> ResponseHeader {
1494            let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1495            resp.append_header("Content-Length", "10").unwrap();
1496            resp
1497        }
1498
1499        // no range
1500        let req = gen_req();
1501        let mut resp = gen_resp();
1502        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1503        assert_eq!(resp.status.as_u16(), 200);
1504        assert_eq!(
1505            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1506            b"bytes"
1507        );
1508
1509        // no range, try HEAD
1510        let mut req = gen_req();
1511        req.method = Method::HEAD;
1512        let mut resp = gen_resp();
1513        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1514        assert_eq!(resp.status.as_u16(), 200);
1515        assert_eq!(
1516            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1517            b"bytes"
1518        );
1519
1520        // regular range
1521        let mut req = gen_req();
1522        req.insert_header("Range", "bytes=0-1").unwrap();
1523        let mut resp = gen_resp();
1524        assert_eq!(
1525            RangeType::new_single(0, 2),
1526            range_header_filter(&req, &mut resp, None)
1527        );
1528        assert_eq!(resp.status.as_u16(), 206);
1529        assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1530        assert_eq!(
1531            resp.headers.get("content-range").unwrap().as_bytes(),
1532            b"bytes 0-1/10"
1533        );
1534        assert!(resp.headers.get("accept-ranges").is_none());
1535
1536        // regular range, accept-ranges included
1537        let mut req = gen_req();
1538        req.insert_header("Range", "bytes=0-1").unwrap();
1539        let mut resp = gen_resp();
1540        resp.insert_header("Accept-Ranges", "bytes").unwrap();
1541        assert_eq!(
1542            RangeType::new_single(0, 2),
1543            range_header_filter(&req, &mut resp, None)
1544        );
1545        assert_eq!(resp.status.as_u16(), 206);
1546        assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1547        assert_eq!(
1548            resp.headers.get("content-range").unwrap().as_bytes(),
1549            b"bytes 0-1/10"
1550        );
1551        // accept-ranges stripped
1552        assert!(resp.headers.get("accept-ranges").is_none());
1553
1554        // bad range
1555        let mut req = gen_req();
1556        req.insert_header("Range", "bytes=1-0").unwrap();
1557        let mut resp = gen_resp();
1558        resp.insert_header("Accept-Ranges", "bytes").unwrap();
1559        assert_eq!(
1560            RangeType::Invalid,
1561            range_header_filter(&req, &mut resp, None)
1562        );
1563        assert_eq!(resp.status.as_u16(), 416);
1564        assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"0");
1565        assert_eq!(
1566            resp.headers.get("content-range").unwrap().as_bytes(),
1567            b"bytes */10"
1568        );
1569        assert!(resp.headers.get("accept-ranges").is_none());
1570    }
1571
1572    // Multipart Tests
1573    #[test]
1574    fn test_range_filter_multipart() {
1575        fn gen_req() -> RequestHeader {
1576            let mut req: RequestHeader =
1577                RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1578            req.append_header("Range", "bytes=0-1,3-4,6-7").unwrap();
1579            req
1580        }
1581        fn gen_req_overlap_range() -> RequestHeader {
1582            let mut req: RequestHeader =
1583                RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1584            req.append_header("Range", "bytes=0-3,2-5,7-8").unwrap();
1585            req
1586        }
1587        fn gen_resp() -> ResponseHeader {
1588            let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1589            resp.append_header("Content-Length", "10").unwrap();
1590            resp
1591        }
1592
1593        // valid multipart range
1594        let req = gen_req();
1595        let mut resp = gen_resp();
1596        let result = range_header_filter(&req, &mut resp, None);
1597        let mut boundary_str = String::new();
1598
1599        assert!(matches!(result, RangeType::Multi(_)));
1600        if let RangeType::Multi(multi_part_info) = result {
1601            assert_eq!(multi_part_info.ranges.len(), 3);
1602            assert_eq!(multi_part_info.ranges[0], Range { start: 0, end: 2 });
1603            assert_eq!(multi_part_info.ranges[1], Range { start: 3, end: 5 });
1604            assert_eq!(multi_part_info.ranges[2], Range { start: 6, end: 8 });
1605            // Verify that multipart info has been set
1606            assert!(multi_part_info.content_type.is_some());
1607            assert_eq!(multi_part_info.total_length, 10);
1608            assert!(!multi_part_info.boundary.is_empty());
1609            boundary_str = multi_part_info.boundary;
1610        }
1611        assert_eq!(resp.status.as_u16(), 206);
1612        // Verify that boundary is the same in header and in multipartinfo
1613        assert_eq!(
1614            resp.headers.get("content-type").unwrap().to_str().unwrap(),
1615            format!("multipart/byteranges; boundary={boundary_str}")
1616        );
1617        assert!(resp.headers.get("content_length").is_none());
1618        assert!(resp.headers.get("accept-ranges").is_none());
1619
1620        // overlapping range, multipart range is declined
1621        let req = gen_req_overlap_range();
1622        let mut resp = gen_resp();
1623        let result = range_header_filter(&req, &mut resp, None);
1624
1625        assert!(matches!(result, RangeType::None));
1626        assert_eq!(resp.status.as_u16(), 200);
1627        assert!(resp.headers.get("content-type").is_none());
1628        assert_eq!(
1629            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1630            b"bytes"
1631        );
1632
1633        // bad multipart range
1634        let mut req = gen_req();
1635        req.insert_header("Range", "bytes=1-0, 12-9, 50-40")
1636            .unwrap();
1637        let mut resp = gen_resp();
1638        let result = range_header_filter(&req, &mut resp, None);
1639        assert!(matches!(result, RangeType::Invalid));
1640        assert_eq!(resp.status.as_u16(), 416);
1641        assert!(resp.headers.get("accept-ranges").is_none());
1642    }
1643
1644    #[test]
1645    fn test_if_range() {
1646        const DATE: &str = "Fri, 07 Jul 2023 22:03:29 GMT";
1647        const ETAG: &str = "\"1234\"";
1648
1649        fn gen_req() -> RequestHeader {
1650            let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1651            req.append_header("Range", "bytes=0-1").unwrap();
1652            req
1653        }
1654        fn get_multipart_req() -> RequestHeader {
1655            let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1656            _ = req.append_header("Range", "bytes=0-1,3-4,6-7");
1657            req
1658        }
1659        fn gen_resp() -> ResponseHeader {
1660            let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1661            resp.append_header("Content-Length", "10").unwrap();
1662            resp.append_header("Last-Modified", DATE).unwrap();
1663            resp.append_header("ETag", ETAG).unwrap();
1664            resp
1665        }
1666
1667        // matching Last-Modified date
1668        let mut req = gen_req();
1669        req.insert_header("If-Range", DATE).unwrap();
1670        let mut resp = gen_resp();
1671        assert_eq!(
1672            RangeType::new_single(0, 2),
1673            range_header_filter(&req, &mut resp, None)
1674        );
1675
1676        // non-matching date
1677        let mut req = gen_req();
1678        req.insert_header("If-Range", "Fri, 07 Jul 2023 22:03:25 GMT")
1679            .unwrap();
1680        let mut resp = gen_resp();
1681        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1682        assert_eq!(resp.status.as_u16(), 200);
1683        assert_eq!(
1684            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1685            b"bytes"
1686        );
1687
1688        // match ETag
1689        let mut req = gen_req();
1690        req.insert_header("If-Range", ETAG).unwrap();
1691        let mut resp = gen_resp();
1692        assert_eq!(
1693            RangeType::new_single(0, 2),
1694            range_header_filter(&req, &mut resp, None)
1695        );
1696        assert_eq!(resp.status.as_u16(), 206);
1697        assert!(resp.headers.get("accept-ranges").is_none());
1698
1699        // non-matching ETags do not result in range
1700        let mut req = gen_req();
1701        req.insert_header("If-Range", "\"4567\"").unwrap();
1702        let mut resp = gen_resp();
1703        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1704        assert_eq!(resp.status.as_u16(), 200);
1705        assert_eq!(
1706            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1707            b"bytes"
1708        );
1709
1710        let mut req = gen_req();
1711        req.insert_header("If-Range", "1234").unwrap();
1712        let mut resp = gen_resp();
1713        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1714        assert_eq!(resp.status.as_u16(), 200);
1715        assert_eq!(
1716            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1717            b"bytes"
1718        );
1719
1720        // multipart range with If-Range
1721        let mut req = get_multipart_req();
1722        req.insert_header("If-Range", DATE).unwrap();
1723        let mut resp = gen_resp();
1724        let result = range_header_filter(&req, &mut resp, None);
1725        assert!(matches!(result, RangeType::Multi(_)));
1726        assert_eq!(resp.status.as_u16(), 206);
1727        assert!(resp.headers.get("accept-ranges").is_none());
1728
1729        // multipart with matching ETag
1730        let req = get_multipart_req();
1731        let mut resp = gen_resp();
1732        assert!(matches!(
1733            range_header_filter(&req, &mut resp, None),
1734            RangeType::Multi(_)
1735        ));
1736
1737        // multipart with non-matching If-Range
1738        let mut req = get_multipart_req();
1739        req.insert_header("If-Range", "\"wrong\"").unwrap();
1740        let mut resp = gen_resp();
1741        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1742        assert_eq!(resp.status.as_u16(), 200);
1743        assert_eq!(
1744            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1745            b"bytes"
1746        );
1747    }
1748
1749    pub struct RangeBodyFilter {
1750        pub range: RangeType,
1751        current: usize,
1752        multipart_idx: Option<usize>,
1753        cache_multipart_idx: Option<usize>,
1754    }
1755
1756    impl Default for RangeBodyFilter {
1757        fn default() -> Self {
1758            Self::new()
1759        }
1760    }
1761
1762    impl RangeBodyFilter {
1763        pub fn new() -> Self {
1764            RangeBodyFilter {
1765                range: RangeType::None,
1766                current: 0,
1767                multipart_idx: None,
1768                cache_multipart_idx: None,
1769            }
1770        }
1771
1772        pub fn new_range(range: RangeType) -> Self {
1773            RangeBodyFilter {
1774                multipart_idx: matches!(range, RangeType::Multi(_)).then_some(0),
1775                range,
1776                ..Default::default()
1777            }
1778        }
1779
1780        pub fn is_multipart_range(&self) -> bool {
1781            matches!(self.range, RangeType::Multi(_))
1782        }
1783
1784        /// Whether we should expect the cache body reader to seek again
1785        /// for a different range.
1786        pub fn should_cache_seek_again(&self) -> bool {
1787            match &self.range {
1788                RangeType::Multi(multipart_info) => self
1789                    .cache_multipart_idx
1790                    .is_some_and(|idx| idx != multipart_info.ranges.len() - 1),
1791                _ => false,
1792            }
1793        }
1794
1795        /// Returns the next multipart range to seek for the cache body reader.
1796        pub fn next_cache_multipart_range(&mut self) -> Range<usize> {
1797            match &self.range {
1798                RangeType::Multi(multipart_info) => {
1799                    match self.cache_multipart_idx.as_mut() {
1800                        Some(v) => *v += 1,
1801                        None => self.cache_multipart_idx = Some(0),
1802                    }
1803                    let cache_multipart_idx = self.cache_multipart_idx.expect("set above");
1804                    let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1805                    // NOTE: currently this assumes once we start seeking multipart from the hit
1806                    // handler, it will continue to return can_seek_multipart true.
1807                    assert_eq!(multipart_idx, cache_multipart_idx,
1808                        "cache multipart idx should match multipart idx, or there is a hit handler bug");
1809                    multipart_info.ranges[cache_multipart_idx].clone()
1810                }
1811                _ => panic!("tried to advance multipart idx on non-multipart range"),
1812            }
1813        }
1814
1815        pub fn set_current_cursor(&mut self, current: usize) {
1816            self.current = current;
1817        }
1818
1819        pub fn set(&mut self, range: RangeType) {
1820            self.multipart_idx = matches!(range, RangeType::Multi(_)).then_some(0);
1821            self.range = range;
1822        }
1823
1824        // Emit final boundary footer for multipart requests
1825        pub fn finalize(&self, boundary: &String) -> Option<Bytes> {
1826            if let RangeType::Multi(_) = self.range {
1827                Some(Bytes::from(format!("\r\n--{boundary}--\r\n")))
1828            } else {
1829                None
1830            }
1831        }
1832
1833        pub fn filter_body(&mut self, data: Option<Bytes>) -> Option<Bytes> {
1834            match &self.range {
1835                RangeType::None => data,
1836                RangeType::Invalid => None,
1837                RangeType::Single(r) => {
1838                    let current = self.current;
1839                    self.current += data.as_ref().map_or(0, |d| d.len());
1840                    data.and_then(|d| Self::filter_range_data(r.start, r.end, current, d))
1841                }
1842
1843                RangeType::Multi(_) => {
1844                    let data = data?;
1845                    let current = self.current;
1846                    let data_len = data.len();
1847                    self.current += data_len;
1848                    self.filter_multi_range_body(data, current, data_len)
1849                }
1850            }
1851        }
1852
1853        fn filter_range_data(
1854            start: usize,
1855            end: usize,
1856            current: usize,
1857            data: Bytes,
1858        ) -> Option<Bytes> {
1859            if current + data.len() < start || current >= end {
1860                // if the current data is out side the desired range, just drop the data
1861                None
1862            } else if current >= start && current + data.len() <= end {
1863                // all data is within the slice
1864                Some(data)
1865            } else {
1866                // data:  current........current+data.len()
1867                // range: start...........end
1868                let slice_start = start.saturating_sub(current);
1869                let slice_end = std::cmp::min(data.len(), end - current);
1870                Some(data.slice(slice_start..slice_end))
1871            }
1872        }
1873
1874        // Returns the multipart header for a given range
1875        fn build_multipart_header(
1876            &self,
1877            range: &Range<usize>,
1878            boundary: &str,
1879            total_length: &usize,
1880            content_type: Option<&str>,
1881        ) -> Bytes {
1882            Bytes::from(format!(
1883                "\r\n--{}\r\n{}Content-Range: bytes {}-{}/{}\r\n\r\n",
1884                boundary,
1885                content_type.map_or(String::new(), |ct| format!("Content-Type: {ct}\r\n")),
1886                range.start,
1887                range.end - 1,
1888                total_length
1889            ))
1890        }
1891
1892        // Return true if chunk includes the start of the given range
1893        fn current_chunk_includes_range_start(
1894            &self,
1895            range: &Range<usize>,
1896            current: usize,
1897            data_len: usize,
1898        ) -> bool {
1899            range.start >= current && range.start < current + data_len
1900        }
1901
1902        // Return true if chunk includes the end of the given range
1903        fn current_chunk_includes_range_end(
1904            &self,
1905            range: &Range<usize>,
1906            current: usize,
1907            data_len: usize,
1908        ) -> bool {
1909            range.end > current && range.end <= current + data_len
1910        }
1911
1912        fn filter_multi_range_body(
1913            &mut self,
1914            data: Bytes,
1915            current: usize,
1916            data_len: usize,
1917        ) -> Option<Bytes> {
1918            let mut result = BytesMut::new();
1919
1920            let RangeType::Multi(multi_part_info) = &self.range else {
1921                return None;
1922            };
1923
1924            let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1925            let final_range = multi_part_info.ranges.last()?;
1926
1927            let (_, remaining_ranges) = multi_part_info.ranges.as_slice().split_at(multipart_idx);
1928            // NOTE: current invariant is that the multipart info ranges are disjoint ascending
1929            // this code is invalid if this invariant is not upheld
1930            for range in remaining_ranges {
1931                if let Some(sliced) =
1932                    Self::filter_range_data(range.start, range.end, current, data.clone())
1933                {
1934                    if self.current_chunk_includes_range_start(range, current, data_len) {
1935                        result.extend_from_slice(&self.build_multipart_header(
1936                            range,
1937                            multi_part_info.boundary.as_ref(),
1938                            &multi_part_info.total_length,
1939                            multi_part_info.content_type.as_deref(),
1940                        ));
1941                    }
1942                    // Emit the actual data bytes
1943                    result.extend_from_slice(&sliced);
1944                    if self.current_chunk_includes_range_end(range, current, data_len) {
1945                        // If this was the last range, we should emit the final footer too
1946                        if range == final_range {
1947                            if let Some(final_chunk) = self.finalize(&multi_part_info.boundary) {
1948                                result.extend_from_slice(&final_chunk);
1949                            }
1950                        }
1951                        // done with this range
1952                        self.multipart_idx = Some(self.multipart_idx.expect("must be set") + 1);
1953                    }
1954                } else {
1955                    // no part of the data was within this range,
1956                    // so lower bound of this range (and remaining ranges) must be
1957                    // > current + data_len
1958                    break;
1959                }
1960            }
1961            if result.is_empty() {
1962                None
1963            } else {
1964                Some(result.freeze())
1965            }
1966        }
1967    }
1968
1969    #[test]
1970    fn test_range_body_filter_single() {
1971        let mut body_filter = RangeBodyFilter::new_range(RangeType::None);
1972        assert_eq!(body_filter.filter_body(Some("123".into())).unwrap(), "123");
1973
1974        let mut body_filter = RangeBodyFilter::new_range(RangeType::Invalid);
1975        assert!(body_filter.filter_body(Some("123".into())).is_none());
1976
1977        let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(0, 1));
1978        assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "0");
1979        assert!(body_filter.filter_body(Some("345".into())).is_none());
1980
1981        let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(4, 6));
1982        assert!(body_filter.filter_body(Some("012".into())).is_none());
1983        assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "45");
1984        assert!(body_filter.filter_body(Some("678".into())).is_none());
1985
1986        let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(1, 7));
1987        assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "12");
1988        assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "345");
1989        assert_eq!(body_filter.filter_body(Some("678".into())).unwrap(), "6");
1990    }
1991
1992    #[test]
1993    fn test_range_body_filter_multipart() {
1994        // Test #1 - Test multipart ranges from 1 chunk
1995        let data = Bytes::from("0123456789");
1996        let ranges = vec![0..3, 6..9];
1997        let content_length = data.len();
1998        let mut body_filter = RangeBodyFilter::new();
1999        body_filter.set(RangeType::new_multi(ranges.clone()));
2000
2001        body_filter
2002            .range
2003            .update_multirange_info(content_length, None);
2004
2005        let multi_range_info = body_filter
2006            .range
2007            .get_multirange_info()
2008            .cloned()
2009            .expect("Multipart Ranges should have MultiPartInfo struct");
2010
2011        // Pass the whole body in one chunk
2012        let output = body_filter.filter_body(Some(data)).unwrap();
2013        let footer = body_filter.finalize(&multi_range_info.boundary).unwrap();
2014
2015        // Convert to String so that we can inspect whole response
2016        let output_str = str::from_utf8(&output).unwrap();
2017        let final_boundary = str::from_utf8(&footer).unwrap();
2018        let boundary = &multi_range_info.boundary;
2019
2020        // Check part headers
2021        for (i, range) in ranges.iter().enumerate() {
2022            let header = &format!(
2023                "--{}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2024                boundary,
2025                range.start,
2026                range.end - 1,
2027                content_length
2028            );
2029            assert!(
2030                output_str.contains(header),
2031                "Missing part header {} in multipart body",
2032                i
2033            );
2034            // Check body matches
2035            let expected_body = &"0123456789"[range.clone()];
2036            assert!(
2037                output_str.contains(expected_body),
2038                "Missing body {} for range {:?}",
2039                expected_body,
2040                range
2041            )
2042        }
2043        // Check the final boundary footer
2044        assert_eq!(final_boundary, format!("\r\n--{}--\r\n", boundary));
2045
2046        // Test #2 - Test multipart ranges from multiple chunks
2047        let full_body = b"0123456789";
2048        let ranges = vec![0..2, 4..6, 8..9];
2049        let content_length = full_body.len();
2050        let content_type = "text/plain".to_string();
2051        let mut body_filter = RangeBodyFilter::new();
2052        body_filter.set(RangeType::new_multi(ranges.clone()));
2053
2054        body_filter
2055            .range
2056            .update_multirange_info(content_length, Some(content_type.clone()));
2057
2058        let multi_range_info = body_filter
2059            .range
2060            .get_multirange_info()
2061            .cloned()
2062            .expect("Multipart Ranges should have MultiPartInfo struct");
2063
2064        // Split the body into 4 chunks
2065        let chunk1 = Bytes::from_static(b"012");
2066        let chunk2 = Bytes::from_static(b"345");
2067        let chunk3 = Bytes::from_static(b"678");
2068        let chunk4 = Bytes::from_static(b"9");
2069
2070        let mut collected_bytes = BytesMut::new();
2071        for chunk in [chunk1, chunk2, chunk3, chunk4] {
2072            if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2073                collected_bytes.extend_from_slice(&filtered);
2074            }
2075        }
2076        if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2077            collected_bytes.extend_from_slice(&final_boundary);
2078        }
2079
2080        let output_str = str::from_utf8(&collected_bytes).unwrap();
2081        let boundary = multi_range_info.boundary;
2082
2083        for (i, range) in ranges.iter().enumerate() {
2084            let header = &format!(
2085                "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2086                boundary,
2087                content_type,
2088                range.start,
2089                range.end - 1,
2090                content_length
2091            );
2092            let expected_body = &full_body[range.clone()];
2093            let expected_output = format!("{}{}", header, str::from_utf8(expected_body).unwrap());
2094
2095            assert!(
2096                output_str.contains(&expected_output),
2097                "Missing or malformed part {} in multipart body. \n Expected: \n{}\n Got: \n{}",
2098                i,
2099                expected_output,
2100                output_str
2101            )
2102        }
2103
2104        assert!(
2105            output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2106            "Missing final boundary"
2107        );
2108
2109        // Test #3 - Test multipart ranges from multiple chunks, with ranges spanning chunks
2110        let full_body = b"abcdefghijkl";
2111        let ranges = vec![2..7, 9..11];
2112        let content_length = full_body.len();
2113        let content_type = "application/octet-stream".to_string();
2114        let mut body_filter = RangeBodyFilter::new();
2115        body_filter.set(RangeType::new_multi(ranges.clone()));
2116
2117        body_filter
2118            .range
2119            .update_multirange_info(content_length, Some(content_type.clone()));
2120
2121        let multi_range_info = body_filter
2122            .range
2123            .clone()
2124            .get_multirange_info()
2125            .cloned()
2126            .expect("Multipart Ranges should have MultiPartInfo struct");
2127
2128        // Split the body into 4 chunks
2129        let chunk1 = Bytes::from_static(b"abc");
2130        let chunk2 = Bytes::from_static(b"def");
2131        let chunk3 = Bytes::from_static(b"ghi");
2132        let chunk4 = Bytes::from_static(b"jkl");
2133
2134        let mut collected_bytes = BytesMut::new();
2135        for chunk in [chunk1, chunk2, chunk3, chunk4] {
2136            if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2137                collected_bytes.extend_from_slice(&filtered);
2138            }
2139        }
2140        if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2141            collected_bytes.extend_from_slice(&final_boundary);
2142        }
2143
2144        let output_str = str::from_utf8(&collected_bytes).unwrap();
2145        let boundary = &multi_range_info.boundary;
2146
2147        let header1 = &format!(
2148            "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2149            boundary,
2150            content_type,
2151            ranges[0].start,
2152            ranges[0].end - 1,
2153            content_length
2154        );
2155        let header2 = &format!(
2156            "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2157            boundary,
2158            content_type,
2159            ranges[1].start,
2160            ranges[1].end - 1,
2161            content_length
2162        );
2163
2164        assert!(output_str.contains(header1));
2165        assert!(output_str.contains(header2));
2166
2167        let expected_body_slices = ["cdefg", "jk"];
2168
2169        assert!(
2170            output_str.contains(expected_body_slices[0]),
2171            "Missing expected sliced body {}",
2172            expected_body_slices[0]
2173        );
2174
2175        assert!(
2176            output_str.contains(expected_body_slices[1]),
2177            "Missing expected sliced body {}",
2178            expected_body_slices[1]
2179        );
2180
2181        assert!(
2182            output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2183            "Missing final boundary"
2184        );
2185    }
2186}
2187
2188// a state machine for proxy logic to tell when to use cache in the case of
2189// miss/revalidation/error.
2190#[derive(Debug)]
2191pub(crate) enum ServeFromCache {
2192    // not using cache
2193    Off,
2194    // should serve cache header
2195    CacheHeader,
2196    // should serve cache header only
2197    CacheHeaderOnly,
2198    // should serve cache header only but upstream response should be admitted to cache
2199    CacheHeaderOnlyMiss,
2200    // should serve cache body with a bool to indicate if it has already called seek on the hit handler
2201    CacheBody(bool),
2202    // should serve cache header but upstream response should be admitted to cache
2203    // This is the starting state for misses, which go to CacheBodyMiss or
2204    // CacheHeaderOnlyMiss before ending at DoneMiss
2205    CacheHeaderMiss,
2206    // should serve cache body but upstream response should be admitted to cache, bool to indicate seek status
2207    CacheBodyMiss(bool),
2208    // done serving cache body
2209    Done,
2210    // done serving cache body, but upstream response should continue to be admitted to cache
2211    DoneMiss,
2212}
2213
2214impl ServeFromCache {
2215    pub fn new() -> Self {
2216        Self::Off
2217    }
2218
2219    pub fn is_on(&self) -> bool {
2220        !matches!(self, Self::Off)
2221    }
2222
2223    pub fn is_miss(&self) -> bool {
2224        matches!(
2225            self,
2226            Self::CacheHeaderMiss
2227                | Self::CacheHeaderOnlyMiss
2228                | Self::CacheBodyMiss(_)
2229                | Self::DoneMiss
2230        )
2231    }
2232
2233    pub fn is_miss_header(&self) -> bool {
2234        // NOTE: this check is for checking if miss was just enabled, so it is excluding
2235        // HeaderOnlyMiss
2236        matches!(self, Self::CacheHeaderMiss)
2237    }
2238
2239    pub fn is_miss_body(&self) -> bool {
2240        matches!(self, Self::CacheBodyMiss(_))
2241    }
2242
2243    pub fn should_discard_upstream(&self) -> bool {
2244        self.is_on() && !self.is_miss()
2245    }
2246
2247    pub fn should_send_to_downstream(&self) -> bool {
2248        !self.is_on()
2249    }
2250
2251    pub fn enable(&mut self) {
2252        *self = Self::CacheHeader;
2253    }
2254
2255    pub fn enable_miss(&mut self) {
2256        if !self.is_on() {
2257            *self = Self::CacheHeaderMiss;
2258        }
2259    }
2260
2261    pub fn enable_header_only(&mut self) {
2262        match self {
2263            Self::CacheBody(_) => *self = Self::Done, // TODO: make sure no body is read yet
2264            Self::CacheBodyMiss(_) => *self = Self::DoneMiss,
2265            _ => {
2266                if self.is_miss() {
2267                    *self = Self::CacheHeaderOnlyMiss;
2268                } else {
2269                    *self = Self::CacheHeaderOnly;
2270                }
2271            }
2272        }
2273    }
2274
2275    // This function is (best effort) cancel-safe to be used in select
2276    pub async fn next_http_task(
2277        &mut self,
2278        cache: &mut HttpCache,
2279        range: &mut RangeBodyFilter,
2280    ) -> Result<HttpTask> {
2281        if !cache.enabled() {
2282            // Cache is disabled due to internal error
2283            // TODO: if nothing is sent to eyeball yet, figure out a way to recovery by
2284            // fetching from upstream
2285            return Error::e_explain(InternalError, "Cache disabled");
2286        }
2287        match self {
2288            Self::Off => panic!("ProxyUseCache not enabled"),
2289            Self::CacheHeader => {
2290                *self = Self::CacheBody(true);
2291                Ok(HttpTask::Header(cache_hit_header(cache), false)) // false for now
2292            }
2293            Self::CacheHeaderMiss => {
2294                *self = Self::CacheBodyMiss(true);
2295                Ok(HttpTask::Header(cache_hit_header(cache), false)) // false for now
2296            }
2297            Self::CacheHeaderOnly => {
2298                *self = Self::Done;
2299                Ok(HttpTask::Header(cache_hit_header(cache), true))
2300            }
2301            Self::CacheHeaderOnlyMiss => {
2302                *self = Self::DoneMiss;
2303                Ok(HttpTask::Header(cache_hit_header(cache), true))
2304            }
2305            Self::CacheBody(should_seek) => {
2306                log::trace!("cache body should seek: {should_seek}");
2307                if *should_seek {
2308                    self.maybe_seek_hit_handler(cache, range)?;
2309                }
2310                loop {
2311                    if let Some(b) = cache.hit_handler().read_body().await? {
2312                        return Ok(HttpTask::Body(Some(b), false)); // false for now
2313                    }
2314                    // EOF from hit handler for body requested
2315                    // if multipart, then seek again
2316                    if range.should_cache_seek_again() {
2317                        self.maybe_seek_hit_handler(cache, range)?;
2318                    } else {
2319                        *self = Self::Done;
2320                        return Ok(HttpTask::Done);
2321                    }
2322                }
2323            }
2324            Self::CacheBodyMiss(should_seek) => {
2325                if *should_seek {
2326                    self.maybe_seek_miss_handler(cache, range)?;
2327                }
2328                // safety: caller of enable_miss() call it only if the async_body_reader exist
2329                loop {
2330                    if let Some(b) = cache.miss_body_reader().unwrap().read_body().await? {
2331                        return Ok(HttpTask::Body(Some(b), false)); // false for now
2332                    } else {
2333                        // EOF from hit handler for body requested
2334                        // if multipart, then seek again
2335                        if range.should_cache_seek_again() {
2336                            self.maybe_seek_miss_handler(cache, range)?;
2337                        } else {
2338                            *self = Self::DoneMiss;
2339                            return Ok(HttpTask::Done);
2340                        }
2341                    }
2342                }
2343            }
2344            Self::Done => Ok(HttpTask::Done),
2345            Self::DoneMiss => Ok(HttpTask::Done),
2346        }
2347    }
2348
2349    fn maybe_seek_miss_handler(
2350        &mut self,
2351        cache: &mut HttpCache,
2352        range_filter: &mut RangeBodyFilter,
2353    ) -> Result<()> {
2354        match &range_filter.range {
2355            RangeType::Single(range) => {
2356                // safety: called only if the async_body_reader exists
2357                if cache.miss_body_reader().unwrap().can_seek() {
2358                    cache
2359                        .miss_body_reader()
2360                        // safety: called only if the async_body_reader exists
2361                        .unwrap()
2362                        .seek(range.start, Some(range.end))
2363                        .or_err(InternalError, "cannot seek miss handler")?;
2364                    // Because the miss body reader is seeking, we no longer need the
2365                    // RangeBodyFilter's help to return the requested byte range.
2366                    range_filter.range = RangeType::None;
2367                }
2368            }
2369            RangeType::Multi(_info) => {
2370                // safety: called only if the async_body_reader exists
2371                if cache.miss_body_reader().unwrap().can_seek_multipart() {
2372                    let range = range_filter.next_cache_multipart_range();
2373                    cache
2374                        .miss_body_reader()
2375                        .unwrap()
2376                        .seek_multipart(range.start, Some(range.end))
2377                        .or_err(InternalError, "cannot seek hit handler for multirange")?;
2378                    // we still need RangeBodyFilter's help to transform the byte
2379                    // range into a multipart response.
2380                    range_filter.set_current_cursor(range.start);
2381                }
2382            }
2383            _ => {}
2384        }
2385
2386        *self = Self::CacheBodyMiss(false);
2387        Ok(())
2388    }
2389
2390    fn maybe_seek_hit_handler(
2391        &mut self,
2392        cache: &mut HttpCache,
2393        range_filter: &mut RangeBodyFilter,
2394    ) -> Result<()> {
2395        match &range_filter.range {
2396            RangeType::Single(range) => {
2397                if cache.hit_handler().can_seek() {
2398                    cache
2399                        .hit_handler()
2400                        .seek(range.start, Some(range.end))
2401                        .or_err(InternalError, "cannot seek hit handler")?;
2402                    // Because the hit handler is seeking, we no longer need the
2403                    // RangeBodyFilter's help to return the requested byte range.
2404                    range_filter.range = RangeType::None;
2405                }
2406            }
2407            RangeType::Multi(_info) => {
2408                if cache.hit_handler().can_seek_multipart() {
2409                    let range = range_filter.next_cache_multipart_range();
2410                    cache
2411                        .hit_handler()
2412                        .seek_multipart(range.start, Some(range.end))
2413                        .or_err(InternalError, "cannot seek hit handler for multirange")?;
2414                    // we still need RangeBodyFilter's help to transform the byte
2415                    // range into a multipart response.
2416                    range_filter.set_current_cursor(range.start);
2417                }
2418            }
2419            _ => {}
2420        }
2421        *self = Self::CacheBody(false);
2422        Ok(())
2423    }
2424}