use std::any::type_name;
use std::ops::{Deref, DerefMut};
use mlua::{AnyUserData, IntoLua, Lua, Result, Table, TableExt, UserData, Value, Variadic};
use crate::{Channel, Core, HttpMessage, LogLevel, Txn};
pub struct FilterMethod;
impl FilterMethod {
pub const START_ANALYZE: u8 = 0b00000001;
pub const END_ANALYZE: u8 = 0b00000010;
pub const HTTP_HEADERS: u8 = 0b00000100;
pub const HTTP_PAYLOAD: u8 = 0b00001000;
pub const HTTP_END: u8 = 0b00010000;
pub const ALL: u8 = u8::MAX;
}
pub enum FilterResult {
Continue,
Wait,
Error,
}
impl FilterResult {
fn code(&self) -> i8 {
match self {
FilterResult::Continue => 1,
FilterResult::Wait => 0,
FilterResult::Error => -1,
}
}
}
const FLT_CFG_FL_HTX: u8 = 1;
pub trait UserFilter: Sized {
const METHODS: u8 = FilterMethod::ALL;
const CONTINUE_IF_ERROR: bool = true;
fn new(lua: &Lua, args: Table) -> Result<Self>;
fn start_analyze(&mut self, lua: &Lua, txn: Txn, chn: Channel) -> Result<FilterResult> {
let _ = (lua, txn, chn);
Ok(FilterResult::Continue)
}
fn end_analyze(&mut self, lua: &Lua, txn: Txn, chn: Channel) -> Result<FilterResult> {
let _ = (lua, txn, chn);
Ok(FilterResult::Continue)
}
fn http_headers(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<FilterResult> {
let _ = (lua, txn, msg);
Ok(FilterResult::Continue)
}
fn http_payload(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<Option<usize>> {
let _ = (lua, txn, msg);
Ok(None)
}
fn http_end(&mut self, lua: &Lua, txn: Txn, msg: HttpMessage) -> Result<FilterResult> {
let _ = (lua, txn, msg);
Ok(FilterResult::Continue)
}
fn register_data_filter(lua: &Lua, txn: Txn, chn: Channel) -> Result<()> {
let global_filter = lua.globals().raw_get::<_, Table>("filter")?;
global_filter.call_function("register_data_filter", (txn.r#priv, chn))?;
Ok(())
}
fn unregister_data_filter(lua: &Lua, txn: Txn, chn: Channel) -> Result<()> {
let filter = lua.globals().raw_get::<_, Table>("filter")?;
filter.call_function("unregister_data_filter", (txn.r#priv, chn))?;
Ok(())
}
}
pub(crate) struct UserFilterWrapper<T>(T);
impl<T> UserFilterWrapper<T>
where
T: UserFilter + 'static,
{
pub(crate) fn make_class(lua: &Lua) -> Result<Table> {
let class = lua.create_table()?;
class.raw_set("__index", class.clone())?;
class.raw_set("id", type_name::<T>())?;
class.raw_set("flags", FLT_CFG_FL_HTX)?;
let class_key = lua.create_registry_value(class.clone())?;
class.raw_set(
"new",
lua.create_function(move |lua, class: Table| {
let args = class.raw_get("args")?;
let filter = match T::new(lua, args) {
Ok(filter) => filter,
Err(err) => {
let core = Core::new(lua)?;
let msg = format!("Filter '{}': {err}", type_name::<T>());
core.log(LogLevel::Err, msg)?;
return Ok(Value::Nil);
}
};
let this = lua.create_sequence_from([Self(filter)])?;
let class = lua.registry_value::<Table>(&class_key)?;
this.set_metatable(Some(class));
Ok(Value::Table(this))
})?,
)?;
if T::METHODS & FilterMethod::START_ANALYZE != 0 {
class.raw_set(
"start_analyze",
lua.create_function(|lua, (t, mut txn, chn): (Table, Txn, Channel)| {
let ud = t.raw_get::<_, AnyUserData>(1)?;
let mut this = ud.borrow_mut::<Self>()?;
txn.r#priv = Value::Table(t);
Self::process_result(lua, this.start_analyze(lua, txn, chn))
})?,
)?;
}
if T::METHODS & FilterMethod::END_ANALYZE != 0 {
class.raw_set(
"end_analyze",
lua.create_function(|lua, (t, mut txn, chn): (Table, Txn, Channel)| {
let ud = t.raw_get::<_, AnyUserData>(1)?;
let mut this = ud.borrow_mut::<Self>()?;
txn.r#priv = Value::Table(t);
Self::process_result(lua, this.end_analyze(lua, txn, chn))
})?,
)?;
}
if T::METHODS & FilterMethod::HTTP_HEADERS != 0 {
class.raw_set(
"http_headers",
lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
let ud = t.raw_get::<_, AnyUserData>(1)?;
let mut this = ud.borrow_mut::<Self>()?;
txn.r#priv = Value::Table(t);
Self::process_result(lua, this.http_headers(lua, txn, msg))
})?,
)?;
}
if T::METHODS & FilterMethod::HTTP_PAYLOAD != 0 {
class.raw_set(
"http_payload",
lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
let ud = t.raw_get::<_, AnyUserData>(1)?;
let mut this = ud.borrow_mut::<Self>()?;
txn.r#priv = Value::Table(t);
let mut res = Variadic::new();
match this.http_payload(lua, txn, msg) {
Ok(Some(len)) => {
res.push(len.into_lua(lua)?);
}
Ok(None) => {}
Err(err) if T::CONTINUE_IF_ERROR => {
if let Ok(core) = Core::new(lua) {
let _ = core.log(
LogLevel::Err,
format!("Filter '{}': {}", type_name::<T>(), err),
);
}
}
Err(err) => return Err(err),
};
Ok(res)
})?,
)?;
}
if T::METHODS & FilterMethod::HTTP_END != 0 {
class.raw_set(
"http_end",
lua.create_function(|lua, (t, mut txn, msg): (Table, Txn, HttpMessage)| {
let ud = t.raw_get::<_, AnyUserData>(1)?;
let mut this = ud.borrow_mut::<Self>()?;
txn.r#priv = Value::Table(t);
Self::process_result(lua, this.http_end(lua, txn, msg))
})?,
)?;
}
Ok(class)
}
#[inline]
fn process_result(lua: &Lua, res: Result<FilterResult>) -> Result<i8> {
match res {
Ok(res) => Ok(res.code()),
Err(err) if T::CONTINUE_IF_ERROR => {
if let Ok(core) = Core::new(lua) {
let _ = core.log(
LogLevel::Err,
format!("Filter '{}': {}", type_name::<T>(), err),
);
}
Ok(FilterResult::Continue.code())
}
Err(err) => Err(err),
}
}
}
impl<T> UserData for UserFilterWrapper<T> where T: UserFilter + 'static {}
impl<T> Deref for UserFilterWrapper<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for UserFilterWrapper<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}