1#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9pub mod crypto;
11
12pub mod db;
14
15pub mod http;
17
18pub mod utils;
20
21#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26use crate::crypto::FileEncryption;
27use crate::db::MDBConfig;
28use malwaredb_api::ServerInfo;
29use std::collections::HashMap;
32use std::fmt::{Debug, Formatter};
33use std::io::{Cursor, Read};
34use std::net::{IpAddr, Ipv4Addr, SocketAddr};
35use std::path::PathBuf;
36use std::sync::{Arc, LazyLock};
37use std::time::{Duration, SystemTime};
38
39use anyhow::{anyhow, bail, ensure, Context, Result};
40use axum_server::tls_rustls::RustlsConfig;
41use chrono::Local;
42use chrono_humanize::{Accuracy, HumanTime, Tense};
43use flate2::read::GzDecoder;
44use mdns_sd::{ServiceDaemon, ServiceInfo};
45use sha2::{Digest, Sha256};
46use tokio::net::TcpListener;
47use tracing::{trace, warn};
48
49pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
54 LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
55
56pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
59
60pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
62
63pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
65
66pub struct StateBuilder {
68 pub port: u16,
70
71 pub directory: Option<PathBuf>,
73
74 pub max_upload: usize,
76
77 pub ip: IpAddr,
79
80 db_type: db::DatabaseType,
82
83 #[cfg(feature = "vt")]
85 vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
86
87 tls_config: Option<RustlsConfig>,
89
90 mdns: bool,
92}
93
94impl StateBuilder {
95 pub async fn new(db_string: &str, pg_cert: Option<PathBuf>) -> Result<Self> {
103 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
104
105 Ok(Self {
106 port: 8080,
107 directory: None,
108 max_upload: 104_857_600, ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
110 db_type,
111 #[cfg(feature = "vt")]
112 vt_client: None,
113 tls_config: None,
114 mdns: false,
115 })
116 }
117
118 #[must_use]
121 pub fn port(mut self, port: u16) -> Self {
122 self.port = port;
123 self
124 }
125
126 #[must_use]
130 pub fn directory(mut self, directory: PathBuf) -> Self {
131 self.directory = Some(directory);
132 self
133 }
134
135 #[must_use]
138 pub fn max_upload(mut self, max_upload: usize) -> Self {
139 self.max_upload = max_upload;
140 self
141 }
142
143 #[must_use]
146 pub fn ip(mut self, ip: IpAddr) -> Self {
147 self.ip = ip;
148 self
149 }
150
151 #[must_use]
153 #[cfg(feature = "vt")]
154 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
155 pub fn vt_client(mut self, vt_client: malwaredb_virustotal::VirusTotalClient) -> Self {
156 self.vt_client = Some(vt_client);
157 self
158 }
159
160 pub async fn tls(mut self, cert_file: PathBuf, key_file: PathBuf) -> Result<Self> {
167 ensure!(
168 cert_file.exists(),
169 "Certificate file {} does not exist!",
170 cert_file.display()
171 );
172
173 ensure!(
174 key_file.exists(),
175 "Key file {} does not exist!",
176 key_file.display()
177 );
178
179 let cert_ext_str = cert_file
180 .extension()
181 .context("failed to get certificate extension")?;
182 let key_ext_str = key_file
183 .extension()
184 .context("failed to get key extension")?;
185
186 if rustls::crypto::CryptoProvider::get_default().is_none() {
188 rustls::crypto::aws_lc_rs::default_provider()
189 .install_default()
190 .map_err(|_| anyhow!("failed to install AWS-LC crypto provider"))?;
191 }
192
193 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem" {
194 RustlsConfig::from_pem_file(cert_file, key_file)
195 .await
196 .context("failed to load or parse certificate and key pem files")?
197 } else if cert_ext_str == "der" && key_ext_str == "der" {
198 let cert_contents =
199 std::fs::read(cert_file).context("failed to read certificate file")?;
200 let key_contents =
201 std::fs::read(key_file).context("failed to read private key file")?;
202 RustlsConfig::from_der(vec![cert_contents], key_contents)
203 .await
204 .context("failed to parse certificate and key der files")?
205 } else {
206 bail!(
207 "Unknown or unmatched certificate and key file extensions {} and {}",
208 cert_ext_str.display(),
209 key_ext_str.display()
210 );
211 };
212
213 self.tls_config = Some(config);
214 Ok(self)
215 }
216
217 #[must_use]
220 pub fn enable_mdns(mut self) -> Self {
221 self.mdns = true;
222 self
223 }
224
225 pub async fn into_state(self) -> Result<State> {
231 let db_config = self.db_type.get_config().await?;
232 let keys = self.db_type.get_encryption_keys().await?;
233
234 Ok(State {
235 port: self.port,
236 directory: self.directory,
237 max_upload: self.max_upload,
238 ip: self.ip,
239 db_type: Arc::new(self.db_type),
240 started: SystemTime::now(),
241 db_config,
242 keys,
243 #[cfg(feature = "vt")]
244 vt_client: self.vt_client,
245 tls_config: self.tls_config,
246 mdns: if self.mdns {
247 Some(ServiceDaemon::new()?)
248 } else {
249 None
250 },
251 })
252 }
253}
254
255pub struct State {
257 pub port: u16,
259
260 pub directory: Option<PathBuf>,
262
263 pub max_upload: usize,
265
266 pub ip: IpAddr,
268
269 pub db_type: Arc<db::DatabaseType>,
271
272 pub started: SystemTime,
274
275 pub db_config: MDBConfig,
277
278 pub(crate) keys: HashMap<u32, FileEncryption>,
280
281 #[cfg(feature = "vt")]
283 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
284
285 tls_config: Option<RustlsConfig>,
287
288 mdns: Option<ServiceDaemon>,
290}
291
292impl State {
293 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
300 if let Some(dest_path) = &self.directory {
301 let mut hasher = Sha256::new();
302 hasher.update(data);
303 let sha256 = hex::encode(hasher.finalize());
304
305 let hashed_path = format!(
309 "{}/{}/{}/{}",
310 &sha256[0..2],
311 &sha256[2..4],
312 &sha256[4..6],
313 sha256
314 );
315
316 let mut dest_path = dest_path.clone();
319 dest_path.push(hashed_path);
320
321 let mut just_the_dir = dest_path.clone();
323 just_the_dir.pop();
324 std::fs::create_dir_all(just_the_dir)?;
325
326 let data = if self.db_config.compression {
327 let buff = Cursor::new(data);
328 let mut compressed = Vec::with_capacity(data.len() / 2);
329 zstd::stream::copy_encode(buff, &mut compressed, 4)?;
330 compressed
331 } else {
332 data.to_vec()
333 };
334
335 let data = if let Some(key_id) = self.db_config.default_key {
336 if let Some(key) = self.keys.get(&key_id) {
337 let nonce = key.nonce();
338 self.db_type
339 .set_file_nonce(&sha256, nonce.as_deref())
340 .await?;
341 key.encrypt(&data, nonce)?
342 } else {
343 bail!("Key not available!")
344 }
345 } else {
346 data
347 };
348
349 std::fs::write(dest_path, data)?;
350
351 Ok(true)
352 } else {
353 Ok(false)
354 }
355 }
356
357 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
365 if let Some(dest_path) = &self.directory {
366 let path = format!(
367 "{}/{}/{}/{}",
368 &sha256[0..2],
369 &sha256[2..4],
370 &sha256[4..6],
371 sha256
372 );
373 let contents = std::fs::read(dest_path.join(path))?;
378
379 let contents = if self.keys.is_empty() {
380 contents
382 } else {
383 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
384 if let Some(key_id) = key_id {
385 if let Some(key) = self.keys.get(&key_id) {
386 key.decrypt(&contents, nonce)?
387 } else {
388 bail!("File was encrypted but we don't have tke key!")
389 }
390 } else {
391 contents
393 }
394 };
395
396 if contents.starts_with(&GZIP_MAGIC) {
397 let buff = Cursor::new(contents);
398 let mut decompressor = GzDecoder::new(buff);
399 let mut decompressed: Vec<u8> = vec![];
400 decompressor.read_to_end(&mut decompressed)?;
401 Ok(decompressed)
402 } else if contents.starts_with(&ZSTD_MAGIC) {
403 let buff = Cursor::new(contents);
404 let mut decompressed: Vec<u8> = vec![];
405 zstd::stream::copy_decode(buff, &mut decompressed)?;
406 Ok(decompressed)
407 } else {
408 Ok(contents)
409 }
410 } else {
411 bail!("files are not saved")
412 }
413 }
414
415 #[must_use]
421 pub fn since(&self) -> Duration {
422 let now = SystemTime::now();
423 now.duration_since(self.started).unwrap()
424 }
425
426 pub async fn get_info(&self) -> Result<ServerInfo> {
432 let db_info = self.db_type.db_info().await?;
433 let uptime = Local::now() - self.since();
434 let mem_size = app_memory_usage_fetcher::get_memory_usage_string().unwrap_or_default();
435
436 Ok(ServerInfo {
437 os_name: std::env::consts::OS.into(),
438 memory_used: mem_size,
439 num_samples: db_info.num_files,
440 num_users: db_info.num_users,
441 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
442 mdb_version: MDB_VERSION_SEMVER.clone(),
443 db_version: db_info.version,
444 db_size: db_info.size,
445 instance_name: self.db_config.name.clone(),
446 })
447 }
448
449 pub async fn serve(
457 self,
458 #[cfg(target_family = "windows")] rx: Option<tokio::sync::mpsc::Receiver<()>>,
459 ) -> Result<()> {
460 let socket = SocketAddr::new(self.ip, self.port);
461 let arc_self = Arc::new(self);
462 let db_info = arc_self.db_type.clone();
463
464 tokio::spawn(async move {
465 loop {
466 match db_info.cleanup().await {
467 Ok(removed) => {
468 trace!("Pagination cleanup succeeded, {removed} searches removed");
469 }
470 Err(e) => warn!("Pagination cleanup failed: {e}"),
471 }
472
473 tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
474 }
475 });
476
477 if arc_self.mdns.is_some() && !arc_self.ip.is_loopback() {
478 if let Err(e) = arc_self.mdns_register().await {
479 warn!("Failed to register MDNS service: {e}");
480 }
481 }
482
483 if let Some(tls_config) = arc_self.tls_config.clone() {
484 println!("Listening on https://{socket:?}");
485 let handle = axum_server::Handle::<SocketAddr>::new();
486 let server_future = axum_server::bind_rustls(socket, tls_config)
487 .serve(http::app(arc_self).into_make_service());
488 tokio::select! {
489 () = shutdown_signal(#[cfg(target_family = "windows")]rx) =>
490 handle.graceful_shutdown(Some(Duration::from_secs(30))),
491 res = server_future => res?,
492 }
493 warn!("Terminate signal received");
494 } else {
495 println!("Listening on http://{socket:?}");
496 let listener = TcpListener::bind(socket)
497 .await
498 .context(format!("failed to bind socket {socket}"))?;
499 axum::serve(listener, http::app(arc_self).into_make_service())
500 .with_graceful_shutdown(shutdown_signal(
501 #[cfg(target_family = "windows")]
502 rx,
503 ))
504 .await?;
505 warn!("Terminate signal received");
506 }
507 Ok(())
508 }
509
510 async fn mdns_register(&self) -> Result<()> {
513 if let Some(mdns) = &self.mdns {
514 let db_config = self.db_type.get_config().await?;
515 let host_name = format!("{}.local.", self.ip);
516 let ssl = self.tls_config.is_some();
517 let properties = [("ssl", ssl.to_string())];
518 let service = ServiceInfo::new(
519 malwaredb_api::MDNS_NAME,
520 &db_config.name,
521 &host_name,
522 self.ip.to_string(),
523 self.port,
524 &properties[..],
525 )?;
526 mdns.register(service)?;
527 }
528
529 Ok(())
530 }
531}
532
533impl Debug for State {
534 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
535 let tls_mode = if self.tls_config.is_some() {
536 ", TLS mode"
537 } else {
538 ""
539 };
540 write!(
541 f,
542 "MDB state, port {}, database {:?}{tls_mode}",
543 self.port, self.db_type
544 )
545 }
546}
547
548async fn shutdown_signal(
551 #[cfg(target_family = "windows")] mut rx: Option<tokio::sync::mpsc::Receiver<()>>,
552) {
553 let ctrl_c = async {
554 tokio::signal::ctrl_c()
555 .await
556 .expect("failed to install Ctrl+C handler");
557 };
558
559 #[cfg(unix)]
560 let terminate = async {
561 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
562 .expect("failed to install signal handler")
563 .recv()
564 .await;
565 };
566
567 #[cfg(not(unix))]
568 let terminate = std::future::pending::<()>();
569
570 #[cfg(target_family = "windows")]
571 if let Some(rx_inner) = &mut rx {
572 let terminate_rx = rx_inner.recv();
573
574 tokio::select! {
575 () = ctrl_c => {},
576 () = terminate => {},
577 Some(()) = terminate_rx => {},
578 }
579 } else {
580 tokio::select! {
581 () = ctrl_c => {},
582 () = terminate => {},
583 }
584 }
585
586 #[cfg(not(target_family = "windows"))]
587 tokio::select! {
588 () = ctrl_c => {},
589 () = terminate => {},
590 }
591}