1#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9pub mod crypto;
11
12pub mod db;
14
15pub mod http;
17
18pub mod utils;
20
21#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26#[cfg(feature = "yara")]
28pub mod yara;
29
30use crate::crypto::FileEncryption;
31use crate::db::MDBConfig;
32use malwaredb_api::ServerInfo;
33use std::collections::HashMap;
36use std::fmt::{Debug, Formatter};
37use std::io::{Cursor, Read};
38use std::net::{IpAddr, Ipv4Addr, SocketAddr};
39use std::path::PathBuf;
40use std::sync::{Arc, LazyLock};
41use std::time::{Duration, SystemTime};
42
43use anyhow::{anyhow, bail, ensure, Context, Result};
44use axum_server::tls_rustls::RustlsConfig;
45use chrono::Local;
46use chrono_humanize::{Accuracy, HumanTime, Tense};
47use flate2::read::GzDecoder;
48use mdns_sd::{ServiceDaemon, ServiceInfo};
49use sha2::{Digest, Sha256};
50use tokio::net::TcpListener;
51use tracing::{trace, warn};
52
53pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
55
56pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
58 LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
59
60pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
63
64pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
66
67pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
69
70pub struct StateBuilder {
72 pub port: u16,
74
75 pub directory: Option<PathBuf>,
77
78 pub max_upload: usize,
80
81 pub ip: IpAddr,
83
84 db_type: db::DatabaseType,
86
87 #[cfg(feature = "vt")]
89 vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
90
91 tls_config: Option<RustlsConfig>,
93
94 mdns: bool,
96}
97
98impl StateBuilder {
99 pub async fn new(db_string: &str, pg_cert: Option<PathBuf>) -> Result<Self> {
107 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
108
109 Ok(Self {
110 port: 8080,
111 directory: None,
112 max_upload: 104_857_600, ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
114 db_type,
115 #[cfg(feature = "vt")]
116 vt_client: None,
117 tls_config: None,
118 mdns: false,
119 })
120 }
121
122 #[must_use]
125 pub fn port(mut self, port: u16) -> Self {
126 self.port = port;
127 self
128 }
129
130 #[must_use]
134 pub fn directory(mut self, directory: PathBuf) -> Self {
135 self.directory = Some(directory);
136 self
137 }
138
139 #[must_use]
142 pub fn max_upload(mut self, max_upload: usize) -> Self {
143 self.max_upload = max_upload;
144 self
145 }
146
147 #[must_use]
150 pub fn ip(mut self, ip: IpAddr) -> Self {
151 self.ip = ip;
152 self
153 }
154
155 #[must_use]
157 #[cfg(feature = "vt")]
158 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
159 pub fn vt_client(mut self, vt_client: malwaredb_virustotal::VirusTotalClient) -> Self {
160 self.vt_client = Some(vt_client);
161 self
162 }
163
164 pub async fn tls(mut self, cert_file: PathBuf, key_file: PathBuf) -> Result<Self> {
171 ensure!(
172 cert_file.exists(),
173 "Certificate file {} does not exist!",
174 cert_file.display()
175 );
176
177 ensure!(
178 key_file.exists(),
179 "Key file {} does not exist!",
180 key_file.display()
181 );
182
183 let cert_ext_str = cert_file
184 .extension()
185 .context("failed to get certificate extension")?;
186 let key_ext_str = key_file
187 .extension()
188 .context("failed to get key extension")?;
189
190 if rustls::crypto::CryptoProvider::get_default().is_none() {
192 rustls::crypto::aws_lc_rs::default_provider()
193 .install_default()
194 .map_err(|_| anyhow!("failed to install AWS-LC crypto provider"))?;
195 }
196
197 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem" {
198 RustlsConfig::from_pem_file(cert_file, key_file)
199 .await
200 .context("failed to load or parse certificate and key pem files")?
201 } else if cert_ext_str == "der" && key_ext_str == "der" {
202 let cert_contents =
203 std::fs::read(cert_file).context("failed to read certificate file")?;
204 let key_contents =
205 std::fs::read(key_file).context("failed to read private key file")?;
206 RustlsConfig::from_der(vec![cert_contents], key_contents)
207 .await
208 .context("failed to parse certificate and key der files")?
209 } else {
210 bail!(
211 "Unknown or unmatched certificate and key file extensions {} and {}",
212 cert_ext_str.display(),
213 key_ext_str.display()
214 );
215 };
216
217 self.tls_config = Some(config);
218 Ok(self)
219 }
220
221 #[must_use]
224 pub fn enable_mdns(mut self) -> Self {
225 self.mdns = true;
226 self
227 }
228
229 pub async fn into_state(self) -> Result<State> {
235 let db_config = self.db_type.get_config().await?;
236 let keys = self.db_type.get_encryption_keys().await?;
237
238 Ok(State {
239 port: self.port,
240 directory: self.directory,
241 max_upload: self.max_upload,
242 ip: self.ip,
243 db_type: Arc::new(self.db_type),
244 started: SystemTime::now(),
245 db_config,
246 keys,
247 #[cfg(feature = "vt")]
248 vt_client: self.vt_client,
249 tls_config: self.tls_config,
250 mdns: if self.mdns {
251 Some(ServiceDaemon::new()?)
252 } else {
253 None
254 },
255 })
256 }
257}
258
259pub struct State {
261 pub port: u16,
263
264 pub directory: Option<PathBuf>,
266
267 pub max_upload: usize,
269
270 pub ip: IpAddr,
272
273 pub db_type: Arc<db::DatabaseType>,
275
276 pub started: SystemTime,
278
279 pub db_config: MDBConfig,
281
282 pub(crate) keys: HashMap<u32, FileEncryption>,
284
285 #[cfg(feature = "vt")]
287 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
288
289 tls_config: Option<RustlsConfig>,
291
292 mdns: Option<ServiceDaemon>,
294}
295
296impl State {
297 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
304 if let Some(dest_path) = &self.directory {
305 let mut hasher = Sha256::new();
306 hasher.update(data);
307 let sha256 = hex::encode(hasher.finalize());
308
309 let hashed_path = format!(
313 "{}/{}/{}/{}",
314 &sha256[0..2],
315 &sha256[2..4],
316 &sha256[4..6],
317 sha256
318 );
319
320 let mut dest_path = dest_path.clone();
323 dest_path.push(hashed_path);
324
325 let mut just_the_dir = dest_path.clone();
327 just_the_dir.pop();
328 std::fs::create_dir_all(just_the_dir)?;
329
330 let data = if self.db_config.compression {
331 let buff = Cursor::new(data);
332 let mut compressed = Vec::with_capacity(data.len() / 2);
333 zstd::stream::copy_encode(buff, &mut compressed, 4)?;
334 compressed
335 } else {
336 data.to_vec()
337 };
338
339 let data = if let Some(key_id) = self.db_config.default_key {
340 if let Some(key) = self.keys.get(&key_id) {
341 let nonce = key.nonce();
342 self.db_type
343 .set_file_nonce(&sha256, nonce.as_deref())
344 .await?;
345 key.encrypt(&data, nonce)?
346 } else {
347 bail!("Key not available!")
348 }
349 } else {
350 data
351 };
352
353 std::fs::write(dest_path, data)?;
354
355 Ok(true)
356 } else {
357 Ok(false)
358 }
359 }
360
361 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
369 if let Some(dest_path) = &self.directory {
370 let path = format!(
371 "{}/{}/{}/{}",
372 &sha256[0..2],
373 &sha256[2..4],
374 &sha256[4..6],
375 sha256
376 );
377 let contents = std::fs::read(dest_path.join(path))?;
382
383 let contents = if self.keys.is_empty() {
384 contents
386 } else {
387 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
388 if let Some(key_id) = key_id {
389 if let Some(key) = self.keys.get(&key_id) {
390 key.decrypt(&contents, nonce)?
391 } else {
392 bail!("File was encrypted but we don't have tke key!")
393 }
394 } else {
395 contents
397 }
398 };
399
400 if contents.starts_with(&GZIP_MAGIC) {
401 let buff = Cursor::new(contents);
402 let mut decompressor = GzDecoder::new(buff);
403 let mut decompressed: Vec<u8> = vec![];
404 decompressor.read_to_end(&mut decompressed)?;
405 Ok(decompressed)
406 } else if contents.starts_with(&ZSTD_MAGIC) {
407 let buff = Cursor::new(contents);
408 let mut decompressed: Vec<u8> = vec![];
409 zstd::stream::copy_decode(buff, &mut decompressed)?;
410 Ok(decompressed)
411 } else {
412 Ok(contents)
413 }
414 } else {
415 bail!("files are not saved")
416 }
417 }
418
419 #[must_use]
425 pub fn since(&self) -> Duration {
426 let now = SystemTime::now();
427 now.duration_since(self.started).unwrap()
428 }
429
430 pub async fn get_info(&self) -> Result<ServerInfo> {
436 let db_info = self.db_type.db_info().await?;
437 let uptime = Local::now() - self.since();
438 let mem_size = app_memory_usage_fetcher::get_memory_usage_string().unwrap_or_default();
439
440 Ok(ServerInfo {
441 os_name: std::env::consts::OS.into(),
442 memory_used: mem_size,
443 num_samples: db_info.num_files,
444 num_users: db_info.num_users,
445 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
446 mdb_version: MDB_VERSION_SEMVER.clone(),
447 db_version: db_info.version,
448 db_size: db_info.size,
449 instance_name: self.db_config.name.clone(),
450 })
451 }
452
453 pub async fn serve(
461 self,
462 #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
463 ) -> Result<()> {
464 let socket = SocketAddr::new(self.ip, self.port);
465 let arc_self = Arc::new(self);
466 let db_info = arc_self.db_type.clone();
467
468 #[cfg(feature = "yara")]
469 {
470 if arc_self.directory.is_some() {
471 start_yara_process(arc_self.clone());
472 }
473 }
474
475 tokio::spawn(async move {
476 loop {
477 match db_info.cleanup().await {
478 Ok(removed) => {
479 trace!("Pagination cleanup succeeded, {removed} searches removed");
480 }
481 Err(e) => warn!("Pagination cleanup failed: {e}"),
482 }
483
484 tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
485 }
486 });
487
488 if arc_self.mdns.is_some() && !arc_self.ip.is_loopback() {
489 if let Err(e) = arc_self.mdns_register().await {
490 warn!("Failed to register MDNS service: {e}");
491 }
492 }
493
494 if let Some(tls_config) = arc_self.tls_config.clone() {
495 println!("Listening on https://{socket:?}");
496 let handle = axum_server::Handle::<SocketAddr>::new();
497 let server_future = axum_server::bind_rustls(socket, tls_config)
498 .serve(http::app(arc_self).into_make_service());
499 tokio::select! {
500 () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
501 handle.graceful_shutdown(Some(Duration::from_secs(30))),
502 res = server_future => res?,
503 }
504 warn!("Terminate signal received");
505 } else {
506 println!("Listening on http://{socket:?}");
507 let listener = TcpListener::bind(socket)
508 .await
509 .context(format!("failed to bind socket {socket}"))?;
510 axum::serve(listener, http::app(arc_self).into_make_service())
511 .with_graceful_shutdown(shutdown_signal(
512 #[cfg(target_family = "windows")]
513 rx,
514 ))
515 .await?;
516 warn!("Terminate signal received");
517 }
518 Ok(())
519 }
520
521 async fn mdns_register(&self) -> Result<()> {
524 if let Some(mdns) = &self.mdns {
525 let db_config = self.db_type.get_config().await?;
526 let host_name = format!("{}.local.", self.ip);
527 let ssl = self.tls_config.is_some();
528 let properties = [("ssl", ssl.to_string())];
529 let service = ServiceInfo::new(
530 malwaredb_api::MDNS_NAME,
531 &db_config.name,
532 &host_name,
533 self.ip.to_string(),
534 self.port,
535 &properties[..],
536 )?;
537 mdns.register(service)?;
538 }
539
540 Ok(())
541 }
542}
543
544impl Debug for State {
545 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
546 let tls_mode = if self.tls_config.is_some() {
547 ", TLS mode"
548 } else {
549 ""
550 };
551 write!(
552 f,
553 "MDB state, port {}, database {:?}{tls_mode}",
554 self.port, self.db_type
555 )
556 }
557}
558
559#[cfg(feature = "yara")]
560#[allow(clippy::needless_pass_by_value)]
561fn start_yara_process(state: Arc<State>) {
562 let state_clone = state.clone();
563 tokio::spawn(async move {
564 let state_clone = state_clone.clone();
565 loop {
566 let tasks = match state_clone
567 .clone()
568 .db_type
569 .get_unfinished_yara_tasks()
570 .await
571 {
572 Ok(tasks) => tasks,
573 Err(e) => {
574 warn!("Failed to get Yara tasks: {e}");
575 continue;
576 }
577 };
578 for task in tasks {
579 let state_clone = state_clone.clone();
580 let (hashes, last_file_id) = match state_clone
581 .db_type
582 .user_allowed_files_by_sha256(task.user_id, task.last_file_id)
583 .await
584 {
585 Ok(hashes) => hashes,
586 Err(e) => {
587 warn!("Failed to get user allowed files: {e}");
588 continue;
589 }
590 };
591
592 if hashes.is_empty() {
593 if let Err(e) = state_clone
594 .db_type
595 .mark_yara_task_as_finished(task.id)
596 .await
597 {
598 warn!("Failed to mark yara task as finished: {e}");
599 }
600 continue;
601 }
602 tokio::spawn(async move {
603 for hash in hashes {
604 let bytes = match state_clone.clone().retrieve_bytes(&hash).await {
605 Ok(bytes) => bytes,
606 Err(e) => {
607 warn!("Failed to retrieve bytes for hash {hash}: {e}");
608 continue;
609 }
610 };
611 let matches = task.process_yara_rules(&bytes).unwrap_or_else(|e| {
612 warn!("Failed to process Yara rules: {e}");
613 Vec::new()
614 });
615
616 for match_ in matches {
617 if let Err(e) = state_clone
618 .db_type
619 .add_yara_match(task.id, &match_, &hash)
620 .await
621 {
622 warn!("Failed to add Yara match: {e}");
623 }
624 }
625 }
626 if let Err(e) = state_clone
627 .db_type
628 .yara_add_next_file_id(task.id, last_file_id)
629 .await
630 {
631 warn!("Failed to update yara task next file id: {e}");
632 }
633 });
634 }
635 tokio::time::sleep(Duration::from_secs(5)).await;
636 }
637 });
638}
639
640async fn shutdown_signal(
643 #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
644) {
645 let ctrl_c = async {
646 tokio::signal::ctrl_c()
647 .await
648 .expect("failed to install Ctrl+C handler");
649 };
650
651 #[cfg(unix)]
652 let terminate = async {
653 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
654 .expect("failed to install signal handler")
655 .recv()
656 .await;
657 };
658
659 #[cfg(not(unix))]
660 let terminate = std::future::pending::<()>();
661
662 #[cfg(target_family = "windows")]
663 if let Some(rx_inner) = &mut rx {
664 let terminate_rx = rx_inner.recv();
665
666 tokio::select! {
667 () = ctrl_c => {},
668 () = terminate => {},
669 Some(()) = terminate_rx => {},
670 }
671 } else {
672 tokio::select! {
673 () = ctrl_c => {},
674 () = terminate => {},
675 }
676 }
677
678 #[cfg(not(target_family = "windows"))]
679 tokio::select! {
680 () = ctrl_c => {},
681 () = terminate => {},
682 }
683}