1#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9pub mod crypto;
11
12pub mod db;
14
15pub mod http;
17
18pub mod utils;
20
21#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
28#[cfg(feature = "yara")]
29pub mod yara;
30
31use crate::crypto::FileEncryption;
32use crate::db::MDBConfig;
33use malwaredb_api::ServerInfo;
34use std::collections::HashMap;
37use std::fmt::{Debug, Formatter};
38use std::io::{Cursor, Read};
39use std::net::{IpAddr, Ipv4Addr, SocketAddr};
40use std::path::PathBuf;
41use std::sync::{Arc, LazyLock};
42use std::time::{Duration, SystemTime};
43
44use anyhow::{anyhow, bail, ensure, Context, Result};
45use axum_server::tls_rustls::RustlsConfig;
46use chrono::Local;
47use chrono_humanize::{Accuracy, HumanTime, Tense};
48use flate2::read::GzDecoder;
49use mdns_sd::{ServiceDaemon, ServiceInfo};
50use sha2::{Digest, Sha256};
51use tokio::net::TcpListener;
52use tracing::{trace, warn};
53
54pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
56
57pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
59 LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
60
61pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_hours(24);
64
65pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
67
68pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
70
71pub struct StateBuilder {
73 pub port: u16,
75
76 pub directory: Option<PathBuf>,
78
79 pub max_upload: usize,
81
82 pub ip: IpAddr,
84
85 db_type: db::DatabaseType,
87
88 #[cfg(feature = "vt")]
90 vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
91
92 tls_config: Option<RustlsConfig>,
94
95 mdns: bool,
97}
98
99impl StateBuilder {
100 pub async fn new(db_string: &str, pg_cert: Option<PathBuf>) -> Result<Self> {
108 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
109
110 Ok(Self {
111 port: 8080,
112 directory: None,
113 max_upload: 104_857_600, ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
115 db_type,
116 #[cfg(feature = "vt")]
117 vt_client: None,
118 tls_config: None,
119 mdns: false,
120 })
121 }
122
123 #[must_use]
126 pub fn port(mut self, port: u16) -> Self {
127 self.port = port;
128 self
129 }
130
131 #[must_use]
135 pub fn directory(mut self, directory: PathBuf) -> Self {
136 self.directory = Some(directory);
137 self
138 }
139
140 #[must_use]
143 pub fn max_upload(mut self, max_upload: usize) -> Self {
144 self.max_upload = max_upload;
145 self
146 }
147
148 #[must_use]
151 pub fn ip(mut self, ip: IpAddr) -> Self {
152 self.ip = ip;
153 self
154 }
155
156 #[must_use]
158 #[cfg(feature = "vt")]
159 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
160 pub fn vt_client(mut self, vt_client: malwaredb_virustotal::VirusTotalClient) -> Self {
161 self.vt_client = Some(vt_client);
162 self
163 }
164
165 pub async fn tls(mut self, cert_file: PathBuf, key_file: PathBuf) -> Result<Self> {
172 ensure!(
173 cert_file.exists(),
174 "Certificate file {} does not exist!",
175 cert_file.display()
176 );
177
178 ensure!(
179 key_file.exists(),
180 "Key file {} does not exist!",
181 key_file.display()
182 );
183
184 let cert_ext_str = cert_file
185 .extension()
186 .context("failed to get certificate extension")?;
187 let key_ext_str = key_file
188 .extension()
189 .context("failed to get key extension")?;
190
191 if rustls::crypto::CryptoProvider::get_default().is_none() {
193 rustls::crypto::aws_lc_rs::default_provider()
194 .install_default()
195 .map_err(|_| anyhow!("failed to install AWS-LC crypto provider"))?;
196 }
197
198 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem" {
199 RustlsConfig::from_pem_file(cert_file, key_file)
200 .await
201 .context("failed to load or parse certificate and key pem files")?
202 } else if cert_ext_str == "der" && key_ext_str == "der" {
203 let cert_contents =
204 std::fs::read(cert_file).context("failed to read certificate file")?;
205 let key_contents =
206 std::fs::read(key_file).context("failed to read private key file")?;
207 RustlsConfig::from_der(vec![cert_contents], key_contents)
208 .await
209 .context("failed to parse certificate and key der files")?
210 } else {
211 bail!(
212 "Unknown or unmatched certificate and key file extensions {} and {}",
213 cert_ext_str.display(),
214 key_ext_str.display()
215 );
216 };
217
218 self.tls_config = Some(config);
219 Ok(self)
220 }
221
222 #[must_use]
225 pub fn enable_mdns(mut self) -> Self {
226 self.mdns = true;
227 self
228 }
229
230 pub async fn into_state(self) -> Result<State> {
236 let db_config = self.db_type.get_config().await?;
237 let keys = self.db_type.get_encryption_keys().await?;
238
239 Ok(State {
240 port: self.port,
241 directory: self.directory,
242 max_upload: self.max_upload,
243 ip: self.ip,
244 db_type: Arc::new(self.db_type),
245 started: SystemTime::now(),
246 db_config,
247 keys,
248 #[cfg(feature = "vt")]
249 vt_client: self.vt_client,
250 tls_config: self.tls_config,
251 mdns: if self.mdns {
252 Some(ServiceDaemon::new()?)
253 } else {
254 None
255 },
256 })
257 }
258}
259
260pub struct State {
262 pub port: u16,
264
265 pub directory: Option<PathBuf>,
267
268 pub max_upload: usize,
270
271 pub ip: IpAddr,
273
274 pub db_type: Arc<db::DatabaseType>,
276
277 pub started: SystemTime,
279
280 pub db_config: MDBConfig,
282
283 pub(crate) keys: HashMap<u32, FileEncryption>,
285
286 #[cfg(feature = "vt")]
288 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
289
290 tls_config: Option<RustlsConfig>,
292
293 mdns: Option<ServiceDaemon>,
295}
296
297impl State {
298 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
305 if let Some(dest_path) = &self.directory {
306 let mut hasher = Sha256::new();
307 hasher.update(data);
308 let sha256 = hex::encode(hasher.finalize());
309
310 let hashed_path = format!(
314 "{}/{}/{}/{}",
315 &sha256[0..2],
316 &sha256[2..4],
317 &sha256[4..6],
318 sha256
319 );
320
321 let mut dest_path = dest_path.clone();
324 dest_path.push(hashed_path);
325
326 let mut just_the_dir = dest_path.clone();
328 just_the_dir.pop();
329 std::fs::create_dir_all(just_the_dir)?;
330
331 let data = if self.db_config.compression {
332 let buff = Cursor::new(data);
333 let mut compressed = Vec::with_capacity(data.len() / 2);
334 zstd::stream::copy_encode(buff, &mut compressed, 4)?;
335 compressed
336 } else {
337 data.to_vec()
338 };
339
340 let data = if let Some(key_id) = self.db_config.default_key {
341 if let Some(key) = self.keys.get(&key_id) {
342 let nonce = key.nonce();
343 self.db_type
344 .set_file_nonce(&sha256, nonce.as_deref())
345 .await?;
346 key.encrypt(&data, nonce)?
347 } else {
348 bail!("Key not available!")
349 }
350 } else {
351 data
352 };
353
354 std::fs::write(dest_path, data)?;
355
356 Ok(true)
357 } else {
358 Ok(false)
359 }
360 }
361
362 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
370 if let Some(dest_path) = &self.directory {
371 let path = format!(
372 "{}/{}/{}/{}",
373 &sha256[0..2],
374 &sha256[2..4],
375 &sha256[4..6],
376 sha256
377 );
378 let contents = std::fs::read(dest_path.join(path))?;
383
384 let contents = if self.keys.is_empty() {
385 contents
387 } else {
388 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
389 if let Some(key_id) = key_id {
390 if let Some(key) = self.keys.get(&key_id) {
391 key.decrypt(&contents, nonce)?
392 } else {
393 bail!("File was encrypted but we don't have tke key!")
394 }
395 } else {
396 contents
398 }
399 };
400
401 if contents.starts_with(&GZIP_MAGIC) {
402 let buff = Cursor::new(contents);
403 let mut decompressor = GzDecoder::new(buff);
404 let mut decompressed: Vec<u8> = vec![];
405 decompressor.read_to_end(&mut decompressed)?;
406 Ok(decompressed)
407 } else if contents.starts_with(&ZSTD_MAGIC) {
408 let buff = Cursor::new(contents);
409 let mut decompressed: Vec<u8> = vec![];
410 zstd::stream::copy_decode(buff, &mut decompressed)?;
411 Ok(decompressed)
412 } else {
413 Ok(contents)
414 }
415 } else {
416 bail!("files are not saved")
417 }
418 }
419
420 #[must_use]
426 pub fn since(&self) -> Duration {
427 let now = SystemTime::now();
428 now.duration_since(self.started).unwrap()
429 }
430
431 pub async fn get_info(&self) -> Result<ServerInfo> {
437 let db_info = self.db_type.db_info().await?;
438 let uptime = Local::now() - self.since();
439 let mem_size = app_memory_usage_fetcher::get_memory_usage_string().unwrap_or_default();
440
441 Ok(ServerInfo {
442 os_name: std::env::consts::OS.into(),
443 memory_used: mem_size,
444 num_samples: db_info.num_files,
445 num_users: db_info.num_users,
446 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
447 mdb_version: MDB_VERSION_SEMVER.clone(),
448 db_version: db_info.version,
449 db_size: db_info.size,
450 instance_name: self.db_config.name.clone(),
451 vt_support: cfg!(feature = "vt"),
452 yara_enabled: cfg!(feature = "yara"),
453 })
454 }
455
456 pub async fn serve(
464 self,
465 #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
466 ) -> Result<()> {
467 let socket = SocketAddr::new(self.ip, self.port);
468 let arc_self = Arc::new(self);
469 let db_info = arc_self.db_type.clone();
470
471 #[cfg(feature = "yara")]
472 {
473 if arc_self.directory.is_some() {
474 start_yara_process(arc_self.clone());
475 }
476 }
477
478 tokio::spawn(async move {
479 loop {
480 match db_info.cleanup().await {
481 Ok(removed) => {
482 trace!("Pagination cleanup succeeded, {removed} searches removed");
483 }
484 Err(e) => warn!("Pagination cleanup failed: {e}"),
485 }
486
487 tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
488 }
489 });
490
491 if let Some(mdns) = &arc_self.mdns {
492 let host_name = format!("{}.local.", arc_self.ip);
493 let ssl = arc_self.tls_config.is_some();
494 let properties = [("ssl", ssl.to_string()), ("version", MDB_VERSION.into())];
495 let service = {
496 let mut service = ServiceInfo::new(
497 malwaredb_api::MDNS_NAME,
498 &arc_self.db_config.name,
499 &host_name,
500 &arc_self.ip,
501 arc_self.port,
502 &properties[..],
503 )?;
504 if arc_self.ip.is_unspecified() {
505 service = service.enable_addr_auto();
506 }
507 service
508 };
509 trace!("Registering MDNS service...");
510 mdns.register(service)?;
511 }
512
513 if let Some(tls_config) = arc_self.tls_config.clone() {
514 println!("Listening on https://{socket:?}");
515 let handle = axum_server::Handle::<SocketAddr>::new();
516 let server_future = axum_server::bind_rustls(socket, tls_config)
517 .serve(http::app(arc_self).into_make_service());
518 tokio::select! {
519 () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
520 handle.graceful_shutdown(Some(Duration::from_secs(30))),
521 res = server_future => res?,
522 }
523 warn!("Terminate signal received");
524 } else {
525 println!("Listening on http://{socket:?}");
526 let listener = TcpListener::bind(socket)
527 .await
528 .context(format!("failed to bind socket {socket}"))?;
529 axum::serve(listener, http::app(arc_self).into_make_service())
530 .with_graceful_shutdown(shutdown_signal(
531 #[cfg(target_family = "windows")]
532 rx,
533 ))
534 .await?;
535 warn!("Terminate signal received");
536 }
537 Ok(())
538 }
539}
540
541impl Debug for State {
542 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
543 let tls_mode = if self.tls_config.is_some() {
544 ", TLS mode"
545 } else {
546 ""
547 };
548 write!(
549 f,
550 "MDB state, port {}, database {:?}{tls_mode}",
551 self.port, self.db_type
552 )
553 }
554}
555
556#[cfg(feature = "yara")]
557#[allow(clippy::needless_pass_by_value)]
558fn start_yara_process(state: Arc<State>) {
559 let state_clone = state.clone();
560 tokio::spawn(async move {
561 let state_clone = state_clone.clone();
562 loop {
563 let tasks = match state_clone
564 .clone()
565 .db_type
566 .get_unfinished_yara_tasks()
567 .await
568 {
569 Ok(tasks) => tasks,
570 Err(e) => {
571 warn!("Failed to get Yara tasks: {e}");
572 continue;
573 }
574 };
575 for task in tasks {
576 let state_clone = state_clone.clone();
577 let (hashes, last_file_id) = match state_clone
578 .db_type
579 .user_allowed_files_by_sha256(task.user_id, task.last_file_id)
580 .await
581 {
582 Ok(hashes) => hashes,
583 Err(e) => {
584 warn!("Failed to get user allowed files: {e}");
585 continue;
586 }
587 };
588
589 if hashes.is_empty() {
590 if let Err(e) = state_clone
591 .db_type
592 .mark_yara_task_as_finished(task.id)
593 .await
594 {
595 warn!("Failed to mark yara task as finished: {e}");
596 }
597 continue;
598 }
599 tokio::spawn(async move {
600 for hash in hashes {
601 let bytes = match state_clone.clone().retrieve_bytes(&hash).await {
602 Ok(bytes) => bytes,
603 Err(e) => {
604 warn!("Failed to retrieve bytes for hash {hash}: {e}");
605 continue;
606 }
607 };
608 let matches = task.process_yara_rules(&bytes).unwrap_or_else(|e| {
609 warn!("Failed to process Yara rules: {e}");
610 Vec::new()
611 });
612
613 for match_ in matches {
614 if let Err(e) = state_clone
615 .db_type
616 .add_yara_match(task.id, &match_, &hash)
617 .await
618 {
619 warn!("Failed to add Yara match: {e}");
620 }
621 }
622 }
623 if let Err(e) = state_clone
624 .db_type
625 .yara_add_next_file_id(task.id, last_file_id)
626 .await
627 {
628 warn!("Failed to update yara task next file id: {e}");
629 }
630 });
631 }
632 tokio::time::sleep(Duration::from_secs(5)).await;
633 }
634 });
635}
636
637async fn shutdown_signal(
640 #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
641) {
642 let ctrl_c = async {
643 tokio::signal::ctrl_c()
644 .await
645 .expect("failed to install Ctrl+C handler");
646 };
647
648 #[cfg(unix)]
649 let terminate = async {
650 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
651 .expect("failed to install signal handler")
652 .recv()
653 .await;
654 };
655
656 #[cfg(not(unix))]
657 let terminate = std::future::pending::<()>();
658
659 #[cfg(target_family = "windows")]
660 if let Some(rx_inner) = &mut rx {
661 let terminate_rx = rx_inner.recv();
662
663 tokio::select! {
664 () = ctrl_c => {},
665 () = terminate => {},
666 Some(()) = terminate_rx => {},
667 }
668 } else {
669 tokio::select! {
670 () = ctrl_c => {},
671 () = terminate => {},
672 }
673 }
674
675 #[cfg(not(target_family = "windows"))]
676 tokio::select! {
677 () = ctrl_c => {},
678 () = terminate => {},
679 }
680}