haproxy_api/
filter.rs

1use std::any::type_name;
2use std::ops::{Deref, DerefMut};
3
4use mlua::{AnyUserData, IntoLua, Lua, ObjectLike, Result, Table, UserData, Value, Variadic};
5
6use crate::{Channel, Core, HttpMessage, LogLevel, Txn};
7
8/// Represents methods available to call in [`UserFilter`].
9pub struct FilterMethod;
10
11impl FilterMethod {
12    pub const START_ANALYZE: u8 = 0b00000001;
13    pub const END_ANALYZE: u8 = 0b00000010;
14    pub const HTTP_HEADERS: u8 = 0b00000100;
15    pub const HTTP_PAYLOAD: u8 = 0b00001000;
16    pub const HTTP_END: u8 = 0b00010000;
17
18    pub const ALL: u8 = u8::MAX;
19}
20
21/// A code that filter callback functions may return.
22pub enum FilterResult {
23    /// A filtering step is finished for filter.
24    Continue,
25    /// A filtering step must be paused, waiting for more data or for an external event depending on filter.
26    Wait,
27    /// Trigger a error
28    Error,
29}
30
31impl FilterResult {
32    fn code(&self) -> i8 {
33        match self {
34            FilterResult::Continue => 1,
35            FilterResult::Wait => 0,
36            FilterResult::Error => -1,
37        }
38    }
39}
40
41/// A flag corresponding to the filter flag FLT_CFG_FL_HTX.
42/// When it is set for a filter, it means the filter is able to filter HTTP streams.
43const FLT_CFG_FL_HTX: u8 = 1;
44
45/// A trait that defines all required callback functions to implement filters.
46pub trait UserFilter: Sized {
47    /// Sets methods available for this filter.
48    /// By default ALL
49    const METHODS: u8 = FilterMethod::ALL;
50
51    /// Continue execution if a filter callback returns an error.
52    const CONTINUE_IF_ERROR: bool = true;
53
54    /// Creates a new instance of filter.
55    fn new(lua: &Lua, args: Table) -> Result<Self>;
56
57    /// Called when the analysis starts on the channel `chn`.
58    fn start_analyze(&mut self, lua: &Lua, txn: Txn, chn: Channel) -> Result<FilterResult> {
59        let _ = (lua, txn, chn);
60        Ok(FilterResult::Continue)
61    }
62
63    /// Called when the analysis ends on the channel `chn`.
64    fn end_analyze(&mut self, lua: &Lua, txn: Txn, chn: Channel) -> Result<FilterResult> {
65        let _ = (lua, txn, chn);
66        Ok(FilterResult::Continue)
67    }
68
69    /// Called just before the HTTP payload analysis and after any processing on the HTTP message `msg`.
70    fn http_headers(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<FilterResult> {
71        let _ = (lua, txn, msg);
72        Ok(FilterResult::Continue)
73    }
74
75    /// Called during the HTTP payload analysis on the HTTP message `msg`.
76    fn http_payload(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<Option<usize>> {
77        let _ = (lua, txn, msg);
78        Ok(None)
79    }
80
81    /// Called after the HTTP payload analysis on the HTTP message `msg`.
82    fn http_end(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<FilterResult> {
83        let _ = (lua, txn, msg);
84        Ok(FilterResult::Continue)
85    }
86
87    //
88    // HAProxy provided methods
89    //
90
91    /// Enable the data filtering on the channel `chn` for the current filter.
92    /// It may be called at any time from any callback functions proceeding the data analysis.
93    fn register_data_filter(lua: &Lua, txn: Txn, chn: Channel) -> Result<()> {
94        let global_filter = lua.globals().raw_get::<Table>("filter")?;
95        global_filter.call_function::<()>("register_data_filter", (txn.r#priv, chn))?;
96        Ok(())
97    }
98
99    /// Disable the data filtering on the channel `chn` for the current filter.
100    /// It may be called at any time from any callback functions.
101    fn unregister_data_filter(lua: &Lua, txn: Txn, chn: Channel) -> Result<()> {
102        let filter = lua.globals().raw_get::<Table>("filter")?;
103        filter.call_function::<()>("unregister_data_filter", (txn.r#priv, chn))?;
104        Ok(())
105    }
106
107    /// Set the pause timeout to the specified time, defined in milliseconds.
108    fn wake_time(lua: &Lua, milliseconds: u64) -> Result<()> {
109        let filter = lua.globals().raw_get::<Table>("filter")?;
110        filter.call_function::<()>("wake_time", milliseconds)?;
111        Ok(())
112    }
113}
114
115pub(crate) struct UserFilterWrapper<T>(T);
116
117impl<T> UserFilterWrapper<T>
118where
119    T: UserFilter + 'static,
120{
121    pub(crate) fn make_class(lua: &Lua) -> Result<Table> {
122        let class = lua.create_table()?;
123        class.raw_set("__index", &class)?;
124
125        // Attributes
126        class.raw_set("id", type_name::<T>())?;
127        class.raw_set("flags", FLT_CFG_FL_HTX)?;
128
129        //
130        // Methods
131        //
132        let class_key = lua.create_registry_value(&class)?;
133        class.raw_set(
134            "new",
135            lua.create_function(move |lua, class: Table| {
136                let args = class.raw_get("args")?;
137                let filter = match T::new(lua, args) {
138                    Ok(filter) => filter,
139                    Err(err) => {
140                        let core = Core::new(lua)?;
141                        let msg = format!("Filter '{}': {err}", type_name::<T>());
142                        core.log(LogLevel::Err, msg)?;
143                        return Ok(Value::Nil);
144                    }
145                };
146                let this = lua.create_sequence_from([Self(filter)])?;
147                let class = lua.registry_value::<Table>(&class_key)?;
148                this.set_metatable(Some(class))?;
149                Ok(Value::Table(this))
150            })?,
151        )?;
152
153        if T::METHODS & FilterMethod::START_ANALYZE != 0 {
154            class.raw_set(
155                "start_analyze",
156                lua.create_function(|lua, (t, mut txn, chn): (Table, Txn, Channel)| {
157                    let ud = t.raw_get::<AnyUserData>(1)?;
158                    let mut this = ud.borrow_mut::<Self>()?;
159                    txn.r#priv = Value::Table(t);
160                    Self::process_result(lua, this.start_analyze(lua, txn, chn))
161                })?,
162            )?;
163        }
164
165        if T::METHODS & FilterMethod::END_ANALYZE != 0 {
166            class.raw_set(
167                "end_analyze",
168                lua.create_function(|lua, (t, mut txn, chn): (Table, Txn, Channel)| {
169                    let ud = t.raw_get::<AnyUserData>(1)?;
170                    let mut this = ud.borrow_mut::<Self>()?;
171                    txn.r#priv = Value::Table(t);
172                    Self::process_result(lua, this.end_analyze(lua, txn, chn))
173                })?,
174            )?;
175        }
176
177        if T::METHODS & FilterMethod::HTTP_HEADERS != 0 {
178            class.raw_set(
179                "http_headers",
180                lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
181                    let ud = t.raw_get::<AnyUserData>(1)?;
182                    let mut this = ud.borrow_mut::<Self>()?;
183                    txn.r#priv = Value::Table(t);
184                    Self::process_result(lua, this.http_headers(lua, txn, msg))
185                })?,
186            )?;
187        }
188
189        if T::METHODS & FilterMethod::HTTP_PAYLOAD != 0 {
190            class.raw_set(
191                "http_payload",
192                lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
193                    let ud = t.raw_get::<AnyUserData>(1)?;
194                    let mut this = ud.borrow_mut::<Self>()?;
195                    txn.r#priv = Value::Table(t);
196                    let mut res = Variadic::new();
197                    match this.http_payload(lua, txn, msg) {
198                        Ok(Some(len)) => {
199                            res.push(len.into_lua(lua)?);
200                        }
201                        Ok(None) => {}
202                        Err(err) if T::CONTINUE_IF_ERROR => {
203                            if let Ok(core) = Core::new(lua) {
204                                let _ = core.log(
205                                    LogLevel::Err,
206                                    format!("Filter '{}': {}", type_name::<T>(), err),
207                                );
208                            }
209                        }
210                        Err(err) => return Err(err),
211                    };
212                    Ok(res)
213                })?,
214            )?;
215        }
216
217        if T::METHODS & FilterMethod::HTTP_END != 0 {
218            class.raw_set(
219                "http_end",
220                lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
221                    let ud = t.raw_get::<AnyUserData>(1)?;
222                    let mut this = ud.borrow_mut::<Self>()?;
223                    txn.r#priv = Value::Table(t);
224                    Self::process_result(lua, this.http_end(lua, txn, msg))
225                })?,
226            )?;
227        }
228
229        Ok(class)
230    }
231
232    #[inline]
233    fn process_result(lua: &Lua, res: Result<FilterResult>) -> Result<i8> {
234        match res {
235            Ok(res) => Ok(res.code()),
236            Err(err) if T::CONTINUE_IF_ERROR => {
237                if let Ok(core) = Core::new(lua) {
238                    let _ = core.log(
239                        LogLevel::Err,
240                        format!("Filter '{}': {}", type_name::<T>(), err),
241                    );
242                }
243                Ok(FilterResult::Continue.code())
244            }
245            Err(err) => Err(err),
246        }
247    }
248}
249
250impl<T> UserData for UserFilterWrapper<T> where T: UserFilter + 'static {}
251
252impl<T> Deref for UserFilterWrapper<T> {
253    type Target = T;
254
255    #[inline]
256    fn deref(&self) -> &Self::Target {
257        &self.0
258    }
259}
260
261impl<T> DerefMut for UserFilterWrapper<T> {
262    #[inline]
263    fn deref_mut(&mut self) -> &mut Self::Target {
264        &mut self.0
265    }
266}