use super::processing;
use super::State;
use crate::db::types::FileTypes;
use malwaredb_api::{
DownloadSampleRequest, EmptyAuthenticatingPost, GetAPIKeyRequest, GetAPIKeyResponse,
GetUserInfoResponse, Labels, NewSample, Report, ServerInfo, SupportedFileTypes,
};
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::DefaultBodyLimit;
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Extension, Json, Router};
use base64::{engine::general_purpose, Engine as _};
use sha2::{Digest, Sha256};
use tracing::debug;
pub fn app(state: Arc<State>) -> Router {
Router::new()
.route("/", get(health))
.route(malwaredb_api::SERVER_INFO, get(get_mdb_info))
.route(malwaredb_api::USER_LOGIN_URL, post(user_login))
.route(malwaredb_api::USER_LOGOUT_URL, post(user_logout))
.route(malwaredb_api::USER_INFO_URL, post(get_user_groups_sources))
.route(
malwaredb_api::SUPPORTED_FILE_TYPES,
get(get_supported_types),
)
.route(malwaredb_api::LIST_LABELS, post(get_labels))
.route(malwaredb_api::UPLOAD_SAMPLE, post(get_new_sample))
.route(malwaredb_api::DOWNLOAD_SAMPLE, post(download_sample))
.route(malwaredb_api::SAMPLE_REPORT, post(sample_report))
.route(malwaredb_api::SIMILAR_SAMPLES, post(find_similar))
.layer(DefaultBodyLimit::max(state.max_upload))
.layer(Extension(state))
}
async fn health() -> StatusCode {
StatusCode::OK
}
async fn get_mdb_info(
Extension(state): Extension<Arc<State>>,
) -> Result<Json<ServerInfo>, StatusCode> {
let server_info = state
.get_info()
.await
.map_err(|e| {
debug!("API: ServerInfo Error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
.unwrap();
Ok(Json(server_info))
}
async fn user_login(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<GetAPIKeyRequest>,
) -> Result<Json<GetAPIKeyResponse>, StatusCode> {
let api_key = state
.db_type
.authenticate(&payload.user, &payload.password)
.await
.map_err(|e| {
debug!("API: Authentication Error: {e}");
StatusCode::UNAUTHORIZED
})?;
Ok(Json(GetAPIKeyResponse {
key: Some(api_key),
message: None,
}))
}
async fn user_logout(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<EmptyAuthenticatingPost>,
) -> Result<StatusCode, StatusCode> {
let uid = state.db_type.get_uid(&payload.key).await.map_err(|e| {
debug!("API Error: {e}");
StatusCode::UNAUTHORIZED
})?;
state.db_type.reset_own_api_key(uid).await.map_err(|e| {
debug!("API Error clearing own API key: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(StatusCode::OK)
}
async fn get_user_groups_sources(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<EmptyAuthenticatingPost>,
) -> Result<Json<GetUserInfoResponse>, StatusCode> {
let uid = state.db_type.get_uid(&payload.key).await.map_err(|e| {
debug!("API Error: {e}");
StatusCode::UNAUTHORIZED
})?;
let groups_sources = state.db_type.get_user_info(uid).await.map_err(|e| {
debug!("API Error: {e}");
StatusCode::UNAUTHORIZED
})?;
Ok(Json(groups_sources))
}
async fn get_supported_types(
Extension(state): Extension<Arc<State>>,
) -> Result<Json<SupportedFileTypes>, StatusCode> {
let data_types = state.db_type.get_known_data_types().await.map_err(|e| {
debug!("API Error: get_supported_types {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let file_types = FileTypes(data_types);
Ok(Json(file_types.into()))
}
async fn get_labels(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<EmptyAuthenticatingPost>,
) -> Result<Json<Labels>, StatusCode> {
let _uid = state.db_type.get_uid(&payload.key).await.map_err(|e| {
debug!("API Error: could not get uid for {} {e}", payload.key);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let labels = state.db_type.get_labels().await.map_err(|e| {
debug!("API Error: could not get list of labels {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(labels))
}
async fn get_new_sample(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<NewSample>,
) -> Result<StatusCode, StatusCode> {
let uid = state.db_type.get_uid(&payload.key).await.map_err(|e| {
debug!("API Error: could not get uid for {} {e}", payload.key);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let allowed = state
.db_type
.allowed_user_source(uid, payload.source_id as i32)
.await
.map_err(|e| {
debug!(
"API Error: could not check access for uid {uid} to sid {}: {e}",
payload.source_id
);
StatusCode::INTERNAL_SERVER_ERROR
})?;
if !allowed {
return Err(StatusCode::UNAUTHORIZED);
}
let received_hash = hex::decode(&payload.sha256).map_err(|e| {
debug!(
"API Error: Failed to convert SHA-256 hash {} to bytes: {e}",
payload.sha256
);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let bytes = general_purpose::STANDARD
.decode(&payload.file_contents_b64)
.map_err(|e| {
debug!("API Error: could not decode base64 for {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let mut hasher = Sha256::new();
hasher.update(&bytes);
let result = hasher.finalize();
if result[..] != received_hash[..] {
return Err(StatusCode::NOT_ACCEPTABLE);
}
processing::receive::incoming_sample(
state.clone(),
bytes,
uid,
payload.source_id as i32,
payload.file_name,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::OK)
}
async fn download_sample(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<DownloadSampleRequest>,
) -> Result<impl IntoResponse, StatusCode> {
if state.directory.is_none() {
return Err(StatusCode::NOT_ACCEPTABLE);
}
let uid = state.db_type.get_uid(&payload.key).await.map_err(|e| {
debug!("API Error: could not get UID from API key for {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let sha256 = state
.db_type
.retrieve_sample(uid, payload.hash)
.await
.map_err(|e| {
debug!("API Error: could not get SHA-256 from the `HashType` {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let contents = state.retrieve_bytes(&sha256).map_err(|e| {
debug!("API Error: could not read the file from disk {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let mut bytes = Bytes::from(contents).into_response();
let name_header_value = format!("attachment; filename=\"{sha256}\"");
bytes.headers_mut().insert(
header::CONTENT_DISPOSITION,
name_header_value.parse().unwrap(),
);
Ok(bytes)
}
async fn sample_report(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<DownloadSampleRequest>,
) -> Result<Json<Report>, StatusCode> {
let uid = state.db_type.get_uid(&payload.key).await.map_err(|e| {
debug!("API Error: could not get UID from API key for {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let report = state
.db_type
.get_sample_report(uid, payload.hash)
.await
.map_err(|e| {
debug!("API Error: could not get SHA-256 from the `HashType` {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(report))
}
async fn find_similar(
Extension(state): Extension<Arc<State>>,
Json(payload): Json<malwaredb_api::SimilarSamplesRequest>,
) -> Result<Json<malwaredb_api::SimilarSamplesResponse>, StatusCode> {
let uid = state.db_type.get_uid(&payload.key).await.map_err(|e| {
debug!("API Error: could not get uid for {} {e}", payload.key);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let results = state
.db_type
.find_similar_samples(uid, &payload.hash)
.await
.map_err(|e| {
debug!("API Error: could not find similar files {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(malwaredb_api::SimilarSamplesResponse {
results,
message: None,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::sqlite::Sqlite;
use crate::db::DatabaseType;
use crate::State;
use std::time::SystemTime;
use std::{env, fs};
use anyhow::Context;
use axum::body::Body;
use axum::http::Request;
use chrono::Local;
use http::header::CONTENT_TYPE;
use malwaredb_api::digest::HashType;
use rstest::rstest;
use tower::ServiceExt; use uuid::Uuid;
const ADMIN_UNAME: &str = "admin";
const ADMIN_PASSWORD: &str = "password12345";
async fn state() -> (Arc<State>, i32) {
let mut db_file = env::temp_dir();
db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
if std::path::Path::new(&db_file).exists() {
fs::remove_file(&db_file)
.context(format!("failed to delete old SQLite file {db_file:?}"))
.unwrap();
}
let state = State {
port: 8080,
directory: Some(
tempfile::TempDir::with_prefix("mdb-temp-samples")
.unwrap()
.into_path(),
),
max_upload: 10 * 1024 * 1024,
ip: "127.0.0.1".parse().unwrap(),
db_type: DatabaseType::SQLite(
Sqlite::new(db_file.to_str().unwrap())
.context(format!("failed to create SQLite instance for {db_file:?}"))
.unwrap(),
),
started: SystemTime::now(),
};
state
.db_type
.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
.await
.context("Failed to set admin password")
.unwrap();
let source_id = state
.db_type
.create_source("temp-source", None, None, Local::now(), true)
.await
.unwrap();
state
.db_type
.add_group_to_source(0, source_id)
.await
.unwrap();
(Arc::new(state), source_id)
}
async fn get_key(state: Arc<State>) -> String {
let key_request = serde_json::to_string(&GetAPIKeyRequest {
user: ADMIN_UNAME.into(),
password: ADMIN_PASSWORD.into(),
})
.context("Failed to convert API key request to JSON")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::USER_LOGIN_URL)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(key_request))
.unwrap();
let response = app(state)
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let json_response = String::from_utf8(body.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: GetAPIKeyResponse = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
let key = response.key.unwrap();
assert_eq!(key.len(), 36);
key
}
#[tokio::test]
async fn about_self() {
let (state, _) = state().await;
let api_key = get_key(state.clone()).await;
let auth = serde_json::to_string(&EmptyAuthenticatingPost { key: api_key })
.context("failed to make JSON from Auth struct")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::USER_INFO_URL)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(auth.clone()))
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let json_response = String::from_utf8(body.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: GetUserInfoResponse = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
assert_eq!(response.id, 0);
assert!(response.is_admin);
assert_eq!(response.username, "admin");
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::LIST_LABELS)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(auth))
.unwrap();
let response = app(state)
.oneshot(request)
.await
.context("failed to send/receive login request")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let json_response = String::from_utf8(body.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let response: Labels = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
assert!(response.is_empty());
}
#[rstest]
#[case(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"))]
#[case(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"))]
#[case(include_bytes!("../../../types/testdata/pdf/test.pdf"))]
#[case(include_bytes!("../../../types/testdata/rtf/hello.rtf"))]
#[tokio::test]
async fn submit_sample(#[case] contents: &[u8]) {
let (state, source_id) = state().await;
let api_key = get_key(state.clone()).await;
let file_contents_b64 = general_purpose::STANDARD.encode(contents);
let mut hasher = Sha256::new();
hasher.update(contents);
let sha256 = hex::encode(hasher.finalize());
let upload = serde_json::to_string(&NewSample {
file_name: "some_sample".into(),
key: api_key.clone(),
source_id: source_id as u32,
file_contents_b64,
sha256: sha256.clone(),
})
.context("failed to create upload structure")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::UPLOAD_SAMPLE)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(upload))
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response")
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
if let Some(dir) = &state.directory {
let mut sample_path = dir.clone();
sample_path.push(format!(
"{}/{}/{}/{}",
&sha256[0..2],
&sha256[2..4],
&sha256[4..6],
sha256
));
eprintln!("Submitted sample should exist at {sample_path:?}.");
assert!(sample_path.exists());
} else {
panic!("Directory was set for the state, but is now `None`");
}
let request = serde_json::to_string(&DownloadSampleRequest {
hash: HashType::try_from(sha256.clone()).expect("failed to get HashType from string"),
key: api_key,
})
.context("failed to create report request structure")
.unwrap();
let request = Request::builder()
.method("POST")
.uri(malwaredb_api::SAMPLE_REPORT)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(request))
.unwrap();
let response = app(state.clone())
.oneshot(request)
.await
.context("failed to send/receive upload request/response")
.unwrap();
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let json_response = String::from_utf8(body.to_ascii_lowercase())
.context("failed to convert response to string")
.unwrap();
let report: Report = serde_json::from_str(&json_response)
.context("failed to convert json response to object")
.unwrap();
assert_eq!(report.sha256, sha256);
println!("Report: {report}");
}
}