1#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8pub mod crypto;
10
11pub mod db;
13
14pub mod http;
16
17pub mod utils;
19
20#[cfg(feature = "vt")]
22pub mod vt;
23
24use crate::crypto::FileEncryption;
25use crate::db::MDBConfig;
26use malwaredb_api::ServerInfo;
27use std::collections::HashMap;
30use std::fmt::{Debug, Formatter};
31use std::io::{Cursor, Read};
32use std::net::{IpAddr, SocketAddr};
33use std::path::PathBuf;
34use std::sync::Arc;
35use std::time::{Duration, SystemTime};
36
37use anyhow::{bail, ensure, Context, Result};
38use axum_server::tls_rustls::RustlsConfig;
39use chrono::Local;
40use chrono_humanize::{Accuracy, HumanTime, Tense};
41use flate2::read::GzDecoder;
42use sha2::{Digest, Sha256};
43use tokio::net::TcpListener;
44
45pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
47
48pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
50
51pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
53
54pub struct State {
56 pub port: u16,
58
59 pub directory: Option<PathBuf>,
61
62 pub max_upload: usize,
64
65 pub ip: IpAddr,
67
68 pub db_type: db::DatabaseType,
70
71 pub started: SystemTime,
73
74 pub db_config: MDBConfig,
76
77 pub(crate) keys: HashMap<u32, FileEncryption>,
79
80 #[cfg(feature = "vt")]
82 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
83
84 cert: Option<PathBuf>,
87
88 key: Option<PathBuf>,
90}
91
92impl State {
93 #[allow(clippy::too_many_arguments)]
100 pub async fn new(
101 port: u16,
102 directory: Option<PathBuf>,
103 max_upload: usize,
104 ip: IpAddr,
105 db_string: &str,
106 cert: Option<PathBuf>,
107 key: Option<PathBuf>,
108 pg_cert: Option<PathBuf>,
109 #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
110 ) -> Result<Self> {
111 if let Some(dir) = &directory {
112 if !dir.exists() {
113 bail!("data directory {dir:?} does not exist!");
114 }
115 }
116
117 if cert.is_some() != key.is_some() {
118 bail!("Either both the https cert & key are provided, or none are provided.");
119 }
120
121 if let Some(cert_file) = &cert {
122 ensure!(
123 cert_file.exists(),
124 "Certificate file {cert_file:?} does not exist!"
125 );
126 }
127
128 if let Some(key_file) = &key {
129 ensure!(key_file.exists(), "Key file {key_file:?} does not exist!");
130 }
131
132 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
133 let db_config = db_type.get_config().await?;
135 let keys = db_type.get_encryption_keys().await?;
136
137 Ok(Self {
138 port,
139 directory,
140 max_upload,
141 ip,
142 db_type,
143 db_config,
144 keys,
145 #[cfg(feature = "vt")]
146 vt_client,
147 started: SystemTime::now(),
148 cert,
149 key,
150 })
151 }
152
153 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
160 if let Some(dest_path) = &self.directory {
161 let mut hasher = Sha256::new();
162 hasher.update(data);
163 let sha256 = hex::encode(hasher.finalize());
164
165 let hashed_path = format!(
169 "{}/{}/{}/{}",
170 &sha256[0..2],
171 &sha256[2..4],
172 &sha256[4..6],
173 sha256
174 );
175
176 let mut dest_path = dest_path.clone();
179 dest_path.push(hashed_path);
180
181 let mut just_the_dir = dest_path.clone();
183 just_the_dir.pop();
184 std::fs::create_dir_all(just_the_dir)?;
185
186 let data = if self.db_config.compression {
187 let buff = Cursor::new(data);
188 let mut compressed = Vec::with_capacity(data.len() / 2);
189 zstd::stream::copy_encode(buff, &mut compressed, 4)?;
190 compressed
191 } else {
192 data.to_vec()
193 };
194
195 let data = if let Some(key_id) = self.db_config.default_key {
196 if let Some(key) = self.keys.get(&key_id) {
197 let nonce = key.nonce();
198 self.db_type
199 .set_file_nonce(&sha256, nonce.as_deref())
200 .await?;
201 key.encrypt(&data, nonce)?
202 } else {
203 bail!("Key not available!")
204 }
205 } else {
206 data
207 };
208
209 std::fs::write(dest_path, data)?;
210
211 Ok(true)
212 } else {
213 Ok(false)
214 }
215 }
216
217 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
225 if let Some(dest_path) = &self.directory {
226 let path = format!(
227 "{}/{}/{}/{}",
228 &sha256[0..2],
229 &sha256[2..4],
230 &sha256[4..6],
231 sha256
232 );
233 let contents = std::fs::read(dest_path.join(path))?;
238
239 let contents = if self.keys.is_empty() {
240 contents
242 } else {
243 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
244 if let Some(key_id) = key_id {
245 if let Some(key) = self.keys.get(&key_id) {
246 key.decrypt(&contents, nonce)?
247 } else {
248 bail!("File was encrypted but we don't have tke key!")
249 }
250 } else {
251 contents
253 }
254 };
255
256 if contents.starts_with(&GZIP_MAGIC) {
257 let buff = Cursor::new(contents);
258 let mut decompressor = GzDecoder::new(buff);
259 let mut decompressed: Vec<u8> = vec![];
260 decompressor.read_to_end(&mut decompressed)?;
261 Ok(decompressed)
262 } else if contents.starts_with(&ZSTD_MAGIC) {
263 let buff = Cursor::new(contents);
264 let mut decompressed: Vec<u8> = vec![];
265 zstd::stream::copy_decode(buff, &mut decompressed)?;
266 Ok(decompressed)
267 } else {
268 Ok(contents)
269 }
270 } else {
271 bail!("files are not saved")
272 }
273 }
274
275 #[must_use]
281 pub fn since(&self) -> Duration {
282 let now = SystemTime::now();
283 now.duration_since(self.started).unwrap()
284 }
285
286 pub async fn get_info(&self) -> Result<ServerInfo> {
292 let db_info = self.db_type.db_info().await?;
293 let uptime = Local::now() - self.since();
294 let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
295 humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
296 } else {
297 String::new()
298 };
299
300 Ok(ServerInfo {
301 os_name: std::env::consts::OS.into(),
302 memory_used: mem_size,
303 num_samples: db_info.num_files,
304 num_users: db_info.num_users,
305 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
306 mdb_version: MDB_VERSION.into(),
307 db_version: db_info.version,
308 db_size: db_info.size,
309 })
310 }
311
312 pub async fn serve(self) -> Result<()> {
324 let socket = SocketAddr::new(self.ip, self.port);
325
326 if self.cert.is_some() && self.key.is_some() {
327 let cert_path = self.cert.as_ref().unwrap();
328 let key_path = self.key.as_ref().unwrap();
329 let cert_ext_str = cert_path
330 .extension()
331 .context("failed to get certificate extension")?;
332 let key_ext_str = key_path
333 .extension()
334 .context("failed to get key extension")?;
335
336 if rustls::crypto::CryptoProvider::get_default().is_none() {
338 rustls::crypto::ring::default_provider()
339 .install_default()
340 .expect("Failed to load crypto provider");
341 }
342
343 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
344 {
345 RustlsConfig::from_pem_file(cert_path, key_path)
346 .await
347 .context("failed to load or parse certificate and key files")?
348 } else if cert_ext_str == "der" && key_ext_str == "der" {
349 let cert_contents =
350 std::fs::read(cert_path).context("failed to read certificate file")?;
351 let key_contents =
352 std::fs::read(key_path).context("failed to read certificate file")?;
353 RustlsConfig::from_der(vec![cert_contents], key_contents)
354 .await
355 .context("failed to parse certificate and key files as DER")?
356 } else {
357 bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
358 };
359
360 println!("Listening on https://{socket:?}");
361 axum_server::bind_rustls(socket, config)
362 .serve(http::app(Arc::new(self)).into_make_service())
363 .await?;
364 } else {
365 println!("Listening on http://{socket:?}");
366 let listener = TcpListener::bind(socket)
367 .await
368 .context(format!("failed to bind socket {socket}"))?;
369 axum::serve(listener, http::app(Arc::new(self)).into_make_service()).await?;
370 }
371 Ok(())
372 }
373}
374
375impl Debug for State {
376 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
377 write!(
378 f,
379 "MDB state, port {}, database {:?}",
380 self.port, self.db_type
381 )
382 }
383}