1#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8pub mod crypto;
10
11pub mod db;
13
14pub mod http;
16
17pub mod utils;
19
20#[cfg(feature = "vt")]
22pub mod vt;
23
24use crate::crypto::FileEncryption;
25use crate::db::MDBConfig;
26use malwaredb_api::ServerInfo;
27use std::collections::HashMap;
30use std::fmt::{Debug, Formatter};
31use std::io::{Cursor, Read};
32use std::net::{IpAddr, SocketAddr};
33use std::path::PathBuf;
34use std::sync::{Arc, LazyLock};
35use std::time::{Duration, SystemTime};
36
37use anyhow::{bail, ensure, Context, Result};
38use axum_server::tls_rustls::RustlsConfig;
39use chrono::Local;
40use chrono_humanize::{Accuracy, HumanTime, Tense};
41use flate2::read::GzDecoder;
42use sha2::{Digest, Sha256};
43use tokio::net::TcpListener;
44
45pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
47
48pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
50 LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
51
52pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
54
55pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
57
58pub struct State {
60 pub port: u16,
62
63 pub directory: Option<PathBuf>,
65
66 pub max_upload: usize,
68
69 pub ip: IpAddr,
71
72 pub db_type: db::DatabaseType,
74
75 pub started: SystemTime,
77
78 pub db_config: MDBConfig,
80
81 pub(crate) keys: HashMap<u32, FileEncryption>,
83
84 #[cfg(feature = "vt")]
86 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
87
88 cert: Option<PathBuf>,
91
92 key: Option<PathBuf>,
94}
95
96impl State {
97 #[allow(clippy::too_many_arguments)]
104 pub async fn new(
105 port: u16,
106 directory: Option<PathBuf>,
107 max_upload: usize,
108 ip: IpAddr,
109 db_string: &str,
110 cert: Option<PathBuf>,
111 key: Option<PathBuf>,
112 pg_cert: Option<PathBuf>,
113 #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
114 ) -> Result<Self> {
115 if let Some(dir) = &directory {
116 if !dir.exists() {
117 bail!("data directory {} does not exist!", dir.display());
118 }
119 }
120
121 if cert.is_some() != key.is_some() {
122 bail!("Either both the https cert & key are provided, or none are provided.");
123 }
124
125 if let Some(cert_file) = &cert {
126 ensure!(
127 cert_file.exists(),
128 "Certificate file {} does not exist!",
129 cert_file.display()
130 );
131 }
132
133 if let Some(key_file) = &key {
134 ensure!(
135 key_file.exists(),
136 "Key file {} does not exist!",
137 key_file.display()
138 );
139 }
140
141 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
142 let db_config = db_type.get_config().await?;
144 let keys = db_type.get_encryption_keys().await?;
145
146 Ok(Self {
147 port,
148 directory,
149 max_upload,
150 ip,
151 db_type,
152 db_config,
153 keys,
154 #[cfg(feature = "vt")]
155 vt_client,
156 started: SystemTime::now(),
157 cert,
158 key,
159 })
160 }
161
162 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
169 if let Some(dest_path) = &self.directory {
170 let mut hasher = Sha256::new();
171 hasher.update(data);
172 let sha256 = hex::encode(hasher.finalize());
173
174 let hashed_path = format!(
178 "{}/{}/{}/{}",
179 &sha256[0..2],
180 &sha256[2..4],
181 &sha256[4..6],
182 sha256
183 );
184
185 let mut dest_path = dest_path.clone();
188 dest_path.push(hashed_path);
189
190 let mut just_the_dir = dest_path.clone();
192 just_the_dir.pop();
193 std::fs::create_dir_all(just_the_dir)?;
194
195 let data = if self.db_config.compression {
196 let buff = Cursor::new(data);
197 let mut compressed = Vec::with_capacity(data.len() / 2);
198 zstd::stream::copy_encode(buff, &mut compressed, 4)?;
199 compressed
200 } else {
201 data.to_vec()
202 };
203
204 let data = if let Some(key_id) = self.db_config.default_key {
205 if let Some(key) = self.keys.get(&key_id) {
206 let nonce = key.nonce();
207 self.db_type
208 .set_file_nonce(&sha256, nonce.as_deref())
209 .await?;
210 key.encrypt(&data, nonce)?
211 } else {
212 bail!("Key not available!")
213 }
214 } else {
215 data
216 };
217
218 std::fs::write(dest_path, data)?;
219
220 Ok(true)
221 } else {
222 Ok(false)
223 }
224 }
225
226 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
234 if let Some(dest_path) = &self.directory {
235 let path = format!(
236 "{}/{}/{}/{}",
237 &sha256[0..2],
238 &sha256[2..4],
239 &sha256[4..6],
240 sha256
241 );
242 let contents = std::fs::read(dest_path.join(path))?;
247
248 let contents = if self.keys.is_empty() {
249 contents
251 } else {
252 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
253 if let Some(key_id) = key_id {
254 if let Some(key) = self.keys.get(&key_id) {
255 key.decrypt(&contents, nonce)?
256 } else {
257 bail!("File was encrypted but we don't have tke key!")
258 }
259 } else {
260 contents
262 }
263 };
264
265 if contents.starts_with(&GZIP_MAGIC) {
266 let buff = Cursor::new(contents);
267 let mut decompressor = GzDecoder::new(buff);
268 let mut decompressed: Vec<u8> = vec![];
269 decompressor.read_to_end(&mut decompressed)?;
270 Ok(decompressed)
271 } else if contents.starts_with(&ZSTD_MAGIC) {
272 let buff = Cursor::new(contents);
273 let mut decompressed: Vec<u8> = vec![];
274 zstd::stream::copy_decode(buff, &mut decompressed)?;
275 Ok(decompressed)
276 } else {
277 Ok(contents)
278 }
279 } else {
280 bail!("files are not saved")
281 }
282 }
283
284 #[must_use]
290 pub fn since(&self) -> Duration {
291 let now = SystemTime::now();
292 now.duration_since(self.started).unwrap()
293 }
294
295 pub async fn get_info(&self) -> Result<ServerInfo> {
301 let db_info = self.db_type.db_info().await?;
302 let uptime = Local::now() - self.since();
303 let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
304 humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
305 } else {
306 String::new()
307 };
308
309 Ok(ServerInfo {
310 os_name: std::env::consts::OS.into(),
311 memory_used: mem_size,
312 num_samples: db_info.num_files,
313 num_users: db_info.num_users,
314 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
315 mdb_version: MDB_VERSION_SEMVER.clone(),
316 db_version: db_info.version,
317 db_size: db_info.size,
318 instance_name: self.db_config.name.clone(),
319 })
320 }
321
322 pub async fn serve(self) -> Result<()> {
334 let socket = SocketAddr::new(self.ip, self.port);
335
336 if self.cert.is_some() && self.key.is_some() {
337 let cert_path = self.cert.as_ref().unwrap();
338 let key_path = self.key.as_ref().unwrap();
339 let cert_ext_str = cert_path
340 .extension()
341 .context("failed to get certificate extension")?;
342 let key_ext_str = key_path
343 .extension()
344 .context("failed to get key extension")?;
345
346 if rustls::crypto::CryptoProvider::get_default().is_none() {
348 rustls::crypto::aws_lc_rs::default_provider()
349 .install_default()
350 .expect("Failed to load crypto provider");
351 }
352
353 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
354 {
355 RustlsConfig::from_pem_file(cert_path, key_path)
356 .await
357 .context("failed to load or parse certificate and key files")?
358 } else if cert_ext_str == "der" && key_ext_str == "der" {
359 let cert_contents =
360 std::fs::read(cert_path).context("failed to read certificate file")?;
361 let key_contents =
362 std::fs::read(key_path).context("failed to read certificate file")?;
363 RustlsConfig::from_der(vec![cert_contents], key_contents)
364 .await
365 .context("failed to parse certificate and key files as DER")?
366 } else {
367 bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
368 };
369
370 println!("Listening on https://{socket:?}");
371 axum_server::bind_rustls(socket, config)
372 .serve(http::app(Arc::new(self)).into_make_service())
373 .await?;
374 } else {
375 println!("Listening on http://{socket:?}");
376 let listener = TcpListener::bind(socket)
377 .await
378 .context(format!("failed to bind socket {socket}"))?;
379 axum::serve(listener, http::app(Arc::new(self)).into_make_service()).await?;
380 }
381 Ok(())
382 }
383}
384
385impl Debug for State {
386 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
387 write!(
388 f,
389 "MDB state, port {}, database {:?}",
390 self.port, self.db_type
391 )
392 }
393}