1use std::any::type_name;
2use std::ops::{Deref, DerefMut};
3
4use mlua::{AnyUserData, IntoLua, Lua, ObjectLike, Result, Table, UserData, Value, Variadic};
5
6use crate::{Channel, Core, HttpMessage, LogLevel, Txn};
7
8pub struct FilterMethod;
10
11impl FilterMethod {
12 pub const START_ANALYZE: u8 = 0b00000001;
13 pub const END_ANALYZE: u8 = 0b00000010;
14 pub const HTTP_HEADERS: u8 = 0b00000100;
15 pub const HTTP_PAYLOAD: u8 = 0b00001000;
16 pub const HTTP_END: u8 = 0b00010000;
17
18 pub const ALL: u8 = u8::MAX;
19}
20
21pub enum FilterResult {
23 Continue,
25 Wait,
27 Error,
29}
30
31impl FilterResult {
32 fn code(&self) -> i8 {
33 match self {
34 FilterResult::Continue => 1,
35 FilterResult::Wait => 0,
36 FilterResult::Error => -1,
37 }
38 }
39}
40
41const FLT_CFG_FL_HTX: u8 = 1;
44
45pub trait UserFilter: Sized {
47 const METHODS: u8 = FilterMethod::ALL;
50
51 const CONTINUE_IF_ERROR: bool = true;
53
54 fn new(lua: &Lua, args: Table) -> Result<Self>;
56
57 fn start_analyze(&mut self, lua: &Lua, txn: Txn, chn: Channel) -> Result<FilterResult> {
59 let _ = (lua, txn, chn);
60 Ok(FilterResult::Continue)
61 }
62
63 fn end_analyze(&mut self, lua: &Lua, txn: Txn, chn: Channel) -> Result<FilterResult> {
65 let _ = (lua, txn, chn);
66 Ok(FilterResult::Continue)
67 }
68
69 fn http_headers(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<FilterResult> {
71 let _ = (lua, txn, msg);
72 Ok(FilterResult::Continue)
73 }
74
75 fn http_payload(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<Option<usize>> {
77 let _ = (lua, txn, msg);
78 Ok(None)
79 }
80
81 fn http_end(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<FilterResult> {
83 let _ = (lua, txn, msg);
84 Ok(FilterResult::Continue)
85 }
86
87 fn register_data_filter(lua: &Lua, txn: Txn, chn: Channel) -> Result<()> {
94 let global_filter = lua.globals().raw_get::<Table>("filter")?;
95 global_filter.call_function::<()>("register_data_filter", (txn.r#priv, chn))?;
96 Ok(())
97 }
98
99 fn unregister_data_filter(lua: &Lua, txn: Txn, chn: Channel) -> Result<()> {
102 let filter = lua.globals().raw_get::<Table>("filter")?;
103 filter.call_function::<()>("unregister_data_filter", (txn.r#priv, chn))?;
104 Ok(())
105 }
106
107 fn wake_time(lua: &Lua, milliseconds: u64) -> Result<()> {
109 let filter = lua.globals().raw_get::<Table>("filter")?;
110 filter.call_function::<()>("wake_time", milliseconds)?;
111 Ok(())
112 }
113}
114
115pub(crate) struct UserFilterWrapper<T>(T);
116
117impl<T> UserFilterWrapper<T>
118where
119 T: UserFilter + 'static,
120{
121 pub(crate) fn make_class(lua: &Lua) -> Result<Table> {
122 let class = lua.create_table()?;
123 class.raw_set("__index", &class)?;
124
125 class.raw_set("id", type_name::<T>())?;
127 class.raw_set("flags", FLT_CFG_FL_HTX)?;
128
129 let class_key = lua.create_registry_value(&class)?;
133 class.raw_set(
134 "new",
135 lua.create_function(move |lua, class: Table| {
136 let args = class.raw_get("args")?;
137 let filter = match T::new(lua, args) {
138 Ok(filter) => filter,
139 Err(err) => {
140 let core = Core::new(lua)?;
141 let msg = format!("Filter '{}': {err}", type_name::<T>());
142 core.log(LogLevel::Err, msg)?;
143 return Ok(Value::Nil);
144 }
145 };
146 let this = lua.create_sequence_from([Self(filter)])?;
147 let class = lua.registry_value::<Table>(&class_key)?;
148 this.set_metatable(Some(class))?;
149 Ok(Value::Table(this))
150 })?,
151 )?;
152
153 if T::METHODS & FilterMethod::START_ANALYZE != 0 {
154 class.raw_set(
155 "start_analyze",
156 lua.create_function(|lua, (t, mut txn, chn): (Table, Txn, Channel)| {
157 let ud = t.raw_get::<AnyUserData>(1)?;
158 let mut this = ud.borrow_mut::<Self>()?;
159 txn.r#priv = Value::Table(t);
160 Self::process_result(lua, this.start_analyze(lua, txn, chn))
161 })?,
162 )?;
163 }
164
165 if T::METHODS & FilterMethod::END_ANALYZE != 0 {
166 class.raw_set(
167 "end_analyze",
168 lua.create_function(|lua, (t, mut txn, chn): (Table, Txn, Channel)| {
169 let ud = t.raw_get::<AnyUserData>(1)?;
170 let mut this = ud.borrow_mut::<Self>()?;
171 txn.r#priv = Value::Table(t);
172 Self::process_result(lua, this.end_analyze(lua, txn, chn))
173 })?,
174 )?;
175 }
176
177 if T::METHODS & FilterMethod::HTTP_HEADERS != 0 {
178 class.raw_set(
179 "http_headers",
180 lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
181 let ud = t.raw_get::<AnyUserData>(1)?;
182 let mut this = ud.borrow_mut::<Self>()?;
183 txn.r#priv = Value::Table(t);
184 Self::process_result(lua, this.http_headers(lua, txn, msg))
185 })?,
186 )?;
187 }
188
189 if T::METHODS & FilterMethod::HTTP_PAYLOAD != 0 {
190 class.raw_set(
191 "http_payload",
192 lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
193 let ud = t.raw_get::<AnyUserData>(1)?;
194 let mut this = ud.borrow_mut::<Self>()?;
195 txn.r#priv = Value::Table(t);
196 let mut res = Variadic::new();
197 match this.http_payload(lua, txn, msg) {
198 Ok(Some(len)) => {
199 res.push(len.into_lua(lua)?);
200 }
201 Ok(None) => {}
202 Err(err) if T::CONTINUE_IF_ERROR => {
203 if let Ok(core) = Core::new(lua) {
204 let _ = core.log(
205 LogLevel::Err,
206 format!("Filter '{}': {}", type_name::<T>(), err),
207 );
208 }
209 }
210 Err(err) => return Err(err),
211 };
212 Ok(res)
213 })?,
214 )?;
215 }
216
217 if T::METHODS & FilterMethod::HTTP_END != 0 {
218 class.raw_set(
219 "http_end",
220 lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
221 let ud = t.raw_get::<AnyUserData>(1)?;
222 let mut this = ud.borrow_mut::<Self>()?;
223 txn.r#priv = Value::Table(t);
224 Self::process_result(lua, this.http_end(lua, txn, msg))
225 })?,
226 )?;
227 }
228
229 Ok(class)
230 }
231
232 #[inline]
233 fn process_result(lua: &Lua, res: Result<FilterResult>) -> Result<i8> {
234 match res {
235 Ok(res) => Ok(res.code()),
236 Err(err) if T::CONTINUE_IF_ERROR => {
237 if let Ok(core) = Core::new(lua) {
238 let _ = core.log(
239 LogLevel::Err,
240 format!("Filter '{}': {}", type_name::<T>(), err),
241 );
242 }
243 Ok(FilterResult::Continue.code())
244 }
245 Err(err) => Err(err),
246 }
247 }
248}
249
250impl<T> UserData for UserFilterWrapper<T> where T: UserFilter + 'static {}
251
252impl<T> Deref for UserFilterWrapper<T> {
253 type Target = T;
254
255 #[inline]
256 fn deref(&self) -> &Self::Target {
257 &self.0
258 }
259}
260
261impl<T> DerefMut for UserFilterWrapper<T> {
262 #[inline]
263 fn deref_mut(&mut self) -> &mut Self::Target {
264 &mut self.0
265 }
266}