Skip to main content

malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9/// Cryptographic functionality for file storage
10pub mod crypto;
11
12/// Database I/O
13pub mod db;
14
15/// HTTP Server
16pub mod http;
17
18/// Entropy functions
19pub mod utils;
20
21/// Virus Total integration
22#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26/// Yara-related functionality
27#[cfg(feature = "yara")]
28pub mod yara;
29
30use crate::crypto::FileEncryption;
31use crate::db::MDBConfig;
32use malwaredb_api::ServerInfo;
33//use utils::HashPath;
34
35use std::collections::HashMap;
36use std::fmt::{Debug, Formatter};
37use std::io::{Cursor, Read};
38use std::net::{IpAddr, Ipv4Addr, SocketAddr};
39use std::path::PathBuf;
40use std::sync::{Arc, LazyLock};
41use std::time::{Duration, SystemTime};
42
43use anyhow::{anyhow, bail, ensure, Context, Result};
44use axum_server::tls_rustls::RustlsConfig;
45use chrono::Local;
46use chrono_humanize::{Accuracy, HumanTime, Tense};
47use flate2::read::GzDecoder;
48use mdns_sd::{ServiceDaemon, ServiceInfo};
49use sha2::{Digest, Sha256};
50use tokio::net::TcpListener;
51use tracing::{trace, warn};
52
53/// MDB version
54pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
55
56/// MDB version as a semantic version object
57pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
58    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
59
60/// How often stale pagination searches should be cleaned up: 1 day
61/// Could be used if future cleanup operations are needed.
62pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
63
64/// Gzip's magic number to see if a file is compressed
65pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
66
67/// Zstd magic number to see if a file is compressed
68pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
69
70/// Builder for server configuration
71pub struct StateBuilder {
72    /// The port which will be used to listen for connections.
73    pub port: u16,
74
75    /// The directory to store malware samples if we're keeping them.
76    pub directory: Option<PathBuf>,
77
78    /// Maximum upload size
79    pub max_upload: usize,
80
81    /// The IP to use for listening for connections
82    pub ip: IpAddr,
83
84    /// Handle to the database connection
85    db_type: db::DatabaseType,
86
87    /// Virus Total API key
88    #[cfg(feature = "vt")]
89    vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
90
91    /// TLS configuration constructed from certificate and private key files
92    tls_config: Option<RustlsConfig>,
93
94    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
95    mdns: bool,
96}
97
98impl StateBuilder {
99    /// Create the builder starting with the database configuration, and optionally, the
100    /// certificate for communicating with Postgres.
101    ///
102    /// # Errors
103    ///
104    /// An error occurs if the database configuration isn't valid or if an error occurs connecting
105    /// to the database.
106    pub async fn new(db_string: &str, pg_cert: Option<PathBuf>) -> Result<Self> {
107        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
108
109        Ok(Self {
110            port: 8080,
111            directory: None,
112            max_upload: 104_857_600, /* 100 MiB */
113            ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
114            db_type,
115            #[cfg(feature = "vt")]
116            vt_client: None,
117            tls_config: None,
118            mdns: false,
119        })
120    }
121
122    /// Specify the port to listen on.
123    /// Default: `8080`
124    #[must_use]
125    pub fn port(mut self, port: u16) -> Self {
126        self.port = port;
127        self
128    }
129
130    /// Specify the directory to store malware samples if we're keeping them.
131    ///
132    /// Default: No directory, no file saving
133    #[must_use]
134    pub fn directory(mut self, directory: PathBuf) -> Self {
135        self.directory = Some(directory);
136        self
137    }
138
139    /// Specify the maximum upload size in bytes.
140    /// Default is 100 MiB.
141    #[must_use]
142    pub fn max_upload(mut self, max_upload: usize) -> Self {
143        self.max_upload = max_upload;
144        self
145    }
146
147    /// Indicate the IP address the server will list on.
148    /// Default: 127.0.0.1
149    #[must_use]
150    pub fn ip(mut self, ip: IpAddr) -> Self {
151        self.ip = ip;
152        self
153    }
154
155    /// Provide the Virus Total API key.
156    #[must_use]
157    #[cfg(feature = "vt")]
158    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
159    pub fn vt_client(mut self, vt_client: malwaredb_virustotal::VirusTotalClient) -> Self {
160        self.vt_client = Some(vt_client);
161        self
162    }
163
164    /// Provide the certificate and private key for TLS mode.
165    /// Files must match: both as PEM or both as DER.
166    ///
167    /// # Errors
168    ///
169    /// An error results if either file doesn't exist, not in the same format, or cannot be parsed.
170    pub async fn tls(mut self, cert_file: PathBuf, key_file: PathBuf) -> Result<Self> {
171        ensure!(
172            cert_file.exists(),
173            "Certificate file {} does not exist!",
174            cert_file.display()
175        );
176
177        ensure!(
178            key_file.exists(),
179            "Key file {} does not exist!",
180            key_file.display()
181        );
182
183        let cert_ext_str = cert_file
184            .extension()
185            .context("failed to get certificate extension")?;
186        let key_ext_str = key_file
187            .extension()
188            .context("failed to get key extension")?;
189
190        // Unnecessary for running MalwareDB, but some unit tests fail without this check.
191        if rustls::crypto::CryptoProvider::get_default().is_none() {
192            rustls::crypto::aws_lc_rs::default_provider()
193                .install_default()
194                .map_err(|_| anyhow!("failed to install AWS-LC crypto provider"))?;
195        }
196
197        let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem" {
198            RustlsConfig::from_pem_file(cert_file, key_file)
199                .await
200                .context("failed to load or parse certificate and key pem files")?
201        } else if cert_ext_str == "der" && key_ext_str == "der" {
202            let cert_contents =
203                std::fs::read(cert_file).context("failed to read certificate file")?;
204            let key_contents =
205                std::fs::read(key_file).context("failed to read private key file")?;
206            RustlsConfig::from_der(vec![cert_contents], key_contents)
207                .await
208                .context("failed to parse certificate and key der files")?
209        } else {
210            bail!(
211                "Unknown or unmatched certificate and key file extensions {} and {}",
212                cert_ext_str.display(),
213                key_ext_str.display()
214            );
215        };
216
217        self.tls_config = Some(config);
218        Ok(self)
219    }
220
221    /// Indicate that Malware DB should advertise itself via multicast DNS.
222    /// Default is false.
223    #[must_use]
224    pub fn enable_mdns(mut self) -> Self {
225        self.mdns = true;
226        self
227    }
228
229    /// Generate the state object.
230    ///
231    /// # Errors
232    ///
233    /// An error occurs if the database can't be reached.
234    pub async fn into_state(self) -> Result<State> {
235        let db_config = self.db_type.get_config().await?;
236        let keys = self.db_type.get_encryption_keys().await?;
237
238        Ok(State {
239            port: self.port,
240            directory: self.directory,
241            max_upload: self.max_upload,
242            ip: self.ip,
243            db_type: Arc::new(self.db_type),
244            started: SystemTime::now(),
245            db_config,
246            keys,
247            #[cfg(feature = "vt")]
248            vt_client: self.vt_client,
249            tls_config: self.tls_config,
250            mdns: if self.mdns {
251                Some(ServiceDaemon::new()?)
252            } else {
253                None
254            },
255        })
256    }
257}
258
259/// State & configuration of the running server instance
260pub struct State {
261    /// The port which will be used to listen for connections.
262    pub port: u16,
263
264    /// The directory to store malware samples if we're keeping them.
265    pub directory: Option<PathBuf>,
266
267    /// Maximum upload size
268    pub max_upload: usize,
269
270    /// The IP to use for listening for connections
271    pub ip: IpAddr,
272
273    /// Handle to the database connection
274    pub db_type: Arc<db::DatabaseType>,
275
276    /// Start time of the server
277    pub started: SystemTime,
278
279    /// Configuration which is stored in the database
280    pub db_config: MDBConfig,
281
282    /// File encryption keys, may be empty
283    pub(crate) keys: HashMap<u32, FileEncryption>,
284
285    /// Virus Total API key
286    #[cfg(feature = "vt")]
287    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
288
289    /// TLS configuration constructed from certificate and private key files
290    tls_config: Option<RustlsConfig>,
291
292    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
293    mdns: Option<ServiceDaemon>,
294}
295
296impl State {
297    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
298    ///
299    /// # Errors
300    ///
301    /// * If the file can't be written.
302    /// * If a necessary sub-directory can't be created.
303    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
304        if let Some(dest_path) = &self.directory {
305            let mut hasher = Sha256::new();
306            hasher.update(data);
307            let sha256 = hex::encode(hasher.finalize());
308
309            // Trait `HashPath` needs to be re-worked so it can work with Strings.
310            // This code below ends up making the String into ASCII representations of the hash
311            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
312            let hashed_path = format!(
313                "{}/{}/{}/{}",
314                &sha256[0..2],
315                &sha256[2..4],
316                &sha256[4..6],
317                sha256
318            );
319
320            // The path which has the file name included, with the storage directory prepended.
321            //let hashed_path = result.hashed_path(3);
322            let mut dest_path = dest_path.clone();
323            dest_path.push(hashed_path);
324
325            // Remove the file name so we can just have the directory path.
326            let mut just_the_dir = dest_path.clone();
327            just_the_dir.pop();
328            std::fs::create_dir_all(just_the_dir)?;
329
330            let data = if self.db_config.compression {
331                let buff = Cursor::new(data);
332                let mut compressed = Vec::with_capacity(data.len() / 2);
333                zstd::stream::copy_encode(buff, &mut compressed, 4)?;
334                compressed
335            } else {
336                data.to_vec()
337            };
338
339            let data = if let Some(key_id) = self.db_config.default_key {
340                if let Some(key) = self.keys.get(&key_id) {
341                    let nonce = key.nonce();
342                    self.db_type
343                        .set_file_nonce(&sha256, nonce.as_deref())
344                        .await?;
345                    key.encrypt(&data, nonce)?
346                } else {
347                    bail!("Key not available!")
348                }
349            } else {
350                data
351            };
352
353            std::fs::write(dest_path, data)?;
354
355            Ok(true)
356        } else {
357            Ok(false)
358        }
359    }
360
361    /// Retrieve a sample given the SHA-256 hash
362    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
363    ///
364    /// # Errors
365    ///
366    /// * The file could not be read, maybe because it doesn't exist.
367    /// * Failure to decrypt or decompress (corruption).
368    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
369        if let Some(dest_path) = &self.directory {
370            let path = format!(
371                "{}/{}/{}/{}",
372                &sha256[0..2],
373                &sha256[2..4],
374                &sha256[4..6],
375                sha256
376            );
377            // Trait `HashPath` needs to be re-worked so it can work with Strings.
378            // This code below ends up making the String into ASCII representations of the hash
379            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
380            //let path = sha256.as_bytes().iter().hashed_path(3);
381            let contents = std::fs::read(dest_path.join(path))?;
382
383            let contents = if self.keys.is_empty() {
384                // We don't have file encryption enabled
385                contents
386            } else {
387                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
388                if let Some(key_id) = key_id {
389                    if let Some(key) = self.keys.get(&key_id) {
390                        key.decrypt(&contents, nonce)?
391                    } else {
392                        bail!("File was encrypted but we don't have tke key!")
393                    }
394                } else {
395                    // File was not encrypted
396                    contents
397                }
398            };
399
400            if contents.starts_with(&GZIP_MAGIC) {
401                let buff = Cursor::new(contents);
402                let mut decompressor = GzDecoder::new(buff);
403                let mut decompressed: Vec<u8> = vec![];
404                decompressor.read_to_end(&mut decompressed)?;
405                Ok(decompressed)
406            } else if contents.starts_with(&ZSTD_MAGIC) {
407                let buff = Cursor::new(contents);
408                let mut decompressed: Vec<u8> = vec![];
409                zstd::stream::copy_decode(buff, &mut decompressed)?;
410                Ok(decompressed)
411            } else {
412                Ok(contents)
413            }
414        } else {
415            bail!("files are not saved")
416        }
417    }
418
419    /// Get the duration for which the server has been running
420    ///
421    /// # Panics
422    ///
423    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
424    #[must_use]
425    pub fn since(&self) -> Duration {
426        let now = SystemTime::now();
427        now.duration_since(self.started).unwrap()
428    }
429
430    /// Get server information
431    ///
432    /// # Errors
433    ///
434    /// An error would occur if the Postgres server could not be reached.
435    pub async fn get_info(&self) -> Result<ServerInfo> {
436        let db_info = self.db_type.db_info().await?;
437        let uptime = Local::now() - self.since();
438        let mem_size = app_memory_usage_fetcher::get_memory_usage_string().unwrap_or_default();
439
440        Ok(ServerInfo {
441            os_name: std::env::consts::OS.into(),
442            memory_used: mem_size,
443            num_samples: db_info.num_files,
444            num_users: db_info.num_users,
445            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
446            mdb_version: MDB_VERSION_SEMVER.clone(),
447            db_version: db_info.version,
448            db_size: db_info.size,
449            instance_name: self.db_config.name.clone(),
450        })
451    }
452
453    /// The server listens and responds to requests. Does not return unless there's an error.
454    ///
455    /// # Errors
456    ///
457    /// * If the certificate and private key could not be parsed or are not valid.
458    /// * If the IP address and port are already in use.
459    /// * If the service doesn't have permission to open the port.
460    pub async fn serve(
461        self,
462        #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
463    ) -> Result<()> {
464        let socket = SocketAddr::new(self.ip, self.port);
465        let arc_self = Arc::new(self);
466        let db_info = arc_self.db_type.clone();
467
468        #[cfg(feature = "yara")]
469        {
470            if arc_self.directory.is_some() {
471                start_yara_process(arc_self.clone());
472            }
473        }
474
475        tokio::spawn(async move {
476            loop {
477                match db_info.cleanup().await {
478                    Ok(removed) => {
479                        trace!("Pagination cleanup succeeded, {removed} searches removed");
480                    }
481                    Err(e) => warn!("Pagination cleanup failed: {e}"),
482                }
483
484                tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
485            }
486        });
487
488        if arc_self.mdns.is_some() && !arc_self.ip.is_loopback() {
489            if let Err(e) = arc_self.mdns_register().await {
490                warn!("Failed to register MDNS service: {e}");
491            }
492        }
493
494        if let Some(tls_config) = arc_self.tls_config.clone() {
495            println!("Listening on https://{socket:?}");
496            let handle = axum_server::Handle::<SocketAddr>::new();
497            let server_future = axum_server::bind_rustls(socket, tls_config)
498                .serve(http::app(arc_self).into_make_service());
499            tokio::select! {
500                () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
501                    handle.graceful_shutdown(Some(Duration::from_secs(30))),
502                res = server_future => res?,
503            }
504            warn!("Terminate signal received");
505        } else {
506            println!("Listening on http://{socket:?}");
507            let listener = TcpListener::bind(socket)
508                .await
509                .context(format!("failed to bind socket {socket}"))?;
510            axum::serve(listener, http::app(arc_self).into_make_service())
511                .with_graceful_shutdown(shutdown_signal(
512                    #[cfg(target_family = "windows")]
513                    rx,
514                ))
515                .await?;
516            warn!("Terminate signal received");
517        }
518        Ok(())
519    }
520
521    /// mdns registration function as a separate function so in the future data can be updated
522    /// without a restart
523    async fn mdns_register(&self) -> Result<()> {
524        if let Some(mdns) = &self.mdns {
525            let db_config = self.db_type.get_config().await?;
526            let host_name = format!("{}.local.", self.ip);
527            let ssl = self.tls_config.is_some();
528            let properties = [("ssl", ssl.to_string())];
529            let service = ServiceInfo::new(
530                malwaredb_api::MDNS_NAME,
531                &db_config.name,
532                &host_name,
533                self.ip.to_string(),
534                self.port,
535                &properties[..],
536            )?;
537            mdns.register(service)?;
538        }
539
540        Ok(())
541    }
542}
543
544impl Debug for State {
545    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
546        let tls_mode = if self.tls_config.is_some() {
547            ", TLS mode"
548        } else {
549            ""
550        };
551        write!(
552            f,
553            "MDB state, port {}, database {:?}{tls_mode}",
554            self.port, self.db_type
555        )
556    }
557}
558
559#[cfg(feature = "yara")]
560#[allow(clippy::needless_pass_by_value)]
561fn start_yara_process(state: Arc<State>) {
562    let state_clone = state.clone();
563    tokio::spawn(async move {
564        let state_clone = state_clone.clone();
565        loop {
566            let tasks = match state_clone
567                .clone()
568                .db_type
569                .get_unfinished_yara_tasks()
570                .await
571            {
572                Ok(tasks) => tasks,
573                Err(e) => {
574                    warn!("Failed to get Yara tasks: {e}");
575                    continue;
576                }
577            };
578            for task in tasks {
579                let state_clone = state_clone.clone();
580                let (hashes, last_file_id) = match state_clone
581                    .db_type
582                    .user_allowed_files_by_sha256(task.user_id, task.last_file_id)
583                    .await
584                {
585                    Ok(hashes) => hashes,
586                    Err(e) => {
587                        warn!("Failed to get user allowed files: {e}");
588                        continue;
589                    }
590                };
591
592                if hashes.is_empty() {
593                    if let Err(e) = state_clone
594                        .db_type
595                        .mark_yara_task_as_finished(task.id)
596                        .await
597                    {
598                        warn!("Failed to mark yara task as finished: {e}");
599                    }
600                    continue;
601                }
602                tokio::spawn(async move {
603                    for hash in hashes {
604                        let bytes = match state_clone.clone().retrieve_bytes(&hash).await {
605                            Ok(bytes) => bytes,
606                            Err(e) => {
607                                warn!("Failed to retrieve bytes for hash {hash}: {e}");
608                                continue;
609                            }
610                        };
611                        let matches = task.process_yara_rules(&bytes).unwrap_or_else(|e| {
612                            warn!("Failed to process Yara rules: {e}");
613                            Vec::new()
614                        });
615
616                        for match_ in matches {
617                            if let Err(e) = state_clone
618                                .db_type
619                                .add_yara_match(task.id, &match_, &hash)
620                                .await
621                            {
622                                warn!("Failed to add Yara match: {e}");
623                            }
624                        }
625                    }
626                    if let Err(e) = state_clone
627                        .db_type
628                        .yara_add_next_file_id(task.id, last_file_id)
629                        .await
630                    {
631                        warn!("Failed to update yara task next file id: {e}");
632                    }
633                });
634            }
635            tokio::time::sleep(Duration::from_secs(5)).await;
636        }
637    });
638}
639
640/// Enable graceful shutdown
641/// <https://github.com/tokio-rs/axum/discussions/1500>
642async fn shutdown_signal(
643    #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
644) {
645    let ctrl_c = async {
646        tokio::signal::ctrl_c()
647            .await
648            .expect("failed to install Ctrl+C handler");
649    };
650
651    #[cfg(unix)]
652    let terminate = async {
653        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
654            .expect("failed to install signal handler")
655            .recv()
656            .await;
657    };
658
659    #[cfg(not(unix))]
660    let terminate = std::future::pending::<()>();
661
662    #[cfg(target_family = "windows")]
663    if let Some(rx_inner) = &mut rx {
664        let terminate_rx = rx_inner.recv();
665
666        tokio::select! {
667            () = ctrl_c => {},
668            () = terminate => {},
669            Some(()) = terminate_rx => {},
670        }
671    } else {
672        tokio::select! {
673            () = ctrl_c => {},
674            () = terminate => {},
675        }
676    }
677
678    #[cfg(not(target_family = "windows"))]
679    tokio::select! {
680        () = ctrl_c => {},
681        () = terminate => {},
682    }
683}