use super::State;
use malwaredb_api::digest::HashType;
use malwaredb_api::{
GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSampleB64, NewSampleBytes,
Report, SearchRequest, SearchResponse, ServerError, ServerInfo, ServerResponse,
SimilarSamplesRequest, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
YaraSearchRequestResponse, YaraSearchResponse, MDB_API_HEADER,
};
use std::fmt::{Display, Formatter};
use std::io::Cursor;
use std::iter::once;
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::{DefaultBodyLimit, Path, Request};
use axum::http::{header, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{middleware, Extension, Json, Router};
use axum_cbor::Cbor;
use base64::{engine::general_purpose, Engine as _};
use constcat::concat;
use http::{HeaderMap, HeaderName, HeaderValue};
use sha2::{Digest, Sha256};
use tower_http::compression::CompressionLayer;
use tower_http::decompression::DecompressionLayer;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
use uuid::Uuid;
mod receive;
const FAVICON_URL: &str = "/favicon.ico";
pub fn app(state: Arc<State>) -> Router {
const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSampleB64>>() * 2;
let compression_layer = CompressionLayer::new()
.br(true)
.deflate(true)
.gzip(true)
.zstd(true);
let decompression_layer = DecompressionLayer::new()
.br(true)
.deflate(true)
.gzip(true)
.zstd(true);
let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
Router::new()
.route("/", get(health))
.route(FAVICON_URL, get(favicon))
.route(malwaredb_api::SERVER_INFO_URL, get(get_mdb_info))
.route(malwaredb_api::USER_LOGIN_URL, post(user_login))
.route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
.route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
.route(
malwaredb_api::SUPPORTED_FILE_TYPES_URL,
get(get_supported_types),
)
.route(malwaredb_api::LIST_LABELS_URL, get(get_labels))
.route(malwaredb_api::LIST_SOURCES_URL, get(get_sources))
.route(
malwaredb_api::UPLOAD_SAMPLE_JSON_URL,
post(upload_new_sample_json),
)
.route(
malwaredb_api::UPLOAD_SAMPLE_CBOR_URL,
post(upload_new_sample_cbor),
)
.route(malwaredb_api::SEARCH_URL, post(sample_search))
.route(
concat!(malwaredb_api::DOWNLOAD_SAMPLE_URL, "/{hash}"),
get(download_sample),
)
.route(
concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART_URL, "/{hash}"),
get(download_sample_cart),
)
.route(
concat!(malwaredb_api::SAMPLE_REPORT_URL, "/{hash}"),
get(sample_report),
)
.route(malwaredb_api::SIMILAR_SAMPLES_URL, post(find_similar))
.route(malwaredb_api::YARA_SEARCH_URL, post(yara_search))
.route(
concat!(malwaredb_api::YARA_SEARCH_URL, "/{uuid}"),
get(yara_search_get),
)
.layer(DefaultBodyLimit::max(state.max_upload))
.layer(compression_layer)
.layer(decompression_layer)
.layer(size_limit_layer)
.layer(SetSensitiveHeadersLayer::new(once(
HeaderName::from_static(MDB_API_HEADER),
)))
.route_layer(middleware::from_fn(response_header_middleware))
.layer(Extension(state))
}
struct UserInfo {
pub id: u32,
}
async fn response_header_middleware(
Extension(state): Extension<Arc<State>>,
headers: HeaderMap,
mut req: Request,
next: Next,
) -> Result<Response, HttpError> {
const ALWAYS_ALLOWED_ENDPOINTS: [&str; 4] = [
"/",
FAVICON_URL,
malwaredb_api::SERVER_INFO_URL,
malwaredb_api::USER_LOGIN_URL,
];
if !ALWAYS_ALLOWED_ENDPOINTS.contains(&req.uri().path()) {
let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
ServerError::Unauthorized,
StatusCode::NOT_ACCEPTABLE,
))?;
let key = key
.to_str()
.map_err(|_| HttpError(ServerError::Unauthorized, StatusCode::NOT_ACCEPTABLE))?;
let uid = state.db_type.get_uid(key).await.map_err(|e| {
tracing::warn!("Failed to get user ID from API key: {e}");
HttpError(ServerError::Unauthorized, StatusCode::UNAUTHORIZED)
})?;
req.extensions_mut().insert(Arc::new(UserInfo { id: uid }));
}
let mut response = next.run(req).await;
response
.headers_mut()
.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store"));
Ok(response)
}
async fn health() -> StatusCode {
StatusCode::OK
}
async fn favicon() -> Response {
const ICON: Bytes = Bytes::from_static(include_bytes!("../../MDB_Logo.ico"));
let mut bytes = ICON.into_response();
bytes.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("image/vnd.microsoft.icon"),
);
bytes
}
async fn get_mdb_info(
Extension(state): Extension<Arc<State>>,
) -> Result<Json<ServerResponse<ServerInfo>>, HttpError> {
let server_info = ServerResponse::Success(state.get_info().await?);
Ok(Json(server_info))
}
async fn user_login(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<GetAPIKeyRequest>,
) -> Result<Json<ServerResponse<GetAPIKeyResponse>>, HttpError> {
let api_key = state
.db_type
.authenticate(&payload.user, &payload.password)
.await?;
Ok(Json(ServerResponse::Success(GetAPIKeyResponse {
key: api_key,
message: None,
})))
}
async fn user_logout(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
) -> Result<StatusCode, HttpError> {
state.db_type.reset_own_api_key(user.id).await?;
Ok(StatusCode::OK)
}
async fn get_user_groups_sources(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
) -> Result<Json<ServerResponse<GetUserInfoResponse>>, HttpError> {
let groups_sources = ServerResponse::Success(state.db_type.get_user_info(user.id).await?);
Ok(Json(groups_sources))
}
async fn get_supported_types(
Extension(state): Extension<Arc<State>>,
) -> Result<Json<ServerResponse<SupportedFileTypes>>, HttpError> {
let data_types = state.db_type.get_known_data_types().await?;
let file_types = ServerResponse::Success(SupportedFileTypes {
types: data_types.into_iter().map(Into::into).collect(),
message: None,
});
Ok(Json(file_types))
}
async fn get_labels(
Extension(state): Extension<Arc<State>>,
) -> Result<Json<ServerResponse<Labels>>, HttpError> {
let labels = ServerResponse::Success(state.db_type.get_labels().await?);
Ok(Json(labels))
}
async fn get_sources(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
) -> Result<Json<ServerResponse<Sources>>, HttpError> {
let sources = ServerResponse::Success(state.db_type.get_user_sources(user.id).await?);
Ok(Json(sources))
}
async fn upload_new_sample_json(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
Json(payload): Json<NewSampleB64>,
) -> Result<StatusCode, HttpError> {
let allowed = state
.db_type
.allowed_user_source(user.id, payload.source_id)
.await?;
if !allowed {
return Err(HttpError(
ServerError::Unauthorized,
StatusCode::UNAUTHORIZED,
));
}
let received_hash = hex::decode(&payload.sha256)?;
let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
let mut hasher = Sha256::new();
hasher.update(&bytes);
let result = hasher.finalize();
if result[..] != received_hash[..] {
return Err(HttpError(
ServerError::Unauthorized,
StatusCode::NOT_ACCEPTABLE,
));
}
receive::incoming_sample(
state.clone(),
bytes,
user.id,
payload.source_id,
payload.file_name,
)
.await?;
Ok(StatusCode::OK)
}
async fn upload_new_sample_cbor(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
Cbor(payload): Cbor<NewSampleBytes>,
) -> Result<StatusCode, HttpError> {
let allowed = state
.db_type
.allowed_user_source(user.id, payload.source_id)
.await?;
if !allowed {
return Err(HttpError(
ServerError::Unauthorized,
StatusCode::UNAUTHORIZED,
));
}
let received_hash = hex::decode(&payload.sha256)?;
let bytes = payload.file_contents;
let mut hasher = Sha256::new();
hasher.update(&bytes);
let result = hasher.finalize();
if result[..] != received_hash[..] {
return Err(HttpError(
ServerError::Unauthorized,
StatusCode::NOT_ACCEPTABLE,
));
}
receive::incoming_sample(
state.clone(),
bytes,
user.id,
payload.source_id,
payload.file_name,
)
.await?;
Ok(StatusCode::OK)
}
async fn sample_search(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
Json(payload): Json<SearchRequest>,
) -> Result<Json<ServerResponse<SearchResponse>>, HttpError> {
let hashes = ServerResponse::Success(state.db_type.partial_search(user.id, payload).await?);
Ok(Json(hashes))
}
async fn download_sample(
Path(hash): Path<String>,
Extension(user): Extension<Arc<UserInfo>>,
Extension(state): Extension<Arc<State>>,
) -> Result<Response, HttpError> {
if state.directory.is_none() {
return Err(NO_SAMPLES_STORED_ERROR);
}
let hash = HashType::try_from(hash.as_str())?;
let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
let hash = HashType::try_from(sha256.as_str())?;
let contents = state.retrieve_bytes(&sha256).await?;
let mut bytes = Bytes::from(contents).into_response();
let name_header_value = format!("attachment; filename=\"{sha256}\"");
bytes.headers_mut().insert(
header::CONTENT_DISPOSITION,
HeaderValue::from_str(&name_header_value)
.unwrap_or(HeaderValue::from_static("Unknown.bin")),
);
bytes.headers_mut().insert(
"content-digest",
HeaderValue::from_str(&hash.content_digest_header())?,
);
Ok(bytes)
}
async fn download_sample_cart(
Path(hash): Path<String>,
Extension(user): Extension<Arc<UserInfo>>,
Extension(state): Extension<Arc<State>>,
) -> Result<Response, HttpError> {
if state.directory.is_none() {
return Err(NO_SAMPLES_STORED_ERROR);
}
let hash = HashType::try_from(hash.as_str())?;
let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
let report = state.db_type.get_sample_report(user.id, &hash).await?;
let contents = state.retrieve_bytes(&sha256).await?;
let contents_cursor = Cursor::new(contents);
let mut output_cursor = Cursor::new(vec![]);
let mut output_metadata = cart_container::JsonMap::new();
output_metadata.insert("sha384".into(), report.sha384.into());
output_metadata.insert("sha512".into(), report.sha512.into());
output_metadata.insert("entropy".into(), report.entropy.into());
if let Some(filecmd) = report.filecommand {
output_metadata.insert("file".into(), filecmd.into());
}
cart_container::pack_stream(
contents_cursor,
&mut output_cursor,
Some(output_metadata),
None,
cart_container::digesters::default_digesters(), None,
)?;
let mut hasher = Sha256::new();
hasher.update(output_cursor.get_ref());
let hash_b64 = general_purpose::STANDARD.encode(hasher.finalize());
let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
bytes.headers_mut().insert(
header::CONTENT_DISPOSITION,
HeaderValue::from_str(&name_header_value)
.unwrap_or(HeaderValue::from_static("Unknown.cart")),
);
bytes.headers_mut().insert(
"content-digest",
HeaderValue::from_str(&format!("sha-256=:{hash_b64}:"))?,
);
Ok(bytes)
}
async fn sample_report(
Path(hash): Path<String>,
Extension(user): Extension<Arc<UserInfo>>,
Extension(state): Extension<Arc<State>>,
) -> Result<Json<ServerResponse<Report>>, HttpError> {
let hash = HashType::try_from(hash.as_str())?;
let report =
ServerResponse::<Report>::Success(state.db_type.get_sample_report(user.id, &hash).await?);
Ok(Json(report))
}
async fn find_similar(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
Json(payload): Json<SimilarSamplesRequest>,
) -> Result<Json<ServerResponse<SimilarSamplesResponse>>, HttpError> {
let results = state
.db_type
.find_similar_samples(user.id, &payload.hashes)
.await?;
Ok(Json(ServerResponse::Success(SimilarSamplesResponse {
results,
message: None,
})))
}
#[allow(unused_variables)]
async fn yara_search(
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
Json(payload): Json<YaraSearchRequest>,
) -> Result<Json<ServerResponse<YaraSearchRequestResponse>>, HttpError> {
#[cfg(feature = "yara")]
{
if state.directory.is_some() {
let yara_string = payload.rules.join("\n");
let rules = crate::yara::compile_yara_rules(payload)?;
let yara_bytes = rules.serialize()?;
let search_uuid = state
.db_type
.add_yara_search(user.id, &yara_string, &yara_bytes)
.await?;
Ok(Json(ServerResponse::Success(YaraSearchRequestResponse {
uuid: search_uuid,
})))
} else {
Err(NO_SAMPLES_STORED_ERROR)
}
}
#[cfg(not(feature = "yara"))]
{
tracing::warn!("Received Yara search request, but Yara support is not enabled");
Err(HttpError(
ServerError::Unsupported,
StatusCode::INTERNAL_SERVER_ERROR,
))
}
}
#[allow(unused_variables)]
async fn yara_search_get(
Path(uuid): Path<Uuid>,
Extension(state): Extension<Arc<State>>,
Extension(user): Extension<Arc<UserInfo>>,
) -> Result<Json<ServerResponse<YaraSearchResponse>>, HttpError> {
#[cfg(feature = "yara")]
{
if state.directory.is_some() {
Ok(Json(ServerResponse::Success(
state.db_type.get_yara_results(uuid, user.id).await?,
)))
} else {
Err(NO_SAMPLES_STORED_ERROR)
}
}
#[cfg(not(feature = "yara"))]
{
tracing::warn!("Received Yara search request, but Yara support is not enabled");
Err(HttpError(
ServerError::Unsupported,
StatusCode::INTERNAL_SERVER_ERROR,
))
}
}
const NO_SAMPLES_STORED_ERROR: HttpError =
HttpError(ServerError::NoSamples, StatusCode::NOT_ACCEPTABLE);
pub struct HttpError(pub ServerError, pub StatusCode);
impl IntoResponse for HttpError {
fn into_response(self) -> Response {
let response: ServerResponse<String> = ServerResponse::Error(self.0);
match serde_json::to_string(&response) {
Ok(json) => (self.1, json).into_response(),
Err(_) => self.1.into_response(),
}
}
}
impl Display for HttpError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl<E> From<E> for HttpError
where
E: Into<anyhow::Error>,
{
fn from(_err: E) -> Self {
Self(ServerError::ServerError, StatusCode::INTERNAL_SERVER_ERROR)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::{EncryptionOption, FileEncryption};
use crate::db::DatabaseType;
use crate::StateBuilder;
use malwaredb_api::PartialHashSearchType;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Once, RwLock};
use std::time::{Instant, SystemTime};
use std::{env, fs};
use anyhow::Context;
use axum::body::Body;
use axum::http::Request;
use chrono::Local;
use constcat::concat;
use deadpool_postgres::tokio_postgres::{Config, NoTls};
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
use http::header::CONTENT_TYPE;
use http_body_util::BodyExt;
use rstest::rstest;
use tower::ServiceExt;
use uuid::Uuid;
const ADMIN_UNAME: &str = "admin";
const ADMIN_PASSWORD: &str = "password12345";
const SAMPLE_BYTES: &[u8] = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
static TRACING: Once = Once::new();
fn init_tracing() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::TRACE)
.init();
}
async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
TRACING.call_once(init_tracing);
let mut db_file = env::temp_dir();
db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
if std::path::Path::new(&db_file).exists() {
fs::remove_file(&db_file)
.context(format!("failed to delete old SQLite file {db_file:?}"))
.unwrap();
}
let db_type =
DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
.await
.context(format!("failed to create SQLite instance for {db_file:?}"))
.unwrap();
if compress {
db_type.enable_compression().await.unwrap();
}
let keys = if encrypt {
let key = FileEncryption::from(EncryptionOption::Xor);
let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
let mut keys = HashMap::new();
keys.insert(key_id, key);
keys
} else {
HashMap::new()
};
let db_config = db_type.get_config().await.unwrap();
let state = State {
port: 8080,
directory: Some(
tempfile::TempDir::with_prefix("mdb-temp-samples")
.unwrap()
.path()
.into(),
),
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type: Arc::new(db_type),
db_config,
keys,
started: SystemTime::now(),
#[cfg(feature = "vt")]
vt_client: None,
tls_config: None,
mdns: None,
};
state
.db_type
.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("Failed to set admin password")
.unwrap();
let source_id = state
.db_type
.create_source("temp-source", None, None, Local::now(), true, Some(false))
.await
.unwrap();
state
.db_type
.add_group_to_source(0, source_id)
.await
.unwrap();
(Arc::new(state), source_id)
}
async fn state_and_token(pem_file: bool, postgres: bool) -> (State, u32, String) {
TRACING.call_once(init_tracing);
let mut builder = if postgres {
const CONNECTION_STRING: &str =
"user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
let pg_config = CONNECTION_STRING.parse::<Config>().unwrap();
let mgr_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast,
};
let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
let pool = Pool::builder(mgr)
.max_size(num_cpus::get().min(16))
.build()
.unwrap();
let client = pool.get().await.unwrap();
client
.batch_execute("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
.await
.unwrap();
StateBuilder::new(concat!("postgres ", CONNECTION_STRING), None)
.await
.unwrap()
} else {
let mut db_file = env::temp_dir();
db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
if std::path::Path::new(&db_file).exists() {
fs::remove_file(&db_file)
.context(format!("failed to delete old SQLite file {db_file:?}"))
.unwrap();
}
StateBuilder::new(&format!("file:{}", db_file.display()), None)
.await
.unwrap()
};
builder = builder.directory(PathBuf::from(
tempfile::TempDir::with_prefix("mdb-temp-samples")
.unwrap()
.path(),
));
if pem_file {
builder = builder.port(8443);
builder = builder
.tls(
"../../testdata/server_ca_cert.pem".into(),
"../../testdata/server_key.pem".into(),
)
.await
.unwrap();
} else {
builder = builder.port(8444);
builder = builder
.tls(
"../../testdata/server_cert.der".into(),
"../../testdata/server_key.der".into(),
)
.await
.unwrap();
}
let state = builder.into_state().await.unwrap();
state
.db_type
.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("Failed to set admin password")
.unwrap();
let source_id = state
.db_type
.create_source("temp-source", None, None, Local::now(), true, Some(false))
.await
.unwrap();
state
.db_type
.add_group_to_source(0, source_id)
.await
.unwrap();
let token = state
.db_type
.authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.unwrap();
(state, source_id, token)
}
async fn get_key(state: Arc<State>) -> String {
TRACING.call_once(init_tracing);
let key_request = serde_json::to_string(&GetAPIKeyRequest {
user: ADMIN_UNAME.into(),
password: ADMIN_PASSWORD.into(),
})
.context("Failed to convert API key request to JSON")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::USER_LOGIN_URL)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(key_request))
.unwrap();
let response = app(state)
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: ServerResponse<GetAPIKeyResponse> = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
if let ServerResponse::Success(response) = response {
let key = response.key.clone();
assert_eq!(key.len(), 64);
key
} else {
panic!("failed to get API key response")
}
}
#[tokio::test]
async fn about_self() {
let (state, _) = state(false, false).await;
let api_key = get_key(state.clone()).await;
let request = Request::builder()
.method("GET")
.uri(malwaredb_api::USER_INFO_URL)
.header(MDB_API_HEADER, &api_key)
.body(Body::empty())
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: ServerResponse<GetUserInfoResponse> = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
let response = response.unwrap();
assert_eq!(response.id, 0);
assert!(response.is_admin);
assert!(!response.is_readonly);
assert_eq!(response.username, "admin");
let request = Request::builder()
.method("GET")
.uri(malwaredb_api::LIST_LABELS_URL)
.header(MDB_API_HEADER, &api_key)
.body(Body::empty())
.unwrap();
let response = app(state)
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: ServerResponse<Labels> = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
let response = response.unwrap();
assert!(response.is_empty());
}
#[rstest]
#[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true, false)]
#[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false, false)]
#[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false, false)]
#[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false, false)]
#[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false, false)]
#[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false, false)]
#[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false, false)]
#[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false, false)]
#[case::icon_unknown_type_proxy(include_bytes!("../../../../MDB_Logo.ico"), false, false, false, true)]
#[tokio::test]
async fn submit_sample(
#[case] contents: &[u8],
#[case] compress: bool,
#[case] encrypt: bool,
#[case] cart: bool,
#[case] should_fail: bool,
) {
let (state, source_id) = state(compress, encrypt).await;
let api_key = get_key(state.clone()).await;
let file_contents_b64 = general_purpose::STANDARD.encode(contents);
let mut hasher = Sha256::new();
hasher.update(contents);
let sha256 = hex::encode(hasher.finalize());
let upload = serde_json::to_string(&NewSampleB64 {
file_name: "some_sample".into(),
source_id,
file_contents_b64,
sha256: sha256.clone(),
})
.context("failed to create upload structure")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::UPLOAD_SAMPLE_JSON_URL)
.header(CONTENT_TYPE, "application/json")
.header(MDB_API_HEADER, &api_key)
.body(Body::from(upload))
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response")
.unwrap();
if should_fail {
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
return;
}
assert_eq!(response.status(), StatusCode::OK);
let sha256 = if cart {
let mut input_buffer = Cursor::new(contents);
let mut output_buffer = Cursor::new(vec![]);
let (_, footer) =
cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
.expect("failed to decode CaRT file");
let footer = footer.expect("CaRT should have had a footer");
let sha256 = footer
.get("sha256")
.expect("CaRT footer should have had an entry for SHA-256")
.to_string();
sha256.replace('"', "") } else {
sha256
};
if let Some(dir) = &state.directory {
let mut sample_path = dir.clone();
sample_path.push(format!(
"{}/{}/{}/{}",
&sha256[0..2],
&sha256[2..4],
&sha256[4..6],
sha256
));
eprintln!("Submitted sample should exist at {sample_path:?}.");
assert!(sample_path.exists());
if compress {
let sample_size_on_disk = sample_path.metadata().unwrap().len();
eprintln!(
"Original size: {}, compressed: {}",
contents.len(),
sample_size_on_disk
);
assert!(sample_size_on_disk < contents.len() as u64);
}
} else {
panic!("Directory was set for the state, but is now `None`");
}
let request = Request::builder()
.method("GET")
.uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT_URL))
.header(MDB_API_HEADER, &api_key)
.body(Body::empty())
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response")
.unwrap();
println!("Response headers: {:?}", response.headers());
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let json_response = String::from_utf8(bytes.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let report: ServerResponse<Report> = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
let report = report.unwrap();
assert_eq!(report.sha256, sha256);
println!("Report: {report}");
let request = Request::builder()
.method("GET")
.uri(format!(
"{}/{sha256}",
malwaredb_api::DOWNLOAD_SAMPLE_CART_URL
))
.header(MDB_API_HEADER, api_key.clone())
.body(Body::empty())
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response for CaRT")
.unwrap();
println!("CaRT Response headers: {:?}", response.headers());
let bytes = response
.into_body()
.collect()
.await
.expect("failed to collect response body to bytes")
.to_bytes();
let bytes = bytes.to_vec();
let bytes_input = Cursor::new(bytes);
let output = Cursor::new(vec![]);
match cart_container::unpack_stream(bytes_input, output, None) {
Ok((header, _)) => {
let header = header.unwrap();
assert_eq!(
header.get("sha384"),
Some(&serde_json::to_value(report.sha384).unwrap())
);
}
Err(e) => panic!("{e}"),
}
}
#[rstest]
#[case::ssl_pem_sqlite(true, false)]
#[case::ssl_der_sqlite(false, false)]
#[ignore = "don't run this in CI"]
#[case::ssl_der_postgres(false, true)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn client_integration(#[case] pem: bool, #[case] postgres: bool) {
TRACING.call_once(init_tracing);
let (state, source_id, token) = state_and_token(pem, postgres).await;
let state_port = state.port;
let server = tokio::spawn(async move {
state
.serve()
.await
.expect("MalwareDB failed to .serve() in tokio::spawn()");
});
assert!(!server.is_finished());
tokio::time::sleep(std::time::Duration::new(1, 0)).await;
println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
assert_eq!(
*malwaredb_client::MDB_VERSION_SEMVER,
*crate::MDB_VERSION_SEMVER,
"SemVer parsing of MDB version failed"
);
let mdb_client = malwaredb_client::MdbClient::new(
format!("https://127.0.0.1:{state_port}"),
token.clone(),
Some("../../testdata/ca_cert.pem".into()),
)
.unwrap();
assert!(!mdb_client.supported_types().await.unwrap().types.is_empty());
assert!(mdb_client.server_info().await.is_ok());
let start = Instant::now();
assert!(mdb_client
.submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
.await
.context("failed to upload test file")
.unwrap());
let duration = start.elapsed();
println!("Initial upload and database record creation via base64 took {duration:?}");
let start = Instant::now();
mdb_client
.submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
.await
.context("failed to upload test file")
.unwrap();
let duration = start.elapsed();
println!("Upload again via base64 took {duration:?}");
let start = Instant::now();
mdb_client
.submit_as_cbor(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
.await
.context("failed to upload test file")
.unwrap();
let duration = start.elapsed();
println!("Upload again via cbor took {duration:?}");
let report = mdb_client
.report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
.await
.expect("failed to get report for file just submitted");
assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
let similar = mdb_client
.similar(SAMPLE_BYTES)
.await
.expect("failed to query for files similar to what was just submitted");
assert_eq!(similar.results.len(), 1);
let search = mdb_client
.partial_search(
Some((PartialHashSearchType::MD5, String::from(&report.md5[0..10]))),
None,
PartialHashSearchType::Any,
10,
)
.await
.unwrap();
assert_eq!(search.hashes.len(), 1);
let search = mdb_client
.partial_search(
Some((PartialHashSearchType::Any, "AAAA".into())),
None,
PartialHashSearchType::Any,
10,
)
.await
.unwrap();
assert!(search.hashes.is_empty());
#[cfg(feature = "yara")]
{
let elf_yara = "rule elf_file {
condition:
uint32(0) == 0x464c457f
}";
let mut counter = 0;
let mut successful = false;
let yara_result = mdb_client.yara_search(elf_yara).await.unwrap();
while counter < 5 {
if let Ok(yara_response) = mdb_client.yara_result(yara_result.uuid).await {
if !yara_response.results.is_empty() {
println!("Yara search upload response: {:?}", yara_response.results);
assert_eq!(yara_response.results.get("elf_file").unwrap().len(), 1);
assert_eq!(yara_response.results.len(), 1);
successful = true;
break;
}
}
tokio::time::sleep(std::time::Duration::new(2, 0)).await;
counter += 1;
}
assert!(successful);
}
mdb_client.reset_key().await.expect("failed to reset key");
server.abort();
}
#[allow(clippy::too_many_lines)]
#[test]
#[ignore = "don't run this in CI"]
fn client_integration_blocking() {
TRACING.call_once(init_tracing);
let token = Arc::new(RwLock::new(String::new()));
let token_clone = token.clone();
let thread = std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let mut db_file = env::temp_dir();
db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
if std::path::Path::new(&db_file).exists() {
fs::remove_file(&db_file)
.context(format!("failed to delete old SQLite file {db_file:?}"))
.unwrap();
}
let (db_type, db_config) = rt.block_on(async {
let db_type =
DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
.await
.context(format!("failed to create SQLite instance for {db_file:?}"))
.unwrap();
let db_config = db_type.get_config().await.unwrap();
(db_type, db_config)
});
let state = State {
port: 9090,
directory: Some(
tempfile::TempDir::with_prefix("mdb-temp-samples")
.unwrap()
.path()
.into(),
),
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type: Arc::new(db_type),
db_config,
keys: HashMap::default(),
started: SystemTime::now(),
#[cfg(feature = "vt")]
vt_client: None,
tls_config: None,
mdns: None,
};
rt.block_on(async {
state
.db_type
.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("Failed to set admin password")
.unwrap();
let source_id = state
.db_type
.create_source("temp-source", None, None, Local::now(), true, Some(false))
.await
.unwrap();
state
.db_type
.add_group_to_source(0, source_id)
.await
.unwrap();
let token_string = state
.db_type
.authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.unwrap();
if let Ok(mut token_lock) = token_clone.write() {
*token_lock = token_string;
}
state
.serve()
.await
.expect("MalwareDB failed to .serve() in tokio::spawn()");
});
});
std::thread::sleep(std::time::Duration::from_secs(1));
let mdb_client = malwaredb_client::blocking::MdbClient::new(
String::from("http://127.0.0.1:9090"),
token.read().unwrap().clone(),
None,
)
.unwrap();
let types = match mdb_client.supported_types() {
Ok(types) => types,
Err(e) => panic!("{e}"),
};
assert!(!types.types.is_empty());
assert!(mdb_client
.submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), 1)
.context("failed to upload test file")
.unwrap());
let report = mdb_client
.report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
.expect("failed to get report for file just submitted");
assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
let similar = mdb_client
.similar(SAMPLE_BYTES)
.expect("failed to query for files similar to what was just submitted");
assert_eq!(similar.results.len(), 1);
let search = mdb_client
.partial_search(
Some((PartialHashSearchType::Any, "AAAA".into())),
None,
PartialHashSearchType::Any,
10,
)
.unwrap();
assert!(search.hashes.is_empty());
mdb_client.reset_key().expect("failed to reset key");
drop(thread); }
}