#![deny(missing_docs, clippy::pedantic)]
#![allow(clippy::cast_sign_loss)]
use libsqlite3_sys as ffi;
use log::error;
use regex::Regex;
use std::sync::Arc;
static FN_NAME: &[u8] = b"regexp\0";
pub fn register(sqlite3: *mut ffi::sqlite3) -> i32 {
unsafe {
ffi::sqlite3_create_function_v2(
sqlite3,
FN_NAME.as_ptr().cast(),
2,
ffi::SQLITE_UTF8 | ffi::SQLITE_DETERMINISTIC,
std::ptr::null_mut(),
Some(sqlite3_regexp_func),
None,
None,
None,
)
}
}
unsafe extern "C" fn sqlite3_regexp_func(
ctx: *mut ffi::sqlite3_context,
n_arg: i32,
args: *mut *mut ffi::sqlite3_value,
) {
if n_arg != 2 {
eprintln!("n_arg expected to be 2, is {n_arg}");
ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
return;
}
let Some(regex) = get_regex_from_arg(ctx, *args.offset(0), 0) else {
return;
};
let Some(value) = get_text_from_arg(ctx, *args.offset(1)) else {
return;
};
if regex.is_match(value) {
ffi::sqlite3_result_int(ctx, 1);
} else {
ffi::sqlite3_result_int(ctx, 0);
}
}
unsafe fn get_regex_from_arg(
ctx: *mut ffi::sqlite3_context,
arg: *mut ffi::sqlite3_value,
index: i32,
) -> Option<Arc<Regex>> {
let ptr = ffi::sqlite3_get_auxdata(ctx, index);
if !ptr.is_null() {
let ptr = ptr as *const Regex;
Arc::increment_strong_count(ptr);
return Some(Arc::from_raw(ptr));
}
let value = get_text_from_arg(ctx, arg)?;
let regex = match Regex::new(value) {
Ok(regex) => Arc::new(regex),
Err(e) => {
error!("Invalid regex {value:?}: {e:?}");
ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
return None;
}
};
ffi::sqlite3_set_auxdata(
ctx,
index,
Arc::into_raw(Arc::clone(®ex)) as *mut _,
Some(cleanup_arc_regex_pointer),
);
Some(regex)
}
unsafe fn get_text_from_arg<'a>(
ctx: *mut ffi::sqlite3_context,
arg: *mut ffi::sqlite3_value,
) -> Option<&'a str> {
let ty = ffi::sqlite3_value_type(arg);
if ty == ffi::SQLITE_TEXT {
let ptr = ffi::sqlite3_value_text(arg);
let len = ffi::sqlite3_value_bytes(arg);
let slice = std::slice::from_raw_parts(ptr.cast(), len as usize);
match std::str::from_utf8(slice) {
Ok(result) => Some(result),
Err(e) => {
log::error!("Incoming text is not valid UTF8: {e:?}");
ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
None
}
}
} else {
None
}
}
unsafe extern "C" fn cleanup_arc_regex_pointer(ptr: *mut std::ffi::c_void) {
Arc::decrement_strong_count(ptr.cast::<Regex>());
}
#[cfg(test)]
mod tests {
use sqlx::{ConnectOptions, Row};
use std::str::FromStr;
async fn test_db() -> crate::SqliteConnection {
let mut conn = crate::SqliteConnectOptions::from_str("sqlite://:memory:")
.unwrap()
.with_regexp()
.connect()
.await
.unwrap();
sqlx::query("CREATE TABLE test (col TEXT NOT NULL)")
.execute(&mut conn)
.await
.unwrap();
for i in 0..10 {
sqlx::query("INSERT INTO test VALUES (?)")
.bind(format!("value {i}"))
.execute(&mut conn)
.await
.unwrap();
}
conn
}
#[sqlx::test]
async fn test_regexp_does_not_fail() {
let mut conn = test_db().await;
let result = sqlx::query("SELECT col FROM test WHERE col REGEXP 'foo.*bar'")
.fetch_all(&mut conn)
.await
.expect("Could not execute query");
assert!(result.is_empty());
}
#[sqlx::test]
async fn test_regexp_filters_correctly() {
let mut conn = test_db().await;
let result = sqlx::query("SELECT col FROM test WHERE col REGEXP '.*2'")
.fetch_all(&mut conn)
.await
.expect("Could not execute query");
assert_eq!(result.len(), 1);
assert_eq!(result[0].get::<String, usize>(0), String::from("value 2"));
let result = sqlx::query("SELECT col FROM test WHERE col REGEXP '^3'")
.fetch_all(&mut conn)
.await
.expect("Could not execute query");
assert!(result.is_empty());
}
#[sqlx::test]
async fn test_invalid_regexp_should_fail() {
let mut conn = test_db().await;
let result = sqlx::query("SELECT col from test WHERE col REGEXP '(?:?)'")
.execute(&mut conn)
.await;
assert!(matches!(result, Err(sqlx::Error::Database(_))));
}
}