#[cfg(any(test, feature = "admin"))]
mod admin;
mod pg;
#[cfg(any(test, feature = "sqlite"))]
mod sqlite;
pub mod types;
#[cfg(any(test, feature = "admin"))]
use crate::crypto::EncryptionOption;
use crate::crypto::FileEncryption;
use crate::db::pg::Postgres;
#[cfg(any(test, feature = "sqlite"))]
use crate::db::sqlite::Sqlite;
use crate::db::types::{FileMetadata, FileType};
use malwaredb_api::{digest::HashType, GetUserInfoResponse, Labels, Sources};
use malwaredb_types::KnownType;
use std::collections::HashMap;
use anyhow::{bail, Result};
use argon2::password_hash::{rand_core::OsRng, SaltString};
use argon2::{Argon2, PasswordHasher};
#[cfg(any(test, feature = "admin"))]
use chrono::Local;
#[cfg(feature = "vt")]
use malwaredb_virustotal::filereport::ScanResultAttributes;
#[derive(Debug)]
pub enum DatabaseType {
Postgres(Postgres),
#[cfg(any(test, feature = "sqlite"))]
SQLite(Sqlite),
}
#[derive(Debug)]
pub struct DatabaseInformation {
pub version: String,
pub size: String,
pub num_files: u64,
pub num_users: u32,
pub num_groups: u32,
pub num_sources: u32,
}
#[derive(Debug)]
pub struct MDBConfig {
pub name: String,
pub compression: bool,
pub send_samples_to_vt: bool,
#[allow(dead_code)]
pub(crate) default_key: Option<u32>,
}
#[cfg(feature = "vt")]
#[derive(Debug, Clone, Copy)]
pub struct VtStats {
pub clean_records: u32,
pub hits_records: u32,
pub files_without_records: u32,
}
impl DatabaseType {
pub async fn from_string(arg: &str) -> Result<Self> {
#[cfg(any(test, feature = "sqlite"))]
if arg.starts_with("file:") {
let new_conn_str = arg.trim_start_matches("file:");
return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
}
if arg.starts_with("postgres") {
let new_conn_str = arg.trim_start_matches("postgres");
return Ok(DatabaseType::Postgres(Postgres::new(new_conn_str).await?));
}
bail!("unknown database type `{arg}`")
}
pub async fn enable_compression(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.enable_compression().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.enable_compression(),
}
}
pub async fn disable_compression(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.disable_compression().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.disable_compression(),
}
}
#[cfg(feature = "vt")]
pub async fn enable_vt_upload(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
}
}
#[cfg(feature = "vt")]
pub async fn disable_vt_upload(&self) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
}
}
#[cfg(feature = "vt")]
pub async fn files_without_vt_records(&self, limit: i32) -> Result<Vec<String>> {
match self {
DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
}
}
#[cfg(feature = "vt")]
pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.store_vt_record(results),
}
}
#[cfg(feature = "vt")]
pub async fn get_vt_stats(&self) -> Result<VtStats> {
match self {
DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_vt_stats(),
}
}
pub async fn get_config(&self) -> Result<MDBConfig> {
match self {
DatabaseType::Postgres(pg) => pg.get_config().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_config(),
}
}
pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
match self {
DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
}
}
pub async fn get_uid(&self, apikey: &str) -> Result<i32> {
match self {
DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_uid(apikey),
}
}
pub async fn db_info(&self) -> Result<DatabaseInformation> {
match self {
DatabaseType::Postgres(pg) => pg.db_info().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.db_info(),
}
}
pub async fn get_user_info(&self, uid: i32) -> Result<GetUserInfoResponse> {
match self {
DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_user_info(uid),
}
}
pub async fn get_user_sources(&self, uid: i32) -> Result<Sources> {
match self {
DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
}
}
pub async fn reset_own_api_key(&self, uid: i32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
}
}
pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
match self {
DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_known_data_types(),
}
}
pub async fn get_labels(&self) -> Result<Labels> {
match self {
DatabaseType::Postgres(pg) => pg.get_labels().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_labels(),
}
}
pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<i32> {
match self {
DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
}
}
pub async fn allowed_user_source(&self, uid: i32, sid: i32) -> Result<bool> {
match self {
DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
}
}
pub async fn user_is_admin(&self, uid: i32) -> Result<bool> {
match self {
DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
}
}
pub async fn add_file(
&self,
meta: &FileMetadata,
known_type: KnownType<'_>,
uid: i32,
sid: i32,
ftype: i32,
parent: Option<i64>,
) -> Result<bool> {
match self {
DatabaseType::Postgres(pg) => {
pg.add_file(meta, known_type, uid, sid, ftype, parent).await
}
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_file(meta, known_type, uid, sid, ftype, parent),
}
}
pub async fn retrieve_sample(&self, uid: i32, hash: &HashType) -> Result<String> {
match self {
DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
}
}
pub async fn get_sample_report(
&self,
uid: i32,
hash: &HashType,
) -> Result<malwaredb_api::Report> {
match self {
DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
}
}
pub async fn find_similar_samples(
&self,
uid: i32,
sim: &[(malwaredb_api::SimilarityHashType, String)],
) -> Result<Vec<malwaredb_api::SimilarSample>> {
match self {
DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
}
}
pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
match self {
DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
}
}
pub(crate) async fn get_file_encryption_key_id(
&self,
hash: &str,
) -> Result<(Option<u32>, Option<Vec<u8>>)> {
match self {
DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
}
}
pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: &Option<Vec<u8>>) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
match self {
DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
match self {
DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_user(
&self,
uname: &str,
fname: &str,
lname: &str,
email: &str,
password: Option<String>,
organisation: Option<String>,
) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => {
pg.create_user(uname, fname, lname, email, password, organisation)
.await
}
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => {
sl.create_user(uname, fname, lname, email, password, organisation)
}
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn reset_api_keys(&self) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.reset_api_keys(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.set_password(uname, password),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_users(&self) -> Result<Vec<admin::User>> {
match self {
DatabaseType::Postgres(pg) => pg.list_users().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.list_users(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
match self {
DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_group(
&self,
gid: i32,
name: &str,
desc: &str,
parent: Option<i32>,
) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
match self {
DatabaseType::Postgres(pg) => pg.list_groups().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.list_groups(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_user_to_group(&self, uid: i32, gid: i32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_group_to_source(&self, gid: i32, sid: i32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_group(
&self,
name: &str,
description: &str,
parent: Option<i32>,
) -> Result<i32> {
match self {
DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
match self {
DatabaseType::Postgres(pg) => pg.list_sources().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.list_sources(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_source(
&self,
name: &str,
description: Option<&str>,
url: Option<&str>,
date: chrono::DateTime<Local>,
releasable: bool,
malicious: Option<bool>,
) -> Result<i32> {
match self {
DatabaseType::Postgres(pg) => {
pg.create_source(name, description, url, date, releasable, malicious)
.await
}
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => {
sl.create_source(name, description, url, date, releasable, malicious)
}
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_user(
&self,
uid: i32,
uname: &str,
fname: &str,
lname: &str,
email: &str,
) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.edit_user(uid, uname, fname, lname, email).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn deactivate_user(&self, uid: i32) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
match self {
DatabaseType::Postgres(pg) => pg.file_types_counts().await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.file_types_counts(),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_label(&self, name: &str, parent: Option<i64>) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.create_label(name, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
match self {
DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
}
}
#[cfg(any(test, feature = "admin"))]
pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
match self {
DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
#[cfg(any(test, feature = "sqlite"))]
DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
}
}
}
pub fn hash_password(password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
Ok(argon2
.hash_password(password.as_bytes(), &salt)?
.to_string())
}
pub fn random_bytes_api_key() -> String {
let key1 = uuid::Uuid::new_v4();
let key2 = uuid::Uuid::new_v4();
let key1 = key1.to_string().replace('-', "");
let key2 = key2.to_string().replace('-', "");
format!("{key1}{key2}")
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "vt")]
use crate::vt::VtUpdater;
use std::fs;
#[cfg(feature = "vt")]
use std::time::SystemTime;
use anyhow::Context;
use fuzzyhash::FuzzyHash;
use malwaredb_lzjd::{LZDict, Murmur3HashState};
use tlsh_fixed::TlshBuilder;
use uuid::Uuid;
fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
let mut hashes = vec![];
hashes.push((
malwaredb_api::SimilarityHashType::SSDeep,
FuzzyHash::new(data).to_string(),
));
let mut builder = TlshBuilder::new(
tlsh_fixed::BucketKind::Bucket256,
tlsh_fixed::ChecksumKind::ThreeByte,
tlsh_fixed::Version::Version4,
);
builder.update(data);
if let Ok(hasher) = builder.build() {
hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
}
let build_hasher = Murmur3HashState::default();
let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
malwaredb_api::SimilarSamplesRequest { hashes }
}
async fn pg_config() -> Postgres {
const CONNECTION_STRING: &str =
"user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost";
if let Ok(pg_port) = std::env::var("PG_PORT") {
let mut conn_string = CONNECTION_STRING.to_string();
conn_string.push_str(&format!(" port={pg_port}"));
Postgres::new(&conn_string)
.await
.context(format!(
"failed to connect to postgres with specified port {pg_port}"
))
.unwrap()
} else {
Postgres::new(CONNECTION_STRING).await.unwrap()
}
}
#[tokio::test]
#[ignore]
async fn pg() {
let psql = pg_config().await;
psql.delete_init().await.unwrap();
let db = DatabaseType::Postgres(psql);
let key = FileEncryption::from(EncryptionOption::Xor);
db.add_file_encryption_key(&key).await.unwrap();
assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
everything(&db).await.unwrap();
#[cfg(feature = "vt")]
{
let db_config = db.get_config().await.unwrap();
let state = crate::State {
port: 8080,
directory: None,
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type: db,
db_config,
keys: HashMap::new(),
started: SystemTime::now(),
vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
Some(malwaredb_virustotal::VirusTotalClient::new(e))
}),
};
let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
vt.updater().await.unwrap();
println!("PG: Did VT ops!");
let psql = pg_config().await;
let vt_stats = psql
.get_vt_stats()
.await
.context("failed to get Postgres VT Stats")
.unwrap();
println!("{vt_stats:?}");
assert!(
vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
);
}
let psql = pg_config().await;
psql.delete_init().await.unwrap();
}
#[tokio::test]
async fn sqlite() {
const DB_FILE: &str = "testing_sqlite.db";
if std::path::Path::new(DB_FILE).exists() {
fs::remove_file(DB_FILE)
.context(format!("failed to delete old SQLite file {DB_FILE}"))
.unwrap();
}
let sqlite = Sqlite::new(DB_FILE)
.context(format!("failed to create SQLite instance for {DB_FILE}"))
.unwrap();
let db = DatabaseType::SQLite(sqlite);
let key = FileEncryption::from(EncryptionOption::Xor);
db.add_file_encryption_key(&key).await.unwrap();
assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
everything(&db).await.unwrap();
#[cfg(feature = "vt")]
{
let db_config = db.get_config().await.unwrap();
let state = crate::State {
port: 8080,
directory: None,
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type: db,
db_config,
keys: HashMap::new(),
started: SystemTime::now(),
vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
Some(malwaredb_virustotal::VirusTotalClient::new(e))
}),
};
let sqlite_second = Sqlite::new(DB_FILE)
.context(format!("failed to create SQLite instance for {DB_FILE}"))
.unwrap();
let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
vt.updater().await.unwrap();
println!("Sqlite: Did VT ops!");
let vt_stats = sqlite_second
.get_vt_stats()
.context("failed to get Sqlite VT Stats")
.unwrap();
println!("{vt_stats:?}");
assert!(
vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
);
}
fs::remove_file(DB_FILE)
.context(format!("failed to delete SQLite file {DB_FILE}"))
.unwrap();
}
async fn everything(db: &DatabaseType) -> Result<()> {
const ADMIN_UNAME: &str = "admin";
const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
assert!(
db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
"Authentication without password should have failed."
);
db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("failed to set admin password")?;
let admin_api_key = db
.authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("unable to get api key for admin")?;
println!("API key: {admin_api_key}");
assert_eq!(admin_api_key.len(), 64);
assert_eq!(
db.get_uid(&admin_api_key).await?,
0,
"Unable to get UID given the API key"
);
let admin_api_key_again = db
.authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("unable to get api key a second time for admin")?;
assert_eq!(
admin_api_key, admin_api_key_again,
"API keys didn't match the second time."
);
let bad_password = "this_is_totally_not_my_password!!";
eprintln!("Testing API login with incorrect password.");
assert!(
db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
"Authenticating as admin with a bad password should have failed."
);
let admin_is_admin = db
.user_is_admin(0)
.await
.context("unable to see if admin (uid 0) is an admin")?;
assert!(admin_is_admin);
let new_user_uname = "testuser";
let new_user_email = "test@example.com";
let new_user_password = "some_awesome_password_++";
let new_id = db
.create_user(
new_user_uname,
new_user_uname,
new_user_uname,
new_user_email,
Some(new_user_password.into()),
None,
)
.await
.context(format!("failed to create user {new_user_uname}"))?;
let passwordless_user_id = db
.create_user(
"passwordless_user",
"passwordless_user",
"passwordless_user",
"passwordless_user@example.com",
None,
None,
)
.await
.context("failed to create passwordless_user")?;
for user in db
.list_users()
.await
.context("failed to list users")?
.iter()
{
if user.id == passwordless_user_id as i32 {
assert_eq!(user.uname, "passwordless_user");
}
}
db.edit_user(
passwordless_user_id as i32,
"passwordless_user_2",
"passwordless_user_2",
"passwordless_user_2",
"passwordless_user_2@something.com",
)
.await
.context(format!(
"failed to alter 'passwordless' user, id {passwordless_user_id}"
))?;
for user in db
.list_users()
.await
.context("failed to list users")?
.iter()
{
if user.id == passwordless_user_id as i32 {
assert_eq!(user.uname, "passwordless_user_2");
}
}
assert!(
new_id > 0,
"Weird UID created for user {}: {}",
new_user_uname,
new_id
);
assert!(
db.create_user(
new_user_uname,
new_user_uname,
new_user_uname,
new_user_email,
Some(new_user_password.into()),
None,
)
.await
.is_err(),
"Creating a new user with the same user name should fail"
);
let new_user_password_change = "some_new_awesomer_password!_++";
db.set_password(new_user_uname, new_user_password_change)
.await
.context("failed to change the password for testuser")?;
let new_user_api_key = db
.authenticate(new_user_uname, new_user_password_change)
.await
.context("unable to get api key for testuser")?;
eprintln!("{new_user_uname} got API key {new_user_api_key}");
assert_eq!(admin_api_key.len(), new_user_api_key.len());
let users = db.list_users().await.context("failed to list users")?;
assert_eq!(
users.len(),
3,
"Three users were created, yet there are {} users",
users.len()
);
eprintln!("DB has {} users:", users.len());
let mut passwordless_user_found = false;
for user in users {
println!("{user}");
if user.uname == "passwordless_user_2" {
assert!(!user.has_api_key);
assert!(!user.has_password);
passwordless_user_found = true;
} else {
assert!(user.has_api_key);
assert!(user.has_password);
}
}
assert!(passwordless_user_found);
let new_group_name = "some_new_group";
let new_group_desc = "some_new_group_description";
let new_group_id = 1;
assert_eq!(
db.create_group(new_group_name, new_group_desc, None)
.await
.context("failed to create group")?,
new_group_id,
"New group didn't have the expected ID, expected {}",
new_group_id
);
assert!(
db.create_group(new_group_name, new_group_desc, None)
.await
.is_err(),
"Duplicate group name should have failed"
);
db.add_user_to_group(1, 1)
.await
.context("Unable to add uid 1 to gid 1")?;
let new_admin_group_name = "admin_subgroup";
let new_admin_group_desc = "admin_subgroup_description";
let new_admin_group_id = 2;
assert!(
db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
.await
.context("failed to create admin sub-group")?
>= new_admin_group_id,
"New group didn't have the expected ID, expected >= {}",
new_admin_group_id
);
let groups = db.list_groups().await.context("failed to list groups")?;
assert_eq!(
groups.len(),
3,
"Three groups were created, yet there are {} groups",
groups.len()
);
eprintln!("DB has {} groups:", groups.len());
for group in groups {
println!("{group}");
if group.id == new_admin_group_id {
assert_eq!(group.parent, Some("admin".to_string()));
}
if group.id == 1 {
let test_user_str = String::from(new_user_uname);
let mut found = false;
for member in group.members {
if member.uname == test_user_str {
found = true;
break;
}
}
assert!(found, "new user {} wasn't in the group", test_user_str);
}
}
let default_source_name = "default_source".to_string();
let default_source_id = db
.create_source(
&default_source_name,
Some("desc_default_source"),
None,
Local::now(),
true,
Some(false),
)
.await
.context("failed to create source `default_source`")?;
db.add_group_to_source(1, default_source_id)
.await
.context("failed to add group 1 to source 1")?;
let another_source_name = "another_source".to_string();
let another_source_id = db
.create_source(
&another_source_name,
Some("yet another file source"),
None,
Local::now(),
true,
Some(false),
)
.await
.context("failed to create source `another_source`")?;
let empty_source_name = "empty_source".to_string();
db.create_source(
&empty_source_name,
Some("empty and unused file source"),
None,
Local::now(),
true,
Some(false),
)
.await
.context("failed to create source `another_source`")?;
db.add_group_to_source(1, another_source_id)
.await
.context("failed to add group 1 to source 1")?;
let sources = db.list_sources().await.context("failed to list sources")?;
eprintln!("DB has {} sources:", sources.len());
for source in sources {
println!("{source}");
assert_eq!(source.files, 0);
if source.id == default_source_id || source.id == another_source_id {
assert_eq!(
source.groups, 1,
"default source {default_source_name} should have 1 group"
);
} else {
assert_eq!(source.groups, 0, "groups should zero (empty)");
}
}
let uid = db
.get_uid(&new_user_api_key)
.await
.context("failed to user uid from apikey")?;
let user_info = db
.get_user_info(uid)
.await
.context("failed to get user's available groups and sources")?;
assert!(user_info.sources.contains(&default_source_name));
assert!(!user_info.is_admin);
println!("UserInfoResponse: {user_info:?}");
assert!(
db.allowed_user_source(1, default_source_id)
.await
.context(format!(
"failed to check that user 1 has access to source {default_source_id}"
))?,
"User 1 should should have had access to source {default_source_id}"
);
assert!(
!db.allowed_user_source(1, 5)
.await
.context("failed to check that user 1 has access to source 5")?,
"User 1 should should not have had access to source 5"
);
let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
let known_type =
KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
assert!(db
.add_file(
&test_elf_meta,
known_type.clone(),
1,
default_source_id,
elf_type,
None
)
.await
.context("failed to insert a test elf")?);
eprintln!("Added ELF to the DB");
let mut test_elf_meta_different_name = test_elf_meta.clone();
test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
assert!(!db
.add_file(
&test_elf_meta_different_name,
known_type,
1,
another_source_id,
elf_type,
None
)
.await
.context("failed to insert a test elf again for a different source")?);
let sources = db
.list_sources()
.await
.context("failed to re-list sources")?;
eprintln!(
"DB has {} sources, and a file was added twice:",
sources.len()
);
println!("We should have two sources with one file each, yet only one ELF.");
for source in sources {
println!("{source}");
if source.id == default_source_id || source.id == another_source_id {
assert_eq!(source.files, 1);
} else {
assert_eq!(source.files, 0, "groups should zero (empty)");
}
}
assert!(!db
.get_user_sources(1)
.await
.expect("failed to get user 1's sources")
.sources
.is_empty());
let file_types_counts = db
.file_types_counts()
.await
.context("failed to get file types and counts")?;
for (name, count) in file_types_counts {
println!("{name}: {count}");
assert_eq!(name, "ELF");
assert_eq!(count, 1);
}
let mut test_elf_modified = test_elf.clone();
let random_bytes = Uuid::new_v4();
let mut random_bytes = random_bytes.into_bytes().to_vec();
test_elf_modified.append(&mut random_bytes);
let similarity_request = generate_similarity_request(&test_elf_modified);
let similarity_response = db
.find_similar_samples(1, &similarity_request.hashes)
.await
.context("failed to get similarity response")?;
eprintln!("Similarity response: {similarity_response:?}");
let similarity_response = similarity_response.first().unwrap();
assert_eq!(
similarity_response.sha256, test_elf_meta.sha256,
"Similarity response should have had the hash of the original ELF"
);
for (algo, sim) in similarity_response.algorithms.iter() {
match *algo {
malwaredb_api::SimilarityHashType::LZJD => {
assert!(*sim > 0.0f32);
}
malwaredb_api::SimilarityHashType::SSDeep => {
assert!(*sim > 80.0f32);
}
malwaredb_api::SimilarityHashType::TLSH => {
assert!(*sim <= 20f32);
}
_ => {}
}
}
let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1)
.context("failed to get `HashType::SHA1` from string")?;
let response_sha256 = db
.retrieve_sample(1, &test_elf_hashtype)
.await
.context("could not get SHA-256 hash from test sample")
.unwrap();
assert_eq!(response_sha256, test_elf_meta.sha256);
let test_bogus_hash = HashType::try_from(String::from(
"d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0",
))
.context("failed to get `HashType` from static string")?;
assert!(
db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
"Getting a file with a bogus hash should have failed."
);
let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
let known_type =
KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
assert!(db
.add_file(
&test_pdf_meta,
known_type,
1,
default_source_id,
pdf_type,
None
)
.await
.context("failed to insert a test pdf")?);
eprintln!("Added PDF to the DB");
let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
let rtf_type = db
.get_type_id_for_bytes(&test_rtf)
.await
.context("failed to get file type id for rtf")?;
let known_type =
KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
assert!(db
.add_file(
&test_rtf_meta,
known_type,
1,
default_source_id,
rtf_type,
None
)
.await
.context("failed to insert a test rtf")?);
eprintln!("Added RTF to the DB");
let report = db
.get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.clone())?)
.await
.context("failed to get report for test rtf")?;
assert!(report
.clone()
.filecommand
.unwrap()
.contains("Rich Text Format"));
println!("Report: {report}");
assert!(db
.get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256)?)
.await
.is_err());
#[cfg(feature = "vt")]
{
assert!(report.vt.is_some());
let files_needing_vt = db
.files_without_vt_records(10)
.await
.context("failed to get files without VT records")?;
assert!(files_needing_vt.len() > 2);
println!(
"{} files needing VT data: {files_needing_vt:?}",
files_needing_vt.len()
);
}
#[cfg(not(feature = "vt"))]
{
assert!(report.vt.is_none());
}
let reset = db
.reset_api_keys()
.await
.context("failed to reset all API keys")?;
eprintln!("Cleared {reset} api keys.");
let db_info = db.db_info().await.context("failed to get database info")?;
eprintln!("DB Info: {db_info:?}");
let data_types = db
.get_known_data_types()
.await
.context("failed to get data types")?;
for data_type in data_types {
println!("{data_type:?}");
}
let sources = db
.list_sources()
.await
.context("failed to list sources second time")?;
eprintln!("DB has {} sources:", sources.len());
for source in sources {
println!("{source}");
}
let file_types_counts = db
.file_types_counts()
.await
.context("failed to get file types and counts")?;
for (name, count) in file_types_counts {
println!("{name}: {count}");
assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
}
let fatmacho =
include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
.to_vec();
let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
let fatmacho_type = db
.get_type_id_for_bytes(&fatmacho)
.await
.context("failed to get file type for Fat Mach-O")?;
let known_type = KnownType::new(&fatmacho)
.context("failed to parse Fat Mach-O from type crate's test data")?;
assert!(db
.add_file(
&fatmacho_meta,
known_type,
1,
default_source_id,
fatmacho_type,
None
)
.await
.context("failed to insert a test Fat Mach-O")?);
eprintln!("Added Fat Mach-O to the DB");
let file_types_counts = db
.file_types_counts()
.await
.context("failed to get file types and counts")?;
for (name, count) in &file_types_counts {
println!("{name}: {count}");
}
assert_eq!(
*file_types_counts.get("Mach-O").unwrap(),
4,
"Expected 4 Mach-O files, got {:?}",
file_types_counts.get("Mach-O")
);
const MALWARE_LABEL: &str = "malware";
const RANSOMWARE_LABEL: &str = "ransomware";
let malware_label_id = db
.create_label(MALWARE_LABEL, None)
.await
.context("failed to create first label")?;
let ransomware_label_id = db
.create_label(RANSOMWARE_LABEL, Some(malware_label_id as i64))
.await
.context("failed to create malware sub-label")?;
let labels = db.get_labels().await.context("failed to get labels")?;
assert_eq!(labels.len(), 2);
for label in labels.0 {
if label.name == RANSOMWARE_LABEL {
assert_eq!(label.id, ransomware_label_id);
assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
}
}
db.reset_own_api_key(0)
.await
.context("failed to clear own API key uid 0")?;
db.deactivate_user(0)
.await
.context("failed to clear password and API key for uid 0")?;
Ok(())
}
}