1use std::fmt::{Debug, Formatter};
4use std::path::{Path, PathBuf};
5
6use crate::{get_config_path, MDB_CLIENT_ERROR_CONTEXT};
7use malwaredb_api::{
8 digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
9 Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
10 ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes,
11};
12use malwaredb_types::exec::pe32::EXE;
13
14use anyhow::{bail, ensure, Context, Result};
15use base64::engine::general_purpose;
16use base64::Engine;
17use fuzzyhash::FuzzyHash;
18use malwaredb_lzjd::{LZDict, Murmur3HashState};
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use tlsh_fixed::TlshBuilder;
22use tracing::{error, info, trace, warn};
23use zeroize::{Zeroize, ZeroizeOnDrop};
24
25#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
27#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
28pub struct MdbClient {
29 pub url: String,
31
32 api_key: String,
34
35 #[zeroize(skip)]
37 #[serde(skip)]
38 client: reqwest::blocking::Client,
39
40 #[cfg(target_os = "macos")]
42 #[zeroize(skip)]
43 #[serde(skip)]
44 cert: Option<crate::macos::CertificateData>,
45}
46
47impl MdbClient {
48 pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
59 let mut url = url;
60 let url = if url.ends_with('/') {
61 url.pop();
62 url
63 } else {
64 url
65 };
66
67 let cert = if let Some(path) = cert_path {
68 Some((crate::path_load_cert(&path)?, path))
69 } else {
70 None
71 };
72
73 let builder = reqwest::blocking::ClientBuilder::new()
74 .gzip(true)
75 .zstd(true)
76 .use_rustls_tls()
77 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
78
79 let client = if let Some(((_cert_type, cert), _path)) = &cert {
80 builder.add_root_certificate(cert.clone()).build()
81 } else {
82 builder.build()
83 }?;
84
85 #[cfg(target_os = "macos")]
86 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
87 Some(crate::macos::CertificateData {
88 cert_type: *cert_type,
89 cert_bytes: std::fs::read(cert_path)?,
90 })
91 } else {
92 None
93 };
94
95 Ok(Self {
96 url,
97 api_key,
98 client,
99
100 #[cfg(target_os = "macos")]
101 cert,
102 })
103 }
104
105 pub fn login(
116 url: String,
117 username: String,
118 password: String,
119 save: bool,
120 cert_path: Option<PathBuf>,
121 ) -> Result<Self> {
122 let mut url = url;
123 let url = if url.ends_with('/') {
124 url.pop();
125 url
126 } else {
127 url
128 };
129
130 let api_request = malwaredb_api::GetAPIKeyRequest {
131 user: username,
132 password,
133 };
134
135 let builder = reqwest::blocking::ClientBuilder::new()
136 .gzip(true)
137 .zstd(true)
138 .use_rustls_tls()
139 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
140
141 let cert = if let Some(path) = cert_path {
142 Some((crate::path_load_cert(&path)?, path))
143 } else {
144 None
145 };
146
147 let client = if let Some(((_cert_type, cert), _path)) = &cert {
148 builder.add_root_certificate(cert.clone()).build()
149 } else {
150 builder.build()
151 }?;
152
153 let res = client
154 .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
155 .json(&api_request)
156 .send()?
157 .json::<ServerResponse<GetAPIKeyResponse>>()
158 .context(MDB_CLIENT_ERROR_CONTEXT)?;
159
160 let res = match res {
161 ServerResponse::Success(res) => res,
162 ServerResponse::Error(err) => return Err(err.into()),
163 };
164
165 #[cfg(target_os = "macos")]
166 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
167 Some(crate::macos::CertificateData {
168 cert_type: *cert_type,
169 cert_bytes: std::fs::read(cert_path)?,
170 })
171 } else {
172 None
173 };
174
175 let client = MdbClient {
176 url,
177 api_key: res.key.clone(),
178 client,
179
180 #[cfg(target_os = "macos")]
181 cert,
182 };
183
184 let server_info = client.server_info()?;
185 if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
186 warn!(
187 "Server version {:?} is newer than client {:?}, consider updating.",
188 server_info.mdb_version,
189 crate::MDB_VERSION_SEMVER
190 );
191 }
192
193 if save {
194 if let Err(e) = client.save() {
195 error!("Login successful but failed to save config: {e}");
196 bail!("Login successful but failed to save config: {e}");
197 }
198 }
199 Ok(client)
200 }
201
202 pub fn reset_key(&self) -> Result<()> {
208 let response = self
209 .client
210 .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
211 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
212 .send()
213 .context(MDB_CLIENT_ERROR_CONTEXT)?;
214 if !response.status().is_success() {
215 bail!("failed to reset API key, was it correct?");
216 }
217 Ok(())
218 }
219
220 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
227 let name = path.as_ref().display();
228 let config =
229 std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
230 let cfg: MdbClient =
231 toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
232 Ok(cfg)
233 }
234
235 pub fn load() -> Result<Self> {
248 #[cfg(target_os = "macos")]
249 {
250 if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
251 let builder = reqwest::blocking::ClientBuilder::new()
252 .gzip(true)
253 .zstd(true)
254 .use_rustls_tls()
255 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
256
257 let client = if let Some(cert) = &cert {
258 builder.add_root_certificate(cert.as_cert()?).build()
259 } else {
260 builder.build()
261 }?;
262
263 return Ok(Self {
264 url,
265 api_key,
266 client,
267 cert,
268 });
269 }
270 }
271
272 let path = get_config_path(false)?;
273 if path.exists() {
274 return Self::from_file(path);
275 }
276 bail!("config file not found")
277 }
278
279 pub fn save(&self) -> Result<()> {
287 #[cfg(target_os = "macos")]
288 {
289 if crate::macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
290 return Ok(());
291 }
292 }
293
294 let toml = toml::to_string(self)?;
295 let path = get_config_path(true)?;
296 std::fs::write(&path, toml)
297 .context(format!("failed to write mdb config to {}", path.display()))
298 }
299
300 pub fn delete(&self) -> Result<()> {
307 #[cfg(target_os = "macos")]
308 crate::macos::clear_credentials();
309
310 let path = get_config_path(false)?;
311 if path.exists() {
312 std::fs::remove_file(&path).context(format!(
313 "failed to delete client config file {}",
314 path.display()
315 ))?;
316 }
317 Ok(())
318 }
319
320 pub fn server_info(&self) -> Result<ServerInfo> {
328 let response = self
329 .client
330 .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
331 .send()?
332 .json::<ServerResponse<ServerInfo>>()
333 .context(MDB_CLIENT_ERROR_CONTEXT)?;
334
335 match response {
336 ServerResponse::Success(info) => Ok(info),
337 ServerResponse::Error(e) => Err(e.into()),
338 }
339 }
340
341 pub fn supported_types(&self) -> Result<SupportedFileTypes> {
347 let response = self
348 .client
349 .get(format!(
350 "{}{}",
351 self.url,
352 malwaredb_api::SUPPORTED_FILE_TYPES_URL
353 ))
354 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
355 .send()?
356 .json::<ServerResponse<SupportedFileTypes>>()
357 .context(MDB_CLIENT_ERROR_CONTEXT)?;
358
359 match response {
360 ServerResponse::Success(types) => Ok(types),
361 ServerResponse::Error(e) => Err(e.into()),
362 }
363 }
364
365 pub fn whoami(&self) -> Result<GetUserInfoResponse> {
372 let response = self
373 .client
374 .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
375 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
376 .send()?
377 .json::<ServerResponse<GetUserInfoResponse>>()
378 .context(MDB_CLIENT_ERROR_CONTEXT)?;
379
380 match response {
381 ServerResponse::Success(info) => Ok(info),
382 ServerResponse::Error(e) => Err(e.into()),
383 }
384 }
385
386 pub fn labels(&self) -> Result<Labels> {
393 let response = self
394 .client
395 .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
396 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
397 .send()?
398 .json::<ServerResponse<Labels>>()
399 .context(MDB_CLIENT_ERROR_CONTEXT)?;
400
401 match response {
402 ServerResponse::Success(labels) => Ok(labels),
403 ServerResponse::Error(e) => Err(e.into()),
404 }
405 }
406
407 pub fn sources(&self) -> Result<Sources> {
414 let response = self
415 .client
416 .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
417 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
418 .send()?
419 .json::<ServerResponse<Sources>>()
420 .context(MDB_CLIENT_ERROR_CONTEXT)?;
421
422 match response {
423 ServerResponse::Success(sources) => Ok(sources),
424 ServerResponse::Error(e) => Err(e.into()),
425 }
426 }
427
428 pub fn submit(
435 &self,
436 contents: impl AsRef<[u8]>,
437 file_name: String,
438 source_id: u32,
439 ) -> Result<bool> {
440 let mut hasher = Sha256::new();
441 hasher.update(&contents);
442 let result = hasher.finalize();
443
444 let encoded = general_purpose::STANDARD.encode(contents);
445
446 let payload = malwaredb_api::NewSampleB64 {
447 file_name,
448 source_id,
449 file_contents_b64: encoded,
450 sha256: hex::encode(result),
451 };
452
453 match self
454 .client
455 .post(format!(
456 "{}{}",
457 self.url,
458 malwaredb_api::UPLOAD_SAMPLE_JSON_URL
459 ))
460 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
461 .json(&payload)
462 .send()
463 {
464 Ok(res) => {
465 if !res.status().is_success() {
466 info!("Code {} sending {}", res.status(), payload.file_name);
467 }
468 Ok(res.status().is_success())
469 }
470 Err(e) => {
471 let status: String = e
472 .status()
473 .map(|s| s.as_str().to_string())
474 .unwrap_or_default();
475 error!("Error{status} sending {}: {e}", payload.file_name);
476 bail!(e.to_string())
477 }
478 }
479 }
480
481 pub fn submit_as_cbor(
489 &self,
490 contents: impl AsRef<[u8]>,
491 file_name: String,
492 source_id: u32,
493 ) -> Result<bool> {
494 let mut hasher = Sha256::new();
495 hasher.update(&contents);
496 let result = hasher.finalize();
497
498 let payload = malwaredb_api::NewSampleBytes {
499 file_name,
500 source_id,
501 file_contents: contents.as_ref().to_vec(),
502 sha256: hex::encode(result),
503 };
504
505 let mut bytes = Vec::with_capacity(payload.file_contents.len());
506 ciborium::ser::into_writer(&payload, &mut bytes)?;
507
508 match self
509 .client
510 .post(format!(
511 "{}{}",
512 self.url,
513 malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
514 ))
515 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
516 .header("content-type", "application/cbor")
517 .body(bytes)
518 .send()
519 {
520 Ok(res) => {
521 if !res.status().is_success() {
522 info!("Code {} sending {}", res.status(), payload.file_name);
523 }
524 Ok(res.status().is_success())
525 }
526 Err(e) => {
527 let status: String = e
528 .status()
529 .map(|s| s.as_str().to_string())
530 .unwrap_or_default();
531 error!("Error{status} sending {}: {e}", payload.file_name);
532 bail!(e.to_string())
533 }
534 }
535 }
536
537 pub fn partial_search(
543 &self,
544 partial_hash: Option<(PartialHashSearchType, String)>,
545 name: Option<String>,
546 response: PartialHashSearchType,
547 limit: u32,
548 ) -> Result<SearchResponse> {
549 let query = SearchRequest {
550 search: SearchType::Search(SearchRequestParameters {
551 partial_hash,
552 file_name: name,
553 response,
554 limit,
555 labels: None,
556 file_type: None,
557 magic: None,
558 }),
559 };
560
561 self.do_search_request(&query)
562 }
563
564 #[allow(clippy::too_many_arguments)]
570 pub fn partial_search_labels_type(
571 &self,
572 partial_hash: Option<(PartialHashSearchType, String)>,
573 name: Option<String>,
574 response: PartialHashSearchType,
575 labels: Option<Vec<String>>,
576 file_type: Option<String>,
577 magic: Option<String>,
578 limit: u32,
579 ) -> Result<SearchResponse> {
580 let query = SearchRequest {
581 search: SearchType::Search(SearchRequestParameters {
582 partial_hash,
583 file_name: name,
584 response,
585 limit,
586 file_type,
587 magic,
588 labels,
589 }),
590 };
591
592 self.do_search_request(&query)
593 }
594
595 pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
601 if let Some(uuid) = response.pagination {
602 let request = SearchRequest {
603 search: SearchType::Continuation(uuid),
604 };
605 return self.do_search_request(&request);
606 }
607
608 bail!("Pagination not available")
609 }
610
611 fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
612 ensure!(
613 query.is_valid(),
614 "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
615 );
616
617 let response = self
618 .client
619 .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
620 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
621 .json(query)
622 .send()?
623 .json::<ServerResponse<SearchResponse>>()
624 .context(MDB_CLIENT_ERROR_CONTEXT)?;
625
626 match response {
627 ServerResponse::Success(search) => Ok(search),
628 ServerResponse::Error(e) => Err(e.into()),
629 }
630 }
631
632 pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
639 let api_endpoint = if cart {
640 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
641 } else {
642 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
643 };
644
645 let res = self
646 .client
647 .get(format!("{}{api_endpoint}", self.url))
648 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
649 .send()?;
650
651 if !res.status().is_success() {
652 bail!("Received code {}", res.status());
653 }
654
655 let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
656 let body = res.bytes()?;
657 let bytes = body.to_vec();
658
659 if let Some(digest) = content_digest {
661 let hash = HashType::from_content_digest_header(digest.to_str()?)?;
662 if hash.verify(&bytes) {
663 trace!("Hash verified for sample {hash}");
664 } else {
665 error!("Hash mismatch for sample {hash}");
666 }
667 } else {
668 warn!("No content digest header received for sample {hash}");
669 }
670
671 Ok(bytes)
672 }
673
674 pub fn report(&self, hash: &str) -> Result<Report> {
681 let response = self
682 .client
683 .get(format!(
684 "{}{}/{hash}",
685 self.url,
686 malwaredb_api::SAMPLE_REPORT_URL
687 ))
688 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
689 .send()?
690 .json::<ServerResponse<Report>>()
691 .context(MDB_CLIENT_ERROR_CONTEXT)?;
692
693 match response {
694 ServerResponse::Success(report) => Ok(report),
695 ServerResponse::Error(e) => Err(e.into()),
696 }
697 }
698
699 pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
707 let mut hashes = vec![];
708 let ssdeep_hash = FuzzyHash::new(contents);
709
710 let build_hasher = Murmur3HashState::default();
711 let lzjd_str =
712 LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
713 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
714 hashes.push((
715 malwaredb_api::SimilarityHashType::SSDeep,
716 ssdeep_hash.to_string(),
717 ));
718
719 let mut builder = TlshBuilder::new(
720 tlsh_fixed::BucketKind::Bucket256,
721 tlsh_fixed::ChecksumKind::ThreeByte,
722 tlsh_fixed::Version::Version4,
723 );
724
725 builder.update(contents);
726 if let Ok(hasher) = builder.build() {
727 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
728 }
729
730 if let Ok(exe) = EXE::from(contents) {
731 if let Some(imports) = exe.imports {
732 hashes.push((
733 malwaredb_api::SimilarityHashType::ImportHash,
734 hex::encode(imports.hash()),
735 ));
736 hashes.push((
737 malwaredb_api::SimilarityHashType::FuzzyImportHash,
738 imports.fuzzy_hash(),
739 ));
740 }
741 }
742
743 let request = malwaredb_api::SimilarSamplesRequest { hashes };
744
745 let response = self
746 .client
747 .post(format!(
748 "{}{}",
749 self.url,
750 malwaredb_api::SIMILAR_SAMPLES_URL
751 ))
752 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
753 .json(&request)
754 .send()?
755 .json::<ServerResponse<SimilarSamplesResponse>>()
756 .context(MDB_CLIENT_ERROR_CONTEXT)?;
757
758 match response {
759 ServerResponse::Success(similar) => Ok(similar),
760 ServerResponse::Error(e) => Err(e.into()),
761 }
762 }
763}
764
765impl Debug for MdbClient {
766 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
767 use crate::MDB_VERSION;
768
769 writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
770 }
771}
772
773#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
788pub struct IterableHashSearchResult<'a> {
789 pub response: SearchResponse,
791
792 client: &'a MdbClient,
794}
795
796impl<'a> IterableHashSearchResult<'a> {
797 #[must_use]
799 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
800 Self { response, client }
801 }
802}
803
804impl Iterator for IterableHashSearchResult<'_> {
805 type Item = String;
806
807 fn next(&mut self) -> Option<Self::Item> {
808 if let Some(hash) = self.response.hashes.pop() {
809 Some(hash)
810 } else if let Some(uuid) = self.response.pagination {
811 let request = SearchRequest {
812 search: SearchType::Continuation(uuid),
813 };
814
815 self.response = match self.client.do_search_request(&request) {
816 Ok(response) => response,
817 Err(e) => {
818 warn!("Failed to continue search: {e}");
819 return None;
820 }
821 };
822
823 self.response.hashes.pop()
824 } else {
825 None
826 }
827 }
828}
829
830#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
832pub struct IterableSampleSearchResult<'a> {
833 pub response: SearchResponse,
835
836 client: &'a MdbClient,
838}
839
840impl<'a> IterableSampleSearchResult<'a> {
841 #[must_use]
843 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
844 Self { response, client }
845 }
846}
847
848impl Iterator for IterableSampleSearchResult<'_> {
849 type Item = Vec<u8>;
850
851 fn next(&mut self) -> Option<Self::Item> {
852 if let Some(hash) = self.response.hashes.pop() {
853 let binary = match self.client.retrieve(&hash, false) {
854 Ok(binary) => binary,
855 Err(e) => {
856 error!("Failed to download {hash}: {e}");
857 return None;
858 }
859 };
860 Some(binary)
861 } else if let Some(uuid) = self.response.pagination {
862 let request = SearchRequest {
863 search: SearchType::Continuation(uuid),
864 };
865
866 self.response = match self.client.do_search_request(&request) {
867 Ok(response) => response,
868 Err(e) => {
869 warn!("Failed to continue search: {e}");
870 return None;
871 }
872 };
873
874 if let Some(hash) = self.response.hashes.pop() {
875 let binary = match self.client.retrieve(&hash, false) {
876 Ok(binary) => binary,
877 Err(e) => {
878 error!("Failed to download {hash}: {e}");
879 return None;
880 }
881 };
882 Some(binary)
883 } else {
884 None
885 }
886 } else {
887 None
888 }
889 }
890}