#![allow(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap
)]
#[cfg(any(test, feature = "admin"))]
use super::{admin, hash_password};
use super::{
random_bytes_api_key, DatabaseInformation, FileAddedResult, MDBConfig, Migration,
PARTIAL_SEARCH_LIMIT,
};
use crate::crypto::{EncryptionOption, FileEncryption};
use crate::db::types::{FileMetadata, FileType};
use malwaredb_api::{
digest::HashType, GetUserInfoResponse, Label, Labels, PartialHashSearchType, SearchRequest,
SearchResponse, SearchType, SimilarSample, SimilarityHashType, SourceInfo, Sources,
};
use malwaredb_types::KnownType;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::path::PathBuf;
use anyhow::{anyhow, bail, Context, Result};
use argon2::{Argon2, PasswordHash, PasswordVerifier};
#[cfg(any(test, feature = "admin"))]
use chrono::Local;
use deadpool_postgres::tokio_postgres::types::{FromSql, ToSql};
use deadpool_postgres::tokio_postgres::Config;
use deadpool_postgres::{GenericClient, Manager, ManagerConfig, Pool, RecyclingMethod};
use humansize::{make_format, DECIMAL};
#[cfg(feature = "vt")]
use malwaredb_virustotal::filereport::ScanResultAttributes;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::CertificateDer;
use rustls::ClientConfig;
use rustls_platform_verifier::BuilderVerifierExt;
use semver::Version;
use serde_json::json;
use tracing::{debug, error, instrument, warn};
use uuid::Uuid;
const PG_SQL: &str = include_str!("malwaredb_pg.sql");
struct WhichHashesInstalled {
pub lzjd: bool,
pub tlsh: bool,
pub ssdeep: bool,
}
impl WhichHashesInstalled {
pub fn all_installed(&self) -> bool {
self.lzjd && self.ssdeep && self.tlsh
}
}
impl Display for WhichHashesInstalled {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.all_installed() {
write!(f, "all similarity hashing extensions are installed")
} else if !self.lzjd && !self.ssdeep && !self.tlsh {
write!(f, "no similarity hashing extensions are installed")
} else {
if self.lzjd {
write!(f, "lzjd ")?;
}
if self.ssdeep {
write!(f, "ssdeep ")?;
}
if self.tlsh {
write!(f, "tlsh ")?;
}
write!(f, "installed")
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, ToSql, FromSql)]
#[postgres(name = "search_pagination_type", rename_all = "lowercase")]
enum PaginationType {
Search,
SimilaritySearch,
}
pub struct Postgres {
pool: Pool,
has_hash_extensions: bool,
}
impl Postgres {
pub async fn new(connection_string: &str, server_ca: Option<PathBuf>) -> Result<Self> {
let mgr_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast,
};
if rustls::crypto::CryptoProvider::get_default().is_none() {
if let Err(_e) = rustls::crypto::aws_lc_rs::default_provider().install_default() {
warn!("Failed to install AWS-LC crypto provider");
}
}
let config = if let Some(ca_path) = server_ca {
let mut certs = rustls::RootCertStore::empty();
certs.add(CertificateDer::from_pem_file(&ca_path).context(format!(
"Failed to read/parse Postgres CA certificate file {}",
ca_path.display()
))?)?;
ClientConfig::builder()
.with_root_certificates(certs)
.with_no_client_auth()
} else {
let arc_crypto_provider =
std::sync::Arc::new(rustls::crypto::aws_lc_rs::default_provider());
ClientConfig::builder_with_provider(arc_crypto_provider)
.with_safe_default_protocol_versions()
.context("TLS client configuration error for Postgres")?
.with_platform_verifier()
.context("TLS platform verifier error for Postgres")?
.with_no_client_auth()
};
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
let pg_config = connection_string.parse::<Config>()?;
let mgr = Manager::from_config(pg_config, tls, mgr_config);
let pool = Pool::builder(mgr)
.max_size(num_cpus::get().min(16))
.build()?;
let client = pool.get().await?;
let exists = client
.query_one(
"SELECT EXISTS (
SELECT FROM
information_schema.tables
WHERE
table_schema LIKE 'public' AND
table_type LIKE 'BASE TABLE' AND
table_name = 'file'
)",
&[],
)
.await?;
let exists: bool = exists.get(0);
if !exists {
client
.batch_execute(PG_SQL)
.await
.context("failed to create postgres tables")?;
client
.execute("update mdbconfig set version = $1", &[&crate::MDB_VERSION])
.await?;
}
client.execute("SET TIME ZONE 'UTC'", &[]).await?;
let mdb_db_version = client
.query_one("select version from mdbconfig", &[])
.await?;
let mdb_db_version: String = mdb_db_version.get(0);
let mdb_db_version = Version::parse(&mdb_db_version).map_err(|_| {
anyhow!("Failed to parse MalwareDB version in the database: {mdb_db_version}")
})?;
if mdb_db_version.major > crate::MDB_VERSION_SEMVER.major {
bail!("MalwareDB database schema {mdb_db_version} is newer than this binary version {}, a new binary is likely needed.", crate::MDB_VERSION);
}
if mdb_db_version.major < crate::MDB_VERSION_SEMVER.major {
warn!("MalwareDB database schema {mdb_db_version} is older than this binary version {}, a migration may be needed.", crate::MDB_VERSION);
}
let pgclient = Self {
pool,
has_hash_extensions: false,
};
let has_hash_extensions = pgclient.check_similarity_functions().await?.all_installed();
Ok(Self {
pool: pgclient.pool,
has_hash_extensions,
})
}
#[cfg(test)]
pub(crate) async fn delete(&self) -> Result<()> {
let client = self.pool.get().await?;
client
.batch_execute("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
.await?;
Ok(())
}
#[cfg(feature = "vt")]
pub async fn enable_vt_upload(&self) -> Result<()> {
let client = self.pool.get().await?;
client
.execute("update mdbconfig set send_samples_to_vt = true", &[])
.await?;
Ok(())
}
#[cfg(feature = "vt")]
pub async fn disable_vt_upload(&self) -> Result<()> {
let client = self.pool.get().await?;
client
.execute("update mdbconfig set send_samples_to_vt = false", &[])
.await?;
Ok(())
}
#[cfg(feature = "vt")]
pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
let client = self.pool.get().await?;
let limit = i64::from(limit);
let results = client
.query("select file.sha256 from file where file.id not in (select fileid from vtdata) limit $1", &[&limit])
.await?;
Ok(results
.iter()
.map(|r| {
let h: String = r.get(0);
h
})
.collect())
}
#[cfg(feature = "vt")]
pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
let client = self.pool.get().await?;
let result = client
.query_one("select id from file where sha256 = $1", &[&results.sha256])
.await?;
let id: i64 = result.get(0);
let report = serde_json::to_string(&results.last_analysis_results)?;
let report = serde_json::Value::from(report);
let hits = results.last_analysis_stats.malicious as i32;
let total = results.last_analysis_stats.av_count() as i32;
client
.execute(
"insert into vtdata(fileid, tstamp, hits, total, vtdetail) values($1, $2, $3, $4, $5)",
&[&id, &results.last_analysis_date, &hits, &total, &report],
)
.await?;
Ok(())
}
#[cfg(feature = "vt")]
pub async fn get_vt_stats(&self) -> Result<super::VtStats> {
let client = self.pool.get().await?;
let result = client
.query_one("select count(1) from vtdata where hits = 0", &[])
.await?;
let clean_records: i64 = result.get(0);
let result = client
.query_one("select count(1) from vtdata where hits > 0", &[])
.await?;
let hits_records: i64 = result.get(0);
let result = client
.query_one(
"select count(1) from file where file.id not in (select fileid from vtdata)",
&[],
)
.await?;
let files_without_records: i64 = result.get(0);
Ok(super::VtStats {
clean_records: clean_records as u32,
hits_records: hits_records as u32,
files_without_records: files_without_records as u32,
})
}
#[cfg(feature = "yara")]
pub async fn add_yara_search(
&self,
uid: u32,
yara_string: &str,
yara_bytes: &[u8],
) -> Result<Uuid> {
let client = self.pool.get().await?;
let uid = uid as i32;
let search_uuid = Uuid::now_v7();
client.execute("insert into yara_search(id, yara_text, yara_compiled, userid) values($1, $2, $3, $4)", &[&search_uuid, &yara_string, &yara_bytes, &uid]).await?;
Ok(search_uuid)
}
#[cfg(feature = "yara")]
pub async fn get_unfinished_yara_tasks(&self) -> Result<Vec<crate::yara::YaraTask>> {
let client = self.pool.get().await?;
let limit = crate::yara::MAX_YARA_PROCESSES as i64;
let results = client.query("select id, yara_text, yara_compiled, userid, last_fileid from yara_search where completed is null order by id asc limit $1", &[&limit]).await?;
let mut tasks = Vec::with_capacity(results.len());
for result in results {
let user_id: i32 = result.get(3);
let last_file_id: Option<i64> = result.get(4);
let task = crate::yara::YaraTask {
id: result.get(0),
yara_string: result.get(1),
yara_bytes: result.get(2),
last_file_id: last_file_id.map(|id| id as u64),
user_id: user_id as u32,
};
tasks.push(task);
}
Ok(tasks)
}
#[cfg(feature = "yara")]
pub async fn add_yara_match(&self, id: Uuid, rule_name: &str, file_sha256: &str) -> Result<()> {
let client = self.pool.get().await?;
let result = format!("{rule_name}|{file_sha256}");
client
.execute(
"update yara_search set results = array_append(results, $1) where id = $2",
&[&result, &id],
)
.await?;
Ok(())
}
#[cfg(feature = "yara")]
pub async fn mark_yara_task_as_finished(&self, id: Uuid) -> Result<()> {
let client = self.pool.get().await?;
client
.execute(
"update yara_search set completed = timezone('utc', now()) where id = $1",
&[&id],
)
.await?;
Ok(())
}
#[cfg(feature = "yara")]
pub async fn yara_add_next_file_id(&self, id: Uuid, file_id: u64) -> Result<()> {
let client = self.pool.get().await?;
let file_id = file_id as i64;
client
.execute(
"update yara_search set last_fileid = $1 where id = $2",
&[&file_id, &id],
)
.await?;
Ok(())
}
#[cfg(feature = "yara")]
pub async fn get_yara_results(
&self,
id: Uuid,
user_id: u32,
) -> Result<malwaredb_api::YaraSearchResponse> {
let client = self.pool.get().await?;
let user_id = user_id as i32;
let result = client
.query_one(
"select results from yara_search where id = $1 and userid = $2",
&[&id, &user_id],
)
.await?;
let results_string: Option<Vec<String>> = result.get(0);
let mut results_map = HashMap::new();
if let Some(results_string) = results_string {
for result in results_string {
let parts: Vec<&str> = result.split('|').collect();
let rule_name = parts[0];
let file_sha256 = hex::decode(parts[1])?;
let entry = results_map
.entry(rule_name.to_string())
.or_insert_with(Vec::new);
entry.push(HashType::SHA256(file_sha256.try_into()?));
}
}
Ok(malwaredb_api::YaraSearchResponse {
results: results_map,
})
}
pub async fn get_config(&self) -> Result<MDBConfig> {
let client = self.pool.get().await?;
let result = client
.query_one(
"select name, compress, send_samples_to_vt, keep_unknown_files, defaultKey from mdbconfig",
&[],
)
.await?;
let default_key: Option<i32> = result.get(4);
Ok(MDBConfig {
name: result.get(0),
compression: result.get(1),
send_samples_to_vt: result.get(2),
keep_unknown_files: result.get(3),
default_key: default_key.map(|id| id as u32),
})
}
pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
let client = self.pool.get().await?;
let result = client
.query_one("select password from person where uname = $1", &[&uname])
.await?;
let db_password: Option<String> = result.get(0);
if let Some(db_password) = db_password {
let password_hashed = PasswordHash::new(&db_password)?;
Argon2::default().verify_password(password.as_ref(), &password_hashed)?;
} else {
bail!("Password not set");
}
let row = client
.query_one("select apikey from person where uname = $1", &[&uname])
.await?;
let apikey: Option<String> = row.get(0);
if let Some(apikey) = apikey {
return Ok(apikey);
}
let apikey = random_bytes_api_key();
client
.execute(
"update person set apikey = $1 where uname = $2",
&[&apikey, &uname],
)
.await?;
Ok(apikey)
}
pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
let client = self.pool.get().await?;
let result = client
.query_one("SELECT id from person where apikey = $1", &[&apikey])
.await?;
let uid: i32 = result.get(0);
Ok(uid as u32)
}
pub async fn db_info(&self) -> Result<DatabaseInformation> {
let client = self.pool.get().await?;
let result = client
.query_one(
"SELECT pg_size_pretty(pg_database_size(current_database()))",
&[],
)
.await?;
let size: String = result.get(0);
let result = client.query_one("SELECT version();", &[]).await?;
let version: String = result.get(0);
let extensions = self.check_similarity_functions().await?;
let result = client.query_one("SELECT count(1) from person", &[]).await?;
let num_users: i64 = result.get(0);
let result = client.query_one("SELECT count(1) from file", &[]).await?;
let num_files: i64 = result.get(0);
let result = client.query_one("SELECT count(1) from grp", &[]).await?;
let num_groups: i64 = result.get(0);
let result = client.query_one("SELECT count(1) from source", &[]).await?;
let num_sources: i64 = result.get(0);
Ok(DatabaseInformation {
version: format!("{version}, {extensions}"),
size,
num_files: num_files as u64,
num_users: num_users as u32,
num_groups: num_groups as u32,
num_sources: num_sources as u32,
})
}
pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
let client = self.pool.get().await?;
let uid = uid as i32;
let results = client
.query_one(
"select uname, created, readonly from person where id = $1",
&[&uid],
)
.await?;
let username = results.get(0);
let created = results.get(1);
let is_readonly = results.get(2);
let results = client.query("select grp.name, grp.id, grp.parent from grp, usergroup where grp.id = usergroup.gid and usergroup.pid = $1", &[&uid]).await?;
let mut groups = vec![];
let mut is_admin = false;
for result in results {
groups.push(result.get(0));
let gid: i32 = result.get(1);
let g_pid: Option<i32> = result.get(2);
if gid == 0 {
is_admin = true;
}
if let Some(parent_id) = g_pid {
if parent_id == 0 {
is_admin = true;
}
}
}
let results = client.query("select source.name from source, usergroup, groupsource where source.id = groupsource.sourceid and groupsource.gid = usergroup.gid and usergroup.pid = $1", &[&uid]).await?;
let mut sources = vec![];
for result in results {
sources.push(result.get(0));
}
Ok(GetUserInfoResponse {
id: uid as u32,
username,
groups,
sources,
is_admin,
created,
is_readonly,
})
}
pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
let client = self.pool.get().await?;
let uid = uid as i32;
let results = client.query("select source.id, source.name, source.description, source.url, source.firstacquisition, source.malicious from source, usergroup, groupsource where source.id = groupsource.sourceid and groupsource.gid = usergroup.gid and usergroup.pid = $1", &[&uid]).await?;
let mut sources = vec![];
for result in results {
let id: i32 = result.get(0);
let name: String = result.get(1);
let description: Option<String> = result.get(2);
let url: Option<String> = result.get(3);
let first_acquisition: chrono::DateTime<chrono::Utc> = result.get(4);
let malicious: Option<bool> = result.get(5);
sources.push(SourceInfo {
id: id as u32,
name,
description,
url,
first_acquisition,
malicious,
});
}
Ok(Sources {
sources,
message: None,
})
}
pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
let client = self.pool.get().await?;
let uid = uid as i32;
client
.execute("update person set apikey = NULL where id = $1", &[&uid])
.await?;
Ok(())
}
pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
let client = self.pool.get().await?;
let results = client
.query(
"select id, name, description, magic, executable from filetype order by id",
&[],
)
.await?;
let mut file_types = vec![];
for result in results {
let id: i32 = result.get(0);
let nullable_bool: Option<bool> = result.get(4);
file_types.push(FileType {
id: id as u32,
name: result.get(1),
description: result.get(2),
magic: result.get(3),
executable: nullable_bool.unwrap_or(false),
});
}
Ok(file_types)
}
pub async fn get_labels(&self) -> Result<Labels> {
let client = self.pool.get().await?;
let results = client.query("select label.id, label.name, parent.name from label left outer join label parent on (parent.id = label.parent)", &[]).await?;
let mut labels = Vec::new();
for result in results {
let id: i64 = result.get(0);
labels.push(Label {
id: id as u64,
name: result.get(1),
parent: result.get(2),
});
}
Ok(Labels(labels))
}
pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
let db_file_types = match self.get_known_data_types().await {
Ok(t) => t,
Err(e) => {
return Err(e);
}
};
for db_file_type in &db_file_types {
for magic in &db_file_type.magic {
if !magic.is_empty() && data.starts_with(magic) {
return Ok(db_file_type.id);
}
}
}
let config = self.get_config().await?;
if config.keep_unknown_files {
debug!("Keeping unknown file");
if let Some(unknown) = db_file_types
.iter()
.find(|&db_file_type| db_file_type.name.eq_ignore_ascii_case("unknown"))
{
return Ok(unknown.id);
}
}
bail!("File type not found")
}
pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
let client = self.pool.get().await?;
let uid = uid as i32;
let sid = sid as i32;
let results = client
.query(
"select usergroup.gid from usergroup, groupsource where usergroup.gid = groupsource.gid and usergroup.pid = $1 and groupsource.sourceid = $2",
&[&uid, &sid],
)
.await?;
Ok(!results.is_empty())
}
pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
let client = self.pool.get().await?;
let uid = uid as i32;
let results = client
.query(
"select usergroup.gid from usergroup join grp on (usergroup.gid = grp.id) where usergroup.pid = $1 and (usergroup.gid = 0 or grp.parent = 0)",
&[&uid],
)
.await?;
Ok(!results.is_empty())
}
#[instrument]
pub async fn add_file(
&self,
meta: &FileMetadata,
known_type: KnownType<'_>,
uid: u32,
sid: u32,
ftype: u32,
parent: Option<u64>,
) -> Result<FileAddedResult> {
if !self.allowed_user_source(uid, sid).await? {
bail!("uid {uid} not allowed to upload to sid {sid}");
}
let mut client = self.pool.get().await?;
let uid_signed = uid as i32;
let readonly_query = client
.query_one("select readonly from person where id = $1", &[&uid_signed])
.await?;
let readonly: bool = readonly_query.get(0);
if readonly {
bail!("user {uid} is read-only!");
}
let transaction = client.transaction().await?;
let known_type_clone = known_type.clone();
let (fid, new_file) = match self
.add_file_transaction(meta, known_type, uid, sid, ftype, parent, &transaction)
.await
{
Ok(result) => result,
Err(e) => {
error!("Postgres: Insertion for parent file failed {e}, aborting transaction");
transaction.rollback().await?;
return Err(e);
}
};
if new_file {
if let Some(children) = known_type_clone.children() {
for child in children {
let child_meta = FileMetadata::new(child.contents(), None);
if let Ok(type_id) = self.get_type_id_for_bytes(child.contents()).await {
if let Err(e) = self
.add_file_transaction(
&child_meta,
child,
uid,
sid,
type_id,
Some(fid),
&transaction,
)
.await
{
error!("Postgres: failed to insert record for child file: {e}");
transaction.rollback().await?;
return Err(e);
}
}
}
}
}
transaction.commit().await?;
Ok(FileAddedResult {
file_id: fid,
is_new: new_file,
})
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
#[instrument]
async fn add_file_transaction(
&self,
meta: &FileMetadata,
known_type: KnownType<'_>,
uid: u32,
sid: u32,
ftype: u32,
parent: Option<u64>,
transaction: &deadpool_postgres::Transaction<'_>,
) -> Result<(u64, bool)> {
let result = transaction
.query("select id from file where sha512 = $1", &[&meta.sha512])
.await?;
let (fid, new_file) = if result.is_empty() {
let size = meta.size as i64;
let creation_date = known_type.created();
let ftype = ftype as i32;
transaction.execute("insert into file(sha1, sha256, sha384, sha512, md5, lzjd, ssdeep, tlsh, humanhash, filecommand, createddate, filetypeid, size, entropy)\
values($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)", &[&meta.sha1, &meta.sha256, &meta.sha384, &meta.sha512, &meta.md5, &meta.lzjd, &meta.ssdeep, &meta.tlsh, &meta.humanhash, &meta.file_command, &creation_date, &ftype, &size, &meta.entropy]).await?;
let result = transaction
.query_one("select id from file where sha512 = $1", &[&meta.sha512])
.await?;
let fid: i64 = result.get(0);
if let Some(parent_id) = parent {
let parent_id = parent_id as i64;
transaction
.execute(
"update file set parent = $1 where id = $2",
&[&parent_id, &fid],
)
.await?;
}
if known_type.is_doc() {
if let Some(doc) = known_type.doc() {
let pages = doc.pages() as i32; transaction.execute("insert into pdf(fileid, title, author, pages, javascript, forms) values($1, $2, $3, $4, $5, $6)",
&[&fid, &doc.title(), &doc.author(), &pages, &doc.has_javascript(), &doc.has_form()]).await?;
}
} else if known_type.is_exec() {
if let Some(exec) = known_type.exec() {
let sections = exec.num_sections() as i32;
let section_entropies: Option<Vec<f32>> = exec
.sections()
.map(|s| s.iter().map(|sn| sn.entropy).collect());
let section_exec: Option<Vec<bool>> = exec
.sections()
.map(|s| s.iter().map(|sn| sn.is_executable).collect());
let architecture = exec.architecture().map(|a| a.to_string());
transaction
.execute(
"insert into executable(fileid, architecture, operatingsystem, sections, sectionentropies, sectionexec, importhash, importhashfuzzy) values($1, $2, $3, $4, $5, $6, $7, $8)",
&[&fid, &architecture, &exec.operating_system().to_string(), §ions, §ion_entropies, §ion_exec, &exec.import_hash(), &exec.fuzzy_imports()],
)
.await?;
}
}
(fid, true)
} else if let Some(fid) = result.first() {
let fid: i64 = fid.get(0);
(fid, false)
} else {
eprintln!(
"Postgres: data integrity failure, {} entries for {}",
result.len(),
hex::encode(&meta.sha512)
);
bail!(
"Postgres: data integrity failure, {} entries for {}",
result.len(),
hex::encode(&meta.sha512)
);
};
let sid = sid as i32;
let result = transaction
.query_one(
"select count(*) from filesource where fileid = $1 and sourceid = $2",
&[&fid, &sid],
)
.await?;
let count: i64 = result.get(0);
let uid = uid as i32;
if count == 0 {
transaction
.execute(
"insert into filesource(fileid, sourceid, userid) values($1, $2, $3)",
&[&fid, &sid, &uid],
)
.await?;
}
let result = transaction
.query_one(
"select filename from filesource where fileid = $1 and sourceid = $2",
&[&fid, &sid],
)
.await?;
let file_names: Option<Vec<String>> = result.get(0);
let mut file_names = file_names.unwrap_or_default();
if let Some(fname) = &meta.name {
if !file_names.contains(fname) {
file_names.push(fname.clone());
transaction
.execute(
"update filesource set filename = $1 where fileid = $2 and sourceid = $3",
&[&file_names, &fid, &sid],
)
.await?;
}
}
Ok((fid as u64, new_file))
}
#[allow(clippy::too_many_lines)]
pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
let client = self.pool.get().await?;
let uid = uid as i32;
if !search.is_valid() {
return Ok(SearchResponse {
hashes: vec![],
pagination: None,
total_results: 0,
message: Some(String::from("Invalid search request")),
});
}
let (parameters, pagination_uuid, next_file_id, total_results, did_query) = match search
.search
{
SearchType::Search(parameters) => (parameters, Uuid::new_v4(), 0, 0, false),
SearchType::Continuation(pagination_uuid) => {
match client.query_one("select query, next_fileid, total_results from pagination where id = $1 and userid = $2 and type = $3", &[&pagination_uuid, &uid, &PaginationType::Search]).await {
Ok(row) => {
let parameters: String = row.get(0);
let next_file_id: i64 = row.get(1);
let total_results: i64 = row.get(2);
(
serde_json::from_str(¶meters)
.context("Failed to deserialize continuation parameters")?,
pagination_uuid,
next_file_id,
total_results,
true,
)
},
Err(e) => {
warn!("Failed to get continuation result for user {uid}: {e}");
return Ok(SearchResponse {
hashes: vec![],
pagination: None,
total_results: 0,
message: Some(String::from("Invalid pagination identifier")),
});
}
}
}
};
let response_hash_type = {
if parameters.response == PartialHashSearchType::Any {
"sha256"
} else {
¶meters.response.to_string()
}
};
let file_name = if let Some(file_name) = ¶meters.file_name {
if file_name.is_empty() {
String::new()
} else {
format!("and array_to_string(filesource.filename, ',') like '%{file_name}%'")
}
} else {
String::new()
};
let file_hash = if let Some((hash_type, file_hash)) = ¶meters.partial_hash {
if *hash_type == PartialHashSearchType::Any {
format!("and encode(file.md5, 'hex') like '%{file_hash}%' or encode(file.sha1, 'hex') like '%{file_hash}%' or encode(file.sha256, 'hex') like '%{file_hash}%' or encode(file.sha384, 'hex') like '%{file_hash}%' or encode(file.sha512, 'hex) like '%{file_hash}%')").to_string()
} else {
format!("and encode(file.{hash_type}, 'hex') like '%{file_hash}%'").to_string()
}
} else {
String::new()
};
let labels_join = if parameters.labels.is_some() {
"join filelabel on (filelabel.fileid = file.id) join label on (filelabel.labelid = label.id)"
} else {
""
};
let labels_clause = if let Some(labels) = ¶meters.labels {
let labels: Vec<String> = labels.iter().map(|l| format!("'{l}'")).collect();
format!("and (label.name in ({}))", labels.join(","))
} else {
String::new()
};
let file_type_join = if parameters.file_type.is_some() {
"join filetype on (filetype.id = file.filetypeid)"
} else {
""
};
let file_type_clause = if let Some(file_type) = ¶meters.file_type {
format!("and filetype.name = '{file_type}'")
} else {
String::new()
};
let magic_clause = if let Some(magic) = ¶meters.magic {
format!("and file.filecommand like '%{magic}%'")
} else {
String::new()
};
let limit = i64::from(parameters.limit.min(PARTIAL_SEARCH_LIMIT));
let result = client
.query(&format!("select distinct {response_hash_type}, file.id from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid) {labels_join} {file_type_join}\
where usergroup.pid = $1 {file_hash} {file_name} {labels_clause} {file_type_clause} {magic_clause}\
and file.id > $3 order by file.id limit $2"),
&[&uid, &limit, &next_file_id],
)
.await?;
let mut last_id: i64 = 0;
let hashes = result
.iter()
.map(|s| {
last_id = s.get(1);
let hash: Vec<u8> = s.get(0);
hex::encode(hash)
})
.collect::<Vec<String>>();
let mut returned_uuid = None;
let total_results = if total_results == 0 {
let result = client
.query_one(&format!("select count(distinct {response_hash_type}) from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid) {labels_join} {file_type_join}\
where usergroup.pid = $1 {file_hash} {file_name} {labels_clause} {file_type_clause} {magic_clause}"),
&[&uid],
)
.await?;
let total_results: i64 = result.get(0);
if total_results as usize > hashes.len() {
returned_uuid = Some(pagination_uuid);
client.execute("insert into pagination(uuid, userid, type, query, next_fileid, total_results) values($1, $2, $3, $4, $5, $6)", &[&pagination_uuid, &uid, &PaginationType::Search, &serde_json::to_string(¶meters).unwrap(), &last_id, &total_results]).await?;
}
total_results
} else {
if did_query {
client.execute("update pagination set next_fileid = $1 where uuid = $2 and userid = $3 and type = $4", &[&last_id, &pagination_uuid, &uid, &PaginationType::Search]).await?;
}
total_results
};
Ok(SearchResponse {
hashes,
pagination: returned_uuid,
total_results: total_results as u64,
message: None,
})
}
pub async fn cleanup(&self) -> Result<u64> {
let client = self.pool.get().await?;
let elapsed = chrono::Utc::now() - crate::DB_CLEANUP_INTERVAL;
let removed = client
.execute("delete from pagination where created < $1", &[&elapsed])
.await?;
#[cfg(not(feature = "yara"))]
return Ok(removed);
#[cfg(feature = "yara")]
{
let yara_removed = client
.execute("delete from yara_search where completed < $1", &[&elapsed])
.await?;
Ok(yara_removed + removed)
}
}
pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
let client = self.pool.get().await?;
let uid = uid as i32;
let result = client
.query(&format!("select distinct sha256 from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid)\
where file.{} = $1 and usergroup.pid = $2", hash.name()),
&[&hash.bytes(), &uid],
)
.await?;
if result.is_empty() {
bail!("Doesn't exist or not allowed");
}
if result.len() != 1 {
eprintln!(
"Postgres: data integrity failure, {} entries for {hash}",
result.len()
);
bail!(
"Postgres: data integrity failure, {} entries for {hash}",
result.len()
);
}
if let Some(sha256) = result.first() {
let sha256: Vec<u8> = sha256.get(0);
Ok(hex::encode(sha256))
} else {
bail!("Unexpected error getting SHA-256");
}
}
pub async fn get_sample_report(
&self,
uid: u32,
hash: &HashType,
) -> Result<malwaredb_api::Report> {
let client = self.pool.get().await?;
let uid = uid as i32;
let result = client
.query_one(&format!("select md5, sha1, sha256, sha384, sha512, lzjd, tlsh, ssdeep, humanhash, filecommand, size, entropy from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid)\
where file.{} = $1 and usergroup.pid = $2", hash.name()),
&[&hash.bytes(), &uid],
)
.await?;
let bytes: i64 = result.get(10);
let formatter = make_format(DECIMAL);
let vt_summary = if cfg!(feature = "vt") {
let vt_result = client.query(&format!("select vtdata.hits, vtdata.total, vtdata.vtdetail, vtdata.tstamp from vtdata join file on (file.id = vtdata.fileid) \
where file.{} = $1 order by vtdata.tstamp desc limit 1", hash.name()), &[&hash.bytes()]).await?;
Some(
vt_result
.first()
.map(|row| {
let hits: i32 = row.get(0);
let total: i32 = row.get(1);
let detail: Option<String> = row.get(2);
malwaredb_api::VirusTotalSummary {
hits: hits as u32,
total: total as u32,
detail: detail.map(|d| json!(d)),
last_analysis_date: row.get(3),
}
})
.unwrap_or_default(),
)
} else {
None
};
let md5: Uuid = result.get(0);
let sha1: Vec<u8> = result.get(1);
let sha256: Vec<u8> = result.get(2);
let sha384: Vec<u8> = result.get(3);
let sha512: Vec<u8> = result.get(4);
Ok(malwaredb_api::Report {
md5: md5.to_string().replace('-', ""),
sha1: hex::encode(sha1),
sha256: hex::encode(sha256),
sha384: hex::encode(sha384),
sha512: hex::encode(sha512),
lzjd: result.get(5),
tlsh: result.get(6),
ssdeep: result.get(7),
humanhash: result.get(8),
filecommand: result.get(9),
bytes: bytes as u64,
size: formatter(bytes as u32),
entropy: result.get(11),
vt: vt_summary,
})
}
async fn check_similarity_functions(&self) -> Result<WhichHashesInstalled> {
let client = self.pool.get().await?;
let result = client
.query(
"select CAST(p.oid::regprocedure AS text)
from pg_proc p join pg_namespace n on p.pronamespace = n.oid
where n.nspname not in ('pg_catalog', 'information_schema');",
&[],
)
.await?;
Ok(if result.is_empty() {
WhichHashesInstalled {
lzjd: false,
tlsh: false,
ssdeep: false,
}
} else {
let mut tlsh = false;
let mut lzjd = false;
let mut ssdeep = false;
for row in result {
let name: String = row.get(0);
if name == "lzjd_compare(text,text)" {
lzjd = true;
}
if name == "fuzzy_hash_compare(text,text)" {
ssdeep = true;
}
if name == "tlsh_compare(text,text)" {
tlsh = true;
}
}
WhichHashesInstalled { lzjd, tlsh, ssdeep }
})
}
#[instrument]
pub async fn find_similar_samples(
&self,
uid: u32,
sim: &[(SimilarityHashType, String)],
) -> Result<Vec<SimilarSample>> {
if !self.has_hash_extensions {
bail!("similarity search without similarity extensions not yet supported");
}
let client = self.pool.get().await?;
let uid = uid as i32;
let mut results = HashMap::<String, Vec<(SimilarityHashType, f32)>>::new();
for (algo, (table_field, hash_func), hash_value) in sim
.iter()
.map(|(sim, val)| (*sim, sim.get_table_field_simfunc(), val))
{
let rows = if let Some(hash_func) = hash_func {
if algo == SimilarityHashType::TLSH {
client
.query(
&format!(
"select file.sha256, {hash_func}({table_field}, $1)::float \
from file join executable on (file.id = executable.fileid) \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid)\
where {hash_func}({table_field}, $1) < 500 and usergroup.pid = $2",
),
&[hash_value, &uid],
)
.await?
} else {
client
.query(
&format!(
"select file.sha256, {hash_func}({table_field}, $1)::float \
from file join executable on (file.id = executable.fileid) \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid)\
where {hash_func}({table_field}, $1) > 0 and usergroup.pid = $2",
),
&[hash_value, &uid],
)
.await?
}
} else {
client
.query(
&format!(
"select file.sha256, 100.0 as algo\
from file join executable on (file.id = executable.fileid)\
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid)\
where {table_field} = $1 and usergroup.pid = $2"
),
&[hash_value, &uid],
)
.await?
};
for row in rows {
let sha256: Vec<u8> = row.get(0);
let sha256 = hex::encode(sha256);
let similarity: f64 = row.get(1);
if let Some(already) = results.get_mut(&sha256) {
already.push((algo, similarity as f32));
} else {
results.insert(sha256, vec![(algo, similarity as f32)]);
}
}
}
Ok(results
.into_iter()
.map(|(sha256, algorithms)| SimilarSample { sha256, algorithms })
.collect())
}
pub async fn user_allowed_files_by_sha256(
&self,
uid: u32,
next: Option<u64>,
) -> Result<(Vec<String>, u64)> {
let client = self.pool.get().await?;
let next = if let Some(next) = next {
format!(" and file.id > {next}")
} else {
String::new()
};
let uid = uid as i32;
let limit = i64::from(PARTIAL_SEARCH_LIMIT);
let result = client
.query(
&format!(
"select distinct file.sha256, file.id \
from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid) \
where usergroup.pid = $1{next} \
order by file.id asc \
limit $2"
),
&[&uid, &limit],
)
.await?;
let hashes = result
.iter()
.map(|s| {
let sha256: Vec<u8> = s.get(0);
hex::encode(sha256)
})
.collect::<Vec<String>>();
let last_id: i64 = result.last().map_or(0, |s| s.get(1));
Ok((hashes, last_id as u64))
}
pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
let client = self.pool.get().await?;
let mut keys = HashMap::new();
let results = client
.query("select id, name, bytes from encryptionkey", &[])
.await?;
for result in results {
let id: i32 = result.get(0);
let name: EncryptionOption = result.get(1);
let bytes: Vec<u8> = result.get(2);
let key = FileEncryption::new(name, bytes)?;
keys.insert(id as u32, key);
}
Ok(keys)
}
pub(crate) async fn get_file_encryption_key_id(
&self,
hash: &str,
) -> Result<(Option<u32>, Option<Vec<u8>>)> {
let client = self.pool.get().await?;
let result = client
.query_one("select key, nonce from file where sha256 = $1", &[&hash])
.await?;
let key_id: Option<i32> = result.get(0);
let nonce: Option<Vec<u8>> = result.get(1);
Ok((key_id.map(|id| id as u32), nonce))
}
pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
let client = self.pool.get().await?;
client
.execute(
"update file set nonce = $2 where sha256 = $1",
&[&hash, &nonce],
)
.await?;
Ok(())
}
#[allow(clippy::too_many_lines)]
#[cfg_attr(not(test), allow(unused_variables))]
pub async fn migrate(&self, action: Migration) -> Result<()> {
let client = self.pool.get().await?;
let mdb_db_version = client
.query_one("select version from mdbconfig", &[])
.await?;
let mdb_db_version: String = mdb_db_version.get(0);
let mdb_db_version = Version::parse(&mdb_db_version)?;
if mdb_db_version < Version::new(0, 2, 0) {
bail!("MalwareDB database version is too old to migrate automatically");
}
if mdb_db_version < Version::new(0, 3, 0) {
#[cfg(any(test, feature = "admin"))]
{
match action {
Migration::Migrate => {
let mut client = self.pool.get().await?;
let transaction = client.transaction().await?;
if let Err(e) = transaction.execute("CREATE TABLE yara_search (
id uuid NOT NULL, -- Unique identifier for this paginated search
yara_text text NOT NULL,
yara_compiled bytea,
userid int NOT NULL REFERENCES person(id), -- User who initiated the search
results text[], -- Yara name|Hash
last_fileid bigint REFERENCES file(id), -- File ID of the last file in the search of allowed files
completed timestamp with time zone,
created timestamp with time zone NOT NULL DEFAULT timezone('utc', now()), -- When the search was initiated so it can be deleted
PRIMARY KEY (id)
);", &[]).await {
transaction.rollback().await?;
bail!("Failed to create Yara search table: {e}");
}
if let Err(e) = transaction
.batch_execute(
"ALTER TABLE executable ADD COLUMN pehash_new UUID; \
ALTER TABLE executable ADD COLUMN importhash_new UUID;",
)
.await
{
transaction.rollback().await?;
bail!("Failed to add new PE hash and import hash columns: {e}");
}
let executable_entries = transaction
.query("select fileid, pehash, importhash from executable", &[])
.await?;
for entry in executable_entries {
let fileid: i64 = entry.get(0);
let pehash: Option<String> = entry.get(1);
let importhash: Option<String> = entry.get(2);
if pehash.is_some() || importhash.is_some() {
let new_pehash = pehash.map(|h| Uuid::parse_str(&h).unwrap());
let new_importhash =
importhash.map(|h| Uuid::parse_str(&h).unwrap());
transaction.execute("update executable set pehash_new = $1, importhash_new = $2 where fileid = $3", &[&new_pehash, &new_importhash, &fileid]).await?;
}
}
if let Err(e) = transaction.batch_execute("ALTER TABLE executable DROP COLUMN pehash; ALTER TABLE executable DROP COLUMN importhash; \
ALTER TABLE executable RENAME COLUMN pehash_new TO pehash; \
ALTER TABLE executable RENAME COLUMN importhash_new TO importhash; ").await {
transaction.rollback().await?;
bail!("Failed to rename PE hash and import hash columns: {e}");
}
if let Err(e) = transaction
.batch_execute(
"ALTER TABLE file ADD COLUMN sha1_new BYTEA; \
ALTER TABLE file ADD COLUMN sha256_new BYTEA; \
ALTER TABLE file ADD COLUMN sha384_new BYTEA; \
ALTER TABLE file ADD COLUMN sha512_new BYTEA; \
ALTER TABLE file ADD COLUMN md5_new UUID;",
)
.await
{
transaction.rollback().await?;
bail!("Failed to add new hash columns: {e}");
}
let file_entries = transaction
.query(
"select id, md5, sha1, sha256, sha384, sha512 from file",
&[],
)
.await?;
for entry in file_entries {
let fileid: i64 = entry.get(0);
let md5: String = entry.get(1);
let sha1: String = entry.get(2);
let sha256: String = entry.get(3);
let sha384: String = entry.get(4);
let sha512: String = entry.get(5);
let new_md5 = Uuid::parse_str(&md5)?;
let new_sha1 = hex::decode(sha1)?;
let new_sha256 = hex::decode(sha256)?;
let new_sha384 = hex::decode(sha384)?;
let new_sha512 = hex::decode(sha512)?;
if let Err(e) = transaction.execute("update file set md5_new = $1, sha1_new = $2, sha256_new = $3, sha384_new = $4, sha512_new = $5 where id = $6",
&[&new_md5, &new_sha1, &new_sha256, &new_sha384, &new_sha512, &fileid]).await {
transaction.rollback().await?;
bail!("Failed to update file hashes from text to bytea for file_id {fileid}: {e}");
}
}
if let Err(e) = transaction
.batch_execute(
"ALTER TABLE file DROP COLUMN md5;\
ALTER TABLE file DROP COLUMN sha1; \
ALTER TABLE file DROP COLUMN sha256; \
ALTER TABLE file DROP COLUMN sha384; \
ALTER TABLE file DROP COLUMN sha512; \
ALTER TABLE file RENAME COLUMN md5_new TO md5; \
ALTER TABLE file RENAME COLUMN sha1_new TO sha1; \
ALTER TABLE file RENAME COLUMN sha256_new TO sha256; \
ALTER TABLE file RENAME COLUMN sha384_new TO sha384; \
ALTER TABLE file RENAME COLUMN sha512_new TO sha512; \
ALTER TABLE file ALTER COLUMN md5 SET NOT NULL; \
ALTER TABLE file ALTER COLUMN sha1 SET NOT NULL; \
ALTER TABLE file ALTER COLUMN sha256 SET NOT NULL; \
ALTER TABLE file ALTER COLUMN sha384 SET NOT NULL; \
ALTER TABLE file ALTER COLUMN sha512 SET NOT NULL; \
ALTER TABLE file ADD CONSTRAINT unique_sha512 UNIQUE (sha512);",
)
.await
{
transaction.rollback().await?;
bail!("Failed to rename PE hash and import hash columns: {e}");
}
if let Err(e) = transaction
.execute("update mdbconfig set version = '0.3.0';", &[])
.await
{
transaction.rollback().await?;
bail!("Failed to update MalwareDB version in Postgres: {e}");
}
transaction.commit().await?;
}
Migration::Check => {
bail!("MalwareDB database needs migration.");
}
}
}
#[cfg(not(any(test, feature = "admin")))]
bail!("MalwareDB database needs migration.");
}
client
.execute("update mdbconfig set version = $1;", &[&crate::MDB_VERSION])
.await?;
Ok(())
}
pub async fn set_name(&self, name: &str) -> Result<()> {
let client = self.pool.get().await?;
client
.execute("update mdbconfig set name = $1", &[&name])
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn enable_compression(&self) -> Result<()> {
let client = self.pool.get().await?;
client
.execute("update mdbconfig set compress = true", &[])
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn disable_compression(&self) -> Result<()> {
let client = self.pool.get().await?;
client
.execute("update mdbconfig set compress = false", &[])
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn enable_keep_unknown_files(&self) -> Result<()> {
let client = self.pool.get().await?;
client
.execute("update mdbconfig set keep_unknown_files = true", &[])
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn disable_keep_unknown_files(&self) -> Result<()> {
let client = self.pool.get().await?;
client
.execute("update mdbconfig set keep_unknown_files = false", &[])
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
let mut client = self.pool.get().await?;
let transaction = client.transaction().await?;
if let Err(e) = transaction
.execute(
"insert into encryptionkey(name, bytes) values($1, $2)",
&[&key.key_type(), &key.key()],
)
.await
{
error!("Failed to insert encryption key: {e}");
transaction.rollback().await?;
bail!("Failed to insert encryption key: {e}");
}
let results = match transaction
.query_one(
"select id from encryptionkey where bytes = $1",
&[&key.key()],
)
.await
{
Ok(r) => r,
Err(e) => {
error!(
"Failed to get ID for newly-added encryption key {}: {e}",
key.name()
);
transaction.rollback().await?;
bail!(
"Failed to get ID for newly-added encryption key {}: {e}",
key.name()
);
}
};
let key_id: i32 = results.get(0);
if let Err(e) = transaction
.execute("update mdbconfig set defaultKey = $1", &[&key_id])
.await
{
error!(
"Failed to set encryption key {} as default: {e}",
key.name()
);
transaction.rollback().await?;
bail!(
"Failed to set encryption key {} as default: {e}",
key.name()
);
}
transaction.commit().await?;
Ok(key_id as u32)
}
#[cfg(any(test, feature = "admin"))]
pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
let client = self.pool.get().await?;
let mut keys = vec![];
let results = client
.query("select id, name from encryptionkey", &[])
.await?;
for result in results {
let id: i32 = result.get(0);
let name: EncryptionOption = result.get(1);
keys.push((id as u32, name));
}
Ok(keys)
}
#[allow(clippy::too_many_arguments)]
#[cfg(any(test, feature = "admin"))]
pub async fn create_user(
&self,
uname: &str,
fname: &str,
lname: &str,
email: &str,
password: Option<String>,
organisation: Option<&String>,
readonly: bool,
) -> Result<u32> {
let client = self.pool.get().await?;
let result = client
.query_one("select count(1) from person where uname = $1", &[&uname])
.await?;
let count: i64 = result.get(0);
if count != 0 {
bail!("username already taken");
}
let result = client
.query_one("select count(1) from person where email = $1", &[&email])
.await?;
let count: i64 = result.get(0);
if count != 0 {
bail!("email address already taken");
}
match password {
Some(password) => {
let password = hash_password(&password)?;
client.execute("insert into person(email, uname, firstname, lastname, organisation, password, readonly) values ($1, $2, $3, $4, $5, $6, $7);", &[&email, &uname, &fname, &lname, &organisation, &password, &readonly]).await?;
}
None => {
client.execute("insert into person(email, uname, firstname, lastname, organisation, readonly) values ($1, $2, $3, $4, $5, $6);", &[&email, &uname, &fname, &lname, &organisation, &readonly]).await?;
}
}
let result = client
.query_one("select id from person where uname = $1", &[&uname])
.await?;
let uid: i32 = result.get(0);
Ok(uid as u32)
}
#[cfg(any(test, feature = "admin"))]
pub async fn reset_api_keys(&self) -> Result<u64> {
let client = self.pool.get().await?;
let reset = client
.execute("update person set apikey = NULL", &[])
.await?;
Ok(reset)
}
#[cfg(any(test, feature = "admin"))]
pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
let password = hash_password(password)?;
let client = self.pool.get().await?;
client
.execute(
"update person set password = $1 where uname = $2",
&[&password, &uname],
)
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn set_user_readonly(&self, uid: u32, readonly: bool) -> Result<()> {
let client = self.pool.get().await?;
let uid = uid as i32;
client
.execute(
"update person set readonly = $1 where id = $2",
&[&readonly, &uid],
)
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_users(&self) -> Result<Vec<admin::User>> {
let mut users = Vec::new();
let client = self.pool.get().await?;
let results = client
.query("select id,email,uname,firstname,lastname,password is not null and length(password)>0,apikey is not null and length(apikey)>0,organisation,phone, created, readonly from person", &[])
.await?;
for result in results {
let id: i32 = result.get(0);
users.push(admin::User {
id: id as u32,
email: result.get(1),
uname: result.get(2),
fname: result.get(3),
lname: result.get(4),
has_password: result.get(5),
has_api_key: result.get(6),
org: result.get(7),
phone: result.get(8),
created: result.get(9),
is_readonly: result.get(10),
});
}
Ok(users)
}
#[cfg(any(test, feature = "admin"))]
pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
let client = self.pool.get().await?;
let results = client
.query_one("select id from grp where name = $1", &[&name])
.await?;
let id: i32 = results.get(0);
Ok(id)
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_group(
&self,
gid: u32,
name: &str,
desc: &str,
parent: Option<u32>,
) -> Result<()> {
let client = self.pool.get().await?;
let gid = gid as i32;
let parent = parent.map(|p| p as i32);
client
.execute(
"update grp set name = $1, description = $2, parent = $3 where id = $4",
&[&name, &desc, &parent, &gid],
)
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
let mut groups = Vec::new();
let client = self.pool.get().await?;
let results = client
.query("select grp.id, grp.name, grp.description, parent.name from grp left outer join grp parent on (grp.parent = parent.id) order by 1", &[])
.await?;
for result in results {
let id: i32 = result.get(0);
let members = {
let members_results = client
.query("select person.id, person.uname, person.email, person.firstname, person.lastname, person.organisation, person.phone, person.password is not null and length(person.password)>0,person.apikey is not null and length(person.apikey)>0, person.created, person.readonly from person, usergroup where person.id = usergroup.pid and usergroup.gid = $1", &[&id])
.await?;
members_results
.into_iter()
.map(|x| {
let id: i32 = x.get(0);
admin::User {
id: id as u32,
uname: x.get(1),
email: x.get(2),
fname: x.get(3),
lname: x.get(4),
org: x.get(5),
phone: x.get(6),
has_password: x.get(7),
has_api_key: x.get(8),
created: x.get(9),
is_readonly: x.get(10),
}
})
.collect()
};
let sources = {
let id: i32 = result.get(0);
let sources_results = client
.query("select source.id, source.name, source.description, source.url, source.firstacquisition, source.malicious, parent_source.name, count(filesource.fileid), count(gs2.gid)\
from source join groupsource as gs1 on (source.id = gs1.sourceid) left join filesource on (filesource.sourceid = source.id) \
left join groupsource as gs2 on (gs2.sourceid = source.id) \
left join source as parent_source on (source.parent = parent_source.id)\
where gs1.gid = $1 group by source.id, parent_source.name", &[&id])
.await.context(format!("failed to list sources for group {id}"))?;
sources_results
.into_iter()
.map(|x| {
let id: i32 = x.get(0);
let files_count: i64 = x.get(7);
let groups_count: i64 = x.get(8);
admin::Source {
id: id as u32,
name: x.get(1),
description: x.get(2),
url: x.get(3),
date: x.get(4),
files: files_count as u64,
groups: groups_count as u32,
parent: x.get(6),
malicious: x.get(5),
}
})
.collect()
};
let files_row = client
.query_one(
"select count(filesource.fileid) \
from filesource join groupsource on (groupsource.gid = filesource.sourceid) \
where groupsource.gid = $1",
&[&id],
)
.await
.context(format!("failed to count files for group id {id}"))?;
let files_count: i64 = files_row.get(0);
groups.push(admin::Group {
id: id as u32,
name: result.get(1),
description: result.get(2),
parent: result.get(3),
members,
sources,
files: files_count as u32,
});
}
Ok(groups)
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
let client = self.pool.get().await?;
let uid = uid as i32;
let gid = gid as i32;
client
.execute(
"insert into usergroup(pid, gid) values($1, $2)",
&[&uid, &gid],
)
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
let client = self.pool.get().await?;
let gid = gid as i32;
let sid = sid as i32;
client
.execute(
"insert into groupsource(gid, sourceid) values($1, $2)",
&[&gid, &sid],
)
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_group(
&self,
name: &str,
description: &str,
parent: Option<u32>,
) -> Result<u32> {
let client = self.pool.get().await?;
let parent = parent.map(|p| p as i32);
client
.execute(
"insert into grp(name, description, parent) values ($1, $2, $3);",
&[&name, &description, &parent],
)
.await?;
let result = client
.query_one("select id from grp where name = $1", &[&name])
.await?;
let gid: i32 = result.get(0);
Ok(gid as u32)
}
#[cfg(any(test, feature = "admin"))]
pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
let mut sources = Vec::new();
let client = self.pool.get().await?;
let results = client
.query(
"select source.id, source.name, source.description, source.url, source.firstacquisition, source.malicious, parent_source.name, count(filesource.fileid), count(groupsource.gid) \
from source left join filesource on (source.id = filesource.sourceid) left join groupsource on (groupsource.sourceid = source.id) left join source as parent_source on (source.parent = parent_source.id) \
group by 1, 7",
&[],
)
.await.context("failed to list sources")?;
for result in results {
let id: i32 = result.get(0);
let files_count: i64 = result.get(7);
let groups_count: i64 = result.get(8);
sources.push(admin::Source {
id: id as u32,
name: result.get(1),
description: result.get(2),
url: result.get(3),
date: result.get(4),
files: files_count as u64,
groups: groups_count as u32,
parent: result.get(6),
malicious: result.get(5),
});
}
Ok(sources)
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_source(
&self,
name: &str,
description: Option<&str>,
url: Option<&str>,
date: chrono::DateTime<Local>,
releasable: bool,
malicious: Option<bool>,
) -> Result<u32> {
let client = self.pool.get().await?;
client.execute(
"insert into source(name, description, url, firstacquisition, releasable, malicious) values ($1, $2, $3, $4, $5, $6);",
&[&name, &description, &url, &date, &releasable, &malicious],
).await?;
let result = client
.query_one("select id from source where name = $1", &[&name])
.await?;
let sid: i32 = result.get(0);
Ok(sid as u32)
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_user(
&self,
uid: u32,
uname: &str,
fname: &str,
lname: &str,
email: &str,
readonly: bool,
) -> Result<()> {
let client = self.pool.get().await?;
let uid = uid as i32;
client.execute("update person set uname = $1, email = $2, firstname = $3, lastname = $4, readonly = $5 where id = $6;", &[&uname, &email, &fname, &lname, &readonly, &uid]).await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
let client = self.pool.get().await?;
let uid = uid as i32;
client
.execute(
"update person set password = null, apikey = null where id = $1;",
&[&uid],
)
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
let client = self.pool.get().await?;
let mut types_counts = HashMap::default();
let results = client.query("SELECT filetype.name, count(file.id) from file join filetype on (file.filetypeid = filetype.id) group by 1", &[]).await?;
for row in results {
let name = row.get(0);
let count: i64 = row.get(1);
types_counts.insert(name, count as u32);
}
Ok(types_counts)
}
#[cfg(any(test, feature = "admin"))]
pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
let client = self.pool.get().await?;
let parent = parent.map(|p| p as i64);
client
.execute(
"insert into label(name, parent) values ($1, $2);",
&[&name, &parent],
)
.await?;
let result = client
.query_one("select id from label where name = $1", &[&name])
.await?;
let lid: i64 = result.get(0);
Ok(lid as u64)
}
#[cfg(any(test, feature = "admin"))]
pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
let client = self.pool.get().await?;
let id = id as i64;
let parent = parent.map(|p| p as i64);
client
.execute(
"update label set name = $1, parent = $2 where id = $3",
&[&name, &parent, &id],
)
.await?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
let client = self.pool.get().await?;
let result = client
.query_one("SELECT id from label where name = $1", &[&name])
.await?;
let label_id: i64 = result.get(0);
Ok(label_id as u64)
}
#[cfg(any(test, feature = "admin"))]
pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
let file_id = file_id as i64;
let label_id = label_id as i64;
let client = self.pool.get().await?;
client
.execute(
"insert into filelabel(fileid, labelid) values ($1, $2)",
&[&file_id, &label_id],
)
.await?;
Ok(())
}
}
impl Display for Postgres {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Postgres client")
}
}
impl Debug for Postgres {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Postgres client")
}
}
#[cfg(test)]
mod tests {
use super::*;
use deadpool_postgres::tokio_postgres::Config;
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
use md5::Md5;
use postgres::NoTls;
use sha1::{Digest, Sha1};
use sha2::{Sha256, Sha384, Sha512};
#[tokio::test]
#[ignore = "don't run this in CI"]
async fn migration() {
const CONNECTION_STRING: &str = "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtestingmigration host=localhost sslmode=disable";
let pg_config = CONNECTION_STRING.parse::<Config>().unwrap();
let mgr_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast,
};
let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
let pool = Pool::builder(mgr)
.max_size(num_cpus::get().min(16))
.build()
.unwrap();
let client = pool.get().await.unwrap();
client
.batch_execute("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
.await
.unwrap();
client
.batch_execute(include_str!("../../testdata/pg/v0.2.2.sql"))
.await
.unwrap();
client
.execute("update mdbconfig set version = '0.2.2'", &[])
.await
.unwrap();
let test_rtf = include_bytes!("../../testdata/empty.rtf");
let rtf_sha1 = hex::encode(Sha1::digest(test_rtf));
let rtf_sha256 = hex::encode(Sha256::digest(test_rtf));
let rtf_sha384 = hex::encode(Sha384::digest(test_rtf));
let rtf_sha512 = hex::encode(Sha512::digest(test_rtf));
let rtf_md5 = hex::encode(Md5::digest(test_rtf));
let rtf_id = 1i64;
let rtf_size = test_rtf.len() as i64;
let rtf_entropy = 2.0f32;
let rtf_type_id: i32 = client
.query_one("select id from filetype where name = 'RTF'", &[])
.await
.unwrap()
.get(0);
let rows = client.execute("insert into file(id, sha1, sha256, sha384, sha512, md5, size, entropy, filetypeid) values($1, $2, $3, $4, $5, $6, $7, $8, $9)",
&[&rtf_id, &rtf_sha1, &rtf_sha256, &rtf_sha384, &rtf_sha512, &rtf_md5, &rtf_size, &rtf_entropy, &rtf_type_id]).await.unwrap();
assert_eq!(rows, 1);
let pg_db = Postgres {
pool,
has_hash_extensions: false,
};
pg_db.migrate(Migration::Migrate).await.unwrap();
let pg_db = Postgres::new(CONNECTION_STRING, None).await.unwrap();
pg_db.migrate(Migration::Check).await.unwrap();
let client = pg_db.pool.get().await.unwrap();
let result = client
.query_one(
"select sha1, sha256, sha384, sha512, md5 from file where id = $1",
&[&rtf_id],
)
.await
.unwrap();
let sha1: Vec<u8> = result.get(0);
let sha256: Vec<u8> = result.get(1);
let sha384: Vec<u8> = result.get(2);
let sha512: Vec<u8> = result.get(3);
let md5: Uuid = result.get(4);
assert_eq!(hex::encode(sha1), rtf_sha1);
assert_eq!(hex::encode(sha256), rtf_sha256);
assert_eq!(hex::encode(sha384), rtf_sha384);
assert_eq!(hex::encode(sha512), rtf_sha512);
assert_eq!(md5.to_string().replace('-', ""), rtf_md5);
pg_db.delete().await.unwrap();
}
}