1use super::{
16 HTTP_HEADER_CONTENT_HTML, HTTP_HEADER_CONTENT_JSON,
17 HTTP_HEADER_CONTENT_TEXT, HTTP_HEADER_NO_CACHE, HTTP_HEADER_NO_STORE,
18 HTTP_HEADER_TRANSFER_CHUNKED, HttpHeader, LOG_TARGET, get_super_ts,
19 new_internal_error,
20};
21use bytes::{Bytes, BytesMut};
22use http::StatusCode;
23use http::header;
24use http::{HeaderName, HeaderValue};
25use pingora::http::ResponseHeader;
26use pingora::proxy::Session;
27use serde::Serialize;
28use std::pin::Pin;
29use tokio::io::AsyncReadExt;
30use tracing::error;
31
32fn new_cache_control_header(
46 max_age: Option<u32>,
47 cache_private: Option<bool>,
48) -> HttpHeader {
49 let max_age = match max_age {
51 Some(0) | None => return HTTP_HEADER_NO_CACHE.clone(),
52 Some(age) => age,
53 };
54
55 let category: &[u8] = if cache_private.unwrap_or_default() {
57 b"private"
58 } else {
59 b"public"
60 };
61
62 let mut buf = BytesMut::with_capacity(category.len() + 9 + 10); buf.extend_from_slice(category);
67 buf.extend_from_slice(b", max-age=");
68 buf.extend_from_slice(itoa::Buffer::new().format(max_age).as_bytes());
69
70 if let Ok(value) = HeaderValue::from_bytes(&buf) {
72 return (header::CACHE_CONTROL, value);
73 }
74 HTTP_HEADER_NO_CACHE.clone()
76}
77
78#[derive(Default, Debug)]
82pub struct HttpResponseBuilder {
83 response: HttpResponse,
85}
86
87impl HttpResponseBuilder {
88 pub fn new(status: StatusCode) -> Self {
90 Self {
91 response: HttpResponse {
92 status,
93 ..Default::default()
94 },
95 }
96 }
97
98 pub fn body(mut self, body: impl Into<Bytes>) -> Self {
103 self.response.body = body.into();
104 self
105 }
106
107 pub fn header(mut self, header: HttpHeader) -> Self {
111 self.response
112 .headers
113 .get_or_insert_with(Vec::new)
114 .push(header);
115 self
116 }
117
118 pub fn headers(mut self, headers: Vec<HttpHeader>) -> Self {
120 self.response
121 .headers
122 .get_or_insert_with(Vec::new)
123 .extend(headers);
124 self
125 }
126
127 pub fn max_age(mut self, seconds: u32, is_private: bool) -> Self {
129 self.response.max_age = Some(seconds);
130 self.response.cache_private = Some(is_private);
131 self
132 }
133
134 pub fn no_store(self) -> Self {
136 self.header(HTTP_HEADER_NO_STORE.clone())
137 }
138
139 pub fn finish(self) -> HttpResponse {
141 self.response
142 }
143}
144
145#[derive(Default, Clone, Debug)]
148pub struct HttpResponse {
149 pub status: StatusCode,
151 pub body: Bytes,
153 pub max_age: Option<u32>,
155 pub created_at: Option<u32>,
157 pub cache_private: Option<bool>,
159 pub headers: Option<Vec<HttpHeader>>,
161}
162
163impl HttpResponse {
164 pub fn builder(status: StatusCode) -> HttpResponseBuilder {
166 HttpResponseBuilder::new(status)
167 }
168
169 pub fn no_content() -> Self {
171 Self::builder(StatusCode::NO_CONTENT).no_store().finish()
172 }
173
174 pub fn bad_request(body: impl Into<Bytes>) -> Self {
176 Self::builder(StatusCode::BAD_REQUEST)
177 .body(body)
178 .header(HTTP_HEADER_CONTENT_TEXT.clone())
179 .no_store()
180 .finish()
181 }
182
183 pub fn not_found(body: impl Into<Bytes>) -> Self {
185 Self::builder(StatusCode::NOT_FOUND)
186 .body(body)
187 .header(HTTP_HEADER_CONTENT_TEXT.clone())
188 .no_store()
189 .finish()
190 }
191
192 pub fn unknown_error(body: impl Into<Bytes>) -> Self {
194 Self::builder(StatusCode::INTERNAL_SERVER_ERROR)
195 .body(body)
196 .header(HTTP_HEADER_CONTENT_TEXT.clone())
197 .no_store()
198 .finish()
199 }
200
201 pub fn html(body: impl Into<Bytes>) -> Self {
203 Self::builder(StatusCode::OK)
204 .body(body)
205 .header(HTTP_HEADER_CONTENT_HTML.clone())
206 .header(HTTP_HEADER_NO_CACHE.clone())
207 .finish()
208 }
209
210 pub fn redirect(location: &str) -> pingora::Result<Self> {
212 let value = HeaderValue::from_str(location).map_err(|e| {
214 error!(error = e.to_string(), "to header value fail");
215 new_internal_error(500, e)
216 })?;
217 Ok(Self::builder(StatusCode::TEMPORARY_REDIRECT)
219 .header((header::LOCATION, value))
220 .header(HTTP_HEADER_NO_CACHE.clone())
221 .finish())
222 }
223
224 pub fn text(body: impl Into<Bytes>) -> Self {
226 Self::builder(StatusCode::OK)
227 .body(body)
228 .header(HTTP_HEADER_CONTENT_TEXT.clone())
229 .header(HTTP_HEADER_NO_CACHE.clone())
230 .finish()
231 }
232
233 pub fn try_from_json<T>(value: &T) -> pingora::Result<Self>
235 where
236 T: ?Sized + Serialize,
237 {
238 let buf = serde_json::to_vec(value).map_err(|e| {
240 error!(target: LOG_TARGET, error = e.to_string(), "to json fail");
241 new_internal_error(400, e)
242 })?;
243 Ok(Self::builder(StatusCode::OK)
245 .body(buf)
246 .header(HTTP_HEADER_CONTENT_JSON.clone())
247 .finish())
248 }
249
250 pub fn try_from_json_status<T>(
252 value: &T,
253 status: StatusCode,
254 ) -> pingora::Result<Self>
255 where
256 T: ?Sized + Serialize,
257 {
258 let mut resp = Self::try_from_json(value)?;
260 resp.status = status;
262 Ok(resp)
263 }
264
265 pub fn new_response_header(&self) -> pingora::Result<ResponseHeader> {
267 let mut resp = ResponseHeader::build(self.status, None)?;
269
270 let mut add_header =
272 |name: &HeaderName, value: &HeaderValue| -> pingora::Result<()> {
273 resp.insert_header(name, value)?;
274 Ok(())
275 };
276
277 add_header(
279 &header::CONTENT_LENGTH,
280 &HeaderValue::from(self.body.len()),
281 )?;
282
283 let (name, value) =
285 new_cache_control_header(self.max_age, self.cache_private);
286 add_header(&name, &value)?;
287
288 if let Some(created_at) = self.created_at {
290 let secs = get_super_ts().saturating_sub(created_at);
291 add_header(&header::AGE, &HeaderValue::from(secs))?;
292 }
293
294 if let Some(headers) = &self.headers {
296 for (name, value) in headers {
297 add_header(name, value)?;
298 }
299 }
300 Ok(resp)
301 }
302
303 pub async fn send(self, session: &mut Session) -> pingora::Result<usize> {
305 let header = self.new_response_header()?;
307 let size = self.body.len();
308 session
310 .write_response_header(Box::new(header), false)
311 .await?;
312 session.write_response_body(Some(self.body), true).await?;
314 session.finish_body().await?;
316 Ok(size)
317 }
318}
319
320pub struct HttpChunkResponse<'r, R> {
324 pub reader: Pin<&'r mut R>,
326 pub chunk_size: usize,
328 pub max_age: Option<u32>,
330 pub cache_private: Option<bool>,
332 pub headers: Option<Vec<HttpHeader>>,
334}
335
336const DEFAULT_BUF_SIZE: usize = 8 * 1024;
338
339impl<'r, R> HttpChunkResponse<'r, R>
340where
341 R: tokio::io::AsyncRead + std::marker::Unpin,
343{
344 pub fn new(r: &'r mut R) -> Self {
346 Self {
347 reader: Pin::new(r),
348 chunk_size: DEFAULT_BUF_SIZE,
349 max_age: None,
350 headers: None,
351 cache_private: None,
352 }
353 }
354
355 pub fn get_response_header(&self) -> pingora::Result<ResponseHeader> {
359 let mut resp = ResponseHeader::build(StatusCode::OK, Some(4))?;
361 if let Some(headers) = &self.headers {
363 for (name, value) in headers {
364 resp.insert_header(name.to_owned(), value)?;
365 }
366 }
367
368 let chunked = HTTP_HEADER_TRANSFER_CHUNKED.clone();
370 resp.insert_header(chunked.0, chunked.1)?;
371
372 let cache_control =
374 new_cache_control_header(self.max_age, self.cache_private);
375 resp.insert_header(cache_control.0, cache_control.1)?;
376 Ok(resp)
377 }
378
379 pub async fn send(
384 mut self,
385 session: &mut Session,
386 ) -> pingora::Result<usize> {
387 let header = self.get_response_header()?;
389 session
390 .write_response_header(Box::new(header), false)
391 .await?;
392
393 let mut sent = 0;
394 let chunk_size = self.chunk_size.max(512);
396 let mut buffer = vec![0; chunk_size];
398 loop {
399 let size = self.reader.read(&mut buffer).await.map_err(|e| {
401 error!(error = e.to_string(), "read data fail");
402 new_internal_error(400, e)
403 })?;
404 let end = size < chunk_size;
406 session
408 .write_response_body(
409 Some(Bytes::copy_from_slice(&buffer[..size])),
411 end,
412 )
413 .await?;
414 sent += size;
415 if end {
417 break;
418 }
419 }
420 session.finish_body().await?;
422
423 Ok(sent)
424 }
425}
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::convert_headers;
430 use bytes::Bytes;
431 use http::StatusCode;
432 use pretty_assertions::assert_eq;
433 use serde::Serialize;
434 use std::io::Write;
435 use tempfile::NamedTempFile;
436 use tokio::fs;
437 #[test]
438 fn test_new_cache_control_header() {
439 assert_eq!(
440 r###"("cache-control", "private, max-age=3600")"###,
441 format!("{:?}", new_cache_control_header(Some(3600), Some(true)))
442 );
443 assert_eq!(
444 r###"("cache-control", "public, max-age=3600")"###,
445 format!("{:?}", new_cache_control_header(Some(3600), None))
446 );
447 assert_eq!(
448 r###"("cache-control", "private, no-cache")"###,
449 format!("{:?}", new_cache_control_header(Some(0), Some(false)))
450 );
451 assert_eq!(
452 r###"("cache-control", "private, no-cache")"###,
453 format!("{:?}", new_cache_control_header(None, None))
454 );
455 }
456
457 #[test]
458 fn test_http_response() {
459 assert_eq!(
460 r###"HttpResponse { status: 204, body: b"", max_age: None, created_at: None, cache_private: None, headers: Some([("cache-control", "private, no-store")]) }"###,
461 format!("{:?}", HttpResponse::no_content())
462 );
463 assert_eq!(
464 r###"HttpResponse { status: 404, body: b"Not Found", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-store")]) }"###,
465 format!("{:?}", HttpResponse::not_found("Not Found"))
466 );
467 assert_eq!(
468 r###"HttpResponse { status: 500, body: b"Unknown Error", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-store")]) }"###,
469 format!("{:?}", HttpResponse::unknown_error("Unknown Error"))
470 );
471
472 assert_eq!(
473 r###"HttpResponse { status: 400, body: b"Bad Request", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-store")]) }"###,
474 format!("{:?}", HttpResponse::bad_request("Bad Request"))
475 );
476
477 assert_eq!(
478 r###"HttpResponse { status: 200, body: b"<p>Pingap</p>", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/html; charset=utf-8"), ("cache-control", "private, no-cache")]) }"###,
479 format!("{:?}", HttpResponse::html("<p>Pingap</p>"))
480 );
481
482 assert_eq!(
483 r###"HttpResponse { status: 307, body: b"", max_age: None, created_at: None, cache_private: None, headers: Some([("location", "http://example.com/"), ("cache-control", "private, no-cache")]) }"###,
484 format!(
485 "{:?}",
486 HttpResponse::redirect("http://example.com/").unwrap()
487 )
488 );
489
490 assert_eq!(
491 r###"HttpResponse { status: 200, body: b"Hello World!", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-cache")]) }"###,
492 format!("{:?}", HttpResponse::text("Hello World!"))
493 );
494
495 #[derive(Serialize)]
496 struct Data {
497 message: String,
498 }
499 let resp = HttpResponse::try_from_json_status(
500 &Data {
501 message: "Hello World!".to_string(),
502 },
503 StatusCode::BAD_REQUEST,
504 )
505 .unwrap();
506 assert_eq!(
507 r###"HttpResponse { status: 400, body: b"{\"message\":\"Hello World!\"}", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "application/json; charset=utf-8")]) }"###,
508 format!("{resp:?}")
509 );
510 let resp = HttpResponse::try_from_json(&Data {
511 message: "Hello World!".to_string(),
512 })
513 .unwrap();
514 assert_eq!(
515 r###"HttpResponse { status: 200, body: b"{\"message\":\"Hello World!\"}", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "application/json; charset=utf-8")]) }"###,
516 format!("{resp:?}")
517 );
518
519 let resp = HttpResponse {
520 status: StatusCode::OK,
521 body: Bytes::from("Hello world!"),
522 max_age: Some(3600),
523 created_at: Some(0),
524 cache_private: Some(true),
525 headers: Some(
526 convert_headers(&[
527 "Contont-Type: application/json".to_string(),
528 "Content-Encoding: gzip".to_string(),
529 ])
530 .unwrap(),
531 ),
532 };
533
534 let mut header = resp.new_response_header().unwrap();
535 assert_eq!(true, !header.headers.get("Age").unwrap().is_empty());
536 header.remove_header("Age").unwrap();
537
538 assert_eq!(
539 r###"ResponseHeader { base: Parts { status: 200, version: HTTP/1.1, headers: {"content-length": "12", "cache-control": "private, max-age=3600", "content-encoding": "gzip", "contont-type": "application/json"} }, header_name_map: Some({"content-length": CaseHeaderName(b"Content-Length"), "cache-control": CaseHeaderName(b"Cache-Control"), "content-encoding": CaseHeaderName(b"Content-Encoding"), "contont-type": CaseHeaderName(b"contont-type")}), reason_phrase: None }"###,
540 format!("{header:?}")
541 );
542 }
543
544 #[tokio::test]
545 async fn test_http_chunk_response() {
546 let file = include_bytes!("../../error.html");
547 let mut f = NamedTempFile::new().unwrap();
548 f.write_all(file).unwrap();
549 let mut f = fs::OpenOptions::new().read(true).open(f).await.unwrap();
550 let mut resp = HttpChunkResponse::new(&mut f);
551 resp.max_age = Some(3600);
552 resp.cache_private = Some(false);
553 resp.headers = Some(
554 convert_headers(&["Contont-Type: text/html".to_string()]).unwrap(),
555 );
556 let header = resp.get_response_header().unwrap();
557 assert_eq!(
558 r###"ResponseHeader { base: Parts { status: 200, version: HTTP/1.1, headers: {"contont-type": "text/html", "transfer-encoding": "chunked", "cache-control": "public, max-age=3600"} }, header_name_map: Some({"contont-type": CaseHeaderName(b"contont-type"), "transfer-encoding": CaseHeaderName(b"Transfer-Encoding"), "cache-control": CaseHeaderName(b"Cache-Control")}), reason_phrase: None }"###,
559 format!("{header:?}")
560 );
561 }
562
563 #[test]
564 fn test_new_cache_control_header_logic() {
565 let (name, value) = new_cache_control_header(Some(3600), Some(true));
567 assert_eq!(name, header::CACHE_CONTROL);
568 assert_eq!(value.to_str().unwrap(), "private, max-age=3600");
569
570 let (name, value) = new_cache_control_header(Some(3600), Some(false));
571 assert_eq!(name, header::CACHE_CONTROL);
572 assert_eq!(value.to_str().unwrap(), "public, max-age=3600");
573
574 let (name, value) = new_cache_control_header(Some(3600), None);
576 assert_eq!(name, header::CACHE_CONTROL);
577 assert_eq!(value.to_str().unwrap(), "public, max-age=3600");
578
579 let (name, value) = new_cache_control_header(Some(0), Some(true));
581 assert_eq!(name, header::CACHE_CONTROL);
582 assert_eq!(value, HTTP_HEADER_NO_CACHE.clone().1);
583
584 let (name, value) = new_cache_control_header(None, Some(false));
586 assert_eq!(name, header::CACHE_CONTROL);
587 assert_eq!(value, HTTP_HEADER_NO_CACHE.clone().1);
588 }
589
590 #[test]
591 fn test_http_response_builder_pattern() {
592 let etag_header = (header::ETAG, HeaderValue::from_static("\"12345\""));
594 let server_header =
595 (header::SERVER, HeaderValue::from_static("MyTestServer"));
596
597 let response = HttpResponse::builder(StatusCode::OK)
598 .body("Test Body")
599 .header(etag_header.clone())
600 .headers(vec![server_header.clone()])
601 .max_age(60, true) .finish();
603
604 assert_eq!(response.status, StatusCode::OK);
605 assert_eq!(response.body, Bytes::from("Test Body"));
606 assert_eq!(response.max_age, Some(60));
607 assert_eq!(response.cache_private, Some(true));
608
609 let headers = response.headers.unwrap();
611 assert_eq!(headers.len(), 2);
612 assert!(headers.contains(&etag_header));
613 assert!(headers.contains(&server_header));
614
615 let no_store_response = HttpResponse::builder(StatusCode::ACCEPTED)
617 .no_store()
618 .finish();
619
620 assert_eq!(no_store_response.status, StatusCode::ACCEPTED);
621 assert!(
622 no_store_response
623 .headers
624 .unwrap()
625 .contains(&HTTP_HEADER_NO_STORE.clone())
626 );
627 }
628
629 #[test]
630 fn test_http_response_error_cases() {
631 let invalid_location = "http://example.com/\0";
634 let result = HttpResponse::redirect(invalid_location);
635 assert!(result.is_err());
636 }
637
638 #[test]
639 fn test_new_response_header_generation() {
640 let resp = HttpResponse {
641 status: StatusCode::OK,
642 body: Bytes::from("Hello world!"),
643 max_age: Some(3600),
644 created_at: Some(get_super_ts().saturating_sub(10)), cache_private: Some(true),
646 headers: Some(vec![(
647 header::CONTENT_ENCODING,
648 HeaderValue::from_static("gzip"),
649 )]),
650 };
651
652 let header = resp.new_response_header().unwrap();
653 let headers_map: std::collections::HashMap<_, _> =
654 header.headers.iter().collect();
655
656 assert_eq!(header.status, StatusCode::OK);
658 assert_eq!(
659 headers_map
660 .get(&header::CONTENT_LENGTH)
661 .unwrap()
662 .to_str()
663 .unwrap(),
664 "12"
665 );
666 assert_eq!(
667 headers_map
668 .get(&header::CACHE_CONTROL)
669 .unwrap()
670 .to_str()
671 .unwrap(),
672 "private, max-age=3600"
673 );
674 assert_eq!(
675 headers_map
676 .get(&header::CONTENT_ENCODING)
677 .unwrap()
678 .to_str()
679 .unwrap(),
680 "gzip"
681 );
682 }
683}