haproxy_brotli/
lib.rs

1use std::io::Write;
2
3use brotlic::{BrotliEncoderOptions, CompressorWriter, Quality, WindowSize};
4use haproxy_api::{Core, FilterMethod, FilterResult, Headers, HttpMessage, Txn, UserFilter};
5use mlua::prelude::{Lua, LuaResult, LuaTable, LuaUserData, LuaValue};
6
7/// A Lua filter that applies Brotli compression to HTTP responses.
8#[derive(Default)]
9pub struct BrotliFilter {
10    enabled: bool,
11    writer: Option<CompressorWriter<Vec<u8>>>,
12    options: BrotliFilterOptions,
13}
14
15/// Options for the Brotli filter.
16#[derive(Debug, Clone, mlua::FromLua)]
17struct BrotliFilterOptions {
18    quality: u8,
19    window: u8,
20    offload: bool,
21    content_types: Vec<String>,
22}
23
24impl LuaUserData for BrotliFilterOptions {}
25
26impl Default for BrotliFilterOptions {
27    fn default() -> Self {
28        BrotliFilterOptions {
29            quality: 5,
30            window: WindowSize::default().bits(),
31            offload: false,
32            content_types: Vec::new(),
33        }
34    }
35}
36
37impl BrotliFilter {
38    fn process_request_headers(&mut self, txn: Txn, msg: HttpMessage) -> LuaResult<()> {
39        // Check if we can prefer brotli over other encodings
40        // We support only GET and POST methods
41        self.enabled = matches!(&*txn.f.get_str("method", ())?, "GET" | "POST")
42            && Self::prefer_brotli_encoding(msg.get_headers()?)?;
43
44        if self.enabled && self.options.offload {
45            msg.del_header("accept-encoding")?;
46        }
47
48        Ok(())
49    }
50
51    fn process_response_headers(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> LuaResult<()> {
52        // We encode only "200" responses
53        if !self.enabled || txn.f.get::<u16>("status", ())? != 200 {
54            return Ok(());
55        }
56
57        let headers = msg.get_headers()?;
58        // Do not encode when `content-encoding` already present
59        let mut skip_encoding = headers.get_first::<LuaValue>("content-encoding")?.is_some();
60        // Do not encode when `cache-control` includes `no-transform`
61        skip_encoding |= headers
62            .get::<String>("cache-control")?
63            .iter()
64            .any(|v| v.contains("no-transform"));
65        // Check content type
66        if !skip_encoding {
67            let content_type = headers
68                .get_first::<String>("content-type")?
69                .unwrap_or_default()
70                .to_ascii_lowercase();
71            skip_encoding = content_type.is_empty() || content_type.starts_with("multipart");
72            if !skip_encoding {
73                let mut found = self.options.content_types.is_empty();
74                for prefix in &self.options.content_types {
75                    if content_type.starts_with(prefix) {
76                        found = true;
77                        break;
78                    }
79                }
80                skip_encoding = !found;
81            }
82        }
83        if skip_encoding {
84            return Ok(());
85        }
86
87        // Update ETag
88        match headers.get::<String>("etag")? {
89            etag if etag.len() > 1 => return Ok(()),
90            etag if etag.len() == 1 && etag[0].starts_with('"') => {
91                msg.set_header("etag", format!("W/{}", etag[0]))?;
92            }
93            _ => {}
94        }
95
96        let size_hint = headers
97            .get_first::<u32>("content-length")
98            .unwrap_or(None)
99            .unwrap_or(0);
100
101        // Initialize brotli encoder
102        let buf = Vec::with_capacity(4096);
103        let encoder = BrotliEncoderOptions::new()
104            .quality(Quality::new(self.options.quality).unwrap_or(Quality::worst()))
105            .window_size(WindowSize::new(self.options.window).unwrap_or(WindowSize::default()))
106            .size_hint(size_hint)
107            .build()
108            .expect("Failed to build brotli encoder");
109        self.writer = Some(CompressorWriter::with_encoder(encoder, buf));
110
111        // Update response headers
112        msg.set_header("content-encoding", "br")?;
113        msg.add_header("vary", "accept-encoding")?;
114        // switch to chunked transfer encoding
115        msg.set_body_len(None)?;
116
117        Self::register_data_filter(lua, txn, msg.channel()?)
118    }
119
120    fn prefer_brotli_encoding(headers: Headers) -> LuaResult<bool> {
121        let accept_encoding = headers.get::<String>("accept-encoding")?;
122        let vals = accept_encoding
123            .iter()
124            .flat_map(|v| v.split(',').map(str::trim))
125            .filter_map(|v| {
126                let (enc, qval) = match v.split_once(";q=") {
127                    Some((e, q)) => (e, q),
128                    None => return Some((v, 1.0f32)),
129                };
130                let qval = match qval.parse::<f32>() {
131                    Ok(f) if f <= 1.0 => f, // q-values over 1 are unacceptable,
132                    _ => return None,
133                };
134                Some((enc, qval))
135            });
136
137        let (mut preferred_encoding, mut max_qval) = ("", 0.);
138        for (enc, qval) in vals {
139            if qval > max_qval {
140                (preferred_encoding, max_qval) = (enc, qval);
141            } else if qval == max_qval && enc == "br" {
142                preferred_encoding = "br";
143            }
144        }
145        Ok(preferred_encoding == "br")
146    }
147
148    fn parse_args(args: LuaTable) -> LuaResult<BrotliFilterOptions> {
149        // Fetch ready parsed options
150        if let Ok(options) = args.raw_get::<BrotliFilterOptions>(0) {
151            return Ok(options);
152        }
153
154        let mut options = BrotliFilterOptions::default();
155        for arg in args.clone().sequence_values::<String>() {
156            match &*arg? {
157                "offload" => options.offload = true,
158                arg if arg.starts_with("type:") => {
159                    options.content_types = arg[5..]
160                        .split(',')
161                        .map(|s| s.trim().to_ascii_lowercase())
162                        .collect();
163                }
164                arg if arg.starts_with("quality:") => {
165                    if let Ok(quality) = arg[8..].trim().parse::<u8>() {
166                        options.quality = quality.clamp(0, 11);
167                    }
168                }
169                arg if arg.starts_with("window:") => {
170                    if let Ok(window) = arg[7..].trim().parse::<u8>() {
171                        options.window = window.clamp(10, 24);
172                    }
173                }
174                _ => {}
175            }
176        }
177        args.raw_set(0, options.clone())?;
178        Ok(options)
179    }
180}
181
182impl UserFilter for BrotliFilter {
183    const METHODS: u8 = FilterMethod::HTTP_HEADERS | FilterMethod::HTTP_PAYLOAD;
184
185    fn new(_: &Lua, args: LuaTable) -> LuaResult<Self> {
186        Ok(BrotliFilter {
187            options: Self::parse_args(args)?,
188            ..Default::default()
189        })
190    }
191
192    fn http_headers(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> LuaResult<FilterResult> {
193        if !msg.is_resp()? {
194            self.process_request_headers(txn, msg)?;
195        } else {
196            self.process_response_headers(lua, txn, msg)?;
197        }
198        Ok(FilterResult::Continue)
199    }
200
201    fn http_payload(&mut self, _: &Lua, _: Txn, msg: HttpMessage) -> LuaResult<Option<usize>> {
202        if let Some(chunk) = msg.body(None, Some(-1))? {
203            let chunk = chunk.as_bytes();
204            let writer = self.writer.as_mut().expect("Brotli writer must exists");
205            if !chunk.is_empty() {
206                writer
207                    .write_all(&chunk)
208                    .expect("Failed to write to brotli encoder");
209                writer.flush().expect("Failed to flush brotli encoder");
210            }
211            if !msg.eom()? {
212                if !writer.get_ref().is_empty() {
213                    msg.set(writer.get_ref(), None, None)?;
214                    writer.get_mut().clear();
215                } else if !chunk.is_empty() {
216                    msg.remove(None, None)?;
217                }
218            } else {
219                let data = self.writer.take().unwrap().into_inner().unwrap();
220                msg.set(data, None, None)?;
221            }
222        }
223        Ok(None)
224    }
225}
226
227/// Registers a "brotli" filter in the given haproxy context.
228pub fn register(lua: &Lua, _options: Option<LuaTable>) -> LuaResult<()> {
229    let core = Core::new(lua)?;
230    core.register_filter::<BrotliFilter>("brotli")?;
231    Ok(())
232}