1#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8#![forbid(unsafe_code)]
9
10#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
12#[cfg(feature = "blocking")]
13pub mod blocking;
14
15pub use malwaredb_api;
16use malwaredb_api::{
17 digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
18 Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
19 ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes,
20};
21use malwaredb_lzjd::{LZDict, Murmur3HashState};
22use malwaredb_types::exec::pe32::EXE;
23use malwaredb_types::utils::entropy_calc;
24
25use std::collections::HashSet;
26use std::fmt::{Debug, Display, Formatter};
27use std::io::Cursor;
28use std::path::{Path, PathBuf};
29use std::sync::LazyLock;
30
31use anyhow::{bail, ensure, Context, Result};
32use base64::engine::general_purpose;
33use base64::Engine;
34use cart_container::JsonMap;
35use fuzzyhash::FuzzyHash;
36use home::home_dir;
37use mdns_sd::{ServiceDaemon, ServiceEvent};
38use reqwest::Certificate;
39use serde::{Deserialize, Serialize};
40use sha2::{Digest, Sha256, Sha384, Sha512};
41use tlsh_fixed::TlshBuilder;
42use tracing::{debug, error, info, trace, warn};
43use zeroize::{Zeroize, ZeroizeOnDrop};
44
45const MDB_CLIENT_DIR: &str = "malwaredb_client";
47
48pub(crate) const MDB_CLIENT_ERROR_CONTEXT: &str =
50 "Network error connecting to MalwareDB, or failure to decode server response.";
51
52const MDB_CLIENT_CONFIG_TOML: &str = "mdb_client.toml";
54
55pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
57
58pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
60 LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
61
62#[cfg(target_os = "macos")]
64pub(crate) mod macos {
65 use crate::CertificateType;
66
67 use anyhow::Result;
68 use reqwest::Certificate;
69 use security_framework::os::macos::keychain::SecKeychain;
70 use tracing::error;
71
72 const KEYCHAIN_ID: &str = "malwaredb-client";
74
75 const KEYCHAIN_URL: &str = "URL";
77
78 const KEYCHAIN_API_KEY: &str = "API_KEY";
80
81 const KEYCHAIN_CERTIFICATE_PEM: &str = "CERT_PEM";
83
84 const KEYCHAIN_CERTIFICATE_DER: &str = "CERT_DER";
86
87 #[derive(Clone)]
88 pub(crate) struct CertificateData {
89 pub cert_type: CertificateType,
90 pub cert_bytes: Vec<u8>,
91 }
92
93 impl CertificateData {
94 pub(crate) fn as_cert(&self) -> Result<Certificate> {
95 Ok(match self.cert_type {
96 CertificateType::PEM => Certificate::from_pem(&self.cert_bytes)?,
97 CertificateType::DER => Certificate::from_der(&self.cert_bytes)?,
98 })
99 }
100 }
101
102 pub fn save_credentials(url: &str, key: &str, cert: Option<CertificateData>) -> Result<()> {
104 let keychain = SecKeychain::default()?;
105
106 keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_URL, url.as_bytes())?;
107 keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY, key.as_bytes())?;
108
109 if let Some(cert) = cert {
110 match cert.cert_type {
111 CertificateType::PEM => keychain.add_generic_password(
112 KEYCHAIN_ID,
113 KEYCHAIN_CERTIFICATE_PEM,
114 &cert.cert_bytes,
115 )?,
116 CertificateType::DER => keychain.add_generic_password(
117 KEYCHAIN_ID,
118 KEYCHAIN_CERTIFICATE_DER,
119 &cert.cert_bytes,
120 )?,
121 }
122 }
123
124 Ok(())
125 }
126
127 pub fn retrieve_credentials() -> Result<(String, String, Option<CertificateData>)> {
129 let keychain = SecKeychain::default()?;
130 let (api_key, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY)?;
131 let api_key = String::from_utf8(api_key.as_ref().to_vec())?;
132 let (url, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_URL)?;
133 let url = String::from_utf8(url.as_ref().to_vec())?;
134
135 if let Ok((cert, _item)) =
136 keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_PEM)
137 {
138 let cert = CertificateData {
139 cert_type: CertificateType::PEM,
140 cert_bytes: cert.to_vec(),
141 };
142 return Ok((api_key, url, Some(cert)));
143 }
144
145 if let Ok((cert, _item)) =
146 keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_DER)
147 {
148 let cert = CertificateData {
149 cert_type: CertificateType::DER,
150 cert_bytes: cert.to_vec(),
151 };
152 return Ok((api_key, url, Some(cert)));
153 }
154
155 Ok((api_key, url, None))
156 }
157
158 pub fn clear_credentials() {
160 if let Ok(keychain) = SecKeychain::default() {
161 for element in [
162 KEYCHAIN_API_KEY,
163 KEYCHAIN_URL,
164 KEYCHAIN_CERTIFICATE_PEM,
165 KEYCHAIN_CERTIFICATE_DER,
166 ] {
167 if let Ok((_, item)) = keychain.find_generic_password(KEYCHAIN_ID, element) {
168 item.delete();
169 }
170 }
171 } else {
172 error!("Failed to get access to the Keychain to clear credentials");
173 }
174 }
175}
176
177#[allow(clippy::upper_case_acronyms)]
178#[derive(Copy, Clone, PartialEq, Eq)]
179enum CertificateType {
180 DER,
181 PEM,
182}
183
184#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
186pub struct MdbClient {
187 pub url: String,
189
190 api_key: String,
192
193 #[zeroize(skip)]
195 #[serde(skip)]
196 client: reqwest::Client,
197
198 #[cfg(target_os = "macos")]
200 #[zeroize(skip)]
201 #[serde(skip)]
202 cert: Option<macos::CertificateData>,
203}
204
205impl MdbClient {
206 pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
213 let mut url = url;
214 let url = if url.ends_with('/') {
215 url.pop();
216 url
217 } else {
218 url
219 };
220
221 let cert = if let Some(path) = cert_path {
222 Some((path_load_cert(&path)?, path))
223 } else {
224 None
225 };
226
227 let builder = reqwest::ClientBuilder::new()
228 .gzip(true)
229 .zstd(true)
230 .use_rustls_tls()
231 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
232
233 let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
234 builder.add_root_certificate(cert.clone()).build()
235 } else {
236 builder.build()
237 }?;
238
239 #[cfg(target_os = "macos")]
240 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
241 Some(macos::CertificateData {
242 cert_type: *cert_type,
243 cert_bytes: std::fs::read(cert_path)?,
244 })
245 } else {
246 None
247 };
248
249 Ok(Self {
250 url,
251 api_key,
252 client,
253
254 #[cfg(target_os = "macos")]
255 cert,
256 })
257 }
258
259 pub async fn login(
266 url: String,
267 username: String,
268 password: String,
269 save: bool,
270 cert_path: Option<PathBuf>,
271 ) -> Result<Self> {
272 let mut url = url;
273 let url = if url.ends_with('/') {
274 url.pop();
275 url
276 } else {
277 url
278 };
279
280 let api_request = malwaredb_api::GetAPIKeyRequest {
281 user: username,
282 password,
283 };
284
285 let builder = reqwest::ClientBuilder::new()
286 .gzip(true)
287 .zstd(true)
288 .use_rustls_tls()
289 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
290
291 let cert = if let Some(path) = cert_path {
292 Some((path_load_cert(&path)?, path))
293 } else {
294 None
295 };
296
297 let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
298 builder.add_root_certificate(cert.clone()).build()
299 } else {
300 builder.build()
301 }?;
302
303 let res = client
304 .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
305 .json(&api_request)
306 .send()
307 .await?
308 .json::<ServerResponse<GetAPIKeyResponse>>()
309 .await
310 .context(MDB_CLIENT_ERROR_CONTEXT)?;
311
312 let res = match res {
313 ServerResponse::Success(res) => res,
314 ServerResponse::Error(err) => return Err(err.into()),
315 };
316
317 #[cfg(target_os = "macos")]
318 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
319 Some(macos::CertificateData {
320 cert_type: *cert_type,
321 cert_bytes: std::fs::read(cert_path)?,
322 })
323 } else {
324 None
325 };
326
327 let client = MdbClient {
328 url,
329 api_key: res.key.clone(),
330 client,
331
332 #[cfg(target_os = "macos")]
333 cert,
334 };
335
336 let server_info = client.server_info().await?;
337 if server_info.mdb_version > *MDB_VERSION_SEMVER {
338 warn!(
339 "Server version {:?} is newer than client {:?}, consider updating.",
340 server_info.mdb_version, MDB_VERSION_SEMVER
341 );
342 }
343
344 if save {
345 if let Err(e) = client.save() {
346 error!("Login successful but failed to save config: {e}");
347 bail!("Login successful but failed to save config: {e}");
348 }
349 }
350 Ok(client)
351 }
352
353 pub async fn reset_key(&self) -> Result<()> {
359 let response = self
360 .client
361 .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
362 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
363 .send()
364 .await
365 .context(MDB_CLIENT_ERROR_CONTEXT)?;
366 if !response.status().is_success() {
367 bail!("failed to reset API key, was it correct?");
368 }
369 Ok(())
370 }
371
372 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
379 let name = path.as_ref().display();
380 let config =
381 std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
382 let cfg: MdbClient =
383 toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
384 Ok(cfg)
385 }
386
387 pub fn load() -> Result<Self> {
396 #[cfg(target_os = "macos")]
397 {
398 if let Ok((api_key, url, cert)) = macos::retrieve_credentials() {
399 let builder = reqwest::ClientBuilder::new()
400 .gzip(true)
401 .zstd(true)
402 .use_rustls_tls()
403 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
404
405 let client = if let Some(cert) = &cert {
406 builder.add_root_certificate(cert.as_cert()?).build()
407 } else {
408 builder.build()
409 }?;
410
411 return Ok(Self {
412 url,
413 api_key,
414 client,
415 cert,
416 });
417 }
418 }
419
420 let path = get_config_path(false)?;
421 if path.exists() {
422 return Self::from_file(path);
423 }
424 bail!("config file not found")
425 }
426
427 pub fn save(&self) -> Result<()> {
435 #[cfg(target_os = "macos")]
436 {
437 if macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
438 return Ok(());
439 }
440 }
441
442 let toml = toml::to_string(self)?;
443 let path = get_config_path(true)?;
444 std::fs::write(&path, toml)
445 .context(format!("failed to write mdb config to {}", path.display()))
446 }
447
448 pub fn delete(&self) -> Result<()> {
455 #[cfg(target_os = "macos")]
456 macos::clear_credentials();
457
458 let path = get_config_path(false)?;
459 if path.exists() {
460 std::fs::remove_file(&path).context(format!(
461 "failed to delete client config file {}",
462 path.display()
463 ))?;
464 }
465 Ok(())
466 }
467
468 pub async fn server_info(&self) -> Result<ServerInfo> {
476 let response = self
477 .client
478 .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
479 .send()
480 .await?
481 .json::<ServerResponse<ServerInfo>>()
482 .await
483 .context(MDB_CLIENT_ERROR_CONTEXT)?;
484
485 match response {
486 ServerResponse::Success(info) => Ok(info),
487 ServerResponse::Error(e) => Err(e.into()),
488 }
489 }
490
491 pub async fn supported_types(&self) -> Result<SupportedFileTypes> {
497 let response = self
498 .client
499 .get(format!(
500 "{}{}",
501 self.url,
502 malwaredb_api::SUPPORTED_FILE_TYPES_URL
503 ))
504 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
505 .send()
506 .await?
507 .json::<ServerResponse<SupportedFileTypes>>()
508 .await
509 .context(MDB_CLIENT_ERROR_CONTEXT)?;
510
511 match response {
512 ServerResponse::Success(types) => Ok(types),
513 ServerResponse::Error(e) => Err(e.into()),
514 }
515 }
516
517 pub async fn whoami(&self) -> Result<GetUserInfoResponse> {
524 let response = self
525 .client
526 .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
527 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
528 .send()
529 .await?
530 .json::<ServerResponse<GetUserInfoResponse>>()
531 .await
532 .context(MDB_CLIENT_ERROR_CONTEXT)?;
533
534 match response {
535 ServerResponse::Success(info) => Ok(info),
536 ServerResponse::Error(e) => Err(e.into()),
537 }
538 }
539
540 pub async fn labels(&self) -> Result<Labels> {
547 let response = self
548 .client
549 .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
550 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
551 .send()
552 .await?
553 .json::<ServerResponse<Labels>>()
554 .await
555 .context(MDB_CLIENT_ERROR_CONTEXT)?;
556
557 match response {
558 ServerResponse::Success(labels) => Ok(labels),
559 ServerResponse::Error(e) => Err(e.into()),
560 }
561 }
562
563 pub async fn sources(&self) -> Result<Sources> {
570 let response = self
571 .client
572 .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
573 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
574 .send()
575 .await?
576 .json::<ServerResponse<Sources>>()
577 .await
578 .context(MDB_CLIENT_ERROR_CONTEXT)?;
579
580 match response {
581 ServerResponse::Success(sources) => Ok(sources),
582 ServerResponse::Error(e) => Err(e.into()),
583 }
584 }
585
586 pub async fn submit(
593 &self,
594 contents: impl AsRef<[u8]>,
595 file_name: impl AsRef<str>,
596 source_id: u32,
597 ) -> Result<bool> {
598 let mut hasher = Sha256::new();
599 hasher.update(&contents);
600 let result = hasher.finalize();
601
602 let encoded = general_purpose::STANDARD.encode(contents);
603
604 let payload = malwaredb_api::NewSampleB64 {
605 file_name: file_name.as_ref().to_string(),
606 source_id,
607 file_contents_b64: encoded,
608 sha256: hex::encode(result),
609 };
610
611 match self
612 .client
613 .post(format!(
614 "{}{}",
615 self.url,
616 malwaredb_api::UPLOAD_SAMPLE_JSON_URL
617 ))
618 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
619 .json(&payload)
620 .send()
621 .await
622 {
623 Ok(res) => {
624 if !res.status().is_success() {
625 info!("Code {} sending {}", res.status(), payload.file_name);
626 }
627 Ok(res.status().is_success())
628 }
629 Err(e) => {
630 let status: String = e
631 .status()
632 .map(|s| s.as_str().to_string())
633 .unwrap_or_default();
634 error!("Error{status} sending {}: {e}", payload.file_name);
635 bail!(e.to_string())
636 }
637 }
638 }
639
640 pub async fn submit_as_cbor(
648 &self,
649 contents: impl AsRef<[u8]>,
650 file_name: impl AsRef<str>,
651 source_id: u32,
652 ) -> Result<bool> {
653 let mut hasher = Sha256::new();
654 hasher.update(&contents);
655 let result = hasher.finalize();
656
657 let payload = malwaredb_api::NewSampleBytes {
658 file_name: file_name.as_ref().to_string(),
659 source_id,
660 file_contents: contents.as_ref().to_vec(),
661 sha256: hex::encode(result),
662 };
663
664 let mut bytes = Vec::with_capacity(payload.file_contents.len());
665 ciborium::ser::into_writer(&payload, &mut bytes)?;
666
667 match self
668 .client
669 .post(format!(
670 "{}{}",
671 self.url,
672 malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
673 ))
674 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
675 .header("content-type", "application/cbor")
676 .body(bytes)
677 .send()
678 .await
679 {
680 Ok(res) => {
681 if !res.status().is_success() {
682 info!("Code {} sending {}", res.status(), payload.file_name);
683 }
684 Ok(res.status().is_success())
685 }
686 Err(e) => {
687 let status: String = e
688 .status()
689 .map(|s| s.as_str().to_string())
690 .unwrap_or_default();
691 error!("Error{status} sending {}: {e}", payload.file_name);
692 bail!(e.to_string())
693 }
694 }
695 }
696
697 pub async fn partial_search(
703 &self,
704 partial_hash: Option<(PartialHashSearchType, String)>,
705 name: Option<String>,
706 response: PartialHashSearchType,
707 limit: u32,
708 ) -> Result<SearchResponse> {
709 let query = SearchRequest {
710 search: SearchType::Search(SearchRequestParameters {
711 partial_hash,
712 file_name: name,
713 response,
714 limit,
715 labels: None,
716 file_type: None,
717 magic: None,
718 }),
719 };
720
721 self.do_search_request(&query).await
722 }
723
724 #[allow(clippy::too_many_arguments)]
730 pub async fn partial_search_labels_type(
731 &self,
732 partial_hash: Option<(PartialHashSearchType, String)>,
733 name: Option<String>,
734 response: PartialHashSearchType,
735 labels: Option<Vec<String>>,
736 file_type: Option<String>,
737 magic: Option<String>,
738 limit: u32,
739 ) -> Result<SearchResponse> {
740 let query = SearchRequest {
741 search: SearchType::Search(SearchRequestParameters {
742 partial_hash,
743 file_name: name,
744 response,
745 limit,
746 file_type,
747 magic,
748 labels,
749 }),
750 };
751
752 self.do_search_request(&query).await
753 }
754
755 pub async fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
761 if let Some(uuid) = response.pagination {
762 let request = SearchRequest {
763 search: SearchType::Continuation(uuid),
764 };
765 return self.do_search_request(&request).await;
766 }
767
768 bail!("Pagination not available")
769 }
770
771 async fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
772 ensure!(
773 query.is_valid(),
774 "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
775 );
776
777 let response = self
778 .client
779 .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
780 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
781 .json(query)
782 .send()
783 .await?
784 .json::<ServerResponse<SearchResponse>>()
785 .await
786 .context(MDB_CLIENT_ERROR_CONTEXT)?;
787
788 match response {
789 ServerResponse::Success(search) => Ok(search),
790 ServerResponse::Error(e) => Err(e.into()),
791 }
792 }
793
794 pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
801 let api_endpoint = if cart {
802 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
803 } else {
804 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
805 };
806
807 let res = self
808 .client
809 .get(format!("{}{api_endpoint}", self.url))
810 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
811 .send()
812 .await?;
813
814 if !res.status().is_success() {
815 bail!("Received code {}", res.status());
816 }
817
818 let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
819 let body = res.bytes().await?;
820 let bytes = body.to_vec();
821
822 if let Some(digest) = content_digest {
824 let hash = HashType::from_content_digest_header(digest.to_str()?)?;
825 if hash.verify(&bytes) {
826 trace!("Hash verified for sample {hash}");
827 } else {
828 error!("Hash mismatch for sample {hash}");
829 }
830 } else {
831 warn!("No content digest header received for sample {hash}");
832 }
833
834 Ok(bytes)
835 }
836
837 pub async fn report(&self, hash: &str) -> Result<Report> {
844 let response = self
845 .client
846 .get(format!(
847 "{}{}/{hash}",
848 self.url,
849 malwaredb_api::SAMPLE_REPORT_URL
850 ))
851 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
852 .send()
853 .await?
854 .json::<ServerResponse<Report>>()
855 .await
856 .context(MDB_CLIENT_ERROR_CONTEXT)?;
857
858 match response {
859 ServerResponse::Success(report) => Ok(report),
860 ServerResponse::Error(e) => Err(e.into()),
861 }
862 }
863
864 pub async fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
872 let mut hashes = vec![];
873 let ssdeep_hash = FuzzyHash::new(contents);
874
875 let build_hasher = Murmur3HashState::default();
876 let lzjd_str =
877 LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
878 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
879 hashes.push((
880 malwaredb_api::SimilarityHashType::SSDeep,
881 ssdeep_hash.to_string(),
882 ));
883
884 let mut builder = TlshBuilder::new(
885 tlsh_fixed::BucketKind::Bucket256,
886 tlsh_fixed::ChecksumKind::ThreeByte,
887 tlsh_fixed::Version::Version4,
888 );
889
890 builder.update(contents);
891 if let Ok(hasher) = builder.build() {
892 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
893 }
894
895 if let Ok(exe) = EXE::from(contents) {
896 if let Some(imports) = exe.imports {
897 hashes.push((
898 malwaredb_api::SimilarityHashType::ImportHash,
899 hex::encode(imports.hash()),
900 ));
901 hashes.push((
902 malwaredb_api::SimilarityHashType::FuzzyImportHash,
903 imports.fuzzy_hash(),
904 ));
905 }
906 }
907
908 let request = malwaredb_api::SimilarSamplesRequest { hashes };
909
910 let response = self
911 .client
912 .post(format!(
913 "{}{}",
914 self.url,
915 malwaredb_api::SIMILAR_SAMPLES_URL
916 ))
917 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
918 .json(&request)
919 .send()
920 .await?
921 .json::<ServerResponse<SimilarSamplesResponse>>()
922 .await
923 .context(MDB_CLIENT_ERROR_CONTEXT)?;
924
925 match response {
926 ServerResponse::Success(similar) => Ok(similar),
927 ServerResponse::Error(e) => Err(e.into()),
928 }
929 }
930}
931
932impl Debug for MdbClient {
933 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
934 writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
935 }
936}
937
938pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
947 let mut input_buffer = Cursor::new(data);
948 let mut output_buffer = Cursor::new(vec![]);
949 let mut output_metadata = JsonMap::new();
950
951 let mut sha384 = Sha384::new();
952 sha384.update(data);
953 let sha384 = hex::encode(sha384.finalize());
954
955 let mut sha512 = Sha512::new();
956 sha512.update(data);
957 let sha512 = hex::encode(sha512.finalize());
958
959 output_metadata.insert("sha384".into(), sha384.into());
960 output_metadata.insert("sha512".into(), sha512.into());
961 output_metadata.insert("entropy".into(), entropy_calc(data).into());
962 cart_container::pack_stream(
963 &mut input_buffer,
964 &mut output_buffer,
965 Some(output_metadata),
966 None,
967 cart_container::digesters::default_digesters(),
968 None,
969 )?;
970
971 Ok(output_buffer.into_inner())
972}
973
974pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
983 let mut input_buffer = Cursor::new(data);
984 let mut output_buffer = Cursor::new(vec![]);
985 let (header, footer) =
986 cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
987 Ok((output_buffer.into_inner(), header, footer))
988}
989
990fn path_load_cert(path: &Path) -> Result<(CertificateType, Certificate)> {
997 if !path.exists() {
998 bail!("Certificate {} does not exist.", path.display());
999 }
1000 let cert = match path
1001 .extension()
1002 .context("can't determine file extension")?
1003 .to_str()
1004 .context("unable to parse file extension")?
1005 {
1006 "pem" => {
1007 let contents = std::fs::read(path)?;
1008 (CertificateType::PEM, Certificate::from_pem(&contents)?)
1009 }
1010 "der" => {
1011 let contents = std::fs::read(path)?;
1012 (CertificateType::DER, Certificate::from_der(&contents)?)
1013 }
1014 ext => {
1015 bail!("Unknown extension {ext:?}")
1016 }
1017 };
1018 Ok(cert)
1019}
1020
1021#[inline]
1029pub(crate) fn get_config_path(create: bool) -> Result<PathBuf> {
1030 let config = PathBuf::from(MDB_CLIENT_CONFIG_TOML);
1032 if config.exists() {
1033 return Ok(config);
1034 }
1035
1036 #[cfg(target_os = "haiku")]
1037 {
1038 let mut settings = PathBuf::from("/boot/home/config/settings/malwaredb");
1039 if create && !settings.exists() {
1040 std::fs::create_dir_all(&settings)?;
1041 }
1042 settings.push(MDB_CLIENT_CONFIG_TOML);
1043 return Ok(settings);
1044 }
1045
1046 #[cfg(unix)]
1047 {
1048 if let Some(xdg_home) = std::env::var_os("XDG_CONFIG_HOME") {
1050 let mut xdg_config_home = PathBuf::from(xdg_home);
1051 xdg_config_home.push(MDB_CLIENT_DIR);
1052 if create && !xdg_config_home.exists() {
1053 std::fs::create_dir_all(&xdg_config_home)?;
1054 }
1055 xdg_config_home.push(MDB_CLIENT_CONFIG_TOML);
1056 return Ok(xdg_config_home);
1057 }
1058 }
1059
1060 if let Some(mut home_config) = home_dir() {
1061 home_config.push(".config");
1062 home_config.push(MDB_CLIENT_DIR);
1063 if create && !home_config.exists() {
1064 std::fs::create_dir_all(&home_config)?;
1065 }
1066 home_config.push(MDB_CLIENT_CONFIG_TOML);
1067 return Ok(home_config);
1068 }
1069
1070 Ok(PathBuf::from(MDB_CLIENT_CONFIG_TOML))
1071}
1072
1073#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1075pub struct MalwareDBServer {
1076 pub host: String,
1078
1079 pub port: u16,
1081
1082 pub ssl: bool,
1084
1085 pub name: String,
1087}
1088
1089impl Display for MalwareDBServer {
1090 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1091 if self.ssl {
1092 write!(f, "https://{}:{}", self.host, self.port)
1093 } else {
1094 write!(f, "http://{}:{}", self.host, self.port)
1095 }
1096 }
1097}
1098
1099impl MalwareDBServer {
1100 pub async fn server_info(&self) -> Result<ServerInfo> {
1106 let client = reqwest::ClientBuilder::new()
1107 .gzip(true)
1108 .zstd(true)
1109 .use_rustls_tls()
1110 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1111 .build()?;
1112
1113 let response = client
1114 .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1115 .send()
1116 .await?
1117 .json::<ServerResponse<ServerInfo>>()
1118 .await
1119 .context(MDB_CLIENT_ERROR_CONTEXT)?;
1120
1121 match response {
1122 ServerResponse::Success(info) => Ok(info),
1123 ServerResponse::Error(e) => Err(e.into()),
1124 }
1125 }
1126
1127 #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
1137 #[cfg(feature = "blocking")]
1138 pub fn server_info_blocking(&self) -> Result<ServerInfo> {
1139 let client = reqwest::blocking::ClientBuilder::new()
1140 .gzip(true)
1141 .zstd(true)
1142 .use_rustls_tls()
1143 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1144 .build()?;
1145
1146 let response = client
1147 .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1148 .send()?
1149 .json::<ServerResponse<ServerInfo>>()
1150 .context(MDB_CLIENT_ERROR_CONTEXT)?;
1151
1152 match response {
1153 ServerResponse::Success(similar) => Ok(similar),
1154 ServerResponse::Error(e) => Err(e.into()),
1155 }
1156 }
1157}
1158
1159pub fn discover_servers() -> Result<Vec<MalwareDBServer>> {
1165 const MAX_ITERS: usize = 5;
1166 let mdns = ServiceDaemon::new()?;
1167 let mut servers = HashSet::new();
1168 let receiver = mdns.browse(malwaredb_api::MDNS_NAME)?;
1169
1170 let mut counter = 0;
1171 while let Ok(event) = receiver.recv() {
1172 if let ServiceEvent::ServiceResolved(resolved) = event {
1173 let host = resolved.host.replace(".local.", "");
1174 let ssl = if let Some(ssl) = resolved.txt_properties.get("ssl") {
1175 ssl.val_str() == "true"
1176 } else {
1177 debug!(
1178 "MalwareDB entry for {host}:{} doesn't specify ssl, assuming not",
1179 resolved.port
1180 );
1181 false
1182 };
1183
1184 let server = MalwareDBServer {
1185 host,
1186 port: resolved.port,
1187 ssl,
1188 name: resolved.fullname.replace(malwaredb_api::MDNS_NAME, ""),
1189 };
1190
1191 servers.insert(server);
1192 }
1193 counter += 1;
1194 if counter > MAX_ITERS {
1195 break;
1196 }
1197 }
1198
1199 Ok(servers.into_iter().collect())
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204 use super::*;
1205
1206 #[test]
1207 fn cart() {
1208 const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
1209 const ORIGINAL_SHA256: &str =
1210 "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
1211
1212 let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
1213
1214 let mut sha256 = Sha256::new();
1215 sha256.update(&decoded);
1216 let sha256 = hex::encode(sha256.finalize());
1217 assert_eq!(sha256, ORIGINAL_SHA256);
1218
1219 let header = header.unwrap();
1220 let entropy = header.get("entropy").unwrap().as_f64().unwrap();
1221 assert!(entropy > 4.0 && entropy < 4.1);
1222
1223 let footer = footer.unwrap();
1224 assert_eq!(footer.get("length").unwrap(), "5093");
1225 assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
1226 }
1227}