#![doc = include_str!("../README.md")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(missing_docs)]
#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![forbid(unsafe_code)]
#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
#[cfg(feature = "blocking")]
pub mod blocking;
pub use malwaredb_api;
use malwaredb_api::{
digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
YaraSearchRequestResponse, YaraSearchResponse,
};
use malwaredb_lzjd::{LZDict, Murmur3HashState};
use malwaredb_types::exec::pe32::EXE;
use malwaredb_types::utils::entropy_calc;
use std::collections::HashSet;
use std::fmt::{Debug, Display, Formatter};
use std::io::Cursor;
use std::path::{Path, PathBuf};
use std::sync::LazyLock;
use anyhow::{bail, ensure, Context, Result};
use base64::engine::general_purpose;
use base64::Engine;
use cart_container::JsonMap;
use fuzzyhash::FuzzyHash;
use home::home_dir;
use mdns_sd::{ServiceDaemon, ServiceEvent};
use reqwest::Certificate;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256, Sha384, Sha512};
use tlsh_fixed::TlshBuilder;
use tracing::{debug, error, info, trace, warn};
use uuid::Uuid;
use zeroize::{Zeroize, ZeroizeOnDrop};
const MDB_CLIENT_DIR: &str = "malwaredb_client";
pub(crate) const MDB_CLIENT_ERROR_CONTEXT: &str =
"Network error connecting to MalwareDB, or failure to decode server response.";
const MDB_CLIENT_CONFIG_TOML: &str = "mdb_client.toml";
pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
#[cfg(target_os = "macos")]
pub(crate) mod macos {
use crate::CertificateType;
use anyhow::Result;
use reqwest::Certificate;
use security_framework::os::macos::keychain::SecKeychain;
use tracing::error;
const KEYCHAIN_ID: &str = "malwaredb-client";
const KEYCHAIN_URL: &str = "URL";
const KEYCHAIN_API_KEY: &str = "API_KEY";
const KEYCHAIN_CERTIFICATE_PEM: &str = "CERT_PEM";
const KEYCHAIN_CERTIFICATE_DER: &str = "CERT_DER";
#[derive(Clone)]
pub(crate) struct CertificateData {
pub cert_type: CertificateType,
pub cert_bytes: Vec<u8>,
}
impl CertificateData {
pub(crate) fn as_cert(&self) -> Result<Certificate> {
Ok(match self.cert_type {
CertificateType::PEM => Certificate::from_pem(&self.cert_bytes)?,
CertificateType::DER => Certificate::from_der(&self.cert_bytes)?,
})
}
}
pub fn save_credentials(url: &str, key: &str, cert: Option<CertificateData>) -> Result<()> {
let keychain = SecKeychain::default()?;
keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_URL, url.as_bytes())?;
keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY, key.as_bytes())?;
if let Some(cert) = cert {
match cert.cert_type {
CertificateType::PEM => keychain.add_generic_password(
KEYCHAIN_ID,
KEYCHAIN_CERTIFICATE_PEM,
&cert.cert_bytes,
)?,
CertificateType::DER => keychain.add_generic_password(
KEYCHAIN_ID,
KEYCHAIN_CERTIFICATE_DER,
&cert.cert_bytes,
)?,
}
}
Ok(())
}
pub fn retrieve_credentials() -> Result<(String, String, Option<CertificateData>)> {
let keychain = SecKeychain::default()?;
let (api_key, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY)?;
let api_key = String::from_utf8(api_key.as_ref().to_vec())?;
let (url, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_URL)?;
let url = String::from_utf8(url.as_ref().to_vec())?;
if let Ok((cert, _item)) =
keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_PEM)
{
let cert = CertificateData {
cert_type: CertificateType::PEM,
cert_bytes: cert.to_vec(),
};
return Ok((api_key, url, Some(cert)));
}
if let Ok((cert, _item)) =
keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_DER)
{
let cert = CertificateData {
cert_type: CertificateType::DER,
cert_bytes: cert.to_vec(),
};
return Ok((api_key, url, Some(cert)));
}
Ok((api_key, url, None))
}
pub fn clear_credentials() {
if let Ok(keychain) = SecKeychain::default() {
for element in [
KEYCHAIN_API_KEY,
KEYCHAIN_URL,
KEYCHAIN_CERTIFICATE_PEM,
KEYCHAIN_CERTIFICATE_DER,
] {
if let Ok((_, item)) = keychain.find_generic_password(KEYCHAIN_ID, element) {
item.delete();
}
}
} else {
error!("Failed to get access to the Keychain to clear credentials");
}
}
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Copy, Clone, PartialEq, Eq)]
enum CertificateType {
DER,
PEM,
}
#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
pub struct MdbClient {
pub url: String,
api_key: String,
#[zeroize(skip)]
#[serde(skip)]
client: reqwest::Client,
#[cfg(target_os = "macos")]
#[zeroize(skip)]
#[serde(skip)]
cert: Option<macos::CertificateData>,
}
impl MdbClient {
pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
let mut url = url;
let url = if url.ends_with('/') {
url.pop();
url
} else {
url
};
let cert = if let Some(path) = cert_path {
Some((path_load_cert(&path)?, path))
} else {
None
};
let builder = reqwest::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
builder.add_root_certificate(cert.clone()).build()
} else {
builder.build()
}?;
#[cfg(target_os = "macos")]
let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
Some(macos::CertificateData {
cert_type: *cert_type,
cert_bytes: std::fs::read(cert_path)?,
})
} else {
None
};
Ok(Self {
url,
api_key,
client,
#[cfg(target_os = "macos")]
cert,
})
}
pub async fn login(
url: String,
username: String,
password: String,
save: bool,
cert_path: Option<PathBuf>,
) -> Result<Self> {
let mut url = url;
let url = if url.ends_with('/') {
url.pop();
url
} else {
url
};
let api_request = malwaredb_api::GetAPIKeyRequest {
user: username,
password,
};
let builder = reqwest::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
let cert = if let Some(path) = cert_path {
Some((path_load_cert(&path)?, path))
} else {
None
};
let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
builder.add_root_certificate(cert.clone()).build()
} else {
builder.build()
}?;
let res = client
.post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
.json(&api_request)
.send()
.await?
.json::<ServerResponse<GetAPIKeyResponse>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
let res = match res {
ServerResponse::Success(res) => res,
ServerResponse::Error(err) => return Err(err.into()),
};
#[cfg(target_os = "macos")]
let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
Some(macos::CertificateData {
cert_type: *cert_type,
cert_bytes: std::fs::read(cert_path)?,
})
} else {
None
};
let client = MdbClient {
url,
api_key: res.key.clone(),
client,
#[cfg(target_os = "macos")]
cert,
};
let server_info = client.server_info().await?;
if server_info.mdb_version > *MDB_VERSION_SEMVER {
warn!(
"Server version {:?} is newer than client {:?}, consider updating.",
server_info.mdb_version, MDB_VERSION_SEMVER
);
}
if save {
if let Err(e) = client.save() {
error!("Login successful but failed to save config: {e}");
bail!("Login successful but failed to save config: {e}");
}
}
Ok(client)
}
pub async fn reset_key(&self) -> Result<()> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
if !response.status().is_success() {
bail!("failed to reset API key, was it correct?");
}
Ok(())
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let name = path.as_ref().display();
let config =
std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
let cfg: MdbClient =
toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
Ok(cfg)
}
pub fn load() -> Result<Self> {
#[cfg(target_os = "macos")]
{
if let Ok((api_key, url, cert)) = macos::retrieve_credentials() {
let builder = reqwest::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
let client = if let Some(cert) = &cert {
builder.add_root_certificate(cert.as_cert()?).build()
} else {
builder.build()
}?;
return Ok(Self {
url,
api_key,
client,
cert,
});
}
}
let path = get_config_path(false)?;
if path.exists() {
return Self::from_file(path);
}
bail!("config file not found")
}
pub fn save(&self) -> Result<()> {
#[cfg(target_os = "macos")]
{
if macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
return Ok(());
}
}
let toml = toml::to_string(self)?;
let path = get_config_path(true)?;
std::fs::write(&path, toml)
.context(format!("failed to write mdb config to {}", path.display()))
}
pub fn delete(&self) -> Result<()> {
#[cfg(target_os = "macos")]
macos::clear_credentials();
let path = get_config_path(false)?;
if path.exists() {
std::fs::remove_file(&path).context(format!(
"failed to delete client config file {}",
path.display()
))?;
}
Ok(())
}
pub async fn server_info(&self) -> Result<ServerInfo> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
.send()
.await?
.json::<ServerResponse<ServerInfo>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(info) => Ok(info),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn supported_types(&self) -> Result<SupportedFileTypes> {
let response = self
.client
.get(format!(
"{}{}",
self.url,
malwaredb_api::SUPPORTED_FILE_TYPES_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await?
.json::<ServerResponse<SupportedFileTypes>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(types) => Ok(types),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn whoami(&self) -> Result<GetUserInfoResponse> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await?
.json::<ServerResponse<GetUserInfoResponse>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(info) => Ok(info),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn labels(&self) -> Result<Labels> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await?
.json::<ServerResponse<Labels>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(labels) => Ok(labels),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn sources(&self) -> Result<Sources> {
let response = self
.client
.get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await?
.json::<ServerResponse<Sources>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(sources) => Ok(sources),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn submit(
&self,
contents: impl AsRef<[u8]>,
file_name: impl AsRef<str>,
source_id: u32,
) -> Result<bool> {
let mut hasher = Sha256::new();
hasher.update(&contents);
let result = hasher.finalize();
let encoded = general_purpose::STANDARD.encode(contents);
let payload = malwaredb_api::NewSampleB64 {
file_name: file_name.as_ref().to_string(),
source_id,
file_contents_b64: encoded,
sha256: hex::encode(result),
};
match self
.client
.post(format!(
"{}{}",
self.url,
malwaredb_api::UPLOAD_SAMPLE_JSON_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(&payload)
.send()
.await
{
Ok(res) => {
if !res.status().is_success() {
info!("Code {} sending {}", res.status(), payload.file_name);
}
Ok(res.status().is_success())
}
Err(e) => {
let status: String = e
.status()
.map(|s| s.as_str().to_string())
.unwrap_or_default();
error!("Error{status} sending {}: {e}", payload.file_name);
bail!(e.to_string())
}
}
}
pub async fn submit_as_cbor(
&self,
contents: impl AsRef<[u8]>,
file_name: impl AsRef<str>,
source_id: u32,
) -> Result<bool> {
let mut hasher = Sha256::new();
hasher.update(&contents);
let result = hasher.finalize();
let payload = malwaredb_api::NewSampleBytes {
file_name: file_name.as_ref().to_string(),
source_id,
file_contents: contents.as_ref().to_vec(),
sha256: hex::encode(result),
};
let mut bytes = Vec::with_capacity(payload.file_contents.len());
ciborium::ser::into_writer(&payload, &mut bytes)?;
match self
.client
.post(format!(
"{}{}",
self.url,
malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.header("content-type", "application/cbor")
.body(bytes)
.send()
.await
{
Ok(res) => {
if !res.status().is_success() {
info!("Code {} sending {}", res.status(), payload.file_name);
}
Ok(res.status().is_success())
}
Err(e) => {
let status: String = e
.status()
.map(|s| s.as_str().to_string())
.unwrap_or_default();
error!("Error{status} sending {}: {e}", payload.file_name);
bail!(e.to_string())
}
}
}
pub async fn partial_search(
&self,
partial_hash: Option<(PartialHashSearchType, String)>,
name: Option<String>,
response: PartialHashSearchType,
limit: u32,
) -> Result<SearchResponse> {
let query = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
partial_hash,
file_name: name,
response,
limit,
labels: None,
file_type: None,
magic: None,
}),
};
self.do_search_request(&query).await
}
#[allow(clippy::too_many_arguments)]
pub async fn partial_search_labels_type(
&self,
partial_hash: Option<(PartialHashSearchType, String)>,
name: Option<String>,
response: PartialHashSearchType,
labels: Option<Vec<String>>,
file_type: Option<String>,
magic: Option<String>,
limit: u32,
) -> Result<SearchResponse> {
let query = SearchRequest {
search: SearchType::Search(SearchRequestParameters {
partial_hash,
file_name: name,
response,
limit,
file_type,
magic,
labels,
}),
};
self.do_search_request(&query).await
}
pub async fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
if let Some(uuid) = response.pagination {
let request = SearchRequest {
search: SearchType::Continuation(uuid),
};
return self.do_search_request(&request).await;
}
bail!("Pagination not available")
}
async fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
ensure!(
query.is_valid(),
"Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
);
let response = self
.client
.post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(query)
.send()
.await?
.json::<ServerResponse<SearchResponse>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(search) => Ok(search),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
let api_endpoint = if cart {
format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
} else {
format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
};
let res = self
.client
.get(format!("{}{api_endpoint}", self.url))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await?;
if !res.status().is_success() {
bail!("Received code {}", res.status());
}
let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
let body = res.bytes().await?;
let bytes = body.to_vec();
if let Some(digest) = content_digest {
let hash = HashType::from_content_digest_header(digest.to_str()?)?;
if hash.verify(&bytes) {
trace!("Hash verified for sample {hash}");
} else {
error!("Hash mismatch for sample {hash}");
}
} else {
warn!("No content digest header received for sample {hash}");
}
Ok(bytes)
}
pub async fn report(&self, hash: &str) -> Result<Report> {
let response = self
.client
.get(format!(
"{}{}/{hash}",
self.url,
malwaredb_api::SAMPLE_REPORT_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await?
.json::<ServerResponse<Report>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(report) => Ok(report),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
let mut hashes = vec![];
let ssdeep_hash = FuzzyHash::new(contents);
let build_hasher = Murmur3HashState::default();
let lzjd_str =
LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
hashes.push((
malwaredb_api::SimilarityHashType::SSDeep,
ssdeep_hash.to_string(),
));
let mut builder = TlshBuilder::new(
tlsh_fixed::BucketKind::Bucket256,
tlsh_fixed::ChecksumKind::ThreeByte,
tlsh_fixed::Version::Version4,
);
builder.update(contents);
if let Ok(hasher) = builder.build() {
hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
}
if let Ok(exe) = EXE::from(contents) {
if let Some(imports) = exe.imports {
hashes.push((
malwaredb_api::SimilarityHashType::ImportHash,
hex::encode(imports.hash()),
));
hashes.push((
malwaredb_api::SimilarityHashType::FuzzyImportHash,
imports.fuzzy_hash(),
));
}
}
let request = malwaredb_api::SimilarSamplesRequest { hashes };
let response = self
.client
.post(format!(
"{}{}",
self.url,
malwaredb_api::SIMILAR_SAMPLES_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(&request)
.send()
.await?
.json::<ServerResponse<SimilarSamplesResponse>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(similar) => Ok(similar),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn yara_search(&self, yara: &str) -> Result<YaraSearchRequestResponse> {
let yara = YaraSearchRequest {
rules: vec![yara.to_string()],
response: PartialHashSearchType::SHA256,
};
let response = self
.client
.post(format!("{}{}", self.url, malwaredb_api::YARA_SEARCH_URL))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.json(&yara)
.send()
.await?
.json::<ServerResponse<YaraSearchRequestResponse>>()
.await?;
match response {
ServerResponse::Success(sources) => Ok(sources),
ServerResponse::Error(e) => Err(e.into()),
}
}
pub async fn yara_result(&self, uuid: Uuid) -> Result<YaraSearchResponse> {
let response = self
.client
.get(format!(
"{}{}/{uuid}",
self.url,
malwaredb_api::YARA_SEARCH_URL
))
.header(malwaredb_api::MDB_API_HEADER, &self.api_key)
.send()
.await?
.json::<ServerResponse<YaraSearchResponse>>()
.await?;
match response {
ServerResponse::Success(sources) => Ok(sources),
ServerResponse::Error(e) => Err(e.into()),
}
}
}
impl Debug for MdbClient {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
}
}
pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
let mut input_buffer = Cursor::new(data);
let mut output_buffer = Cursor::new(vec![]);
let mut output_metadata = JsonMap::new();
let mut sha384 = Sha384::new();
sha384.update(data);
let sha384 = hex::encode(sha384.finalize());
let mut sha512 = Sha512::new();
sha512.update(data);
let sha512 = hex::encode(sha512.finalize());
output_metadata.insert("sha384".into(), sha384.into());
output_metadata.insert("sha512".into(), sha512.into());
output_metadata.insert("entropy".into(), entropy_calc(data).into());
cart_container::pack_stream(
&mut input_buffer,
&mut output_buffer,
Some(output_metadata),
None,
cart_container::digesters::default_digesters(),
None,
)?;
Ok(output_buffer.into_inner())
}
pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
let mut input_buffer = Cursor::new(data);
let mut output_buffer = Cursor::new(vec![]);
let (header, footer) =
cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
Ok((output_buffer.into_inner(), header, footer))
}
fn path_load_cert(path: &Path) -> Result<(CertificateType, Certificate)> {
if !path.exists() {
bail!("Certificate {} does not exist.", path.display());
}
let cert = match path
.extension()
.context("can't determine file extension")?
.to_str()
.context("unable to parse file extension")?
{
"pem" => {
let contents = std::fs::read(path)?;
(CertificateType::PEM, Certificate::from_pem(&contents)?)
}
"der" => {
let contents = std::fs::read(path)?;
(CertificateType::DER, Certificate::from_der(&contents)?)
}
ext => {
bail!("Unknown extension {ext:?}")
}
};
Ok(cert)
}
#[inline]
pub(crate) fn get_config_path(create: bool) -> Result<PathBuf> {
let config = PathBuf::from(MDB_CLIENT_CONFIG_TOML);
if config.exists() {
return Ok(config);
}
#[cfg(target_os = "haiku")]
{
let mut settings = PathBuf::from("/boot/home/config/settings/malwaredb");
if create && !settings.exists() {
std::fs::create_dir_all(&settings)?;
}
settings.push(MDB_CLIENT_CONFIG_TOML);
return Ok(settings);
}
#[cfg(unix)]
{
if let Some(xdg_home) = std::env::var_os("XDG_CONFIG_HOME") {
let mut xdg_config_home = PathBuf::from(xdg_home);
xdg_config_home.push(MDB_CLIENT_DIR);
if create && !xdg_config_home.exists() {
std::fs::create_dir_all(&xdg_config_home)?;
}
xdg_config_home.push(MDB_CLIENT_CONFIG_TOML);
return Ok(xdg_config_home);
}
}
if let Some(mut home_config) = home_dir() {
home_config.push(".config");
home_config.push(MDB_CLIENT_DIR);
if create && !home_config.exists() {
std::fs::create_dir_all(&home_config)?;
}
home_config.push(MDB_CLIENT_CONFIG_TOML);
return Ok(home_config);
}
Ok(PathBuf::from(MDB_CLIENT_CONFIG_TOML))
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MalwareDBServer {
pub host: String,
pub port: u16,
pub ssl: bool,
pub name: String,
}
impl Display for MalwareDBServer {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.ssl {
write!(f, "https://{}:{}", self.host, self.port)
} else {
write!(f, "http://{}:{}", self.host, self.port)
}
}
}
impl MalwareDBServer {
pub async fn server_info(&self) -> Result<ServerInfo> {
let client = reqwest::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
.build()?;
let response = client
.get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
.send()
.await?
.json::<ServerResponse<ServerInfo>>()
.await
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(info) => Ok(info),
ServerResponse::Error(e) => Err(e.into()),
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
#[cfg(feature = "blocking")]
pub fn server_info_blocking(&self) -> Result<ServerInfo> {
let client = reqwest::blocking::ClientBuilder::new()
.gzip(true)
.zstd(true)
.use_rustls_tls()
.user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
.build()?;
let response = client
.get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
.send()?
.json::<ServerResponse<ServerInfo>>()
.context(MDB_CLIENT_ERROR_CONTEXT)?;
match response {
ServerResponse::Success(similar) => Ok(similar),
ServerResponse::Error(e) => Err(e.into()),
}
}
}
pub fn discover_servers() -> Result<Vec<MalwareDBServer>> {
const MAX_ITERS: usize = 5;
let mdns = ServiceDaemon::new()?;
let mut servers = HashSet::new();
let receiver = mdns.browse(malwaredb_api::MDNS_NAME)?;
let mut counter = 0;
while let Ok(event) = receiver.recv() {
if let ServiceEvent::ServiceResolved(resolved) = event {
let host = resolved.host.replace(".local.", "");
let ssl = if let Some(ssl) = resolved.txt_properties.get("ssl") {
ssl.val_str() == "true"
} else {
debug!(
"MalwareDB entry for {host}:{} doesn't specify ssl, assuming not",
resolved.port
);
false
};
let server = MalwareDBServer {
host,
port: resolved.port,
ssl,
name: resolved.fullname.replace(malwaredb_api::MDNS_NAME, ""),
};
servers.insert(server);
}
counter += 1;
if counter > MAX_ITERS {
break;
}
}
if mdns.shutdown().is_err() {
}
Ok(servers.into_iter().collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cart() {
const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
const ORIGINAL_SHA256: &str =
"de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
let mut sha256 = Sha256::new();
sha256.update(&decoded);
let sha256 = hex::encode(sha256.finalize());
assert_eq!(sha256, ORIGINAL_SHA256);
let header = header.unwrap();
let entropy = header.get("entropy").unwrap().as_f64().unwrap();
assert!(entropy > 4.0 && entropy < 4.1);
let footer = footer.unwrap();
assert_eq!(footer.get("length").unwrap(), "5093");
assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
}
}