1#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8pub mod crypto;
10
11pub mod db;
13
14pub mod http;
16
17pub mod utils;
19
20#[cfg(feature = "vt")]
22pub mod vt;
23
24use crate::crypto::FileEncryption;
25use crate::db::MDBConfig;
26use malwaredb_api::ServerInfo;
27use std::collections::HashMap;
30use std::fmt::{Debug, Formatter};
31use std::io::{Cursor, Read, Write};
32use std::net::{IpAddr, SocketAddr};
33use std::path::PathBuf;
34use std::sync::Arc;
35use std::time::{Duration, SystemTime};
36
37use anyhow::{bail, ensure, Context, Result};
38use axum_server::tls_rustls::RustlsConfig;
39use chrono::Local;
40use chrono_humanize::{Accuracy, HumanTime, Tense};
41use flate2::read::GzDecoder;
42use flate2::write::GzEncoder;
43use flate2::Compression;
44use sha2::{Digest, Sha256};
45use tokio::net::TcpListener;
46
47pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
49
50pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
52
53pub struct State {
55 pub port: u16,
57
58 pub directory: Option<PathBuf>,
60
61 pub max_upload: usize,
63
64 pub ip: IpAddr,
66
67 pub db_type: db::DatabaseType,
69
70 pub started: SystemTime,
72
73 pub db_config: MDBConfig,
75
76 pub(crate) keys: HashMap<u32, FileEncryption>,
78
79 #[cfg(feature = "vt")]
81 pub(crate) vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
82
83 cert: Option<PathBuf>,
86
87 key: Option<PathBuf>,
89}
90
91impl State {
92 #[allow(clippy::too_many_arguments)]
99 pub async fn new(
100 port: u16,
101 directory: Option<PathBuf>,
102 max_upload: usize,
103 ip: IpAddr,
104 db_string: &str,
105 cert: Option<PathBuf>,
106 key: Option<PathBuf>,
107 pg_cert: Option<PathBuf>,
108 #[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
109 ) -> Result<Self> {
110 if let Some(dir) = &directory {
111 if !dir.exists() {
112 bail!("data directory {dir:?} does not exist!");
113 }
114 }
115
116 if cert.is_some() != key.is_some() {
117 bail!("Either both the https cert & key are provided, or none are provided.");
118 }
119
120 if let Some(cert_file) = &cert {
121 ensure!(
122 cert_file.exists(),
123 "Certificate file {cert_file:?} does not exist!"
124 );
125 }
126
127 if let Some(key_file) = &key {
128 ensure!(key_file.exists(), "Key file {key_file:?} does not exist!");
129 }
130
131 let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
132 let db_config = db_type.get_config().await?;
134 let keys = db_type.get_encryption_keys().await?;
135
136 Ok(Self {
137 port,
138 directory,
139 max_upload,
140 ip,
141 db_type,
142 db_config,
143 keys,
144 #[cfg(feature = "vt")]
145 vt_client,
146 started: SystemTime::now(),
147 cert,
148 key,
149 })
150 }
151
152 pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
159 if let Some(dest_path) = &self.directory {
160 let mut hasher = Sha256::new();
161 hasher.update(data);
162 let sha256 = hex::encode(hasher.finalize());
163
164 let hashed_path = format!(
168 "{}/{}/{}/{}",
169 &sha256[0..2],
170 &sha256[2..4],
171 &sha256[4..6],
172 sha256
173 );
174
175 let mut dest_path = dest_path.clone();
178 dest_path.push(hashed_path);
179
180 let mut just_the_dir = dest_path.clone();
182 just_the_dir.pop();
183 std::fs::create_dir_all(just_the_dir)?;
184
185 let data = if self.db_config.compression {
186 let mut compressor = GzEncoder::new(Vec::new(), Compression::default());
187 compressor.write_all(data)?;
188 compressor.finish()?
189 } else {
190 data.to_vec()
191 };
192
193 let data = if let Some(key_id) = self.db_config.default_key {
194 if let Some(key) = self.keys.get(&key_id) {
195 let nonce = key.nonce();
196 self.db_type
197 .set_file_nonce(&sha256, nonce.as_deref())
198 .await?;
199 key.encrypt(&data, nonce)?
200 } else {
201 bail!("Key not available!")
202 }
203 } else {
204 data
205 };
206
207 std::fs::write(dest_path, data)?;
208
209 Ok(true)
210 } else {
211 Ok(false)
212 }
213 }
214
215 pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
223 if let Some(dest_path) = &self.directory {
224 let path = format!(
225 "{}/{}/{}/{}",
226 &sha256[0..2],
227 &sha256[2..4],
228 &sha256[4..6],
229 sha256
230 );
231 let contents = std::fs::read(dest_path.join(path))?;
236
237 let contents = if self.keys.is_empty() {
238 contents
240 } else {
241 let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
242 if let Some(key_id) = key_id {
243 if let Some(key) = self.keys.get(&key_id) {
244 key.decrypt(&contents, nonce)?
245 } else {
246 bail!("File was encrypted but we don't have tke key!")
247 }
248 } else {
249 contents
251 }
252 };
253
254 if contents.starts_with(&GZIP_MAGIC) {
255 let buff = Cursor::new(contents);
256 let mut decompressor = GzDecoder::new(buff);
257 let mut decompressed: Vec<u8> = vec![];
258 decompressor.read_to_end(&mut decompressed)?;
259 Ok(decompressed)
260 } else {
261 Ok(contents)
262 }
263 } else {
264 bail!("files are not saved")
265 }
266 }
267
268 #[must_use]
274 pub fn since(&self) -> Duration {
275 let now = SystemTime::now();
276 now.duration_since(self.started).unwrap()
277 }
278
279 pub async fn get_info(&self) -> Result<ServerInfo> {
285 let db_info = self.db_type.db_info().await?;
286 let uptime = Local::now() - self.since();
287 let mem_size = if let Some(mem_size) = app_memory_usage_fetcher::get_memory_usage_bytes() {
288 humansize::SizeFormatter::new(mem_size.get(), humansize::BINARY).to_string()
289 } else {
290 String::new()
291 };
292
293 Ok(ServerInfo {
294 os_name: std::env::consts::OS.into(),
295 memory_used: mem_size,
296 num_samples: db_info.num_files,
297 num_users: db_info.num_users,
298 uptime: HumanTime::from(uptime).to_text_en(Accuracy::Rough, Tense::Present),
299 mdb_version: MDB_VERSION.into(),
300 db_version: db_info.version,
301 db_size: db_info.size,
302 })
303 }
304
305 pub async fn serve(self) -> Result<()> {
317 let socket = SocketAddr::new(self.ip, self.port);
318
319 if self.cert.is_some() && self.key.is_some() {
320 let cert_path = self.cert.as_ref().unwrap();
321 let key_path = self.key.as_ref().unwrap();
322 let cert_ext_str = cert_path
323 .extension()
324 .context("failed to get certificate extension")?;
325 let key_ext_str = key_path
326 .extension()
327 .context("failed to get key extension")?;
328
329 if rustls::crypto::CryptoProvider::get_default().is_none() {
331 rustls::crypto::ring::default_provider()
332 .install_default()
333 .expect("Failed to load crypto provider");
334 }
335
336 let config = if (cert_ext_str == "pem" || cert_ext_str == "crt") && key_ext_str == "pem"
337 {
338 RustlsConfig::from_pem_file(cert_path, key_path)
339 .await
340 .context("failed to load or parse certificate and key files")?
341 } else if cert_ext_str == "der" && key_ext_str == "der" {
342 let cert_contents =
343 std::fs::read(cert_path).context("failed to read certificate file")?;
344 let key_contents =
345 std::fs::read(key_path).context("failed to read certificate file")?;
346 RustlsConfig::from_der(vec![cert_contents], key_contents)
347 .await
348 .context("failed to parse certificate and key files as DER")?
349 } else {
350 bail!("Unknown or unmatched certificate and key file extensions {cert_ext_str:?} and {key_ext_str:?}");
351 };
352
353 println!("Listening on https://{socket:?}");
354 axum_server::bind_rustls(socket, config)
355 .serve(http::app(Arc::new(self)).into_make_service())
356 .await?;
357 } else {
358 println!("Listening on http://{socket:?}");
359 let listener = TcpListener::bind(socket)
360 .await
361 .context(format!("failed to bind socket {socket}"))?;
362 axum::serve(listener, http::app(Arc::new(self)).into_make_service()).await?;
363 }
364 Ok(())
365 }
366}
367
368impl Debug for State {
369 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
370 write!(
371 f,
372 "MDB state, port {}, database {:?}",
373 self.port, self.db_type
374 )
375 }
376}