malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8/// Cryptographic functionality for file storage
9pub mod crypto;
10
11/// Database I/O
12pub mod db;
13
14/// HTTP Server
15pub mod http;
16
17/// Entropy functions
18pub mod utils;
19
20/// Virus Total communication
21#[cfg(feature = "vt")]
22pub mod vt;
23
24use crate::crypto::FileEncryption;
25use crate::db::MDBConfig;
26use malwaredb_api::ServerInfo;
27//use utils::HashPath;
28
29use std::collections::HashMap;
30use std::fmt::{Debug, Formatter};
31use std::io::{Cursor, Read};
32use std::net::{IpAddr, SocketAddr};
33use std::path::PathBuf;
34use std::sync::Arc;
35use std::time::{Duration, SystemTime};
36
37use anyhow::{bail, ensure, Context, Result};
38use axum_server::tls_rustls::RustlsConfig;
39use chrono::Local;
40use chrono_humanize::{Accuracy, HumanTime, Tense};
41use flate2::read::GzDecoder;
42use sha2::{Digest, Sha256};
43use tokio::net::TcpListener;
44
45/// MDB version
46pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
47
48/// Gzip's magic number to see if a file is compressed
49pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
50
51/// Zstd magic number to see if a file is compressed
52pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
53
54/// State & configuration of the running server instance
55pub struct State {
56    /// The port which will be used to listen for connections.
57    pub port: u16,
58
59    /// The directory to store malware samples if we're keeping them.
60    pub directory: Option<PathBuf>,
61
62    /// Maximum upload size
63    pub max_upload: usize,
64
65    /// The IP to use for listening for connections
66    pub ip: IpAddr,
67
68    /// Handle to the database connection
69    pub db_type: db::DatabaseType,
70
71    /// Start time of the server
72    pub started: SystemTime,
73
74    /// Configuration which is stored in the database
75    pub db_config: MDBConfig,
76
77    /// File encryption keys, may be empty
78    pub(crate) keys: HashMap<u32, FileEncryption>,
79
80    /// Virus Total API key
81    #[cfg(feature = "vt")]
82    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
83
84    /// Https certificate file, optionally containing the CA cert as well (server and CA certs
85    /// concatenated)
86    cert: Option<PathBuf>,
87
88    /// Https private key file
89    key: Option<PathBuf>,
90}
91
92impl State {
93    /// New server state object given a few configuration parameters
94    ///
95    /// # Errors
96    /// * If there's a certificate and not a key, or a key and not a certificate; if either don't exist.
97    /// * If the sample storage directory doesn't exist.
98    /// * if the database configuration isn't valid or if the Postgres server can't be reached.
99    #[allow(clippy::too_many_arguments)]
100    pub async fn new(
101        port: u16,
102        directory: Option<PathBuf>,
103        max_upload: usize,
104        ip: IpAddr,
105        db_string: &str,
106        cert: Option<PathBuf>,
107        key: Option<PathBuf>,
108        pg_cert: Option<PathBuf>,
109        #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
110    ) -> Result<Self> {
111        if let Some(dir) = &directory {
112            if !dir.exists() {
113                bail!("data directory {dir:?} does not exist!");
114            }
115        }
116
117        if cert.is_some() != key.is_some() {
118            bail!("Either both the https cert & key are provided, or none are provided.");
119        }
120
121        if let Some(cert_file) = &cert {
122            ensure!(
123                cert_file.exists(),
124                "Certificate file {cert_file:?} does not exist!"
125            );
126        }
127
128        if let Some(key_file) = &key {
129            ensure!(key_file.exists(), "Key file {key_file:?} does not exist!");
130        }
131
132        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
133        // TODO: allow config and keys to be refreshed so changes don't require a restart
134        let db_config = db_type.get_config().await?;
135        let keys = db_type.get_encryption_keys().await?;
136
137        Ok(Self {
138            port,
139            directory,
140            max_upload,
141            ip,
142            db_type,
143            db_config,
144            keys,
145            #[cfg(feature = "vt")]
146            vt_client,
147            started: SystemTime::now(),
148            cert,
149            key,
150        })
151    }
152
153    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
154    ///
155    /// # Errors
156    ///
157    /// * If the file can't be written.
158    /// * If a necessary sub-directory can't be created.
159    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
160        if let Some(dest_path) = &self.directory {
161            let mut hasher = Sha256::new();
162            hasher.update(data);
163            let sha256 = hex::encode(hasher.finalize());
164
165            // Trait `HashPath` needs to be re-worked so it can work with Strings.
166            // This code below ends up making the String into ASCII representations of the hash
167            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
168            let hashed_path = format!(
169                "{}/{}/{}/{}",
170                &sha256[0..2],
171                &sha256[2..4],
172                &sha256[4..6],
173                sha256
174            );
175
176            // The path which has the file name included, with the storage directory prepended.
177            //let hashed_path = result.hashed_path(3);
178            let mut dest_path = dest_path.clone();
179            dest_path.push(hashed_path);
180
181            // Remove the file name so we can just have the directory path.
182            let mut just_the_dir = dest_path.clone();
183            just_the_dir.pop();
184            std::fs::create_dir_all(just_the_dir)?;
185
186            let data = if self.db_config.compression {
187                let buff = Cursor::new(data);
188                let mut compressed = Vec::with_capacity(data.len() / 2);
189                zstd::stream::copy_encode(buff, &mut compressed, 4)?;
190                compressed
191            } else {
192                data.to_vec()
193            };
194
195            let data = if let Some(key_id) = self.db_config.default_key {
196                if let Some(key) = self.keys.get(&key_id) {
197                    let nonce = key.nonce();
198                    self.db_type
199                        .set_file_nonce(&sha256, nonce.as_deref())
200                        .await?;
201                    key.encrypt(&data, nonce)?
202                } else {
203                    bail!("Key not available!")
204                }
205            } else {
206                data
207            };
208
209            std::fs::write(dest_path, data)?;
210
211            Ok(true)
212        } else {
213            Ok(false)
214        }
215    }
216
217    /// Retrieve a sample given the SHA-256 hash
218    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
219    ///
220    /// # Errors
221    ///
222    /// * The file could not be read, maybe because it doesn't exist.
223    /// * Failure to decrypt or decompress (corruption).
224    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
225        if let Some(dest_path) = &self.directory {
226            let path = format!(
227                "{}/{}/{}/{}",
228                &sha256[0..2],
229                &sha256[2..4],
230                &sha256[4..6],
231                sha256
232            );
233            // Trait `HashPath` needs to be re-worked so it can work with Strings.
234            // This code below ends up making the String into ASCII representations of the hash
235            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
236            //let path = sha256.as_bytes().iter().hashed_path(3);
237            let contents = std::fs::read(dest_path.join(path))?;
238
239            let contents = if self.keys.is_empty() {
240                // We don't have file encryption enabled
241                contents
242            } else {
243                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
244                if let Some(key_id) = key_id {
245                    if let Some(key) = self.keys.get(&key_id) {
246                        key.decrypt(&contents, nonce)?
247                    } else {
248                        bail!("File was encrypted but we don't have tke key!")
249                    }
250                } else {
251                    // File was not encrypted
252                    contents
253                }
254            };
255
256            if contents.starts_with(&GZIP_MAGIC) {
257                let buff = Cursor::new(contents);
258                let mut decompressor = GzDecoder::new(buff);
259                let mut decompressed: Vec<u8> = vec![];
260                decompressor.read_to_end(&mut decompressed)?;
261                Ok(decompressed)
262            } else if contents.starts_with(&ZSTD_MAGIC) {
263                let buff = Cursor::new(contents);
264                let mut decompressed: Vec<u8> = vec![];
265                zstd::stream::copy_decode(buff, &mut decompressed)?;
266                Ok(decompressed)
267            } else {
268                Ok(contents)
269            }
270        } else {
271            bail!("files are not saved")
272        }
273    }
274
275    /// Get the duration for which the server has been running
276    ///
277    /// # Panics
278    ///
279    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
280    #[must_use]
281    pub fn since(&self) -> Duration {
282        let now = SystemTime::now();
283        now.duration_since(self.started).unwrap()
284    }
285
286    /// Get server information
287    ///
288    /// # Errors
289    ///
290    /// An error would occur if the Postgres server could not be reached.
291    pub async fn get_info(&self) -> Result<ServerInfo> {
292        let db_info = self.db_type.db_info().await?;
293        let uptime = Local::now() - self.since();
294        let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
295            humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
296        } else {
297            String::new()
298        };
299
300        Ok(ServerInfo {
301            os_name: std::env::consts::OS.into(),
302            memory_used: mem_size,
303            num_samples: db_info.num_files,
304            num_users: db_info.num_users,
305            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
306            mdb_version: MDB_VERSION.into(),
307            db_version: db_info.version,
308            db_size: db_info.size,
309        })
310    }
311
312    /// The server listens and responds to requests. Does not return unless there's an error.
313    ///
314    /// # Errors
315    ///
316    /// * If the certificate and private key could not be parsed or are not valid.
317    /// * If the IP address and port are already in use.
318    /// * If the service doesn't have permission to open the port.
319    ///
320    /// # Panics
321    ///
322    /// * The `.unwrap()` calls won't panic because the data would have already been validated.
323    pub async fn serve(self) -> Result<()> {
324        let socket = SocketAddr::new(self.ip, self.port);
325
326        if self.cert.is_some() && self.key.is_some() {
327            let cert_path = self.cert.as_ref().unwrap();
328            let key_path = self.key.as_ref().unwrap();
329            let cert_ext_str = cert_path
330                .extension()
331                .context("failed to get certificate extension")?;
332            let key_ext_str = key_path
333                .extension()
334                .context("failed to get key extension")?;
335
336            // Unnecessary for running MalwareDB, but some unit tests fail without this check.
337            if rustls::crypto::CryptoProvider::get_default().is_none() {
338                rustls::crypto::ring::default_provider()
339                    .install_default()
340                    .expect("Failed to load crypto provider");
341            }
342
343            let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
344            {
345                RustlsConfig::from_pem_file(cert_path, key_path)
346                    .await
347                    .context("failed to load or parse certificate and key files")?
348            } else if cert_ext_str == "der" && key_ext_str == "der" {
349                let cert_contents =
350                    std::fs::read(cert_path).context("failed to read certificate file")?;
351                let key_contents =
352                    std::fs::read(key_path).context("failed to read certificate file")?;
353                RustlsConfig::from_der(vec![cert_contents], key_contents)
354                    .await
355                    .context("failed to parse certificate and key files as DER")?
356            } else {
357                bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
358            };
359
360            println!("Listening on https://{socket:?}");
361            axum_server::bind_rustls(socket, config)
362                .serve(http::app(Arc::new(self)).into_make_service())
363                .await?;
364        } else {
365            println!("Listening on http://{socket:?}");
366            let listener = TcpListener::bind(socket)
367                .await
368                .context(format!("failed to bind socket {socket}"))?;
369            axum::serve(listener, http::app(Arc::new(self)).into_make_service()).await?;
370        }
371        Ok(())
372    }
373}
374
375impl Debug for State {
376    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
377        write!(
378            f,
379            "MDB state, port {}, database {:?}",
380            self.port, self.db_type
381        )
382    }
383}