1use std::fmt::{Debug, Formatter};
4use std::fs::OpenOptions;
5use std::io::Write;
6use std::path::{Path, PathBuf};
7
8use crate::{MDB_CLIENT_ERROR_CONTEXT, get_config_path};
9use malwaredb_api::{
10 GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType, Report, SearchRequest,
11 SearchRequestParameters, SearchResponse, SearchType, ServerInfo, ServerResponse,
12 SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
13 YaraSearchRequestResponse, YaraSearchResponse, digest::HashType,
14};
15use malwaredb_types::exec::pe32::EXE;
16
17use anyhow::{Context, Result, bail, ensure};
18use base64::Engine;
19use base64::engine::general_purpose;
20use fuzzyhash::FuzzyHash;
21use malwaredb_lzjd::{LZDict, Murmur3HashState};
22use serde::{Deserialize, Serialize};
23use sha2::{Digest, Sha256};
24use tlsh_fixed::TlshBuilder;
25use tracing::{error, info, trace, warn};
26use uuid::Uuid;
27use zeroize::{Zeroize, ZeroizeOnDrop};
28
29#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
31#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
32pub struct MdbClient {
33 pub url: String,
35
36 api_key: String,
38
39 #[zeroize(skip)]
41 #[serde(skip)]
42 client: reqwest::blocking::Client,
43
44 #[cfg(target_os = "macos")]
46 #[zeroize(skip)]
47 #[serde(skip)]
48 cert: Option<crate::macos::CertificateData>,
49}
50
51impl MdbClient {
52 pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
63 let mut url = url;
64 let url = if url.ends_with('/') {
65 url.pop();
66 url
67 } else {
68 url
69 };
70
71 let cert = if let Some(path) = cert_path {
72 Some((crate::path_load_cert(&path)?, path))
73 } else {
74 None
75 };
76
77 let builder = reqwest::blocking::ClientBuilder::new()
78 .gzip(true)
79 .zstd(true)
80 .use_rustls_tls()
81 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
82
83 let client = if let Some(((_cert_type, cert), _path)) = &cert {
84 builder.add_root_certificate(cert.clone()).build()
85 } else {
86 builder.build()
87 }?;
88
89 #[cfg(target_os = "macos")]
90 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
91 Some(crate::macos::CertificateData {
92 cert_type: *cert_type,
93 cert_bytes: std::fs::read(cert_path)?,
94 })
95 } else {
96 None
97 };
98
99 Ok(Self {
100 url,
101 api_key,
102 client,
103
104 #[cfg(target_os = "macos")]
105 cert,
106 })
107 }
108
109 pub fn login(
120 url: String,
121 username: String,
122 password: String,
123 save: bool,
124 cert_path: Option<PathBuf>,
125 ) -> Result<Self> {
126 let mut url = url;
127 let url = if url.ends_with('/') {
128 url.pop();
129 url
130 } else {
131 url
132 };
133
134 let api_request = malwaredb_api::GetAPIKeyRequest {
135 user: username,
136 password,
137 };
138
139 let builder = reqwest::blocking::ClientBuilder::new()
140 .gzip(true)
141 .zstd(true)
142 .use_rustls_tls()
143 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
144
145 let cert = if let Some(path) = cert_path {
146 Some((crate::path_load_cert(&path)?, path))
147 } else {
148 None
149 };
150
151 let client = if let Some(((_cert_type, cert), _path)) = &cert {
152 builder.add_root_certificate(cert.clone()).build()
153 } else {
154 builder.build()
155 }?;
156
157 let res = client
158 .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
159 .json(&api_request)
160 .send()?
161 .json::<ServerResponse<GetAPIKeyResponse>>()
162 .context(MDB_CLIENT_ERROR_CONTEXT)?;
163
164 let res = match res {
165 ServerResponse::Success(res) => res,
166 ServerResponse::Error(err) => return Err(err.into()),
167 };
168
169 #[cfg(target_os = "macos")]
170 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
171 Some(crate::macos::CertificateData {
172 cert_type: *cert_type,
173 cert_bytes: std::fs::read(cert_path)?,
174 })
175 } else {
176 None
177 };
178
179 let client = MdbClient {
180 url,
181 api_key: res.key.clone(),
182 client,
183
184 #[cfg(target_os = "macos")]
185 cert,
186 };
187
188 let server_info = client.server_info()?;
189 if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
190 warn!(
191 "Server version {:?} is newer than client {:?}, consider updating.",
192 server_info.mdb_version,
193 crate::MDB_VERSION_SEMVER
194 );
195 }
196
197 if save && let Err(e) = client.save() {
198 error!("Login successful but failed to save config: {e}");
199 bail!("Login successful but failed to save config: {e}");
200 }
201 Ok(client)
202 }
203
204 pub fn reset_key(&self) -> Result<()> {
210 let response = self
211 .client
212 .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
213 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
214 .send()
215 .context(MDB_CLIENT_ERROR_CONTEXT)?;
216 if !response.status().is_success() {
217 bail!("failed to reset API key, was it correct?");
218 }
219 Ok(())
220 }
221
222 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
229 let name = path.as_ref().display();
230 let config =
231 std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
232 let cfg: MdbClient =
233 toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
234 Ok(cfg)
235 }
236
237 pub fn load() -> Result<Self> {
250 #[cfg(target_os = "macos")]
251 {
252 if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
253 let builder = reqwest::blocking::ClientBuilder::new()
254 .gzip(true)
255 .zstd(true)
256 .use_rustls_tls()
257 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
258
259 let client = if let Some(cert) = &cert {
260 builder.add_root_certificate(cert.as_cert()?).build()
261 } else {
262 builder.build()
263 }?;
264
265 return Ok(Self {
266 url,
267 api_key,
268 client,
269 cert,
270 });
271 }
272 }
273
274 let path = get_config_path(false)?;
275 if path.exists() {
276 return Self::from_file(path);
277 }
278 bail!("config file not found")
279 }
280
281 pub fn save(&self) -> Result<()> {
289 #[cfg(target_os = "macos")]
290 {
291 if crate::macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
292 return Ok(());
293 }
294 }
295
296 let toml = toml::to_string(self)?;
297 let path = get_config_path(true)?;
298
299 let mut options = OpenOptions::new();
300 options
301 .write(true)
302 .create(true)
303 .append(false)
304 .truncate(false);
305
306 #[cfg(target_family = "unix")]
307 {
308 use std::os::unix::fs::OpenOptionsExt;
309
310 options.mode(0o600);
311 }
312
313 let mut file = options.open(&path)?;
314 write!(file, "{toml}").context(format!("failed to write mdb config to {}", path.display()))
315 }
316
317 pub fn delete(&self) -> Result<()> {
324 #[cfg(target_os = "macos")]
325 crate::macos::clear_credentials();
326
327 let path = get_config_path(false)?;
328 if path.exists() {
329 std::fs::remove_file(&path)
330 .context(format!("failed to delete client config file {}", path.display()))?;
331 }
332 Ok(())
333 }
334
335 pub fn server_info(&self) -> Result<ServerInfo> {
343 let response = self
344 .client
345 .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
346 .send()?
347 .json::<ServerResponse<ServerInfo>>()
348 .context(MDB_CLIENT_ERROR_CONTEXT)?;
349
350 match response {
351 ServerResponse::Success(info) => Ok(info),
352 ServerResponse::Error(e) => Err(e.into()),
353 }
354 }
355
356 pub fn supported_types(&self) -> Result<SupportedFileTypes> {
362 let response = self
363 .client
364 .get(format!("{}{}", self.url, malwaredb_api::SUPPORTED_FILE_TYPES_URL))
365 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
366 .send()?
367 .json::<ServerResponse<SupportedFileTypes>>()
368 .context(MDB_CLIENT_ERROR_CONTEXT)?;
369
370 match response {
371 ServerResponse::Success(types) => Ok(types),
372 ServerResponse::Error(e) => Err(e.into()),
373 }
374 }
375
376 pub fn whoami(&self) -> Result<GetUserInfoResponse> {
383 let response = self
384 .client
385 .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
386 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
387 .send()?
388 .json::<ServerResponse<GetUserInfoResponse>>()
389 .context(MDB_CLIENT_ERROR_CONTEXT)?;
390
391 match response {
392 ServerResponse::Success(info) => Ok(info),
393 ServerResponse::Error(e) => Err(e.into()),
394 }
395 }
396
397 pub fn labels(&self) -> Result<Labels> {
404 let response = self
405 .client
406 .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
407 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
408 .send()?
409 .json::<ServerResponse<Labels>>()
410 .context(MDB_CLIENT_ERROR_CONTEXT)?;
411
412 match response {
413 ServerResponse::Success(labels) => Ok(labels),
414 ServerResponse::Error(e) => Err(e.into()),
415 }
416 }
417
418 pub fn sources(&self) -> Result<Sources> {
425 let response = self
426 .client
427 .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
428 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
429 .send()?
430 .json::<ServerResponse<Sources>>()
431 .context(MDB_CLIENT_ERROR_CONTEXT)?;
432
433 match response {
434 ServerResponse::Success(sources) => Ok(sources),
435 ServerResponse::Error(e) => Err(e.into()),
436 }
437 }
438
439 pub fn submit(
446 &self,
447 contents: impl AsRef<[u8]>,
448 file_name: String,
449 source_id: u32,
450 ) -> Result<bool> {
451 let mut hasher = Sha256::new();
452 hasher.update(&contents);
453 let result = hasher.finalize();
454
455 let encoded = general_purpose::STANDARD.encode(contents);
456
457 let payload = malwaredb_api::NewSampleB64 {
458 file_name,
459 source_id,
460 file_contents_b64: encoded,
461 sha256: hex::encode(result),
462 };
463
464 match self
465 .client
466 .post(format!("{}{}", self.url, malwaredb_api::UPLOAD_SAMPLE_JSON_URL))
467 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
468 .json(&payload)
469 .send()
470 {
471 Ok(res) => {
472 if !res.status().is_success() {
473 info!("Code {} sending {}", res.status(), payload.file_name);
474 }
475 Ok(res.status().is_success())
476 }
477 Err(e) => {
478 let status: String = e
479 .status()
480 .map(|s| s.as_str().to_string())
481 .unwrap_or_default();
482 error!("Error{status} sending {}: {e}", payload.file_name);
483 bail!(e.to_string())
484 }
485 }
486 }
487
488 pub fn submit_as_cbor(
496 &self,
497 contents: impl AsRef<[u8]>,
498 file_name: String,
499 source_id: u32,
500 ) -> Result<bool> {
501 let mut hasher = Sha256::new();
502 hasher.update(&contents);
503 let result = hasher.finalize();
504
505 let payload = malwaredb_api::NewSampleBytes {
506 file_name,
507 source_id,
508 file_contents: contents.as_ref().to_vec(),
509 sha256: hex::encode(result),
510 };
511
512 let mut bytes = Vec::with_capacity(payload.file_contents.len());
513 ciborium::ser::into_writer(&payload, &mut bytes)?;
514
515 match self
516 .client
517 .post(format!("{}{}", self.url, malwaredb_api::UPLOAD_SAMPLE_CBOR_URL))
518 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
519 .header("content-type", "application/cbor")
520 .body(bytes)
521 .send()
522 {
523 Ok(res) => {
524 if !res.status().is_success() {
525 info!("Code {} sending {}", res.status(), payload.file_name);
526 }
527 Ok(res.status().is_success())
528 }
529 Err(e) => {
530 let status: String = e
531 .status()
532 .map(|s| s.as_str().to_string())
533 .unwrap_or_default();
534 error!("Error{status} sending {}: {e}", payload.file_name);
535 bail!(e.to_string())
536 }
537 }
538 }
539
540 pub fn partial_search(
546 &self,
547 partial_hash: Option<(PartialHashSearchType, String)>,
548 name: Option<String>,
549 response: PartialHashSearchType,
550 limit: u32,
551 ) -> Result<SearchResponse> {
552 let query = SearchRequest {
553 search: SearchType::Search(SearchRequestParameters {
554 partial_hash,
555 file_name: name,
556 response,
557 limit,
558 labels: None,
559 file_type: None,
560 magic: None,
561 }),
562 };
563
564 self.do_search_request(&query)
565 }
566
567 #[allow(clippy::too_many_arguments)]
573 pub fn partial_search_labels_type(
574 &self,
575 partial_hash: Option<(PartialHashSearchType, String)>,
576 name: Option<String>,
577 response: PartialHashSearchType,
578 labels: Option<Vec<String>>,
579 file_type: Option<String>,
580 magic: Option<String>,
581 limit: u32,
582 ) -> Result<SearchResponse> {
583 let query = SearchRequest {
584 search: SearchType::Search(SearchRequestParameters {
585 partial_hash,
586 file_name: name,
587 response,
588 limit,
589 file_type,
590 magic,
591 labels,
592 }),
593 };
594
595 self.do_search_request(&query)
596 }
597
598 pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
604 if let Some(uuid) = response.pagination {
605 let request = SearchRequest {
606 search: SearchType::Continuation(uuid),
607 };
608 return self.do_search_request(&request);
609 }
610
611 bail!("Pagination not available")
612 }
613
614 fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
615 ensure!(
616 query.is_valid(),
617 "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
618 );
619
620 let response = self
621 .client
622 .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
623 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
624 .json(query)
625 .send()?
626 .json::<ServerResponse<SearchResponse>>()
627 .context(MDB_CLIENT_ERROR_CONTEXT)?;
628
629 match response {
630 ServerResponse::Success(search) => Ok(search),
631 ServerResponse::Error(e) => Err(e.into()),
632 }
633 }
634
635 pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
642 let api_endpoint = if cart {
643 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
644 } else {
645 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
646 };
647
648 let res = self
649 .client
650 .get(format!("{}{api_endpoint}", self.url))
651 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
652 .send()?;
653
654 if !res.status().is_success() {
655 bail!("Received code {}", res.status());
656 }
657
658 let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
659 let body = res.bytes()?;
660 let bytes = body.to_vec();
661
662 if let Some(digest) = content_digest {
664 let hash = HashType::from_content_digest_header(digest.to_str()?)?;
665 if hash.verify(&bytes) {
666 trace!("Hash verified for sample {hash}");
667 } else {
668 error!("Hash mismatch for sample {hash}");
669 }
670 } else {
671 warn!("No content digest header received for sample {hash}");
672 }
673
674 Ok(bytes)
675 }
676
677 pub fn report(&self, hash: &str) -> Result<Report> {
684 let response = self
685 .client
686 .get(format!("{}{}/{hash}", self.url, malwaredb_api::SAMPLE_REPORT_URL))
687 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
688 .send()?
689 .json::<ServerResponse<Report>>()
690 .context(MDB_CLIENT_ERROR_CONTEXT)?;
691
692 match response {
693 ServerResponse::Success(report) => Ok(report),
694 ServerResponse::Error(e) => Err(e.into()),
695 }
696 }
697
698 pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
706 let mut hashes = vec![];
707 let ssdeep_hash = FuzzyHash::new(contents);
708
709 let build_hasher = Murmur3HashState::default();
710 let lzjd_str =
711 LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
712 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
713 hashes.push((malwaredb_api::SimilarityHashType::SSDeep, ssdeep_hash.to_string()));
714
715 let mut builder = TlshBuilder::new(
716 tlsh_fixed::BucketKind::Bucket256,
717 tlsh_fixed::ChecksumKind::ThreeByte,
718 tlsh_fixed::Version::Version4,
719 );
720
721 builder.update(contents);
722 if let Ok(hasher) = builder.build() {
723 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
724 }
725
726 if let Ok(exe) = EXE::from(contents)
727 && let Some(imports) = exe.imports
728 {
729 hashes
730 .push((malwaredb_api::SimilarityHashType::ImportHash, hex::encode(imports.hash())));
731 hashes.push((malwaredb_api::SimilarityHashType::FuzzyImportHash, imports.fuzzy_hash()));
732 }
733
734 let request = malwaredb_api::SimilarSamplesRequest { hashes };
735
736 let response = self
737 .client
738 .post(format!("{}{}", self.url, malwaredb_api::SIMILAR_SAMPLES_URL))
739 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
740 .json(&request)
741 .send()?
742 .json::<ServerResponse<SimilarSamplesResponse>>()
743 .context(MDB_CLIENT_ERROR_CONTEXT)?;
744
745 match response {
746 ServerResponse::Success(similar) => Ok(similar),
747 ServerResponse::Error(e) => Err(e.into()),
748 }
749 }
750
751 pub fn yara_search(&self, yara: &str) -> Result<YaraSearchRequestResponse> {
757 let yara = YaraSearchRequest {
758 rules: vec![yara.to_string()],
759 response: PartialHashSearchType::SHA256,
760 };
761
762 let response = self
763 .client
764 .post(format!("{}{}", self.url, malwaredb_api::YARA_SEARCH_URL))
765 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
766 .json(&yara)
767 .send()?
768 .json::<ServerResponse<YaraSearchRequestResponse>>()?;
769
770 match response {
771 ServerResponse::Success(similar) => Ok(similar),
772 ServerResponse::Error(e) => Err(e.into()),
773 }
774 }
775
776 pub fn yara_result(&self, uuid: Uuid) -> Result<YaraSearchResponse> {
782 let response = self
783 .client
784 .get(format!("{}{}/{uuid}", self.url, malwaredb_api::YARA_SEARCH_URL))
785 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
786 .send()?
787 .json::<ServerResponse<YaraSearchResponse>>()?;
788
789 match response {
790 ServerResponse::Success(sources) => Ok(sources),
791 ServerResponse::Error(e) => Err(e.into()),
792 }
793 }
794}
795
796impl Debug for MdbClient {
797 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
798 use crate::MDB_VERSION;
799
800 writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
801 }
802}
803
804#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
819pub struct IterableHashSearchResult<'a> {
820 pub response: SearchResponse,
822
823 client: &'a MdbClient,
825}
826
827impl<'a> IterableHashSearchResult<'a> {
828 #[must_use]
830 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
831 Self { response, client }
832 }
833}
834
835impl Iterator for IterableHashSearchResult<'_> {
836 type Item = String;
837
838 fn next(&mut self) -> Option<Self::Item> {
839 if let Some(hash) = self.response.hashes.pop() {
840 Some(hash)
841 } else if let Some(uuid) = self.response.pagination {
842 let request = SearchRequest {
843 search: SearchType::Continuation(uuid),
844 };
845
846 self.response = match self.client.do_search_request(&request) {
847 Ok(response) => response,
848 Err(e) => {
849 warn!("Failed to continue search: {e}");
850 return None;
851 }
852 };
853
854 self.response.hashes.pop()
855 } else {
856 None
857 }
858 }
859}
860
861#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
863pub struct IterableSampleSearchResult<'a> {
864 pub response: SearchResponse,
866
867 client: &'a MdbClient,
869}
870
871impl<'a> IterableSampleSearchResult<'a> {
872 #[must_use]
874 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
875 Self { response, client }
876 }
877}
878
879impl Iterator for IterableSampleSearchResult<'_> {
880 type Item = Vec<u8>;
881
882 fn next(&mut self) -> Option<Self::Item> {
883 if let Some(hash) = self.response.hashes.pop() {
884 let binary = match self.client.retrieve(&hash, false) {
885 Ok(binary) => binary,
886 Err(e) => {
887 error!("Failed to download {hash}: {e}");
888 return None;
889 }
890 };
891 Some(binary)
892 } else if let Some(uuid) = self.response.pagination {
893 let request = SearchRequest {
894 search: SearchType::Continuation(uuid),
895 };
896
897 self.response = match self.client.do_search_request(&request) {
898 Ok(response) => response,
899 Err(e) => {
900 warn!("Failed to continue search: {e}");
901 return None;
902 }
903 };
904
905 if let Some(hash) = self.response.hashes.pop() {
906 let binary = match self.client.retrieve(&hash, false) {
907 Ok(binary) => binary,
908 Err(e) => {
909 error!("Failed to download {hash}: {e}");
910 return None;
911 }
912 };
913 Some(binary)
914 } else {
915 None
916 }
917 } else {
918 None
919 }
920 }
921}