Skip to main content

pingora_proxy/
proxy_cache.rs

1// Copyright 2026 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::*;
16use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING};
17use http::{Method, StatusCode};
18use pingora_cache::key::CacheHashKey;
19use pingora_cache::lock::LockStatus;
20use pingora_cache::max_file_size::ERR_RESPONSE_TOO_LARGE;
21use pingora_cache::{ForcedFreshness, HitHandler, HitStatus, RespCacheable::*};
22use pingora_core::protocols::http::conditional_filter::to_304;
23use pingora_core::protocols::http::v1::common::header_value_content_length;
24use pingora_core::ErrorType;
25use range_filter::RangeBodyFilter;
26use std::time::SystemTime;
27
28impl<SV, C> HttpProxy<SV, C>
29where
30    C: custom::Connector,
31{
32    // return bool: server_session can be reused, and error if any
33    pub(crate) async fn proxy_cache(
34        self: &Arc<Self>,
35        session: &mut Session,
36        ctx: &mut SV::CTX,
37    ) -> Option<(bool, Option<Box<Error>>)>
38    // None: continue to proxy, Some: return
39    where
40        SV: ProxyHttp + Send + Sync + 'static,
41        SV::CTX: Send + Sync,
42    {
43        // Cache logic request phase
44        if let Err(e) = self.inner.request_cache_filter(session, ctx) {
45            // TODO: handle this error
46            warn!(
47                "Fail to request_cache_filter: {e}, {}",
48                self.inner.request_summary(session, ctx)
49            );
50        }
51
52        // cache key logic, should this be part of request_cache_filter?
53        if session.cache.enabled() {
54            match self.inner.cache_key_callback(session, ctx) {
55                Ok(key) => {
56                    session.cache.set_cache_key(key);
57                }
58                Err(e) => {
59                    // TODO: handle this error
60                    session.cache.disable(NoCacheReason::StorageError);
61                    warn!(
62                        "Fail to cache_key_callback: {e}, {}",
63                        self.inner.request_summary(session, ctx)
64                    );
65                }
66            }
67        }
68
69        // cache purge logic: PURGE short-circuits rest of request
70        if self.inner.is_purge(session, ctx) {
71            return self.proxy_purge(session, ctx).await;
72        }
73
74        // bypass cache lookup if we predict to be uncacheable
75        if session.cache.enabled() && !session.cache.cacheable_prediction() {
76            session.cache.bypass();
77        }
78
79        if !session.cache.enabled() {
80            return None;
81        }
82
83        // cache lookup logic
84        loop {
85            // for cache lock, TODO: cap the max number of loops
86            match session.cache.cache_lookup().await {
87                Ok(res) => {
88                    let mut hit_status_opt = None;
89                    if let Some((mut meta, mut handler)) = res {
90                        // Vary logic
91                        // Because this branch can be called multiple times in a loop, and we only
92                        // need to update the vary once, check if variance is already set to
93                        // prevent unnecessary vary lookups.
94                        let cache_key = session.cache.cache_key();
95                        if let Some(variance) = cache_key.variance_bin() {
96                            // We've looked up a secondary slot.
97                            // Adhoc double check that the variance found is the variance we want.
98                            if Some(variance) != meta.variance() {
99                                warn!("Cache variance mismatch, {variance:?}, {cache_key:?}");
100                                session.cache.disable(NoCacheReason::InternalError);
101                                break None;
102                            }
103                        } else {
104                            // Basic cache key; either variance is off, or this is the primary slot.
105                            let req_header = session.req_header();
106                            let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
107                            if let Some(variance) = variance {
108                                // Variance is on. This is the primary slot.
109                                if !session.cache.cache_vary_lookup(variance, &meta) {
110                                    // This wasn't the desired variant. Updated cache key variance, cause another
111                                    // lookup to get the desired variant, which would be in a secondary slot.
112                                    continue;
113                                }
114                            } // else: vary is not in use
115                        }
116
117                        // Either no variance, or the current handler targets the correct variant.
118
119                        // hit
120                        // TODO: maybe round and/or cache now()
121                        let is_fresh = meta.is_fresh(SystemTime::now());
122                        // check if we should force expire or force miss
123                        let hit_status = match self
124                            .inner
125                            .cache_hit_filter(session, &meta, &mut handler, is_fresh, ctx)
126                            .await
127                        {
128                            Err(e) => {
129                                error!(
130                                    "Failed to filter cache hit: {e}, {}",
131                                    self.inner.request_summary(session, ctx)
132                                );
133                                // this return value will cause us to fetch from upstream
134                                HitStatus::FailedHitFilter
135                            }
136                            Ok(None) => {
137                                if is_fresh {
138                                    HitStatus::Fresh
139                                } else {
140                                    HitStatus::Expired
141                                }
142                            }
143                            Ok(Some(ForcedFreshness::ForceExpired)) => {
144                                // force expired asset should not be serve as stale
145                                // because force expire is usually to remove data
146                                meta.disable_serve_stale();
147                                HitStatus::ForceExpired
148                            }
149                            Ok(Some(ForcedFreshness::ForceMiss)) => HitStatus::ForceMiss,
150                            Ok(Some(ForcedFreshness::ForceFresh)) => HitStatus::Fresh,
151                        };
152
153                        hit_status_opt = Some(hit_status);
154
155                        // init cache for hit / stale
156                        session.cache.cache_found(meta, handler, hit_status);
157                    }
158
159                    if hit_status_opt.is_none_or(HitStatus::is_treated_as_miss) {
160                        // cache miss
161                        if session.cache.is_cache_locked() {
162                            // Another request is filling the cache; try waiting til that's done and retry.
163                            let lock_status = session.cache.cache_lock_wait().await;
164                            if self.handle_lock_status(session, ctx, lock_status) {
165                                continue;
166                            } else {
167                                break None;
168                            }
169                        } else {
170                            self.inner.cache_miss(session, ctx);
171                            break None;
172                        }
173                    }
174
175                    // Safe because an empty hit status would have broken out
176                    // in the block above
177                    let hit_status = hit_status_opt.expect("None case handled as miss");
178
179                    if !hit_status.is_fresh() {
180                        // expired or force expired asset
181                        if session.cache.is_cache_locked() {
182                            // first if this is the sub request for the background cache update
183                            if let Some(write_lock) = session
184                                .subrequest_ctx
185                                .as_mut()
186                                .and_then(|ctx| ctx.take_write_lock())
187                            {
188                                // Put the write lock in the request
189                                session.cache.set_write_lock(write_lock);
190                                session.cache.tag_as_subrequest();
191                                // and then let it go to upstream
192                                break None;
193                            }
194                            let will_serve_stale = session.cache.can_serve_stale_updating()
195                                && self.inner.should_serve_stale(session, ctx, None);
196                            if !will_serve_stale {
197                                let lock_status = session.cache.cache_lock_wait().await;
198                                if self.handle_lock_status(session, ctx, lock_status) {
199                                    continue;
200                                } else {
201                                    break None;
202                                }
203                            }
204                            // else continue to serve stale
205                            session.cache.set_stale_updating();
206                        } else if session.cache.is_cache_lock_writer() {
207                            // stale while revalidate logic for the writer
208                            let will_serve_stale = session.cache.can_serve_stale_updating()
209                                && self.inner.should_serve_stale(session, ctx, None);
210                            if will_serve_stale {
211                                // create a background thread to do the actual update
212                                // the subrequest handle is only None by this phase in unit tests
213                                // that don't go through process_new_http
214                                let (permit, cache_lock) = session.cache.take_write_lock();
215                                SubrequestSpawner::new(self.clone()).spawn_background_subrequest(
216                                    session.as_ref(),
217                                    subrequest::Ctx::builder()
218                                        .cache_write_lock(
219                                            cache_lock,
220                                            session.cache.cache_key().clone(),
221                                            permit,
222                                        )
223                                        .build(),
224                                );
225                                // continue to serve stale for this request
226                                session.cache.set_stale_updating();
227                            } else {
228                                // return to fetch from upstream
229                                break None;
230                            }
231                        } else {
232                            // return to fetch from upstream
233                            break None;
234                        }
235                    }
236
237                    let (reuse, err) = self.proxy_cache_hit(session, ctx).await;
238                    if let Some(e) = err.as_ref() {
239                        error!(
240                            "Fail to serve cache: {e}, {}",
241                            self.inner.request_summary(session, ctx)
242                        );
243                    }
244                    // responses is served from cache, exit
245                    break Some((reuse, err));
246                }
247                Err(e) => {
248                    // Allow cache miss to fill cache even if cache lookup errors
249                    // this is mostly to support backward incompatible metadata update
250                    // TODO: check error types
251                    // session.cache.disable();
252                    self.inner.cache_miss(session, ctx);
253                    warn!(
254                        "Fail to cache lookup: {e}, {}",
255                        self.inner.request_summary(session, ctx)
256                    );
257                    break None;
258                }
259            }
260        }
261    }
262
263    // return bool: server_session can be reused, and error if any
264    pub(crate) async fn proxy_cache_hit(
265        &self,
266        session: &mut Session,
267        ctx: &mut SV::CTX,
268    ) -> (bool, Option<Box<Error>>)
269    where
270        SV: ProxyHttp + Send + Sync,
271        SV::CTX: Send + Sync,
272    {
273        use range_filter::*;
274
275        let seekable = session.cache.hit_handler().can_seek();
276        let mut header = cache_hit_header(&session.cache);
277
278        let req = session.req_header();
279
280        let not_modified = match self.inner.cache_not_modified_filter(session, &header, ctx) {
281            Ok(not_modified) => not_modified,
282            Err(e) => {
283                // fail open if cache_not_modified_filter errors,
284                // just return the whole original response
285                warn!(
286                    "Failed to run cache not modified filter: {e}, {}",
287                    self.inner.request_summary(session, ctx)
288                );
289                false
290            }
291        };
292        if not_modified {
293            to_304(&mut header);
294        }
295        let header_only = not_modified || req.method == http::method::Method::HEAD;
296
297        // process range header if the cache storage supports seek
298        let range_type = if seekable && !session.ignore_downstream_range {
299            self.inner.range_header_filter(session, &mut header, ctx)
300        } else {
301            RangeType::None
302        };
303
304        // return a 416 with an empty body for simplicity
305        let header_only = header_only || matches!(range_type, RangeType::Invalid);
306        debug!("header: {header:?}");
307
308        // TODO: use ProxyUseCache to replace the logic below
309        match self.inner.response_filter(session, &mut header, ctx).await {
310            Ok(_) => {
311                if let Err(e) = session
312                    .downstream_modules_ctx
313                    .response_header_filter(&mut header, header_only)
314                    .await
315                {
316                    error!(
317                        "Failed to run downstream modules response header filter in hit: {e}, {}",
318                        self.inner.request_summary(session, ctx)
319                    );
320                    session
321                        .as_mut()
322                        .respond_error(500)
323                        .await
324                        .unwrap_or_else(|e| {
325                            error!("failed to send error response to downstream: {e}");
326                        });
327                    // we have not write anything dirty to downstream, it is still reusable
328                    return (true, Some(e));
329                }
330
331                if let Err(e) = session
332                    .as_mut()
333                    .write_response_header(header)
334                    .await
335                    .map_err(|e| e.into_down())
336                {
337                    // downstream connection is bad already
338                    return (false, Some(e));
339                }
340            }
341            Err(e) => {
342                error!(
343                    "Failed to run response filter in hit: {e}, {}",
344                    self.inner.request_summary(session, ctx)
345                );
346                session
347                    .as_mut()
348                    .respond_error(500)
349                    .await
350                    .unwrap_or_else(|e| {
351                        error!("failed to send error response to downstream: {e}");
352                    });
353                // we have not write anything dirty to downstream, it is still reusable
354                return (true, Some(e));
355            }
356        }
357        debug!("finished sending cached header to downstream");
358
359        // If the function returns an Err, there was an issue seeking from the hit handler.
360        //
361        // Returning false means that no seeking or state change was done, either because the
362        // hit handler doesn't support the seek or because multipart doesn't apply.
363        fn seek_multipart(
364            hit_handler: &mut HitHandler,
365            range_filter: &mut RangeBodyFilter,
366        ) -> Result<bool> {
367            if !range_filter.is_multipart_range() || !hit_handler.can_seek_multipart() {
368                return Ok(false);
369            }
370            let r = range_filter.next_cache_multipart_range();
371            hit_handler.seek_multipart(r.start, Some(r.end))?;
372            // we still need RangeBodyFilter's help to transform the byte
373            // range into a multipart response.
374            range_filter.set_current_cursor(r.start);
375            Ok(true)
376        }
377
378        if !header_only {
379            let mut maybe_range_filter = match &range_type {
380                RangeType::Single(r) => {
381                    if session.cache.hit_handler().can_seek() {
382                        if let Err(e) = session.cache.hit_handler().seek(r.start, Some(r.end)) {
383                            return (false, Some(e));
384                        }
385                        None
386                    } else {
387                        Some(RangeBodyFilter::new_range(range_type.clone()))
388                    }
389                }
390                RangeType::Multi(_) => {
391                    let mut range_filter = RangeBodyFilter::new_range(range_type.clone());
392                    if let Err(e) = seek_multipart(session.cache.hit_handler(), &mut range_filter) {
393                        return (false, Some(e));
394                    }
395                    Some(range_filter)
396                }
397                RangeType::Invalid => unreachable!(),
398                RangeType::None => None,
399            };
400            loop {
401                match session.cache.hit_handler().read_body().await {
402                    Ok(raw_body) => {
403                        let end = raw_body.is_none();
404
405                        if end {
406                            if let Some(range_filter) = maybe_range_filter.as_mut() {
407                                if range_filter.should_cache_seek_again() {
408                                    let e = match seek_multipart(
409                                        session.cache.hit_handler(),
410                                        range_filter,
411                                    ) {
412                                        Ok(true) => {
413                                            // called seek(), read again
414                                            continue;
415                                        }
416                                        Ok(false) => {
417                                            // body reader can no longer seek multipart,
418                                            // but cache wants to continue seeking
419                                            // the body will just end in this case if we pass the
420                                            // None through
421                                            // (TODO: how might hit handlers want to recover from
422                                            // this situation)?
423                                            Error::explain(
424                                                InternalError,
425                                                "hit handler cannot seek for multipart again",
426                                            )
427                                            // the body will just end in this case.
428                                        }
429                                        Err(e) => e,
430                                    };
431                                    return (false, Some(e));
432                                }
433                            }
434                        }
435
436                        let mut body = if let Some(range_filter) = maybe_range_filter.as_mut() {
437                            range_filter.filter_body(raw_body)
438                        } else {
439                            raw_body
440                        };
441
442                        match self
443                            .inner
444                            .response_body_filter(session, &mut body, end, ctx)
445                        {
446                            Ok(Some(duration)) => {
447                                trace!("delaying response for {duration:?}");
448                                time::sleep(duration).await;
449                            }
450                            Ok(None) => { /* continue */ }
451                            Err(e) => {
452                                // body is being sent, don't treat downstream as reusable
453                                return (false, Some(e));
454                            }
455                        }
456
457                        if let Err(e) = session
458                            .downstream_modules_ctx
459                            .response_body_filter(&mut body, end)
460                        {
461                            // body is being sent, don't treat downstream as reusable
462                            return (false, Some(e));
463                        }
464
465                        if !end && body.as_ref().is_none_or(|b| b.is_empty()) {
466                            // Don't write empty body which will end session,
467                            // still more hit handler bytes to read
468                            continue;
469                        }
470
471                        // write to downstream
472                        let b = body.unwrap_or_default();
473                        if let Err(e) = session
474                            .as_mut()
475                            .write_response_body(b, end)
476                            .await
477                            .map_err(|e| e.into_down())
478                        {
479                            return (false, Some(e));
480                        }
481                        if end {
482                            break;
483                        }
484                    }
485                    Err(e) => return (false, Some(e)),
486                }
487            }
488        }
489
490        if let Err(e) = session.cache.finish_hit_handler().await {
491            warn!("Error during finish_hit_handler: {}", e);
492        }
493
494        match session.as_mut().finish_body().await {
495            Ok(_) => {
496                debug!("finished sending cached body to downstream");
497                (true, None)
498            }
499            Err(e) => (false, Some(e)),
500        }
501    }
502
503    /* Downstream revalidation, only needed when cache is on because otherwise origin
504     * will handle it */
505    pub(crate) fn downstream_response_conditional_filter(
506        &self,
507        use_cache: &mut ServeFromCache,
508        session: &Session,
509        resp: &mut ResponseHeader,
510        ctx: &mut SV::CTX,
511    ) where
512        SV: ProxyHttp,
513    {
514        // TODO: range
515        let req = session.req_header();
516
517        let not_modified = match self.inner.cache_not_modified_filter(session, resp, ctx) {
518            Ok(not_modified) => not_modified,
519            Err(e) => {
520                // fail open if cache_not_modified_filter errors,
521                // just return the whole original response
522                warn!(
523                    "Failed to run cache not modified filter: {e}, {}",
524                    self.inner.request_summary(session, ctx)
525                );
526                false
527            }
528        };
529
530        if not_modified {
531            to_304(resp);
532        }
533        let header_only = not_modified || req.method == http::method::Method::HEAD;
534        if header_only && use_cache.is_on() {
535            // tell cache to stop serving downstream after yielding header
536            // (misses will continue to allow admitting upstream into cache)
537            use_cache.enable_header_only();
538        }
539    }
540
541    // TODO: cache upstream header filter to add/remove headers
542
543    pub(crate) async fn cache_http_task(
544        &self,
545        session: &mut Session,
546        task: &HttpTask,
547        ctx: &mut SV::CTX,
548        serve_from_cache: &mut ServeFromCache,
549    ) -> Result<()>
550    where
551        SV: ProxyHttp + Send + Sync,
552        SV::CTX: Send + Sync,
553    {
554        if !session.cache.enabled() && !session.cache.bypassing() {
555            return Ok(());
556        }
557
558        match task {
559            HttpTask::Header(header, end_stream) => {
560                // decide if cacheable and create cache meta
561                // for now, skip 1xxs (should not affect response cache decisions)
562                // However 101 is an exception because it is the final response header
563                if header.status.is_informational()
564                    && header.status != StatusCode::SWITCHING_PROTOCOLS
565                {
566                    return Ok(());
567                }
568                match self.inner.response_cache_filter(session, header, ctx)? {
569                    Cacheable(meta) => {
570                        let mut fill_cache = true;
571                        if session.cache.bypassing() {
572                            // The cache might have been bypassed because the response exceeded the
573                            // maximum cacheable asset size. If that looks like the case (there
574                            // is a maximum file size configured and we don't know the content
575                            // length up front), attempting to re-enable the cache now would cause
576                            // the request to fail when the chunked response exceeds the maximum
577                            // file size again.
578                            if session.cache.max_file_size_bytes().is_some()
579                                && !meta.headers().contains_key(header::CONTENT_LENGTH)
580                            {
581                                session
582                                    .cache
583                                    .disable(NoCacheReason::PredictedResponseTooLarge);
584                                return Ok(());
585                            }
586
587                            session.cache.response_became_cacheable();
588
589                            if session.req_header().method == Method::GET
590                                && meta.response_header().status == StatusCode::OK
591                            {
592                                self.inner.cache_miss(session, ctx);
593                                if !session.cache.enabled() {
594                                    fill_cache = false;
595                                }
596                            } else {
597                                // we've allowed caching on the next request,
598                                // but do not cache _this_ request if bypassed and not 200
599                                // (We didn't run upstream request cache filters to strip range or condition headers,
600                                // so this could be an uncacheable response e.g. 206 or 304 or HEAD.
601                                // Exclude all non-200/GET for simplicity, may expand allowable codes in the future.)
602                                fill_cache = false;
603                                session.cache.disable(NoCacheReason::Deferred);
604                            }
605                        }
606
607                        // If the Content-Length is known, and a maximum asset size has been configured
608                        // on the cache, validate that the response does not exceed the maximum asset size.
609                        if session.cache.enabled() {
610                            if let Some(max_file_size) = session.cache.max_file_size_bytes() {
611                                let content_length_hdr = meta.headers().get(header::CONTENT_LENGTH);
612                                if let Some(content_length) =
613                                    header_value_content_length(content_length_hdr)
614                                {
615                                    if content_length > max_file_size {
616                                        fill_cache = false;
617                                        session.cache.response_became_uncacheable(
618                                            NoCacheReason::ResponseTooLarge,
619                                        );
620                                        session.cache.disable(NoCacheReason::ResponseTooLarge);
621                                        // too large to cache, disable ranging
622                                        session.ignore_downstream_range = true;
623                                    }
624                                }
625                                // if the content-length header is not specified, the miss handler
626                                // will count the response size on the fly, aborting the request
627                                // mid-transfer if the max file size is exceeded
628                            }
629                        }
630                        if fill_cache {
631                            let req_header = session.req_header();
632                            // Update the variance in the meta via the same callback,
633                            // cache_vary_filter(), used in cache lookup for consistency.
634                            // Future cache lookups need a matching variance in the meta
635                            // with the cache key to pick up the correct variance
636                            let variance = self.inner.cache_vary_filter(&meta, ctx, req_header);
637                            session.cache.set_cache_meta(meta);
638                            session.cache.update_variance(variance);
639                            // this sends the meta and header
640                            session.cache.set_miss_handler().await?;
641                            if session.cache.miss_body_reader().is_some() {
642                                serve_from_cache.enable_miss();
643                            }
644                            if *end_stream {
645                                session
646                                    .cache
647                                    .miss_handler()
648                                    .unwrap() // safe, it is set above
649                                    .write_body(Bytes::new(), true)
650                                    .await?;
651                                session.cache.finish_miss_handler().await?;
652                            }
653                        }
654                    }
655                    Uncacheable(reason) => {
656                        if !session.cache.bypassing() {
657                            // mark as uncacheable, so we bypass cache next time
658                            session.cache.response_became_uncacheable(reason);
659                        }
660                        session.cache.disable(reason);
661                    }
662                }
663            }
664            HttpTask::Body(data, end_stream) | HttpTask::UpgradedBody(data, end_stream) => {
665                // It is not normally advisable to cache upgraded responses
666                // e.g. they are essentially close-delimited, so they are easily truncated
667                // but the framework still allows for it
668                match data {
669                    Some(d) => {
670                        if session.cache.enabled() {
671                            // TODO: do this async
672                            // fail if writing the body would exceed the max_file_size_bytes
673                            let body_size_allowed =
674                                session.cache.track_body_bytes_for_max_file_size(d.len());
675                            if !body_size_allowed {
676                                debug!("chunked response exceeded max cache size, remembering that it is uncacheable");
677                                session
678                                    .cache
679                                    .response_became_uncacheable(NoCacheReason::ResponseTooLarge);
680
681                                return Error::e_explain(
682                                    ERR_RESPONSE_TOO_LARGE,
683                                    format!(
684                                        "writing data of size {} bytes would exceed max file size of {} bytes",
685                                        d.len(),
686                                        session.cache.max_file_size_bytes().expect("max file size bytes must be set to exceed size")
687                                    ),
688                                );
689                            }
690
691                            // this will panic if more data is sent after we see end_stream
692                            // but should be impossible in real world
693                            let miss_handler = session.cache.miss_handler().unwrap();
694
695                            miss_handler.write_body(d.clone(), *end_stream).await?;
696                            if *end_stream {
697                                session.cache.finish_miss_handler().await?;
698                            }
699                        }
700                    }
701                    None => {
702                        if session.cache.enabled() && *end_stream {
703                            session.cache.finish_miss_handler().await?;
704                        }
705                    }
706                }
707            }
708            HttpTask::Trailer(_) => {} // h1 trailer is not supported yet
709            HttpTask::Done => {
710                if session.cache.enabled() {
711                    session.cache.finish_miss_handler().await?;
712                }
713            }
714            HttpTask::Failed(_) => {
715                // TODO: handle this failure: delete the temp files?
716            }
717        }
718        Ok(())
719    }
720
721    // Decide if local cache can be used according to upstream http header
722    // 1. when upstream returns 304, the local cache is refreshed and served fresh
723    // 2. when upstream returns certain HTTP error status, the local cache is served stale
724    // Return true if local cache should be used, false otherwise
725    pub(crate) async fn revalidate_or_stale(
726        &self,
727        session: &mut Session,
728        task: &mut HttpTask,
729        ctx: &mut SV::CTX,
730    ) -> bool
731    where
732        SV: ProxyHttp + Send + Sync,
733        SV::CTX: Send + Sync,
734    {
735        if !session.cache.enabled() {
736            return false;
737        }
738
739        match task {
740            HttpTask::Header(resp, _eos) => {
741                if resp.status == StatusCode::NOT_MODIFIED {
742                    if session.cache.maybe_cache_meta().is_some() {
743                        // run upstream response filters on upstream 304 first
744                        if let Err(err) = self
745                            .inner
746                            .upstream_response_filter(session, resp, ctx)
747                            .await
748                        {
749                            error!("upstream response filter error on 304: {err:?}");
750                            session.cache.revalidate_uncacheable(
751                                *resp.clone(),
752                                NoCacheReason::InternalError,
753                            );
754                            // always serve from cache after receiving the 304
755                            return true;
756                        }
757                        // 304 doesn't contain all the headers, merge 304 into cached 200 header
758                        // in order for response_cache_filter to run correctly
759                        let merged_header = session.cache.revalidate_merge_header(resp);
760                        match self
761                            .inner
762                            .response_cache_filter(session, &merged_header, ctx)
763                        {
764                            Ok(Cacheable(mut meta)) => {
765                                // For simplicity, ignore changes to variance over 304 for now.
766                                // Note this means upstream can only update variance via 2xx
767                                // (expired response).
768                                //
769                                // TODO: if we choose to respect changing Vary / variance over 304,
770                                // then there are a few cases to consider. See `update_variance` in
771                                // the `pingora-cache` module.
772                                let old_meta = session.cache.maybe_cache_meta().unwrap(); // safe, checked above
773                                if let Some(old_variance) = old_meta.variance() {
774                                    meta.set_variance(old_variance);
775                                }
776                                if let Err(e) = session.cache.revalidate_cache_meta(meta).await {
777                                    // Fail open: we can continue use the revalidated response even
778                                    // if the meta failed to write to storage
779                                    warn!("revalidate_cache_meta failed {e:?}");
780                                }
781                            }
782                            Ok(Uncacheable(reason)) => {
783                                // This response was once cacheable, and upstream tells us it has not changed
784                                // but now we decided it is uncacheable!
785                                // RFC 9111: still allowed to reuse stored response this time because
786                                // it was "successfully validated"
787                                // https://www.rfc-editor.org/rfc/rfc9111#constructing.responses.from.caches
788                                // Serve the response, but do not update cache
789
790                                // We also want to avoid poisoning downstream's cache with an unsolicited 304
791                                // if we did not receive a conditional request from downstream
792                                // (downstream may have a different cacheability assessment and could cache the 304)
793
794                                //TODO: log more
795                                debug!("Uncacheable {reason:?} 304 received");
796                                session.cache.response_became_uncacheable(reason);
797                                session.cache.revalidate_uncacheable(merged_header, reason);
798                            }
799                            Err(e) => {
800                                // Error during revalidation, similarly to the reasons above
801                                // (avoid poisoning downstream cache with passthrough 304),
802                                // allow serving the stored response without updating cache
803                                warn!("Error {e:?} response_cache_filter during revalidation");
804                                session.cache.revalidate_uncacheable(
805                                    merged_header,
806                                    NoCacheReason::InternalError,
807                                );
808                                // Assume the next 304 may succeed, so don't mark uncacheable
809                            }
810                        }
811                        // always serve from cache after receiving the 304
812                        true
813                    } else {
814                        //TODO: log more
815                        warn!("304 received without cached asset, disable caching");
816                        let reason = NoCacheReason::Custom("304 on miss");
817                        session.cache.response_became_uncacheable(reason);
818                        session.cache.disable(reason);
819                        false
820                    }
821                } else if resp.status.is_server_error() {
822                    // stale if error logic, 5xx only for now
823
824                    // this is response header filter, response_written should always be None?
825                    if !session.cache.can_serve_stale_error()
826                        || session.response_written().is_some()
827                    {
828                        return false;
829                    }
830
831                    // create an error to encode the http status code
832                    let http_status_error = Error::create(
833                        ErrorType::HTTPStatus(resp.status.as_u16()),
834                        ErrorSource::Upstream,
835                        None,
836                        None,
837                    );
838                    if self
839                        .inner
840                        .should_serve_stale(session, ctx, Some(&http_status_error))
841                    {
842                        // no more need to keep the write lock
843                        session
844                            .cache
845                            .release_write_lock(NoCacheReason::UpstreamError);
846                        true
847                    } else {
848                        false
849                    }
850                } else {
851                    false // not 304, not stale if error status code
852                }
853            }
854            _ => false, // not header
855        }
856    }
857
858    // None: no staled asset is used, Some(_): staled asset is sent to downstream
859    // bool: can the downstream connection be reused
860    pub(crate) async fn handle_stale_if_error(
861        &self,
862        session: &mut Session,
863        ctx: &mut SV::CTX,
864        error: &Error,
865    ) -> Option<(bool, Option<Box<Error>>)>
866    where
867        SV: ProxyHttp + Send + Sync,
868        SV::CTX: Send + Sync,
869    {
870        // the caller might already checked this as an optimization
871        if !session.cache.can_serve_stale_error() {
872            return None;
873        }
874
875        // the error happen halfway through a regular response to downstream
876        // can't resend the response
877        if session.response_written().is_some() {
878            return None;
879        }
880
881        // check error types
882        if !self.inner.should_serve_stale(session, ctx, Some(error)) {
883            return None;
884        }
885
886        // log the original error
887        warn!(
888            "Fail to proxy: {}, serving stale, {}",
889            error,
890            self.inner.request_summary(session, ctx)
891        );
892
893        // no more need to hang onto the cache lock
894        session
895            .cache
896            .release_write_lock(NoCacheReason::UpstreamError);
897
898        Some(self.proxy_cache_hit(session, ctx).await)
899    }
900
901    // helper function to check when to continue to retry lock (true) or give up (false)
902    fn handle_lock_status(
903        &self,
904        session: &mut Session,
905        ctx: &SV::CTX,
906        lock_status: LockStatus,
907    ) -> bool
908    where
909        SV: ProxyHttp,
910    {
911        debug!("cache unlocked {lock_status:?}");
912        match lock_status {
913            // should lookup the cached asset again
914            LockStatus::Done => true,
915            // should compete to be a new writer
916            LockStatus::TransientError => true,
917            // the request is uncacheable, go ahead to fetch from the origin
918            LockStatus::GiveUp => {
919                // TODO: It will be nice for the writer to propagate the real reason
920                session.cache.disable(NoCacheReason::CacheLockGiveUp);
921                // not cacheable, just go to the origin.
922                false
923            }
924            // treat this the same as TransientError
925            LockStatus::Dangling => {
926                // software bug, but request can recover from this
927                warn!(
928                    "Dangling cache lock, {}",
929                    self.inner.request_summary(session, ctx)
930                );
931                true
932            }
933            // If this reader has spent too long waiting on locks, let the request
934            // through while disabling cache (to avoid amplifying disk writes).
935            LockStatus::WaitTimeout => {
936                warn!(
937                    "Cache lock timeout, {}",
938                    self.inner.request_summary(session, ctx)
939                );
940                session.cache.disable(NoCacheReason::CacheLockTimeout);
941                // not cacheable, just go to the origin.
942                false
943            }
944            // When a singular cache lock has been held for too long,
945            // we should allow requests to recompete for the lock
946            // to protect upstreams from load.
947            LockStatus::AgeTimeout => true,
948            // software bug, this status should be impossible to reach
949            LockStatus::Waiting => panic!("impossible LockStatus::Waiting"),
950        }
951    }
952}
953
954fn cache_hit_header(cache: &HttpCache) -> Box<ResponseHeader> {
955    let mut header = Box::new(cache.cache_meta().response_header_copy());
956    // convert cache response
957
958    // these status codes / method cannot have body, so no need to add chunked encoding
959    let no_body = matches!(header.status.as_u16(), 204 | 304);
960
961    // https://www.rfc-editor.org/rfc/rfc9111#section-4:
962    // When a stored response is used to satisfy a request without validation, a cache
963    // MUST generate an Age header field
964    if !cache.upstream_used() {
965        let age = cache.cache_meta().age().as_secs();
966        header.insert_header(http::header::AGE, age).unwrap();
967    }
968    log::debug!("cache header: {header:?} {:?}", cache.phase());
969
970    // currently storage cache is always considered an h1 upstream
971    // (header-serde serializes as h1.0 or h1.1)
972    // set this header to be h1.1
973    header.set_version(Version::HTTP_11);
974
975    /* Add chunked header to tell downstream to use chunked encoding
976     * during the absent of content-length in h2 */
977    if !no_body
978        && !header.status.is_informational()
979        && header.headers.get(http::header::CONTENT_LENGTH).is_none()
980    {
981        header
982            .insert_header(http::header::TRANSFER_ENCODING, "chunked")
983            .unwrap();
984    }
985    header
986}
987
988// https://datatracker.ietf.org/doc/html/rfc7233#section-3
989pub mod range_filter {
990    use super::*;
991    use bytes::BytesMut;
992    use http::header::*;
993    use std::ops::Range;
994
995    // parse bytes into usize, ignores specific error
996    fn parse_number(input: &[u8]) -> Option<usize> {
997        str::from_utf8(input).ok()?.parse().ok()
998    }
999
1000    fn parse_range_header(
1001        range: &[u8],
1002        content_length: usize,
1003        max_multipart_ranges: Option<usize>,
1004    ) -> RangeType {
1005        use regex::Regex;
1006
1007        // Match individual range parts, (e.g. "0-100", "-5", "1-")
1008        static RE_SINGLE_RANGE_PART: Lazy<Regex> =
1009            Lazy::new(|| Regex::new(r"(?i)^\s*(?P<start>\d*)-(?P<end>\d*)\s*$").unwrap());
1010
1011        // Convert bytes to UTF-8 string
1012        let range_str = match str::from_utf8(range) {
1013            Ok(s) => s,
1014            Err(_) => return RangeType::None,
1015        };
1016
1017        // Split into "bytes=" and the actual range(s)
1018        let mut parts = range_str.splitn(2, "=");
1019
1020        // Check if it starts with "bytes="
1021        let prefix = parts.next();
1022        if !prefix.is_some_and(|s| s.eq_ignore_ascii_case("bytes")) {
1023            return RangeType::None;
1024        }
1025
1026        let Some(ranges_str) = parts.next() else {
1027            // No ranges provided
1028            return RangeType::None;
1029        };
1030
1031        // "bytes=" with an empty (or whitespace-only) range-set is syntactically a
1032        // range request with zero satisfiable range-specs, so return 416.
1033        if ranges_str.trim().is_empty() {
1034            return RangeType::Invalid;
1035        }
1036
1037        // Get the actual range string (e.g."100-200,300-400")
1038        let mut range_count = 0;
1039        for _ in ranges_str.split(',') {
1040            range_count += 1;
1041            if let Some(max_ranges) = max_multipart_ranges {
1042                if range_count >= max_ranges {
1043                    // If we get more than max configured ranges, return None for now to save parsing time
1044                    return RangeType::None;
1045                }
1046            }
1047        }
1048        let mut ranges: Vec<Range<usize>> = Vec::with_capacity(range_count);
1049
1050        // Process each range
1051        let mut last_range_end = 0;
1052        for part in ranges_str.split(',') {
1053            let captured = match RE_SINGLE_RANGE_PART.captures(part) {
1054                Some(c) => c,
1055                None => {
1056                    return RangeType::None;
1057                }
1058            };
1059
1060            let maybe_start = captured
1061                .name("start")
1062                .and_then(|s| s.as_str().parse::<usize>().ok());
1063            let end = captured
1064                .name("end")
1065                .and_then(|s| s.as_str().parse::<usize>().ok());
1066
1067            let range = if let Some(start) = maybe_start {
1068                if start >= content_length {
1069                    // Skip the invalid range
1070                    continue;
1071                }
1072                // open-ended range should end at the last byte
1073                // over sized end is allowed but ignored
1074                // range end is inclusive
1075                let end = std::cmp::min(end.unwrap_or(content_length - 1), content_length - 1) + 1;
1076                if end <= start {
1077                    // Skip the invalid range
1078                    continue;
1079                }
1080                start..end
1081            } else {
1082                // start is empty, this changes the meaning of the value of `end`
1083                // Now it means to read the last `end` bytes
1084                if let Some(end) = end {
1085                    if content_length >= end {
1086                        (content_length - end)..content_length
1087                    } else {
1088                        // over sized end is allowed but ignored
1089                        0..content_length
1090                    }
1091                } else {
1092                    // No start or end, skip the invalid range
1093                    continue;
1094                }
1095            };
1096            // For now we stick to non-overlapping, ascending ranges for simplicity
1097            // and parity with nginx
1098            if range.start < last_range_end {
1099                return RangeType::None;
1100            }
1101            last_range_end = range.end;
1102            ranges.push(range);
1103        }
1104
1105        // Note for future: we can technically coalesce multiple ranges for multipart
1106        //
1107        // https://www.rfc-editor.org/rfc/rfc9110#section-17.15
1108        // "Servers ought to ignore, coalesce, or reject egregious range
1109        // requests, such as requests for more than two overlapping ranges or
1110        // for many small ranges in a single set, particularly when the ranges
1111        // are requested out of order for no apparent reason. Multipart range
1112        // requests are not designed to support random access."
1113
1114        if ranges.is_empty() {
1115            // We got some ranges, processed them but none were valid
1116            RangeType::Invalid
1117        } else if ranges.len() == 1 {
1118            RangeType::Single(ranges[0].clone()) // Only 1 index
1119        } else {
1120            RangeType::Multi(MultiRangeInfo::new(ranges))
1121        }
1122    }
1123    #[test]
1124    fn test_parse_range() {
1125        assert_eq!(
1126            parse_range_header(b"bytes=0-1", 10, None),
1127            RangeType::new_single(0, 2)
1128        );
1129        assert_eq!(
1130            parse_range_header(b"bYTes=0-9", 10, None),
1131            RangeType::new_single(0, 10)
1132        );
1133        assert_eq!(
1134            parse_range_header(b"bytes=0-12", 10, None),
1135            RangeType::new_single(0, 10)
1136        );
1137        assert_eq!(
1138            parse_range_header(b"bytes=0-", 10, None),
1139            RangeType::new_single(0, 10)
1140        );
1141        assert_eq!(
1142            parse_range_header(b"bytes=2-1", 10, None),
1143            RangeType::Invalid
1144        );
1145        assert_eq!(
1146            parse_range_header(b"bytes=10-11", 10, None),
1147            RangeType::Invalid
1148        );
1149        assert_eq!(
1150            parse_range_header(b"bytes=-2", 10, None),
1151            RangeType::new_single(8, 10)
1152        );
1153        assert_eq!(
1154            parse_range_header(b"bytes=-12", 10, None),
1155            RangeType::new_single(0, 10)
1156        );
1157        assert_eq!(parse_range_header(b"bytes=-", 10, None), RangeType::Invalid);
1158        assert_eq!(parse_range_header(b"bytes=", 10, None), RangeType::Invalid);
1159        assert_eq!(
1160            parse_range_header(b"bytes=  ", 10, None),
1161            RangeType::Invalid
1162        );
1163    }
1164
1165    // Add some tests for multi-range too
1166    #[test]
1167    fn test_parse_range_header_multi() {
1168        assert_eq!(
1169            parse_range_header(b"bytes=0-1,4-5", 10, None)
1170                .get_multirange_info()
1171                .expect("Should have multipart info for Multipart range request")
1172                .ranges,
1173            (vec![Range { start: 0, end: 2 }, Range { start: 4, end: 6 }])
1174        );
1175        // Last range is invalid because the content-length is too small
1176        assert_eq!(
1177            parse_range_header(b"bytEs=0-99,200-299,400-499", 320, None)
1178                .get_multirange_info()
1179                .expect("Should have multipart info for Multipart range request")
1180                .ranges,
1181            (vec![
1182                Range { start: 0, end: 100 },
1183                Range {
1184                    start: 200,
1185                    end: 300
1186                }
1187            ])
1188        );
1189        // Same as above but appropriate content length
1190        assert_eq!(
1191            parse_range_header(b"bytEs=0-99,200-299,400-499", 500, None)
1192                .get_multirange_info()
1193                .expect("Should have multipart info for Multipart range request")
1194                .ranges,
1195            vec![
1196                Range { start: 0, end: 100 },
1197                Range {
1198                    start: 200,
1199                    end: 300
1200                },
1201                Range {
1202                    start: 400,
1203                    end: 500
1204                },
1205            ]
1206        );
1207        // Looks like a range request but it is continuous, we decline to range
1208        assert_eq!(
1209            parse_range_header(b"bytes=0-,-2", 10, None),
1210            RangeType::None,
1211        );
1212        // Should not have multirange info set
1213        assert!(parse_range_header(b"bytes=0-,-2", 10, None)
1214            .get_multirange_info()
1215            .is_none());
1216        // Overlapping ranges, these ranges are currently declined
1217        assert_eq!(
1218            parse_range_header(b"bytes=0-3,2-5", 10, None),
1219            RangeType::None,
1220        );
1221        assert!(parse_range_header(b"bytes=0-3,2-5", 10, None)
1222            .get_multirange_info()
1223            .is_none());
1224
1225        // Content length is 2, so only range is 0-2.
1226        assert_eq!(
1227            parse_range_header(b"bytes=0-5,10-", 2, None),
1228            RangeType::new_single(0, 2)
1229        );
1230        assert!(parse_range_header(b"bytes=0-5,10-", 2, None)
1231            .get_multirange_info()
1232            .is_none());
1233
1234        // We should ignore the last incorrect range and return the other acceptable ranges
1235        assert_eq!(
1236            parse_range_header(b"bytes=0-5, 10-20, 30-18", 200, None)
1237                .get_multirange_info()
1238                .expect("Should have multipart info for Multipart range request")
1239                .ranges,
1240            vec![Range { start: 0, end: 6 }, Range { start: 10, end: 21 },]
1241        );
1242        // All invalid ranges
1243        assert_eq!(
1244            parse_range_header(b"bytes=5-0, 20-15, 30-25", 200, None),
1245            RangeType::Invalid
1246        );
1247
1248        // Helper function to generate a large number of ranges for the next test
1249        fn generate_range_header(count: usize) -> Vec<u8> {
1250            let mut s = String::from("bytes=");
1251            for i in 0..count {
1252                let start = i * 4;
1253                let end = start + 1;
1254                if i > 0 {
1255                    s.push(',');
1256                }
1257                s.push_str(&start.to_string());
1258                s.push('-');
1259                s.push_str(&end.to_string());
1260            }
1261            s.into_bytes()
1262        }
1263
1264        // Test 200 range limit for parsing.
1265        let ranges = generate_range_header(201);
1266        assert_eq!(
1267            parse_range_header(&ranges, 1000, Some(200)),
1268            RangeType::None
1269        )
1270    }
1271
1272    // For Multipart Requests, we need to know the boundary, content length and type across
1273    // the headers and the body. So let us store this information as part of the range
1274    #[derive(Debug, Eq, PartialEq, Clone)]
1275    pub struct MultiRangeInfo {
1276        pub ranges: Vec<Range<usize>>,
1277        pub boundary: String,
1278        total_length: usize,
1279        content_type: Option<String>,
1280    }
1281
1282    impl MultiRangeInfo {
1283        // Create a new MultiRangeInfo, when we just have the ranges
1284        pub fn new(ranges: Vec<Range<usize>>) -> Self {
1285            Self {
1286                ranges,
1287                // Directly create boundary string on initialization
1288                boundary: Self::generate_boundary(),
1289                total_length: 0,
1290                content_type: None,
1291            }
1292        }
1293        pub fn set_content_type(&mut self, content_type: String) {
1294            self.content_type = Some(content_type)
1295        }
1296        pub fn set_total_length(&mut self, total_length: usize) {
1297            self.total_length = total_length;
1298        }
1299        // Per [RFC 9110](https://www.rfc-editor.org/rfc/rfc9110.html#multipart.byteranges),
1300        // we need generate a boundary string for each body part.
1301        // Per [RFC 2046](https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1), the boundary should be no longer than 70 characters
1302        // and it must not match the body content.
1303        fn generate_boundary() -> String {
1304            use rand::Rng;
1305            let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
1306            format!("{:016x}", rng.gen::<u64>())
1307        }
1308        pub fn calculate_multipart_length(&self) -> usize {
1309            let mut total_length = 0;
1310            let content_type = self.content_type.as_ref();
1311            for range in self.ranges.clone() {
1312                // Each part should have
1313                // \r\n--boundary\r\n                         --> 4 + boundary.len() (16) + 2 = 20
1314                // Content-Type: original-content-type\r\n    --> 14 + content_type.len() + 2
1315                // Content-Range: bytes start-end/total\r\n   --> Variable +2
1316                // \r\n                                       --> 2
1317                // [data]                                     --> data.len()
1318                total_length += 4 + self.boundary.len() + 2;
1319                total_length += content_type.map_or(0, |ct| 14 + ct.len() + 2);
1320                total_length += format!(
1321                    "Content-Range: bytes {}-{}/{}",
1322                    range.start,
1323                    range.end - 1,
1324                    self.total_length
1325                )
1326                .len()
1327                    + 2;
1328                total_length += 2;
1329                total_length += range.end - range.start;
1330            }
1331            // Final boundary: "\r\n--<boundary>--\r\n"
1332            total_length += 4 + self.boundary.len() + 4;
1333            total_length
1334        }
1335    }
1336    #[derive(Debug, Eq, PartialEq, Clone)]
1337    pub enum RangeType {
1338        None,
1339        Single(Range<usize>),
1340        Multi(MultiRangeInfo),
1341        Invalid,
1342    }
1343
1344    impl RangeType {
1345        // Helper functions for tests
1346        #[allow(dead_code)]
1347        fn new_single(start: usize, end: usize) -> Self {
1348            RangeType::Single(Range { start, end })
1349        }
1350        #[allow(dead_code)]
1351        pub fn new_multi(ranges: Vec<Range<usize>>) -> Self {
1352            RangeType::Multi(MultiRangeInfo::new(ranges))
1353        }
1354        #[allow(dead_code)]
1355        fn get_multirange_info(&self) -> Option<&MultiRangeInfo> {
1356            match self {
1357                RangeType::Multi(multi_range_info) => Some(multi_range_info),
1358                _ => None,
1359            }
1360        }
1361        #[allow(dead_code)]
1362        fn update_multirange_info(&mut self, content_length: usize, content_type: Option<String>) {
1363            if let RangeType::Multi(multipart_range_info) = self {
1364                multipart_range_info.content_type = content_type;
1365                multipart_range_info.set_total_length(content_length);
1366            }
1367        }
1368    }
1369
1370    // Handles both single-range and multipart-range requests
1371    pub fn range_header_filter(
1372        req: &RequestHeader,
1373        resp: &mut ResponseHeader,
1374        max_multipart_ranges: Option<usize>,
1375    ) -> RangeType {
1376        // The Range header field is evaluated after evaluating the precondition
1377        // header fields defined in [RFC7232], and only if the result in absence
1378        // of the Range header field would be a 200 (OK) response
1379        if resp.status != StatusCode::OK {
1380            return RangeType::None;
1381        }
1382
1383        // Content-Length is not required by RFC but it is what nginx does and easier to implement
1384        // with this header present.
1385        let Some(content_length_bytes) = resp.headers.get(CONTENT_LENGTH) else {
1386            return RangeType::None;
1387        };
1388        // bail on invalid content length
1389        let Some(content_length) = parse_number(content_length_bytes.as_bytes()) else {
1390            return RangeType::None;
1391        };
1392
1393        // At this point the response is allowed to be served as ranges
1394        // TODO: we can also check Accept-Range header from resp. Nginx gives uses the option
1395        // see proxy_force_ranges
1396
1397        fn request_range_type(
1398            req: &RequestHeader,
1399            resp: &ResponseHeader,
1400            content_length: usize,
1401            max_multipart_ranges: Option<usize>,
1402        ) -> RangeType {
1403            // "A server MUST ignore a Range header field received with a request method other than GET."
1404            if req.method != http::Method::GET && req.method != http::Method::HEAD {
1405                return RangeType::None;
1406            }
1407
1408            let Some(range_header) = req.headers.get(RANGE) else {
1409                return RangeType::None;
1410            };
1411
1412            // if-range wants to understand if the Last-Modified / ETag value matches exactly for use
1413            // with resumable downloads.
1414            // https://datatracker.ietf.org/doc/html/rfc9110#name-if-range
1415            // Note that the RFC wants strong validation, and suggests that
1416            // "A valid entity-tag can be distinguished from a valid HTTP-date
1417            // by examining the first three characters for a DQUOTE,"
1418            // but this current etag matching behavior most closely mirrors nginx.
1419            if let Some(if_range) = req.headers.get(IF_RANGE) {
1420                let ir = if_range.as_bytes();
1421                let matches = if ir.len() >= 2 && ir.last() == Some(&b'"') {
1422                    resp.headers.get(ETAG).is_some_and(|etag| etag == if_range)
1423                } else if let Some(last_modified) = resp.headers.get(LAST_MODIFIED) {
1424                    last_modified == if_range
1425                } else {
1426                    false
1427                };
1428                if !matches {
1429                    return RangeType::None;
1430                }
1431            }
1432
1433            parse_range_header(
1434                range_header.as_bytes(),
1435                content_length,
1436                max_multipart_ranges,
1437            )
1438        }
1439
1440        let mut range_type = request_range_type(req, resp, content_length, max_multipart_ranges);
1441
1442        match &mut range_type {
1443            RangeType::None => {
1444                // At this point, the response is _eligible_ to be served in ranges
1445                // in the future, so add Accept-Ranges, mirroring nginx behavior
1446                resp.insert_header(&ACCEPT_RANGES, "bytes").unwrap();
1447            }
1448            RangeType::Single(r) => {
1449                // 206 response
1450                resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1451                resp.remove_header(&ACCEPT_RANGES);
1452                resp.insert_header(&CONTENT_LENGTH, r.end - r.start)
1453                    .unwrap();
1454                resp.insert_header(
1455                    &CONTENT_RANGE,
1456                    format!("bytes {}-{}/{content_length}", r.start, r.end - 1), // range end is inclusive
1457                )
1458                .unwrap()
1459            }
1460
1461            RangeType::Multi(multi_range_info) => {
1462                let content_type = resp
1463                    .headers
1464                    .get(CONTENT_TYPE)
1465                    .and_then(|v| v.to_str().ok())
1466                    .unwrap_or("application/octet-stream");
1467                // Update multipart info
1468                multi_range_info.set_total_length(content_length);
1469                multi_range_info.set_content_type(content_type.to_string());
1470
1471                let total_length = multi_range_info.calculate_multipart_length();
1472
1473                resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap();
1474                resp.remove_header(&ACCEPT_RANGES);
1475                resp.insert_header(CONTENT_LENGTH, total_length).unwrap();
1476                resp.insert_header(
1477                    CONTENT_TYPE,
1478                    format!(
1479                        "multipart/byteranges; boundary={}",
1480                        multi_range_info.boundary
1481                    ), // RFC 2046
1482                )
1483                .unwrap();
1484                resp.remove_header(&CONTENT_RANGE);
1485            }
1486            RangeType::Invalid => {
1487                // 416 response
1488                resp.set_status(StatusCode::RANGE_NOT_SATISFIABLE).unwrap();
1489                // empty body for simplicity
1490                resp.insert_header(&CONTENT_LENGTH, HeaderValue::from_static("0"))
1491                    .unwrap();
1492                resp.remove_header(&ACCEPT_RANGES);
1493                resp.remove_header(&CONTENT_TYPE);
1494                resp.remove_header(&CONTENT_ENCODING);
1495                resp.remove_header(&TRANSFER_ENCODING);
1496                resp.insert_header(&CONTENT_RANGE, format!("bytes */{content_length}"))
1497                    .unwrap()
1498            }
1499        }
1500
1501        range_type
1502    }
1503
1504    #[test]
1505    fn test_range_filter_single() {
1506        fn gen_req() -> RequestHeader {
1507            RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap()
1508        }
1509        fn gen_resp() -> ResponseHeader {
1510            let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1511            resp.append_header("Content-Length", "10").unwrap();
1512            resp
1513        }
1514
1515        // no range
1516        let req = gen_req();
1517        let mut resp = gen_resp();
1518        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1519        assert_eq!(resp.status.as_u16(), 200);
1520        assert_eq!(
1521            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1522            b"bytes"
1523        );
1524
1525        // no range, try HEAD
1526        let mut req = gen_req();
1527        req.method = Method::HEAD;
1528        let mut resp = gen_resp();
1529        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1530        assert_eq!(resp.status.as_u16(), 200);
1531        assert_eq!(
1532            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1533            b"bytes"
1534        );
1535
1536        // regular range
1537        let mut req = gen_req();
1538        req.insert_header("Range", "bytes=0-1").unwrap();
1539        let mut resp = gen_resp();
1540        assert_eq!(
1541            RangeType::new_single(0, 2),
1542            range_header_filter(&req, &mut resp, None)
1543        );
1544        assert_eq!(resp.status.as_u16(), 206);
1545        assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1546        assert_eq!(
1547            resp.headers.get("content-range").unwrap().as_bytes(),
1548            b"bytes 0-1/10"
1549        );
1550        assert!(resp.headers.get("accept-ranges").is_none());
1551
1552        // regular range, accept-ranges included
1553        let mut req = gen_req();
1554        req.insert_header("Range", "bytes=0-1").unwrap();
1555        let mut resp = gen_resp();
1556        resp.insert_header("Accept-Ranges", "bytes").unwrap();
1557        assert_eq!(
1558            RangeType::new_single(0, 2),
1559            range_header_filter(&req, &mut resp, None)
1560        );
1561        assert_eq!(resp.status.as_u16(), 206);
1562        assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2");
1563        assert_eq!(
1564            resp.headers.get("content-range").unwrap().as_bytes(),
1565            b"bytes 0-1/10"
1566        );
1567        // accept-ranges stripped
1568        assert!(resp.headers.get("accept-ranges").is_none());
1569
1570        // bad range
1571        let mut req = gen_req();
1572        req.insert_header("Range", "bytes=1-0").unwrap();
1573        let mut resp = gen_resp();
1574        resp.insert_header("Accept-Ranges", "bytes").unwrap();
1575        resp.insert_header("Content-Encoding", "gzip").unwrap();
1576        resp.insert_header("Transfer-Encoding", "chunked").unwrap();
1577        assert_eq!(
1578            RangeType::Invalid,
1579            range_header_filter(&req, &mut resp, None)
1580        );
1581        assert_eq!(resp.status.as_u16(), 416);
1582        assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"0");
1583        assert_eq!(
1584            resp.headers.get("content-range").unwrap().as_bytes(),
1585            b"bytes */10"
1586        );
1587        assert!(resp.headers.get("accept-ranges").is_none());
1588        assert!(resp.headers.get("content-encoding").is_none());
1589        assert!(resp.headers.get("transfer-encoding").is_none());
1590    }
1591
1592    // Multipart Tests
1593    #[test]
1594    fn test_range_filter_multipart() {
1595        fn gen_req() -> RequestHeader {
1596            let mut req: RequestHeader =
1597                RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1598            req.append_header("Range", "bytes=0-1,3-4,6-7").unwrap();
1599            req
1600        }
1601        fn gen_req_overlap_range() -> RequestHeader {
1602            let mut req: RequestHeader =
1603                RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1604            req.append_header("Range", "bytes=0-3,2-5,7-8").unwrap();
1605            req
1606        }
1607        fn gen_resp() -> ResponseHeader {
1608            let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1609            resp.append_header("Content-Length", "10").unwrap();
1610            resp
1611        }
1612
1613        // valid multipart range
1614        let req = gen_req();
1615        let mut resp = gen_resp();
1616        let result = range_header_filter(&req, &mut resp, None);
1617        let mut boundary_str = String::new();
1618
1619        assert!(matches!(result, RangeType::Multi(_)));
1620        if let RangeType::Multi(multi_part_info) = result {
1621            assert_eq!(multi_part_info.ranges.len(), 3);
1622            assert_eq!(multi_part_info.ranges[0], Range { start: 0, end: 2 });
1623            assert_eq!(multi_part_info.ranges[1], Range { start: 3, end: 5 });
1624            assert_eq!(multi_part_info.ranges[2], Range { start: 6, end: 8 });
1625            // Verify that multipart info has been set
1626            assert!(multi_part_info.content_type.is_some());
1627            assert_eq!(multi_part_info.total_length, 10);
1628            assert!(!multi_part_info.boundary.is_empty());
1629            boundary_str = multi_part_info.boundary;
1630        }
1631        assert_eq!(resp.status.as_u16(), 206);
1632        // Verify that boundary is the same in header and in multipartinfo
1633        assert_eq!(
1634            resp.headers.get("content-type").unwrap().to_str().unwrap(),
1635            format!("multipart/byteranges; boundary={boundary_str}")
1636        );
1637        assert!(resp.headers.get("content_length").is_none());
1638        assert!(resp.headers.get("accept-ranges").is_none());
1639
1640        // overlapping range, multipart range is declined
1641        let req = gen_req_overlap_range();
1642        let mut resp = gen_resp();
1643        let result = range_header_filter(&req, &mut resp, None);
1644
1645        assert!(matches!(result, RangeType::None));
1646        assert_eq!(resp.status.as_u16(), 200);
1647        assert!(resp.headers.get("content-type").is_none());
1648        assert_eq!(
1649            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1650            b"bytes"
1651        );
1652
1653        // bad multipart range
1654        let mut req = gen_req();
1655        req.insert_header("Range", "bytes=1-0, 12-9, 50-40")
1656            .unwrap();
1657        let mut resp = gen_resp();
1658        resp.insert_header("Content-Encoding", "br").unwrap();
1659        resp.insert_header("Transfer-Encoding", "chunked").unwrap();
1660        let result = range_header_filter(&req, &mut resp, None);
1661        assert!(matches!(result, RangeType::Invalid));
1662        assert_eq!(resp.status.as_u16(), 416);
1663        assert!(resp.headers.get("accept-ranges").is_none());
1664        assert!(resp.headers.get("content-encoding").is_none());
1665        assert!(resp.headers.get("transfer-encoding").is_none());
1666    }
1667
1668    #[test]
1669    fn test_if_range() {
1670        const DATE: &str = "Fri, 07 Jul 2023 22:03:29 GMT";
1671        const ETAG: &str = "\"1234\"";
1672
1673        fn gen_req() -> RequestHeader {
1674            let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1675            req.append_header("Range", "bytes=0-1").unwrap();
1676            req
1677        }
1678        fn get_multipart_req() -> RequestHeader {
1679            let mut req = RequestHeader::build(http::Method::GET, b"/", Some(1)).unwrap();
1680            _ = req.append_header("Range", "bytes=0-1,3-4,6-7");
1681            req
1682        }
1683        fn gen_resp() -> ResponseHeader {
1684            let mut resp = ResponseHeader::build(200, Some(1)).unwrap();
1685            resp.append_header("Content-Length", "10").unwrap();
1686            resp.append_header("Last-Modified", DATE).unwrap();
1687            resp.append_header("ETag", ETAG).unwrap();
1688            resp
1689        }
1690
1691        // matching Last-Modified date
1692        let mut req = gen_req();
1693        req.insert_header("If-Range", DATE).unwrap();
1694        let mut resp = gen_resp();
1695        assert_eq!(
1696            RangeType::new_single(0, 2),
1697            range_header_filter(&req, &mut resp, None)
1698        );
1699
1700        // non-matching date
1701        let mut req = gen_req();
1702        req.insert_header("If-Range", "Fri, 07 Jul 2023 22:03:25 GMT")
1703            .unwrap();
1704        let mut resp = gen_resp();
1705        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1706        assert_eq!(resp.status.as_u16(), 200);
1707        assert_eq!(
1708            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1709            b"bytes"
1710        );
1711
1712        // match ETag
1713        let mut req = gen_req();
1714        req.insert_header("If-Range", ETAG).unwrap();
1715        let mut resp = gen_resp();
1716        assert_eq!(
1717            RangeType::new_single(0, 2),
1718            range_header_filter(&req, &mut resp, None)
1719        );
1720        assert_eq!(resp.status.as_u16(), 206);
1721        assert!(resp.headers.get("accept-ranges").is_none());
1722
1723        // non-matching ETags do not result in range
1724        let mut req = gen_req();
1725        req.insert_header("If-Range", "\"4567\"").unwrap();
1726        let mut resp = gen_resp();
1727        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1728        assert_eq!(resp.status.as_u16(), 200);
1729        assert_eq!(
1730            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1731            b"bytes"
1732        );
1733
1734        let mut req = gen_req();
1735        req.insert_header("If-Range", "1234").unwrap();
1736        let mut resp = gen_resp();
1737        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1738        assert_eq!(resp.status.as_u16(), 200);
1739        assert_eq!(
1740            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1741            b"bytes"
1742        );
1743
1744        // multipart range with If-Range
1745        let mut req = get_multipart_req();
1746        req.insert_header("If-Range", DATE).unwrap();
1747        let mut resp = gen_resp();
1748        let result = range_header_filter(&req, &mut resp, None);
1749        assert!(matches!(result, RangeType::Multi(_)));
1750        assert_eq!(resp.status.as_u16(), 206);
1751        assert!(resp.headers.get("accept-ranges").is_none());
1752
1753        // multipart with matching ETag
1754        let req = get_multipart_req();
1755        let mut resp = gen_resp();
1756        assert!(matches!(
1757            range_header_filter(&req, &mut resp, None),
1758            RangeType::Multi(_)
1759        ));
1760
1761        // multipart with non-matching If-Range
1762        let mut req = get_multipart_req();
1763        req.insert_header("If-Range", "\"wrong\"").unwrap();
1764        let mut resp = gen_resp();
1765        assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None));
1766        assert_eq!(resp.status.as_u16(), 200);
1767        assert_eq!(
1768            resp.headers.get("accept-ranges").unwrap().as_bytes(),
1769            b"bytes"
1770        );
1771    }
1772
1773    pub struct RangeBodyFilter {
1774        pub range: RangeType,
1775        current: usize,
1776        multipart_idx: Option<usize>,
1777        cache_multipart_idx: Option<usize>,
1778    }
1779
1780    impl Default for RangeBodyFilter {
1781        fn default() -> Self {
1782            Self::new()
1783        }
1784    }
1785
1786    impl RangeBodyFilter {
1787        pub fn new() -> Self {
1788            RangeBodyFilter {
1789                range: RangeType::None,
1790                current: 0,
1791                multipart_idx: None,
1792                cache_multipart_idx: None,
1793            }
1794        }
1795
1796        pub fn new_range(range: RangeType) -> Self {
1797            RangeBodyFilter {
1798                multipart_idx: matches!(range, RangeType::Multi(_)).then_some(0),
1799                range,
1800                ..Default::default()
1801            }
1802        }
1803
1804        pub fn is_multipart_range(&self) -> bool {
1805            matches!(self.range, RangeType::Multi(_))
1806        }
1807
1808        /// Whether we should expect the cache body reader to seek again
1809        /// for a different range.
1810        pub fn should_cache_seek_again(&self) -> bool {
1811            match &self.range {
1812                RangeType::Multi(multipart_info) => self
1813                    .cache_multipart_idx
1814                    .is_some_and(|idx| idx != multipart_info.ranges.len() - 1),
1815                _ => false,
1816            }
1817        }
1818
1819        /// Returns the next multipart range to seek for the cache body reader.
1820        pub fn next_cache_multipart_range(&mut self) -> Range<usize> {
1821            match &self.range {
1822                RangeType::Multi(multipart_info) => {
1823                    match self.cache_multipart_idx.as_mut() {
1824                        Some(v) => *v += 1,
1825                        None => self.cache_multipart_idx = Some(0),
1826                    }
1827                    let cache_multipart_idx = self.cache_multipart_idx.expect("set above");
1828                    let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1829                    // NOTE: currently this assumes once we start seeking multipart from the hit
1830                    // handler, it will continue to return can_seek_multipart true.
1831                    assert_eq!(multipart_idx, cache_multipart_idx,
1832                        "cache multipart idx should match multipart idx, or there is a hit handler bug");
1833                    multipart_info.ranges[cache_multipart_idx].clone()
1834                }
1835                _ => panic!("tried to advance multipart idx on non-multipart range"),
1836            }
1837        }
1838
1839        pub fn set_current_cursor(&mut self, current: usize) {
1840            self.current = current;
1841        }
1842
1843        pub fn set(&mut self, range: RangeType) {
1844            self.multipart_idx = matches!(range, RangeType::Multi(_)).then_some(0);
1845            self.range = range;
1846        }
1847
1848        // Emit final boundary footer for multipart requests
1849        pub fn finalize(&self, boundary: &String) -> Option<Bytes> {
1850            if let RangeType::Multi(_) = self.range {
1851                Some(Bytes::from(format!("\r\n--{boundary}--\r\n")))
1852            } else {
1853                None
1854            }
1855        }
1856
1857        pub fn filter_body(&mut self, data: Option<Bytes>) -> Option<Bytes> {
1858            match &self.range {
1859                RangeType::None => data,
1860                RangeType::Invalid => None,
1861                RangeType::Single(r) => {
1862                    let current = self.current;
1863                    self.current += data.as_ref().map_or(0, |d| d.len());
1864                    data.and_then(|d| Self::filter_range_data(r.start, r.end, current, d))
1865                }
1866
1867                RangeType::Multi(_) => {
1868                    let data = data?;
1869                    let current = self.current;
1870                    let data_len = data.len();
1871                    self.current += data_len;
1872                    self.filter_multi_range_body(data, current, data_len)
1873                }
1874            }
1875        }
1876
1877        fn filter_range_data(
1878            start: usize,
1879            end: usize,
1880            current: usize,
1881            data: Bytes,
1882        ) -> Option<Bytes> {
1883            if current + data.len() < start || current >= end {
1884                // if the current data is out side the desired range, just drop the data
1885                None
1886            } else if current >= start && current + data.len() <= end {
1887                // all data is within the slice
1888                Some(data)
1889            } else {
1890                // data:  current........current+data.len()
1891                // range: start...........end
1892                let slice_start = start.saturating_sub(current);
1893                let slice_end = std::cmp::min(data.len(), end - current);
1894                Some(data.slice(slice_start..slice_end))
1895            }
1896        }
1897
1898        // Returns the multipart header for a given range
1899        fn build_multipart_header(
1900            &self,
1901            range: &Range<usize>,
1902            boundary: &str,
1903            total_length: &usize,
1904            content_type: Option<&str>,
1905        ) -> Bytes {
1906            Bytes::from(format!(
1907                "\r\n--{}\r\n{}Content-Range: bytes {}-{}/{}\r\n\r\n",
1908                boundary,
1909                content_type.map_or(String::new(), |ct| format!("Content-Type: {ct}\r\n")),
1910                range.start,
1911                range.end - 1,
1912                total_length
1913            ))
1914        }
1915
1916        // Return true if chunk includes the start of the given range
1917        fn current_chunk_includes_range_start(
1918            &self,
1919            range: &Range<usize>,
1920            current: usize,
1921            data_len: usize,
1922        ) -> bool {
1923            range.start >= current && range.start < current + data_len
1924        }
1925
1926        // Return true if chunk includes the end of the given range
1927        fn current_chunk_includes_range_end(
1928            &self,
1929            range: &Range<usize>,
1930            current: usize,
1931            data_len: usize,
1932        ) -> bool {
1933            range.end > current && range.end <= current + data_len
1934        }
1935
1936        fn filter_multi_range_body(
1937            &mut self,
1938            data: Bytes,
1939            current: usize,
1940            data_len: usize,
1941        ) -> Option<Bytes> {
1942            let mut result = BytesMut::new();
1943
1944            let RangeType::Multi(multi_part_info) = &self.range else {
1945                return None;
1946            };
1947
1948            let multipart_idx = self.multipart_idx.expect("must be set on multirange");
1949            let final_range = multi_part_info.ranges.last()?;
1950
1951            let (_, remaining_ranges) = multi_part_info.ranges.as_slice().split_at(multipart_idx);
1952            // NOTE: current invariant is that the multipart info ranges are disjoint ascending
1953            // this code is invalid if this invariant is not upheld
1954            for range in remaining_ranges {
1955                if let Some(sliced) =
1956                    Self::filter_range_data(range.start, range.end, current, data.clone())
1957                {
1958                    if self.current_chunk_includes_range_start(range, current, data_len) {
1959                        result.extend_from_slice(&self.build_multipart_header(
1960                            range,
1961                            multi_part_info.boundary.as_ref(),
1962                            &multi_part_info.total_length,
1963                            multi_part_info.content_type.as_deref(),
1964                        ));
1965                    }
1966                    // Emit the actual data bytes
1967                    result.extend_from_slice(&sliced);
1968                    if self.current_chunk_includes_range_end(range, current, data_len) {
1969                        // If this was the last range, we should emit the final footer too
1970                        if range == final_range {
1971                            if let Some(final_chunk) = self.finalize(&multi_part_info.boundary) {
1972                                result.extend_from_slice(&final_chunk);
1973                            }
1974                        }
1975                        // done with this range
1976                        self.multipart_idx = Some(self.multipart_idx.expect("must be set") + 1);
1977                    }
1978                } else {
1979                    // no part of the data was within this range,
1980                    // so lower bound of this range (and remaining ranges) must be
1981                    // > current + data_len
1982                    break;
1983                }
1984            }
1985            if result.is_empty() {
1986                None
1987            } else {
1988                Some(result.freeze())
1989            }
1990        }
1991    }
1992
1993    #[test]
1994    fn test_range_body_filter_single() {
1995        let mut body_filter = RangeBodyFilter::new_range(RangeType::None);
1996        assert_eq!(body_filter.filter_body(Some("123".into())).unwrap(), "123");
1997
1998        let mut body_filter = RangeBodyFilter::new_range(RangeType::Invalid);
1999        assert!(body_filter.filter_body(Some("123".into())).is_none());
2000
2001        let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(0, 1));
2002        assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "0");
2003        assert!(body_filter.filter_body(Some("345".into())).is_none());
2004
2005        let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(4, 6));
2006        assert!(body_filter.filter_body(Some("012".into())).is_none());
2007        assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "45");
2008        assert!(body_filter.filter_body(Some("678".into())).is_none());
2009
2010        let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(1, 7));
2011        assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "12");
2012        assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "345");
2013        assert_eq!(body_filter.filter_body(Some("678".into())).unwrap(), "6");
2014    }
2015
2016    #[test]
2017    fn test_range_body_filter_multipart() {
2018        // Test #1 - Test multipart ranges from 1 chunk
2019        let data = Bytes::from("0123456789");
2020        let ranges = vec![0..3, 6..9];
2021        let content_length = data.len();
2022        let mut body_filter = RangeBodyFilter::new();
2023        body_filter.set(RangeType::new_multi(ranges.clone()));
2024
2025        body_filter
2026            .range
2027            .update_multirange_info(content_length, None);
2028
2029        let multi_range_info = body_filter
2030            .range
2031            .get_multirange_info()
2032            .cloned()
2033            .expect("Multipart Ranges should have MultiPartInfo struct");
2034
2035        // Pass the whole body in one chunk
2036        let output = body_filter.filter_body(Some(data)).unwrap();
2037        let footer = body_filter.finalize(&multi_range_info.boundary).unwrap();
2038
2039        // Convert to String so that we can inspect whole response
2040        let output_str = str::from_utf8(&output).unwrap();
2041        let final_boundary = str::from_utf8(&footer).unwrap();
2042        let boundary = &multi_range_info.boundary;
2043
2044        // Check part headers
2045        for (i, range) in ranges.iter().enumerate() {
2046            let header = &format!(
2047                "--{}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2048                boundary,
2049                range.start,
2050                range.end - 1,
2051                content_length
2052            );
2053            assert!(
2054                output_str.contains(header),
2055                "Missing part header {} in multipart body",
2056                i
2057            );
2058            // Check body matches
2059            let expected_body = &"0123456789"[range.clone()];
2060            assert!(
2061                output_str.contains(expected_body),
2062                "Missing body {} for range {:?}",
2063                expected_body,
2064                range
2065            )
2066        }
2067        // Check the final boundary footer
2068        assert_eq!(final_boundary, format!("\r\n--{}--\r\n", boundary));
2069
2070        // Test #2 - Test multipart ranges from multiple chunks
2071        let full_body = b"0123456789";
2072        let ranges = vec![0..2, 4..6, 8..9];
2073        let content_length = full_body.len();
2074        let content_type = "text/plain".to_string();
2075        let mut body_filter = RangeBodyFilter::new();
2076        body_filter.set(RangeType::new_multi(ranges.clone()));
2077
2078        body_filter
2079            .range
2080            .update_multirange_info(content_length, Some(content_type.clone()));
2081
2082        let multi_range_info = body_filter
2083            .range
2084            .get_multirange_info()
2085            .cloned()
2086            .expect("Multipart Ranges should have MultiPartInfo struct");
2087
2088        // Split the body into 4 chunks
2089        let chunk1 = Bytes::from_static(b"012");
2090        let chunk2 = Bytes::from_static(b"345");
2091        let chunk3 = Bytes::from_static(b"678");
2092        let chunk4 = Bytes::from_static(b"9");
2093
2094        let mut collected_bytes = BytesMut::new();
2095        for chunk in [chunk1, chunk2, chunk3, chunk4] {
2096            if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2097                collected_bytes.extend_from_slice(&filtered);
2098            }
2099        }
2100        if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2101            collected_bytes.extend_from_slice(&final_boundary);
2102        }
2103
2104        let output_str = str::from_utf8(&collected_bytes).unwrap();
2105        let boundary = multi_range_info.boundary;
2106
2107        for (i, range) in ranges.iter().enumerate() {
2108            let header = &format!(
2109                "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2110                boundary,
2111                content_type,
2112                range.start,
2113                range.end - 1,
2114                content_length
2115            );
2116            let expected_body = &full_body[range.clone()];
2117            let expected_output = format!("{}{}", header, str::from_utf8(expected_body).unwrap());
2118
2119            assert!(
2120                output_str.contains(&expected_output),
2121                "Missing or malformed part {} in multipart body. \n Expected: \n{}\n Got: \n{}",
2122                i,
2123                expected_output,
2124                output_str
2125            )
2126        }
2127
2128        assert!(
2129            output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2130            "Missing final boundary"
2131        );
2132
2133        // Test #3 - Test multipart ranges from multiple chunks, with ranges spanning chunks
2134        let full_body = b"abcdefghijkl";
2135        let ranges = vec![2..7, 9..11];
2136        let content_length = full_body.len();
2137        let content_type = "application/octet-stream".to_string();
2138        let mut body_filter = RangeBodyFilter::new();
2139        body_filter.set(RangeType::new_multi(ranges.clone()));
2140
2141        body_filter
2142            .range
2143            .update_multirange_info(content_length, Some(content_type.clone()));
2144
2145        let multi_range_info = body_filter
2146            .range
2147            .clone()
2148            .get_multirange_info()
2149            .cloned()
2150            .expect("Multipart Ranges should have MultiPartInfo struct");
2151
2152        // Split the body into 4 chunks
2153        let chunk1 = Bytes::from_static(b"abc");
2154        let chunk2 = Bytes::from_static(b"def");
2155        let chunk3 = Bytes::from_static(b"ghi");
2156        let chunk4 = Bytes::from_static(b"jkl");
2157
2158        let mut collected_bytes = BytesMut::new();
2159        for chunk in [chunk1, chunk2, chunk3, chunk4] {
2160            if let Some(filtered) = body_filter.filter_body(Some(chunk)) {
2161                collected_bytes.extend_from_slice(&filtered);
2162            }
2163        }
2164        if let Some(final_boundary) = body_filter.finalize(&multi_range_info.boundary) {
2165            collected_bytes.extend_from_slice(&final_boundary);
2166        }
2167
2168        let output_str = str::from_utf8(&collected_bytes).unwrap();
2169        let boundary = &multi_range_info.boundary;
2170
2171        let header1 = &format!(
2172            "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2173            boundary,
2174            content_type,
2175            ranges[0].start,
2176            ranges[0].end - 1,
2177            content_length
2178        );
2179        let header2 = &format!(
2180            "--{}\r\nContent-Type: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
2181            boundary,
2182            content_type,
2183            ranges[1].start,
2184            ranges[1].end - 1,
2185            content_length
2186        );
2187
2188        assert!(output_str.contains(header1));
2189        assert!(output_str.contains(header2));
2190
2191        let expected_body_slices = ["cdefg", "jk"];
2192
2193        assert!(
2194            output_str.contains(expected_body_slices[0]),
2195            "Missing expected sliced body {}",
2196            expected_body_slices[0]
2197        );
2198
2199        assert!(
2200            output_str.contains(expected_body_slices[1]),
2201            "Missing expected sliced body {}",
2202            expected_body_slices[1]
2203        );
2204
2205        assert!(
2206            output_str.ends_with(&format!("\r\n--{}--\r\n", boundary)),
2207            "Missing final boundary"
2208        );
2209    }
2210}
2211
2212// a state machine for proxy logic to tell when to use cache in the case of
2213// miss/revalidation/error.
2214#[derive(Debug)]
2215pub(crate) enum ServeFromCache {
2216    // not using cache
2217    Off,
2218    // should serve cache header
2219    CacheHeader,
2220    // should serve cache header only
2221    CacheHeaderOnly,
2222    // should serve cache header only but upstream response should be admitted to cache
2223    CacheHeaderOnlyMiss,
2224    // should serve cache body with a bool to indicate if it has already called seek on the hit handler
2225    CacheBody(bool),
2226    // should serve cache header but upstream response should be admitted to cache
2227    // This is the starting state for misses, which go to CacheBodyMiss or
2228    // CacheHeaderOnlyMiss before ending at DoneMiss
2229    CacheHeaderMiss,
2230    // should serve cache body but upstream response should be admitted to cache, bool to indicate seek status
2231    CacheBodyMiss(bool),
2232    // done serving cache body
2233    Done,
2234    // done serving cache body, but upstream response should continue to be admitted to cache
2235    DoneMiss,
2236}
2237
2238impl ServeFromCache {
2239    pub fn new() -> Self {
2240        Self::Off
2241    }
2242
2243    pub fn is_on(&self) -> bool {
2244        !matches!(self, Self::Off)
2245    }
2246
2247    pub fn is_miss(&self) -> bool {
2248        matches!(
2249            self,
2250            Self::CacheHeaderMiss
2251                | Self::CacheHeaderOnlyMiss
2252                | Self::CacheBodyMiss(_)
2253                | Self::DoneMiss
2254        )
2255    }
2256
2257    pub fn is_miss_header(&self) -> bool {
2258        // NOTE: this check is for checking if miss was just enabled, so it is excluding
2259        // HeaderOnlyMiss
2260        matches!(self, Self::CacheHeaderMiss)
2261    }
2262
2263    pub fn is_miss_body(&self) -> bool {
2264        matches!(self, Self::CacheBodyMiss(_))
2265    }
2266
2267    pub fn should_discard_upstream(&self) -> bool {
2268        self.is_on() && !self.is_miss()
2269    }
2270
2271    pub fn should_send_to_downstream(&self) -> bool {
2272        !self.is_on()
2273    }
2274
2275    pub fn enable(&mut self) {
2276        *self = Self::CacheHeader;
2277    }
2278
2279    pub fn enable_miss(&mut self) {
2280        if !self.is_on() {
2281            *self = Self::CacheHeaderMiss;
2282        }
2283    }
2284
2285    pub fn enable_header_only(&mut self) {
2286        match self {
2287            Self::CacheBody(_) => *self = Self::Done, // TODO: make sure no body is read yet
2288            Self::CacheBodyMiss(_) => *self = Self::DoneMiss,
2289            _ => {
2290                if self.is_miss() {
2291                    *self = Self::CacheHeaderOnlyMiss;
2292                } else {
2293                    *self = Self::CacheHeaderOnly;
2294                }
2295            }
2296        }
2297    }
2298
2299    // This function is (best effort) cancel-safe to be used in select
2300    pub async fn next_http_task(
2301        &mut self,
2302        cache: &mut HttpCache,
2303        range: &mut RangeBodyFilter,
2304        upgraded: bool,
2305    ) -> Result<HttpTask> {
2306        fn body_task(data: Bytes, upgraded: bool) -> HttpTask {
2307            if upgraded {
2308                HttpTask::UpgradedBody(Some(data), false)
2309            } else {
2310                HttpTask::Body(Some(data), false)
2311            }
2312        }
2313
2314        if !cache.enabled() {
2315            // Cache is disabled due to internal error
2316            // TODO: if nothing is sent to eyeball yet, figure out a way to recovery by
2317            // fetching from upstream
2318            return Error::e_explain(InternalError, "Cache disabled");
2319        }
2320        match self {
2321            Self::Off => panic!("ProxyUseCache not enabled"),
2322            Self::CacheHeader => {
2323                *self = Self::CacheBody(true);
2324                Ok(HttpTask::Header(cache_hit_header(cache), false)) // false for now
2325            }
2326            Self::CacheHeaderMiss => {
2327                *self = Self::CacheBodyMiss(true);
2328                Ok(HttpTask::Header(cache_hit_header(cache), false)) // false for now
2329            }
2330            Self::CacheHeaderOnly => {
2331                *self = Self::Done;
2332                Ok(HttpTask::Header(cache_hit_header(cache), true))
2333            }
2334            Self::CacheHeaderOnlyMiss => {
2335                *self = Self::DoneMiss;
2336                Ok(HttpTask::Header(cache_hit_header(cache), true))
2337            }
2338            Self::CacheBody(should_seek) => {
2339                log::trace!("cache body should seek: {should_seek}");
2340                if *should_seek {
2341                    self.maybe_seek_hit_handler(cache, range)?;
2342                }
2343                loop {
2344                    if let Some(b) = cache.hit_handler().read_body().await? {
2345                        return Ok(body_task(b, upgraded));
2346                    }
2347                    // EOF from hit handler for body requested
2348                    // if multipart, then seek again
2349                    if range.should_cache_seek_again() {
2350                        self.maybe_seek_hit_handler(cache, range)?;
2351                    } else {
2352                        *self = Self::Done;
2353                        return Ok(HttpTask::Done);
2354                    }
2355                }
2356            }
2357            Self::CacheBodyMiss(should_seek) => {
2358                if *should_seek {
2359                    self.maybe_seek_miss_handler(cache, range)?;
2360                }
2361                // safety: caller of enable_miss() call it only if the async_body_reader exist
2362                loop {
2363                    if let Some(b) = cache.miss_body_reader().unwrap().read_body().await? {
2364                        return Ok(body_task(b, upgraded));
2365                    } else {
2366                        // EOF from hit handler for body requested
2367                        // if multipart, then seek again
2368                        if range.should_cache_seek_again() {
2369                            self.maybe_seek_miss_handler(cache, range)?;
2370                        } else {
2371                            *self = Self::DoneMiss;
2372                            return Ok(HttpTask::Done);
2373                        }
2374                    }
2375                }
2376            }
2377            Self::Done => Ok(HttpTask::Done),
2378            Self::DoneMiss => Ok(HttpTask::Done),
2379        }
2380    }
2381
2382    fn maybe_seek_miss_handler(
2383        &mut self,
2384        cache: &mut HttpCache,
2385        range_filter: &mut RangeBodyFilter,
2386    ) -> Result<()> {
2387        match &range_filter.range {
2388            RangeType::Single(range) => {
2389                // safety: called only if the async_body_reader exists
2390                if cache.miss_body_reader().unwrap().can_seek() {
2391                    cache
2392                        .miss_body_reader()
2393                        // safety: called only if the async_body_reader exists
2394                        .unwrap()
2395                        .seek(range.start, Some(range.end))
2396                        .or_err(InternalError, "cannot seek miss handler")?;
2397                    // Because the miss body reader is seeking, we no longer need the
2398                    // RangeBodyFilter's help to return the requested byte range.
2399                    range_filter.range = RangeType::None;
2400                }
2401            }
2402            RangeType::Multi(_info) => {
2403                // safety: called only if the async_body_reader exists
2404                if cache.miss_body_reader().unwrap().can_seek_multipart() {
2405                    let range = range_filter.next_cache_multipart_range();
2406                    cache
2407                        .miss_body_reader()
2408                        .unwrap()
2409                        .seek_multipart(range.start, Some(range.end))
2410                        .or_err(InternalError, "cannot seek hit handler for multirange")?;
2411                    // we still need RangeBodyFilter's help to transform the byte
2412                    // range into a multipart response.
2413                    range_filter.set_current_cursor(range.start);
2414                }
2415            }
2416            _ => {}
2417        }
2418
2419        *self = Self::CacheBodyMiss(false);
2420        Ok(())
2421    }
2422
2423    fn maybe_seek_hit_handler(
2424        &mut self,
2425        cache: &mut HttpCache,
2426        range_filter: &mut RangeBodyFilter,
2427    ) -> Result<()> {
2428        match &range_filter.range {
2429            RangeType::Single(range) => {
2430                if cache.hit_handler().can_seek() {
2431                    cache
2432                        .hit_handler()
2433                        .seek(range.start, Some(range.end))
2434                        .or_err(InternalError, "cannot seek hit handler")?;
2435                    // Because the hit handler is seeking, we no longer need the
2436                    // RangeBodyFilter's help to return the requested byte range.
2437                    range_filter.range = RangeType::None;
2438                }
2439            }
2440            RangeType::Multi(_info) => {
2441                if cache.hit_handler().can_seek_multipart() {
2442                    let range = range_filter.next_cache_multipart_range();
2443                    cache
2444                        .hit_handler()
2445                        .seek_multipart(range.start, Some(range.end))
2446                        .or_err(InternalError, "cannot seek hit handler for multirange")?;
2447                    // we still need RangeBodyFilter's help to transform the byte
2448                    // range into a multipart response.
2449                    range_filter.set_current_cursor(range.start);
2450                }
2451            }
2452            _ => {}
2453        }
2454        *self = Self::CacheBody(false);
2455        Ok(())
2456    }
2457}