1#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8
9pub mod crypto;
11
12pub mod db;
14
15pub mod http;
17
18pub mod utils;
20
21#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
23#[cfg(feature = "vt")]
24pub mod vt;
25
26use crate::crypto::FileEncryption;
27use crate::db::MDBConfig;
28use malwaredb_api::ServerInfo;
29use std::collections::HashMap;
32use std::fmt::{Debug, Formatter};
33use std::io::{Cursor, Read};
34use std::net::{IpAddr, SocketAddr};
35use std::path::PathBuf;
36use std::sync::{Arc, LazyLock};
37use std::time::{Duration, SystemTime};
38
39use anyhow::{bail, ensure, Context, Result};
40use axum_server::tls_rustls::RustlsConfig;
41use chrono::Local;
42use chrono_humanize::{Accuracy, HumanTime, Tense};
43use flate2::read::GzDecoder;
44use mdns_sd::{ServiceDaemon, ServiceInfo};
45use sha2::{Digest, Sha256};
46use tokio::net::TcpListener;
47use tracing::{debug, error, trace, warn};
48
49pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
51
52pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
54 LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
55
56pub(crate) const DB_CLEANUP_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
59
60pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
62
63pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
65
66pub struct State {
68 pub port: u16,
70
71 pub directory: Option<PathBuf>,
73
74 pub max_upload: usize,
76
77 pub ip: IpAddr,
79
80 pub db_type: Arc<db::DatabaseType>,
82
83 pub started: SystemTime,
85
86 pub db_config: MDBConfig,
88
89 pub(crate) keys: HashMap<u32, FileEncryption>,
91
92 #[cfg(feature = "vt")]
94 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
95
96 cert: Option<PathBuf>,
99
100 key: Option<PathBuf>,
102
103 mdns: bool,
105}
106
107impl State {
108 #[allow(clippy::too_many_arguments)]
115 pub async fn new(
116 port: u16,
117 directory: Option<PathBuf>,
118 max_upload: usize,
119 ip: IpAddr,
120 db_string: &str,
121 cert: Option<PathBuf>,
122 key: Option<PathBuf>,
123 pg_cert: Option<PathBuf>,
124 mdns: bool,
125 #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
126 ) -> Result<Self> {
127 if let Some(dir) = &directory {
128 if !dir.exists() {
129 bail!("data directory {} does not exist!", dir.display());
130 }
131 }
132
133 if cert.is_some() != key.is_some() {
134 bail!("Either both the https cert & key are provided, or none are provided.");
135 }
136
137 if let Some(cert_file) = &cert {
138 ensure!(
139 cert_file.exists(),
140 "Certificate file {} does not exist!",
141 cert_file.display()
142 );
143 }
144
145 if let Some(key_file) = &key {
146 ensure!(
147 key_file.exists(),
148 "Key file {} does not exist!",
149 key_file.display()
150 );
151 }
152
153 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
154 let db_config = db_type.get_config().await?;
156 let keys = db_type.get_encryption_keys().await?;
157
158 Ok(Self {
159 port,
160 directory,
161 max_upload,
162 ip,
163 db_type: Arc::new(db_type),
164 db_config,
165 keys,
166 #[cfg(feature = "vt")]
167 vt_client,
168 started: SystemTime::now(),
169 cert,
170 key,
171 mdns,
172 })
173 }
174
175 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
182 if let Some(dest_path) = &self.directory {
183 let mut hasher = Sha256::new();
184 hasher.update(data);
185 let sha256 = hex::encode(hasher.finalize());
186
187 let hashed_path = format!(
191 "{}/{}/{}/{}",
192 &sha256[0..2],
193 &sha256[2..4],
194 &sha256[4..6],
195 sha256
196 );
197
198 let mut dest_path = dest_path.clone();
201 dest_path.push(hashed_path);
202
203 let mut just_the_dir = dest_path.clone();
205 just_the_dir.pop();
206 std::fs::create_dir_all(just_the_dir)?;
207
208 let data = if self.db_config.compression {
209 let buff = Cursor::new(data);
210 let mut compressed = Vec::with_capacity(data.len() / 2);
211 zstd::stream::copy_encode(buff, &mut compressed, 4)?;
212 compressed
213 } else {
214 data.to_vec()
215 };
216
217 let data = if let Some(key_id) = self.db_config.default_key {
218 if let Some(key) = self.keys.get(&key_id) {
219 let nonce = key.nonce();
220 self.db_type
221 .set_file_nonce(&sha256, nonce.as_deref())
222 .await?;
223 key.encrypt(&data, nonce)?
224 } else {
225 bail!("Key not available!")
226 }
227 } else {
228 data
229 };
230
231 std::fs::write(dest_path, data)?;
232
233 Ok(true)
234 } else {
235 Ok(false)
236 }
237 }
238
239 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
247 if let Some(dest_path) = &self.directory {
248 let path = format!(
249 "{}/{}/{}/{}",
250 &sha256[0..2],
251 &sha256[2..4],
252 &sha256[4..6],
253 sha256
254 );
255 let contents = std::fs::read(dest_path.join(path))?;
260
261 let contents = if self.keys.is_empty() {
262 contents
264 } else {
265 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
266 if let Some(key_id) = key_id {
267 if let Some(key) = self.keys.get(&key_id) {
268 key.decrypt(&contents, nonce)?
269 } else {
270 bail!("File was encrypted but we don't have tke key!")
271 }
272 } else {
273 contents
275 }
276 };
277
278 if contents.starts_with(&GZIP_MAGIC) {
279 let buff = Cursor::new(contents);
280 let mut decompressor = GzDecoder::new(buff);
281 let mut decompressed: Vec<u8> = vec![];
282 decompressor.read_to_end(&mut decompressed)?;
283 Ok(decompressed)
284 } else if contents.starts_with(&ZSTD_MAGIC) {
285 let buff = Cursor::new(contents);
286 let mut decompressed: Vec<u8> = vec![];
287 zstd::stream::copy_decode(buff, &mut decompressed)?;
288 Ok(decompressed)
289 } else {
290 Ok(contents)
291 }
292 } else {
293 bail!("files are not saved")
294 }
295 }
296
297 #[must_use]
303 pub fn since(&self) -> Duration {
304 let now = SystemTime::now();
305 now.duration_since(self.started).unwrap()
306 }
307
308 pub async fn get_info(&self) -> Result<ServerInfo> {
314 let db_info = self.db_type.db_info().await?;
315 let uptime = Local::now() - self.since();
316 let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
317 humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
318 } else {
319 String::new()
320 };
321
322 Ok(ServerInfo {
323 os_name: std::env::consts::OS.into(),
324 memory_used: mem_size,
325 num_samples: db_info.num_files,
326 num_users: db_info.num_users,
327 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
328 mdb_version: MDB_VERSION_SEMVER.clone(),
329 db_version: db_info.version,
330 db_size: db_info.size,
331 instance_name: self.db_config.name.clone(),
332 })
333 }
334
335 pub async fn serve(self) -> Result<()> {
347 let socket = SocketAddr::new(self.ip, self.port);
348 let arc_self = Arc::new(self);
349 let db_ifo = arc_self.db_type.clone();
350
351 tokio::spawn(async move {
352 loop {
353 match db_ifo.cleanup().await {
354 Ok(removed) => {
355 trace!("Pagination cleanup succeeded, {removed} searches removed");
356 }
357 Err(e) => warn!("Pagination cleanup failed: {e}"),
358 }
359
360 tokio::time::sleep(DB_CLEANUP_INTERVAL).await;
361 }
362 });
363
364 if arc_self.mdns {
365 if arc_self.ip.is_loopback() {
366 debug!("Refusing to start mdns responder for localhost");
367 } else {
368 trace!("Enabling MDNS advertising");
369 match ServiceDaemon::new() {
370 Ok(mdns) => {
371 let host_name = format!("{}.local.", arc_self.ip);
372 let ssl = arc_self.cert.is_some() && arc_self.key.is_some();
373 let properties = [("ssl", ssl.to_string())];
374 match ServiceInfo::new(
375 malwaredb_api::MDNS_NAME,
376 &arc_self.db_config.name,
377 &host_name,
378 arc_self.ip.to_string(),
379 arc_self.port,
380 &properties[..],
381 ) {
382 Ok(service) => match mdns.register(service) {
383 Ok(()) => trace!("MalwareDB mdns registered"),
384 Err(e) => error!("Failed to register service: {e}"),
385 },
386 Err(e) => error!("Failed to publish MDNS service: {e}"),
387 }
388 }
389 Err(e) => error!("Failed to open port for MDNS responder: {e}"),
390 }
391 }
392 }
393
394 if arc_self.cert.is_some() && arc_self.key.is_some() {
395 let cert_path = arc_self.cert.as_ref().unwrap();
396 let key_path = arc_self.key.as_ref().unwrap();
397 let cert_ext_str = cert_path
398 .extension()
399 .context("failed to get certificate extension")?;
400 let key_ext_str = key_path
401 .extension()
402 .context("failed to get key extension")?;
403
404 if rustls::crypto::CryptoProvider::get_default().is_none() {
406 rustls::crypto::aws_lc_rs::default_provider()
407 .install_default()
408 .expect("Failed to load crypto provider");
409 }
410
411 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
412 {
413 RustlsConfig::from_pem_file(cert_path, key_path)
414 .await
415 .context("failed to load or parse certificate and key files")?
416 } else if cert_ext_str == "der" && key_ext_str == "der" {
417 let cert_contents =
418 std::fs::read(cert_path).context("failed to read certificate file")?;
419 let key_contents =
420 std::fs::read(key_path).context("failed to read certificate file")?;
421 RustlsConfig::from_der(vec![cert_contents], key_contents)
422 .await
423 .context("failed to parse certificate and key files as DER")?
424 } else {
425 bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
426 };
427
428 println!("Listening on https://{socket:?}");
429 axum_server::bind_rustls(socket, config)
430 .serve(http::app(arc_self).into_make_service())
431 .await?;
432 } else {
433 println!("Listening on http://{socket:?}");
434 let listener = TcpListener::bind(socket)
435 .await
436 .context(format!("failed to bind socket {socket}"))?;
437 axum::serve(listener, http::app(arc_self).into_make_service()).await?;
438 }
439 Ok(())
440 }
441}
442
443impl Debug for State {
444 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
445 write!(
446 f,
447 "MDB state, port {}, database {:?}",
448 self.port, self.db_type
449 )
450 }
451}