malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9/// Cryptographic functionality for file storage
10pub mod crypto;
11
12/// Database I/O
13pub mod db;
14
15/// HTTP Server
16pub mod http;
17
18/// Entropy functions
19pub mod utils;
20
21/// Virus Total integration
22#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26use crate::crypto::FileEncryption;
27use crate::db::MDBConfig;
28use malwaredb_api::ServerInfo;
29//use utils::HashPath;
30
31use std::collections::HashMap;
32use std::fmt::{Debug, Formatter};
33use std::io::{Cursor, Read};
34use std::net::{IpAddr, SocketAddr};
35use std::path::PathBuf;
36use std::sync::{Arc, LazyLock};
37use std::time::{Duration, SystemTime};
38
39use anyhow::{bail, ensure, Context, Result};
40use axum_server::tls_rustls::RustlsConfig;
41use chrono::Local;
42use chrono_humanize::{Accuracy, HumanTime, Tense};
43use flate2::read::GzDecoder;
44use mdns_sd::{ServiceDaemon, ServiceInfo};
45use sha2::{Digest, Sha256};
46use tokio::net::TcpListener;
47use tracing::{debug, error, trace, warn};
48
49/// MDB version
50pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52/// MDB version as a semantic version object
53pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
54    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
55
56/// How often stale pagination searches should be cleaned up: 1 day
57/// Could be used if future cleanup operations are needed.
58pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
59
60/// Gzip's magic number to see if a file is compressed
61pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
62
63/// Zstd magic number to see if a file is compressed
64pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
65
66/// State & configuration of the running server instance
67pub struct State {
68    /// The port which will be used to listen for connections.
69    pub port: u16,
70
71    /// The directory to store malware samples if we're keeping them.
72    pub directory: Option<PathBuf>,
73
74    /// Maximum upload size
75    pub max_upload: usize,
76
77    /// The IP to use for listening for connections
78    pub ip: IpAddr,
79
80    /// Handle to the database connection
81    pub db_type: Arc<db::DatabaseType>,
82
83    /// Start time of the server
84    pub started: SystemTime,
85
86    /// Configuration which is stored in the database
87    pub db_config: MDBConfig,
88
89    /// File encryption keys, may be empty
90    pub(crate) keys: HashMap<u32, FileEncryption>,
91
92    /// Virus Total API key
93    #[cfg(feature = "vt")]
94    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
95
96    /// Https certificate file, optionally containing the CA cert as well (server and CA certs
97    /// concatenated)
98    cert: Option<PathBuf>,
99
100    /// Https private key file
101    key: Option<PathBuf>,
102
103    /// If Malware DB should be advertised via Multicast DNS (also known as Bonjour or Zeroconf)
104    mdns: bool,
105}
106
107impl State {
108    /// New server state object given a few configuration parameters
109    ///
110    /// # Errors
111    /// * If there's a certificate and not a key, or a key and not a certificate; if either don't exist.
112    /// * If the sample storage directory doesn't exist.
113    /// * if the database configuration isn't valid or if the Postgres server can't be reached.
114    #[allow(clippy::too_many_arguments)]
115    pub async fn new(
116        port: u16,
117        directory: Option<PathBuf>,
118        max_upload: usize,
119        ip: IpAddr,
120        db_string: &str,
121        cert: Option<PathBuf>,
122        key: Option<PathBuf>,
123        pg_cert: Option<PathBuf>,
124        mdns: bool,
125        #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
126    ) -> Result<Self> {
127        if let Some(dir) = &directory {
128            if !dir.exists() {
129                bail!("data directory {} does not exist!", dir.display());
130            }
131        }
132
133        if cert.is_some() != key.is_some() {
134            bail!("Either both the https cert & key are provided, or none are provided.");
135        }
136
137        if let Some(cert_file) = &cert {
138            ensure!(
139                cert_file.exists(),
140                "Certificate file {} does not exist!",
141                cert_file.display()
142            );
143        }
144
145        if let Some(key_file) = &key {
146            ensure!(
147                key_file.exists(),
148                "Key file {} does not exist!",
149                key_file.display()
150            );
151        }
152
153        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
154        // TODO: allow config and keys to be refreshed so changes don't require a restart
155        let db_config = db_type.get_config().await?;
156        let keys = db_type.get_encryption_keys().await?;
157
158        Ok(Self {
159            port,
160            directory,
161            max_upload,
162            ip,
163            db_type: Arc::new(db_type),
164            db_config,
165            keys,
166            #[cfg(feature = "vt")]
167            vt_client,
168            started: SystemTime::now(),
169            cert,
170            key,
171            mdns,
172        })
173    }
174
175    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
176    ///
177    /// # Errors
178    ///
179    /// * If the file can't be written.
180    /// * If a necessary sub-directory can't be created.
181    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
182        if let Some(dest_path) = &self.directory {
183            let mut hasher = Sha256::new();
184            hasher.update(data);
185            let sha256 = hex::encode(hasher.finalize());
186
187            // Trait `HashPath` needs to be re-worked so it can work with Strings.
188            // This code below ends up making the String into ASCII representations of the hash
189            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
190            let hashed_path = format!(
191                "{}/{}/{}/{}",
192                &sha256[0..2],
193                &sha256[2..4],
194                &sha256[4..6],
195                sha256
196            );
197
198            // The path which has the file name included, with the storage directory prepended.
199            //let hashed_path = result.hashed_path(3);
200            let mut dest_path = dest_path.clone();
201            dest_path.push(hashed_path);
202
203            // Remove the file name so we can just have the directory path.
204            let mut just_the_dir = dest_path.clone();
205            just_the_dir.pop();
206            std::fs::create_dir_all(just_the_dir)?;
207
208            let data = if self.db_config.compression {
209                let buff = Cursor::new(data);
210                let mut compressed = Vec::with_capacity(data.len() / 2);
211                zstd::stream::copy_encode(buff, &mut compressed, 4)?;
212                compressed
213            } else {
214                data.to_vec()
215            };
216
217            let data = if let Some(key_id) = self.db_config.default_key {
218                if let Some(key) = self.keys.get(&key_id) {
219                    let nonce = key.nonce();
220                    self.db_type
221                        .set_file_nonce(&sha256, nonce.as_deref())
222                        .await?;
223                    key.encrypt(&data, nonce)?
224                } else {
225                    bail!("Key not available!")
226                }
227            } else {
228                data
229            };
230
231            std::fs::write(dest_path, data)?;
232
233            Ok(true)
234        } else {
235            Ok(false)
236        }
237    }
238
239    /// Retrieve a sample given the SHA-256 hash
240    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
241    ///
242    /// # Errors
243    ///
244    /// * The file could not be read, maybe because it doesn't exist.
245    /// * Failure to decrypt or decompress (corruption).
246    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
247        if let Some(dest_path) = &self.directory {
248            let path = format!(
249                "{}/{}/{}/{}",
250                &sha256[0..2],
251                &sha256[2..4],
252                &sha256[4..6],
253                sha256
254            );
255            // Trait `HashPath` needs to be re-worked so it can work with Strings.
256            // This code below ends up making the String into ASCII representations of the hash
257            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
258            //let path = sha256.as_bytes().iter().hashed_path(3);
259            let contents = std::fs::read(dest_path.join(path))?;
260
261            let contents = if self.keys.is_empty() {
262                // We don't have file encryption enabled
263                contents
264            } else {
265                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
266                if let Some(key_id) = key_id {
267                    if let Some(key) = self.keys.get(&key_id) {
268                        key.decrypt(&contents, nonce)?
269                    } else {
270                        bail!("File was encrypted but we don't have tke key!")
271                    }
272                } else {
273                    // File was not encrypted
274                    contents
275                }
276            };
277
278            if contents.starts_with(&GZIP_MAGIC) {
279                let buff = Cursor::new(contents);
280                let mut decompressor = GzDecoder::new(buff);
281                let mut decompressed: Vec<u8> = vec![];
282                decompressor.read_to_end(&mut decompressed)?;
283                Ok(decompressed)
284            } else if contents.starts_with(&ZSTD_MAGIC) {
285                let buff = Cursor::new(contents);
286                let mut decompressed: Vec<u8> = vec![];
287                zstd::stream::copy_decode(buff, &mut decompressed)?;
288                Ok(decompressed)
289            } else {
290                Ok(contents)
291            }
292        } else {
293            bail!("files are not saved")
294        }
295    }
296
297    /// Get the duration for which the server has been running
298    ///
299    /// # Panics
300    ///
301    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
302    #[must_use]
303    pub fn since(&self) -> Duration {
304        let now = SystemTime::now();
305        now.duration_since(self.started).unwrap()
306    }
307
308    /// Get server information
309    ///
310    /// # Errors
311    ///
312    /// An error would occur if the Postgres server could not be reached.
313    pub async fn get_info(&self) -> Result<ServerInfo> {
314        let db_info = self.db_type.db_info().await?;
315        let uptime = Local::now() - self.since();
316        let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
317            humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
318        } else {
319            String::new()
320        };
321
322        Ok(ServerInfo {
323            os_name: std::env::consts::OS.into(),
324            memory_used: mem_size,
325            num_samples: db_info.num_files,
326            num_users: db_info.num_users,
327            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
328            mdb_version: MDB_VERSION_SEMVER.clone(),
329            db_version: db_info.version,
330            db_size: db_info.size,
331            instance_name: self.db_config.name.clone(),
332        })
333    }
334
335    /// The server listens and responds to requests. Does not return unless there's an error.
336    ///
337    /// # Errors
338    ///
339    /// * If the certificate and private key could not be parsed or are not valid.
340    /// * If the IP address and port are already in use.
341    /// * If the service doesn't have permission to open the port.
342    ///
343    /// # Panics
344    ///
345    /// * The `.unwrap()` calls won't panic because the data would have already been validated.
346    pub async fn serve(self) -> Result<()> {
347        let socket = SocketAddr::new(self.ip, self.port);
348        let arc_self = Arc::new(self);
349        let db_ifo = arc_self.db_type.clone();
350
351        tokio::spawn(async move {
352            loop {
353                match db_ifo.cleanup().await {
354                    Ok(removed) => {
355                        trace!("Pagination cleanup succeeded, {removed} searches removed");
356                    }
357                    Err(e) => warn!("Pagination cleanup failed: {e}"),
358                }
359
360                tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
361            }
362        });
363
364        if arc_self.mdns {
365            if arc_self.ip.is_loopback() {
366                debug!("Refusing to start mdns responder for localhost");
367            } else {
368                trace!("Enabling MDNS advertising");
369                match ServiceDaemon::new() {
370                    Ok(mdns) => {
371                        let host_name = format!("{}.local.", arc_self.ip);
372                        let ssl = arc_self.cert.is_some() && arc_self.key.is_some();
373                        let properties = [("ssl", ssl.to_string())];
374                        match ServiceInfo::new(
375                            malwaredb_api::MDNS_NAME,
376                            &arc_self.db_config.name,
377                            &host_name,
378                            arc_self.ip.to_string(),
379                            arc_self.port,
380                            &properties[..],
381                        ) {
382                            Ok(service) => match mdns.register(service) {
383                                Ok(()) => trace!("MalwareDB mdns registered"),
384                                Err(e) => error!("Failed to register service: {e}"),
385                            },
386                            Err(e) => error!("Failed to publish MDNS service: {e}"),
387                        }
388                    }
389                    Err(e) => error!("Failed to open port for MDNS responder: {e}"),
390                }
391            }
392        }
393
394        if arc_self.cert.is_some() && arc_self.key.is_some() {
395            let cert_path = arc_self.cert.as_ref().unwrap();
396            let key_path = arc_self.key.as_ref().unwrap();
397            let cert_ext_str = cert_path
398                .extension()
399                .context("failed to get certificate extension")?;
400            let key_ext_str = key_path
401                .extension()
402                .context("failed to get key extension")?;
403
404            // Unnecessary for running MalwareDB, but some unit tests fail without this check.
405            if rustls::crypto::CryptoProvider::get_default().is_none() {
406                rustls::crypto::aws_lc_rs::default_provider()
407                    .install_default()
408                    .expect("Failed to load crypto provider");
409            }
410
411            let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
412            {
413                RustlsConfig::from_pem_file(cert_path, key_path)
414                    .await
415                    .context("failed to load or parse certificate and key files")?
416            } else if cert_ext_str == "der" && key_ext_str == "der" {
417                let cert_contents =
418                    std::fs::read(cert_path).context("failed to read certificate file")?;
419                let key_contents =
420                    std::fs::read(key_path).context("failed to read certificate file")?;
421                RustlsConfig::from_der(vec![cert_contents], key_contents)
422                    .await
423                    .context("failed to parse certificate and key files as DER")?
424            } else {
425                bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
426            };
427
428            println!("Listening on https://{socket:?}");
429            axum_server::bind_rustls(socket, config)
430                .serve(http::app(arc_self).into_make_service())
431                .await?;
432        } else {
433            println!("Listening on http://{socket:?}");
434            let listener = TcpListener::bind(socket)
435                .await
436                .context(format!("failed to bind socket {socket}"))?;
437            axum::serve(listener, http::app(arc_self).into_make_service()).await?;
438        }
439        Ok(())
440    }
441}
442
443impl Debug for State {
444    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
445        write!(
446            f,
447            "MDB state, port {}, database {:?}",
448            self.port, self.db_type
449        )
450    }
451}