use std::marker::PhantomData;
use std::os::raw::c_int;
use thiserror::Error;
use tre_sys::{
compile as compile_raw, default_regaparams, safe_reganexec, safe_regerror, ExecOutcome,
OwnedRegex, REG_EXTENDED, REG_ICASE,
};
#[derive(Debug, Error)]
pub enum TreError {
#[error("tre compile failed: {0}")]
Compile(String),
#[error("invalid pattern options: {0}")]
InvalidOptions(String),
#[error("internal tre error: {0}")]
Internal(String),
}
#[derive(Clone, Copy, Debug)]
pub struct TreMatchOpts {
pub max_errors: u16,
pub cost_ins: u16,
pub cost_del: u16,
pub cost_subst: u16,
pub max_cost: u16,
pub case_insensitive: bool,
}
impl Default for TreMatchOpts {
fn default() -> Self {
Self {
max_errors: 0,
cost_ins: 1,
cost_del: 1,
cost_subst: 1,
max_cost: 0,
case_insensitive: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TreMatch {
pub start: usize,
pub end: usize,
pub cost: u32,
pub n_ins: u32,
pub n_del: u32,
pub n_subst: u32,
}
pub struct TreCompiledPattern {
raw: OwnedRegex,
opts: TreMatchOpts,
_not_send_or_sync: PhantomData<*const ()>,
}
impl std::fmt::Debug for TreCompiledPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TreCompiledPattern")
.field("nsub", &self.raw.nsub())
.field("opts", &self.opts)
.finish()
}
}
impl TreCompiledPattern {
pub fn compile(pattern: &[u8], opts: TreMatchOpts) -> Result<Self, TreError> {
validate_opts(opts)?;
let mut cflags = REG_EXTENDED;
if opts.case_insensitive {
cflags |= REG_ICASE;
}
match compile_raw(pattern, cflags) {
Ok(raw) => Ok(Self {
raw,
opts,
_not_send_or_sync: PhantomData,
}),
Err(rc) => Err(TreError::Compile(safe_regerror(rc, None))),
}
}
#[must_use]
pub fn is_match(&self, text: &[u8]) -> bool {
self.matches(text).is_some()
}
#[must_use]
pub fn matches(&self, text: &[u8]) -> Option<TreMatch> {
let params = self.build_params();
match safe_reganexec(self.raw.as_raw(), text, params, 0) {
ExecOutcome::Match(m) => Some(TreMatch {
start: 0,
end: text.len(),
cost: u32::try_from(m.cost).unwrap_or(u32::MAX),
n_ins: u32::try_from(m.num_ins).unwrap_or(u32::MAX),
n_del: u32::try_from(m.num_del).unwrap_or(u32::MAX),
n_subst: u32::try_from(m.num_subst).unwrap_or(u32::MAX),
}),
ExecOutcome::NoMatch => None,
ExecOutcome::Error(other) => {
debug_assert!(false, "tre_reganexec returned unexpected code {other}");
None
}
}
}
#[must_use]
pub fn opts(&self) -> TreMatchOpts {
self.opts
}
fn build_params(&self) -> tre_sys::regaparams_t {
let mut params = default_regaparams();
params.cost_ins = c_int::from(self.opts.cost_ins);
params.cost_del = c_int::from(self.opts.cost_del);
params.cost_subst = c_int::from(self.opts.cost_subst);
params.max_err = c_int::from(self.opts.max_errors);
params.max_ins = c_int::from(self.opts.max_errors);
params.max_del = c_int::from(self.opts.max_errors);
params.max_subst = c_int::from(self.opts.max_errors);
let max_cost = if self.opts.max_cost == 0 {
let largest = self
.opts
.cost_ins
.max(self.opts.cost_del)
.max(self.opts.cost_subst);
let budget = u32::from(self.opts.max_errors) * u32::from(largest.max(1));
c_int::try_from(budget).unwrap_or(c_int::MAX)
} else {
c_int::from(self.opts.max_cost)
};
params.max_cost = max_cost;
params
}
}
fn validate_opts(opts: TreMatchOpts) -> Result<(), TreError> {
if opts.max_errors > 0 && opts.cost_ins == 0 && opts.cost_del == 0 && opts.cost_subst == 0 {
return Err(TreError::InvalidOptions(
"max_errors > 0 requires at least one non-zero edit cost".into(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_opts_are_zero_error_unit_costs() {
let opts = TreMatchOpts::default();
assert_eq!(opts.max_errors, 0);
assert_eq!(opts.cost_ins, 1);
assert_eq!(opts.cost_del, 1);
assert_eq!(opts.cost_subst, 1);
assert_eq!(opts.max_cost, 0);
assert!(!opts.case_insensitive);
}
#[test]
fn validate_opts_rejects_max_errors_with_zero_costs() {
let bad = TreMatchOpts {
max_errors: 1,
cost_ins: 0,
cost_del: 0,
cost_subst: 0,
..TreMatchOpts::default()
};
assert!(matches!(
validate_opts(bad),
Err(TreError::InvalidOptions(_))
));
}
#[test]
fn compiled_pattern_phantom_is_not_send_or_sync() {
fn assert_any<T>() {}
assert_any::<TreCompiledPattern>();
}
}