1use std::fmt::{Debug, Formatter};
4use std::path::{Path, PathBuf};
5
6use crate::{get_config_path, MDB_CLIENT_ERROR_CONTEXT};
7use malwaredb_types::exec::pe32::EXE;
8
9use anyhow::{bail, ensure, Context, Result};
10use base64::engine::general_purpose;
11use base64::Engine;
12use fuzzyhash::FuzzyHash;
13use malwaredb_api::{
14 GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType, Report, SearchRequest,
15 SearchRequestParameters, SearchResponse, SearchType, ServerInfo, ServerResponse,
16 SimilarSamplesResponse, Sources, SupportedFileTypes,
17};
18use malwaredb_lzjd::{LZDict, Murmur3HashState};
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use tlsh_fixed::TlshBuilder;
22use tracing::{error, info, warn};
23use zeroize::{Zeroize, ZeroizeOnDrop};
24
25#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
27#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
28pub struct MdbClient {
29 pub url: String,
31
32 api_key: String,
34
35 #[zeroize(skip)]
37 #[serde(skip)]
38 client: reqwest::blocking::Client,
39
40 #[cfg(target_os = "macos")]
42 #[zeroize(skip)]
43 #[serde(skip)]
44 cert: Option<crate::macos::CertificateData>,
45}
46
47impl MdbClient {
48 pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
59 let mut url = url;
60 let url = if url.ends_with('/') {
61 url.pop();
62 url
63 } else {
64 url
65 };
66
67 let cert = if let Some(path) = cert_path {
68 Some((crate::path_load_cert(&path)?, path))
69 } else {
70 None
71 };
72
73 let builder = reqwest::blocking::ClientBuilder::new()
74 .gzip(true)
75 .zstd(true)
76 .use_rustls_tls()
77 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
78
79 let client = if let Some(((_cert_type, cert), _path)) = &cert {
80 builder.add_root_certificate(cert.clone()).build()
81 } else {
82 builder.build()
83 }?;
84
85 #[cfg(target_os = "macos")]
86 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
87 Some(crate::macos::CertificateData {
88 cert_type: *cert_type,
89 cert_bytes: std::fs::read(cert_path)?,
90 })
91 } else {
92 None
93 };
94
95 Ok(Self {
96 url,
97 api_key,
98 client,
99
100 #[cfg(target_os = "macos")]
101 cert,
102 })
103 }
104
105 pub fn login(
116 url: String,
117 username: String,
118 password: String,
119 save: bool,
120 cert_path: Option<PathBuf>,
121 ) -> Result<Self> {
122 let mut url = url;
123 let url = if url.ends_with('/') {
124 url.pop();
125 url
126 } else {
127 url
128 };
129
130 let api_request = malwaredb_api::GetAPIKeyRequest {
131 user: username,
132 password,
133 };
134
135 let builder = reqwest::blocking::ClientBuilder::new()
136 .gzip(true)
137 .zstd(true)
138 .use_rustls_tls()
139 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
140
141 let cert = if let Some(path) = cert_path {
142 Some((crate::path_load_cert(&path)?, path))
143 } else {
144 None
145 };
146
147 let client = if let Some(((_cert_type, cert), _path)) = &cert {
148 builder.add_root_certificate(cert.clone()).build()
149 } else {
150 builder.build()
151 }?;
152
153 let res = client
154 .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
155 .json(&api_request)
156 .send()?
157 .json::<ServerResponse<GetAPIKeyResponse>>()
158 .context(MDB_CLIENT_ERROR_CONTEXT)?;
159
160 let res = match res {
161 ServerResponse::Success(res) => res,
162 ServerResponse::Error(err) => return Err(err.into()),
163 };
164
165 #[cfg(target_os = "macos")]
166 let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
167 Some(crate::macos::CertificateData {
168 cert_type: *cert_type,
169 cert_bytes: std::fs::read(cert_path)?,
170 })
171 } else {
172 None
173 };
174
175 let client = MdbClient {
176 url,
177 api_key: res.key.clone(),
178 client,
179
180 #[cfg(target_os = "macos")]
181 cert,
182 };
183
184 let server_info = client.server_info()?;
185 if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
186 warn!(
187 "Server version {:?} is newer than client {:?}, consider updating.",
188 server_info.mdb_version,
189 crate::MDB_VERSION_SEMVER
190 );
191 }
192
193 if save {
194 if let Err(e) = client.save() {
195 error!("Login successful but failed to save config: {e}");
196 bail!("Login successful but failed to save config: {e}");
197 }
198 }
199 Ok(client)
200 }
201
202 pub fn reset_key(&self) -> Result<()> {
208 let response = self
209 .client
210 .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
211 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
212 .send()
213 .context(MDB_CLIENT_ERROR_CONTEXT)?;
214 if !response.status().is_success() {
215 bail!("failed to reset API key, was it correct?");
216 }
217 Ok(())
218 }
219
220 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
227 let name = path.as_ref().display();
228 let config =
229 std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
230 let cfg: MdbClient =
231 toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
232 Ok(cfg)
233 }
234
235 pub fn load() -> Result<Self> {
248 #[cfg(target_os = "macos")]
249 {
250 if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
251 let builder = reqwest::blocking::ClientBuilder::new()
252 .gzip(true)
253 .zstd(true)
254 .use_rustls_tls()
255 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
256
257 let client = if let Some(cert) = &cert {
258 builder.add_root_certificate(cert.as_cert()?).build()
259 } else {
260 builder.build()
261 }?;
262
263 return Ok(Self {
264 url,
265 api_key,
266 client,
267 cert,
268 });
269 }
270 }
271
272 let path = get_config_path(false)?;
273 if path.exists() {
274 return Self::from_file(path);
275 }
276 bail!("config file not found")
277 }
278
279 pub fn save(&self) -> Result<()> {
287 let toml = toml::to_string(self)?;
288 let path = get_config_path(true)?;
289 std::fs::write(&path, toml)
290 .context(format!("failed to write mdb config to {}", path.display()))
291 }
292
293 pub fn delete(&self) -> Result<()> {
300 #[cfg(target_os = "macos")]
301 crate::macos::clear_credentials();
302
303 let path = get_config_path(false)?;
304 if path.exists() {
305 std::fs::remove_file(&path).context(format!(
306 "failed to delete client config file {}",
307 path.display()
308 ))?;
309 }
310 Ok(())
311 }
312
313 pub fn server_info(&self) -> Result<ServerInfo> {
321 let response = self
322 .client
323 .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
324 .send()?
325 .json::<ServerResponse<ServerInfo>>()
326 .context(MDB_CLIENT_ERROR_CONTEXT)?;
327
328 match response {
329 ServerResponse::Success(info) => Ok(info),
330 ServerResponse::Error(e) => Err(e.into()),
331 }
332 }
333
334 pub fn supported_types(&self) -> Result<SupportedFileTypes> {
340 let response = self
341 .client
342 .get(format!(
343 "{}{}",
344 self.url,
345 malwaredb_api::SUPPORTED_FILE_TYPES_URL
346 ))
347 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
348 .send()?
349 .json::<ServerResponse<SupportedFileTypes>>()
350 .context(MDB_CLIENT_ERROR_CONTEXT)?;
351
352 match response {
353 ServerResponse::Success(types) => Ok(types),
354 ServerResponse::Error(e) => Err(e.into()),
355 }
356 }
357
358 pub fn whoami(&self) -> Result<GetUserInfoResponse> {
365 let response = self
366 .client
367 .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
368 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
369 .send()?
370 .json::<ServerResponse<GetUserInfoResponse>>()
371 .context(MDB_CLIENT_ERROR_CONTEXT)?;
372
373 match response {
374 ServerResponse::Success(info) => Ok(info),
375 ServerResponse::Error(e) => Err(e.into()),
376 }
377 }
378
379 pub fn labels(&self) -> Result<Labels> {
386 let response = self
387 .client
388 .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
389 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
390 .send()?
391 .json::<ServerResponse<Labels>>()
392 .context(MDB_CLIENT_ERROR_CONTEXT)?;
393
394 match response {
395 ServerResponse::Success(labels) => Ok(labels),
396 ServerResponse::Error(e) => Err(e.into()),
397 }
398 }
399
400 pub fn sources(&self) -> Result<Sources> {
407 let response = self
408 .client
409 .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
410 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
411 .send()?
412 .json::<ServerResponse<Sources>>()
413 .context(MDB_CLIENT_ERROR_CONTEXT)?;
414
415 match response {
416 ServerResponse::Success(sources) => Ok(sources),
417 ServerResponse::Error(e) => Err(e.into()),
418 }
419 }
420
421 pub fn submit(
428 &self,
429 contents: impl AsRef<[u8]>,
430 file_name: String,
431 source_id: u32,
432 ) -> Result<bool> {
433 let mut hasher = Sha256::new();
434 hasher.update(&contents);
435 let result = hasher.finalize();
436
437 let encoded = general_purpose::STANDARD.encode(contents);
438
439 let payload = malwaredb_api::NewSampleB64 {
440 file_name,
441 source_id,
442 file_contents_b64: encoded,
443 sha256: hex::encode(result),
444 };
445
446 match self
447 .client
448 .post(format!(
449 "{}{}",
450 self.url,
451 malwaredb_api::UPLOAD_SAMPLE_JSON_URL
452 ))
453 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
454 .json(&payload)
455 .send()
456 {
457 Ok(res) => {
458 if !res.status().is_success() {
459 info!("Code {} sending {}", res.status(), payload.file_name);
460 }
461 Ok(res.status().is_success())
462 }
463 Err(e) => {
464 let status: String = e
465 .status()
466 .map(|s| s.as_str().to_string())
467 .unwrap_or_default();
468 error!("Error{status} sending {}: {e}", payload.file_name);
469 bail!(e.to_string())
470 }
471 }
472 }
473
474 pub fn submit_as_cbor(
482 &self,
483 contents: impl AsRef<[u8]>,
484 file_name: String,
485 source_id: u32,
486 ) -> Result<bool> {
487 let mut hasher = Sha256::new();
488 hasher.update(&contents);
489 let result = hasher.finalize();
490
491 let payload = malwaredb_api::NewSampleBytes {
492 file_name,
493 source_id,
494 file_contents: contents.as_ref().to_vec(),
495 sha256: hex::encode(result),
496 };
497
498 let mut bytes = Vec::with_capacity(payload.file_contents.len());
499 ciborium::ser::into_writer(&payload, &mut bytes)?;
500
501 match self
502 .client
503 .post(format!(
504 "{}{}",
505 self.url,
506 malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
507 ))
508 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
509 .header("content-type", "application/cbor")
510 .body(bytes)
511 .send()
512 {
513 Ok(res) => {
514 if !res.status().is_success() {
515 info!("Code {} sending {}", res.status(), payload.file_name);
516 }
517 Ok(res.status().is_success())
518 }
519 Err(e) => {
520 let status: String = e
521 .status()
522 .map(|s| s.as_str().to_string())
523 .unwrap_or_default();
524 error!("Error{status} sending {}: {e}", payload.file_name);
525 bail!(e.to_string())
526 }
527 }
528 }
529
530 pub fn partial_search(
536 &self,
537 partial_hash: Option<(PartialHashSearchType, String)>,
538 name: Option<String>,
539 response: PartialHashSearchType,
540 limit: u32,
541 ) -> Result<SearchResponse> {
542 let query = SearchRequest {
543 search: SearchType::Search(SearchRequestParameters {
544 partial_hash,
545 file_name: name,
546 response,
547 limit,
548 labels: None,
549 file_type: None,
550 magic: None,
551 }),
552 };
553
554 self.do_search_request(&query)
555 }
556
557 #[allow(clippy::too_many_arguments)]
563 pub fn partial_search_labels_type(
564 &self,
565 partial_hash: Option<(PartialHashSearchType, String)>,
566 name: Option<String>,
567 response: PartialHashSearchType,
568 labels: Option<Vec<String>>,
569 file_type: Option<String>,
570 magic: Option<String>,
571 limit: u32,
572 ) -> Result<SearchResponse> {
573 let query = SearchRequest {
574 search: SearchType::Search(SearchRequestParameters {
575 partial_hash,
576 file_name: name,
577 response,
578 limit,
579 file_type,
580 magic,
581 labels,
582 }),
583 };
584
585 self.do_search_request(&query)
586 }
587
588 pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
594 if let Some(uuid) = response.pagination {
595 let request = SearchRequest {
596 search: SearchType::Continuation(uuid),
597 };
598 return self.do_search_request(&request);
599 }
600
601 bail!("Pagination not available")
602 }
603
604 fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
605 ensure!(
606 query.is_valid(),
607 "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
608 );
609
610 let response = self
611 .client
612 .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
613 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
614 .json(query)
615 .send()?
616 .json::<ServerResponse<SearchResponse>>()
617 .context(MDB_CLIENT_ERROR_CONTEXT)?;
618
619 match response {
620 ServerResponse::Success(search) => Ok(search),
621 ServerResponse::Error(e) => Err(e.into()),
622 }
623 }
624
625 pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
632 let api_endpoint = if cart {
633 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
634 } else {
635 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
636 };
637
638 let res = self
639 .client
640 .get(format!("{}{api_endpoint}", self.url))
641 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
642 .send()?;
643
644 if !res.status().is_success() {
645 bail!("Received code {}", res.status());
646 }
647
648 let body = res.bytes()?;
649 Ok(body.to_vec())
650 }
651
652 pub fn report(&self, hash: &str) -> Result<Report> {
659 let response = self
660 .client
661 .get(format!(
662 "{}{}/{hash}",
663 self.url,
664 malwaredb_api::SAMPLE_REPORT_URL
665 ))
666 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
667 .send()?
668 .json::<ServerResponse<Report>>()
669 .context(MDB_CLIENT_ERROR_CONTEXT)?;
670
671 match response {
672 ServerResponse::Success(report) => Ok(report),
673 ServerResponse::Error(e) => Err(e.into()),
674 }
675 }
676
677 pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
685 let mut hashes = vec![];
686 let ssdeep_hash = FuzzyHash::new(contents);
687
688 let build_hasher = Murmur3HashState::default();
689 let lzjd_str =
690 LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
691 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
692 hashes.push((
693 malwaredb_api::SimilarityHashType::SSDeep,
694 ssdeep_hash.to_string(),
695 ));
696
697 let mut builder = TlshBuilder::new(
698 tlsh_fixed::BucketKind::Bucket256,
699 tlsh_fixed::ChecksumKind::ThreeByte,
700 tlsh_fixed::Version::Version4,
701 );
702
703 builder.update(contents);
704 if let Ok(hasher) = builder.build() {
705 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
706 }
707
708 if let Ok(exe) = EXE::from(contents) {
709 if let Some(imports) = exe.imports {
710 hashes.push((
711 malwaredb_api::SimilarityHashType::ImportHash,
712 hex::encode(imports.hash()),
713 ));
714 hashes.push((
715 malwaredb_api::SimilarityHashType::FuzzyImportHash,
716 imports.fuzzy_hash(),
717 ));
718 }
719 }
720
721 let request = malwaredb_api::SimilarSamplesRequest { hashes };
722
723 let response = self
724 .client
725 .post(format!(
726 "{}{}",
727 self.url,
728 malwaredb_api::SIMILAR_SAMPLES_URL
729 ))
730 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
731 .json(&request)
732 .send()?
733 .json::<ServerResponse<SimilarSamplesResponse>>()
734 .context(MDB_CLIENT_ERROR_CONTEXT)?;
735
736 match response {
737 ServerResponse::Success(similar) => Ok(similar),
738 ServerResponse::Error(e) => Err(e.into()),
739 }
740 }
741}
742
743impl Debug for MdbClient {
744 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
745 use crate::MDB_VERSION;
746
747 writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
748 }
749}
750
751pub struct IterableHashSearchResult<'a> {
753 pub response: SearchResponse,
755
756 client: &'a MdbClient,
757}
758
759impl<'a> IterableHashSearchResult<'a> {
760 #[must_use]
762 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
763 Self { response, client }
764 }
765}
766
767impl Iterator for IterableHashSearchResult<'_> {
768 type Item = String;
769
770 fn next(&mut self) -> Option<Self::Item> {
771 if let Some(hash) = self.response.hashes.pop() {
772 Some(hash)
773 } else if let Some(uuid) = self.response.pagination {
774 let request = SearchRequest {
775 search: SearchType::Continuation(uuid),
776 };
777
778 self.response = match self.client.do_search_request(&request) {
779 Ok(response) => response,
780 Err(e) => {
781 warn!("Failed to continue search: {e}");
782 return None;
783 }
784 };
785
786 self.response.hashes.pop()
787 } else {
788 None
789 }
790 }
791}
792
793pub struct IterableSampleSearchResult<'a> {
795 pub response: SearchResponse,
797
798 client: &'a MdbClient,
799}
800
801impl<'a> IterableSampleSearchResult<'a> {
802 #[must_use]
804 pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
805 Self { response, client }
806 }
807}
808
809impl Iterator for IterableSampleSearchResult<'_> {
810 type Item = Vec<u8>;
811
812 fn next(&mut self) -> Option<Self::Item> {
813 if let Some(hash) = self.response.hashes.pop() {
814 let binary = match self.client.retrieve(&hash, false) {
815 Ok(binary) => binary,
816 Err(e) => {
817 error!("Failed to download {hash}: {e}");
818 return None;
819 }
820 };
821 Some(binary)
822 } else if let Some(uuid) = self.response.pagination {
823 let request = SearchRequest {
824 search: SearchType::Continuation(uuid),
825 };
826
827 self.response = match self.client.do_search_request(&request) {
828 Ok(response) => response,
829 Err(e) => {
830 warn!("Failed to continue search: {e}");
831 return None;
832 }
833 };
834
835 if let Some(hash) = self.response.hashes.pop() {
836 let binary = match self.client.retrieve(&hash, false) {
837 Ok(binary) => binary,
838 Err(e) => {
839 error!("Failed to download {hash}: {e}");
840 return None;
841 }
842 };
843 Some(binary)
844 } else {
845 None
846 }
847 } else {
848 None
849 }
850 }
851}