1#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
2use anyhow::Result;
53use async_trait::async_trait;
54use bytes::Bytes;
55use flate2::{
56 Compression as GzLevel,
57 write::{DeflateEncoder, GzEncoder},
58};
59use http::{
60 HeaderValue, StatusCode,
61 header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, VARY},
62};
63use http_body_util::BodyExt;
64use std::io::{Read, Write};
65
66pub mod brotli_stream;
67pub mod deflate_stream;
68pub mod gzip_stream;
69pub mod zstd_stream;
70
71#[cfg(feature = "zstd")]
72use zstd::stream::encode_all as zstd_encode;
73
74#[cfg(feature = "zstd")]
75use crate::plugins::compression::zstd_stream::stream_zstd;
76use crate::{
77 body::TakoBody,
78 middleware::Next,
79 plugins::{
80 TakoPlugin,
81 compression::{
82 brotli_stream::stream_brotli, deflate_stream::stream_deflate, gzip_stream::stream_gzip,
83 },
84 },
85 responder::Responder,
86 router::Router,
87 types::{Request, Response},
88};
89
90#[derive(Clone, Copy, Debug, PartialEq, Eq)]
92pub enum Encoding {
93 Gzip,
95 Brotli,
97 Deflate,
99 #[cfg(feature = "zstd")]
101 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
102 Zstd,
103}
104
105impl Encoding {
106 fn as_str(&self) -> &'static str {
108 match self {
109 Encoding::Gzip => "gzip",
110 Encoding::Brotli => "br",
111 Encoding::Deflate => "deflate",
112 #[cfg(feature = "zstd")]
113 Encoding::Zstd => "zstd",
114 }
115 }
116}
117
118#[derive(Clone)]
120pub struct Config {
121 pub enabled: Vec<Encoding>,
123 pub min_size: usize,
125 pub gzip_level: u32,
127 pub brotli_level: u32,
129 pub deflate_level: u32,
131 #[cfg(feature = "zstd")]
133 pub zstd_level: i32,
134 pub stream: bool,
136}
137
138impl Default for Config {
139 fn default() -> Self {
141 Self {
142 enabled: vec![Encoding::Gzip, Encoding::Brotli, Encoding::Deflate],
143 min_size: 1024,
144 gzip_level: 5,
145 brotli_level: 5,
146 deflate_level: 5,
147 #[cfg(feature = "zstd")]
148 zstd_level: 3,
149 stream: false,
150 }
151 }
152}
153
154pub struct CompressionBuilder(Config);
181
182impl CompressionBuilder {
183 pub fn new() -> Self {
185 Self(Config::default())
186 }
187
188 pub fn enable_gzip(mut self, yes: bool) -> Self {
190 if yes && !self.0.enabled.contains(&Encoding::Gzip) {
191 self.0.enabled.push(Encoding::Gzip)
192 }
193 if !yes {
194 self.0.enabled.retain(|e| *e != Encoding::Gzip)
195 }
196 self
197 }
198
199 pub fn enable_brotli(mut self, yes: bool) -> Self {
201 if yes && !self.0.enabled.contains(&Encoding::Brotli) {
202 self.0.enabled.push(Encoding::Brotli)
203 }
204 if !yes {
205 self.0.enabled.retain(|e| *e != Encoding::Brotli)
206 }
207 self
208 }
209
210 pub fn enable_deflate(mut self, yes: bool) -> Self {
212 if yes && !self.0.enabled.contains(&Encoding::Deflate) {
213 self.0.enabled.push(Encoding::Deflate)
214 }
215 if !yes {
216 self.0.enabled.retain(|e| *e != Encoding::Deflate)
217 }
218 self
219 }
220
221 #[cfg(feature = "zstd")]
223 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
224 pub fn enable_zstd(mut self, yes: bool) -> Self {
225 if yes && !self.0.enabled.contains(&Encoding::Zstd) {
226 self.0.enabled.push(Encoding::Zstd)
227 }
228 if !yes {
229 self.0.enabled.retain(|e| *e != Encoding::Zstd)
230 }
231 self
232 }
233
234 pub fn enable_stream(mut self, stream: bool) -> Self {
236 self.0.stream = stream;
237 self
238 }
239
240 pub fn min_size(mut self, bytes: usize) -> Self {
242 self.0.min_size = bytes;
243 self
244 }
245
246 pub fn gzip_level(mut self, lvl: u32) -> Self {
248 self.0.gzip_level = lvl.min(9);
249 self
250 }
251
252 pub fn brotli_level(mut self, lvl: u32) -> Self {
254 self.0.brotli_level = lvl.min(11);
255 self
256 }
257
258 pub fn deflate_level(mut self, lvl: u32) -> Self {
260 self.0.deflate_level = lvl.min(9);
261 self
262 }
263
264 #[cfg(feature = "zstd")]
266 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
267 pub fn zstd_level(mut self, lvl: i32) -> Self {
268 self.0.zstd_level = lvl.clamp(1, 22);
269 self
270 }
271
272 pub fn build(self) -> CompressionPlugin {
274 CompressionPlugin { cfg: self.0 }
275 }
276}
277
278pub enum CompressionResponse<R>
279where
280 R: Responder,
281{
282 Plain(R),
284 Stream(R),
286}
287
288impl<R> Responder for CompressionResponse<R>
289where
290 R: Responder,
291{
292 fn into_response(self) -> Response {
293 match self {
294 CompressionResponse::Plain(r) => r.into_response(),
295 CompressionResponse::Stream(r) => r.into_response(),
296 }
297 }
298}
299
300#[derive(Clone)]
328#[doc(alias = "compression")]
329#[doc(alias = "gzip")]
330#[doc(alias = "brotli")]
331#[doc(alias = "deflate")]
332pub struct CompressionPlugin {
333 cfg: Config,
334}
335
336impl Default for CompressionPlugin {
337 fn default() -> Self {
339 Self {
340 cfg: Config::default(),
341 }
342 }
343}
344
345#[async_trait]
346impl TakoPlugin for CompressionPlugin {
347 fn name(&self) -> &'static str {
349 "CompressionPlugin"
350 }
351
352 fn setup(&self, router: &Router) -> Result<()> {
354 let cfg = self.cfg.clone();
355 router.middleware(move |req, next| {
356 let cfg = cfg.clone();
357 let stream = cfg.stream.clone();
358 async move {
359 if stream == false {
360 return CompressionResponse::Plain(
361 compress_middleware(req, next, cfg).await.into_response(),
362 );
363 } else {
364 return CompressionResponse::Stream(
365 compress_stream_middleware(req, next, cfg)
366 .await
367 .into_response(),
368 );
369 }
370 }
371 });
372 Ok(())
373 }
374}
375
376async fn compress_middleware(req: Request, next: Next, cfg: Config) -> impl Responder {
382 let accepted = req
384 .headers()
385 .get(ACCEPT_ENCODING)
386 .and_then(|v| v.to_str().ok())
387 .unwrap_or("")
388 .to_ascii_lowercase();
389
390 let mut resp = next.run(req).await;
392 let chosen = choose_encoding(&accepted, &cfg.enabled);
393
394 let status = resp.status();
396 if !(status.is_success() || status == StatusCode::NOT_MODIFIED) {
397 return resp.into_response();
398 }
399
400 if resp.headers().contains_key(CONTENT_ENCODING) {
401 return resp.into_response();
402 }
403
404 if let Some(ct) = resp.headers().get(CONTENT_TYPE) {
406 let ct = ct.to_str().unwrap_or("");
407 if !(ct.starts_with("text/")
408 || ct.contains("json")
409 || ct.contains("javascript")
410 || ct.contains("xml"))
411 {
412 return resp.into_response();
413 }
414 }
415
416 let body_bytes = resp.body_mut().collect().await.unwrap().to_bytes();
418 if body_bytes.len() < cfg.min_size {
419 *resp.body_mut() = TakoBody::from(Bytes::from(body_bytes));
420 return resp.into_response();
421 }
422
423 if let Some(enc) = chosen {
425 let compressed = match enc {
426 Encoding::Gzip => {
427 compress_gzip(&body_bytes, cfg.gzip_level).unwrap_or_else(|_| body_bytes.to_vec())
428 }
429 Encoding::Brotli => {
430 compress_brotli(&body_bytes, cfg.brotli_level).unwrap_or_else(|_| body_bytes.to_vec())
431 }
432 Encoding::Deflate => {
433 compress_deflate(&body_bytes, cfg.deflate_level).unwrap_or_else(|_| body_bytes.to_vec())
434 }
435 #[cfg(feature = "zstd")]
436 Encoding::Zstd => {
437 compress_zstd(&body_bytes, cfg.zstd_level).unwrap_or_else(|_| body_bytes.to_vec())
438 }
439 };
440 *resp.body_mut() = TakoBody::from(Bytes::from(compressed));
441 resp
442 .headers_mut()
443 .insert(CONTENT_ENCODING, HeaderValue::from_static(enc.as_str()));
444 resp.headers_mut().remove(CONTENT_LENGTH);
445 resp
446 .headers_mut()
447 .insert(VARY, HeaderValue::from_static("Accept-Encoding"));
448 } else {
449 *resp.body_mut() = TakoBody::from(Bytes::from(body_bytes));
450 }
451
452 resp.into_response()
453}
454
455pub async fn compress_stream_middleware(req: Request, next: Next, cfg: Config) -> impl Responder {
461 let accepted = req
463 .headers()
464 .get(ACCEPT_ENCODING)
465 .and_then(|v| v.to_str().ok())
466 .unwrap_or("")
467 .to_ascii_lowercase();
468
469 let mut resp = next.run(req).await;
471 let chosen = choose_encoding(&accepted, &cfg.enabled);
472
473 let status = resp.status();
475 if !(status.is_success() || status == StatusCode::NOT_MODIFIED) {
476 return resp.into_response();
477 }
478
479 if resp.headers().contains_key(CONTENT_ENCODING) {
480 return resp.into_response();
481 }
482
483 if let Some(ct) = resp.headers().get(CONTENT_TYPE) {
485 let ct = ct.to_str().unwrap_or("");
486 if !(ct.starts_with("text/")
487 || ct.contains("json")
488 || ct.contains("javascript")
489 || ct.contains("xml"))
490 {
491 return resp.into_response();
492 }
493 }
494
495 if let Some(len) = resp
497 .headers()
498 .get(CONTENT_LENGTH)
499 .and_then(|v| v.to_str().ok())
500 .and_then(|v| v.parse::<usize>().ok())
501 {
502 if len < cfg.min_size {
503 return resp.into_response();
504 }
505 }
506
507 if let Some(enc) = chosen {
508 let body = std::mem::replace(resp.body_mut(), TakoBody::empty());
509 let new_body = match enc {
510 Encoding::Gzip => stream_gzip(body, cfg.gzip_level),
511 Encoding::Brotli => stream_brotli(body, cfg.brotli_level),
512 Encoding::Deflate => stream_deflate(body, cfg.deflate_level),
513 #[cfg(feature = "zstd")]
514 Encoding::Zstd => stream_zstd(body, cfg.zstd_level),
515 };
516 *resp.body_mut() = new_body;
517 resp
518 .headers_mut()
519 .insert(CONTENT_ENCODING, HeaderValue::from_static(enc.as_str()));
520 resp.headers_mut().remove(CONTENT_LENGTH);
521 resp
522 .headers_mut()
523 .insert(VARY, HeaderValue::from_static("Accept-Encoding"));
524 }
525
526 resp.into_response()
527}
528
529fn choose_encoding(header: &str, enabled: &[Encoding]) -> Option<Encoding> {
535 let header = header.to_ascii_lowercase();
536 let test = |e: Encoding| header.contains(e.as_str()) && enabled.contains(&e);
537 if test(Encoding::Brotli) {
538 Some(Encoding::Brotli)
539 } else if test(Encoding::Gzip) {
540 Some(Encoding::Gzip)
541 } else if test(Encoding::Deflate) {
542 Some(Encoding::Deflate)
543 } else {
544 #[cfg(feature = "zstd")]
545 {
546 if test(Encoding::Zstd) {
547 return Some(Encoding::Zstd);
548 }
549 }
550 None
551 }
552}
553
554fn compress_gzip(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
556 let mut enc = GzEncoder::new(Vec::new(), GzLevel::new(lvl));
557 enc.write_all(data)?;
558 enc.finish()
559}
560
561fn compress_brotli(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
563 let mut out = Vec::new();
564 brotli::CompressorReader::new(data, 4096, lvl, 22)
565 .read_to_end(&mut out)
566 .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Failed to compress data"))?;
567 Ok(out)
568}
569
570fn compress_deflate(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
572 let mut enc = DeflateEncoder::new(Vec::new(), flate2::Compression::new(lvl));
573 enc.write_all(data)?;
574 enc.finish()
575}
576
577#[cfg(feature = "zstd")]
579#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
580fn compress_zstd(data: &[u8], lvl: i32) -> std::io::Result<Vec<u8>> {
581 zstd_encode(data, lvl)
582}