malwaredb_server/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8/// Cryptographic functionality for file storage
9pub mod crypto;
10
11/// Database I/O
12pub mod db;
13
14/// HTTP Server
15pub mod http;
16
17/// Entropy functions
18pub mod utils;
19
20/// Virus Total communication
21#[cfg(feature = "vt")]
22pub mod vt;
23
24use crate::crypto::FileEncryption;
25use crate::db::MDBConfig;
26use malwaredb_api::ServerInfo;
27//use utils::HashPath;
28
29use std::collections::HashMap;
30use std::fmt::{Debug, Formatter};
31use std::io::{Cursor, Read};
32use std::net::{IpAddr, SocketAddr};
33use std::path::PathBuf;
34use std::sync::{Arc, LazyLock};
35use std::time::{Duration, SystemTime};
36
37use anyhow::{bail, ensure, Context, Result};
38use axum_server::tls_rustls::RustlsConfig;
39use chrono::Local;
40use chrono_humanize::{Accuracy, HumanTime, Tense};
41use flate2::read::GzDecoder;
42use sha2::{Digest, Sha256};
43use tokio::net::TcpListener;
44
45/// MDB version
46pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
47
48/// MDB version as a semantic version object
49pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
50    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
51
52/// Gzip's magic number to see if a file is compressed
53pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
54
55/// Zstd magic number to see if a file is compressed
56pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
57
58/// State & configuration of the running server instance
59pub struct State {
60    /// The port which will be used to listen for connections.
61    pub port: u16,
62
63    /// The directory to store malware samples if we're keeping them.
64    pub directory: Option<PathBuf>,
65
66    /// Maximum upload size
67    pub max_upload: usize,
68
69    /// The IP to use for listening for connections
70    pub ip: IpAddr,
71
72    /// Handle to the database connection
73    pub db_type: db::DatabaseType,
74
75    /// Start time of the server
76    pub started: SystemTime,
77
78    /// Configuration which is stored in the database
79    pub db_config: MDBConfig,
80
81    /// File encryption keys, may be empty
82    pub(crate) keys: HashMap<u32, FileEncryption>,
83
84    /// Virus Total API key
85    #[cfg(feature = "vt")]
86    pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
87
88    /// Https certificate file, optionally containing the CA cert as well (server and CA certs
89    /// concatenated)
90    cert: Option<PathBuf>,
91
92    /// Https private key file
93    key: Option<PathBuf>,
94}
95
96impl State {
97    /// New server state object given a few configuration parameters
98    ///
99    /// # Errors
100    /// * If there's a certificate and not a key, or a key and not a certificate; if either don't exist.
101    /// * If the sample storage directory doesn't exist.
102    /// * if the database configuration isn't valid or if the Postgres server can't be reached.
103    #[allow(clippy::too_many_arguments)]
104    pub async fn new(
105        port: u16,
106        directory: Option<PathBuf>,
107        max_upload: usize,
108        ip: IpAddr,
109        db_string: &str,
110        cert: Option<PathBuf>,
111        key: Option<PathBuf>,
112        pg_cert: Option<PathBuf>,
113        #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
114    ) -> Result<Self> {
115        if let Some(dir) = &directory {
116            if !dir.exists() {
117                bail!("data directory {} does not exist!", dir.display());
118            }
119        }
120
121        if cert.is_some() != key.is_some() {
122            bail!("Either both the https cert & key are provided, or none are provided.");
123        }
124
125        if let Some(cert_file) = &cert {
126            ensure!(
127                cert_file.exists(),
128                "Certificate file {} does not exist!",
129                cert_file.display()
130            );
131        }
132
133        if let Some(key_file) = &key {
134            ensure!(
135                key_file.exists(),
136                "Key file {} does not exist!",
137                key_file.display()
138            );
139        }
140
141        let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
142        // TODO: allow config and keys to be refreshed so changes don't require a restart
143        let db_config = db_type.get_config().await?;
144        let keys = db_type.get_encryption_keys().await?;
145
146        Ok(Self {
147            port,
148            directory,
149            max_upload,
150            ip,
151            db_type,
152            db_config,
153            keys,
154            #[cfg(feature = "vt")]
155            vt_client,
156            started: SystemTime::now(),
157            cert,
158            key,
159        })
160    }
161
162    /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
163    ///
164    /// # Errors
165    ///
166    /// * If the file can't be written.
167    /// * If a necessary sub-directory can't be created.
168    pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
169        if let Some(dest_path) = &self.directory {
170            let mut hasher = Sha256::new();
171            hasher.update(data);
172            let sha256 = hex::encode(hasher.finalize());
173
174            // Trait `HashPath` needs to be re-worked so it can work with Strings.
175            // This code below ends up making the String into ASCII representations of the hash
176            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
177            let hashed_path = format!(
178                "{}/{}/{}/{}",
179                &sha256[0..2],
180                &sha256[2..4],
181                &sha256[4..6],
182                sha256
183            );
184
185            // The path which has the file name included, with the storage directory prepended.
186            //let hashed_path = result.hashed_path(3);
187            let mut dest_path = dest_path.clone();
188            dest_path.push(hashed_path);
189
190            // Remove the file name so we can just have the directory path.
191            let mut just_the_dir = dest_path.clone();
192            just_the_dir.pop();
193            std::fs::create_dir_all(just_the_dir)?;
194
195            let data = if self.db_config.compression {
196                let buff = Cursor::new(data);
197                let mut compressed = Vec::with_capacity(data.len() / 2);
198                zstd::stream::copy_encode(buff, &mut compressed, 4)?;
199                compressed
200            } else {
201                data.to_vec()
202            };
203
204            let data = if let Some(key_id) = self.db_config.default_key {
205                if let Some(key) = self.keys.get(&key_id) {
206                    let nonce = key.nonce();
207                    self.db_type
208                        .set_file_nonce(&sha256, nonce.as_deref())
209                        .await?;
210                    key.encrypt(&data, nonce)?
211                } else {
212                    bail!("Key not available!")
213                }
214            } else {
215                data
216            };
217
218            std::fs::write(dest_path, data)?;
219
220            Ok(true)
221        } else {
222            Ok(false)
223        }
224    }
225
226    /// Retrieve a sample given the SHA-256 hash
227    /// Assumes that `MalwareDB` permissions have already been checked to ensure this is permitted.
228    ///
229    /// # Errors
230    ///
231    /// * The file could not be read, maybe because it doesn't exist.
232    /// * Failure to decrypt or decompress (corruption).
233    pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
234        if let Some(dest_path) = &self.directory {
235            let path = format!(
236                "{}/{}/{}/{}",
237                &sha256[0..2],
238                &sha256[2..4],
239                &sha256[4..6],
240                sha256
241            );
242            // Trait `HashPath` needs to be re-worked so it can work with Strings.
243            // This code below ends up making the String into ASCII representations of the hash
244            // See: https://github.com/malwaredb/malwaredb-rs/issues/60
245            //let path = sha256.as_bytes().iter().hashed_path(3);
246            let contents = std::fs::read(dest_path.join(path))?;
247
248            let contents = if self.keys.is_empty() {
249                // We don't have file encryption enabled
250                contents
251            } else {
252                let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
253                if let Some(key_id) = key_id {
254                    if let Some(key) = self.keys.get(&key_id) {
255                        key.decrypt(&contents, nonce)?
256                    } else {
257                        bail!("File was encrypted but we don't have tke key!")
258                    }
259                } else {
260                    // File was not encrypted
261                    contents
262                }
263            };
264
265            if contents.starts_with(&GZIP_MAGIC) {
266                let buff = Cursor::new(contents);
267                let mut decompressor = GzDecoder::new(buff);
268                let mut decompressed: Vec<u8> = vec![];
269                decompressor.read_to_end(&mut decompressed)?;
270                Ok(decompressed)
271            } else if contents.starts_with(&ZSTD_MAGIC) {
272                let buff = Cursor::new(contents);
273                let mut decompressed: Vec<u8> = vec![];
274                zstd::stream::copy_decode(buff, &mut decompressed)?;
275                Ok(decompressed)
276            } else {
277                Ok(contents)
278            }
279        } else {
280            bail!("files are not saved")
281        }
282    }
283
284    /// Get the duration for which the server has been running
285    ///
286    /// # Panics
287    ///
288    /// Despite the `unwrap()` this function will not panic as the data used is guaranteed to be valid.
289    #[must_use]
290    pub fn since(&self) -> Duration {
291        let now = SystemTime::now();
292        now.duration_since(self.started).unwrap()
293    }
294
295    /// Get server information
296    ///
297    /// # Errors
298    ///
299    /// An error would occur if the Postgres server could not be reached.
300    pub async fn get_info(&self) -> Result<ServerInfo> {
301        let db_info = self.db_type.db_info().await?;
302        let uptime = Local::now() - self.since();
303        let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
304            humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
305        } else {
306            String::new()
307        };
308
309        Ok(ServerInfo {
310            os_name: std::env::consts::OS.into(),
311            memory_used: mem_size,
312            num_samples: db_info.num_files,
313            num_users: db_info.num_users,
314            uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
315            mdb_version: MDB_VERSION_SEMVER.clone(),
316            db_version: db_info.version,
317            db_size: db_info.size,
318            instance_name: self.db_config.name.clone(),
319        })
320    }
321
322    /// The server listens and responds to requests. Does not return unless there's an error.
323    ///
324    /// # Errors
325    ///
326    /// * If the certificate and private key could not be parsed or are not valid.
327    /// * If the IP address and port are already in use.
328    /// * If the service doesn't have permission to open the port.
329    ///
330    /// # Panics
331    ///
332    /// * The `.unwrap()` calls won't panic because the data would have already been validated.
333    pub async fn serve(self) -> Result<()> {
334        let socket = SocketAddr::new(self.ip, self.port);
335
336        if self.cert.is_some() && self.key.is_some() {
337            let cert_path = self.cert.as_ref().unwrap();
338            let key_path = self.key.as_ref().unwrap();
339            let cert_ext_str = cert_path
340                .extension()
341                .context("failed to get certificate extension")?;
342            let key_ext_str = key_path
343                .extension()
344                .context("failed to get key extension")?;
345
346            // Unnecessary for running MalwareDB, but some unit tests fail without this check.
347            if rustls::crypto::CryptoProvider::get_default().is_none() {
348                rustls::crypto::aws_lc_rs::default_provider()
349                    .install_default()
350                    .expect("Failed to load crypto provider");
351            }
352
353            let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
354            {
355                RustlsConfig::from_pem_file(cert_path, key_path)
356                    .await
357                    .context("failed to load or parse certificate and key files")?
358            } else if cert_ext_str == "der" && key_ext_str == "der" {
359                let cert_contents =
360                    std::fs::read(cert_path).context("failed to read certificate file")?;
361                let key_contents =
362                    std::fs::read(key_path).context("failed to read certificate file")?;
363                RustlsConfig::from_der(vec![cert_contents], key_contents)
364                    .await
365                    .context("failed to parse certificate and key files as DER")?
366            } else {
367                bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
368            };
369
370            println!("Listening on https://{socket:?}");
371            axum_server::bind_rustls(socket, config)
372                .serve(http::app(Arc::new(self)).into_make_service())
373                .await?;
374        } else {
375            println!("Listening on http://{socket:?}");
376            let listener = TcpListener::bind(socket)
377                .await
378                .context(format!("failed to bind socket {socket}"))?;
379            axum::serve(listener, http::app(Arc::new(self)).into_make_service()).await?;
380        }
381        Ok(())
382    }
383}
384
385impl Debug for State {
386    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
387        write!(
388            f,
389            "MDB state, port {}, database {:?}",
390            self.port, self.db_type
391        )
392    }
393}