1use std::fmt::Debug;
2use std::future::Ready;
3
4use bytes::Bytes;
5use chrono::Utc;
6use futures::FutureExt;
7use futures::future::BoxFuture;
8use hitbox::{
9 CachePolicy, CacheValue, CacheableResponse, EntityPolicyConfig, predicate::PredicateResult,
10};
11use http::{HeaderMap, Response, response::Parts};
12use hyper::body::Body as HttpBody;
13use serde::{Deserialize, Serialize};
14
15use crate::CacheableSubject;
16use crate::body::BufferedBody;
17use crate::predicates::header::HasHeaders;
18use crate::predicates::version::HasVersion;
19
20#[derive(Debug)]
66pub struct CacheableHttpResponse<ResBody>
67where
68 ResBody: HttpBody,
69{
70 pub parts: Parts,
72 pub body: BufferedBody<ResBody>,
74}
75
76impl<ResBody> CacheableHttpResponse<ResBody>
77where
78 ResBody: HttpBody,
79{
80 pub fn from_response(response: Response<BufferedBody<ResBody>>) -> Self {
85 let (parts, body) = response.into_parts();
86 CacheableHttpResponse { parts, body }
87 }
88
89 pub fn into_response(self) -> Response<BufferedBody<ResBody>> {
93 Response::from_parts(self.parts, self.body)
94 }
95}
96
97impl<ResBody> CacheableSubject for CacheableHttpResponse<ResBody>
98where
99 ResBody: HttpBody,
100{
101 type Body = ResBody;
102 type Parts = Parts;
103
104 fn into_parts(self) -> (Self::Parts, BufferedBody<Self::Body>) {
105 (self.parts, self.body)
106 }
107
108 fn from_parts(parts: Self::Parts, body: BufferedBody<Self::Body>) -> Self {
109 Self { parts, body }
110 }
111}
112
113impl<ResBody> HasHeaders for CacheableHttpResponse<ResBody>
114where
115 ResBody: HttpBody,
116{
117 fn headers(&self) -> &http::HeaderMap {
118 &self.parts.headers
119 }
120}
121
122impl<ResBody> HasVersion for CacheableHttpResponse<ResBody>
123where
124 ResBody: HttpBody,
125{
126 fn http_version(&self) -> http::Version {
127 self.parts.version
128 }
129}
130
131#[cfg(feature = "rkyv_format")]
132mod rkyv_error {
133 use std::fmt;
134
135 #[derive(Debug)]
136 pub(super) enum InvalidArchivedData {
137 UnsupportedHttpVersion,
138 UnknownVersionByte(u8),
139 InvalidStatusCode(u16),
140 InvalidHeaderName(String),
141 InvalidHeaderValue(String),
142 }
143
144 impl fmt::Display for InvalidArchivedData {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 match self {
147 Self::UnsupportedHttpVersion => write!(f, "unsupported HTTP version"),
148 Self::UnknownVersionByte(v) => write!(f, "unknown HTTP version byte: {v}"),
149 Self::InvalidStatusCode(code) => write!(f, "invalid HTTP status code: {code}"),
150 Self::InvalidHeaderName(name) => write!(f, "invalid header name: {name}"),
151 Self::InvalidHeaderValue(name) => {
152 write!(f, "invalid header value for: {name}")
153 }
154 }
155 }
156 }
157
158 impl std::error::Error for InvalidArchivedData {}
159}
160
161#[cfg(feature = "rkyv_format")]
162mod rkyv_version {
163 use http::Version;
164 use rkyv::{
165 Place,
166 rancor::{Fallible, Source},
167 with::{ArchiveWith, DeserializeWith, SerializeWith},
168 };
169
170 use super::rkyv_error::InvalidArchivedData;
171
172 pub(super) struct VersionAsU8;
175
176 impl ArchiveWith<Version> for VersionAsU8 {
177 type Archived = rkyv::Archived<u8>;
178 type Resolver = rkyv::Resolver<u8>;
179
180 fn resolve_with(field: &Version, resolver: Self::Resolver, out: Place<Self::Archived>) {
181 let value = version_to_u8(*field).unwrap_or_default();
183 rkyv::Archive::resolve(&value, resolver, out);
184 }
185 }
186
187 impl<S> SerializeWith<Version, S> for VersionAsU8
188 where
189 S: Fallible + rkyv::ser::Writer + ?Sized,
190 S::Error: Source,
191 {
192 fn serialize_with(field: &Version, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
193 let value = version_to_u8(*field)
194 .ok_or_else(|| S::Error::new(InvalidArchivedData::UnsupportedHttpVersion))?;
195 rkyv::Serialize::serialize(&value, serializer)
196 }
197 }
198
199 impl<D> DeserializeWith<rkyv::Archived<u8>, Version, D> for VersionAsU8
200 where
201 D: Fallible + ?Sized,
202 D::Error: Source,
203 {
204 fn deserialize_with(
205 field: &rkyv::Archived<u8>,
206 deserializer: &mut D,
207 ) -> Result<Version, D::Error> {
208 let value: u8 = rkyv::Deserialize::deserialize(field, deserializer)?;
209 u8_to_version(value)
210 .ok_or_else(|| D::Error::new(InvalidArchivedData::UnknownVersionByte(value)))
211 }
212 }
213
214 fn version_to_u8(version: Version) -> Option<u8> {
215 Some(match version {
216 Version::HTTP_09 => 9,
217 Version::HTTP_10 => 10,
218 Version::HTTP_11 => 11,
219 Version::HTTP_2 => 20,
220 Version::HTTP_3 => 30,
221 _ => return None,
222 })
223 }
224
225 fn u8_to_version(value: u8) -> Option<Version> {
226 Some(match value {
227 9 => Version::HTTP_09,
228 10 => Version::HTTP_10,
229 11 => Version::HTTP_11,
230 20 => Version::HTTP_2,
231 30 => Version::HTTP_3,
232 _ => return None,
233 })
234 }
235}
236
237#[cfg(feature = "rkyv_format")]
238mod rkyv_status_code {
239 use http::StatusCode;
240 use rkyv::{
241 Place,
242 rancor::{Fallible, Source},
243 with::{ArchiveWith, DeserializeWith, SerializeWith},
244 };
245
246 use super::rkyv_error::InvalidArchivedData;
247
248 pub(super) struct StatusCodeAsU16;
249
250 impl ArchiveWith<StatusCode> for StatusCodeAsU16 {
251 type Archived = rkyv::Archived<u16>;
252 type Resolver = rkyv::Resolver<u16>;
253
254 fn resolve_with(field: &StatusCode, resolver: Self::Resolver, out: Place<Self::Archived>) {
255 let value = field.as_u16();
256 rkyv::Archive::resolve(&value, resolver, out);
257 }
258 }
259
260 impl<S: Fallible + rkyv::ser::Writer + ?Sized> SerializeWith<StatusCode, S> for StatusCodeAsU16 {
261 fn serialize_with(
262 field: &StatusCode,
263 serializer: &mut S,
264 ) -> Result<Self::Resolver, S::Error> {
265 rkyv::Serialize::serialize(&field.as_u16(), serializer)
266 }
267 }
268
269 impl<D> DeserializeWith<rkyv::Archived<u16>, StatusCode, D> for StatusCodeAsU16
270 where
271 D: Fallible + ?Sized,
272 D::Error: Source,
273 {
274 fn deserialize_with(
275 field: &rkyv::Archived<u16>,
276 deserializer: &mut D,
277 ) -> Result<StatusCode, D::Error> {
278 let value: u16 = rkyv::Deserialize::deserialize(field, deserializer)?;
279 StatusCode::from_u16(value)
280 .map_err(|_| D::Error::new(InvalidArchivedData::InvalidStatusCode(value)))
281 }
282 }
283}
284
285#[cfg(feature = "rkyv_format")]
286mod rkyv_header_map {
287 use http::HeaderMap;
288 use rkyv::{
289 Place,
290 rancor::{Fallible, Source},
291 with::{ArchiveWith, DeserializeWith, SerializeWith},
292 };
293
294 use super::rkyv_error::InvalidArchivedData;
295
296 pub(super) struct AsHeaderVec;
297
298 impl ArchiveWith<HeaderMap> for AsHeaderVec {
299 type Archived = rkyv::Archived<Vec<(String, Vec<u8>)>>;
300 type Resolver = rkyv::Resolver<Vec<(String, Vec<u8>)>>;
301
302 fn resolve_with(field: &HeaderMap, resolver: Self::Resolver, out: Place<Self::Archived>) {
303 let vec: Vec<(String, Vec<u8>)> = field
304 .iter()
305 .map(|(name, value)| (name.as_str().to_string(), value.as_bytes().to_vec()))
306 .collect();
307 rkyv::Archive::resolve(&vec, resolver, out);
308 }
309 }
310
311 impl<S> SerializeWith<HeaderMap, S> for AsHeaderVec
312 where
313 S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized,
314 S::Error: Source,
315 {
316 fn serialize_with(
317 field: &HeaderMap,
318 serializer: &mut S,
319 ) -> Result<Self::Resolver, S::Error> {
320 let vec: Vec<(String, Vec<u8>)> = field
321 .iter()
322 .map(|(name, value)| (name.as_str().to_string(), value.as_bytes().to_vec()))
323 .collect();
324 rkyv::Serialize::serialize(&vec, serializer)
325 }
326 }
327
328 impl<D> DeserializeWith<rkyv::Archived<Vec<(String, Vec<u8>)>>, HeaderMap, D> for AsHeaderVec
329 where
330 D: Fallible + ?Sized,
331 D::Error: Source,
332 {
333 fn deserialize_with(
334 field: &rkyv::Archived<Vec<(String, Vec<u8>)>>,
335 _deserializer: &mut D,
336 ) -> Result<HeaderMap, D::Error> {
337 let mut map = HeaderMap::with_capacity(field.len());
338
339 for item in field.iter() {
340 let name_str: &str = item.0.as_str();
341 let value_slice: &[u8] = item.1.as_slice();
342
343 let header_name = http::header::HeaderName::from_bytes(name_str.as_bytes())
344 .map_err(|_| {
345 D::Error::new(InvalidArchivedData::InvalidHeaderName(name_str.to_string()))
346 })?;
347 let header_value =
348 http::header::HeaderValue::from_bytes(value_slice).map_err(|_| {
349 D::Error::new(InvalidArchivedData::InvalidHeaderValue(
350 name_str.to_string(),
351 ))
352 })?;
353 map.append(header_name, header_value);
354 }
355 Ok(map)
356 }
357 }
358}
359
360#[derive(Serialize, Deserialize, Debug, Clone)]
382#[cfg_attr(
383 feature = "rkyv_format",
384 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
385)]
386pub struct SerializableHttpResponse {
387 #[serde(with = "http_serde::status_code")]
388 #[cfg_attr(feature = "rkyv_format", rkyv(with = rkyv_status_code::StatusCodeAsU16))]
389 status: http::StatusCode,
390 #[serde(with = "http_serde::version")]
391 #[cfg_attr(feature = "rkyv_format", rkyv(with = rkyv_version::VersionAsU8))]
392 version: http::Version,
393 body: Bytes,
394 #[serde(with = "http_serde::header_map")]
395 #[cfg_attr(feature = "rkyv_format", rkyv(with = rkyv_header_map::AsHeaderVec))]
396 headers: HeaderMap,
397}
398
399impl<ResBody> CacheableResponse for CacheableHttpResponse<ResBody>
400where
401 ResBody: HttpBody + Send + 'static,
402 ResBody::Error: Send,
403 ResBody::Data: Send,
404{
405 type Cached = SerializableHttpResponse;
406 type Subject = Self;
407 type IntoCachedFuture = BoxFuture<'static, CachePolicy<Self::Cached, Self>>;
408 type FromCachedFuture = Ready<Self>;
409
410 async fn cache_policy<P>(
411 self,
412 predicates: P,
413 config: &EntityPolicyConfig,
414 ) -> hitbox::ResponseCachePolicy<Self>
415 where
416 P: hitbox::Predicate<Subject = Self::Subject> + Send + Sync,
417 {
418 match predicates.check(self).await {
419 PredicateResult::Cacheable(cacheable) => match cacheable.into_cached().await {
420 CachePolicy::Cacheable(res) => CachePolicy::Cacheable(CacheValue::new(
421 res,
422 config.ttl.map(|duration| Utc::now() + duration),
423 config.stale_ttl.map(|duration| Utc::now() + duration),
424 )),
425 CachePolicy::NonCacheable(res) => CachePolicy::NonCacheable(res),
426 },
427 PredicateResult::NonCacheable(res) => CachePolicy::NonCacheable(res),
428 }
429 }
430
431 fn into_cached(self) -> Self::IntoCachedFuture {
432 async move {
433 let body_bytes = match self.body.collect().await {
434 Ok(bytes) => bytes,
435 Err(error_body) => {
436 return CachePolicy::NonCacheable(CacheableHttpResponse {
438 parts: self.parts,
439 body: error_body,
440 });
441 }
442 };
443
444 CachePolicy::Cacheable(SerializableHttpResponse {
447 status: self.parts.status,
448 version: self.parts.version,
449 body: body_bytes,
450 headers: self.parts.headers,
451 })
452 }
453 .boxed()
454 }
455
456 fn from_cached(cached: Self::Cached) -> Self::FromCachedFuture {
457 let body = BufferedBody::Complete(Some(cached.body));
458 let mut response = Response::new(body);
459 *response.status_mut() = cached.status;
460 *response.version_mut() = cached.version;
461 *response.headers_mut() = cached.headers;
462
463 std::future::ready(CacheableHttpResponse::from_response(response))
464 }
465}