use std::fmt::Display;
use std::str::FromStr;
use fuzzyhash::FuzzyHash;
use malwaredb_lzjd::LZDict;
use rusqlite::functions::FunctionFlags;
use rusqlite::{Connection, Error, Result};
use tlsh_fixed::Tlsh;
#[derive(Debug)]
struct TlshError(tlsh_fixed::TlshError);
impl Display for TlshError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for TlshError {}
#[derive(Debug)]
struct LZJDError(malwaredb_lzjd::LZJDError);
impl Display for LZJDError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl std::error::Error for LZJDError {}
const NUM_ARGS: usize = 2;
pub(crate) fn add_similarity_functions(db: &Connection) -> Result<()> {
db.create_scalar_function(
"lzjd_compare",
2,
FunctionFlags::SQLITE_UTF8,
move |ctx| {
if !ctx.len() == NUM_ARGS {
return Err(Error::InvalidParameterCount(ctx.len(), NUM_ARGS));
}
let lzjd1 = ctx.get_raw(0).as_str().unwrap_or_default();
let lzjd1 = match LZDict::from_base64_string(lzjd1) {
Ok(l) => l,
Err(e) => {
return Err(Error::UserFunctionError(Box::new(LZJDError(e))));
}
};
let lzjd2 = ctx.get_raw(1).as_str().unwrap_or_default();
let lzjd2 = match LZDict::from_base64_string(lzjd2) {
Ok(l) => l,
Err(e) => {
return Err(Error::UserFunctionError(Box::new(LZJDError(e))));
}
};
Ok(lzjd1.similarity(&lzjd2))
},
)?;
db.create_scalar_function(
"fuzzy_hash_compare",
2,
FunctionFlags::SQLITE_UTF8,
move |ctx| {
if !ctx.len() == NUM_ARGS {
return Err(Error::InvalidParameterCount(ctx.len(), NUM_ARGS));
}
let ssdeep1 = ctx.get_raw(0).as_str().unwrap_or_default();
let ssdeep2 = ctx.get_raw(1).as_str().unwrap_or_default();
let similarity = FuzzyHash::compare(ssdeep1, ssdeep2)
.map_err(|e| Error::UserFunctionError(Box::new(e)))?;
Ok(similarity)
},
)?;
db.create_scalar_function(
"tlsh_compare",
2,
FunctionFlags::SQLITE_UTF8,
move |ctx| {
if !ctx.len() == NUM_ARGS {
return Err(Error::InvalidParameterCount(ctx.len(), NUM_ARGS));
}
let tlsh1 = ctx.get_raw(0).as_str().unwrap_or_default();
let tlsh1 = match Tlsh::from_str(tlsh1) {
Ok(t) => t,
Err(e) => {
return Err(Error::UserFunctionError(Box::new(TlshError(e))));
}
};
let tlsh2 = ctx.get_raw(1).as_str().unwrap_or_default();
let tlsh2 = match Tlsh::from_str(tlsh2) {
Ok(t) => t,
Err(e) => {
return Err(Error::UserFunctionError(Box::new(TlshError(e))));
}
};
Ok(tlsh1.diff(&tlsh2, true))
},
)
}