malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9/// Cryptographic functionality for file storage
10pub mod crypto;
11
12/// Database I/O
13pub mod db;
14
15/// HTTP Server
16pub mod http;
17
18/// Entropy functions
19pub mod utils;
20
21/// Virus Total integration
22#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26use crate::crypto::FileEncryption;
27use crate::db::MDBConfig;
28use malwaredb_api::ServerInfo;
29//use utils::HashPath;
30
31use std::collections::HashMap;
32use std::fmt::{Debug, Formatter};
33use std::io::{Cursor, Read};
34use std::net::{IpAddr, Ipv4Addr, SocketAddr};
35use std::path::PathBuf;
36use std::sync::{Arc, LazyLock};
37use std::time::{Duration, SystemTime};
38
39use anyhow::{anyhow, bail, ensure, Context, Result};
40use axum_server::tls_rustls::RustlsConfig;
41use chrono::Local;
42use chrono_humanize::{Accuracy, HumanTime, Tense};
43use flate2::read::GzDecoder;
44use mdns_sd::{ServiceDaemon, ServiceInfo};
45use sha2::{Digest, Sha256};
46use tokio::net::TcpListener;
47use tracing::{trace, warn};
48
49/// MDB version
50pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52/// MDB version as a semantic version object
53pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
54    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
55
56/// How often stale pagination searches should be cleaned up: 1 day
57/// Could be used if future cleanup operations are needed.
58pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
59
60/// Gzip's magic number to see if a file is compressed
61pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
62
63/// Zstd magic number to see if a file is compressed
64pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
65
66/// Builder for server configuration
67pub struct StateBuilder {
68    /// The port which will be used to listen for connections.
69    pub port: u16,
70
71    /// The directory to store malware samples if we're keeping them.
72    pub directory: Option<PathBuf>,
73
74    /// Maximum upload size
75    pub max_upload: usize,
76
77    /// The IP to use for listening for connections
78    pub ip: IpAddr,
79
80    /// Handle to the database connection
81    db_type: db::DatabaseType,
82
83    /// Virus Total API key
84    #[cfg(feature = "vt")]
85    vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
86
87    /// TLS configuration constructed from certificate and private key files
88    tls_config: Option<RustlsConfig>,
89
90    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
91    mdns: bool,
92}
93
94impl StateBuilder {
95    /// Create the builder starting with the database configuration, and optionally, the
96    /// certificate for communicating with Postgres.
97    ///
98    /// # Errors
99    ///
100    /// An error occurs if the database configuration isn't valid or if an error occurs connecting
101    /// to the database.
102    pub async fn new(db_string: &str, pg_cert: Option<PathBuf>) -> Result<Self> {
103        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
104
105        Ok(Self {
106            port: 8080,
107            directory: None,
108            max_upload: 104_857_600, /* 100 MiB */
109            ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
110            db_type,
111            #[cfg(feature = "vt")]
112            vt_client: None,
113            tls_config: None,
114            mdns: false,
115        })
116    }
117
118    /// Specify the port to listen on.
119    /// Default: `8080`
120    #[must_use]
121    pub fn port(mut self, port: u16) -> Self {
122        self.port = port;
123        self
124    }
125
126    /// Specify the directory to store malware samples if we're keeping them.
127    ///
128    /// Default: No directory, no file saving
129    #[must_use]
130    pub fn directory(mut self, directory: PathBuf) -> Self {
131        self.directory = Some(directory);
132        self
133    }
134
135    /// Specify the maximum upload size in bytes.
136    /// Default is 100 MiB.
137    #[must_use]
138    pub fn max_upload(mut self, max_upload: usize) -> Self {
139        self.max_upload = max_upload;
140        self
141    }
142
143    /// Indicate the IP address the server will list on.
144    /// Default: 127.0.0.1
145    #[must_use]
146    pub fn ip(mut self, ip: IpAddr) -> Self {
147        self.ip = ip;
148        self
149    }
150
151    /// Provide the Virus Total API key.
152    #[must_use]
153    #[cfg(feature = "vt")]
154    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
155    pub fn vt_client(mut self, vt_client: malwaredb_virustotal::VirusTotalClient) -> Self {
156        self.vt_client = Some(vt_client);
157        self
158    }
159
160    /// Provide the certificate and private key for TLS mode.
161    /// Files must match: both as PEM or both as DER.
162    ///
163    /// # Errors
164    ///
165    /// An error results if either file doesn't exist, not in the same format, or cannot be parsed.
166    pub async fn tls(mut self, cert_file: PathBuf, key_file: PathBuf) -> Result<Self> {
167        ensure!(
168            cert_file.exists(),
169            "Certificate file {} does not exist!",
170            cert_file.display()
171        );
172
173        ensure!(
174            key_file.exists(),
175            "Key file {} does not exist!",
176            key_file.display()
177        );
178
179        let cert_ext_str = cert_file
180            .extension()
181            .context("failed to get certificate extension")?;
182        let key_ext_str = key_file
183            .extension()
184            .context("failed to get key extension")?;
185
186        // Unnecessary for running MalwareDB, but some unit tests fail without this check.
187        if rustls::crypto::CryptoProvider::get_default().is_none() {
188            rustls::crypto::aws_lc_rs::default_provider()
189                .install_default()
190                .map_err(|_| anyhow!("failed to install AWS-LC crypto provider"))?;
191        }
192
193        let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem" {
194            RustlsConfig::from_pem_file(cert_file, key_file)
195                .await
196                .context("failed to load or parse certificate and key pem files")?
197        } else if cert_ext_str == "der" && key_ext_str == "der" {
198            let cert_contents =
199                std::fs::read(cert_file).context("failed to read certificate file")?;
200            let key_contents =
201                std::fs::read(key_file).context("failed to read private key file")?;
202            RustlsConfig::from_der(vec![cert_contents], key_contents)
203                .await
204                .context("failed to parse certificate and key der files")?
205        } else {
206            bail!(
207                "Unknown or unmatched certificate and key file extensions {} and {}",
208                cert_ext_str.display(),
209                key_ext_str.display()
210            );
211        };
212
213        self.tls_config = Some(config);
214        Ok(self)
215    }
216
217    /// Indicate that Malware DB should advertise itself via multicast DNS.
218    /// Default is false.
219    #[must_use]
220    pub fn enable_mdns(mut self) -> Self {
221        self.mdns = true;
222        self
223    }
224
225    /// Generate the state object.
226    ///
227    /// # Errors
228    ///
229    /// An error occurs if the database can't be reached.
230    pub async fn into_state(self) -> Result<State> {
231        let db_config = self.db_type.get_config().await?;
232        let keys = self.db_type.get_encryption_keys().await?;
233
234        Ok(State {
235            port: self.port,
236            directory: self.directory,
237            max_upload: self.max_upload,
238            ip: self.ip,
239            db_type: Arc::new(self.db_type),
240            started: SystemTime::now(),
241            db_config,
242            keys,
243            #[cfg(feature = "vt")]
244            vt_client: self.vt_client,
245            tls_config: self.tls_config,
246            mdns: if self.mdns {
247                Some(ServiceDaemon::new()?)
248            } else {
249                None
250            },
251        })
252    }
253}
254
255/// State & configuration of the running server instance
256pub struct State {
257    /// The port which will be used to listen for connections.
258    pub port: u16,
259
260    /// The directory to store malware samples if we're keeping them.
261    pub directory: Option<PathBuf>,
262
263    /// Maximum upload size
264    pub max_upload: usize,
265
266    /// The IP to use for listening for connections
267    pub ip: IpAddr,
268
269    /// Handle to the database connection
270    pub db_type: Arc<db::DatabaseType>,
271
272    /// Start time of the server
273    pub started: SystemTime,
274
275    /// Configuration which is stored in the database
276    pub db_config: MDBConfig,
277
278    /// File encryption keys, may be empty
279    pub(crate) keys: HashMap<u32, FileEncryption>,
280
281    /// Virus Total API key
282    #[cfg(feature = "vt")]
283    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
284
285    /// TLS configuration constructed from certificate and private key files
286    tls_config: Option<RustlsConfig>,
287
288    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
289    mdns: Option<ServiceDaemon>,
290}
291
292impl State {
293    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
294    ///
295    /// # Errors
296    ///
297    /// * If the file can't be written.
298    /// * If a necessary sub-directory can't be created.
299    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
300        if let Some(dest_path) = &self.directory {
301            let mut hasher = Sha256::new();
302            hasher.update(data);
303            let sha256 = hex::encode(hasher.finalize());
304
305            // Trait `HashPath` needs to be re-worked so it can work with Strings.
306            // This code below ends up making the String into ASCII representations of the hash
307            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
308            let hashed_path = format!(
309                "{}/{}/{}/{}",
310                &sha256[0..2],
311                &sha256[2..4],
312                &sha256[4..6],
313                sha256
314            );
315
316            // The path which has the file name included, with the storage directory prepended.
317            //let hashed_path = result.hashed_path(3);
318            let mut dest_path = dest_path.clone();
319            dest_path.push(hashed_path);
320
321            // Remove the file name so we can just have the directory path.
322            let mut just_the_dir = dest_path.clone();
323            just_the_dir.pop();
324            std::fs::create_dir_all(just_the_dir)?;
325
326            let data = if self.db_config.compression {
327                let buff = Cursor::new(data);
328                let mut compressed = Vec::with_capacity(data.len() / 2);
329                zstd::stream::copy_encode(buff, &mut compressed, 4)?;
330                compressed
331            } else {
332                data.to_vec()
333            };
334
335            let data = if let Some(key_id) = self.db_config.default_key {
336                if let Some(key) = self.keys.get(&key_id) {
337                    let nonce = key.nonce();
338                    self.db_type
339                        .set_file_nonce(&sha256, nonce.as_deref())
340                        .await?;
341                    key.encrypt(&data, nonce)?
342                } else {
343                    bail!("Key not available!")
344                }
345            } else {
346                data
347            };
348
349            std::fs::write(dest_path, data)?;
350
351            Ok(true)
352        } else {
353            Ok(false)
354        }
355    }
356
357    /// Retrieve a sample given the SHA-256 hash
358    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
359    ///
360    /// # Errors
361    ///
362    /// * The file could not be read, maybe because it doesn't exist.
363    /// * Failure to decrypt or decompress (corruption).
364    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
365        if let Some(dest_path) = &self.directory {
366            let path = format!(
367                "{}/{}/{}/{}",
368                &sha256[0..2],
369                &sha256[2..4],
370                &sha256[4..6],
371                sha256
372            );
373            // Trait `HashPath` needs to be re-worked so it can work with Strings.
374            // This code below ends up making the String into ASCII representations of the hash
375            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
376            //let path = sha256.as_bytes().iter().hashed_path(3);
377            let contents = std::fs::read(dest_path.join(path))?;
378
379            let contents = if self.keys.is_empty() {
380                // We don't have file encryption enabled
381                contents
382            } else {
383                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
384                if let Some(key_id) = key_id {
385                    if let Some(key) = self.keys.get(&key_id) {
386                        key.decrypt(&contents, nonce)?
387                    } else {
388                        bail!("File was encrypted but we don't have tke key!")
389                    }
390                } else {
391                    // File was not encrypted
392                    contents
393                }
394            };
395
396            if contents.starts_with(&GZIP_MAGIC) {
397                let buff = Cursor::new(contents);
398                let mut decompressor = GzDecoder::new(buff);
399                let mut decompressed: Vec<u8> = vec![];
400                decompressor.read_to_end(&mut decompressed)?;
401                Ok(decompressed)
402            } else if contents.starts_with(&ZSTD_MAGIC) {
403                let buff = Cursor::new(contents);
404                let mut decompressed: Vec<u8> = vec![];
405                zstd::stream::copy_decode(buff, &mut decompressed)?;
406                Ok(decompressed)
407            } else {
408                Ok(contents)
409            }
410        } else {
411            bail!("files are not saved")
412        }
413    }
414
415    /// Get the duration for which the server has been running
416    ///
417    /// # Panics
418    ///
419    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
420    #[must_use]
421    pub fn since(&self) -> Duration {
422        let now = SystemTime::now();
423        now.duration_since(self.started).unwrap()
424    }
425
426    /// Get server information
427    ///
428    /// # Errors
429    ///
430    /// An error would occur if the Postgres server could not be reached.
431    pub async fn get_info(&self) -> Result<ServerInfo> {
432        let db_info = self.db_type.db_info().await?;
433        let uptime = Local::now() - self.since();
434        let mem_size = app_memory_usage_fetcher::get_memory_usage_string().unwrap_or_default();
435
436        Ok(ServerInfo {
437            os_name: std::env::consts::OS.into(),
438            memory_used: mem_size,
439            num_samples: db_info.num_files,
440            num_users: db_info.num_users,
441            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
442            mdb_version: MDB_VERSION_SEMVER.clone(),
443            db_version: db_info.version,
444            db_size: db_info.size,
445            instance_name: self.db_config.name.clone(),
446        })
447    }
448
449    /// The server listens and responds to requests. Does not return unless there's an error.
450    ///
451    /// # Errors
452    ///
453    /// * If the certificate and private key could not be parsed or are not valid.
454    /// * If the IP address and port are already in use.
455    /// * If the service doesn't have permission to open the port.
456    pub async fn serve(
457        self,
458        #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
459    ) -> Result<()> {
460        let socket = SocketAddr::new(self.ip, self.port);
461        let arc_self = Arc::new(self);
462        let db_info = arc_self.db_type.clone();
463
464        tokio::spawn(async move {
465            loop {
466                match db_info.cleanup().await {
467                    Ok(removed) => {
468                        trace!("Pagination cleanup succeeded, {removed} searches removed");
469                    }
470                    Err(e) => warn!("Pagination cleanup failed: {e}"),
471                }
472
473                tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
474            }
475        });
476
477        if arc_self.mdns.is_some() && !arc_self.ip.is_loopback() {
478            if let Err(e) = arc_self.mdns_register().await {
479                warn!("Failed to register MDNS service: {e}");
480            }
481        }
482
483        if let Some(tls_config) = arc_self.tls_config.clone() {
484            println!("Listening on https://{socket:?}");
485            let handle = axum_server::Handle::<SocketAddr>::new();
486            let server_future = axum_server::bind_rustls(socket, tls_config)
487                .serve(http::app(arc_self).into_make_service());
488            tokio::select! {
489                () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
490                    handle.graceful_shutdown(Some(Duration::from_secs(30))),
491                res = server_future => res?,
492            }
493            warn!("Terminate signal received");
494        } else {
495            println!("Listening on http://{socket:?}");
496            let listener = TcpListener::bind(socket)
497                .await
498                .context(format!("failed to bind socket {socket}"))?;
499            axum::serve(listener, http::app(arc_self).into_make_service())
500                .with_graceful_shutdown(shutdown_signal(
501                    #[cfg(target_family = "windows")]
502                    rx,
503                ))
504                .await?;
505            warn!("Terminate signal received");
506        }
507        Ok(())
508    }
509
510    /// mdns registration function as a separate function so in the future data can be updated
511    /// without a restart
512    async fn mdns_register(&self) -> Result<()> {
513        if let Some(mdns) = &self.mdns {
514            let db_config = self.db_type.get_config().await?;
515            let host_name = format!("{}.local.", self.ip);
516            let ssl = self.tls_config.is_some();
517            let properties = [("ssl", ssl.to_string())];
518            let service = ServiceInfo::new(
519                malwaredb_api::MDNS_NAME,
520                &db_config.name,
521                &host_name,
522                self.ip.to_string(),
523                self.port,
524                &properties[..],
525            )?;
526            mdns.register(service)?;
527        }
528
529        Ok(())
530    }
531}
532
533impl Debug for State {
534    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
535        let tls_mode = if self.tls_config.is_some() {
536            ", TLS mode"
537        } else {
538            ""
539        };
540        write!(
541            f,
542            "MDB state, port {}, database {:?}{tls_mode}",
543            self.port, self.db_type
544        )
545    }
546}
547
548/// Enable graceful shutdown
549/// <https://github.com/tokio-rs/axum/discussions/1500>
550async fn shutdown_signal(
551    #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
552) {
553    let ctrl_c = async {
554        tokio::signal::ctrl_c()
555            .await
556            .expect("failed to install Ctrl+C handler");
557    };
558
559    #[cfg(unix)]
560    let terminate = async {
561        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
562            .expect("failed to install signal handler")
563            .recv()
564            .await;
565    };
566
567    #[cfg(not(unix))]
568    let terminate = std::future::pending::<()>();
569
570    #[cfg(target_family = "windows")]
571    if let Some(rx_inner) = &mut rx {
572        let terminate_rx = rx_inner.recv();
573
574        tokio::select! {
575            () = ctrl_c => {},
576            () = terminate => {},
577            Some(()) = terminate_rx => {},
578        }
579    } else {
580        tokio::select! {
581            () = ctrl_c => {},
582            () = terminate => {},
583        }
584    }
585
586    #[cfg(not(target_family = "windows"))]
587    tokio::select! {
588        () = ctrl_c => {},
589        () = terminate => {},
590    }
591}