#[cfg(any(test, feature = "admin"))]
mod admin;
mod pg;
#[cfg(any(test, feature = "sqlite"))]
mod sqlite;
#[cfg(any(test, feature = "sqlite"))]
mod sqlite_functions;
pub mod types;
#[cfg(any(test, feature = "admin"))]
use crate::crypto::EncryptionOption;
use crate::crypto::FileEncryption;
use crate::db::pg::Postgres;
#[cfg(any(test, feature = "sqlite"))]
use crate::db::sqlite::Sqlite;
use crate::db::types::{FileMetadata, FileType};
use malwaredb_api::{
digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
};
use malwaredb_types::KnownType;
use std::collections::HashMap;
use std::path::PathBuf;
use anyhow::{bail, ensure, Result};
use argon2::password_hash::{rand_core::OsRng, SaltString};
use argon2::{Argon2, PasswordHasher};
#[cfg(any(test, feature = "admin"))]
use chrono::Local;
#[cfg(feature = "vt")]
use malwaredb_virustotal::filereport::ScanResultAttributes;
pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
#[derive(Copy, Clone)]
pub enum Migration {
Check,
#[cfg(any(test, feature = "admin"))]
Migrate,
}
#[derive(Debug)]
pub enum DatabaseType {
Postgres(Postgres),
#[cfg(any(test, feature = "sqlite"))]
SQLite(Sqlite),
}
#[derive(Debug)]
pub struct DatabaseInformation {
pub version: String,
pub size: String,
pub num_files: u64,
pub num_users: u32,
pub num_groups: u32,
pub num_sources: u32,
}
pub struct FileAddedResult {
pub file_id: u64,
pub is_new: bool,
}
#[derive(Debug)]
pub struct MDBConfig {
pub name: String,
pub compression: bool,
pub send_samples_to_vt: bool,
pub keep_unknown_files: bool,
pub(crate) default_key: Option<u32>,
}
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
#[cfg(feature = "vt")]
#[derive(Debug, Clone, Copy)]
pub struct VtStats {
pub clean_records: u32,
pub hits_records: u32,
pub files_without_records: u32,
}
impl DatabaseType {
pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
let db = Self::init_from_string(arg, server_ca).await?;
db.migrate_check(Migration::Check).await?;
Ok(db)
}
#[cfg(feature = "admin")]
pub async fn migrate(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
let db = Self::init_from_string(arg, server_ca).await?;
db.migrate_check(Migration::Migrate).await?;
Ok(db)
}
async fn init_from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
#[cfg(any(test, feature = "sqlite"))]
if arg.starts_with("file:") {
let new_conn_str = arg.trim_start_matches("file:");
let db = DatabaseType::SQLite(Sqlite::new(new_conn_str)?);
return Ok(db);
}
if arg.starts_with("postgres") {
let new_conn_str = arg.trim_start_matches("postgres");
let db = DatabaseType::Postgres(Postgres::new(new_conn_str, server_ca).await?);
return Ok(db);
}
bail!("unknown database type `{arg}`")
}
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
#[cfg(feature = "vt")]
pub async fn enable_vt_upload(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
#[cfg(feature = "vt")]
pub async fn disable_vt_upload(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
#[cfg(feature = "vt")]
pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
match self {
DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
#[cfg(feature = "vt")]
pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.store_vt_record(results),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
#[cfg(feature = "vt")]
pub async fn get_vt_stats(&self) -> Result<VtStats> {
match self {
DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_vt_stats(),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
#[cfg(feature = "yara")]
pub async fn add_yara_search(
&self,
uid: u32,
yara_string: &str,
yara_bytes: &[u8],
) -> Result<uuid::Uuid> {
match self {
DatabaseType::Postgres(pg) => pg.add_yara_search(uid, yara_string, yara_bytes).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_yara_search(uid, yara_string, yara_bytes),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
#[cfg(feature = "yara")]
pub async fn get_unfinished_yara_tasks(&self) -> Result<Vec<crate::yara::YaraTask>> {
match self {
DatabaseType::Postgres(pg) => pg.get_unfinished_yara_tasks().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_unfinished_yara_tasks(),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
#[cfg(feature = "yara")]
pub async fn add_yara_match(
&self,
id: uuid::Uuid,
rule_name: &str,
file_sha256: &str,
) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.add_yara_match(id, rule_name, file_sha256).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_yara_match(id, rule_name, file_sha256),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
#[cfg(feature = "yara")]
pub async fn mark_yara_task_as_finished(&self, id: uuid::Uuid) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.mark_yara_task_as_finished(id).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.mark_yara_task_as_finished(id),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
#[cfg(feature = "yara")]
pub async fn yara_add_next_file_id(&self, id: uuid::Uuid, file_id: u64) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.yara_add_next_file_id(id, file_id).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.yara_add_next_file_id(id, file_id),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
#[cfg(feature = "yara")]
pub async fn get_yara_results(
&self,
id: uuid::Uuid,
user_id: u32,
) -> Result<malwaredb_api::YaraSearchResponse> {
match self {
DatabaseType::Postgres(pg) => pg.get_yara_results(id, user_id).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_yara_results(id, user_id),
}
}
pub async fn get_config(&self) -> Result<MDBConfig> {
match self {
DatabaseType::Postgres(pg) => pg.get_config().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_config(),
}
}
pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
match self {
DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
}
}
pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
ensure!(!apikey.is_empty(), "API key was empty");
match self {
DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_uid(apikey),
}
}
pub async fn db_info(&self) -> Result<DatabaseInformation> {
match self {
DatabaseType::Postgres(pg) => pg.db_info().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.db_info(),
}
}
pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
match self {
DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_user_info(uid),
}
}
pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
match self {
DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
}
}
pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
}
}
pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
match self {
DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_known_data_types(),
}
}
pub async fn get_labels(&self) -> Result<Labels> {
match self {
DatabaseType::Postgres(pg) => pg.get_labels().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_labels(),
}
}
pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
match self {
DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
}
}
pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
match self {
DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
}
}
pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
match self {
DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
}
}
pub async fn add_file(
&self,
meta: &FileMetadata,
known_type: KnownType<'_>,
uid: u32,
sid: u32,
ftype: u32,
parent: Option<u64>,
) -> Result<FileAddedResult> {
match self {
DatabaseType::Postgres(pg) => {
pg.add_file(meta, known_type, uid, sid, ftype, parent).await
}
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
}
}
pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
match self {
DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
}
}
pub async fn cleanup(&self) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => pg.cleanup().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.cleanup(),
}
}
pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
match self {
DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
}
}
pub async fn get_sample_report(
&self,
uid: u32,
hash: &HashType,
) -> Result<malwaredb_api::Report> {
match self {
DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
}
}
pub async fn find_similar_samples(
&self,
uid: u32,
sim: &[(malwaredb_api::SimilarityHashType, String)],
) -> Result<Vec<malwaredb_api::SimilarSample>> {
match self {
DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
}
}
pub async fn user_allowed_files_by_sha256(
&self,
uid: u32,
next: Option<u64>,
) -> Result<(Vec<String>, u64)> {
match self {
DatabaseType::Postgres(pg) => pg.user_allowed_files_by_sha256(uid, next).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.user_allowed_files_by_sha256(uid, next),
}
}
pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
match self {
DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
}
}
pub(crate) async fn get_file_encryption_key_id(
&self,
hash: &str,
) -> Result<(Option<u32>, Option<Vec<u8>>)> {
match self {
DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
}
}
pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
}
}
pub async fn migrate_check(&self, action: Migration) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.migrate(action).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.migrate(action),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn set_name(&self, name: &str) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.set_name(name).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.set_name(name),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn enable_compression(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.enable_compression().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.enable_compression(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn disable_compression(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.disable_compression().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.disable_compression(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn enable_keep_unknown_files(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn disable_keep_unknown_files(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
match self {
DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
match self {
DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
}
}
#[allow(clippy::too_many_arguments)]
#[cfg(any(test, feature = "admin"))]
pub async fn create_user(
&self,
uname: &str,
fname: &str,
lname: &str,
email: &str,
password: Option<String>,
organisation: Option<&String>,
readonly: bool,
) -> Result<u32> {
match self {
DatabaseType::Postgres(pg) => {
pg.create_user(uname, fname, lname, email, password, organisation, readonly)
.await
}
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => {
sl.create_user(uname, fname, lname, email, password, organisation, readonly)
}
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn reset_api_keys(&self) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.reset_api_keys(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.set_password(uname, password),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_users(&self) -> Result<Vec<admin::User>> {
match self {
DatabaseType::Postgres(pg) => pg.list_users().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.list_users(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
match self {
DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_group(
&self,
gid: u32,
name: &str,
desc: &str,
parent: Option<u32>,
) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
match self {
DatabaseType::Postgres(pg) => pg.list_groups().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.list_groups(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_group(
&self,
name: &str,
description: &str,
parent: Option<u32>,
) -> Result<u32> {
match self {
DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
match self {
DatabaseType::Postgres(pg) => pg.list_sources().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.list_sources(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_source(
&self,
name: &str,
description: Option<&str>,
url: Option<&str>,
date: chrono::DateTime<Local>,
releasable: bool,
malicious: Option<bool>,
) -> Result<u32> {
match self {
DatabaseType::Postgres(pg) => {
pg.create_source(name, description, url, date, releasable, malicious)
.await
}
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => {
sl.create_source(name, description, url, date, releasable, malicious)
}
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_user(
&self,
uid: u32,
uname: &str,
fname: &str,
lname: &str,
email: &str,
readonly: bool,
) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => {
pg.edit_user(uid, uname, fname, lname, email, readonly)
.await
}
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
match self {
DatabaseType::Postgres(pg) => pg.file_types_counts().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.file_types_counts(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.create_label(name, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
}
}
}
pub fn hash_password(password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
Ok(argon2
.hash_password(password.as_bytes(), &salt)?
.to_string())
}
#[must_use]
pub fn random_bytes_api_key() -> String {
let key1 = uuid::Uuid::new_v4();
let key2 = uuid::Uuid::new_v4();
let key1 = key1.to_string().replace('-', "");
let key2 = key2.to_string().replace('-', "");
format!("{key1}{key2}")
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "vt")]
use crate::vt::VtUpdater;
use std::fs;
#[cfg(feature = "vt")]
use std::sync::Arc;
#[cfg(feature = "vt")]
use std::time::SystemTime;
use anyhow::Context;
use fuzzyhash::FuzzyHash;
use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
use malwaredb_lzjd::{LZDict, Murmur3HashState};
use tlsh_fixed::TlshBuilder;
use uuid::Uuid;
const MALWARE_LABEL: &str = "malware";
const RANSOMWARE_LABEL: &str = "ransomware";
fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
let mut hashes = vec![];
hashes.push((
malwaredb_api::SimilarityHashType::SSDeep,
FuzzyHash::new(data).to_string(),
));
let mut builder = TlshBuilder::new(
tlsh_fixed::BucketKind::Bucket256,
tlsh_fixed::ChecksumKind::ThreeByte,
tlsh_fixed::Version::Version4,
);
builder.update(data);
if let Ok(hasher) = builder.build() {
hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
}
let build_hasher = Murmur3HashState::default();
let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
malwaredb_api::SimilarSamplesRequest { hashes }
}
async fn pg_config() -> Postgres {
const CONNECTION_STRING: &str =
"user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
if let Ok(pg_port) = std::env::var("PG_PORT") {
let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
Postgres::new(&conn_string, None)
.await
.context(format!(
"failed to connect to postgres with specified port {pg_port}"
))
.unwrap()
} else {
Postgres::new(CONNECTION_STRING, None).await.unwrap()
}
}
#[tokio::test]
#[ignore = "don't run this in CI"]
async fn pg() {
let psql = pg_config().await;
psql.delete().await.unwrap();
let psql = pg_config().await;
let db = DatabaseType::Postgres(psql);
everything(&db).await.unwrap();
#[cfg(feature = "vt")]
{
let db_config = db.get_config().await.unwrap();
let state = crate::State {
port: 8080,
directory: None,
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type: Arc::new(db),
db_config,
keys: HashMap::new(),
started: SystemTime::now(),
vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
Some(malwaredb_virustotal::VirusTotalClient::new(e))
}),
tls_config: None,
mdns: None,
};
let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
vt.updater().await.unwrap();
println!("PG: Did VT ops!");
let psql = pg_config().await;
let vt_stats = psql
.get_vt_stats()
.await
.context("failed to get Postgres VT Stats")
.unwrap();
println!("{vt_stats:?}");
assert!(
vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
);
}
let psql = pg_config().await;
psql.delete().await.unwrap();
}
#[tokio::test]
async fn sqlite() {
const DB_FILE: &str = "testing_sqlite.db";
if std::path::Path::new(DB_FILE).exists() {
fs::remove_file(DB_FILE)
.context(format!("failed to delete old SQLite file {DB_FILE}"))
.unwrap();
}
let sqlite = Sqlite::new(DB_FILE)
.context(format!("failed to create SQLite instance for {DB_FILE}"))
.unwrap();
let db = DatabaseType::SQLite(sqlite);
everything(&db).await.unwrap();
#[cfg(feature = "vt")]
{
let db_config = db.get_config().await.unwrap();
let state = crate::State {
port: 8080,
directory: None,
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type: Arc::new(db),
db_config,
keys: HashMap::new(),
started: SystemTime::now(),
vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
Some(malwaredb_virustotal::VirusTotalClient::new(e))
}),
tls_config: None,
mdns: None,
};
let sqlite_second = Sqlite::new(DB_FILE)
.context(format!("failed to create SQLite instance for {DB_FILE}"))
.unwrap();
let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
vt.updater().await.unwrap();
println!("Sqlite: Did VT ops!");
let vt_stats = sqlite_second
.get_vt_stats()
.context("failed to get Sqlite VT Stats")
.unwrap();
println!("{vt_stats:?}");
assert!(
vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
);
}
fs::remove_file(DB_FILE)
.context(format!("failed to delete SQLite file {DB_FILE}"))
.unwrap();
}
#[allow(clippy::too_many_lines)]
async fn everything(db: &DatabaseType) -> Result<()> {
const ADMIN_UNAME: &str = "admin";
const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
db.set_name("Testing Database")
.await
.context("setting instance name failed")?;
assert!(
db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
"Authentication without password should have failed."
);
db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("failed to set admin password")?;
let admin_api_key = db
.authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("unable to get api key for admin")?;
println!("API key: {admin_api_key}");
assert_eq!(admin_api_key.len(), 64);
assert_eq!(
db.get_uid(&admin_api_key).await?,
0,
"Unable to get UID given the API key"
);
let admin_api_key_again = db
.authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("unable to get api key a second time for admin")?;
assert_eq!(
admin_api_key, admin_api_key_again,
"API keys didn't match the second time."
);
let bad_password = "this_is_totally_not_my_password!!";
eprintln!("Testing API login with incorrect password.");
assert!(
db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
"Authenticating as admin with a bad password should have failed."
);
let admin_is_admin = db
.user_is_admin(0)
.await
.context("unable to see if admin (uid 0) is an admin")?;
assert!(admin_is_admin);
let new_user_uname = "testuser";
let new_user_email = "test@example.com";
let new_user_password = "some_awesome_password_++";
let new_id = db
.create_user(
new_user_uname,
new_user_uname,
new_user_uname,
new_user_email,
Some(new_user_password.into()),
None,
false,
)
.await
.context(format!("failed to create user {new_user_uname}"))?;
let passwordless_user_id = db
.create_user(
"passwordless_user",
"passwordless_user",
"passwordless_user",
"passwordless_user@example.com",
None,
None,
false,
)
.await
.context("failed to create passwordless_user")?;
for user in &db.list_users().await.context("failed to list users")? {
if user.id == passwordless_user_id {
assert_eq!(user.uname, "passwordless_user");
}
}
db.edit_user(
passwordless_user_id,
"passwordless_user_2",
"passwordless_user_2",
"passwordless_user_2",
"passwordless_user_2@something.com",
false,
)
.await
.context(format!(
"failed to alter 'passwordless' user, id {passwordless_user_id}"
))?;
for user in &db.list_users().await.context("failed to list users")? {
if user.id == passwordless_user_id {
assert_eq!(user.uname, "passwordless_user_2");
}
}
assert!(
new_id > 0,
"Weird UID created for user {new_user_uname}: {new_id}"
);
assert!(
db.create_user(
new_user_uname,
new_user_uname,
new_user_uname,
new_user_email,
Some(new_user_password.into()),
None,
false
)
.await
.is_err(),
"Creating a new user with the same user name should fail"
);
let ro_user_name = "ro_user";
let ro_user_password = "ro_user_password";
db.create_user(
ro_user_name,
"ro_user",
"ro_user",
"ro@example.com",
Some(ro_user_password.into()),
None,
true,
)
.await
.context("failed to create read-only user")?;
let ro_user_api_key = db
.authenticate(ro_user_name, ro_user_password)
.await
.context("unable to get api key for read-only user")?;
let new_user_password_change = "some_new_awesomer_password!_++";
db.set_password(new_user_uname, new_user_password_change)
.await
.context("failed to change the password for testuser")?;
let new_user_api_key = db
.authenticate(new_user_uname, new_user_password_change)
.await
.context("unable to get api key for testuser")?;
eprintln!("{new_user_uname} got API key {new_user_api_key}");
assert_eq!(admin_api_key.len(), new_user_api_key.len());
let users = db.list_users().await.context("failed to list users")?;
assert_eq!(
users.len(),
4,
"Four users were created, yet there are {} users",
users.len()
);
eprintln!("DB has {} users:", users.len());
let mut passwordless_user_found = false;
for user in users {
println!("{user}");
if user.uname == "passwordless_user_2" {
assert!(!user.has_api_key);
assert!(!user.has_password);
passwordless_user_found = true;
} else {
assert!(user.has_api_key);
assert!(user.has_password);
}
}
assert!(passwordless_user_found);
let new_group_name = "some_new_group";
let new_group_desc = "some_new_group_description";
let new_group_id = 1;
assert_eq!(
db.create_group(new_group_name, new_group_desc, None)
.await
.context("failed to create group")?,
new_group_id,
"New group didn't have the expected ID, expected {new_group_id}"
);
assert!(
db.create_group(new_group_name, new_group_desc, None)
.await
.is_err(),
"Duplicate group name should have failed"
);
db.add_user_to_group(1, 1)
.await
.context("Unable to add uid 1 to gid 1")?;
let ro_user_uid = db
.get_uid(&ro_user_api_key)
.await
.context("Unable to get UID for read-only user")?;
db.add_user_to_group(ro_user_uid, 1)
.await
.context("Unable to add uid 2 to gid 1")?;
let new_admin_group_name = "admin_subgroup";
let new_admin_group_desc = "admin_subgroup_description";
let new_admin_group_id = 2;
assert!(
db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
.await
.context("failed to create admin sub-group")?
>= new_admin_group_id,
"New group didn't have the expected ID, expected >= {new_admin_group_id}"
);
let groups = db.list_groups().await.context("failed to list groups")?;
assert_eq!(
groups.len(),
3,
"Three groups were created, yet there are {} groups",
groups.len()
);
eprintln!("DB has {} groups:", groups.len());
for group in groups {
println!("{group}");
if group.id == new_admin_group_id {
assert_eq!(group.parent, Some("admin".to_string()));
}
if group.id == 1 {
let test_user_str = String::from(new_user_uname);
let mut found = false;
for member in group.members {
if member.uname == test_user_str {
found = true;
break;
}
}
assert!(found, "new user {test_user_str} wasn't in the group");
}
}
let default_source_name = "default_source".to_string();
let default_source_id = db
.create_source(
&default_source_name,
Some("desc_default_source"),
None,
Local::now(),
true,
Some(false),
)
.await
.context("failed to create source `default_source`")?;
db.add_group_to_source(1, default_source_id)
.await
.context("failed to add group 1 to source 1")?;
let another_source_name = "another_source".to_string();
let another_source_id = db
.create_source(
&another_source_name,
Some("yet another file source"),
None,
Local::now(),
true,
Some(false),
)
.await
.context("failed to create source `another_source`")?;
let empty_source_name = "empty_source".to_string();
db.create_source(
&empty_source_name,
Some("empty and unused file source"),
None,
Local::now(),
true,
Some(false),
)
.await
.context("failed to create source `another_source`")?;
db.add_group_to_source(1, another_source_id)
.await
.context("failed to add group 1 to source 1")?;
let sources = db.list_sources().await.context("failed to list sources")?;
eprintln!("DB has {} sources:", sources.len());
for source in sources {
println!("{source}");
assert_eq!(source.files, 0);
if source.id == default_source_id || source.id == another_source_id {
assert_eq!(
source.groups, 1,
"default source {default_source_name} should have 1 group"
);
} else {
assert_eq!(source.groups, 0, "groups should zero (empty)");
}
}
let uid = db
.get_uid(&new_user_api_key)
.await
.context("failed to user uid from apikey")?;
let user_info = db
.get_user_info(uid)
.await
.context("failed to get user's available groups and sources")?;
assert!(user_info.sources.contains(&default_source_name));
assert!(!user_info.is_admin);
println!("UserInfoResponse: {user_info:?}");
assert!(
db.allowed_user_source(1, default_source_id)
.await
.context(format!(
"failed to check that user 1 has access to source {default_source_id}"
))?,
"User 1 should should have had access to source {default_source_id}"
);
assert!(
!db.allowed_user_source(1, 5)
.await
.context("failed to check that user 1 has access to source 5")?,
"User 1 should should not have had access to source 5"
);
let test_label_id = db
.create_label("TestLabel", None)
.await
.context("failed to create test label")?;
let test_elf_label_id = db
.create_label("TestELF", Some(test_label_id))
.await
.context("failed to create test label")?;
let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
let known_type =
KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
assert!(known_type.is_exec(), "ELF should be executable");
eprintln!("ELF type ID: {elf_type}");
let file_addition = db
.add_file(
&test_elf_meta,
known_type.clone(),
1,
default_source_id,
elf_type,
None,
)
.await
.context("failed to insert a test elf")?;
assert!(file_addition.is_new, "File should have been added");
eprintln!("Added ELF to the DB");
db.label_file(file_addition.file_id, test_elf_label_id)
.await
.context("failed to label file")?;
let partial_search = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
labels: Some(vec![String::from("TestELF")]),
file_type: Some(String::from("ELF")),
magic: Some(String::from("OpenPOWER ELF V2 ABI")),
..Default::default()
}),
};
assert!(partial_search.is_valid());
let partial_search_response = db.partial_search(1, partial_search).await?;
assert_eq!(partial_search_response.hashes.len(), 1);
assert_eq!(
partial_search_response.hashes[0],
"897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
);
let partial_search = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
partial_hash: None,
labels: None,
file_type: None,
magic: Some(String::from("OpenPOWER ELF V2 ABI")),
..Default::default()
}),
};
assert!(partial_search.is_valid());
let partial_search_response = db.partial_search(1, partial_search).await?;
assert_eq!(partial_search_response.hashes.len(), 1);
assert_eq!(
partial_search_response.hashes[0],
"897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
);
let partial_search = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
file_type: Some(String::from("PE32")),
..Default::default()
}),
};
assert!(partial_search.is_valid());
let partial_search_response = db.partial_search(1, partial_search).await?;
assert_eq!(partial_search_response.hashes.len(), 0);
let partial_search = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
file_name: Some("ppc64".into()),
..Default::default()
}),
};
assert!(partial_search.is_valid());
let partial_search_response = db.partial_search(1, partial_search).await?;
assert_eq!(partial_search_response.hashes.len(), 1);
let partial_search = SearchRequest {
search: SearchType::Search(SearchRequestParameters::default()),
};
assert!(!partial_search.is_valid());
let partial_search_response = db.partial_search(1, partial_search).await?;
assert!(partial_search_response.hashes.is_empty());
let partial_search = SearchRequest {
search: SearchType::Continuation(Uuid::default()),
};
assert!(partial_search.is_valid());
let partial_search_response = db.partial_search(1, partial_search).await?;
assert!(partial_search_response.hashes.is_empty());
assert!(db
.get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
.await
.is_err());
assert!(
db.add_file(
&test_elf_meta,
known_type.clone(),
ro_user_uid,
default_source_id,
elf_type,
None
)
.await
.is_err(),
"Read-only user should not be able to add a file"
);
let mut test_elf_meta_different_name = test_elf_meta.clone();
test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
assert!(
!db.add_file(
&test_elf_meta_different_name,
known_type,
1,
another_source_id,
elf_type,
None
)
.await
.context("failed to insert a test elf again for a different source")?
.is_new
);
let sources = db
.list_sources()
.await
.context("failed to re-list sources")?;
eprintln!(
"DB has {} sources, and a file was added twice:",
sources.len()
);
println!("We should have two sources with one file each, yet only one ELF.");
for source in sources {
println!("{source}");
if source.id == default_source_id || source.id == another_source_id {
assert_eq!(source.files, 1);
} else {
assert_eq!(source.files, 0, "groups should zero (empty)");
}
}
assert!(!db
.get_user_sources(1)
.await
.expect("failed to get user 1's sources")
.sources
.is_empty());
let file_types_counts = db
.file_types_counts()
.await
.context("failed to get file types and counts")?;
for (name, count) in file_types_counts {
println!("{name}: {count}");
assert_eq!(name, "ELF");
assert_eq!(count, 1);
}
let mut test_elf_modified = test_elf.clone();
let random_bytes = Uuid::new_v4();
let mut random_bytes = random_bytes.into_bytes().to_vec();
test_elf_modified.append(&mut random_bytes);
let similarity_request = generate_similarity_request(&test_elf_modified);
let similarity_response = db
.find_similar_samples(1, &similarity_request.hashes)
.await
.context("failed to get similarity response")?;
eprintln!("Similarity response: {similarity_response:?}");
let similarity_response = similarity_response.first().unwrap();
assert_eq!(
similarity_response.sha256,
hex::encode(&test_elf_meta.sha256),
"Similarity response should have had the hash of the original ELF"
);
for (algo, sim) in &similarity_response.algorithms {
match algo {
malwaredb_api::SimilarityHashType::LZJD => {
assert!(*sim > 0.0f32);
}
malwaredb_api::SimilarityHashType::SSDeep => {
assert!(*sim > 80.0f32);
}
malwaredb_api::SimilarityHashType::TLSH => {
assert!(*sim <= 20f32);
}
_ => {}
}
}
let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1.as_slice())
.context("failed to get `HashType::SHA1` from string")?;
let response_sha256 = db
.retrieve_sample(1, &test_elf_hashtype)
.await
.context("could not get SHA-256 hash from test sample")
.unwrap();
assert_eq!(response_sha256, hex::encode(&test_elf_meta.sha256));
let test_bogus_hash =
HashType::try_from("d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0")
.context("failed to get `HashType` from static string")?;
assert!(
db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
"Getting a file with a bogus hash should have failed."
);
let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
let known_type =
KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
assert!(
db.add_file(
&test_pdf_meta,
known_type,
1,
default_source_id,
pdf_type,
None
)
.await
.context("failed to insert a test pdf")?
.is_new
);
eprintln!("Added PDF to the DB");
let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
let rtf_type = db
.get_type_id_for_bytes(&test_rtf)
.await
.context("failed to get file type id for rtf")?;
let known_type =
KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
assert!(
db.add_file(
&test_rtf_meta,
known_type,
1,
default_source_id,
rtf_type,
None
)
.await
.context("failed to insert a test rtf")?
.is_new
);
eprintln!("Added RTF to the DB");
let report = db
.get_sample_report(
1,
&HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap(),
)
.await
.context("failed to get report for test rtf")?;
assert!(report
.clone()
.filecommand
.unwrap()
.contains("Rich Text Format"));
println!("Report: {report}");
assert!(db
.get_sample_report(
999,
&HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap()
)
.await
.is_err());
#[cfg(feature = "vt")]
{
assert!(report.vt.is_some());
let files_needing_vt = db
.files_without_vt_records(10)
.await
.context("failed to get files without VT records")?;
assert!(files_needing_vt.len() > 2);
println!(
"{} files needing VT data: {files_needing_vt:?}",
files_needing_vt.len()
);
}
#[cfg(not(feature = "vt"))]
{
assert!(report.vt.is_none());
}
let reset = db
.reset_api_keys()
.await
.context("failed to reset all API keys")?;
eprintln!("Cleared {reset} api keys.");
let db_info = db.db_info().await.context("failed to get database info")?;
eprintln!("DB Info: {db_info:?}");
let data_types = db
.get_known_data_types()
.await
.context("failed to get data types")?;
for data_type in data_types {
println!("{data_type:?}");
}
let sources = db
.list_sources()
.await
.context("failed to list sources second time")?;
eprintln!("DB has {} sources:", sources.len());
for source in sources {
println!("{source}");
}
let file_types_counts = db
.file_types_counts()
.await
.context("failed to get file types and counts")?;
for (name, count) in file_types_counts {
println!("{name}: {count}");
assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
}
let fatmacho =
include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
.to_vec();
let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
let fatmacho_type = db
.get_type_id_for_bytes(&fatmacho)
.await
.context("failed to get file type for Fat Mach-O")?;
let known_type = KnownType::new(&fatmacho)
.context("failed to parse Fat Mach-O from type crate's test data")?;
assert!(
db.add_file(
&fatmacho_meta,
known_type,
1,
default_source_id,
fatmacho_type,
None
)
.await
.context("failed to insert a test Fat Mach-O")?
.is_new
);
eprintln!("Added Fat Mach-O to the DB");
let file_types_counts = db
.file_types_counts()
.await
.context("failed to get file types and counts")?;
for (name, count) in &file_types_counts {
println!("{name}: {count}");
}
assert_eq!(
*file_types_counts.get("Mach-O").unwrap(),
4,
"Expected 4 Mach-O files, got {:?}",
file_types_counts.get("Mach-O")
);
let allowed_files = db
.user_allowed_files_by_sha256(1, None)
.await
.context("failed to get allowed files")?;
assert_eq!(allowed_files.0.len(), 8);
let allowed_files = db
.user_allowed_files_by_sha256(1, Some(allowed_files.1))
.await
.context("failed to get allowed files")?;
assert!(allowed_files.0.is_empty());
let malware_label_id = db
.create_label(MALWARE_LABEL, None)
.await
.context("failed to create first label")?;
let ransomware_label_id = db
.create_label(RANSOMWARE_LABEL, Some(malware_label_id))
.await
.context("failed to create malware sub-label")?;
let labels = db.get_labels().await.context("failed to get labels")?;
assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
for label in labels.0 {
if label.name == RANSOMWARE_LABEL {
assert_eq!(label.id, ransomware_label_id);
assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
}
}
let source_code = include_bytes!("mod.rs");
let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
let known_type =
KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
assert!(matches!(known_type, KnownType::Unknown(_)));
let unknown_type: Vec<FileType> = db
.get_known_data_types()
.await?
.into_iter()
.filter(|t| t.name.eq_ignore_ascii_case("unknown"))
.collect();
let unknown_type_id = unknown_type.first().unwrap().id;
assert!(db.get_type_id_for_bytes(source_code).await.is_err());
db.enable_keep_unknown_files()
.await
.context("failed to enable keeping of unknown files")?;
let source_type = db
.get_type_id_for_bytes(source_code)
.await
.context("failed to type id for source code unknown type example")?;
assert_eq!(source_type, unknown_type_id);
eprintln!("Unknown file type ID: {source_type}");
assert!(
db.add_file(
&source_meta,
known_type,
1,
default_source_id,
unknown_type_id,
None
)
.await
.context("failed to add Rust source code file")?
.is_new
);
eprintln!("Added Rust source code to the DB");
#[cfg(feature = "yara")]
assert!(db.get_unfinished_yara_tasks().await?.is_empty());
db.reset_own_api_key(0)
.await
.context("failed to clear own API key uid 0")?;
db.deactivate_user(0)
.await
.context("failed to clear password and API key for uid 0")?;
Ok(())
}
}