1use std::fmt::Debug;
2use std::future::Ready;
3
4use bytes::Bytes;
5use chrono::Utc;
6use futures::FutureExt;
7use futures::future::BoxFuture;
8use hitbox::{
9 CachePolicy, CacheValue, CacheableResponse, EntityPolicyConfig, predicate::PredicateResult,
10};
11use http::{HeaderMap, Response, response::Parts};
12use hyper::body::Body as HttpBody;
13use serde::{Deserialize, Serialize};
14
15use crate::CacheableSubject;
16use crate::body::BufferedBody;
17use crate::predicates::header::HasHeaders;
18use crate::predicates::version::HasVersion;
19
20#[derive(Debug)]
66pub struct CacheableHttpResponse<ResBody>
67where
68 ResBody: HttpBody,
69{
70 pub parts: Parts,
72 pub body: BufferedBody<ResBody>,
74}
75
76impl<ResBody> CacheableHttpResponse<ResBody>
77where
78 ResBody: HttpBody,
79{
80 pub fn from_response(response: Response<BufferedBody<ResBody>>) -> Self {
85 let (parts, body) = response.into_parts();
86 CacheableHttpResponse { parts, body }
87 }
88
89 pub fn into_response(self) -> Response<BufferedBody<ResBody>> {
93 Response::from_parts(self.parts, self.body)
94 }
95}
96
97impl<ResBody> CacheableSubject for CacheableHttpResponse<ResBody>
98where
99 ResBody: HttpBody,
100{
101 type Body = ResBody;
102 type Parts = Parts;
103
104 fn into_parts(self) -> (Self::Parts, BufferedBody<Self::Body>) {
105 (self.parts, self.body)
106 }
107
108 fn from_parts(parts: Self::Parts, body: BufferedBody<Self::Body>) -> Self {
109 Self { parts, body }
110 }
111}
112
113impl<ResBody> HasHeaders for CacheableHttpResponse<ResBody>
114where
115 ResBody: HttpBody,
116{
117 fn headers(&self) -> &http::HeaderMap {
118 &self.parts.headers
119 }
120}
121
122impl<ResBody> HasVersion for CacheableHttpResponse<ResBody>
123where
124 ResBody: HttpBody,
125{
126 fn http_version(&self) -> http::Version {
127 self.parts.version
128 }
129}
130
131#[cfg(feature = "rkyv_format")]
132mod rkyv_version {
133 use http::Version;
134 use rkyv::{
135 Place,
136 rancor::Fallible,
137 with::{ArchiveWith, DeserializeWith, SerializeWith},
138 };
139
140 pub struct VersionAsU8;
143
144 impl ArchiveWith<Version> for VersionAsU8 {
145 type Archived = rkyv::Archived<u8>;
146 type Resolver = rkyv::Resolver<u8>;
147
148 fn resolve_with(field: &Version, resolver: Self::Resolver, out: Place<Self::Archived>) {
149 rkyv::Archive::resolve(&version_to_u8(*field), resolver, out);
150 }
151 }
152
153 impl<S: Fallible + rkyv::ser::Writer + ?Sized> SerializeWith<Version, S> for VersionAsU8 {
154 fn serialize_with(field: &Version, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
155 rkyv::Serialize::serialize(&version_to_u8(*field), serializer)
156 }
157 }
158
159 impl<D: Fallible + ?Sized> DeserializeWith<rkyv::Archived<u8>, Version, D> for VersionAsU8 {
160 fn deserialize_with(
161 field: &rkyv::Archived<u8>,
162 deserializer: &mut D,
163 ) -> Result<Version, D::Error> {
164 let value: u8 = rkyv::Deserialize::deserialize(field, deserializer)?;
165 Ok(u8_to_version(value))
166 }
167 }
168
169 fn version_to_u8(version: Version) -> u8 {
170 match version {
171 Version::HTTP_09 => 9,
172 Version::HTTP_10 => 10,
173 Version::HTTP_11 => 11,
174 Version::HTTP_2 => 20,
175 Version::HTTP_3 => 30,
176 _ => 11, }
178 }
179
180 fn u8_to_version(value: u8) -> Version {
181 match value {
182 9 => Version::HTTP_09,
183 10 => Version::HTTP_10,
184 11 => Version::HTTP_11,
185 20 => Version::HTTP_2,
186 30 => Version::HTTP_3,
187 _ => Version::HTTP_11, }
189 }
190}
191
192#[cfg(feature = "rkyv_format")]
193mod rkyv_status_code {
194 use http::StatusCode;
195 use rkyv::{
196 Place,
197 rancor::Fallible,
198 with::{ArchiveWith, DeserializeWith, SerializeWith},
199 };
200
201 pub struct StatusCodeAsU16;
202
203 impl ArchiveWith<StatusCode> for StatusCodeAsU16 {
204 type Archived = rkyv::Archived<u16>;
205 type Resolver = rkyv::Resolver<u16>;
206
207 fn resolve_with(field: &StatusCode, resolver: Self::Resolver, out: Place<Self::Archived>) {
208 let value = field.as_u16();
209 rkyv::Archive::resolve(&value, resolver, out);
210 }
211 }
212
213 impl<S: Fallible + rkyv::ser::Writer + ?Sized> SerializeWith<StatusCode, S> for StatusCodeAsU16 {
214 fn serialize_with(
215 field: &StatusCode,
216 serializer: &mut S,
217 ) -> Result<Self::Resolver, S::Error> {
218 rkyv::Serialize::serialize(&field.as_u16(), serializer)
219 }
220 }
221
222 impl<D: Fallible + ?Sized> DeserializeWith<rkyv::Archived<u16>, StatusCode, D> for StatusCodeAsU16 {
223 fn deserialize_with(
224 field: &rkyv::Archived<u16>,
225 deserializer: &mut D,
226 ) -> Result<StatusCode, D::Error> {
227 let value: u16 = rkyv::Deserialize::deserialize(field, deserializer)?;
228 Ok(StatusCode::from_u16(value).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
230 }
231 }
232}
233
234#[cfg(feature = "rkyv_format")]
235mod rkyv_header_map {
236 use http::HeaderMap;
237 use rkyv::{
238 Place,
239 rancor::Fallible,
240 with::{ArchiveWith, DeserializeWith, SerializeWith},
241 };
242
243 pub struct AsHeaderVec;
244
245 impl ArchiveWith<HeaderMap> for AsHeaderVec {
246 type Archived = rkyv::Archived<Vec<(String, Vec<u8>)>>;
247 type Resolver = rkyv::Resolver<Vec<(String, Vec<u8>)>>;
248
249 fn resolve_with(field: &HeaderMap, resolver: Self::Resolver, out: Place<Self::Archived>) {
250 let vec: Vec<(String, Vec<u8>)> = field
251 .iter()
252 .map(|(name, value)| (name.as_str().to_string(), value.as_bytes().to_vec()))
253 .collect();
254 rkyv::Archive::resolve(&vec, resolver, out);
255 }
256 }
257
258 impl<S> SerializeWith<HeaderMap, S> for AsHeaderVec
259 where
260 S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized,
261 S::Error: rkyv::rancor::Source,
262 {
263 fn serialize_with(
264 field: &HeaderMap,
265 serializer: &mut S,
266 ) -> Result<Self::Resolver, S::Error> {
267 let vec: Vec<(String, Vec<u8>)> = field
268 .iter()
269 .map(|(name, value)| (name.as_str().to_string(), value.as_bytes().to_vec()))
270 .collect();
271 rkyv::Serialize::serialize(&vec, serializer)
272 }
273 }
274
275 impl<D> DeserializeWith<rkyv::Archived<Vec<(String, Vec<u8>)>>, HeaderMap, D> for AsHeaderVec
276 where
277 D: Fallible + ?Sized,
278 {
279 fn deserialize_with(
280 field: &rkyv::Archived<Vec<(String, Vec<u8>)>>,
281 _deserializer: &mut D,
282 ) -> Result<HeaderMap, D::Error> {
283 let mut map = HeaderMap::with_capacity(field.len());
286
287 for item in field.iter() {
288 let name_str: &str = item.0.as_str();
290 let value_slice: &[u8] = item.1.as_slice();
291
292 if let (Ok(header_name), Ok(header_value)) = (
293 http::header::HeaderName::from_bytes(name_str.as_bytes()),
294 http::header::HeaderValue::from_bytes(value_slice),
295 ) {
296 map.append(header_name, header_value);
297 }
298 }
299 Ok(map)
300 }
301 }
302}
303
304#[derive(Serialize, Deserialize, Debug, Clone)]
326#[cfg_attr(
327 feature = "rkyv_format",
328 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
329)]
330pub struct SerializableHttpResponse {
331 #[serde(with = "http_serde::status_code")]
332 #[cfg_attr(feature = "rkyv_format", rkyv(with = rkyv_status_code::StatusCodeAsU16))]
333 status: http::StatusCode,
334 #[serde(with = "http_serde::version")]
335 #[cfg_attr(feature = "rkyv_format", rkyv(with = rkyv_version::VersionAsU8))]
336 version: http::Version,
337 body: Bytes,
338 #[serde(with = "http_serde::header_map")]
339 #[cfg_attr(feature = "rkyv_format", rkyv(with = rkyv_header_map::AsHeaderVec))]
340 headers: HeaderMap,
341}
342
343impl<ResBody> CacheableResponse for CacheableHttpResponse<ResBody>
344where
345 ResBody: HttpBody + Send + 'static,
346 ResBody::Error: Send,
347 ResBody::Data: Send,
348{
349 type Cached = SerializableHttpResponse;
350 type Subject = Self;
351 type IntoCachedFuture = BoxFuture<'static, CachePolicy<Self::Cached, Self>>;
352 type FromCachedFuture = Ready<Self>;
353
354 async fn cache_policy<P>(
355 self,
356 predicates: P,
357 config: &EntityPolicyConfig,
358 ) -> hitbox::ResponseCachePolicy<Self>
359 where
360 P: hitbox::Predicate<Subject = Self::Subject> + Send + Sync,
361 {
362 match predicates.check(self).await {
363 PredicateResult::Cacheable(cacheable) => match cacheable.into_cached().await {
364 CachePolicy::Cacheable(res) => CachePolicy::Cacheable(CacheValue::new(
365 res,
366 config.ttl.map(|duration| Utc::now() + duration),
367 config.stale_ttl.map(|duration| Utc::now() + duration),
368 )),
369 CachePolicy::NonCacheable(res) => CachePolicy::NonCacheable(res),
370 },
371 PredicateResult::NonCacheable(res) => CachePolicy::NonCacheable(res),
372 }
373 }
374
375 fn into_cached(self) -> Self::IntoCachedFuture {
376 async move {
377 let body_bytes = match self.body.collect().await {
378 Ok(bytes) => bytes,
379 Err(error_body) => {
380 return CachePolicy::NonCacheable(CacheableHttpResponse {
382 parts: self.parts,
383 body: error_body,
384 });
385 }
386 };
387
388 CachePolicy::Cacheable(SerializableHttpResponse {
391 status: self.parts.status,
392 version: self.parts.version,
393 body: body_bytes,
394 headers: self.parts.headers,
395 })
396 }
397 .boxed()
398 }
399
400 fn from_cached(cached: Self::Cached) -> Self::FromCachedFuture {
401 let body = BufferedBody::Complete(Some(cached.body));
402 let mut response = Response::new(body);
403 *response.status_mut() = cached.status;
404 *response.version_mut() = cached.version;
405 *response.headers_mut() = cached.headers;
406
407 std::future::ready(CacheableHttpResponse::from_response(response))
408 }
409}