Skip to main content

malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9/// Cryptographic functionality for file storage
10pub mod crypto;
11
12/// Database I/O
13pub mod db;
14
15/// HTTP Server
16pub mod http;
17
18/// Entropy functions
19pub mod utils;
20
21/// Virus Total integration
22#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26/// Yara-related functionality
27#[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
28#[cfg(feature = "yara")]
29pub mod yara;
30
31use crate::crypto::FileEncryption;
32use crate::db::MDBConfig;
33use malwaredb_api::ServerInfo;
34//use utils::HashPath;
35
36use std::collections::HashMap;
37use std::fmt::{Debug, Formatter};
38use std::io::{Cursor, Read};
39use std::net::{IpAddr, Ipv4Addr, SocketAddr};
40use std::path::PathBuf;
41use std::sync::{Arc, LazyLock};
42use std::time::{Duration, SystemTime};
43
44use anyhow::{anyhow, bail, ensure, Context, Result};
45use axum_server::tls_rustls::RustlsConfig;
46use chrono::Local;
47use chrono_humanize::{Accuracy, HumanTime, Tense};
48use flate2::read::GzDecoder;
49use mdns_sd::{ServiceDaemon, ServiceInfo};
50use sha2::{Digest, Sha256};
51use tokio::net::TcpListener;
52use tracing::{trace, warn};
53
54/// MDB version
55pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
56
57/// MDB version as a semantic version object
58pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
59    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
60
61/// How often stale pagination searches should be cleaned up: 1 day
62/// Could be used if future cleanup operations are needed.
63pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_hours(24);
64
65/// Gzip's magic number to see if a file is compressed
66pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
67
68/// Zstd magic number to see if a file is compressed
69pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
70
71/// Builder for server configuration
72pub struct StateBuilder {
73    /// The port which will be used to listen for connections.
74    pub port: u16,
75
76    /// The directory to store malware samples if we're keeping them.
77    pub directory: Option<PathBuf>,
78
79    /// Maximum upload size
80    pub max_upload: usize,
81
82    /// The IP to use for listening for connections
83    pub ip: IpAddr,
84
85    /// Handle to the database connection
86    db_type: db::DatabaseType,
87
88    /// Virus Total API key
89    #[cfg(feature = "vt")]
90    vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
91
92    /// TLS configuration constructed from certificate and private key files
93    tls_config: Option<RustlsConfig>,
94
95    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
96    mdns: bool,
97}
98
99impl StateBuilder {
100    /// Create the builder starting with the database configuration, and optionally, the
101    /// certificate for communicating with Postgres.
102    ///
103    /// # Errors
104    ///
105    /// An error occurs if the database configuration isn't valid or if an error occurs connecting
106    /// to the database.
107    pub async fn new(db_string: &str, pg_cert: Option<PathBuf>) -> Result<Self> {
108        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
109
110        Ok(Self {
111            port: 8080,
112            directory: None,
113            max_upload: 104_857_600, /* 100 MiB */
114            ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
115            db_type,
116            #[cfg(feature = "vt")]
117            vt_client: None,
118            tls_config: None,
119            mdns: false,
120        })
121    }
122
123    /// Specify the port to listen on.
124    /// Default: `8080`
125    #[must_use]
126    pub fn port(mut self, port: u16) -> Self {
127        self.port = port;
128        self
129    }
130
131    /// Specify the directory to store malware samples if we're keeping them.
132    ///
133    /// Default: No directory, no file saving
134    #[must_use]
135    pub fn directory(mut self, directory: PathBuf) -> Self {
136        self.directory = Some(directory);
137        self
138    }
139
140    /// Specify the maximum upload size in bytes.
141    /// Default is 100 MiB.
142    #[must_use]
143    pub fn max_upload(mut self, max_upload: usize) -> Self {
144        self.max_upload = max_upload;
145        self
146    }
147
148    /// Indicate the IP address the server will list on.
149    /// Default: 127.0.0.1
150    #[must_use]
151    pub fn ip(mut self, ip: IpAddr) -> Self {
152        self.ip = ip;
153        self
154    }
155
156    /// Provide the Virus Total API key.
157    #[must_use]
158    #[cfg(feature = "vt")]
159    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
160    pub fn vt_client(mut self, vt_client: malwaredb_virustotal::VirusTotalClient) -> Self {
161        self.vt_client = Some(vt_client);
162        self
163    }
164
165    /// Provide the certificate and private key for TLS mode.
166    /// Files must match: both as PEM or both as DER.
167    ///
168    /// # Errors
169    ///
170    /// An error results if either file doesn't exist, not in the same format, or cannot be parsed.
171    pub async fn tls(mut self, cert_file: PathBuf, key_file: PathBuf) -> Result<Self> {
172        ensure!(
173            cert_file.exists(),
174            "Certificate file {} does not exist!",
175            cert_file.display()
176        );
177
178        ensure!(
179            key_file.exists(),
180            "Key file {} does not exist!",
181            key_file.display()
182        );
183
184        let cert_ext_str = cert_file
185            .extension()
186            .context("failed to get certificate extension")?;
187        let key_ext_str = key_file
188            .extension()
189            .context("failed to get key extension")?;
190
191        // Unnecessary for running MalwareDB, but some unit tests fail without this check.
192        if rustls::crypto::CryptoProvider::get_default().is_none() {
193            rustls::crypto::aws_lc_rs::default_provider()
194                .install_default()
195                .map_err(|_| anyhow!("failed to install AWS-LC crypto provider"))?;
196        }
197
198        let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem" {
199            RustlsConfig::from_pem_file(cert_file, key_file)
200                .await
201                .context("failed to load or parse certificate and key pem files")?
202        } else if cert_ext_str == "der" && key_ext_str == "der" {
203            let cert_contents =
204                std::fs::read(cert_file).context("failed to read certificate file")?;
205            let key_contents =
206                std::fs::read(key_file).context("failed to read private key file")?;
207            RustlsConfig::from_der(vec![cert_contents], key_contents)
208                .await
209                .context("failed to parse certificate and key der files")?
210        } else {
211            bail!(
212                "Unknown or unmatched certificate and key file extensions {} and {}",
213                cert_ext_str.display(),
214                key_ext_str.display()
215            );
216        };
217
218        self.tls_config = Some(config);
219        Ok(self)
220    }
221
222    /// Indicate that Malware DB should advertise itself via multicast DNS.
223    /// Default is false.
224    #[must_use]
225    pub fn enable_mdns(mut self) -> Self {
226        self.mdns = true;
227        self
228    }
229
230    /// Generate the state object.
231    ///
232    /// # Errors
233    ///
234    /// An error occurs if the database can't be reached.
235    pub async fn into_state(self) -> Result<State> {
236        let db_config = self.db_type.get_config().await?;
237        let keys = self.db_type.get_encryption_keys().await?;
238
239        Ok(State {
240            port: self.port,
241            directory: self.directory,
242            max_upload: self.max_upload,
243            ip: self.ip,
244            db_type: Arc::new(self.db_type),
245            started: SystemTime::now(),
246            db_config,
247            keys,
248            #[cfg(feature = "vt")]
249            vt_client: self.vt_client,
250            tls_config: self.tls_config,
251            mdns: if self.mdns {
252                Some(ServiceDaemon::new()?)
253            } else {
254                None
255            },
256        })
257    }
258}
259
260/// State & configuration of the running server instance
261pub struct State {
262    /// The port which will be used to listen for connections.
263    pub port: u16,
264
265    /// The directory to store malware samples if we're keeping them.
266    pub directory: Option<PathBuf>,
267
268    /// Maximum upload size
269    pub max_upload: usize,
270
271    /// The IP to use for listening for connections
272    pub ip: IpAddr,
273
274    /// Handle to the database connection
275    pub db_type: Arc<db::DatabaseType>,
276
277    /// Start time of the server
278    pub started: SystemTime,
279
280    /// Configuration which is stored in the database
281    pub db_config: MDBConfig,
282
283    /// File encryption keys, may be empty
284    pub(crate) keys: HashMap<u32, FileEncryption>,
285
286    /// Virus Total API key
287    #[cfg(feature = "vt")]
288    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
289
290    /// TLS configuration constructed from certificate and private key files
291    tls_config: Option<RustlsConfig>,
292
293    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
294    mdns: Option<ServiceDaemon>,
295}
296
297impl State {
298    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
299    ///
300    /// # Errors
301    ///
302    /// * If the file can't be written.
303    /// * If a necessary sub-directory can't be created.
304    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
305        if let Some(dest_path) = &self.directory {
306            let mut hasher = Sha256::new();
307            hasher.update(data);
308            let sha256 = hex::encode(hasher.finalize());
309
310            // Trait `HashPath` needs to be re-worked so it can work with Strings.
311            // This code below ends up making the String into ASCII representations of the hash
312            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
313            let hashed_path = format!(
314                "{}/{}/{}/{}",
315                &sha256[0..2],
316                &sha256[2..4],
317                &sha256[4..6],
318                sha256
319            );
320
321            // The path which has the file name included, with the storage directory prepended.
322            //let hashed_path = result.hashed_path(3);
323            let mut dest_path = dest_path.clone();
324            dest_path.push(hashed_path);
325
326            // Remove the file name so we can just have the directory path.
327            let mut just_the_dir = dest_path.clone();
328            just_the_dir.pop();
329            std::fs::create_dir_all(just_the_dir)?;
330
331            let data = if self.db_config.compression {
332                let buff = Cursor::new(data);
333                let mut compressed = Vec::with_capacity(data.len() / 2);
334                zstd::stream::copy_encode(buff, &mut compressed, 4)?;
335                compressed
336            } else {
337                data.to_vec()
338            };
339
340            let data = if let Some(key_id) = self.db_config.default_key {
341                if let Some(key) = self.keys.get(&key_id) {
342                    let nonce = key.nonce();
343                    self.db_type
344                        .set_file_nonce(&sha256, nonce.as_deref())
345                        .await?;
346                    key.encrypt(&data, nonce)?
347                } else {
348                    bail!("Key not available!")
349                }
350            } else {
351                data
352            };
353
354            std::fs::write(dest_path, data)?;
355
356            Ok(true)
357        } else {
358            Ok(false)
359        }
360    }
361
362    /// Retrieve a sample given the SHA-256 hash
363    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
364    ///
365    /// # Errors
366    ///
367    /// * The file could not be read, maybe because it doesn't exist.
368    /// * Failure to decrypt or decompress (corruption).
369    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
370        if let Some(dest_path) = &self.directory {
371            let path = format!(
372                "{}/{}/{}/{}",
373                &sha256[0..2],
374                &sha256[2..4],
375                &sha256[4..6],
376                sha256
377            );
378            // Trait `HashPath` needs to be re-worked so it can work with Strings.
379            // This code below ends up making the String into ASCII representations of the hash
380            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
381            //let path = sha256.as_bytes().iter().hashed_path(3);
382            let contents = std::fs::read(dest_path.join(path))?;
383
384            let contents = if self.keys.is_empty() {
385                // We don't have file encryption enabled
386                contents
387            } else {
388                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
389                if let Some(key_id) = key_id {
390                    if let Some(key) = self.keys.get(&key_id) {
391                        key.decrypt(&contents, nonce)?
392                    } else {
393                        bail!("File was encrypted but we don't have tke key!")
394                    }
395                } else {
396                    // File was not encrypted
397                    contents
398                }
399            };
400
401            if contents.starts_with(&GZIP_MAGIC) {
402                let buff = Cursor::new(contents);
403                let mut decompressor = GzDecoder::new(buff);
404                let mut decompressed: Vec<u8> = vec![];
405                decompressor.read_to_end(&mut decompressed)?;
406                Ok(decompressed)
407            } else if contents.starts_with(&ZSTD_MAGIC) {
408                let buff = Cursor::new(contents);
409                let mut decompressed: Vec<u8> = vec![];
410                zstd::stream::copy_decode(buff, &mut decompressed)?;
411                Ok(decompressed)
412            } else {
413                Ok(contents)
414            }
415        } else {
416            bail!("files are not saved")
417        }
418    }
419
420    /// Get the duration for which the server has been running
421    ///
422    /// # Panics
423    ///
424    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
425    #[must_use]
426    pub fn since(&self) -> Duration {
427        let now = SystemTime::now();
428        now.duration_since(self.started).unwrap()
429    }
430
431    /// Get server information
432    ///
433    /// # Errors
434    ///
435    /// An error would occur if the Postgres server could not be reached.
436    pub async fn get_info(&self) -> Result<ServerInfo> {
437        let db_info = self.db_type.db_info().await?;
438        let uptime = Local::now() - self.since();
439        let mem_size = app_memory_usage_fetcher::get_memory_usage_string().unwrap_or_default();
440
441        Ok(ServerInfo {
442            os_name: std::env::consts::OS.into(),
443            memory_used: mem_size,
444            num_samples: db_info.num_files,
445            num_users: db_info.num_users,
446            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
447            mdb_version: MDB_VERSION_SEMVER.clone(),
448            db_version: db_info.version,
449            db_size: db_info.size,
450            instance_name: self.db_config.name.clone(),
451            vt_support: cfg!(feature = "vt"),
452            yara_enabled: cfg!(feature = "yara"),
453        })
454    }
455
456    /// The server listens and responds to requests. Does not return unless there's an error.
457    ///
458    /// # Errors
459    ///
460    /// * If the certificate and private key could not be parsed or are not valid.
461    /// * If the IP address and port are already in use.
462    /// * If the service doesn't have permission to open the port.
463    pub async fn serve(
464        self,
465        #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
466    ) -> Result<()> {
467        let socket = SocketAddr::new(self.ip, self.port);
468        let arc_self = Arc::new(self);
469        let db_info = arc_self.db_type.clone();
470
471        #[cfg(feature = "yara")]
472        {
473            if arc_self.directory.is_some() {
474                start_yara_process(arc_self.clone());
475            }
476        }
477
478        tokio::spawn(async move {
479            loop {
480                match db_info.cleanup().await {
481                    Ok(removed) => {
482                        trace!("Pagination cleanup succeeded, {removed} searches removed");
483                    }
484                    Err(e) => warn!("Pagination cleanup failed: {e}"),
485                }
486
487                tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
488            }
489        });
490
491        if let Some(mdns) = &arc_self.mdns {
492            let host_name = format!("{}.local.", arc_self.ip);
493            let ssl = arc_self.tls_config.is_some();
494            let properties = [("ssl", ssl.to_string()), ("version", MDB_VERSION.into())];
495            let service = {
496                let mut service = ServiceInfo::new(
497                    malwaredb_api::MDNS_NAME,
498                    &arc_self.db_config.name,
499                    &host_name,
500                    &arc_self.ip,
501                    arc_self.port,
502                    &properties[..],
503                )?;
504                if arc_self.ip.is_unspecified() {
505                    service = service.enable_addr_auto();
506                }
507                service
508            };
509            trace!("Registering MDNS service...");
510            mdns.register(service)?;
511        }
512
513        if let Some(tls_config) = arc_self.tls_config.clone() {
514            println!("Listening on https://{socket:?}");
515            let handle = axum_server::Handle::<SocketAddr>::new();
516            let server_future = axum_server::bind_rustls(socket, tls_config)
517                .serve(http::app(arc_self).into_make_service());
518            tokio::select! {
519                () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
520                    handle.graceful_shutdown(Some(Duration::from_secs(30))),
521                res = server_future => res?,
522            }
523            warn!("Terminate signal received");
524        } else {
525            println!("Listening on http://{socket:?}");
526            let listener = TcpListener::bind(socket)
527                .await
528                .context(format!("failed to bind socket {socket}"))?;
529            axum::serve(listener, http::app(arc_self).into_make_service())
530                .with_graceful_shutdown(shutdown_signal(
531                    #[cfg(target_family = "windows")]
532                    rx,
533                ))
534                .await?;
535            warn!("Terminate signal received");
536        }
537        Ok(())
538    }
539}
540
541impl Debug for State {
542    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
543        let tls_mode = if self.tls_config.is_some() {
544            ", TLS mode"
545        } else {
546            ""
547        };
548        write!(
549            f,
550            "MDB state, port {}, database {:?}{tls_mode}",
551            self.port, self.db_type
552        )
553    }
554}
555
556#[cfg(feature = "yara")]
557#[allow(clippy::needless_pass_by_value)]
558fn start_yara_process(state: Arc<State>) {
559    let state_clone = state.clone();
560    tokio::spawn(async move {
561        let state_clone = state_clone.clone();
562        loop {
563            let tasks = match state_clone
564                .clone()
565                .db_type
566                .get_unfinished_yara_tasks()
567                .await
568            {
569                Ok(tasks) => tasks,
570                Err(e) => {
571                    warn!("Failed to get Yara tasks: {e}");
572                    continue;
573                }
574            };
575            for task in tasks {
576                let state_clone = state_clone.clone();
577                let (hashes, last_file_id) = match state_clone
578                    .db_type
579                    .user_allowed_files_by_sha256(task.user_id, task.last_file_id)
580                    .await
581                {
582                    Ok(hashes) => hashes,
583                    Err(e) => {
584                        warn!("Failed to get user allowed files: {e}");
585                        continue;
586                    }
587                };
588
589                if hashes.is_empty() {
590                    if let Err(e) = state_clone
591                        .db_type
592                        .mark_yara_task_as_finished(task.id)
593                        .await
594                    {
595                        warn!("Failed to mark yara task as finished: {e}");
596                    }
597                    continue;
598                }
599                tokio::spawn(async move {
600                    for hash in hashes {
601                        let bytes = match state_clone.clone().retrieve_bytes(&hash).await {
602                            Ok(bytes) => bytes,
603                            Err(e) => {
604                                warn!("Failed to retrieve bytes for hash {hash}: {e}");
605                                continue;
606                            }
607                        };
608                        let matches = task.process_yara_rules(&bytes).unwrap_or_else(|e| {
609                            warn!("Failed to process Yara rules: {e}");
610                            Vec::new()
611                        });
612
613                        for match_ in matches {
614                            if let Err(e) = state_clone
615                                .db_type
616                                .add_yara_match(task.id, &match_, &hash)
617                                .await
618                            {
619                                warn!("Failed to add Yara match: {e}");
620                            }
621                        }
622                    }
623                    if let Err(e) = state_clone
624                        .db_type
625                        .yara_add_next_file_id(task.id, last_file_id)
626                        .await
627                    {
628                        warn!("Failed to update yara task next file id: {e}");
629                    }
630                });
631            }
632            tokio::time::sleep(Duration::from_secs(5)).await;
633        }
634    });
635}
636
637/// Enable graceful shutdown
638/// <https://github.com/tokio-rs/axum/discussions/1500>
639async fn shutdown_signal(
640    #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
641) {
642    let ctrl_c = async {
643        tokio::signal::ctrl_c()
644            .await
645            .expect("failed to install Ctrl+C handler");
646    };
647
648    #[cfg(unix)]
649    let terminate = async {
650        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
651            .expect("failed to install signal handler")
652            .recv()
653            .await;
654    };
655
656    #[cfg(not(unix))]
657    let terminate = std::future::pending::<()>();
658
659    #[cfg(target_family = "windows")]
660    if let Some(rx_inner) = &mut rx {
661        let terminate_rx = rx_inner.recv();
662
663        tokio::select! {
664            () = ctrl_c => {},
665            () = terminate => {},
666            Some(()) = terminate_rx => {},
667        }
668    } else {
669        tokio::select! {
670            () = ctrl_c => {},
671            () = terminate => {},
672        }
673    }
674
675    #[cfg(not(target_family = "windows"))]
676    tokio::select! {
677        () = ctrl_c => {},
678        () = terminate => {},
679    }
680}