malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9/// Cryptographic functionality for file storage
10pub mod crypto;
11
12/// Database I/O
13pub mod db;
14
15/// HTTP Server
16pub mod http;
17
18/// Entropy functions
19pub mod utils;
20
21/// Virus Total integration
22#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26use crate::crypto::FileEncryption;
27use crate::db::MDBConfig;
28use malwaredb_api::ServerInfo;
29//use utils::HashPath;
30
31use std::collections::HashMap;
32use std::fmt::{Debug, Formatter};
33use std::io::{Cursor, Read};
34use std::net::{IpAddr, SocketAddr};
35use std::path::PathBuf;
36use std::sync::{Arc, LazyLock};
37use std::time::{Duration, SystemTime};
38
39use anyhow::{bail, ensure, Context, Result};
40use axum_server::tls_rustls::RustlsConfig;
41use chrono::Local;
42use chrono_humanize::{Accuracy, HumanTime, Tense};
43use flate2::read::GzDecoder;
44use mdns_sd::{ServiceDaemon, ServiceInfo};
45use sha2::{Digest, Sha256};
46use tokio::net::TcpListener;
47use tracing::{debug, error, trace, warn};
48
49/// MDB version
50pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52/// MDB version as a semantic version object
53pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
54    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
55
56/// How often stale pagination searches should be cleaned up: 1 day
57/// Could be used if future cleanup operations are needed.
58pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
59
60/// Gzip's magic number to see if a file is compressed
61pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
62
63/// Zstd magic number to see if a file is compressed
64pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
65
66/// State & configuration of the running server instance
67pub struct State {
68    /// The port which will be used to listen for connections.
69    pub port: u16,
70
71    /// The directory to store malware samples if we're keeping them.
72    pub directory: Option<PathBuf>,
73
74    /// Maximum upload size
75    pub max_upload: usize,
76
77    /// The IP to use for listening for connections
78    pub ip: IpAddr,
79
80    /// Handle to the database connection
81    pub db_type: Arc<db::DatabaseType>,
82
83    /// Start time of the server
84    pub started: SystemTime,
85
86    /// Configuration which is stored in the database
87    pub db_config: MDBConfig,
88
89    /// File encryption keys, may be empty
90    pub(crate) keys: HashMap<u32, FileEncryption>,
91
92    /// Virus Total API key
93    #[cfg(feature = "vt")]
94    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
95
96    /// Https certificate file, optionally containing the CA cert as well (server and CA certs
97    /// concatenated)
98    cert: Option<PathBuf>,
99
100    /// Https private key file
101    key: Option<PathBuf>,
102
103    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
104    mdns: bool,
105}
106
107impl State {
108    /// New server state object given a few configuration parameters
109    ///
110    /// # Errors
111    /// * If there's a certificate and not a key, or a key and not a certificate; if either don't exist.
112    /// * If the sample storage directory doesn't exist.
113    /// * if the database configuration isn't valid or if the Postgres server can't be reached.
114    #[allow(clippy::too_many_arguments)]
115    pub async fn new(
116        port: u16,
117        directory: Option<PathBuf>,
118        max_upload: usize,
119        ip: IpAddr,
120        db_string: &str,
121        cert: Option<PathBuf>,
122        key: Option<PathBuf>,
123        pg_cert: Option<PathBuf>,
124        mdns: bool,
125        #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
126    ) -> Result<Self> {
127        if let Some(dir) = &directory {
128            if !dir.exists() {
129                bail!("data directory {} does not exist!", dir.display());
130            }
131        }
132
133        if cert.is_some() != key.is_some() {
134            bail!("Either both the https cert & key are provided, or none are provided.");
135        }
136
137        if let Some(cert_file) = &cert {
138            ensure!(
139                cert_file.exists(),
140                "Certificate file {} does not exist!",
141                cert_file.display()
142            );
143        }
144
145        if let Some(key_file) = &key {
146            ensure!(
147                key_file.exists(),
148                "Key file {} does not exist!",
149                key_file.display()
150            );
151        }
152
153        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
154        // TODO: allow config and keys to be refreshed so changes don't require a restart
155        let db_config = db_type.get_config().await?;
156        let keys = db_type.get_encryption_keys().await?;
157
158        Ok(Self {
159            port,
160            directory,
161            max_upload,
162            ip,
163            db_type: Arc::new(db_type),
164            db_config,
165            keys,
166            #[cfg(feature = "vt")]
167            vt_client,
168            started: SystemTime::now(),
169            cert,
170            key,
171            mdns,
172        })
173    }
174
175    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
176    ///
177    /// # Errors
178    ///
179    /// * If the file can't be written.
180    /// * If a necessary sub-directory can't be created.
181    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
182        if let Some(dest_path) = &self.directory {
183            let mut hasher = Sha256::new();
184            hasher.update(data);
185            let sha256 = hex::encode(hasher.finalize());
186
187            // Trait `HashPath` needs to be re-worked so it can work with Strings.
188            // This code below ends up making the String into ASCII representations of the hash
189            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
190            let hashed_path = format!(
191                "{}/{}/{}/{}",
192                &sha256[0..2],
193                &sha256[2..4],
194                &sha256[4..6],
195                sha256
196            );
197
198            // The path which has the file name included, with the storage directory prepended.
199            //let hashed_path = result.hashed_path(3);
200            let mut dest_path = dest_path.clone();
201            dest_path.push(hashed_path);
202
203            // Remove the file name so we can just have the directory path.
204            let mut just_the_dir = dest_path.clone();
205            just_the_dir.pop();
206            std::fs::create_dir_all(just_the_dir)?;
207
208            let data = if self.db_config.compression {
209                let buff = Cursor::new(data);
210                let mut compressed = Vec::with_capacity(data.len() / 2);
211                zstd::stream::copy_encode(buff, &mut compressed, 4)?;
212                compressed
213            } else {
214                data.to_vec()
215            };
216
217            let data = if let Some(key_id) = self.db_config.default_key {
218                if let Some(key) = self.keys.get(&key_id) {
219                    let nonce = key.nonce();
220                    self.db_type
221                        .set_file_nonce(&sha256, nonce.as_deref())
222                        .await?;
223                    key.encrypt(&data, nonce)?
224                } else {
225                    bail!("Key not available!")
226                }
227            } else {
228                data
229            };
230
231            std::fs::write(dest_path, data)?;
232
233            Ok(true)
234        } else {
235            Ok(false)
236        }
237    }
238
239    /// Retrieve a sample given the SHA-256 hash
240    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
241    ///
242    /// # Errors
243    ///
244    /// * The file could not be read, maybe because it doesn't exist.
245    /// * Failure to decrypt or decompress (corruption).
246    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
247        if let Some(dest_path) = &self.directory {
248            let path = format!(
249                "{}/{}/{}/{}",
250                &sha256[0..2],
251                &sha256[2..4],
252                &sha256[4..6],
253                sha256
254            );
255            // Trait `HashPath` needs to be re-worked so it can work with Strings.
256            // This code below ends up making the String into ASCII representations of the hash
257            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
258            //let path = sha256.as_bytes().iter().hashed_path(3);
259            let contents = std::fs::read(dest_path.join(path))?;
260
261            let contents = if self.keys.is_empty() {
262                // We don't have file encryption enabled
263                contents
264            } else {
265                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
266                if let Some(key_id) = key_id {
267                    if let Some(key) = self.keys.get(&key_id) {
268                        key.decrypt(&contents, nonce)?
269                    } else {
270                        bail!("File was encrypted but we don't have tke key!")
271                    }
272                } else {
273                    // File was not encrypted
274                    contents
275                }
276            };
277
278            if contents.starts_with(&GZIP_MAGIC) {
279                let buff = Cursor::new(contents);
280                let mut decompressor = GzDecoder::new(buff);
281                let mut decompressed: Vec<u8> = vec![];
282                decompressor.read_to_end(&mut decompressed)?;
283                Ok(decompressed)
284            } else if contents.starts_with(&ZSTD_MAGIC) {
285                let buff = Cursor::new(contents);
286                let mut decompressed: Vec<u8> = vec![];
287                zstd::stream::copy_decode(buff, &mut decompressed)?;
288                Ok(decompressed)
289            } else {
290                Ok(contents)
291            }
292        } else {
293            bail!("files are not saved")
294        }
295    }
296
297    /// Get the duration for which the server has been running
298    ///
299    /// # Panics
300    ///
301    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
302    #[must_use]
303    pub fn since(&self) -> Duration {
304        let now = SystemTime::now();
305        now.duration_since(self.started).unwrap()
306    }
307
308    /// Get server information
309    ///
310    /// # Errors
311    ///
312    /// An error would occur if the Postgres server could not be reached.
313    pub async fn get_info(&self) -> Result<ServerInfo> {
314        let db_info = self.db_type.db_info().await?;
315        let uptime = Local::now() - self.since();
316        let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
317            humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
318        } else {
319            String::new()
320        };
321
322        Ok(ServerInfo {
323            os_name: std::env::consts::OS.into(),
324            memory_used: mem_size,
325            num_samples: db_info.num_files,
326            num_users: db_info.num_users,
327            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
328            mdb_version: MDB_VERSION_SEMVER.clone(),
329            db_version: db_info.version,
330            db_size: db_info.size,
331            instance_name: self.db_config.name.clone(),
332        })
333    }
334
335    /// The server listens and responds to requests. Does not return unless there's an error.
336    ///
337    /// # Errors
338    ///
339    /// * If the certificate and private key could not be parsed or are not valid.
340    /// * If the IP address and port are already in use.
341    /// * If the service doesn't have permission to open the port.
342    ///
343    /// # Panics
344    ///
345    /// * The `.unwrap()` calls won't panic because the data would have already been validated.
346    pub async fn serve(
347        self,
348        #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
349    ) -> Result<()> {
350        let socket = SocketAddr::new(self.ip, self.port);
351        let arc_self = Arc::new(self);
352        let db_ifo = arc_self.db_type.clone();
353
354        tokio::spawn(async move {
355            loop {
356                match db_ifo.cleanup().await {
357                    Ok(removed) => {
358                        trace!("Pagination cleanup succeeded, {removed} searches removed");
359                    }
360                    Err(e) => warn!("Pagination cleanup failed: {e}"),
361                }
362
363                tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
364            }
365        });
366
367        if arc_self.mdns {
368            if arc_self.ip.is_loopback() {
369                debug!("Refusing to start mdns responder for localhost");
370            } else {
371                trace!("Enabling MDNS advertising");
372                match ServiceDaemon::new() {
373                    Ok(mdns) => {
374                        let host_name = format!("{}.local.", arc_self.ip);
375                        let ssl = arc_self.cert.is_some() && arc_self.key.is_some();
376                        let properties = [("ssl", ssl.to_string())];
377                        match ServiceInfo::new(
378                            malwaredb_api::MDNS_NAME,
379                            &arc_self.db_config.name,
380                            &host_name,
381                            arc_self.ip.to_string(),
382                            arc_self.port,
383                            &properties[..],
384                        ) {
385                            Ok(service) => match mdns.register(service) {
386                                Ok(()) => trace!("MalwareDB mdns registered"),
387                                Err(e) => error!("Failed to register service: {e}"),
388                            },
389                            Err(e) => error!("Failed to publish MDNS service: {e}"),
390                        }
391                    }
392                    Err(e) => error!("Failed to open port for MDNS responder: {e}"),
393                }
394            }
395        }
396
397        if arc_self.cert.is_some() && arc_self.key.is_some() {
398            let cert_path = arc_self.cert.as_ref().unwrap();
399            let key_path = arc_self.key.as_ref().unwrap();
400            let cert_ext_str = cert_path
401                .extension()
402                .context("failed to get certificate extension")?;
403            let key_ext_str = key_path
404                .extension()
405                .context("failed to get key extension")?;
406
407            // Unnecessary for running MalwareDB, but some unit tests fail without this check.
408            if rustls::crypto::CryptoProvider::get_default().is_none() {
409                rustls::crypto::aws_lc_rs::default_provider()
410                    .install_default()
411                    .expect("Failed to load crypto provider");
412            }
413
414            let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
415            {
416                RustlsConfig::from_pem_file(cert_path, key_path)
417                    .await
418                    .context("failed to load or parse certificate and key files")?
419            } else if cert_ext_str == "der" && key_ext_str == "der" {
420                let cert_contents =
421                    std::fs::read(cert_path).context("failed to read certificate file")?;
422                let key_contents =
423                    std::fs::read(key_path).context("failed to read certificate file")?;
424                RustlsConfig::from_der(vec![cert_contents], key_contents)
425                    .await
426                    .context("failed to parse certificate and key files as DER")?
427            } else {
428                bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
429            };
430
431            println!("Listening on https://{socket:?}");
432            let handle = axum_server::Handle::new();
433            let server_future = axum_server::bind_rustls(socket, config)
434                .serve(http::app(arc_self).into_make_service());
435            tokio::select! {
436                () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
437                    handle.graceful_shutdown(Some(Duration::from_secs(30))),
438                res = server_future => res?,
439            }
440            warn!("Terminate signal received");
441        } else {
442            println!("Listening on http://{socket:?}");
443            let listener = TcpListener::bind(socket)
444                .await
445                .context(format!("failed to bind socket {socket}"))?;
446            axum::serve(listener, http::app(arc_self).into_make_service())
447                .with_graceful_shutdown(shutdown_signal(
448                    #[cfg(target_family = "windows")]
449                    rx,
450                ))
451                .await?;
452            warn!("Terminate signal received");
453        }
454        Ok(())
455    }
456}
457
458impl Debug for State {
459    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
460        write!(
461            f,
462            "MDB state, port {}, database {:?}",
463            self.port, self.db_type
464        )
465    }
466}
467
468/// Enable graceful shutdown
469/// <https://github.com/tokio-rs/axum/discussions/1500>
470async fn shutdown_signal(
471    #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
472) {
473    let ctrl_c = async {
474        tokio::signal::ctrl_c()
475            .await
476            .expect("failed to install Ctrl+C handler");
477    };
478
479    #[cfg(unix)]
480    let terminate = async {
481        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
482            .expect("failed to install signal handler")
483            .recv()
484            .await;
485    };
486
487    #[cfg(not(unix))]
488    let terminate = std::future::pending::<()>();
489
490    #[cfg(target_family = "windows")]
491    if let Some(rx_inner) = &mut rx {
492        let terminate_rx = rx_inner.recv();
493
494        tokio::select! {
495            () = ctrl_c => {},
496            () = terminate => {},
497            Some(()) = terminate_rx => {},
498        }
499    } else {
500        tokio::select! {
501            () = ctrl_c => {},
502            () = terminate => {},
503        }
504    }
505
506    #[cfg(not(target_family = "windows"))]
507    tokio::select! {
508        () = ctrl_c => {},
509        () = terminate => {},
510    }
511}