malwaredb-server 0.3.4

Server data storage logic for MalwareDB.
Documentation
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::str::FromStr;

use fuzzyhash::FuzzyHash;
use malwaredb_lzjd::LZDict;
use rusqlite::functions::FunctionFlags;
use rusqlite::{Connection, Error, Result};
use tlsh_fixed::Tlsh;

#[derive(Debug)]
struct TlshError(tlsh_fixed::TlshError);

impl Display for TlshError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl std::error::Error for TlshError {}

#[derive(Debug)]
struct LZJDError(malwaredb_lzjd::LZJDError);

impl Display for LZJDError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?}", self.0)
    }
}

impl std::error::Error for LZJDError {}

const NUM_ARGS: usize = 2;

/// Create the `lzjd_compare()` and`tlsh_compare()` functions in `SQLite`
pub(crate) fn add_similarity_functions(db: &Connection) -> Result<()> {
    db.create_scalar_function(
        "lzjd_compare",
        2, /*NUM_ARGS*/
        FunctionFlags::SQLITE_UTF8,
        move |ctx| {
            if !ctx.len() == NUM_ARGS {
                return Err(Error::InvalidParameterCount(ctx.len(), NUM_ARGS));
            }

            let lzjd1 = ctx.get_raw(0).as_str().unwrap_or_default();
            let lzjd1 = match LZDict::from_base64_string(lzjd1) {
                Ok(l) => l,
                Err(e) => {
                    return Err(Error::UserFunctionError(Box::new(LZJDError(e))));
                }
            };

            let lzjd2 = ctx.get_raw(1).as_str().unwrap_or_default();
            let lzjd2 = match LZDict::from_base64_string(lzjd2) {
                Ok(l) => l,
                Err(e) => {
                    return Err(Error::UserFunctionError(Box::new(LZJDError(e))));
                }
            };

            Ok(lzjd1.similarity(&lzjd2))
        },
    )?;

    db.create_scalar_function(
        "fuzzy_hash_compare",
        2, /*NUM_ARGS*/
        FunctionFlags::SQLITE_UTF8,
        move |ctx| {
            if !ctx.len() == NUM_ARGS {
                return Err(Error::InvalidParameterCount(ctx.len(), NUM_ARGS));
            }

            let ssdeep1 = ctx.get_raw(0).as_str().unwrap_or_default();
            let ssdeep2 = ctx.get_raw(1).as_str().unwrap_or_default();
            let similarity = FuzzyHash::compare(ssdeep1, ssdeep2)
                .map_err(|e| Error::UserFunctionError(Box::new(e)))?;

            Ok(similarity)
        },
    )?;

    db.create_scalar_function(
        "tlsh_compare",
        2, /*NUM_ARGS*/
        FunctionFlags::SQLITE_UTF8,
        move |ctx| {
            if !ctx.len() == NUM_ARGS {
                return Err(Error::InvalidParameterCount(ctx.len(), NUM_ARGS));
            }

            let tlsh1 = ctx.get_raw(0).as_str().unwrap_or_default();
            let tlsh1 = match Tlsh::from_str(tlsh1) {
                Ok(t) => t,
                Err(e) => {
                    return Err(Error::UserFunctionError(Box::new(TlshError(e))));
                }
            };

            let tlsh2 = ctx.get_raw(1).as_str().unwrap_or_default();
            let tlsh2 = match Tlsh::from_str(tlsh2) {
                Ok(t) => t,
                Err(e) => {
                    return Err(Error::UserFunctionError(Box::new(TlshError(e))));
                }
            };

            Ok(tlsh1.diff(&tlsh2, true))
        },
    )
}