use std::collections::HashMap;
use std::future::Future;
use mlua::{
AnyUserData, AsChunk, ExternalError, FromLuaMulti, Function, IntoLua, IntoLuaMulti, Lua,
Result, Table, TableExt, Value,
};
use crate::filter::UserFilterWrapper;
use crate::{Proxy, UserFilter};
#[derive(Clone)]
pub struct Core<'lua> {
lua: &'lua Lua,
class: Table<'lua>,
}
#[derive(Debug, Copy, Clone)]
pub struct Time {
pub sec: u64,
pub usec: u64,
}
#[derive(Debug, Copy, Clone)]
pub enum Action {
TcpReq,
TcpRes,
HttpReq,
HttpRes,
}
#[derive(Debug, Copy, Clone)]
pub enum ServiceMode {
Tcp,
Http,
}
#[derive(Debug, Copy, Clone)]
pub enum LogLevel {
Emerg,
Alert,
Crit,
Err,
Warning,
Notice,
Info,
Debug,
}
impl<'lua> Core<'lua> {
#[inline]
pub fn new(lua: &'lua Lua) -> Result<Self> {
let class: Table = lua.globals().get("core")?;
Ok(Core { lua, class })
}
#[inline]
pub fn proxies(&self) -> Result<HashMap<String, Proxy<'lua>>> {
self.class.get("proxies")
}
#[inline]
pub fn backends(&self) -> Result<HashMap<String, Proxy<'lua>>> {
self.class.get("backends")
}
#[inline]
pub fn frontends(&self) -> Result<HashMap<String, Proxy<'lua>>> {
self.class.get("frontends")
}
#[inline]
pub fn thread(&self) -> Result<u16> {
self.class.get("thread")
}
#[inline]
pub fn log(&self, level: LogLevel, msg: impl AsRef<str>) -> Result<()> {
let msg = msg.as_ref();
self.class.call_function("log", (level, msg))
}
#[inline]
pub fn add_acl(&self, filename: &str, key: &str) -> Result<()> {
self.class.call_function("add_acl", (filename, key))
}
#[inline]
pub fn del_acl(&self, filename: &str, key: &str) -> Result<()> {
self.class.call_function("del_acl", (filename, key))
}
#[inline]
pub fn del_map(&self, filename: &str, key: &str) -> Result<()> {
self.class.call_function("del_map", (filename, key))
}
#[inline]
pub fn set_map(&self, filename: &str, key: &str, value: &str) -> Result<()> {
self.class.call_function("set_map", (filename, key, value))
}
#[inline]
pub fn get_info(&self) -> Result<Vec<String>> {
self.class.call_function("get_info", ())
}
#[inline]
pub fn now(&self) -> Result<Time> {
let time: Table = self.class.call_function("now", ())?;
Ok(Time {
sec: time.get("sec")?,
usec: time.get("usec")?,
})
}
#[inline]
pub fn http_date(&self, date: &str) -> Result<u64> {
let date: Option<u64> = self.class.call_function("http_date", date)?;
date.ok_or_else(|| "invalid date".into_lua_err())
}
#[inline]
pub fn imf_date(&self, date: &str) -> Result<u64> {
let date: Option<u64> = self.class.call_function("imf_date", date)?;
date.ok_or_else(|| "invalid date".into_lua_err())
}
#[inline]
pub fn rfc850_date(&self, date: &str) -> Result<u64> {
let date: Option<u64> = self.class.call_function("rfc850_date", date)?;
date.ok_or_else(|| "invalid date".into_lua_err())
}
#[inline]
pub fn asctime_date(&self, date: &str) -> Result<u64> {
let date: Option<u64> = self.class.call_function("asctime_date", date)?;
date.ok_or_else(|| "invalid date".into_lua_err())
}
pub fn register_action<A, F>(
&self,
name: &str,
actions: &[Action],
nb_args: usize,
func: F,
) -> Result<()>
where
A: FromLuaMulti<'lua>,
F: Fn(&'lua Lua, A) -> Result<()> + Send + 'static,
{
let func = self.lua.create_function(func)?;
let actions = actions
.iter()
.map(|act| match act {
Action::TcpReq => "tcp-req",
Action::TcpRes => "tcp-res",
Action::HttpReq => "http-req",
Action::HttpRes => "http-res",
})
.collect::<Vec<_>>();
self.class
.call_function("register_action", (name, actions, func, nb_args))
}
pub fn register_async_action<A, F, FR>(
&self,
name: &str,
actions: &[&str],
nb_args: usize,
func: F,
) -> Result<()>
where
A: FromLuaMulti<'lua>,
F: Fn(&'lua Lua, A) -> FR + Send + 'static,
FR: Future<Output = Result<()>> + 'lua,
{
let func = create_async_function(self.lua, func)?;
self.class
.call_function("register_action", (name, actions.to_vec(), func, nb_args))
}
pub fn register_lua_action<'a, S>(
&self,
name: &str,
actions: &[&str],
nb_args: usize,
code: S,
) -> Result<()>
where
S: AsChunk<'lua, 'a>,
{
let func = self.lua.load(code).into_function()?;
self.class
.call_function("register_action", (name, actions.to_vec(), func, nb_args))
}
pub fn register_converters<A, R, F>(&self, name: &str, func: F) -> Result<()>
where
A: FromLuaMulti<'lua>,
R: IntoLua<'lua>,
F: Fn(&'lua Lua, A) -> Result<R> + Send + 'static,
{
let func = self.lua.create_function(func)?;
self.class
.call_function("register_converters", (name, func))
}
#[deprecated(note = "haproxy does not support async converters")]
pub fn register_async_converters<A, R, F, FR>(&self, name: &str, func: F) -> Result<()>
where
A: FromLuaMulti<'lua>,
R: IntoLua<'lua>,
F: Fn(&'lua Lua, A) -> FR + Send + 'static,
FR: Future<Output = Result<R>> + 'lua,
{
let func = create_async_function(self.lua, func)?;
self.class
.call_function("register_converters", (name, func))
}
pub fn register_lua_converters<'a, S>(&self, name: &str, code: S) -> Result<()>
where
S: AsChunk<'lua, 'a>,
{
let func = self.lua.load(code).into_function()?;
self.class
.call_function("register_converters", (name, func))
}
pub fn register_fetches<A, R, F>(&self, name: &str, func: F) -> Result<()>
where
A: FromLuaMulti<'lua>,
R: IntoLua<'lua>,
F: Fn(&'lua Lua, A) -> Result<R> + Send + 'static,
{
let func = self.lua.create_function(func)?;
self.class.call_function("register_fetches", (name, func))
}
#[deprecated(note = "haproxy does not support async fetches")]
pub fn register_async_fetches<A, R, F, FR>(&self, name: &str, func: F) -> Result<()>
where
A: FromLuaMulti<'lua>,
R: IntoLua<'lua>,
F: Fn(&'lua Lua, A) -> FR + Send + 'static,
FR: Future<Output = Result<R>> + 'lua,
{
let func = create_async_function(self.lua, func)?;
self.class.call_function("register_fetches", (name, func))
}
pub fn register_lua_fetches<'a, S>(&self, name: &str, code: S) -> Result<()>
where
S: AsChunk<'lua, 'a>,
{
let func = self.lua.load(code).into_function()?;
self.class.call_function("register_fetches", (name, func))
}
pub fn register_filter<T: UserFilter + 'static>(&self, name: &str) -> Result<()> {
let lua = self.lua;
let func = lua.create_function(|_, (class, args): (Table, Table)| {
class.raw_set("args", args)?;
Ok(class)
});
let filter_class = UserFilterWrapper::<T>::make_class(lua)?;
self.class
.call_function("register_filter", (name, filter_class, func))
}
pub fn register_lua_service<'a, S>(&self, name: &str, mode: ServiceMode, code: S) -> Result<()>
where
S: AsChunk<'lua, 'a>,
{
let func = self.lua.load(code).into_function()?;
let mode = match mode {
ServiceMode::Tcp => "tcp",
ServiceMode::Http => "http",
};
self.class
.call_function("register_service", (name, mode, func))
}
pub fn register_init<F>(&self, func: F) -> Result<()>
where
F: Fn(&'lua Lua) -> Result<()> + Send + 'static,
{
let func = self.lua.create_function(move |lua, ()| func(lua))?;
self.class.call_function("register_init", func)
}
pub fn register_task<F>(&self, func: F) -> Result<()>
where
F: Fn(&'lua Lua) -> Result<()> + Send + 'static,
{
let func = self.lua.create_function(move |lua, ()| func(lua))?;
self.class.call_function("register_task", func)
}
pub fn register_async_task<F, FR>(&self, func: F) -> Result<()>
where
F: Fn(&'lua Lua) -> FR + Send + 'static,
FR: Future<Output = Result<()>> + 'lua,
{
let func = create_async_function(self.lua, move |lua, ()| func(lua))?;
self.class.call_function("register_task", func)
}
pub fn register_lua_task<'a, S>(&self, code: S) -> Result<()>
where
S: AsChunk<'lua, 'a>,
{
let func = self.lua.load(code).into_function()?;
self.class.call_function("register_task", func)
}
pub fn register_lua_cli<'a, S>(&self, path: &[&str], usage: &str, code: S) -> Result<()>
where
S: AsChunk<'lua, 'a>,
{
let func = self.lua.load(code).into_function()?;
self.class
.call_function("register_cli", (path, usage, func))
}
#[inline]
pub fn set_nice(&self, nice: i32) -> Result<()> {
self.class.call_function("set_nice", nice)
}
#[inline]
pub fn parse_addr(&self, addr: &str) -> Result<AnyUserData<'lua>> {
self.class.call_function("parse_addr", addr)
}
#[inline]
pub fn match_addr(&self, addr1: AnyUserData, addr2: AnyUserData) -> Result<bool> {
self.class.call_function("match_addr", (addr1, addr2))
}
}
impl<'lua> IntoLua<'lua> for LogLevel {
#[inline]
fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
(match self {
LogLevel::Emerg => 0,
LogLevel::Alert => 1,
LogLevel::Crit => 2,
LogLevel::Err => 3,
LogLevel::Warning => 4,
LogLevel::Notice => 5,
LogLevel::Info => 6,
LogLevel::Debug => 7,
})
.into_lua(lua)
}
}
pub fn create_async_function<'lua, A, R, F, FR>(lua: &'lua Lua, func: F) -> Result<Function<'lua>>
where
A: FromLuaMulti<'lua>,
R: IntoLuaMulti<'lua>,
F: 'static + Send + Fn(&'lua Lua, A) -> FR,
FR: 'lua + Future<Output = Result<R>>,
{
let _yield_fixup = YieldFixUp::new(lua)?;
lua.create_async_function(func)
}
struct YieldFixUp<'lua>(&'lua Lua, Function<'lua>);
impl<'lua> YieldFixUp<'lua> {
fn new(lua: &'lua Lua) -> Result<Self> {
let coroutine: Table = lua.globals().get("coroutine")?;
let orig_yield: Function = coroutine.get("yield")?;
let new_yield: Function = lua
.load(
r#"
local yield, msleep = core.yield, core.msleep
local i = 0
return function()
if i == 0 then
i = 1
yield()
else
msleep(1)
end
end
"#,
)
.call(())?;
coroutine.set("yield", new_yield)?;
Ok(YieldFixUp(lua, orig_yield))
}
}
impl<'lua> Drop for YieldFixUp<'lua> {
fn drop(&mut self) {
if let Err(e) = (|| {
let coroutine: Table = self.0.globals().get("coroutine")?;
coroutine.set("yield", self.1.clone())
})() {
println!("Error in YieldFixUp destructor: {}", e);
}
}
}