use super::State;
use crate::db::types::FileTypes;
use std::fmt::{Display, Formatter};
use std::io::Cursor;
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::{DefaultBodyLimit, Path};
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Extension, Json, Router};
use base64::{engine::general_purpose, Engine as _};
use constcat::concat;
use http::HeaderMap;
use sha2::{Digest, Sha256};
use malwaredb_api::digest::HashType;
use malwaredb_api::{
GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSample, Report,
ServerInfo, Sources, SupportedFileTypes, MDB_API_HEADER,
};
mod receive;
pub fn app(state: Arc<State>) -> Router {
Router::new()
.route("/", get(health))
.route(malwaredb_api::SERVER_INFO, get(get_mdb_info))
.route(malwaredb_api::USER_LOGIN_URL, post(user_login))
.route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
.route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
.route(
malwaredb_api::SUPPORTED_FILE_TYPES,
get(get_supported_types),
)
.route(malwaredb_api::LIST_LABELS, get(get_labels))
.route(malwaredb_api::LIST_SOURCES, get(get_sources))
.route(malwaredb_api::UPLOAD_SAMPLE, post(get_new_sample))
.route(
concat!(malwaredb_api::DOWNLOAD_SAMPLE, "/:hash"),
get(download_sample),
)
.route(
concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART, "/:hash"),
get(download_sample_cart),
)
.route(
concat!(malwaredb_api::SAMPLE_REPORT, "/:hash"),
get(sample_report),
)
.route(malwaredb_api::SIMILAR_SAMPLES, post(find_similar))
.layer(DefaultBodyLimit::max(state.max_upload))
.layer(Extension(state))
}
async fn health() -> StatusCode {
StatusCode::OK
}
async fn get_mdb_info(
Extension(state): Extension<Arc<State>>,
) -> Result<Json<ServerInfo>, HttpError> {
let server_info = state.get_info().await?;
Ok(Json(server_info))
}
async fn user_login(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<GetAPIKeyRequest>,
) -> Result<Json<GetAPIKeyResponse>, HttpError> {
let api_key = state
.db_type
.authenticate(&payload.user, &payload.password)
.await?;
Ok(Json(GetAPIKeyResponse {
key: Some(api_key),
message: None,
}))
}
async fn user_logout(
Extension(state): Extension<Arc<State>>,
headers: HeaderMap,
) -> Result<StatusCode, HttpError> {
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
state.db_type.reset_own_api_key(uid).await?;
Ok(StatusCode::OK)
}
async fn get_user_groups_sources(
Extension(state): Extension<Arc<State>>,
headers: HeaderMap,
) -> Result<Json<GetUserInfoResponse>, HttpError> {
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let groups_sources = state.db_type.get_user_info(uid).await?;
Ok(Json(groups_sources))
}
async fn get_supported_types(
Extension(state): Extension<Arc<State>>,
) -> Result<Json<SupportedFileTypes>, HttpError> {
let data_types = state.db_type.get_known_data_types().await?;
let file_types = FileTypes(data_types);
Ok(Json(file_types.into()))
}
async fn get_labels(
Extension(state): Extension<Arc<State>>,
headers: HeaderMap,
) -> Result<Json<Labels>, HttpError> {
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let _uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let labels = state.db_type.get_labels().await?;
Ok(Json(labels))
}
async fn get_sources(
Extension(state): Extension<Arc<State>>,
headers: HeaderMap,
) -> Result<Json<Sources>, HttpError> {
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let sources = state.db_type.get_user_sources(uid).await?;
Ok(Json(sources))
}
async fn get_new_sample(
Extension(state): Extension<Arc<State>>,
headers: HeaderMap,
Json(payload): Json<NewSample>,
) -> Result<StatusCode, HttpError> {
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let allowed = state
.db_type
.allowed_user_source(uid, payload.source_id as i32)
.await?;
if !allowed {
return Err(HttpError(
anyhow::Error::msg("Unauthorized"),
StatusCode::UNAUTHORIZED,
));
}
let received_hash = hex::decode(&payload.sha256)?;
let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
let mut hasher = Sha256::new();
hasher.update(&bytes);
let result = hasher.finalize();
if result[..] != received_hash[..] {
return Err(HttpError(
anyhow::Error::msg("Hash mismatch"),
StatusCode::NOT_ACCEPTABLE,
));
}
receive::incoming_sample(
state.clone(),
bytes,
uid,
payload.source_id as i32,
payload.file_name,
)
.await?;
Ok(StatusCode::OK)
}
async fn download_sample(
Path(hash): Path<String>,
headers: HeaderMap,
Extension(state): Extension<Arc<State>>,
) -> Result<impl IntoResponse, HttpError> {
if state.directory.is_none() {
return Err(HttpError(
anyhow::Error::msg("Server does not store samples"),
StatusCode::NOT_ACCEPTABLE,
));
}
let hash = HashType::try_from(hash)?;
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let sha256 = state.db_type.retrieve_sample(uid, &hash).await?;
let contents = state.retrieve_bytes(&sha256).await?;
let mut bytes = Bytes::from(contents).into_response();
let name_header_value = format!("attachment; filename=\"{sha256}\"");
bytes.headers_mut().insert(
header::CONTENT_DISPOSITION,
name_header_value.parse().unwrap(),
);
Ok(bytes)
}
async fn download_sample_cart(
Path(hash): Path<String>,
headers: HeaderMap,
Extension(state): Extension<Arc<State>>,
) -> Result<impl IntoResponse, HttpError> {
if state.directory.is_none() {
return Err(HttpError(
anyhow::Error::msg("Server does not store samples"),
StatusCode::NOT_ACCEPTABLE,
));
}
let hash = HashType::try_from(hash)?;
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let sha256 = state.db_type.retrieve_sample(uid, &hash).await?;
let report = state.db_type.get_sample_report(uid, &hash).await?;
let contents = state.retrieve_bytes(&sha256).await?;
let contents_cursor = Cursor::new(contents);
let mut output_cursor = Cursor::new(vec![]);
let mut output_metadata = cart::cart::JsonMap::new();
output_metadata.insert("sha384".into(), report.sha384.into());
output_metadata.insert("sha512".into(), report.sha512.into());
output_metadata.insert("entropy".into(), report.entropy.into());
if let Some(filecmd) = report.filecommand {
output_metadata.insert("file".into(), filecmd.into());
}
cart::cart::pack_stream(
contents_cursor,
&mut output_cursor,
Some(output_metadata),
None,
cart::digesters::default_digesters(), None,
)?;
let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
bytes.headers_mut().insert(
header::CONTENT_DISPOSITION,
name_header_value.parse().unwrap(),
);
Ok(bytes)
}
async fn sample_report(
Path(hash): Path<String>,
headers: HeaderMap,
Extension(state): Extension<Arc<State>>,
) -> Result<Json<Report>, HttpError> {
let hash = HashType::try_from(hash)?;
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let report = state.db_type.get_sample_report(uid, &hash).await?;
Ok(Json(report))
}
async fn find_similar(
Extension(state): Extension<Arc<State>>,
headers: HeaderMap,
Json(payload): Json<malwaredb_api::SimilarSamplesRequest>,
) -> Result<Json<malwaredb_api::SimilarSamplesResponse>, HttpError> {
let key = headers.get(malwaredb_api::MDB_API_HEADER).ok_or(HttpError(
anyhow::Error::msg("Missing API key"),
StatusCode::NOT_ACCEPTABLE,
))?;
let uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
let results = state
.db_type
.find_similar_samples(uid, &payload.hashes)
.await?;
Ok(Json(malwaredb_api::SimilarSamplesResponse {
results,
message: None,
}))
}
pub struct HttpError(pub anyhow::Error, pub StatusCode);
impl IntoResponse for HttpError {
fn into_response(self) -> axum::response::Response {
(self.1, format!("MDB error: {}", self.0)).into_response()
}
}
impl Display for HttpError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl<E> From<E> for HttpError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into(), StatusCode::INTERNAL_SERVER_ERROR)
}
}
#[cfg(test)]
mod tests {
use malwaredb_client::MdbClient;
use super::*;
use crate::crypto::{EncryptionOption, FileEncryption};
use crate::db::DatabaseType;
use std::collections::HashMap;
use std::time::SystemTime;
use std::{env, fs};
use anyhow::Context;
use axum::body::Body;
use axum::http::Request;
use chrono::Local;
use http::header::CONTENT_TYPE;
use http_body_util::BodyExt;
use rstest::rstest;
use tower::ServiceExt;
use uuid::Uuid;
const ADMIN_UNAME: &str = "admin";
const ADMIN_PASSWORD: &str = "password12345";
async fn state(compress: bool, encrypt: bool) -> (Arc<State>, i32) {
let mut db_file = env::temp_dir();
db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
if std::path::Path::new(&db_file).exists() {
fs::remove_file(&db_file)
.context(format!("failed to delete old SQLite file {db_file:?}"))
.unwrap();
}
let db_type = DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()))
.await
.context(format!("failed to create SQLite instance for {db_file:?}"))
.unwrap();
if compress {
db_type.enable_compression().await.unwrap();
}
let keys = if encrypt {
let key = FileEncryption::from(EncryptionOption::Xor);
let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
let mut keys = HashMap::new();
keys.insert(key_id, key);
keys
} else {
HashMap::new()
};
let db_config = db_type.get_config().await.unwrap();
let state = State {
port: 8080,
directory: Some(
tempfile::TempDir::with_prefix("mdb-temp-samples")
.unwrap()
.into_path(),
),
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type,
db_config,
keys,
started: SystemTime::now(),
#[cfg(feature = "vt")]
vt_client: None,
};
state
.db_type
.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("Failed to set admin password")
.unwrap();
let source_id = state
.db_type
.create_source("temp-source", None, None, Local::now(), true, Some(false))
.await
.unwrap();
state
.db_type
.add_group_to_source(0, source_id)
.await
.unwrap();
(Arc::new(state), source_id)
}
async fn state_and_token() -> (State, u32, String) {
let mut db_file = env::temp_dir();
db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
if std::path::Path::new(&db_file).exists() {
fs::remove_file(&db_file)
.context(format!("failed to delete old SQLite file {db_file:?}"))
.unwrap();
}
let db_type = DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()))
.await
.context(format!("failed to create SQLite instance for {db_file:?}"))
.unwrap();
let db_config = db_type.get_config().await.unwrap();
let state = State {
port: 8080,
directory: Some(
tempfile::TempDir::with_prefix("mdb-temp-samples")
.unwrap()
.into_path(),
),
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type,
db_config,
keys: Default::default(),
started: SystemTime::now(),
#[cfg(feature = "vt")]
vt_client: None,
};
state
.db_type
.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("Failed to set admin password")
.unwrap();
let source_id = state
.db_type
.create_source("temp-source", None, None, Local::now(), true, Some(false))
.await
.unwrap();
state
.db_type
.add_group_to_source(0, source_id)
.await
.unwrap();
let token = state
.db_type
.authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.unwrap();
(state, source_id as u32, token)
}
async fn get_key(state: Arc<State>) -> String {
let key_request = serde_json::to_string(&GetAPIKeyRequest {
user: ADMIN_UNAME.into(),
password: ADMIN_PASSWORD.into(),
})
.context("Failed to convert API key request to JSON")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::USER_LOGIN_URL)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(key_request))
.unwrap();
let response = app(state)
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: GetAPIKeyResponse = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
let key = response.key.clone().unwrap();
assert_eq!(key.len(), 64);
key
}
#[tokio::test]
async fn about_self() {
let (state, _) = state(false, false).await;
let api_key = get_key(state.clone()).await;
let request = Request::builder()
.method("GET")
.uri(malwaredb_api::USER_INFO_URL)
.header(MDB_API_HEADER, &api_key)
.body(Body::empty())
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: GetUserInfoResponse = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
assert_eq!(response.id, 0);
assert!(response.is_admin);
assert_eq!(response.username, "admin");
let request = Request::builder()
.method("GET")
.uri(malwaredb_api::LIST_LABELS)
.header(MDB_API_HEADER, &api_key)
.body(Body::empty())
.unwrap();
let response = app(state)
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: Labels = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
assert!(response.is_empty());
}
#[rstest]
#[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true)]
#[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false)]
#[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false)]
#[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false)]
#[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false)]
#[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false)]
#[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false)]
#[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false)]
#[tokio::test]
async fn submit_sample(
#[case] contents: &[u8],
#[case] compress: bool,
#[case] encrypt: bool,
#[case] cart: bool,
) {
let (state, source_id) = state(compress, encrypt).await;
let api_key = get_key(state.clone()).await;
let file_contents_b64 = general_purpose::STANDARD.encode(contents);
let mut hasher = Sha256::new();
hasher.update(contents);
let sha256 = hex::encode(hasher.finalize());
let upload = serde_json::to_string(&NewSample {
file_name: "some_sample".into(),
source_id: source_id as u32,
file_contents_b64,
sha256: sha256.clone(),
})
.context("failed to create upload structure")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::UPLOAD_SAMPLE)
.header(CONTENT_TYPE, "application/json")
.header(MDB_API_HEADER, &api_key)
.body(Body::from(upload))
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let sha256 = if cart {
let mut input_buffer = Cursor::new(contents);
let mut output_buffer = Cursor::new(vec![]);
let (_, footer) =
cart::cart::unpack_stream(&mut input_buffer, &mut output_buffer, None)
.expect("failed to decode CaRT file");
let footer = footer.expect("CaRT should have had a footer");
let sha256 = footer
.get("sha256")
.expect("CaRT footer should have had an entry for SHA-256")
.to_string();
sha256.replace('"', "") } else {
sha256
};
if let Some(dir) = &state.directory {
let mut sample_path = dir.clone();
sample_path.push(format!(
"{}/{}/{}/{}",
&sha256[0..2],
&sha256[2..4],
&sha256[4..6],
sha256
));
eprintln!("Submitted sample should exist at {sample_path:?}.");
assert!(sample_path.exists());
if compress {
let sample_size_on_disk = sample_path.metadata().unwrap().len();
eprintln!(
"Original size: {}, compressed: {}",
contents.len(),
sample_size_on_disk
);
assert!(sample_size_on_disk < contents.len() as u64);
}
} else {
panic!("Directory was set for the state, but is now `None`");
}
let request = Request::builder()
.method("GET")
.uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT))
.header(MDB_API_HEADER, &api_key)
.body(Body::empty())
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response")
.unwrap();
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let report: Report = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
assert_eq!(report.sha256, sha256);
println!("Report: {report}");
let request = Request::builder()
.method("GET")
.uri(format!("{}/{sha256}", malwaredb_api::DOWNLOAD_SAMPLE_CART))
.header(MDB_API_HEADER, api_key)
.body(Body::empty())
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response for CaRT")
.unwrap();
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let bytes = bytes.to_vec();
let bytes_input = Cursor::new(bytes);
let output = Cursor::new(vec![]);
match cart::cart::unpack_stream(bytes_input, output, None) {
Ok((header, _)) => {
let header = header.unwrap();
assert_eq!(
header.get("sha384"),
Some(&serde_json::to_value(report.sha384).unwrap())
);
}
Err(e) => panic!("{e}"),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn client_integration() {
let (state, source_id, token) = state_and_token().await;
let server = tokio::spawn(async move {
state
.serve()
.await
.expect("MalwareDB failed to .serve() in tokio::spawn()")
});
assert!(!server.is_finished());
tokio::time::sleep(std::time::Duration::new(1, 0)).await;
let mdb_client = MdbClient::new("http://127.0.0.1:8080".to_string(), token);
let contents = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
assert!(mdb_client
.submit(contents, "elf_haiku_x86", source_id)
.await
.expect("failed to upload test file"));
let report = mdb_client
.report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
.await
.expect("failed to get report for file just submitted");
assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
let similar = mdb_client
.similar(contents)
.await
.expect("failed to query for files similar to what was just submitted");
assert_eq!(similar.results.len(), 1);
server.abort();
}
}