#![allow(
clippy::cast_sign_loss,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap
)]
#[cfg(any(test, feature = "admin"))]
use super::admin;
#[cfg(any(test, feature = "admin"))]
use super::hash_password;
use super::{
random_bytes_api_key, DatabaseInformation, FileAddedResult, MDBConfig, Migration,
PARTIAL_SEARCH_LIMIT,
};
use crate::crypto::{EncryptionOption, FileEncryption};
use crate::db::types::{FileMetadata, FileType};
use crate::{MDB_VERSION, MDB_VERSION_SEMVER};
use malwaredb_api::{
digest::HashType, GetUserInfoResponse, Label, Labels, PartialHashSearchType, SearchRequest,
SearchResponse, SearchType, SimilarSample, SimilarityHashType, SourceInfo, Sources,
};
use malwaredb_types::KnownType;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::path::PathBuf;
use anyhow::{anyhow, bail, Context, Result};
use argon2::{Argon2, PasswordHash, PasswordVerifier};
#[cfg(any(test, feature = "admin"))]
use chrono::Local;
use chrono::{DateTime, Utc};
use humansize::{make_format, DECIMAL};
#[cfg(feature = "vt")]
use malwaredb_virustotal::filereport::ScanResultAttributes;
use rusqlite::fallible_iterator::FallibleIterator;
use rusqlite::params;
use rusqlite::types::Type;
use rusqlite::{Batch, Connection};
use semver::Version;
use serde_json::json;
use tracing::{debug, error, instrument, warn};
#[cfg(feature = "yara")]
use uuid::Uuid;
const ARRAY_DELIMITER: &str = "|";
const SQLITE_SQL: &str = include_str!("malwaredb_sqlite.sql");
pub struct Sqlite {
conn: Connection,
file_path: PathBuf,
}
impl Sqlite {
pub fn new(file_name: &str) -> Result<Self> {
let file_path = std::path::Path::new(file_name);
let new_db = !file_path.exists();
if new_db {
let conn = Connection::open(file_name)?;
let mut batch = Batch::new(&conn, SQLITE_SQL);
while let Some(mut stmt) = batch.next()? {
stmt.execute([])?;
}
conn.execute("update mdbconfig set version = ?1", [&crate::MDB_VERSION])?;
}
let conn = Connection::open(file_name)?;
super::sqlite_functions::add_similarity_functions(&conn)?;
{
let mut statement = conn.prepare("select version from mdbconfig")?;
let mdb_db_version: String = statement.query_one([], |row| row.get(0))?;
let mdb_db_version = Version::parse(&mdb_db_version).map_err(|_| {
anyhow!("Failed to parse MalwareDB version in the database: {mdb_db_version}")
})?;
if mdb_db_version.major > MDB_VERSION_SEMVER.major {
bail!("MalwareDB database schema {mdb_db_version} is newer than this binary version {MDB_VERSION}, a new binary is likely needed.");
}
if mdb_db_version.major < MDB_VERSION_SEMVER.major {
warn!("MalwareDB database schema {mdb_db_version} is older than this binary version {MDB_VERSION}, a migration may be needed.");
}
}
Ok(Self {
conn,
file_path: file_name.into(),
})
}
#[cfg(feature = "vt")]
pub fn enable_vt_upload(&self) -> Result<()> {
self.conn
.execute("update mdbconfig set send_samples_to_vt = true", ())?;
Ok(())
}
#[cfg(feature = "vt")]
pub fn disable_vt_upload(&self) -> Result<()> {
self.conn
.execute("update mdbconfig set send_samples_to_vt = false", ())?;
Ok(())
}
#[cfg(feature = "vt")]
pub fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
let mut statement = self
.conn.prepare("select file.sha256 from file where file.id not in (select fileid from vtdata) limit ?1")?;
let limit = limit as i32;
let results = statement.query_map([limit], |row| {
let hash: String = row.get(0)?;
Ok(hash)
})?;
Ok(results.filter_map(Result::ok).collect())
}
#[cfg(feature = "vt")]
pub fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
let mut statement = self.conn.prepare("select id from file where sha256 = ?1")?;
let fid = statement.query_row([&results.sha256], |row| {
let fid: Option<i32> = row.get(0)?;
Ok(fid)
})?;
if let Some(id) = fid {
let report = serde_json::to_string(&results.last_analysis_results)?;
self.conn.execute(
"insert into vtdata(fileid, tstamp, hits, total, vtdetail) values(?1, ?2, ?3, ?4, ?5)",
(
id,
results.last_analysis_stats.malicious,
results.last_analysis_stats.av_count(),
report,
&results.last_analysis_date,
),
)?;
}
Ok(())
}
#[cfg(feature = "vt")]
pub fn get_vt_stats(&self) -> Result<super::VtStats> {
let mut statement = self
.conn
.prepare("select count(1) from vtdata where hits = 0")?;
let clean_records = statement.query_row([], |row| {
let clean: i32 = row.get(0)?;
Ok(clean as u32)
})?;
let mut statement = self
.conn
.prepare("select count(1) from vtdata where hits > 0")?;
let hits_records = statement.query_row([], |row| {
let hits: i32 = row.get(0)?;
Ok(hits as u32)
})?;
let mut statement = self.conn.prepare(
"select count(1) from file where file.id not in (select fileid from vtdata)",
)?;
let files_without_records = statement.query_row([], |row| {
let no_recs: i32 = row.get(0)?;
Ok(no_recs as u32)
})?;
Ok(super::VtStats {
clean_records,
hits_records,
files_without_records,
})
}
#[cfg(feature = "yara")]
pub fn add_yara_search(&self, uid: u32, yara_string: &str, yara_bytes: &[u8]) -> Result<Uuid> {
use base64::Engine;
let search_uuid = Uuid::now_v7();
let yara_bytes = base64::engine::general_purpose::STANDARD.encode(yara_bytes);
let mut statement = self.conn.prepare(
"insert into yara_search(id, yara_text, yara_compiled, userid) values(?1, ?2, ?3, ?4)",
)?;
statement.execute((search_uuid.to_string(), yara_string, yara_bytes, uid))?;
Ok(search_uuid)
}
#[cfg(feature = "yara")]
pub fn get_unfinished_yara_tasks(&self) -> Result<Vec<crate::yara::YaraTask>> {
use base64::Engine;
let mut statement = self.conn.prepare("select id, yara_text, yara_compiled, userid, last_fileid from yara_search where completed is null order by id asc limit ?1")?;
let results = statement.query_map([crate::yara::MAX_YARA_PROCESSES], |row| {
let uuid: String = row.get(0)?;
let yara_bytes: String = row.get(2)?;
let yara_bytes = base64::engine::general_purpose::STANDARD
.decode(&yara_bytes)
.map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
2,
Type::Text,
Box::from(format!("Invalid base64 `{yara_bytes}` error: {e}")),
)
})?;
let task = crate::yara::YaraTask {
id: Uuid::parse_str(&uuid).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
0,
Type::Text,
Box::from(format!("Invalid uuid `{uuid}` error: {e}")),
)
})?,
yara_string: row.get(1)?,
yara_bytes,
user_id: row.get(3)?,
last_file_id: row.get(4)?,
};
Ok(task)
})?;
let tasks = results.collect::<Result<Vec<_>, _>>()?;
Ok(tasks)
}
#[cfg(feature = "yara")]
pub fn add_yara_match(&self, id: Uuid, rule_name: &str, file_sha256: &str) -> Result<()> {
let id = id.to_string();
let mut statement = self
.conn
.prepare("select results from yara_search where id = ?1")?;
let results = statement.query_row([&id], |row| {
let results: Option<String> = row.get(0)?;
Ok(results)
})?;
let result = if let Some(results) = results {
format!("{results},{rule_name}|{file_sha256}")
} else {
format!("{rule_name}|{file_sha256}")
};
let mut statement = self
.conn
.prepare("update yara_search set results = ?1 where id = ?2")?;
anyhow::ensure!(
statement.execute(params![&result, &id])? == 1,
"Failed to update Yara task {id} with results {result}"
);
Ok(())
}
#[cfg(feature = "yara")]
pub fn mark_yara_task_as_finished(&self, id: Uuid) -> Result<()> {
let id = id.to_string();
let mut statement = self
.conn
.prepare("update yara_search set completed = current_timestamp where id = ?1")?;
statement.execute(params![&id])?;
Ok(())
}
#[cfg(feature = "yara")]
pub fn yara_add_next_file_id(&self, id: Uuid, file_id: u64) -> Result<()> {
let id = id.to_string();
let mut statement = self
.conn
.prepare("update yara_search set last_fileid = ?1 where id = ?2")?;
anyhow::ensure!(
statement.execute(params![&file_id, &id])? == 1,
"Failed to update Yara task {id} with last file id {file_id}"
);
Ok(())
}
#[cfg(feature = "yara")]
pub fn get_yara_results(
&self,
id: Uuid,
user_id: u32,
) -> Result<malwaredb_api::YaraSearchResponse> {
let id = id.to_string();
let mut statement = self
.conn
.prepare("select results from yara_search where id = ?1 and userid = ?2")?;
let results = statement.query_row(params![&id, &user_id], |row| {
let results: String = row.get(0)?;
Ok(results)
})?;
let mut results_map = HashMap::new();
let results_string: Vec<&str> = results.split(',').collect();
for result in results_string {
let parts: Vec<&str> = result.split('|').collect();
let rule_name = parts[0];
let file_sha256 = hex::decode(parts[1])?;
let entry = results_map
.entry(rule_name.to_string())
.or_insert_with(Vec::new);
entry.push(HashType::SHA256(file_sha256.try_into()?));
}
Ok(malwaredb_api::YaraSearchResponse {
results: results_map,
})
}
pub fn get_config(&self) -> Result<MDBConfig> {
let mut statement = self
.conn
.prepare("select name, compress, send_samples_to_vt, keep_unknown_files, defaultKey from mdbconfig")?;
statement
.query_row([], |row| {
Ok(MDBConfig {
name: row.get(0)?,
compression: row.get(1)?,
send_samples_to_vt: row.get(2)?,
keep_unknown_files: row.get(3)?,
default_key: row.get(4)?,
})
})
.map_err(anyhow::Error::new)
}
pub fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
let mut statement = self
.conn
.prepare("select password from person where uname = ?1")?;
let results = statement.query_row([uname], |row| {
let db_password: Option<String> = row.get(0)?;
Ok(db_password)
})?;
if let Some(password_hash) = results {
let password_hashed = PasswordHash::new(&password_hash)?;
Argon2::default().verify_password(password.as_ref(), &password_hashed)?;
} else {
bail!("Password not set");
}
let mut statement = self
.conn
.prepare("select apikey from person where uname = ?1")?;
let result = statement.query_row([uname], |row| {
let apikey: Option<String> = row.get(0)?;
let Some(key) = apikey else {
return Ok(None);
};
Ok(Some(key))
})?;
if let Some(apikey) = result {
return Ok(apikey);
}
let apikey = random_bytes_api_key();
self.conn.execute(
"update person set apikey = ?1 where uname = ?2",
(&apikey, &uname),
)?;
Ok(apikey)
}
pub fn get_uid(&self, apikey: &str) -> Result<u32> {
let mut statement = self
.conn
.prepare("select id from person where apikey = ?1")?;
let uid = statement.query_row([apikey], |row| {
let uid: Option<i32> = row.get(0)?;
Ok(uid)
})?;
uid.map(|i| i as u32)
.ok_or(anyhow!("unable to get user ID"))
}
pub fn db_info(&self) -> Result<DatabaseInformation> {
let size = self.file_path.metadata()?.len();
let size = humansize::SizeFormatter::new(size, humansize::BINARY).to_string();
let mut statement = self.conn.prepare("SELECT sqlite_version();")?;
let version = statement.query_row([], |row| {
let ver: String = row.get(0)?;
Ok(ver)
})?;
let mut statement = self.conn.prepare("SELECT count(1) from person")?;
let num_users = statement.query_row([], |row| {
let users: u32 = row.get(0)?;
Ok(users)
})?;
let mut statement = self.conn.prepare("SELECT count(1) from file")?;
let num_files = statement.query_row([], |row| {
let files: u64 = row.get(0)?;
Ok(files)
})?;
let mut statement = self.conn.prepare("SELECT count(1) from grp")?;
let num_groups = statement.query_row([], |row| {
let groups: u32 = row.get(0)?;
Ok(groups)
})?;
let mut statement = self.conn.prepare("SELECT count(1) from source")?;
let num_sources = statement.query_row([], |row| {
let sources: u32 = row.get(0)?;
Ok(sources)
})?;
Ok(DatabaseInformation {
version: format!("SQLite {version}"),
size,
num_files,
num_users,
num_groups,
num_sources,
})
}
pub fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
let mut statement = self
.conn
.prepare("select uname, created, readonly from person where id = ?1")?;
let (username, created, readonly) = statement.query_row([uid], |row| {
let uname: Option<String> = row.get(0)?;
let created: Option<DateTime<Utc>> = row.get(1)?;
let readonly: Option<bool> = row.get(2)?;
Ok((uname, created, readonly))
})?;
let username = username.ok_or(anyhow!("unable to get username"))?;
let created = created.ok_or(anyhow!("unable to get created date"))?;
let is_readonly = readonly.unwrap_or(true);
let mut is_admin = false;
let mut statement = self.conn.prepare("select grp.name, grp.id, grp.parent from grp, usergroup where grp.id = usergroup.gid and usergroup.pid = ?1")?;
let rows = statement.query_map([uid], |row| {
let name: String = row.get(0)?;
let id: i32 = row.get(1)?;
let parent_id: Option<i32> = row.get(2)?;
if id == 0 {
is_admin = true;
}
if let Some(parent) = parent_id {
if parent == 0 {
is_admin = true;
}
}
Ok(name)
})?;
let mut groups = vec![];
for row in rows {
groups.push(row?);
}
let mut statement = self.conn.prepare("select source.name from source, usergroup, groupsource where source.id = groupsource.sourceid and groupsource.gid = usergroup.gid and usergroup.pid = ?1")?;
let rows = statement.query_map([uid], |row| {
let name: String = row.get(0)?;
Ok(name)
})?;
let mut sources = vec![];
for row in rows {
sources.push(row?);
}
Ok(GetUserInfoResponse {
id: uid,
username,
groups,
sources,
is_admin,
created,
is_readonly,
})
}
pub fn get_user_sources(&self, uid: u32) -> Result<Sources> {
let mut statement = self.conn.prepare("select source.id, source.name, source.description, source.url, source.firstacquisition, source.malicious from source, usergroup, groupsource where source.id = groupsource.sourceid and groupsource.gid = usergroup.gid and usergroup.pid = ?1")?;
let mut sources = vec![];
let rows = statement.query_map([uid], |row| {
let id: i32 = row.get(0)?;
let name: String = row.get(1)?;
let description: Option<String> = row.get(2)?;
let url: Option<String> = row.get(3)?;
let date: String = row.get(4).expect("failed to get date");
let date = DateTime::parse_from_rfc3339(&date)
.expect("failed to get source date in in Sqlite::list_groups");
let first_acquisition = date.with_timezone(&Utc);
let malicious: Option<bool> = row.get(5)?;
Ok(SourceInfo {
id: id as u32,
name,
description,
url,
first_acquisition,
malicious,
})
})?;
for row in rows {
sources.push(row?);
}
Ok(Sources {
sources,
message: None,
})
}
pub fn reset_own_api_key(&self, uid: u32) -> Result<()> {
self.conn.execute(
"update person set apikey = NULL where id = ?1",
params![uid],
)?;
Ok(())
}
pub fn get_known_data_types(&self) -> Result<Vec<FileType>> {
let mut statement = self
.conn
.prepare("select id, name, description, magic, executable from filetype order by 1")?;
let rows = statement.query_map([], |row| {
let name: String = row.get(1)?;
let magic = {
let magic: String = row.get(3)?;
let mut magic_numbers = vec![];
if magic.contains(ARRAY_DELIMITER) {
for magic_hex in magic.split(ARRAY_DELIMITER) {
magic_numbers.push(hex::decode(magic_hex).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(0, Type::Text, Box::new(e))
})?);
}
} else {
magic_numbers.push(hex::decode(magic).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(0, Type::Text, Box::new(e))
})?);
}
magic_numbers
};
let exec: Option<i32> = row.get(4)?;
Ok(FileType {
id: row.get(0)?,
name,
description: row.get(2)?,
magic,
executable: exec.unwrap_or_default() > 0,
})
})?;
let mut file_types = vec![];
for row in rows {
file_types.push(row?);
}
Ok(file_types)
}
pub fn get_labels(&self) -> Result<Labels> {
let mut statement = self.conn.prepare("select label.id, label.name, parent.name from label left outer join label parent on (parent.id = label.parent)")?;
let mut labels = Vec::new();
let rows = statement.query_map([], |row| {
let id: i64 = row.get(0)?;
let name: String = row.get(1)?;
let parent: Option<String> = row.get(2)?;
Ok((id, name, parent))
})?;
for (id, name, parent) in rows.flatten() {
labels.push(Label {
id: id as u64,
name,
parent,
});
}
Ok(Labels(labels))
}
pub fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
let db_file_types = match self.get_known_data_types() {
Ok(t) => t,
Err(e) => {
return Err(e);
}
};
for db_file_type in &db_file_types {
for magic in &db_file_type.magic {
if !magic.is_empty() && data.starts_with(magic) {
return Ok(db_file_type.id);
}
}
}
let config = self.get_config()?;
if config.keep_unknown_files {
debug!("Keeping unknown file");
if let Some(unknown) = db_file_types
.iter()
.find(|&db_file_type| db_file_type.name.eq_ignore_ascii_case("unknown"))
{
return Ok(unknown.id);
}
}
bail!("File type not found")
}
pub fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
let mut statement = self.conn.prepare(
"select usergroup.gid from usergroup, groupsource where usergroup.gid = groupsource.gid and usergroup.pid = ?1 and groupsource.sourceid = ?2")?;
let uid = uid as i32;
let sid = sid as i32;
let results = statement.query_map([uid, sid], |row| {
let gid: u32 = row.get(0)?;
Ok(gid)
})?;
let allowed = results.flatten().next().is_some();
Ok(allowed)
}
pub fn user_is_admin(&self, uid: u32) -> Result<bool> {
let mut statement = self.conn.prepare(
"select usergroup.gid from usergroup join grp on (usergroup.gid = grp.id) where usergroup.pid = ?1 and (usergroup.gid = 0 or grp.parent = 0)")?;
let uid = uid as i32;
let results = statement.query_map([uid], |row| {
let gid: u32 = row.get(0)?;
Ok(gid)
})?;
let admin = results.flatten().next().is_some();
Ok(admin)
}
#[allow(clippy::too_many_lines)] #[instrument]
pub fn add_file(
&self,
meta: &FileMetadata,
known_type: &KnownType<'_>,
uid: u32,
sid: u32,
ftype: u32,
parent: Option<u64>,
) -> Result<FileAddedResult> {
if !self.allowed_user_source(uid, sid)? {
bail!("uid {uid} not allowed to upload to sid {sid}");
}
let mut statement = self
.conn
.prepare("select readonly from person where id = ?1")?;
let readonly = statement.query_row([uid as i32], |row| {
let readonly: bool = row.get(0)?;
Ok(readonly)
})?;
if readonly {
bail!("user {uid} is read-only!");
}
let creation_date = known_type.created().map(|d| d.to_rfc3339());
let mut statement = self
.conn
.prepare("select count(1) from file where sha512 = ?1")?;
let sha512 = hex::encode(&meta.sha512);
let result = statement.query_map([&sha512], |row| {
let counts: i32 = row.get(0)?;
Ok(counts)
})?;
let exists = result
.flatten()
.next()
.context("Failed to get file count from SQLite")?;
let (fid, new_file) = if exists > 0 {
let mut statement = self.conn.prepare("select id from file where sha512 = ?1")?;
let result = statement.query_map([&sha512], |row| {
let fid: i32 = row.get(0)?;
Ok(fid)
})?;
let fid = result
.flatten()
.next()
.context("SQLite file ID retrieval failed")?;
(fid, false)
} else {
let ftype = ftype as i32;
let size = meta.size as i64;
let md5 = meta.md5.to_string().replace('-', "");
let sha1 = hex::encode(&meta.sha1);
let sha256 = hex::encode(&meta.sha256);
let sha384 = hex::encode(&meta.sha384);
self.conn.execute("insert into file(sha1, sha256, sha384, sha512, md5, lzjd, ssdeep, tlsh, humanhash, filecommand, createddate, filetypeid, size, entropy)\
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)", (&sha1, &sha256, &sha384, &sha512, &md5, &meta.lzjd, &meta.ssdeep, &meta.tlsh, &meta.humanhash, &meta.file_command, &creation_date, &ftype, &size, &meta.entropy))?;
let mut statement = self.conn.prepare("select id from file where sha512 = ?1")?;
let result = statement.query_map([&sha512], |row| {
let fid: i32 = row.get(0)?;
Ok(fid)
})?;
let fid = result
.flatten()
.next()
.context("SQLite file record insertion failed")?;
if let Some(parent_id) = parent {
let fid = fid as u64;
let mut statement = self
.conn
.prepare("update file set parent = ?1 where id = ?2")?;
if let Err(e) = statement.execute([&parent_id, &fid]) {
error!("SQLite: failed to set parent id for a file: {e}");
}
}
if known_type.is_doc() {
if let Some(doc) = known_type.clone().doc() {
self.conn.execute("insert into pdf(fileid, title, author, pages, javascript, forms) values(?1, ?2, ?3, ?4, ?5, ?6)",
(&fid, &doc.title(), &doc.author(), &doc.pages(), &doc.has_javascript(), &doc.has_form()))?;
}
} else if known_type.is_exec() {
if let Some(exec) = known_type.clone().exec() {
let section_names: Option<String> = exec.sections().map(|s| {
s.iter()
.map(|sn| sn.name.clone())
.collect::<Vec<String>>()
.join(ARRAY_DELIMITER)
});
let section_entropies: Option<String> = exec.sections().map(|s| {
s.iter()
.map(|sn| format!("{}", sn.entropy))
.collect::<Vec<String>>()
.join(ARRAY_DELIMITER)
});
let section_exec: Option<String> = exec.sections().map(|s| {
s.iter()
.map(|sn| format!("{}", sn.is_executable))
.collect::<Vec<String>>()
.join(ARRAY_DELIMITER)
});
let architecture = exec.architecture().map(|a| a.to_string());
self.conn.execute(
"insert into executable(fileid, architecture, operatingsystem, sections, sectionnames, sectionentropies, sectionexec, importhash, importhashfuzzy) values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
(&fid, &architecture, &exec.operating_system().to_string(), &exec.num_sections(), §ion_names, §ion_entropies, §ion_exec, &exec.import_hash().map(|h| h.to_string().replace('-', "")), &exec.fuzzy_imports()),
)?;
}
}
if let Some(children) = known_type.children() {
for child in children {
let child_meta = FileMetadata::new(child.contents(), None);
if let Ok(type_id) = self.get_type_id_for_bytes(child.contents()) {
if let Err(e) =
self.add_file(&child_meta, &child, uid, sid, type_id, Some(fid as u64))
{
error!("Sqlite: failed to insert record for child file: {e}");
}
}
}
}
(fid, true)
};
let mut statement = self
.conn
.prepare("select count(*) from filesource where fileid = ?1 and sourceid = ?2")?;
let sid = sid as i32;
let count = statement.query_one([&fid, &sid], |row| {
let count: i32 = row.get(0)?;
Ok(count)
})?;
if count == 0 {
self.conn.execute(
"insert into filesource(fileid, sourceid, userid) values(?1, ?2, ?3)",
(&fid, &sid, &uid),
)?;
}
if let Some(fname) = &meta.name {
let mut statement = self
.conn
.prepare("select filename from filesource where fileid = ?1 and sourceid = ?2")?;
let result = statement.query_map([&fid, &sid], |row| {
let fnames: String = row.get(0).unwrap_or_default();
Ok(fnames)
})?;
let mut file_names = result.flatten().next().unwrap();
let mut changed = false;
if file_names.is_empty() {
file_names.clone_from(fname);
changed = true;
} else if !file_names.contains(fname) {
file_names = format!("{file_names}{ARRAY_DELIMITER}{fname}");
changed = true;
}
if changed {
self.conn.execute(
"update filesource set filename = ?1 where fileid = ?2 and sourceid = ?3",
(&file_names, &fid, &sid),
)?;
}
}
Ok(FileAddedResult {
file_id: fid as u64,
is_new: new_file,
})
}
#[allow(clippy::too_many_lines)]
pub fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
let uid = uid as i32;
if !search.is_valid() {
return Ok(SearchResponse {
hashes: vec![],
pagination: None,
total_results: 0,
message: Some(String::from("Invalid search request")),
});
}
let (parameters, pagination_uuid, next_file_id, total_results, did_query) = match search
.search
{
SearchType::Search(parameters) => (parameters, uuid::Uuid::new_v4(), 0, 0, false),
SearchType::Continuation(pagination_uuid) => {
let mut statement = self.conn.prepare("select query, next_fileid, total_results from pagination where id = ?1 and userid = ?2 and type = 'search'")?;
let result = statement.query_one(params![pagination_uuid, uid], |row| {
let parameters: String = row.get(0)?;
let next_file_id: i64 = row.get(1)?;
let total_results: i64 = row.get(2)?;
Ok((parameters, next_file_id, total_results))
});
let (parameters, next_file_id, total_results) = match result {
Ok((p, n, t)) => (p, n, t),
Err(e) => {
warn!("Failed to get continuation result for user {uid}: {e}");
return Ok(SearchResponse {
hashes: vec![],
pagination: None,
total_results: 0,
message: Some(String::from("Invalid pagination identifier")),
});
}
};
(
serde_json::from_str(¶meters)
.context("Failed to deserialize continuation parameters")?,
pagination_uuid,
next_file_id,
total_results,
true,
)
}
};
let response_hash_type = {
if parameters.response == PartialHashSearchType::Any {
"sha256"
} else {
¶meters.response.to_string()
}
};
let file_name = if let Some(file_name) = ¶meters.file_name {
if file_name.is_empty() {
String::new()
} else {
format!("and filesource.filename like '%{file_name}%'")
}
} else {
String::new()
};
let file_hash = if let Some((hash_type, file_hash)) = ¶meters.partial_hash {
if *hash_type == PartialHashSearchType::Any {
format!("and (file.md5 like '%{file_hash}%' or file.sha1 like '%{file_hash}%' or file.sha256 like '%{file_hash}%' or file.sha384 like '%{file_hash}%' or file.sha512 like '%{file_hash}%')").to_string()
} else {
format!("and file.{hash_type} like '%{file_hash}%'").to_string()
}
} else {
String::new()
};
let labels_join = if parameters.labels.is_some() {
"join filelabel on (filelabel.fileid = file.id) join label on (filelabel.labelid = label.id)"
} else {
""
};
let labels_clause = if let Some(labels) = ¶meters.labels {
let labels: Vec<String> = labels.iter().map(|l| format!("'{l}'")).collect();
format!("and (label.name in ({}))", labels.join(","))
} else {
String::new()
};
let file_type_join = if parameters.file_type.is_some() {
"join filetype on (filetype.id = file.filetypeid)"
} else {
""
};
let file_type_clause = if let Some(file_type) = ¶meters.file_type {
format!("and filetype.name = '{file_type}'")
} else {
String::new()
};
let magic_clause = if let Some(magic) = ¶meters.magic {
format!("and file.filecommand like '%{magic}%'")
} else {
String::new()
};
let limit = parameters.limit.min(PARTIAL_SEARCH_LIMIT);
let mut statement = self.conn.prepare(&format!("select distinct {response_hash_type}, file.id from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid) {labels_join} {file_type_join}\
where usergroup.pid = ?1 {file_hash} {file_name} {labels_clause} {file_type_clause} {magic_clause}\
and file.id > ?3 order by file.id limit ?2"))?;
let mut last_id = 0;
let hashes = statement
.query_map(params![uid, limit, next_file_id], |row| {
let hash: String = row.get(0)?;
last_id = row.get(1)?;
Ok(hash)
})
.map_err(|e| anyhow!(e.to_string()))?;
let hashes = hashes.collect::<Result<Vec<_>, _>>()?;
let mut returned_uuid = None;
let total_results = if total_results == 0 {
let mut statement = self.conn.prepare(&format!("select count(distinct {response_hash_type}) from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid) {labels_join} {file_type_join}\
where usergroup.pid = ?1 {file_hash} {file_name} {labels_clause} {file_type_clause} {magic_clause}"))?;
let total_results = statement
.query_one(params![uid], |row| {
let count: i64 = row.get(0)?;
Ok(count)
})
.map_err(|e| anyhow!(e.to_string()))?;
if total_results as usize > hashes.len() {
returned_uuid = Some(pagination_uuid);
let mut statement = self.conn.prepare("insert into pagination(id, userid, type, query, next_fileid, total_results) values(?1, ?2, 'search', ?3, ?4, ?5)")?;
statement.execute(params![
pagination_uuid,
uid,
serde_json::to_string(¶meters)?,
last_id,
total_results
])?;
}
total_results
} else {
if did_query {
let mut statement = self.conn.prepare("update pagination set next_fileid = ?1 where uuid = ?2 and userid = ?3 and type = 'search'")?;
statement.execute(params![last_id, pagination_uuid, uid])?;
}
total_results
};
Ok(SearchResponse {
hashes,
pagination: returned_uuid,
total_results: total_results as u64,
message: None,
})
}
pub fn cleanup(&self) -> Result<u64> {
let elapsed = Utc::now() - crate::DB_CLEANUP_INTERVAL;
let mut statement = self
.conn
.prepare("delete from pagination where created < ?1")?;
let removed = statement.execute(params![elapsed])?;
#[cfg(not(feature = "yara"))]
return Ok(removed as u64);
#[cfg(feature = "yara")]
{
let mut statement = self
.conn
.prepare("delete from yara_search where completed < ?1")?;
let yara_removed = statement.execute(params![elapsed])?;
Ok((yara_removed + removed) as u64)
}
}
pub fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
let mut statement = self.conn.prepare(&format!("select distinct sha256 from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid)\
join usergroup on (groupsource.gid = usergroup.gid)\
where file.{} = ?1 and usergroup.pid = ?2", hash.name()))?;
statement
.query_row(params![hash.the_hash(), uid], |row| {
let sha256: String = row.get(0)?;
Ok(sha256)
})
.map_err(|e| anyhow!(e.to_string()))
}
pub fn get_sample_report(&self, uid: u32, hash: &HashType) -> Result<malwaredb_api::Report> {
let mut statement = if cfg!(feature = "vt") {
self.conn.prepare(&format!("select md5, sha1, sha256, sha384, sha512, lzjd, tlsh, ssdeep, humanhash, filecommand, size, entropy, vtdata.hits, vtdata.total, vtdata.vtdetail, vtdata.tstamp from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid) \
left outer join vtdata on (vtdata.fileid = file.id) \
where file.{} = ?1 and usergroup.pid = ?2 \
order by vtdata.tstamp desc limit 1", hash.name()))?
} else {
self.conn.prepare(&format!("select md5, sha1, sha256, sha384, sha512, lzjd, tlsh, ssdeep, humanhash, filecommand, size, entropy from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid) \
where file.{} = ?1 and usergroup.pid = ?2", hash.name()))?
};
statement
.query_row(params![&hash.the_hash(), &uid], |row| {
let bytes: i32 = row.get(10)?;
let formatter = make_format(DECIMAL);
let vt = if cfg!(feature = "vt") {
let hits: Option<u32> = row.get(12)?;
let total: Option<u32> = row.get(13)?;
let detail: Option<String> = row.get(14)?;
Some(malwaredb_api::VirusTotalSummary {
hits: hits.unwrap_or_default(),
total: total.unwrap_or_default(),
detail: detail.map(|d| json!(d)),
last_analysis_date: row.get(15)?,
})
} else {
None
};
Ok(malwaredb_api::Report {
md5: row.get(0)?,
sha1: row.get(1)?,
sha256: row.get(2)?,
sha384: row.get(3)?,
sha512: row.get(4)?,
lzjd: row.get(5)?,
tlsh: row.get(6)?,
ssdeep: row.get(7)?,
humanhash: row.get(8)?,
filecommand: row.get(9)?,
bytes: bytes as u64,
size: formatter(bytes as u32),
entropy: row.get(11)?,
vt,
})
})
.map_err(|e| anyhow!(e.to_string()))
}
#[instrument]
pub fn find_similar_samples(
&self,
uid: u32,
sim: &[(SimilarityHashType, String)],
) -> Result<Vec<SimilarSample>> {
let mut results = HashMap::<String, Vec<(SimilarityHashType, f32)>>::new();
for (algo, (table_field, hash_func), hash_value) in sim
.iter()
.map(|(sim, val)| (*sim, sim.get_table_field_simfunc(), val))
{
let mut statement = if let Some(hash_func) = hash_func {
if algo == SimilarityHashType::TLSH {
self.conn.prepare(&format!(
"select sha256, CAST({hash_func}({table_field}, ?2) as REAL) as similarity from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid) \
where {hash_func}({table_field}, ?2) < 500 and usergroup.pid = ?1"
))?
} else {
self.conn.prepare(&format!(
"select sha256, CAST({hash_func}({table_field}, ?2) as REAL) as similarity from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid) \
where {hash_func}({table_field}, ?2) > 0 and usergroup.pid = ?1"
))?
}
} else {
self.conn.prepare(&format!(
"select file.sha256, 100.0 \
from file join executable on (file.id = executable.file_id) \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid) \
where {table_field} = ?2 and usergroup.pid = ?1",
))?
};
let rows = statement.query_map(params![&uid, &hash_value], |row| {
let sha256: String = row.get(0)?;
let tlsh: f32 = row.get(1)?;
Ok((sha256, tlsh))
})?;
for row in rows {
let (sha256, similarity) = row?;
if let Some(already) = results.get_mut(&sha256) {
already.push((algo, similarity));
} else {
results.insert(sha256, vec![(algo, similarity)]);
}
}
}
Ok(results
.into_iter()
.map(|(sha256, algorithms)| SimilarSample { sha256, algorithms })
.collect())
}
pub fn user_allowed_files_by_sha256(
&self,
uid: u32,
next: Option<u64>,
) -> Result<(Vec<String>, u64)> {
let next = if let Some(next) = next {
format!(" and file.id > {next}")
} else {
String::new()
};
let mut statement = self.conn.prepare(&format!(
"select distinct file.sha256, file.id \
from file \
join filesource on (file.id = filesource.fileid) \
join groupsource on (groupsource.sourceid = filesource.sourceid) \
join usergroup on (groupsource.gid = usergroup.gid) \
where usergroup.pid = ?1{next} \
order by file.id asc \
limit ?2"
))?;
let results = statement.query_map(params![&uid, &PARTIAL_SEARCH_LIMIT], |row| {
let hash: String = row.get(0)?;
let id: u64 = row.get(1)?;
Ok((hash, id))
})?;
let mut hashes = Vec::with_capacity(PARTIAL_SEARCH_LIMIT as usize);
let mut last_id = 0;
for result in results {
let (hash, id) = result?;
hashes.push(hash);
last_id = id;
}
Ok((hashes, last_id))
}
pub(crate) fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
let mut statement = self
.conn
.prepare("select id, name, bytes from encryptionkey")?;
let mut keys = HashMap::new();
let results = statement.query_map([], |row| {
let id: u32 = row.get(0)?;
let name: String = row.get(1)?;
let key: String = row.get(2)?;
Ok((id, name, key))
})?;
for result in results {
let (id, name, key) = result?;
let bytes = hex::decode(key)?;
let key = FileEncryption::new(EncryptionOption::try_from(name.as_str())?, bytes)?;
keys.insert(id, key);
}
Ok(keys)
}
pub(crate) fn get_file_encryption_key_id(
&self,
hash: &str,
) -> Result<(Option<u32>, Option<Vec<u8>>)> {
let mut statement = self
.conn
.prepare("select key, nonce from file where sha256 = ?1")?;
let (id, nonce) = statement.query_row([hash], |row| {
let id: Option<u32> = row.get(0)?;
let nonce: Option<String> = row.get(1)?;
Ok((id, nonce))
})?;
let nonce = if let Some(nonce) = nonce {
Some(hex::decode(nonce)?)
} else {
None
};
Ok((id, nonce))
}
pub(crate) fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
let nonce = nonce.as_ref().map(hex::encode);
self.conn.execute(
"update file set nonce = ?2 where sha256 = ?1",
(&hash, &nonce),
)?;
Ok(())
}
#[cfg_attr(not(test), allow(unused_variables))]
pub fn migrate(&self, action: Migration) -> Result<()> {
let mut statement = self.conn.prepare("select version from mdbconfig")?;
let mdb_db_version: String = statement.query_one([], |row| row.get(0))?;
let mdb_db_version = Version::parse(&mdb_db_version).map_err(|_| {
anyhow!("Failed to parse MalwareDB version in the database: {mdb_db_version}")
})?;
if mdb_db_version < Version::new(0, 2, 0) {
bail!("MalwareDB database version is too old to migrate automatically");
}
if mdb_db_version < Version::new(0, 3, 0) {
#[cfg(any(test, feature = "admin"))]
{
match action {
Migration::Migrate => {
self.conn.execute(
"CREATE TABLE yara_search (
id text NOT NULL,
yara_text text NOT NULL,
yara_compiled text,
userid INTEGER NOT NULL REFERENCES person(id),
results text,
last_fileid INTEGER REFERENCES file(id),
completed timestamp with time zone,
created timestamp with time zone NOT NULL DEFAULT current_timestamp,
PRIMARY KEY (id)
);",
[],
)?;
self.conn
.execute("update mdbconfig set version = '0.3.0';", [])?;
}
Migration::Check => {
bail!("MalwareDB database needs migration.");
}
}
}
#[cfg(not(any(test, feature = "admin")))]
bail!("MalwareDB database needs migration.");
}
self.conn
.execute("update mdbconfig set version = ?1;", params![&MDB_VERSION])?;
Ok(())
}
pub fn set_name(&self, name: &str) -> Result<()> {
self.conn
.execute("update mdbconfig set name = ?1", params![name])?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn enable_compression(&self) -> Result<()> {
self.conn
.execute("update mdbconfig set compress = true", ())?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn disable_compression(&self) -> Result<()> {
self.conn
.execute("update mdbconfig set compress = false", ())?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn enable_keep_unknown_files(&self) -> Result<()> {
self.conn
.execute("update mdbconfig set keep_unknown_files = true", ())?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn disable_keep_unknown_files(&self) -> Result<()> {
self.conn
.execute("update mdbconfig set keep_unknown_files = false", ())?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
let bytes = hex::encode(key.key());
self.conn.execute(
"insert into encryptionkey(name, bytes) values(?1, ?2)",
(&key.name(), &bytes),
)?;
let mut statement = self
.conn
.prepare("select id from encryptionkey where bytes = ?1")?;
let key_id = statement.query_row([bytes], |row| {
let id: u32 = row.get(0)?;
Ok(id)
})?;
self.conn
.execute("update mdbconfig set defaultKey = ?1 ", params![key_id])?;
Ok(key_id)
}
#[cfg(any(test, feature = "admin"))]
pub fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
let mut statement = self.conn.prepare("select id, name from encryptionkey")?;
let mut keys = vec![];
let results = statement.query_map([], |row| {
let id: u32 = row.get(0)?;
let name: String = row.get(1)?;
Ok((id, name))
})?;
for result in results {
let (id, name) = result?;
let key_option: EncryptionOption = name.as_str().try_into()?;
keys.push((id, key_option));
}
Ok(keys)
}
#[allow(clippy::too_many_arguments)]
#[cfg(any(test, feature = "admin"))]
pub fn create_user(
&self,
uname: &str,
fname: &str,
lname: &str,
email: &str,
password: Option<String>,
organisation: Option<&String>,
readonly: bool,
) -> Result<u32> {
let mut statement = self
.conn
.prepare("select count(1) from person where uname = ?1")?;
let results = statement.query_map([uname], |row| {
let count: u32 = row.get(0)?;
Ok(count)
})?;
if let Some(count) = results.flatten().next() {
if count != 0 {
bail!("User already exists");
}
}
let now: DateTime<Utc> = Utc::now();
match password {
None => {
self.conn
.execute("insert into person(email, uname, firstname, lastname, organisation, created, readonly) values (?1, ?2, ?3, ?4, ?5, ?6, ?7);", (&email, &uname, &fname, &lname, &organisation, &now.to_rfc3339(), &readonly))?;
}
Some(pass) => {
let password = hash_password(&pass)?;
self.conn
.execute("insert into person(email, uname, firstname, lastname, organisation, password, created, readonly) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8);", (&email, &uname, &fname, &lname, &organisation, &password, &now.to_rfc3339(), &readonly))?;
}
}
let mut statement = self
.conn
.prepare("select id from person where uname = ?1")?;
let results = statement.query_map([uname], |row| {
let uid: u64 = row.get(0)?;
Ok(uid)
})?;
if let Some(uid) = results.flatten().next() {
return Ok(uid as u32);
}
bail!("User not created")
}
#[cfg(any(test, feature = "admin"))]
pub fn reset_api_keys(&self) -> Result<u64> {
let reset = self.conn.execute("update person set apikey = NULL", ())?;
Ok(reset as u64)
}
#[cfg(any(test, feature = "admin"))]
pub fn set_password(&self, uname: &str, password: &str) -> Result<()> {
let password = hash_password(password)?;
debug_assert_eq!(
self.conn.execute(
"update person set password = ?1 where uname = ?2",
(&password, &uname),
)?,
1
);
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn list_users(&self) -> Result<Vec<admin::User>> {
let mut users = Vec::new();
let mut statement = self.conn.prepare("select id,email,uname,firstname,lastname,password is not null and length(password)>0,apikey is not null and length(apikey)>0,organisation,phone,created,readonly from person")?;
for result in statement.query_map([], |row| {
Ok(admin::User {
id: row.get(0)?,
email: row.get(1)?,
uname: row.get(2)?,
fname: row.get(3)?,
lname: row.get(4)?,
has_password: row.get(5)?,
has_api_key: row.get(6)?,
org: row.get(7)?,
phone: row.get(8)?,
created: row.get(9)?,
is_readonly: row.get(10)?,
})
})? {
if let Ok(user) = result {
users.push(user);
} else {
bail!("Failed to fetch user data");
}
}
Ok(users)
}
#[cfg(any(test, feature = "admin"))]
pub fn group_id_from_name(&self, name: &str) -> Result<i32> {
let mut statement = self.conn.prepare("select id from grp where name = ?1")?;
let id = statement.query_row(params![name], |row| {
let id: i32 = row.get(0)?;
Ok(id)
})?;
Ok(id)
}
#[cfg(any(test, feature = "admin"))]
pub fn edit_group(&self, gid: u32, name: &str, desc: &str, parent: Option<u32>) -> Result<()> {
self.conn.execute(
"update grp set name = ?1, description = ?2, parent = ?3 where id = ?4",
params![&name, &desc, &parent, &gid],
)?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn list_groups(&self) -> Result<Vec<admin::Group>> {
let mut groups = Vec::new();
let mut statement = self.conn.prepare("select grp.id, grp.name, grp.description, parent.name from grp left join grp parent on grp.parent = parent.id order by 1")?;
for result in statement.query_map([], |row| {
let id: i32 = row.get(0)?;
let members = {
let mut statement_members = self.conn.prepare("select person.id, person.uname, person.email, person.firstname, person.lastname, person.organisation, person.phone, person.password is not null and length(person.password)>0,person.apikey is not null and length(person.apikey)>0, person.created, person.readonly from person, usergroup where person.id = usergroup.pid and usergroup.gid = ?1").expect("failed to get user names for group membership");
let mut members_list = vec![];
for member in statement_members.query_map([id], |member_row| {
Ok(admin::User {
id: member_row.get(0).expect("failed to get id"),
uname: member_row.get(1).expect("failed to get uname"),
email: member_row.get(2).expect("failed to get email"),
fname: member_row.get(3).expect("failed to get first name"),
lname: member_row.get(4).expect("failed to get last name"),
org: member_row.get(5).expect("failed to get org"),
phone: member_row.get(6).expect("failed to get phone"),
has_password: member_row.get(7).expect("failed to get password info"),
has_api_key: member_row.get(8).expect("failed to get api key info"),
created: member_row.get(9).expect("failed to get user created date"),
is_readonly: member_row.get(10).expect("failed to get user is readonly"),
})
})? {
members_list.push(member.expect("failed to get string form of uname"));
}
members_list
};
let sources = {
let mut sources_list = vec![];
let mut statement_sources = self.conn.prepare("select source.id, source.name, source.description, source.url, source.firstacquisition, source.malicious, parent_source.name from source left join source as parent_source on (source.parent = parent_source.id), groupsource where source.id = groupsource.sourceid and groupsource.gid = ?1")?;
for source in statement_sources.query_map([id], |source_row| {
let date: String = source_row.get(4).expect("failed to get date");
let date = chrono::DateTime::parse_from_rfc3339(&date).expect("failed to get source date in in Sqlite::list_groups");
let date = date.with_timezone(&Local);
let id = source_row.get(0).expect("failed to get id");
let mut counts_statement = self.conn.prepare("select (select count(1) from filesource where sourceid = ?1), (select count(1) from groupsource where sourceid = ?2)")?;
let (files, groups) = counts_statement.query_row([id, id], |counts_row| {
let files: u64 = counts_row.get(0)?;
let groups: u32 = counts_row.get(1)?;
Ok((files, groups))
})?;
Ok(admin::Source {
id,
name: source_row.get(1).expect("failed to get name"),
description: source_row.get(2).expect("failed to get description"),
url: source_row.get(3).expect("failed to get url"),
date,
files,
groups,
parent: source_row.get(6).expect("failed to get source parent"),
malicious: source_row.get(5).expect("failed to get source is malicious flag"),
})
})? {
sources_list.push(source.expect("unable to get source linked to a group"));
}
sources_list
};
let mut files_statement = self.conn.prepare("select count(filesource.fileid) from filesource join groupsource on groupsource.gid = filesource.sourceid where groupsource.gid = ?1")?;
let files = files_statement.query_row([id], |row| {
let count: u32 = row.get(0).expect("failed to get files count for group");
Ok(count)
})?;
Ok(admin::Group {
id: id as u32,
name: row.get(1)?,
description: row.get(2)?,
parent: row.get(3)?,
members,
sources,
files,
})
})? {
match result {
Ok(group) => groups.push(group),
Err(e) => bail!("Failed to fetch group data: {e}"),
}
}
Ok(groups)
}
#[cfg(any(test, feature = "admin"))]
pub fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
self.conn
.execute("insert into usergroup(pid, gid) values(?1, ?2)", (uid, gid))?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
self.conn.execute(
"insert into groupsource(gid, sourceid) values(?1, ?2)",
(gid, sid),
)?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn create_group(&self, name: &str, description: &str, parent: Option<u32>) -> Result<u32> {
let result = if let Some(pid) = parent {
self.conn.execute(
"insert into grp(name, description, parent) values (?1, ?2, ?3);",
params![&name, &description, &pid],
)
} else {
self.conn.execute(
"insert into grp(name, description) values (?1, ?2);",
[&name, &description],
)
}?;
if result != 1 {
bail!("failed to create group {name}");
}
let mut statement = self.conn.prepare("select id from grp where name = ?1")?;
let gid = statement
.query_map([name], |row| {
let gid: i32 = row.get(0)?;
Ok(gid)
})?
.flatten()
.next();
gid.map(|g| g as u32)
.ok_or(anyhow!("unable to get group ID"))
}
#[cfg(any(test, feature = "admin"))]
pub fn list_sources(&self) -> Result<Vec<admin::Source>> {
let mut sources = Vec::new();
let mut statement = self
.conn
.prepare("select source.id, source.name, source.description, source.url, source.firstacquisition, source.malicious, parent_source.name from source left join source as parent_source on (source.parent = parent_source.id)")?;
for result in statement.query_map([], |row| {
let id = row.get(0)?;
let mut counts_statement = self.conn.prepare("select (select count(1) from filesource where sourceid = ?1), (select count(1) from groupsource where sourceid = ?1)")?;
let (files, groups) = counts_statement.query_row([id], |counts_row| {
let files: u64 = counts_row.get(0)?;
let groups: u32 = counts_row.get(1)?;
Ok((files, groups))
})?;
let date: String = row.get(4)?;
let date = chrono::DateTime::parse_from_rfc3339(&date)
.expect("failed to get source date in Sqlite::list_sources");
let date = date.with_timezone(&Local);
Ok(admin::Source {
id,
name: row.get(1)?,
description: row.get(2)?,
url: row.get(3)?,
date,
files,
groups,
parent: row.get(6)?,
malicious: row.get(5)?,
})
})? {
if let Ok(source) = result {
sources.push(source);
} else {
bail!("Failed to fetch source data");
}
}
Ok(sources)
}
#[cfg(any(test, feature = "admin"))]
pub fn create_source(
&self,
name: &str,
description: Option<&str>,
url: Option<&str>,
date: DateTime<Local>,
releasable: bool,
malicious: Option<bool>,
) -> Result<u32> {
let malicious: Option<u32> = malicious.map(u32::from);
let result = self.conn.execute(
"insert into source(name, description, url, firstacquisition, releasable, malicious) values (?1, ?2, ?3, ?4, ?5, ?6);",
params![&name, &description, url, date.to_rfc3339(), releasable, malicious],
)?;
if result != 1 {
bail!("failed to create source {name}");
}
let mut statement = self.conn.prepare("select id from source where name = ?1")?;
let sid = statement.query_row([name], |row| {
let sid: i32 = row.get(0)?;
Ok(sid as u32)
})?;
Ok(sid)
}
#[cfg(any(test, feature = "admin"))]
pub fn edit_user(
&self,
uid: u32,
uname: &str,
fname: &str,
lname: &str,
email: &str,
readonly: bool,
) -> Result<()> {
self.conn.execute("update person set uname = ?1, email = ?2, firstname = ?3, lastname = ?4, readonly = ?5 where id = ?6;", params![&uname, &email, &fname, &lname, &readonly, &uid])?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn deactivate_user(&self, uid: u32) -> Result<()> {
self.conn.execute(
"update person set password = null, apikey = null where id = ?1;",
params![&uid],
)?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
let mut types_counts = HashMap::default();
let mut statement = self.conn.prepare("SELECT filetype.name, count(file.id) from file join filetype on (file.filetypeid = filetype.id) group by 1")?;
let results = statement
.query_map([], |row| {
let name = row.get(0)?;
let count: i32 = row.get(1)?;
Ok((name, count))
})?
.flatten();
for (name, count) in results {
types_counts.insert(name, count as u32);
}
Ok(types_counts)
}
#[cfg(any(test, feature = "admin"))]
pub fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
self.conn.execute(
"insert into label(name, parent) values (?1, ?2);",
params![&name, &parent],
)?;
let mut statement = self.conn.prepare("select id from label where name = ?1")?;
let lid = statement.query_row([name], |row| {
let lid: i64 = row.get(0)?;
Ok(lid)
})?;
Ok(lid as u64)
}
#[cfg(any(test, feature = "admin"))]
pub fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
self.conn.execute(
"update label set name = ?1, parent = ?2 where id = ?3",
params![&name, &parent, &id],
)?;
Ok(())
}
#[cfg(any(test, feature = "admin"))]
pub fn label_id_from_name(&self, name: &str) -> Result<u64> {
let mut statement = self.conn.prepare("select id from label where name = ?1")?;
let lid = statement.query_row([name], |row| {
let lid: i64 = row.get(0)?;
Ok(lid)
})?;
Ok(lid as u64)
}
#[cfg(any(test, feature = "admin"))]
pub fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
self.conn.execute(
"insert into filelabel(fileid, labelid) values (?1, ?2)",
params![file_id, label_id],
)?;
Ok(())
}
}
unsafe impl Send for Sqlite {}
unsafe impl Sync for Sqlite {}
impl Display for Sqlite {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "SQLite client")
}
}
impl Debug for Sqlite {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "SQLite client")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn migration() {
const DB_FILE: &str = "testing_sqlite_migration.db";
if std::path::Path::new(DB_FILE).exists() {
fs::remove_file(DB_FILE).unwrap();
}
let conn = Connection::open(DB_FILE).unwrap();
let mut batch = Batch::new(&conn, include_str!("../../testdata/sqlite/v0.2.2.sql"));
while let Some(mut stmt) = batch.next().unwrap() {
stmt.execute([]).unwrap();
}
conn.execute("update mdbconfig set version = '0.2.0';", [])
.unwrap();
let sqlite = Sqlite {
conn,
file_path: DB_FILE.into(),
};
sqlite.migrate(Migration::Migrate).unwrap();
let sqlite = Sqlite::new(DB_FILE).unwrap();
sqlite.migrate(Migration::Check).unwrap();
if std::path::Path::new(DB_FILE).exists() {
fs::remove_file(DB_FILE).unwrap();
}
}
}