#![doc = include_str!("../README.md")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(missing_docs)]
#![deny(clippy::all)]
#![deny(clippy::pedantic)]
pub mod crypto;
pub mod db;
pub mod http;
pub mod utils;
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
#[cfg(feature = "vt")]
pub mod vt;
#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
#[cfg(feature = "yara")]
pub mod yara;
use crate::crypto::FileEncryption;
use crate::db::MDBConfig;
use malwaredb_api::ServerInfo;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::io::{Cursor, Read};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::sync::{Arc, LazyLock};
use std::time::{Duration, SystemTime};
use anyhow::{anyhow, bail, ensure, Context, Result};
use axum_server::tls_rustls::RustlsConfig;
use chrono::Local;
use chrono_humanize::{Accuracy, HumanTime, Tense};
use flate2::read::GzDecoder;
use mdns_sd::{ServiceDaemon, ServiceInfo};
use sha2::{Digest, Sha256};
use tokio::net::TcpListener;
use tracing::{trace, warn};
pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_hours(24);
pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
pub struct StateBuilder {
pub port: u16,
pub directory: Option<PathBuf>,
pub max_upload: usize,
pub ip: IpAddr,
db_type: db::DatabaseType,
#[cfg(feature = "vt")]
vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
tls_config: Option<RustlsConfig>,
mdns: bool,
}
impl StateBuilder {
pub async fn new(db_string: &str, pg_cert: Option<PathBuf>) -> Result<Self> {
let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
Ok(Self {
port: 8080,
directory: None,
max_upload: 104_857_600,
ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
db_type,
#[cfg(feature = "vt")]
vt_client: None,
tls_config: None,
mdns: false,
})
}
#[must_use]
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
#[must_use]
pub fn directory(mut self, directory: PathBuf) -> Self {
self.directory = Some(directory);
self
}
#[must_use]
pub fn max_upload(mut self, max_upload: usize) -> Self {
self.max_upload = max_upload;
self
}
#[must_use]
pub fn ip(mut self, ip: IpAddr) -> Self {
self.ip = ip;
self
}
#[must_use]
#[cfg(feature = "vt")]
#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
pub fn vt_client(mut self, vt_client: malwaredb_virustotal::VirusTotalClient) -> Self {
self.vt_client = Some(vt_client);
self
}
pub async fn tls(mut self, cert_file: PathBuf, key_file: PathBuf) -> Result<Self> {
ensure!(
cert_file.exists(),
"Certificate file {} does not exist!",
cert_file.display()
);
ensure!(
key_file.exists(),
"Key file {} does not exist!",
key_file.display()
);
let cert_ext_str = cert_file
.extension()
.context("failed to get certificate extension")?;
let key_ext_str = key_file
.extension()
.context("failed to get key extension")?;
if rustls::crypto::CryptoProvider::get_default().is_none() {
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.map_err(|_| anyhow!("failed to install AWS-LC crypto provider"))?;
}
let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem" {
RustlsConfig::from_pem_file(cert_file, key_file)
.await
.context("failed to load or parse certificate and key pem files")?
} else if cert_ext_str == "der" && key_ext_str == "der" {
let cert_contents =
std::fs::read(cert_file).context("failed to read certificate file")?;
let key_contents =
std::fs::read(key_file).context("failed to read private key file")?;
RustlsConfig::from_der(vec![cert_contents], key_contents)
.await
.context("failed to parse certificate and key der files")?
} else {
bail!(
"Unknown or unmatched certificate and key file extensions {} and {}",
cert_ext_str.display(),
key_ext_str.display()
);
};
self.tls_config = Some(config);
Ok(self)
}
#[must_use]
pub fn enable_mdns(mut self) -> Self {
self.mdns = true;
self
}
pub async fn into_state(self) -> Result<State> {
let db_config = self.db_type.get_config().await?;
let keys = self.db_type.get_encryption_keys().await?;
Ok(State {
port: self.port,
directory: self.directory,
max_upload: self.max_upload,
ip: self.ip,
db_type: Arc::new(self.db_type),
started: SystemTime::now(),
db_config,
keys,
#[cfg(feature = "vt")]
vt_client: self.vt_client,
tls_config: self.tls_config,
mdns: if self.mdns {
Some(ServiceDaemon::new()?)
} else {
None
},
})
}
}
pub struct State {
pub port: u16,
pub directory: Option<PathBuf>,
pub max_upload: usize,
pub ip: IpAddr,
pub db_type: Arc<db::DatabaseType>,
pub started: SystemTime,
pub db_config: MDBConfig,
pub(crate) keys: HashMap<u32, FileEncryption>,
#[cfg(feature = "vt")]
pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
tls_config: Option<RustlsConfig>,
mdns: Option<ServiceDaemon>,
}
impl State {
pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
if let Some(dest_path) = &self.directory {
let mut hasher = Sha256::new();
hasher.update(data);
let sha256 = hex::encode(hasher.finalize());
let hashed_path = format!(
"{}/{}/{}/{}",
&sha256[0..2],
&sha256[2..4],
&sha256[4..6],
sha256
);
let mut dest_path = dest_path.clone();
dest_path.push(hashed_path);
let mut just_the_dir = dest_path.clone();
just_the_dir.pop();
std::fs::create_dir_all(just_the_dir)?;
let data = if self.db_config.compression {
let buff = Cursor::new(data);
let mut compressed = Vec::with_capacity(data.len() / 2);
zstd::stream::copy_encode(buff, &mut compressed, 4)?;
compressed
} else {
data.to_vec()
};
let data = if let Some(key_id) = self.db_config.default_key {
if let Some(key) = self.keys.get(&key_id) {
let nonce = key.nonce();
self.db_type
.set_file_nonce(&sha256, nonce.as_deref())
.await?;
key.encrypt(&data, nonce)?
} else {
bail!("Key not available!")
}
} else {
data
};
std::fs::write(dest_path, data)?;
Ok(true)
} else {
Ok(false)
}
}
pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
if let Some(dest_path) = &self.directory {
let path = format!(
"{}/{}/{}/{}",
&sha256[0..2],
&sha256[2..4],
&sha256[4..6],
sha256
);
let contents = std::fs::read(dest_path.join(path))?;
let contents = if self.keys.is_empty() {
contents
} else {
let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
if let Some(key_id) = key_id {
if let Some(key) = self.keys.get(&key_id) {
key.decrypt(&contents, nonce)?
} else {
bail!("File was encrypted but we don't have tke key!")
}
} else {
contents
}
};
if contents.starts_with(&GZIP_MAGIC) {
let buff = Cursor::new(contents);
let mut decompressor = GzDecoder::new(buff);
let mut decompressed: Vec<u8> = vec![];
decompressor.read_to_end(&mut decompressed)?;
Ok(decompressed)
} else if contents.starts_with(&ZSTD_MAGIC) {
let buff = Cursor::new(contents);
let mut decompressed: Vec<u8> = vec![];
zstd::stream::copy_decode(buff, &mut decompressed)?;
Ok(decompressed)
} else {
Ok(contents)
}
} else {
bail!("files are not saved")
}
}
#[must_use]
pub fn since(&self) -> Duration {
let now = SystemTime::now();
now.duration_since(self.started).unwrap()
}
pub async fn get_info(&self) -> Result<ServerInfo> {
let db_info = self.db_type.db_info().await?;
let uptime = Local::now() - self.since();
let mem_size = app_memory_usage_fetcher::get_memory_usage_string().unwrap_or_default();
Ok(ServerInfo {
os_name: std::env::consts::OS.into(),
memory_used: mem_size,
num_samples: db_info.num_files,
num_users: db_info.num_users,
uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
mdb_version: MDB_VERSION_SEMVER.clone(),
db_version: db_info.version,
db_size: db_info.size,
instance_name: self.db_config.name.clone(),
vt_support: cfg!(feature = "vt"),
yara_enabled: cfg!(feature = "yara"),
})
}
pub async fn serve(
self,
#[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
) -> Result<()> {
let socket = SocketAddr::new(self.ip, self.port);
let arc_self = Arc::new(self);
let db_info = arc_self.db_type.clone();
#[cfg(feature = "yara")]
{
if arc_self.directory.is_some() {
start_yara_process(arc_self.clone());
}
}
tokio::spawn(async move {
loop {
match db_info.cleanup().await {
Ok(removed) => {
trace!("Pagination cleanup succeeded, {removed} searches removed");
}
Err(e) => warn!("Pagination cleanup failed: {e}"),
}
tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
}
});
if let Some(mdns) = &arc_self.mdns {
let host_name = format!("{}.local.", arc_self.ip);
let ssl = arc_self.tls_config.is_some();
let properties = [("ssl", ssl.to_string()), ("version", MDB_VERSION.into())];
let service = {
let mut service = ServiceInfo::new(
malwaredb_api::MDNS_NAME,
&arc_self.db_config.name,
&host_name,
&arc_self.ip,
arc_self.port,
&properties[..],
)?;
if arc_self.ip.is_unspecified() {
service = service.enable_addr_auto();
}
service
};
trace!("Registering MDNS service...");
mdns.register(service)?;
}
if let Some(tls_config) = arc_self.tls_config.clone() {
println!("Listening on https://{socket:?}");
let handle = axum_server::Handle::<SocketAddr>::new();
let server_future = axum_server::bind_rustls(socket, tls_config)
.serve(http::app(arc_self).into_make_service());
tokio::select! {
() = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
handle.graceful_shutdown(Some(Duration::from_secs(30))),
res = server_future => res?,
}
warn!("Terminate signal received");
} else {
println!("Listening on http://{socket:?}");
let listener = TcpListener::bind(socket)
.await
.context(format!("failed to bind socket {socket}"))?;
axum::serve(listener, http::app(arc_self).into_make_service())
.with_graceful_shutdown(shutdown_signal(
#[cfg(target_family = "windows")]
rx,
))
.await?;
warn!("Terminate signal received");
}
Ok(())
}
}
impl Debug for State {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let tls_mode = if self.tls_config.is_some() {
", TLS mode"
} else {
""
};
write!(
f,
"MDB state, port {}, database {:?}{tls_mode}",
self.port, self.db_type
)
}
}
#[cfg(feature = "yara")]
#[allow(clippy::needless_pass_by_value)]
fn start_yara_process(state: Arc<State>) {
let state_clone = state.clone();
tokio::spawn(async move {
let state_clone = state_clone.clone();
loop {
let tasks = match state_clone
.clone()
.db_type
.get_unfinished_yara_tasks()
.await
{
Ok(tasks) => tasks,
Err(e) => {
warn!("Failed to get Yara tasks: {e}");
continue;
}
};
for task in tasks {
let state_clone = state_clone.clone();
let (hashes, last_file_id) = match state_clone
.db_type
.user_allowed_files_by_sha256(task.user_id, task.last_file_id)
.await
{
Ok(hashes) => hashes,
Err(e) => {
warn!("Failed to get user allowed files: {e}");
continue;
}
};
if hashes.is_empty() {
if let Err(e) = state_clone
.db_type
.mark_yara_task_as_finished(task.id)
.await
{
warn!("Failed to mark yara task as finished: {e}");
}
continue;
}
tokio::spawn(async move {
for hash in hashes {
let bytes = match state_clone.clone().retrieve_bytes(&hash).await {
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to retrieve bytes for hash {hash}: {e}");
continue;
}
};
let matches = task.process_yara_rules(&bytes).unwrap_or_else(|e| {
warn!("Failed to process Yara rules: {e}");
Vec::new()
});
for match_ in matches {
if let Err(e) = state_clone
.db_type
.add_yara_match(task.id, &match_, &hash)
.await
{
warn!("Failed to add Yara match: {e}");
}
}
}
if let Err(e) = state_clone
.db_type
.yara_add_next_file_id(task.id, last_file_id)
.await
{
warn!("Failed to update yara task next file id: {e}");
}
});
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
});
}
async fn shutdown_signal(
#[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
) {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
#[cfg(target_family = "windows")]
if let Some(rx_inner) = &mut rx {
let terminate_rx = rx_inner.recv();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
Some(()) = terminate_rx => {},
}
} else {
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
#[cfg(not(target_family = "windows"))]
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}