1use std::fmt::{Debug, Formatter};
4use std::path::{Path, PathBuf};
5
6use crate::{get_config_path, MDB_CLIENT_ERROR_CONTEXT};
7use malwaredb_api::{
8 digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
9 Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
10 ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
11 YaraSearchRequestResponse, YaraSearchResponse,
12};
13use malwaredb_types::exec::pe32::EXE;
14
15use anyhow::{bail, ensure, Context, Result};
16use base64::engine::general_purpose;
17use base64::Engine;
18use fuzzyhash::FuzzyHash;
19use malwaredb_lzjd::{LZDict, Murmur3HashState};
20use serde::{Deserialize, Serialize};
21use sha2::{Digest, Sha256};
22use tlsh_fixed::TlshBuilder;
23use tracing::{error, info, trace, warn};
24use uuid::Uuid;
25use zeroize::{Zeroize, ZeroizeOnDrop};
26
27#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
29#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
30pub struct MdbClient {
31 pub url: String,
33
34 api_key: String,
36
37 #[zeroize(skip)]
39 #[serde(skip)]
40 client: reqwest::blocking::Client,
41
42 #[cfg(target_os = "macos")]
44 #[zeroize(skip)]
45 #[serde(skip)]
46 cert: Option<crate::macos::CertificateData>,
47}
48
49impl MdbClient {
50 pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
61 let mut url = url;
62 let url = if url.ends_with('/') {
63 url.pop();
64 url
65 } else {
66 url
67 };
68
69 let cert = if let Some(path) = cert_path {
70 Some((crate::path_load_cert(&path)?, path))
71 } else {
72 None
73 };
74
75 let builder = reqwest::blocking::ClientBuilder::new()
76 .gzip(true)
77 .zstd(true)
78 .use_rustls_tls()
79 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
80
81 let client = if let Some(((_cert_type, cert), _path)) = &cert {
82 builder.add_root_certificate(cert.clone()).build()
83 } else {
84 builder.build()
85 }?;
86
87 #[cfg(target_os = "macos")]
88 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
89 Some(crate::macos::CertificateData {
90 cert_type: *cert_type,
91 cert_bytes: std::fs::read(cert_path)?,
92 })
93 } else {
94 None
95 };
96
97 Ok(Self {
98 url,
99 api_key,
100 client,
101
102 #[cfg(target_os = "macos")]
103 cert,
104 })
105 }
106
107 pub fn login(
118 url: String,
119 username: String,
120 password: String,
121 save: bool,
122 cert_path: Option<PathBuf>,
123 ) -> Result<Self> {
124 let mut url = url;
125 let url = if url.ends_with('/') {
126 url.pop();
127 url
128 } else {
129 url
130 };
131
132 let api_request = malwaredb_api::GetAPIKeyRequest {
133 user: username,
134 password,
135 };
136
137 let builder = reqwest::blocking::ClientBuilder::new()
138 .gzip(true)
139 .zstd(true)
140 .use_rustls_tls()
141 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
142
143 let cert = if let Some(path) = cert_path {
144 Some((crate::path_load_cert(&path)?, path))
145 } else {
146 None
147 };
148
149 let client = if let Some(((_cert_type, cert), _path)) = &cert {
150 builder.add_root_certificate(cert.clone()).build()
151 } else {
152 builder.build()
153 }?;
154
155 let res = client
156 .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
157 .json(&api_request)
158 .send()?
159 .json::<ServerResponse<GetAPIKeyResponse>>()
160 .context(MDB_CLIENT_ERROR_CONTEXT)?;
161
162 let res = match res {
163 ServerResponse::Success(res) => res,
164 ServerResponse::Error(err) => return Err(err.into()),
165 };
166
167 #[cfg(target_os = "macos")]
168 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
169 Some(crate::macos::CertificateData {
170 cert_type: *cert_type,
171 cert_bytes: std::fs::read(cert_path)?,
172 })
173 } else {
174 None
175 };
176
177 let client = MdbClient {
178 url,
179 api_key: res.key.clone(),
180 client,
181
182 #[cfg(target_os = "macos")]
183 cert,
184 };
185
186 let server_info = client.server_info()?;
187 if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
188 warn!(
189 "Server version {:?} is newer than client {:?}, consider updating.",
190 server_info.mdb_version,
191 crate::MDB_VERSION_SEMVER
192 );
193 }
194
195 if save {
196 if let Err(e) = client.save() {
197 error!("Login successful but failed to save config: {e}");
198 bail!("Login successful but failed to save config: {e}");
199 }
200 }
201 Ok(client)
202 }
203
204 pub fn reset_key(&self) -> Result<()> {
210 let response = self
211 .client
212 .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
213 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
214 .send()
215 .context(MDB_CLIENT_ERROR_CONTEXT)?;
216 if !response.status().is_success() {
217 bail!("failed to reset API key, was it correct?");
218 }
219 Ok(())
220 }
221
222 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
229 let name = path.as_ref().display();
230 let config =
231 std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
232 let cfg: MdbClient =
233 toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
234 Ok(cfg)
235 }
236
237 pub fn load() -> Result<Self> {
250 #[cfg(target_os = "macos")]
251 {
252 if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
253 let builder = reqwest::blocking::ClientBuilder::new()
254 .gzip(true)
255 .zstd(true)
256 .use_rustls_tls()
257 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
258
259 let client = if let Some(cert) = &cert {
260 builder.add_root_certificate(cert.as_cert()?).build()
261 } else {
262 builder.build()
263 }?;
264
265 return Ok(Self {
266 url,
267 api_key,
268 client,
269 cert,
270 });
271 }
272 }
273
274 let path = get_config_path(false)?;
275 if path.exists() {
276 return Self::from_file(path);
277 }
278 bail!("config file not found")
279 }
280
281 pub fn save(&self) -> Result<()> {
289 #[cfg(target_os = "macos")]
290 {
291 if crate::macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
292 return Ok(());
293 }
294 }
295
296 let toml = toml::to_string(self)?;
297 let path = get_config_path(true)?;
298 std::fs::write(&path, toml)
299 .context(format!("failed to write mdb config to {}", path.display()))
300 }
301
302 pub fn delete(&self) -> Result<()> {
309 #[cfg(target_os = "macos")]
310 crate::macos::clear_credentials();
311
312 let path = get_config_path(false)?;
313 if path.exists() {
314 std::fs::remove_file(&path).context(format!(
315 "failed to delete client config file {}",
316 path.display()
317 ))?;
318 }
319 Ok(())
320 }
321
322 pub fn server_info(&self) -> Result<ServerInfo> {
330 let response = self
331 .client
332 .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
333 .send()?
334 .json::<ServerResponse<ServerInfo>>()
335 .context(MDB_CLIENT_ERROR_CONTEXT)?;
336
337 match response {
338 ServerResponse::Success(info) => Ok(info),
339 ServerResponse::Error(e) => Err(e.into()),
340 }
341 }
342
343 pub fn supported_types(&self) -> Result<SupportedFileTypes> {
349 let response = self
350 .client
351 .get(format!(
352 "{}{}",
353 self.url,
354 malwaredb_api::SUPPORTED_FILE_TYPES_URL
355 ))
356 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
357 .send()?
358 .json::<ServerResponse<SupportedFileTypes>>()
359 .context(MDB_CLIENT_ERROR_CONTEXT)?;
360
361 match response {
362 ServerResponse::Success(types) => Ok(types),
363 ServerResponse::Error(e) => Err(e.into()),
364 }
365 }
366
367 pub fn whoami(&self) -> Result<GetUserInfoResponse> {
374 let response = self
375 .client
376 .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
377 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
378 .send()?
379 .json::<ServerResponse<GetUserInfoResponse>>()
380 .context(MDB_CLIENT_ERROR_CONTEXT)?;
381
382 match response {
383 ServerResponse::Success(info) => Ok(info),
384 ServerResponse::Error(e) => Err(e.into()),
385 }
386 }
387
388 pub fn labels(&self) -> Result<Labels> {
395 let response = self
396 .client
397 .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
398 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
399 .send()?
400 .json::<ServerResponse<Labels>>()
401 .context(MDB_CLIENT_ERROR_CONTEXT)?;
402
403 match response {
404 ServerResponse::Success(labels) => Ok(labels),
405 ServerResponse::Error(e) => Err(e.into()),
406 }
407 }
408
409 pub fn sources(&self) -> Result<Sources> {
416 let response = self
417 .client
418 .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
419 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
420 .send()?
421 .json::<ServerResponse<Sources>>()
422 .context(MDB_CLIENT_ERROR_CONTEXT)?;
423
424 match response {
425 ServerResponse::Success(sources) => Ok(sources),
426 ServerResponse::Error(e) => Err(e.into()),
427 }
428 }
429
430 pub fn submit(
437 &self,
438 contents: impl AsRef<[u8]>,
439 file_name: String,
440 source_id: u32,
441 ) -> Result<bool> {
442 let mut hasher = Sha256::new();
443 hasher.update(&contents);
444 let result = hasher.finalize();
445
446 let encoded = general_purpose::STANDARD.encode(contents);
447
448 let payload = malwaredb_api::NewSampleB64 {
449 file_name,
450 source_id,
451 file_contents_b64: encoded,
452 sha256: hex::encode(result),
453 };
454
455 match self
456 .client
457 .post(format!(
458 "{}{}",
459 self.url,
460 malwaredb_api::UPLOAD_SAMPLE_JSON_URL
461 ))
462 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
463 .json(&payload)
464 .send()
465 {
466 Ok(res) => {
467 if !res.status().is_success() {
468 info!("Code {} sending {}", res.status(), payload.file_name);
469 }
470 Ok(res.status().is_success())
471 }
472 Err(e) => {
473 let status: String = e
474 .status()
475 .map(|s| s.as_str().to_string())
476 .unwrap_or_default();
477 error!("Error{status} sending {}: {e}", payload.file_name);
478 bail!(e.to_string())
479 }
480 }
481 }
482
483 pub fn submit_as_cbor(
491 &self,
492 contents: impl AsRef<[u8]>,
493 file_name: String,
494 source_id: u32,
495 ) -> Result<bool> {
496 let mut hasher = Sha256::new();
497 hasher.update(&contents);
498 let result = hasher.finalize();
499
500 let payload = malwaredb_api::NewSampleBytes {
501 file_name,
502 source_id,
503 file_contents: contents.as_ref().to_vec(),
504 sha256: hex::encode(result),
505 };
506
507 let mut bytes = Vec::with_capacity(payload.file_contents.len());
508 ciborium::ser::into_writer(&payload, &mut bytes)?;
509
510 match self
511 .client
512 .post(format!(
513 "{}{}",
514 self.url,
515 malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
516 ))
517 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
518 .header("content-type", "application/cbor")
519 .body(bytes)
520 .send()
521 {
522 Ok(res) => {
523 if !res.status().is_success() {
524 info!("Code {} sending {}", res.status(), payload.file_name);
525 }
526 Ok(res.status().is_success())
527 }
528 Err(e) => {
529 let status: String = e
530 .status()
531 .map(|s| s.as_str().to_string())
532 .unwrap_or_default();
533 error!("Error{status} sending {}: {e}", payload.file_name);
534 bail!(e.to_string())
535 }
536 }
537 }
538
539 pub fn partial_search(
545 &self,
546 partial_hash: Option<(PartialHashSearchType, String)>,
547 name: Option<String>,
548 response: PartialHashSearchType,
549 limit: u32,
550 ) -> Result<SearchResponse> {
551 let query = SearchRequest {
552 search: SearchType::Search(SearchRequestParameters {
553 partial_hash,
554 file_name: name,
555 response,
556 limit,
557 labels: None,
558 file_type: None,
559 magic: None,
560 }),
561 };
562
563 self.do_search_request(&query)
564 }
565
566 #[allow(clippy::too_many_arguments)]
572 pub fn partial_search_labels_type(
573 &self,
574 partial_hash: Option<(PartialHashSearchType, String)>,
575 name: Option<String>,
576 response: PartialHashSearchType,
577 labels: Option<Vec<String>>,
578 file_type: Option<String>,
579 magic: Option<String>,
580 limit: u32,
581 ) -> Result<SearchResponse> {
582 let query = SearchRequest {
583 search: SearchType::Search(SearchRequestParameters {
584 partial_hash,
585 file_name: name,
586 response,
587 limit,
588 file_type,
589 magic,
590 labels,
591 }),
592 };
593
594 self.do_search_request(&query)
595 }
596
597 pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
603 if let Some(uuid) = response.pagination {
604 let request = SearchRequest {
605 search: SearchType::Continuation(uuid),
606 };
607 return self.do_search_request(&request);
608 }
609
610 bail!("Pagination not available")
611 }
612
613 fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
614 ensure!(
615 query.is_valid(),
616 "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
617 );
618
619 let response = self
620 .client
621 .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
622 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
623 .json(query)
624 .send()?
625 .json::<ServerResponse<SearchResponse>>()
626 .context(MDB_CLIENT_ERROR_CONTEXT)?;
627
628 match response {
629 ServerResponse::Success(search) => Ok(search),
630 ServerResponse::Error(e) => Err(e.into()),
631 }
632 }
633
634 pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
641 let api_endpoint = if cart {
642 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
643 } else {
644 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
645 };
646
647 let res = self
648 .client
649 .get(format!("{}{api_endpoint}", self.url))
650 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
651 .send()?;
652
653 if !res.status().is_success() {
654 bail!("Received code {}", res.status());
655 }
656
657 let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
658 let body = res.bytes()?;
659 let bytes = body.to_vec();
660
661 if let Some(digest) = content_digest {
663 let hash = HashType::from_content_digest_header(digest.to_str()?)?;
664 if hash.verify(&bytes) {
665 trace!("Hash verified for sample {hash}");
666 } else {
667 error!("Hash mismatch for sample {hash}");
668 }
669 } else {
670 warn!("No content digest header received for sample {hash}");
671 }
672
673 Ok(bytes)
674 }
675
676 pub fn report(&self, hash: &str) -> Result<Report> {
683 let response = self
684 .client
685 .get(format!(
686 "{}{}/{hash}",
687 self.url,
688 malwaredb_api::SAMPLE_REPORT_URL
689 ))
690 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
691 .send()?
692 .json::<ServerResponse<Report>>()
693 .context(MDB_CLIENT_ERROR_CONTEXT)?;
694
695 match response {
696 ServerResponse::Success(report) => Ok(report),
697 ServerResponse::Error(e) => Err(e.into()),
698 }
699 }
700
701 pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
709 let mut hashes = vec![];
710 let ssdeep_hash = FuzzyHash::new(contents);
711
712 let build_hasher = Murmur3HashState::default();
713 let lzjd_str =
714 LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
715 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
716 hashes.push((
717 malwaredb_api::SimilarityHashType::SSDeep,
718 ssdeep_hash.to_string(),
719 ));
720
721 let mut builder = TlshBuilder::new(
722 tlsh_fixed::BucketKind::Bucket256,
723 tlsh_fixed::ChecksumKind::ThreeByte,
724 tlsh_fixed::Version::Version4,
725 );
726
727 builder.update(contents);
728 if let Ok(hasher) = builder.build() {
729 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
730 }
731
732 if let Ok(exe) = EXE::from(contents) {
733 if let Some(imports) = exe.imports {
734 hashes.push((
735 malwaredb_api::SimilarityHashType::ImportHash,
736 hex::encode(imports.hash()),
737 ));
738 hashes.push((
739 malwaredb_api::SimilarityHashType::FuzzyImportHash,
740 imports.fuzzy_hash(),
741 ));
742 }
743 }
744
745 let request = malwaredb_api::SimilarSamplesRequest { hashes };
746
747 let response = self
748 .client
749 .post(format!(
750 "{}{}",
751 self.url,
752 malwaredb_api::SIMILAR_SAMPLES_URL
753 ))
754 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
755 .json(&request)
756 .send()?
757 .json::<ServerResponse<SimilarSamplesResponse>>()
758 .context(MDB_CLIENT_ERROR_CONTEXT)?;
759
760 match response {
761 ServerResponse::Success(similar) => Ok(similar),
762 ServerResponse::Error(e) => Err(e.into()),
763 }
764 }
765
766 pub fn yara_search(&self, yara: &str) -> Result<YaraSearchRequestResponse> {
772 let yara = YaraSearchRequest {
773 rules: vec![yara.to_string()],
774 response: PartialHashSearchType::SHA256,
775 };
776
777 let response = self
778 .client
779 .post(format!("{}{}", self.url, malwaredb_api::YARA_SEARCH_URL))
780 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
781 .json(&yara)
782 .send()?
783 .json::<ServerResponse<YaraSearchRequestResponse>>()?;
784
785 match response {
786 ServerResponse::Success(similar) => Ok(similar),
787 ServerResponse::Error(e) => Err(e.into()),
788 }
789 }
790
791 pub fn yara_result(&self, uuid: Uuid) -> Result<YaraSearchResponse> {
797 let response = self
798 .client
799 .get(format!(
800 "{}{}/{uuid}",
801 self.url,
802 malwaredb_api::YARA_SEARCH_URL
803 ))
804 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
805 .send()?
806 .json::<ServerResponse<YaraSearchResponse>>()?;
807
808 match response {
809 ServerResponse::Success(sources) => Ok(sources),
810 ServerResponse::Error(e) => Err(e.into()),
811 }
812 }
813}
814
815impl Debug for MdbClient {
816 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
817 use crate::MDB_VERSION;
818
819 writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
820 }
821}
822
823#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
838pub struct IterableHashSearchResult<'a> {
839 pub response: SearchResponse,
841
842 client: &'a MdbClient,
844}
845
846impl<'a> IterableHashSearchResult<'a> {
847 #[must_use]
849 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
850 Self { response, client }
851 }
852}
853
854impl Iterator for IterableHashSearchResult<'_> {
855 type Item = String;
856
857 fn next(&mut self) -> Option<Self::Item> {
858 if let Some(hash) = self.response.hashes.pop() {
859 Some(hash)
860 } else if let Some(uuid) = self.response.pagination {
861 let request = SearchRequest {
862 search: SearchType::Continuation(uuid),
863 };
864
865 self.response = match self.client.do_search_request(&request) {
866 Ok(response) => response,
867 Err(e) => {
868 warn!("Failed to continue search: {e}");
869 return None;
870 }
871 };
872
873 self.response.hashes.pop()
874 } else {
875 None
876 }
877 }
878}
879
880#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
882pub struct IterableSampleSearchResult<'a> {
883 pub response: SearchResponse,
885
886 client: &'a MdbClient,
888}
889
890impl<'a> IterableSampleSearchResult<'a> {
891 #[must_use]
893 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
894 Self { response, client }
895 }
896}
897
898impl Iterator for IterableSampleSearchResult<'_> {
899 type Item = Vec<u8>;
900
901 fn next(&mut self) -> Option<Self::Item> {
902 if let Some(hash) = self.response.hashes.pop() {
903 let binary = match self.client.retrieve(&hash, false) {
904 Ok(binary) => binary,
905 Err(e) => {
906 error!("Failed to download {hash}: {e}");
907 return None;
908 }
909 };
910 Some(binary)
911 } else if let Some(uuid) = self.response.pagination {
912 let request = SearchRequest {
913 search: SearchType::Continuation(uuid),
914 };
915
916 self.response = match self.client.do_search_request(&request) {
917 Ok(response) => response,
918 Err(e) => {
919 warn!("Failed to continue search: {e}");
920 return None;
921 }
922 };
923
924 if let Some(hash) = self.response.hashes.pop() {
925 let binary = match self.client.retrieve(&hash, false) {
926 Ok(binary) => binary,
927 Err(e) => {
928 error!("Failed to download {hash}: {e}");
929 return None;
930 }
931 };
932 Some(binary)
933 } else {
934 None
935 }
936 } else {
937 None
938 }
939 }
940}