use std::cell::UnsafeCell;
use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::os::raw::c_void;
use coraza_sys::*;
use crate::callbacks::{debug_log_trampoline, error_trampoline, CallbackContext};
use crate::error::Error;
use crate::matched_rule::{LogLevel, MatchedRule};
use crate::transaction::Transaction;
pub struct WafConfig {
handle: coraza_waf_config_t,
callback_ctx: Option<Box<CallbackContext>>,
_phantom: PhantomData<UnsafeCell<i32>>, }
impl WafConfig {
pub fn new() -> Result<Self, Error> {
let handle = unsafe { coraza_new_waf_config() };
if handle == 0 {
return Err(Error::InvalidConfig);
}
Ok(Self {
handle,
callback_ctx: Some(Box::new(CallbackContext::new())),
_phantom: PhantomData,
})
}
pub fn with_directives(self, directives: &str) -> Self {
let c_directives = CString::new(directives).expect("directive string contains null byte");
unsafe {
coraza_rules_add(self.handle, c_directives.as_ptr());
}
self
}
pub fn with_directives_from_file(self, path: &str) -> Self {
let c_path = CString::new(path).expect("path contains null byte");
unsafe {
coraza_rules_add_file(self.handle, c_path.as_ptr());
}
self
}
pub fn with_debug_log_callback<F>(mut self, f: F) -> Self
where
F: Fn(LogLevel, &str, &str) + Send + 'static,
{
if let Some(ref mut ctx) = self.callback_ctx {
ctx.debug_log = Some(Box::new(f));
}
self
}
pub fn with_error_callback<F>(mut self, f: F) -> Self
where
F: Fn(MatchedRule) + Send + 'static,
{
if let Some(ref mut ctx) = self.callback_ctx {
ctx.error = Some(Box::new(f));
}
self
}
pub fn build(mut self) -> Result<Waf, Error> {
let ctx = self.callback_ctx.take();
if let Some(ctx) = ctx {
let ctx_ptr = Box::into_raw(ctx) as *mut c_void;
let has_debug = unsafe { &*ctx_ptr.cast::<CallbackContext>() }
.debug_log
.is_some();
let has_error = unsafe { &*ctx_ptr.cast::<CallbackContext>() }
.error
.is_some();
if has_debug {
unsafe {
coraza_add_debug_log_callback(self.handle, Some(debug_log_trampoline), ctx_ptr);
}
}
if has_error {
unsafe {
coraza_add_error_callback(self.handle, Some(error_trampoline), ctx_ptr);
}
}
}
let mut err: *mut std::os::raw::c_char = std::ptr::null_mut();
let handle = unsafe { coraza_new_waf(self.handle, &mut err) };
if handle == 0 {
let msg = if err.is_null() {
"unknown error".to_string()
} else {
let s = unsafe { CStr::from_ptr(err) }
.to_string_lossy()
.into_owned();
unsafe { coraza_free_string(err) };
s
};
return Err(Error::WafCreation(msg));
}
let _ = std::mem::replace(&mut self.handle, 0);
Ok(Waf { handle })
}
}
impl Drop for WafConfig {
fn drop(&mut self) {
if self.handle != 0 {
unsafe {
coraza_free_waf_config(self.handle);
}
}
}
}
pub struct Waf {
handle: coraza_waf_t,
}
impl std::fmt::Debug for Waf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Waf")
.field("rules_count", &self.rules_count())
.finish()
}
}
impl Waf {
pub fn rules_count(&self) -> i32 {
unsafe { coraza_rules_count(self.handle) }
}
pub fn new_transaction(&self) -> Transaction {
let handle = unsafe { coraza_new_transaction(self.handle) };
Transaction::new(handle)
}
pub fn new_transaction_with_id(&self, id: &str) -> Transaction {
let c_id = CString::new(id).expect("ID contains null byte");
let handle = unsafe { coraza_new_transaction_with_id(self.handle, c_id.as_ptr()) };
Transaction::new(handle)
}
}
impl Drop for Waf {
fn drop(&mut self) {
if self.handle != 0 {
unsafe {
coraza_free_waf(self.handle);
}
}
}
}