use std::fmt;
use std::mem::{self, MaybeUninit};
use std::ops::Range;
use std::ptr;
use std::slice;
use derive_more::{Deref, From, Into};
use foreign_types::{foreign_type, ForeignType, ForeignTypeRef};
use crate::{
chimera::{error::AsResult, ffi, DatabaseRef},
Result,
};
foreign_type! {
pub unsafe type Scratch: Send {
type CType = ffi::ch_scratch_t;
fn drop = free_scratch;
fn clone = clone_scratch;
}
}
unsafe fn free_scratch(s: *mut ffi::ch_scratch_t) {
ffi::ch_free_scratch(s).expect("free scratch");
}
unsafe fn clone_scratch(s: *mut ffi::ch_scratch_t) -> *mut ffi::ch_scratch_t {
let mut p = MaybeUninit::uninit();
ffi::ch_clone_scratch(s, p.as_mut_ptr()).expect("clone scratch");
p.assume_init()
}
impl ScratchRef {
pub fn size(&self) -> Result<usize> {
let mut size = MaybeUninit::uninit();
unsafe { ffi::ch_scratch_size(self.as_ptr(), size.as_mut_ptr()).map(|_| size.assume_init()) }
}
}
impl DatabaseRef {
pub fn alloc_scratch(&self) -> Result<Scratch> {
let mut s = MaybeUninit::zeroed();
unsafe { ffi::ch_alloc_scratch(self.as_ptr(), s.as_mut_ptr()).map(|_| Scratch::from_ptr(s.assume_init())) }
}
pub fn realloc_scratch(&self, s: &mut Scratch) -> Result<&ScratchRef> {
let mut p = s.as_ptr();
unsafe {
ffi::ch_alloc_scratch(self.as_ptr(), &mut p).map(|_| {
s.0 = ptr::NonNull::new_unchecked(p);
ScratchRef::from_ptr(p)
})
}
}
}
#[repr(u32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Matching {
Continue = ffi::CH_CALLBACK_CONTINUE,
Terminate = ffi::CH_CALLBACK_TERMINATE,
Skip = ffi::CH_CALLBACK_SKIP_PATTERN,
}
impl Default for Matching {
fn default() -> Self {
Matching::Continue
}
}
#[repr(u32)]
#[derive(Clone, Copy, Debug, From, PartialEq, Eq)]
pub enum Error {
MatchLimit = ffi::CH_ERROR_MATCHLIMIT,
RecursionLimit = ffi::CH_ERROR_RECURSIONLIMIT,
}
#[repr(transparent)]
#[derive(Clone, Copy, From, Into, Deref, PartialEq, Eq)]
pub struct Capture(ffi::ch_capture);
impl fmt::Debug for Capture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Capture")
.field("is_active", &self.is_active())
.field("from", &self.from)
.field("to", &self.to)
.finish()
}
}
impl From<Capture> for Range<usize> {
fn from(capture: Capture) -> Self {
capture.range()
}
}
impl Capture {
pub fn is_active(&self) -> bool {
self.flags == ffi::CH_CAPTURE_FLAG_ACTIVE
}
pub fn range(&self) -> Range<usize> {
self.from as usize..self.to as usize
}
}
pub trait MatchEventHandler<'a> {
unsafe fn split(&mut self) -> (ffi::ch_match_event_handler, *mut libc::c_void);
}
impl MatchEventHandler<'_> for () {
unsafe fn split(&mut self) -> (ffi::ch_match_event_handler, *mut libc::c_void) {
(None, ptr::null_mut())
}
}
impl MatchEventHandler<'_> for Matching {
unsafe fn split(&mut self) -> (ffi::ch_match_event_handler, *mut libc::c_void) {
unsafe extern "C" fn trampoline(
_id: u32,
_from: u64,
_to: u64,
_flags: u32,
_size: u32,
_captured: *const ffi::ch_capture_t,
ctx: *mut ::libc::c_void,
) -> ::libc::c_int {
*(*(ctx as *mut (&mut Matching, *mut ()))).0 as _
}
(Some(trampoline), self as *mut _ as *mut _)
}
}
impl<'a, F> MatchEventHandler<'a> for F
where
F: FnMut(u32, u64, u64, u32, Option<&'a [Capture]>) -> Matching,
{
unsafe fn split(&mut self) -> (ffi::ch_match_event_handler, *mut libc::c_void) {
(Some(on_match_trampoline::<'a, F>), self as *mut _ as *mut _)
}
}
unsafe extern "C" fn on_match_trampoline<'a, F>(
id: u32,
from: u64,
to: u64,
flags: u32,
size: u32,
captured: *const ffi::ch_capture_t,
ctx: *mut ::libc::c_void,
) -> ffi::ch_callback_t
where
F: FnMut(u32, u64, u64, u32, Option<&'a [Capture]>) -> Matching,
{
let &mut (ref mut callback, _) = &mut *(ctx as *mut (&mut F, *mut ()));
callback(
id,
from,
to,
flags,
if captured.is_null() || size == 0 {
None
} else {
Some(slice::from_raw_parts(captured as *const _, size as usize))
},
) as i32
}
pub trait ErrorEventHandler {
unsafe fn split(&mut self) -> (ffi::ch_error_event_handler, *mut libc::c_void);
}
impl ErrorEventHandler for () {
unsafe fn split(&mut self) -> (ffi::ch_error_event_handler, *mut libc::c_void) {
(None, ptr::null_mut())
}
}
impl ErrorEventHandler for Matching {
unsafe fn split(&mut self) -> (ffi::ch_error_event_handler, *mut libc::c_void) {
unsafe extern "C" fn trampoline(
_error_type: ffi::ch_error_event_t,
_id: u32,
_info: *mut ::libc::c_void,
ctx: *mut ::libc::c_void,
) -> ffi::ch_callback_t {
*(*(ctx as *mut (*mut (), &mut Matching))).1 as _
}
(Some(trampoline), self as *mut _ as *mut _)
}
}
impl<F> ErrorEventHandler for F
where
F: FnMut(Error, u32) -> Matching,
{
unsafe fn split(&mut self) -> (ffi::ch_error_event_handler, *mut libc::c_void) {
(Some(on_error_trampoline::<F>), self as *mut _ as *mut _)
}
}
unsafe extern "C" fn on_error_trampoline<F>(
error_type: ffi::ch_error_event_t,
id: u32,
_info: *mut ::libc::c_void,
ctx: *mut ::libc::c_void,
) -> ffi::ch_callback_t
where
F: FnMut(Error, u32) -> Matching,
{
let &mut (_, ref mut callback) = &mut *(ctx as *mut (*mut (), &mut F));
callback(mem::transmute(error_type), id) as i32
}
impl DatabaseRef {
pub fn scan<'a, T, F, E>(
&self,
data: T,
scratch: &'a ScratchRef,
mut on_match_event: F,
mut on_error_event: E,
) -> Result<()>
where
T: AsRef<[u8]>,
F: MatchEventHandler<'a>,
E: ErrorEventHandler,
{
let data = data.as_ref();
unsafe {
let (on_match_callback, on_match_data) = on_match_event.split();
let (on_error_callback, on_error_data) = on_error_event.split();
let mut userdata = (on_match_data, on_error_data);
ffi::ch_scan(
self.as_ptr(),
data.as_ptr() as *const _,
data.len() as _,
0,
scratch.as_ptr(),
on_match_callback,
on_error_callback,
&mut userdata as *mut _ as *mut _,
)
.ok()
}
}
}
#[cfg(test)]
pub mod tests {
use std::ptr;
use foreign_types::ForeignType;
use crate::chimera::prelude::*;
const SCRATCH_SIZE: usize = 2000;
#[test]
fn test_scratch() {
let db: Database = "test".parse().unwrap();
let s = db.alloc_scratch().unwrap();
assert!(s.size().unwrap() > SCRATCH_SIZE);
let mut s2 = s.clone();
assert!(!ptr::eq(s.as_ptr(), s2.as_ptr()));
assert!(s2.size().unwrap() > SCRATCH_SIZE);
let db2: Database = "foobar".parse().unwrap();
db2.realloc_scratch(&mut s2).unwrap();
assert!(!ptr::eq(s.as_ptr(), s2.as_ptr()));
assert!(s2.size().unwrap() >= s.size().unwrap());
}
}