malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8/// Cryptographic functionality for file storage
9pub mod crypto;
10
11/// Database I/O
12pub mod db;
13
14/// HTTP Server
15pub mod http;
16
17/// Entropy functions
18pub mod utils;
19
20/// Virus Total communication
21#[cfg(feature = "vt")]
22pub mod vt;
23
24use crate::crypto::FileEncryption;
25use crate::db::MDBConfig;
26use malwaredb_api::ServerInfo;
27//use utils::HashPath;
28
29use std::collections::HashMap;
30use std::fmt::{Debug, Formatter};
31use std::io::{Cursor, Read, Write};
32use std::net::{IpAddr, SocketAddr};
33use std::path::PathBuf;
34use std::sync::Arc;
35use std::time::{Duration, SystemTime};
36
37use anyhow::{bail, ensure, Context, Result};
38use axum_server::tls_rustls::RustlsConfig;
39use chrono::Local;
40use chrono_humanize::{Accuracy, HumanTime, Tense};
41use flate2::read::GzDecoder;
42use flate2::write::GzEncoder;
43use flate2::Compression;
44use sha2::{Digest, Sha256};
45use tokio::net::TcpListener;
46
47/// MDB version
48pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
49
50/// Gzip's magic number to see if a file is compressed
51pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
52
53/// State & configuration of the running server instance
54pub struct State {
55    /// The port which will be used to listen for connections.
56    pub port: u16,
57
58    /// The directory to store malware samples if we're keeping them.
59    pub directory: Option<PathBuf>,
60
61    /// Maximum upload size
62    pub max_upload: usize,
63
64    /// The IP to use for listening for connections
65    pub ip: IpAddr,
66
67    /// Handle to the database connection
68    pub db_type: db::DatabaseType,
69
70    /// Start time of the server
71    pub started: SystemTime,
72
73    /// Configuration which is stored in the database
74    pub db_config: MDBConfig,
75
76    /// File encryption keys, may be empty
77    pub(crate) keys: HashMap<u32, FileEncryption>,
78
79    /// Virus Total API key
80    #[cfg(feature = "vt")]
81    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
82
83    /// Https certificate file, optionally containing the CA cert as well (server and CA certs
84    /// concatenated)
85    cert: Option<PathBuf>,
86
87    /// Https private key file
88    key: Option<PathBuf>,
89}
90
91impl State {
92    /// New server state object given a few configuration parameters
93    ///
94    /// # Errors
95    /// * If there's a certificate and not a key, or a key and not a certificate; if either don't exist.
96    /// * If the sample storage directory doesn't exist.
97    /// * if the database configuration isn't valid or if the Postgres server can't be reached.
98    #[allow(clippy::too_many_arguments)]
99    pub async fn new(
100        port: u16,
101        directory: Option<PathBuf>,
102        max_upload: usize,
103        ip: IpAddr,
104        db_string: &str,
105        cert: Option<PathBuf>,
106        key: Option<PathBuf>,
107        pg_cert: Option<PathBuf>,
108        #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
109    ) -> Result<Self> {
110        if let Some(dir) = &directory {
111            if !dir.exists() {
112                bail!("data directory {dir:?} does not exist!");
113            }
114        }
115
116        if cert.is_some() != key.is_some() {
117            bail!("Either both the https cert & key are provided, or none are provided.");
118        }
119
120        if let Some(cert_file) = &cert {
121            ensure!(
122                cert_file.exists(),
123                "Certificate file {cert_file:?} does not exist!"
124            );
125        }
126
127        if let Some(key_file) = &key {
128            ensure!(key_file.exists(), "Key file {key_file:?} does not exist!");
129        }
130
131        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
132        // TODO: allow config and keys to be refreshed so changes don't require a restart
133        let db_config = db_type.get_config().await?;
134        let keys = db_type.get_encryption_keys().await?;
135
136        Ok(Self {
137            port,
138            directory,
139            max_upload,
140            ip,
141            db_type,
142            db_config,
143            keys,
144            #[cfg(feature = "vt")]
145            vt_client,
146            started: SystemTime::now(),
147            cert,
148            key,
149        })
150    }
151
152    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
153    ///
154    /// # Errors
155    ///
156    /// * If the file can't be written.
157    /// * If a necessary sub-directory can't be created.
158    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
159        if let Some(dest_path) = &self.directory {
160            let mut hasher = Sha256::new();
161            hasher.update(data);
162            let sha256 = hex::encode(hasher.finalize());
163
164            // Trait `HashPath` needs to be re-worked so it can work with Strings.
165            // This code below ends up making the String into ASCII representations of the hash
166            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
167            let hashed_path = format!(
168                "{}/{}/{}/{}",
169                &sha256[0..2],
170                &sha256[2..4],
171                &sha256[4..6],
172                sha256
173            );
174
175            // The path which has the file name included, with the storage directory prepended.
176            //let hashed_path = result.hashed_path(3);
177            let mut dest_path = dest_path.clone();
178            dest_path.push(hashed_path);
179
180            // Remove the file name so we can just have the directory path.
181            let mut just_the_dir = dest_path.clone();
182            just_the_dir.pop();
183            std::fs::create_dir_all(just_the_dir)?;
184
185            let data = if self.db_config.compression {
186                let mut compressor = GzEncoder::new(Vec::new(), Compression::default());
187                compressor.write_all(data)?;
188                compressor.finish()?
189            } else {
190                data.to_vec()
191            };
192
193            let data = if let Some(key_id) = self.db_config.default_key {
194                if let Some(key) = self.keys.get(&key_id) {
195                    let nonce = key.nonce();
196                    self.db_type
197                        .set_file_nonce(&sha256, nonce.as_deref())
198                        .await?;
199                    key.encrypt(&data, nonce)?
200                } else {
201                    bail!("Key not available!")
202                }
203            } else {
204                data
205            };
206
207            std::fs::write(dest_path, data)?;
208
209            Ok(true)
210        } else {
211            Ok(false)
212        }
213    }
214
215    /// Retrieve a sample given the SHA-256 hash
216    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
217    ///
218    /// # Errors
219    ///
220    /// * The file could not be read, maybe because it doesn't exist.
221    /// * Failure to decrypt or decompress (corruption).
222    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
223        if let Some(dest_path) = &self.directory {
224            let path = format!(
225                "{}/{}/{}/{}",
226                &sha256[0..2],
227                &sha256[2..4],
228                &sha256[4..6],
229                sha256
230            );
231            // Trait `HashPath` needs to be re-worked so it can work with Strings.
232            // This code below ends up making the String into ASCII representations of the hash
233            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
234            //let path = sha256.as_bytes().iter().hashed_path(3);
235            let contents = std::fs::read(dest_path.join(path))?;
236
237            let contents = if self.keys.is_empty() {
238                // We don't have file encryption enabled
239                contents
240            } else {
241                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
242                if let Some(key_id) = key_id {
243                    if let Some(key) = self.keys.get(&key_id) {
244                        key.decrypt(&contents, nonce)?
245                    } else {
246                        bail!("File was encrypted but we don't have tke key!")
247                    }
248                } else {
249                    // File was not encrypted
250                    contents
251                }
252            };
253
254            if contents.starts_with(&GZIP_MAGIC) {
255                let buff = Cursor::new(contents);
256                let mut decompressor = GzDecoder::new(buff);
257                let mut decompressed: Vec<u8> = vec![];
258                decompressor.read_to_end(&mut decompressed)?;
259                Ok(decompressed)
260            } else {
261                Ok(contents)
262            }
263        } else {
264            bail!("files are not saved")
265        }
266    }
267
268    /// Get the duration for which the server has been running
269    ///
270    /// # Panics
271    ///
272    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
273    #[must_use]
274    pub fn since(&self) -> Duration {
275        let now = SystemTime::now();
276        now.duration_since(self.started).unwrap()
277    }
278
279    /// Get server information
280    ///
281    /// # Errors
282    ///
283    /// An error would occur if the Postgres server could not be reached.
284    pub async fn get_info(&self) -> Result<ServerInfo> {
285        let db_info = self.db_type.db_info().await?;
286        let uptime = Local::now() - self.since();
287        let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
288            humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
289        } else {
290            String::new()
291        };
292
293        Ok(ServerInfo {
294            os_name: std::env::consts::OS.into(),
295            memory_used: mem_size,
296            num_samples: db_info.num_files,
297            num_users: db_info.num_users,
298            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
299            mdb_version: MDB_VERSION.into(),
300            db_version: db_info.version,
301            db_size: db_info.size,
302        })
303    }
304
305    /// The server listens and responds to requests. Does not return unless there's an error.
306    ///
307    /// # Errors
308    ///
309    /// * If the certificate and private key could not be parsed or are not valid.
310    /// * If the IP address and port are already in use.
311    /// * If the service doesn't have permission to open the port.
312    ///
313    /// # Panics
314    ///
315    /// * The `.unwrap()` calls won't panic because the data would have already been validated.
316    pub async fn serve(self) -> Result<()> {
317        let socket = SocketAddr::new(self.ip, self.port);
318
319        if self.cert.is_some() && self.key.is_some() {
320            let cert_path = self.cert.as_ref().unwrap();
321            let key_path = self.key.as_ref().unwrap();
322            let cert_ext_str = cert_path
323                .extension()
324                .context("failed to get certificate extension")?;
325            let key_ext_str = key_path
326                .extension()
327                .context("failed to get key extension")?;
328
329            // Unnecessary for running MalwareDB, but some unit tests fail without this check.
330            if rustls::crypto::CryptoProvider::get_default().is_none() {
331                rustls::crypto::ring::default_provider()
332                    .install_default()
333                    .expect("Failed to load crypto provider");
334            }
335
336            let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
337            {
338                RustlsConfig::from_pem_file(cert_path, key_path)
339                    .await
340                    .context("failed to load or parse certificate and key files")?
341            } else if cert_ext_str == "der" && key_ext_str == "der" {
342                let cert_contents =
343                    std::fs::read(cert_path).context("failed to read certificate file")?;
344                let key_contents =
345                    std::fs::read(key_path).context("failed to read certificate file")?;
346                RustlsConfig::from_der(vec![cert_contents], key_contents)
347                    .await
348                    .context("failed to parse certificate and key files as DER")?
349            } else {
350                bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
351            };
352
353            println!("Listening on https://{socket:?}");
354            axum_server::bind_rustls(socket, config)
355                .serve(http::app(Arc::new(self)).into_make_service())
356                .await?;
357        } else {
358            println!("Listening on http://{socket:?}");
359            let listener = TcpListener::bind(socket)
360                .await
361                .context(format!("failed to bind socket {socket}"))?;
362            axum::serve(listener, http::app(Arc::new(self)).into_make_service()).await?;
363        }
364        Ok(())
365    }
366}
367
368impl Debug for State {
369    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
370        write!(
371            f,
372            "MDB state, port {}, database {:?}",
373            self.port, self.db_type
374        )
375    }
376}