1use std::io::Write;
2
3use brotlic::{BrotliEncoderOptions, CompressorWriter, Quality, WindowSize};
4use haproxy_api::{Core, FilterMethod, FilterResult, Headers, HttpMessage, Txn, UserFilter};
5use mlua::prelude::{Lua, LuaResult, LuaTable, LuaUserData, LuaValue};
6
7#[derive(Default)]
9pub struct BrotliFilter {
10 enabled: bool,
11 writer: Option<CompressorWriter<Vec<u8>>>,
12 options: BrotliFilterOptions,
13}
14
15#[derive(Debug, Clone, mlua::FromLua)]
17struct BrotliFilterOptions {
18 quality: u8,
19 window: u8,
20 offload: bool,
21 content_types: Vec<String>,
22}
23
24impl LuaUserData for BrotliFilterOptions {}
25
26impl Default for BrotliFilterOptions {
27 fn default() -> Self {
28 BrotliFilterOptions {
29 quality: 5,
30 window: WindowSize::default().bits(),
31 offload: false,
32 content_types: Vec::new(),
33 }
34 }
35}
36
37impl BrotliFilter {
38 fn process_request_headers(&mut self, txn: Txn, msg: HttpMessage) -> LuaResult<()> {
39 self.enabled = matches!(&*txn.f.get_str("method", ())?, "GET" | "POST")
42 && Self::prefer_brotli_encoding(msg.get_headers()?)?;
43
44 if self.enabled && self.options.offload {
45 msg.del_header("accept-encoding")?;
46 }
47
48 Ok(())
49 }
50
51 fn process_response_headers(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> LuaResult<()> {
52 if !self.enabled || txn.f.get::<u16>("status", ())? != 200 {
54 return Ok(());
55 }
56
57 let headers = msg.get_headers()?;
58 let mut skip_encoding = headers.get_first::<LuaValue>("content-encoding")?.is_some();
60 skip_encoding |= headers
62 .get::<String>("cache-control")?
63 .iter()
64 .any(|v| v.contains("no-transform"));
65 if !skip_encoding {
67 let content_type = headers
68 .get_first::<String>("content-type")?
69 .unwrap_or_default()
70 .to_ascii_lowercase();
71 skip_encoding = content_type.is_empty() || content_type.starts_with("multipart");
72 if !skip_encoding {
73 let mut found = self.options.content_types.is_empty();
74 for prefix in &self.options.content_types {
75 if content_type.starts_with(prefix) {
76 found = true;
77 break;
78 }
79 }
80 skip_encoding = !found;
81 }
82 }
83 if skip_encoding {
84 return Ok(());
85 }
86
87 match headers.get::<String>("etag")? {
89 etag if etag.len() > 1 => return Ok(()),
90 etag if etag.len() == 1 && etag[0].starts_with('"') => {
91 msg.set_header("etag", format!("W/{}", etag[0]))?;
92 }
93 _ => {}
94 }
95
96 let size_hint = headers
97 .get_first::<u32>("content-length")
98 .unwrap_or(None)
99 .unwrap_or(0);
100
101 let buf = Vec::with_capacity(4096);
103 let encoder = BrotliEncoderOptions::new()
104 .quality(Quality::new(self.options.quality).unwrap_or(Quality::worst()))
105 .window_size(WindowSize::new(self.options.window).unwrap_or(WindowSize::default()))
106 .size_hint(size_hint)
107 .build()
108 .expect("Failed to build brotli encoder");
109 self.writer = Some(CompressorWriter::with_encoder(encoder, buf));
110
111 msg.set_header("content-encoding", "br")?;
113 msg.add_header("vary", "accept-encoding")?;
114 msg.set_body_len(None)?;
116
117 Self::register_data_filter(lua, txn, msg.channel()?)
118 }
119
120 fn prefer_brotli_encoding(headers: Headers) -> LuaResult<bool> {
121 let accept_encoding = headers.get::<String>("accept-encoding")?;
122 let vals = accept_encoding
123 .iter()
124 .flat_map(|v| v.split(',').map(str::trim))
125 .filter_map(|v| {
126 let (enc, qval) = match v.split_once(";q=") {
127 Some((e, q)) => (e, q),
128 None => return Some((v, 1.0f32)),
129 };
130 let qval = match qval.parse::<f32>() {
131 Ok(f) if f <= 1.0 => f, _ => return None,
133 };
134 Some((enc, qval))
135 });
136
137 let (mut preferred_encoding, mut max_qval) = ("", 0.);
138 for (enc, qval) in vals {
139 if qval > max_qval {
140 (preferred_encoding, max_qval) = (enc, qval);
141 } else if qval == max_qval && enc == "br" {
142 preferred_encoding = "br";
143 }
144 }
145 Ok(preferred_encoding == "br")
146 }
147
148 fn parse_args(args: LuaTable) -> LuaResult<BrotliFilterOptions> {
149 if let Ok(options) = args.raw_get::<BrotliFilterOptions>(0) {
151 return Ok(options);
152 }
153
154 let mut options = BrotliFilterOptions::default();
155 for arg in args.clone().sequence_values::<String>() {
156 match &*arg? {
157 "offload" => options.offload = true,
158 arg if arg.starts_with("type:") => {
159 options.content_types = arg[5..]
160 .split(',')
161 .map(|s| s.trim().to_ascii_lowercase())
162 .collect();
163 }
164 arg if arg.starts_with("quality:") => {
165 if let Ok(quality) = arg[8..].trim().parse::<u8>() {
166 options.quality = quality.clamp(0, 11);
167 }
168 }
169 arg if arg.starts_with("window:") => {
170 if let Ok(window) = arg[7..].trim().parse::<u8>() {
171 options.window = window.clamp(10, 24);
172 }
173 }
174 _ => {}
175 }
176 }
177 args.raw_set(0, options.clone())?;
178 Ok(options)
179 }
180}
181
182impl UserFilter for BrotliFilter {
183 const METHODS: u8 = FilterMethod::HTTP_HEADERS | FilterMethod::HTTP_PAYLOAD;
184
185 fn new(_: &Lua, args: LuaTable) -> LuaResult<Self> {
186 Ok(BrotliFilter {
187 options: Self::parse_args(args)?,
188 ..Default::default()
189 })
190 }
191
192 fn http_headers(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> LuaResult<FilterResult> {
193 if !msg.is_resp()? {
194 self.process_request_headers(txn, msg)?;
195 } else {
196 self.process_response_headers(lua, txn, msg)?;
197 }
198 Ok(FilterResult::Continue)
199 }
200
201 fn http_payload(&mut self, _: &Lua, _: Txn, msg: HttpMessage) -> LuaResult<Option<usize>> {
202 if let Some(chunk) = msg.body(None, Some(-1))? {
203 let chunk = chunk.as_bytes();
204 let writer = self.writer.as_mut().expect("Brotli writer must exists");
205 if !chunk.is_empty() {
206 writer
207 .write_all(&chunk)
208 .expect("Failed to write to brotli encoder");
209 writer.flush().expect("Failed to flush brotli encoder");
210 }
211 if !msg.eom()? {
212 if !writer.get_ref().is_empty() {
213 msg.set(writer.get_ref(), None, None)?;
214 writer.get_mut().clear();
215 } else if !chunk.is_empty() {
216 msg.remove(None, None)?;
217 }
218 } else {
219 let data = self.writer.take().unwrap().into_inner().unwrap();
220 msg.set(data, None, None)?;
221 }
222 }
223 Ok(None)
224 }
225}
226
227pub fn register(lua: &Lua, _options: Option<LuaTable>) -> LuaResult<()> {
229 let core = Core::new(lua)?;
230 core.register_filter::<BrotliFilter>("brotli")?;
231 Ok(())
232}