1#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9pub mod crypto;
11
12pub mod db;
14
15pub mod http;
17
18pub mod utils;
20
21#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26use crate::crypto::FileEncryption;
27use crate::db::MDBConfig;
28use malwaredb_api::ServerInfo;
29use std::collections::HashMap;
32use std::fmt::{Debug, Formatter};
33use std::io::{Cursor, Read};
34use std::net::{IpAddr, SocketAddr};
35use std::path::PathBuf;
36use std::sync::{Arc, LazyLock};
37use std::time::{Duration, SystemTime};
38
39use anyhow::{bail, ensure, Context, Result};
40use axum_server::tls_rustls::RustlsConfig;
41use chrono::Local;
42use chrono_humanize::{Accuracy, HumanTime, Tense};
43use flate2::read::GzDecoder;
44use mdns_sd::{ServiceDaemon, ServiceInfo};
45use sha2::{Digest, Sha256};
46use tokio::net::TcpListener;
47use tracing::{debug, error, trace, warn};
48
49pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
54 LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
55
56pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
59
60pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
62
63pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
65
66pub struct State {
68 pub port: u16,
70
71 pub directory: Option<PathBuf>,
73
74 pub max_upload: usize,
76
77 pub ip: IpAddr,
79
80 pub db_type: Arc<db::DatabaseType>,
82
83 pub started: SystemTime,
85
86 pub db_config: MDBConfig,
88
89 pub(crate) keys: HashMap<u32, FileEncryption>,
91
92 #[cfg(feature = "vt")]
94 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
95
96 cert: Option<PathBuf>,
99
100 key: Option<PathBuf>,
102
103 mdns: bool,
105}
106
107impl State {
108 #[allow(clippy::too_many_arguments)]
115 pub async fn new(
116 port: u16,
117 directory: Option<PathBuf>,
118 max_upload: usize,
119 ip: IpAddr,
120 db_string: &str,
121 cert: Option<PathBuf>,
122 key: Option<PathBuf>,
123 pg_cert: Option<PathBuf>,
124 mdns: bool,
125 #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
126 ) -> Result<Self> {
127 if let Some(dir) = &directory {
128 if !dir.exists() {
129 bail!("data directory {} does not exist!", dir.display());
130 }
131 }
132
133 if cert.is_some() != key.is_some() {
134 bail!("Either both the https cert & key are provided, or none are provided.");
135 }
136
137 if let Some(cert_file) = &cert {
138 ensure!(
139 cert_file.exists(),
140 "Certificate file {} does not exist!",
141 cert_file.display()
142 );
143 }
144
145 if let Some(key_file) = &key {
146 ensure!(
147 key_file.exists(),
148 "Key file {} does not exist!",
149 key_file.display()
150 );
151 }
152
153 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
154 let db_config = db_type.get_config().await?;
156 let keys = db_type.get_encryption_keys().await?;
157
158 Ok(Self {
159 port,
160 directory,
161 max_upload,
162 ip,
163 db_type: Arc::new(db_type),
164 db_config,
165 keys,
166 #[cfg(feature = "vt")]
167 vt_client,
168 started: SystemTime::now(),
169 cert,
170 key,
171 mdns,
172 })
173 }
174
175 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
182 if let Some(dest_path) = &self.directory {
183 let mut hasher = Sha256::new();
184 hasher.update(data);
185 let sha256 = hex::encode(hasher.finalize());
186
187 let hashed_path = format!(
191 "{}/{}/{}/{}",
192 &sha256[0..2],
193 &sha256[2..4],
194 &sha256[4..6],
195 sha256
196 );
197
198 let mut dest_path = dest_path.clone();
201 dest_path.push(hashed_path);
202
203 let mut just_the_dir = dest_path.clone();
205 just_the_dir.pop();
206 std::fs::create_dir_all(just_the_dir)?;
207
208 let data = if self.db_config.compression {
209 let buff = Cursor::new(data);
210 let mut compressed = Vec::with_capacity(data.len() / 2);
211 zstd::stream::copy_encode(buff, &mut compressed, 4)?;
212 compressed
213 } else {
214 data.to_vec()
215 };
216
217 let data = if let Some(key_id) = self.db_config.default_key {
218 if let Some(key) = self.keys.get(&key_id) {
219 let nonce = key.nonce();
220 self.db_type
221 .set_file_nonce(&sha256, nonce.as_deref())
222 .await?;
223 key.encrypt(&data, nonce)?
224 } else {
225 bail!("Key not available!")
226 }
227 } else {
228 data
229 };
230
231 std::fs::write(dest_path, data)?;
232
233 Ok(true)
234 } else {
235 Ok(false)
236 }
237 }
238
239 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
247 if let Some(dest_path) = &self.directory {
248 let path = format!(
249 "{}/{}/{}/{}",
250 &sha256[0..2],
251 &sha256[2..4],
252 &sha256[4..6],
253 sha256
254 );
255 let contents = std::fs::read(dest_path.join(path))?;
260
261 let contents = if self.keys.is_empty() {
262 contents
264 } else {
265 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
266 if let Some(key_id) = key_id {
267 if let Some(key) = self.keys.get(&key_id) {
268 key.decrypt(&contents, nonce)?
269 } else {
270 bail!("File was encrypted but we don't have tke key!")
271 }
272 } else {
273 contents
275 }
276 };
277
278 if contents.starts_with(&GZIP_MAGIC) {
279 let buff = Cursor::new(contents);
280 let mut decompressor = GzDecoder::new(buff);
281 let mut decompressed: Vec<u8> = vec![];
282 decompressor.read_to_end(&mut decompressed)?;
283 Ok(decompressed)
284 } else if contents.starts_with(&ZSTD_MAGIC) {
285 let buff = Cursor::new(contents);
286 let mut decompressed: Vec<u8> = vec![];
287 zstd::stream::copy_decode(buff, &mut decompressed)?;
288 Ok(decompressed)
289 } else {
290 Ok(contents)
291 }
292 } else {
293 bail!("files are not saved")
294 }
295 }
296
297 #[must_use]
303 pub fn since(&self) -> Duration {
304 let now = SystemTime::now();
305 now.duration_since(self.started).unwrap()
306 }
307
308 pub async fn get_info(&self) -> Result<ServerInfo> {
314 let db_info = self.db_type.db_info().await?;
315 let uptime = Local::now() - self.since();
316 let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
317 humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
318 } else {
319 String::new()
320 };
321
322 Ok(ServerInfo {
323 os_name: std::env::consts::OS.into(),
324 memory_used: mem_size,
325 num_samples: db_info.num_files,
326 num_users: db_info.num_users,
327 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
328 mdb_version: MDB_VERSION_SEMVER.clone(),
329 db_version: db_info.version,
330 db_size: db_info.size,
331 instance_name: self.db_config.name.clone(),
332 })
333 }
334
335 pub async fn serve(
347 self,
348 #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
349 ) -> Result<()> {
350 let socket = SocketAddr::new(self.ip, self.port);
351 let arc_self = Arc::new(self);
352 let db_ifo = arc_self.db_type.clone();
353
354 tokio::spawn(async move {
355 loop {
356 match db_ifo.cleanup().await {
357 Ok(removed) => {
358 trace!("Pagination cleanup succeeded, {removed} searches removed");
359 }
360 Err(e) => warn!("Pagination cleanup failed: {e}"),
361 }
362
363 tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
364 }
365 });
366
367 if arc_self.mdns {
368 if arc_self.ip.is_loopback() {
369 debug!("Refusing to start mdns responder for localhost");
370 } else {
371 trace!("Enabling MDNS advertising");
372 match ServiceDaemon::new() {
373 Ok(mdns) => {
374 let host_name = format!("{}.local.", arc_self.ip);
375 let ssl = arc_self.cert.is_some() && arc_self.key.is_some();
376 let properties = [("ssl", ssl.to_string())];
377 match ServiceInfo::new(
378 malwaredb_api::MDNS_NAME,
379 &arc_self.db_config.name,
380 &host_name,
381 arc_self.ip.to_string(),
382 arc_self.port,
383 &properties[..],
384 ) {
385 Ok(service) => match mdns.register(service) {
386 Ok(()) => trace!("MalwareDB mdns registered"),
387 Err(e) => error!("Failed to register service: {e}"),
388 },
389 Err(e) => error!("Failed to publish MDNS service: {e}"),
390 }
391 }
392 Err(e) => error!("Failed to open port for MDNS responder: {e}"),
393 }
394 }
395 }
396
397 if arc_self.cert.is_some() && arc_self.key.is_some() {
398 let cert_path = arc_self.cert.as_ref().unwrap();
399 let key_path = arc_self.key.as_ref().unwrap();
400 let cert_ext_str = cert_path
401 .extension()
402 .context("failed to get certificate extension")?;
403 let key_ext_str = key_path
404 .extension()
405 .context("failed to get key extension")?;
406
407 if rustls::crypto::CryptoProvider::get_default().is_none() {
409 rustls::crypto::aws_lc_rs::default_provider()
410 .install_default()
411 .expect("Failed to load crypto provider");
412 }
413
414 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
415 {
416 RustlsConfig::from_pem_file(cert_path, key_path)
417 .await
418 .context("failed to load or parse certificate and key files")?
419 } else if cert_ext_str == "der" && key_ext_str == "der" {
420 let cert_contents =
421 std::fs::read(cert_path).context("failed to read certificate file")?;
422 let key_contents =
423 std::fs::read(key_path).context("failed to read certificate file")?;
424 RustlsConfig::from_der(vec![cert_contents], key_contents)
425 .await
426 .context("failed to parse certificate and key files as DER")?
427 } else {
428 bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
429 };
430
431 println!("Listening on https://{socket:?}");
432 let handle = axum_server::Handle::new();
433 let server_future = axum_server::bind_rustls(socket, config)
434 .serve(http::app(arc_self).into_make_service());
435 tokio::select! {
436 () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
437 handle.graceful_shutdown(Some(Duration::from_secs(30))),
438 res = server_future => res?,
439 }
440 warn!("Terminate signal received");
441 } else {
442 println!("Listening on http://{socket:?}");
443 let listener = TcpListener::bind(socket)
444 .await
445 .context(format!("failed to bind socket {socket}"))?;
446 axum::serve(listener, http::app(arc_self).into_make_service())
447 .with_graceful_shutdown(shutdown_signal(
448 #[cfg(target_family = "windows")]
449 rx,
450 ))
451 .await?;
452 warn!("Terminate signal received");
453 }
454 Ok(())
455 }
456}
457
458impl Debug for State {
459 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
460 write!(
461 f,
462 "MDB state, port {}, database {:?}",
463 self.port, self.db_type
464 )
465 }
466}
467
468async fn shutdown_signal(
471 #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
472) {
473 let ctrl_c = async {
474 tokio::signal::ctrl_c()
475 .await
476 .expect("failed to install Ctrl+C handler");
477 };
478
479 #[cfg(unix)]
480 let terminate = async {
481 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
482 .expect("failed to install signal handler")
483 .recv()
484 .await;
485 };
486
487 #[cfg(not(unix))]
488 let terminate = std::future::pending::<()>();
489
490 #[cfg(target_family = "windows")]
491 if let Some(rx_inner) = &mut rx {
492 let terminate_rx = rx_inner.recv();
493
494 tokio::select! {
495 () = ctrl_c => {},
496 () = terminate => {},
497 Some(()) = terminate_rx => {},
498 }
499 } else {
500 tokio::select! {
501 () = ctrl_c => {},
502 () = terminate => {},
503 }
504 }
505
506 #[cfg(not(target_family = "windows"))]
507 tokio::select! {
508 () = ctrl_c => {},
509 () = terminate => {},
510 }
511}