use std::fmt::{Debug, Formatter};
use std::path::{Path, PathBuf};
use crate::{get_config_path, MDB_CLIENT_ERROR_CONTEXT};
use malwaredb_api::{
digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
YaraSearchRequestResponse, YaraSearchResponse,
};
use malwaredb_types::exec::pe32::EXE;
use anyhow::{bail, ensure, Context, Result};
use base64::engine::general_purpose;
use base64::Engine;
use fuzzyhash::FuzzyHash;
use malwaredb_lzjd::{LZDict, Murmur3HashState};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tlsh_fixed::TlshBuilder;
use tracing::{error, info, trace, warn};
use uuid::Uuid;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
pub struct MdbClient {
pub url: String,
api_key: String,
#[zeroize(skip)]
#[serde(skip)]
client: reqwest::blocking::Client,
#[cfg(target_os = "macos")]
#[zeroize(skip)]
#[serde(skip)]
cert: Option<crate::macos::CertificateData>,
}
impl MdbClient {
pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
let mut url = url;
let url = if url.ends_with('/') {
url.pop();
url
} else {
url
};
let cert = if let Some(path) = cert_path {
Some((crate::path_load_cert(&path)?, path))
} else {
None
};
let builder = reqwest::blocking::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
let client = if let Some(((_cert_type, cert), _path)) = &cert {
builder.add_root_certificate(cert.clone()).build()
} else {
builder.build()
}?;
#[cfg(target_os = "macos")]
let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
Some(crate::macos::CertificateData {
cert_type: *cert_type,
cert_bytes: std::fs::read(cert_path)?,
})
} else {
None
};
Ok(Self {
url,
api_key,
client,
#[cfg(target_os = "macos")]
cert,
})
}
pub fn login(
url: String,
username: String,
password: String,
save: bool,
cert_path: Option<PathBuf>,
) -> Result<Self> {
let mut url = url;
let url = if url.ends_with('/') {
url.pop();
url
} else {
url
};
let api_request = malwaredb_api::GetAPIKeyRequest {
user: username,
password,
};
let builder = reqwest::blocking::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
let cert = if let Some(path) = cert_path {
Some((crate::path_load_cert(&path)?, path))
} else {
None
};
let client = if let Some(((_cert_type, cert), _path)) = &cert {
builder.add_root_certificate(cert.clone()).build()
} else {
builder.build()
}?;
let res = client
.post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
.json(&api_request)
.send()?
.json::<ServerResponse<GetAPIKeyResponse>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
let res = match res {
ServerResponse::Success(res) => res,
ServerResponse::Error(err) => return Err(err.into()),
};
#[cfg(target_os = "macos")]
let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
Some(crate::macos::CertificateData {
cert_type: *cert_type,
cert_bytes: std::fs::read(cert_path)?,
})
} else {
None
};
let client = MdbClient {
url,
api_key: res.key.clone(),
client,
#[cfg(target_os = "macos")]
cert,
};
let server_info = client.server_info()?;
if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
warn!(
"Server version {:?} is newer than client {:?}, consider updating.",
server_info.mdb_version,
crate::MDB_VERSION_SEMVER
);
}
if save {
if let Err(e) = client.save() {
error!("Login successful but failed to save config: {e}");
bail!("Login successful but failed to save config: {e}");
}
}
Ok(client)
}
pub fn reset_key(&self) -> Result<()> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
if !response.status().is_success() {
bail!("failed to reset API key, was it correct?");
}
Ok(())
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let name = path.as_ref().display();
let config =
std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
let cfg: MdbClient =
toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
Ok(cfg)
}
pub fn load() -> Result<Self> {
#[cfg(target_os = "macos")]
{
if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
let builder = reqwest::blocking::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
let client = if let Some(cert) = &cert {
builder.add_root_certificate(cert.as_cert()?).build()
} else {
builder.build()
}?;
return Ok(Self {
url,
api_key,
client,
cert,
});
}
}
let path = get_config_path(false)?;
if path.exists() {
return Self::from_file(path);
}
bail!("config file not found")
}
pub fn save(&self) -> Result<()> {
#[cfg(target_os = "macos")]
{
if crate::macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
return Ok(());
}
}
let toml = toml::to_string(self)?;
let path = get_config_path(true)?;
std::fs::write(&path, toml)
.context(format!("failed to write mdb config to {}", path.display()))
}
pub fn delete(&self) -> Result<()> {
#[cfg(target_os = "macos")]
crate::macos::clear_credentials();
let path = get_config_path(false)?;
if path.exists() {
std::fs::remove_file(&path).context(format!(
"failed to delete client config file {}",
path.display()
))?;
}
Ok(())
}
pub fn server_info(&self) -> Result<ServerInfo> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
.send()?
.json::<ServerResponse<ServerInfo>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(info) => Ok(info),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn supported_types(&self) -> Result<SupportedFileTypes> {
let response = self
.client
.get(format!(
"{}{}",
self.url,
malwaredb_api::SUPPORTED_FILE_TYPES_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()?
.json::<ServerResponse<SupportedFileTypes>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(types) => Ok(types),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn whoami(&self) -> Result<GetUserInfoResponse> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()?
.json::<ServerResponse<GetUserInfoResponse>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(info) => Ok(info),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn labels(&self) -> Result<Labels> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()?
.json::<ServerResponse<Labels>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(labels) => Ok(labels),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn sources(&self) -> Result<Sources> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()?
.json::<ServerResponse<Sources>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(sources) => Ok(sources),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn submit(
&self,
contents: impl AsRef<[u8]>,
file_name: String,
source_id: u32,
) -> Result<bool> {
let mut hasher = Sha256::new();
hasher.update(&contents);
let result = hasher.finalize();
let encoded = general_purpose::STANDARD.encode(contents);
let payload = malwaredb_api::NewSampleB64 {
file_name,
source_id,
file_contents_b64: encoded,
sha256: hex::encode(result),
};
match self
.client
.post(format!(
"{}{}",
self.url,
malwaredb_api::UPLOAD_SAMPLE_JSON_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(&payload)
.send()
{
Ok(res) => {
if !res.status().is_success() {
info!("Code {} sending {}", res.status(), payload.file_name);
}
Ok(res.status().is_success())
}
Err(e) => {
let status: String = e
.status()
.map(|s| s.as_str().to_string())
.unwrap_or_default();
error!("Error{status} sending {}: {e}", payload.file_name);
bail!(e.to_string())
}
}
}
pub fn submit_as_cbor(
&self,
contents: impl AsRef<[u8]>,
file_name: String,
source_id: u32,
) -> Result<bool> {
let mut hasher = Sha256::new();
hasher.update(&contents);
let result = hasher.finalize();
let payload = malwaredb_api::NewSampleBytes {
file_name,
source_id,
file_contents: contents.as_ref().to_vec(),
sha256: hex::encode(result),
};
let mut bytes = Vec::with_capacity(payload.file_contents.len());
ciborium::ser::into_writer(&payload, &mut bytes)?;
match self
.client
.post(format!(
"{}{}",
self.url,
malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.header("content-type", "application/cbor")
.body(bytes)
.send()
{
Ok(res) => {
if !res.status().is_success() {
info!("Code {} sending {}", res.status(), payload.file_name);
}
Ok(res.status().is_success())
}
Err(e) => {
let status: String = e
.status()
.map(|s| s.as_str().to_string())
.unwrap_or_default();
error!("Error{status} sending {}: {e}", payload.file_name);
bail!(e.to_string())
}
}
}
pub fn partial_search(
&self,
partial_hash: Option<(PartialHashSearchType, String)>,
name: Option<String>,
response: PartialHashSearchType,
limit: u32,
) -> Result<SearchResponse> {
let query = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
partial_hash,
file_name: name,
response,
limit,
labels: None,
file_type: None,
magic: None,
}),
};
self.do_search_request(&query)
}
#[allow(clippy::too_many_arguments)]
pub fn partial_search_labels_type(
&self,
partial_hash: Option<(PartialHashSearchType, String)>,
name: Option<String>,
response: PartialHashSearchType,
labels: Option<Vec<String>>,
file_type: Option<String>,
magic: Option<String>,
limit: u32,
) -> Result<SearchResponse> {
let query = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
partial_hash,
file_name: name,
response,
limit,
file_type,
magic,
labels,
}),
};
self.do_search_request(&query)
}
pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
if let Some(uuid) = response.pagination {
let request = SearchRequest {
search: SearchType::Continuation(uuid),
};
return self.do_search_request(&request);
}
bail!("Pagination not available")
}
fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
ensure!(
query.is_valid(),
"Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
);
let response = self
.client
.post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(query)
.send()?
.json::<ServerResponse<SearchResponse>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(search) => Ok(search),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
let api_endpoint = if cart {
format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
} else {
format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
};
let res = self
.client
.get(format!("{}{api_endpoint}", self.url))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()?;
if !res.status().is_success() {
bail!("Received code {}", res.status());
}
let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
let body = res.bytes()?;
let bytes = body.to_vec();
if let Some(digest) = content_digest {
let hash = HashType::from_content_digest_header(digest.to_str()?)?;
if hash.verify(&bytes) {
trace!("Hash verified for sample {hash}");
} else {
error!("Hash mismatch for sample {hash}");
}
} else {
warn!("No content digest header received for sample {hash}");
}
Ok(bytes)
}
pub fn report(&self, hash: &str) -> Result<Report> {
let response = self
.client
.get(format!(
"{}{}/{hash}",
self.url,
malwaredb_api::SAMPLE_REPORT_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()?
.json::<ServerResponse<Report>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(report) => Ok(report),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
let mut hashes = vec![];
let ssdeep_hash = FuzzyHash::new(contents);
let build_hasher = Murmur3HashState::default();
let lzjd_str =
LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
hashes.push((
malwaredb_api::SimilarityHashType::SSDeep,
ssdeep_hash.to_string(),
));
let mut builder = TlshBuilder::new(
tlsh_fixed::BucketKind::Bucket256,
tlsh_fixed::ChecksumKind::ThreeByte,
tlsh_fixed::Version::Version4,
);
builder.update(contents);
if let Ok(hasher) = builder.build() {
hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
}
if let Ok(exe) = EXE::from(contents) {
if let Some(imports) = exe.imports {
hashes.push((
malwaredb_api::SimilarityHashType::ImportHash,
hex::encode(imports.hash()),
));
hashes.push((
malwaredb_api::SimilarityHashType::FuzzyImportHash,
imports.fuzzy_hash(),
));
}
}
let request = malwaredb_api::SimilarSamplesRequest { hashes };
let response = self
.client
.post(format!(
"{}{}",
self.url,
malwaredb_api::SIMILAR_SAMPLES_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(&request)
.send()?
.json::<ServerResponse<SimilarSamplesResponse>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(similar) => Ok(similar),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn yara_search(&self, yara: &str) -> Result<YaraSearchRequestResponse> {
let yara = YaraSearchRequest {
rules: vec![yara.to_string()],
response: PartialHashSearchType::SHA256,
};
let response = self
.client
.post(format!("{}{}", self.url, malwaredb_api::YARA_SEARCH_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(&yara)
.send()?
.json::<ServerResponse<YaraSearchRequestResponse>>()?;
match response {
ServerResponse::Success(similar) => Ok(similar),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub fn yara_result(&self, uuid: Uuid) -> Result<YaraSearchResponse> {
let response = self
.client
.get(format!(
"{}{}/{uuid}",
self.url,
malwaredb_api::YARA_SEARCH_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()?
.json::<ServerResponse<YaraSearchResponse>>()?;
match response {
ServerResponse::Success(sources) => Ok(sources),
ServerResponse::Error(e) => Err(e.into()),
}
}
}
impl Debug for MdbClient {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use crate::MDB_VERSION;
writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
pub struct IterableHashSearchResult<'a> {
pub response: SearchResponse,
client: &'a MdbClient,
}
impl<'a> IterableHashSearchResult<'a> {
#[must_use]
pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
Self { response, client }
}
}
impl Iterator for IterableHashSearchResult<'_> {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if let Some(hash) = self.response.hashes.pop() {
Some(hash)
} else if let Some(uuid) = self.response.pagination {
let request = SearchRequest {
search: SearchType::Continuation(uuid),
};
self.response = match self.client.do_search_request(&request) {
Ok(response) => response,
Err(e) => {
warn!("Failed to continue search: {e}");
return None;
}
};
self.response.hashes.pop()
} else {
None
}
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
pub struct IterableSampleSearchResult<'a> {
pub response: SearchResponse,
client: &'a MdbClient,
}
impl<'a> IterableSampleSearchResult<'a> {
#[must_use]
pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
Self { response, client }
}
}
impl Iterator for IterableSampleSearchResult<'_> {
type Item = Vec<u8>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(hash) = self.response.hashes.pop() {
let binary = match self.client.retrieve(&hash, false) {
Ok(binary) => binary,
Err(e) => {
error!("Failed to download {hash}: {e}");
return None;
}
};
Some(binary)
} else if let Some(uuid) = self.response.pagination {
let request = SearchRequest {
search: SearchType::Continuation(uuid),
};
self.response = match self.client.do_search_request(&request) {
Ok(response) => response,
Err(e) => {
warn!("Failed to continue search: {e}");
return None;
}
};
if let Some(hash) = self.response.hashes.pop() {
let binary = match self.client.retrieve(&hash, false) {
Ok(binary) => binary,
Err(e) => {
error!("Failed to download {hash}: {e}");
return None;
}
};
Some(binary)
} else {
None
}
} else {
None
}
}
}