Skip to main content

malwaredb_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8#![forbid(unsafe_code)]
9
10/// Non-async version of the Malware DB client
11#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
12#[cfg(feature = "blocking")]
13pub mod blocking;
14
15pub use malwaredb_api;
16use malwaredb_api::{
17    GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType, Report, SearchRequest,
18    SearchRequestParameters, SearchResponse, SearchType, ServerInfo, ServerResponse,
19    SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
20    YaraSearchRequestResponse, YaraSearchResponse, digest::HashType,
21};
22use malwaredb_lzjd::{LZDict, Murmur3HashState};
23use malwaredb_types::exec::pe32::EXE;
24use malwaredb_types::utils::entropy_calc;
25
26use std::collections::HashSet;
27use std::fmt::{Debug, Display, Formatter};
28use std::fs::OpenOptions;
29use std::io::{Cursor, Write};
30use std::path::{Path, PathBuf};
31use std::sync::LazyLock;
32
33use anyhow::{Context, Result, bail, ensure};
34use base64::Engine;
35use base64::engine::general_purpose;
36use cart_container::JsonMap;
37use fuzzyhash::FuzzyHash;
38use home::home_dir;
39use mdns_sd::{ServiceDaemon, ServiceEvent};
40use reqwest::Certificate;
41use serde::{Deserialize, Serialize};
42use sha2::{Digest, Sha256, Sha384, Sha512};
43use tlsh_fixed::TlshBuilder;
44use tracing::{debug, error, info, trace, warn};
45use uuid::Uuid;
46use zeroize::{Zeroize, ZeroizeOnDrop};
47
48/// Local directory for the Malware DB client configs
49const MDB_CLIENT_DIR: &str = "malwaredb_client";
50
51/// Error for Anyhow's `context()` function with regard to what is expected just a network error
52pub(crate) const MDB_CLIENT_ERROR_CONTEXT: &str =
53    "Network error connecting to MalwareDB, or failure to decode server response.";
54
55/// Config file name expected by Malware DB client
56const MDB_CLIENT_CONFIG_TOML: &str = "mdb_client.toml";
57
58/// MDB version
59pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
60
61/// MDB version as a semantic version object
62pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
63    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
64
65/// macOS Keychain functionality
66#[cfg(target_os = "macos")]
67pub(crate) mod macos {
68    use crate::CertificateType;
69
70    use anyhow::Result;
71    use reqwest::Certificate;
72    use security_framework::os::macos::keychain::SecKeychain;
73    use tracing::error;
74
75    /// Application identifier for macOS Keychain
76    const KEYCHAIN_ID: &str = "malwaredb-client";
77
78    /// Entry ID for the Malware DB server URL
79    const KEYCHAIN_URL: &str = "URL";
80
81    /// Entry ID for the user's Malware DB API Key
82    const KEYCHAIN_API_KEY: &str = "API_KEY";
83
84    /// Entry ID for the custom PEM-encoded certificate for the Malware DB server
85    const KEYCHAIN_CERTIFICATE_PEM: &str = "CERT_PEM";
86
87    /// Entry ID for the custom DER-encoded certificate for the Malware DB server
88    const KEYCHAIN_CERTIFICATE_DER: &str = "CERT_DER";
89
90    #[derive(Clone)]
91    pub(crate) struct CertificateData {
92        pub cert_type: CertificateType,
93        pub cert_bytes: Vec<u8>,
94    }
95
96    impl CertificateData {
97        pub(crate) fn as_cert(&self) -> Result<Certificate> {
98            Ok(match self.cert_type {
99                CertificateType::PEM => Certificate::from_pem(&self.cert_bytes)?,
100                CertificateType::DER => Certificate::from_der(&self.cert_bytes)?,
101            })
102        }
103    }
104
105    /// Save the elements to they Keychain
106    pub fn save_credentials(url: &str, key: &str, cert: Option<CertificateData>) -> Result<()> {
107        let keychain = SecKeychain::default()?;
108
109        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_URL, url.as_bytes())?;
110        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY, key.as_bytes())?;
111
112        if let Some(cert) = cert {
113            match cert.cert_type {
114                CertificateType::PEM => keychain.add_generic_password(
115                    KEYCHAIN_ID,
116                    KEYCHAIN_CERTIFICATE_PEM,
117                    &cert.cert_bytes,
118                )?,
119                CertificateType::DER => keychain.add_generic_password(
120                    KEYCHAIN_ID,
121                    KEYCHAIN_CERTIFICATE_DER,
122                    &cert.cert_bytes,
123                )?,
124            }
125        }
126
127        Ok(())
128    }
129
130    /// Return key, url, and optionally the certificate in that order. Errors silently discarded.
131    pub fn retrieve_credentials() -> Result<(String, String, Option<CertificateData>)> {
132        let keychain = SecKeychain::default()?;
133        let (api_key, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY)?;
134        let api_key = String::from_utf8(api_key.as_ref().to_vec())?;
135        let (url, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_URL)?;
136        let url = String::from_utf8(url.as_ref().to_vec())?;
137
138        if let Ok((cert, _item)) =
139            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_PEM)
140        {
141            let cert = CertificateData {
142                cert_type: CertificateType::PEM,
143                cert_bytes: cert.to_vec(),
144            };
145            return Ok((api_key, url, Some(cert)));
146        }
147
148        if let Ok((cert, _item)) =
149            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_DER)
150        {
151            let cert = CertificateData {
152                cert_type: CertificateType::DER,
153                cert_bytes: cert.to_vec(),
154            };
155            return Ok((api_key, url, Some(cert)));
156        }
157
158        Ok((api_key, url, None))
159    }
160
161    /// Delete Malware DB client information from the Keychain
162    pub fn clear_credentials() {
163        if let Ok(keychain) = SecKeychain::default() {
164            for element in [
165                KEYCHAIN_API_KEY,
166                KEYCHAIN_URL,
167                KEYCHAIN_CERTIFICATE_PEM,
168                KEYCHAIN_CERTIFICATE_DER,
169            ] {
170                if let Ok((_, item)) = keychain.find_generic_password(KEYCHAIN_ID, element) {
171                    item.delete();
172                }
173            }
174        } else {
175            error!("Failed to get access to the Keychain to clear credentials");
176        }
177    }
178}
179
180#[allow(clippy::upper_case_acronyms)]
181#[derive(Copy, Clone, PartialEq, Eq)]
182enum CertificateType {
183    DER,
184    PEM,
185}
186
187/// Asynchronous Malware DB Client Configuration and connection
188#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
189pub struct MdbClient {
190    /// URL of the Malware DB server, including http and port number, ending without a slash
191    pub url: String,
192
193    /// User's API key for Malware DB
194    api_key: String,
195
196    /// Async http client which stores the optional server certificate
197    #[zeroize(skip)]
198    #[serde(skip)]
199    client: reqwest::Client,
200
201    /// Server's certificate
202    #[cfg(target_os = "macos")]
203    #[zeroize(skip)]
204    #[serde(skip)]
205    cert: Option<macos::CertificateData>,
206}
207
208impl MdbClient {
209    /// MDB Client from components, doesn't test connectivity
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if a list of certificates was passed and any were not in the expected
214    /// DER or PEM format or could not be parsed.
215    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
216        let mut url = url;
217        let url = if url.ends_with('/') {
218            url.pop();
219            url
220        } else {
221            url
222        };
223
224        let cert = if let Some(path) = cert_path {
225            Some((path_load_cert(&path)?, path))
226        } else {
227            None
228        };
229
230        let builder = reqwest::ClientBuilder::new()
231            .gzip(true)
232            .zstd(true)
233            .use_rustls_tls()
234            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
235
236        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
237            builder.add_root_certificate(cert.clone()).build()
238        } else {
239            builder.build()
240        }?;
241
242        #[cfg(target_os = "macos")]
243        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
244            Some(macos::CertificateData {
245                cert_type: *cert_type,
246                cert_bytes: std::fs::read(cert_path)?,
247            })
248        } else {
249            None
250        };
251
252        Ok(Self {
253            url,
254            api_key,
255            client,
256
257            #[cfg(target_os = "macos")]
258            cert,
259        })
260    }
261
262    /// Login to a server, optionally save the configuration file, and return a client object
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if the server URL, username, or password were incorrect, or if a network
267    /// issue occurred.
268    pub async fn login(
269        url: String,
270        username: String,
271        password: String,
272        save: bool,
273        cert_path: Option<PathBuf>,
274    ) -> Result<Self> {
275        let mut url = url;
276        let url = if url.ends_with('/') {
277            url.pop();
278            url
279        } else {
280            url
281        };
282
283        let api_request = malwaredb_api::GetAPIKeyRequest {
284            user: username,
285            password,
286        };
287
288        let builder = reqwest::ClientBuilder::new()
289            .gzip(true)
290            .zstd(true)
291            .use_rustls_tls()
292            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
293
294        let cert = if let Some(path) = cert_path {
295            Some((path_load_cert(&path)?, path))
296        } else {
297            None
298        };
299
300        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
301            builder.add_root_certificate(cert.clone()).build()
302        } else {
303            builder.build()
304        }?;
305
306        let res = client
307            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
308            .json(&api_request)
309            .send()
310            .await?
311            .json::<ServerResponse<GetAPIKeyResponse>>()
312            .await
313            .context(MDB_CLIENT_ERROR_CONTEXT)?;
314
315        let res = match res {
316            ServerResponse::Success(res) => res,
317            ServerResponse::Error(err) => return Err(err.into()),
318        };
319
320        #[cfg(target_os = "macos")]
321        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
322            Some(macos::CertificateData {
323                cert_type: *cert_type,
324                cert_bytes: std::fs::read(cert_path)?,
325            })
326        } else {
327            None
328        };
329
330        let client = MdbClient {
331            url,
332            api_key: res.key.clone(),
333            client,
334
335            #[cfg(target_os = "macos")]
336            cert,
337        };
338
339        let server_info = client.server_info().await?;
340        if server_info.mdb_version > *MDB_VERSION_SEMVER {
341            warn!(
342                "Server version {:?} is newer than client {:?}, consider updating.",
343                server_info.mdb_version, MDB_VERSION_SEMVER
344            );
345        }
346
347        if save && let Err(e) = client.save() {
348            error!("Login successful but failed to save config: {e}");
349            bail!("Login successful but failed to save config: {e}");
350        }
351        Ok(client)
352    }
353
354    /// Reset one's own API key to effectively logout & disable all clients who are using the key
355    ///
356    /// # Errors
357    ///
358    /// Returns an error if there was a network issue or the user wasn't properly logged in.
359    pub async fn reset_key(&self) -> Result<()> {
360        let response = self
361            .client
362            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
363            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
364            .send()
365            .await
366            .context(MDB_CLIENT_ERROR_CONTEXT)?;
367        if !response.status().is_success() {
368            bail!("failed to reset API key, was it correct?");
369        }
370        Ok(())
371    }
372
373    /// Malware DB Client configuration loaded from a specified path
374    ///
375    /// # Errors
376    ///
377    /// Returns an error if the configuration file cannot be read, possibly because it
378    /// doesn't exist or due to a permission error or a parsing error.
379    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
380        let name = path.as_ref().display();
381        let config =
382            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
383        let cfg: MdbClient =
384            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
385        Ok(cfg)
386    }
387
388    /// Malware DB Client configuration from user's home directory
389    ///
390    /// On macOS, it will attempt to load this information in the Keychain, which isn't required.
391    ///
392    /// # Errors
393    ///
394    /// Returns an error if the configuration file cannot be read, possibly because it
395    /// doesn't exist or due to a permission error or a parsing error.
396    pub fn load() -> Result<Self> {
397        #[cfg(target_os = "macos")]
398        {
399            if let Ok((api_key, url, cert)) = macos::retrieve_credentials() {
400                let builder = reqwest::ClientBuilder::new()
401                    .gzip(true)
402                    .zstd(true)
403                    .use_rustls_tls()
404                    .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
405
406                let client = if let Some(cert) = &cert {
407                    builder.add_root_certificate(cert.as_cert()?).build()
408                } else {
409                    builder.build()
410                }?;
411
412                return Ok(Self {
413                    url,
414                    api_key,
415                    client,
416                    cert,
417                });
418            }
419        }
420
421        let path = get_config_path(false)?;
422        if path.exists() {
423            return Self::from_file(path);
424        }
425        bail!("config file not found")
426    }
427
428    /// Save Malware DB Client configuration to the user's home directory.
429    ///
430    /// On macOS, it will attempt to save this information in the Keychain, which isn't required.
431    ///
432    /// # Errors
433    ///
434    /// Returns an error if there was a problem saving the configuration file.
435    pub fn save(&self) -> Result<()> {
436        #[cfg(target_os = "macos")]
437        {
438            if macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
439                return Ok(());
440            }
441        }
442
443        let toml = toml::to_string(self)?;
444        let path = get_config_path(true)?;
445
446        let mut options = OpenOptions::new();
447        options
448            .write(true)
449            .create(true)
450            .append(false)
451            .truncate(false);
452
453        #[cfg(target_family = "unix")]
454        {
455            use std::os::unix::fs::OpenOptionsExt;
456
457            options.mode(0o600);
458        }
459
460        let mut file = options.open(&path)?;
461        write!(file, "{toml}").context(format!("failed to write mdb config to {}", path.display()))
462    }
463
464    /// Delete the Malware DB client configuration file
465    ///
466    /// # Errors
467    ///
468    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
469    /// possibly due to a permissions error.
470    pub fn delete(&self) -> Result<()> {
471        #[cfg(target_os = "macos")]
472        macos::clear_credentials();
473
474        let path = get_config_path(false)?;
475        if path.exists() {
476            std::fs::remove_file(&path)
477                .context(format!("failed to delete client config file {}", path.display()))?;
478        }
479        Ok(())
480    }
481
482    // Actions of the client
483
484    /// Get information about the server, unauthenticated
485    ///
486    /// # Errors
487    ///
488    /// This may return an error if there's a network situation.
489    pub async fn server_info(&self) -> Result<ServerInfo> {
490        let response = self
491            .client
492            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
493            .send()
494            .await?
495            .json::<ServerResponse<ServerInfo>>()
496            .await
497            .context(MDB_CLIENT_ERROR_CONTEXT)?;
498
499        match response {
500            ServerResponse::Success(info) => Ok(info),
501            ServerResponse::Error(e) => Err(e.into()),
502        }
503    }
504
505    /// Get file types supported by the server, unauthenticated
506    ///
507    /// # Errors
508    ///
509    /// This may return an error if there's a network situation.
510    pub async fn supported_types(&self) -> Result<SupportedFileTypes> {
511        let response = self
512            .client
513            .get(format!("{}{}", self.url, malwaredb_api::SUPPORTED_FILE_TYPES_URL))
514            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
515            .send()
516            .await?
517            .json::<ServerResponse<SupportedFileTypes>>()
518            .await
519            .context(MDB_CLIENT_ERROR_CONTEXT)?;
520
521        match response {
522            ServerResponse::Success(types) => Ok(types),
523            ServerResponse::Error(e) => Err(e.into()),
524        }
525    }
526
527    /// Get information about the user
528    ///
529    /// # Errors
530    ///
531    /// This may return an error if there's a network situation or if the user is not logged in
532    /// or not properly authorized to connect.
533    pub async fn whoami(&self) -> Result<GetUserInfoResponse> {
534        let response = self
535            .client
536            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
537            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
538            .send()
539            .await?
540            .json::<ServerResponse<GetUserInfoResponse>>()
541            .await
542            .context(MDB_CLIENT_ERROR_CONTEXT)?;
543
544        match response {
545            ServerResponse::Success(info) => Ok(info),
546            ServerResponse::Error(e) => Err(e.into()),
547        }
548    }
549
550    /// Get the sample labels known to the server
551    ///
552    /// # Errors
553    ///
554    /// This may return an error if there's a network situation or if the user is not logged in
555    /// or not properly authorized to connect.
556    pub async fn labels(&self) -> Result<Labels> {
557        let response = self
558            .client
559            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
560            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
561            .send()
562            .await?
563            .json::<ServerResponse<Labels>>()
564            .await
565            .context(MDB_CLIENT_ERROR_CONTEXT)?;
566
567        match response {
568            ServerResponse::Success(labels) => Ok(labels),
569            ServerResponse::Error(e) => Err(e.into()),
570        }
571    }
572
573    /// Get the sources available to the current user
574    ///
575    /// # Errors
576    ///
577    /// This may return an error if there's a network situation or if the user is not logged in
578    /// or not properly authorized to connect.
579    pub async fn sources(&self) -> Result<Sources> {
580        let response = self
581            .client
582            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
583            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
584            .send()
585            .await?
586            .json::<ServerResponse<Sources>>()
587            .await
588            .context(MDB_CLIENT_ERROR_CONTEXT)?;
589
590        match response {
591            ServerResponse::Success(sources) => Ok(sources),
592            ServerResponse::Error(e) => Err(e.into()),
593        }
594    }
595
596    /// Submit one file to `MalwareDB`: provide the contents, file name, and source ID
597    ///
598    /// # Errors
599    ///
600    /// This may return an error if there's a network situation or if the user is not logged in
601    /// or not properly authorized to connect.
602    pub async fn submit(
603        &self,
604        contents: impl AsRef<[u8]>,
605        file_name: impl AsRef<str>,
606        source_id: u32,
607    ) -> Result<bool> {
608        let mut hasher = Sha256::new();
609        hasher.update(&contents);
610        let result = hasher.finalize();
611
612        let encoded = general_purpose::STANDARD.encode(contents);
613
614        let payload = malwaredb_api::NewSampleB64 {
615            file_name: file_name.as_ref().to_string(),
616            source_id,
617            file_contents_b64: encoded,
618            sha256: hex::encode(result),
619        };
620
621        match self
622            .client
623            .post(format!("{}{}", self.url, malwaredb_api::UPLOAD_SAMPLE_JSON_URL))
624            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
625            .json(&payload)
626            .send()
627            .await
628        {
629            Ok(res) => {
630                if !res.status().is_success() {
631                    info!("Code {} sending {}", res.status(), payload.file_name);
632                }
633                Ok(res.status().is_success())
634            }
635            Err(e) => {
636                let status: String = e
637                    .status()
638                    .map(|s| s.as_str().to_string())
639                    .unwrap_or_default();
640                error!("Error{status} sending {}: {e}", payload.file_name);
641                bail!(e.to_string())
642            }
643        }
644    }
645
646    /// Submit one file to `MalwareDB` as a Cbor object: provide the contents, file name, and source ID
647    /// Experimental! May be removed at any point.
648    ///
649    /// # Errors
650    ///
651    /// This may return an error if there's a network situation or if the user is not logged in
652    /// or not properly authorized to connect.
653    pub async fn submit_as_cbor(
654        &self,
655        contents: impl AsRef<[u8]>,
656        file_name: impl AsRef<str>,
657        source_id: u32,
658    ) -> Result<bool> {
659        let mut hasher = Sha256::new();
660        hasher.update(&contents);
661        let result = hasher.finalize();
662
663        let payload = malwaredb_api::NewSampleBytes {
664            file_name: file_name.as_ref().to_string(),
665            source_id,
666            file_contents: contents.as_ref().to_vec(),
667            sha256: hex::encode(result),
668        };
669
670        let mut bytes = Vec::with_capacity(payload.file_contents.len());
671        ciborium::ser::into_writer(&payload, &mut bytes)?;
672
673        match self
674            .client
675            .post(format!("{}{}", self.url, malwaredb_api::UPLOAD_SAMPLE_CBOR_URL))
676            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
677            .header("content-type", "application/cbor")
678            .body(bytes)
679            .send()
680            .await
681        {
682            Ok(res) => {
683                if !res.status().is_success() {
684                    info!("Code {} sending {}", res.status(), payload.file_name);
685                }
686                Ok(res.status().is_success())
687            }
688            Err(e) => {
689                let status: String = e
690                    .status()
691                    .map(|s| s.as_str().to_string())
692                    .unwrap_or_default();
693                error!("Error{status} sending {}: {e}", payload.file_name);
694                bail!(e.to_string())
695            }
696        }
697    }
698
699    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
700    ///
701    /// # Errors
702    ///
703    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
704    pub async fn partial_search(
705        &self,
706        partial_hash: Option<(PartialHashSearchType, String)>,
707        name: Option<String>,
708        response: PartialHashSearchType,
709        limit: u32,
710    ) -> Result<SearchResponse> {
711        let query = SearchRequest {
712            search: SearchType::Search(SearchRequestParameters {
713                partial_hash,
714                file_name: name,
715                response,
716                limit,
717                labels: None,
718                file_type: None,
719                magic: None,
720            }),
721        };
722
723        self.do_search_request(&query).await
724    }
725
726    /// Search for a file based on partial hash and/or partial file name, labels, file type; returns a list of hashes
727    ///
728    /// # Errors
729    ///
730    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
731    #[allow(clippy::too_many_arguments)]
732    pub async fn partial_search_labels_type(
733        &self,
734        partial_hash: Option<(PartialHashSearchType, String)>,
735        name: Option<String>,
736        response: PartialHashSearchType,
737        labels: Option<Vec<String>>,
738        file_type: Option<String>,
739        magic: Option<String>,
740        limit: u32,
741    ) -> Result<SearchResponse> {
742        let query = SearchRequest {
743            search: SearchType::Search(SearchRequestParameters {
744                partial_hash,
745                file_name: name,
746                response,
747                limit,
748                file_type,
749                magic,
750                labels,
751            }),
752        };
753
754        self.do_search_request(&query).await
755    }
756
757    /// Return the next page from the search result
758    ///
759    /// # Errors
760    ///
761    /// Returns an error if there is a network problem, or pagination not available
762    pub async fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
763        if let Some(uuid) = response.pagination {
764            let request = SearchRequest {
765                search: SearchType::Continuation(uuid),
766            };
767            return self.do_search_request(&request).await;
768        }
769
770        bail!("Pagination not available")
771    }
772
773    async fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
774        ensure!(
775            query.is_valid(),
776            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
777        );
778
779        let response = self
780            .client
781            .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
782            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
783            .json(query)
784            .send()
785            .await?
786            .json::<ServerResponse<SearchResponse>>()
787            .await
788            .context(MDB_CLIENT_ERROR_CONTEXT)?;
789
790        match response {
791            ServerResponse::Success(search) => Ok(search),
792            ServerResponse::Error(e) => Err(e.into()),
793        }
794    }
795
796    /// Retrieve sample by hash, optionally in the `CaRT` format
797    ///
798    /// # Errors
799    ///
800    /// This may return an error if there's a network situation or if the user is not logged in
801    /// or not properly authorized to connect.
802    pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
803        let api_endpoint = if cart {
804            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
805        } else {
806            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
807        };
808
809        let res = self
810            .client
811            .get(format!("{}{api_endpoint}", self.url))
812            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
813            .send()
814            .await?;
815
816        if !res.status().is_success() {
817            bail!("Received code {}", res.status());
818        }
819
820        let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
821        let body = res.bytes().await?;
822        let bytes = body.to_vec();
823
824        // TODO: Make this required in v0.3
825        if let Some(digest) = content_digest {
826            let hash = HashType::from_content_digest_header(digest.to_str()?)?;
827            if hash.verify(&bytes) {
828                trace!("Hash verified for sample {hash}");
829            } else {
830                error!("Hash mismatch for sample {hash}");
831            }
832        } else {
833            warn!("No content digest header received for sample {hash}");
834        }
835
836        Ok(bytes)
837    }
838
839    /// Fetch a report for a sample
840    ///
841    /// # Errors
842    ///
843    /// This may return an error if there's a network situation or if the user is not logged in
844    /// or not properly authorized to connect.
845    pub async fn report(&self, hash: &str) -> Result<Report> {
846        let response = self
847            .client
848            .get(format!("{}{}/{hash}", self.url, malwaredb_api::SAMPLE_REPORT_URL))
849            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
850            .send()
851            .await?
852            .json::<ServerResponse<Report>>()
853            .await
854            .context(MDB_CLIENT_ERROR_CONTEXT)?;
855
856        match response {
857            ServerResponse::Success(report) => Ok(report),
858            ServerResponse::Error(e) => Err(e.into()),
859        }
860    }
861
862    /// Find similar samples in `MalwareDB` based on the contents of a given file.
863    /// This does not submit the sample to `MalwareDB`.
864    ///
865    /// # Errors
866    ///
867    /// This may return an error if there's a network situation or if the user is not logged in
868    /// or not properly authorized to connect.
869    pub async fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
870        let mut hashes = vec![];
871        let ssdeep_hash = FuzzyHash::new(contents);
872
873        let build_hasher = Murmur3HashState::default();
874        let lzjd_str =
875            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
876        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
877        hashes.push((malwaredb_api::SimilarityHashType::SSDeep, ssdeep_hash.to_string()));
878
879        let mut builder = TlshBuilder::new(
880            tlsh_fixed::BucketKind::Bucket256,
881            tlsh_fixed::ChecksumKind::ThreeByte,
882            tlsh_fixed::Version::Version4,
883        );
884
885        builder.update(contents);
886        if let Ok(hasher) = builder.build() {
887            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
888        }
889
890        if let Ok(exe) = EXE::from(contents)
891            && let Some(imports) = exe.imports
892        {
893            hashes
894                .push((malwaredb_api::SimilarityHashType::ImportHash, hex::encode(imports.hash())));
895            hashes.push((malwaredb_api::SimilarityHashType::FuzzyImportHash, imports.fuzzy_hash()));
896        }
897
898        let request = malwaredb_api::SimilarSamplesRequest { hashes };
899
900        let response = self
901            .client
902            .post(format!("{}{}", self.url, malwaredb_api::SIMILAR_SAMPLES_URL))
903            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
904            .json(&request)
905            .send()
906            .await?
907            .json::<ServerResponse<SimilarSamplesResponse>>()
908            .await
909            .context(MDB_CLIENT_ERROR_CONTEXT)?;
910
911        match response {
912            ServerResponse::Success(similar) => Ok(similar),
913            ServerResponse::Error(e) => Err(e.into()),
914        }
915    }
916
917    /// Submit a Yara rule and return the UUID of the search for later retrieval.
918    ///
919    /// # Errors
920    ///
921    /// Network or authentication errors
922    pub async fn yara_search(&self, yara: &str) -> Result<YaraSearchRequestResponse> {
923        let yara = YaraSearchRequest {
924            rules: vec![yara.to_string()],
925            response: PartialHashSearchType::SHA256,
926        };
927
928        let response = self
929            .client
930            .post(format!("{}{}", self.url, malwaredb_api::YARA_SEARCH_URL))
931            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
932            .json(&yara)
933            .send()
934            .await?
935            .json::<ServerResponse<YaraSearchRequestResponse>>()
936            .await?;
937
938        match response {
939            ServerResponse::Success(sources) => Ok(sources),
940            ServerResponse::Error(e) => Err(e.into()),
941        }
942    }
943
944    /// Get the result from a Yara search
945    ///
946    /// # Errors
947    ///
948    /// Network or authentication errors
949    pub async fn yara_result(&self, uuid: Uuid) -> Result<YaraSearchResponse> {
950        let response = self
951            .client
952            .get(format!("{}{}/{uuid}", self.url, malwaredb_api::YARA_SEARCH_URL))
953            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
954            .send()
955            .await?
956            .json::<ServerResponse<YaraSearchResponse>>()
957            .await?;
958
959        match response {
960            ServerResponse::Success(sources) => Ok(sources),
961            ServerResponse::Error(e) => Err(e.into()),
962        }
963    }
964}
965
966impl Debug for MdbClient {
967    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
968        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
969    }
970}
971
972/// Convenience function for encoding bytes into a `CaRT` file using the default key. This also
973/// adds SHA-384 and SHA-512 hashes plus the entropy of the original file.
974/// See <https://github.com/CybercentreCanada/cart> for more information.
975///
976/// # Errors
977///
978/// There should not be any errors, but the underlying library can't guarantee that. However, any data
979/// passed to it is correct.
980pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
981    let mut input_buffer = Cursor::new(data);
982    let mut output_buffer = Cursor::new(vec![]);
983    let mut output_metadata = JsonMap::new();
984
985    let mut sha384 = Sha384::new();
986    sha384.update(data);
987    let sha384 = hex::encode(sha384.finalize());
988
989    let mut sha512 = Sha512::new();
990    sha512.update(data);
991    let sha512 = hex::encode(sha512.finalize());
992
993    output_metadata.insert("sha384".into(), sha384.into());
994    output_metadata.insert("sha512".into(), sha512.into());
995    output_metadata.insert("entropy".into(), entropy_calc(data).into());
996    cart_container::pack_stream(
997        &mut input_buffer,
998        &mut output_buffer,
999        Some(output_metadata),
1000        None,
1001        cart_container::digesters::default_digesters(),
1002        None,
1003    )?;
1004
1005    Ok(output_buffer.into_inner())
1006}
1007
1008/// Convenience function for decoding a `CaRT` file using the default key, returning the bytes plus the
1009/// optional header and footer metadata, if present.
1010/// See <https://github.com/CybercentreCanada/cart> for more information.
1011///
1012/// # Errors
1013///
1014/// Returns an error if the file cannot be parsed or if this `CaRT` file didn't use the default key.
1015/// <https://github.com/CybercentreCanada/cart-rs/blob/7ad548143bb85b64f364804e90cfada6c31cf902/cart_container/src/cipher.rs#L14-L17>
1016pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
1017    let mut input_buffer = Cursor::new(data);
1018    let mut output_buffer = Cursor::new(vec![]);
1019    let (header, footer) =
1020        cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
1021    Ok((output_buffer.into_inner(), header, footer))
1022}
1023
1024/// Load a certificate from a path
1025///
1026/// # Errors
1027///
1028/// Returns errors if the file cannot be read or if the file isn't an ASN.1 DER file or
1029/// base64-encoded ASN.1 PEM file.
1030fn path_load_cert(path: &Path) -> Result<(CertificateType, Certificate)> {
1031    if !path.exists() {
1032        bail!("Certificate {} does not exist.", path.display());
1033    }
1034    let cert = match path
1035        .extension()
1036        .context("can't determine file extension")?
1037        .to_str()
1038        .context("unable to parse file extension")?
1039    {
1040        "pem" => {
1041            let contents = std::fs::read(path)?;
1042            (CertificateType::PEM, Certificate::from_pem(&contents)?)
1043        }
1044        "der" => {
1045            let contents = std::fs::read(path)?;
1046            (CertificateType::DER, Certificate::from_der(&contents)?)
1047        }
1048        ext => {
1049            bail!("Unknown extension {ext:?}")
1050        }
1051    };
1052    Ok(cert)
1053}
1054
1055/// Gets the configuration file in the following order:
1056///
1057/// 1. Current working directory: `./mdb_client.toml`
1058/// 2. Haiku-specific directory if on Haiku
1059/// 3. XDG Free Desktop directory if on Unix
1060/// 4. The user's home directory in `~/.config/malwaredb_client/mdb_client.toml`
1061/// 5. `mdb_client.toml` in the current directory (same as the first, but after checking others)
1062#[inline]
1063pub(crate) fn get_config_path(create: bool) -> Result<PathBuf> {
1064    // If there is a config file in the current working directory, use it
1065    let config = PathBuf::from(MDB_CLIENT_CONFIG_TOML);
1066    if config.exists() {
1067        return Ok(config);
1068    }
1069
1070    #[cfg(target_os = "haiku")]
1071    {
1072        let mut settings = PathBuf::from("/boot/home/config/settings/malwaredb");
1073        if create && !settings.exists() {
1074            std::fs::create_dir_all(&settings)?;
1075        }
1076        settings.push(MDB_CLIENT_CONFIG_TOML);
1077        return Ok(settings);
1078    }
1079
1080    #[cfg(unix)]
1081    {
1082        // Obey the Free Desktop standard, check the variable
1083        if let Some(xdg_home) = std::env::var_os("XDG_CONFIG_HOME") {
1084            let mut xdg_config_home = PathBuf::from(xdg_home);
1085            xdg_config_home.push(MDB_CLIENT_DIR);
1086            if create && !xdg_config_home.exists() {
1087                std::fs::create_dir_all(&xdg_config_home)?;
1088            }
1089            xdg_config_home.push(MDB_CLIENT_CONFIG_TOML);
1090            return Ok(xdg_config_home);
1091        }
1092    }
1093
1094    if let Some(mut home_config) = home_dir() {
1095        home_config.push(".config");
1096        home_config.push(MDB_CLIENT_DIR);
1097        if create && !home_config.exists() {
1098            std::fs::create_dir_all(&home_config)?;
1099        }
1100        home_config.push(MDB_CLIENT_CONFIG_TOML);
1101        return Ok(home_config);
1102    }
1103
1104    Ok(PathBuf::from(MDB_CLIENT_CONFIG_TOML))
1105}
1106
1107/// Malware DB entries found by Multicast DNS (also known as Bonjour or Zeroconf)
1108#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1109pub struct MalwareDBServer {
1110    /// Server IP or domain
1111    pub host: String,
1112
1113    /// Server port
1114    pub port: u16,
1115
1116    /// If the server expects an encrypted connection
1117    pub ssl: bool,
1118
1119    /// Malware DB server name
1120    pub name: String,
1121}
1122
1123impl Display for MalwareDBServer {
1124    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1125        if self.ssl {
1126            write!(f, "https://{}:{}", self.host, self.port)
1127        } else {
1128            write!(f, "http://{}:{}", self.host, self.port)
1129        }
1130    }
1131}
1132
1133impl MalwareDBServer {
1134    /// Retrieve details about the server
1135    ///
1136    /// # Errors
1137    ///
1138    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1139    pub async fn server_info(&self) -> Result<ServerInfo> {
1140        let client = reqwest::ClientBuilder::new()
1141            .gzip(true)
1142            .zstd(true)
1143            .use_rustls_tls()
1144            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1145            .build()?;
1146
1147        let response = client
1148            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1149            .send()
1150            .await?
1151            .json::<ServerResponse<ServerInfo>>()
1152            .await
1153            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1154
1155        match response {
1156            ServerResponse::Success(info) => Ok(info),
1157            ServerResponse::Error(e) => Err(e.into()),
1158        }
1159    }
1160
1161    /// Retrieve details about the server
1162    ///
1163    /// # Errors
1164    ///
1165    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1166    ///
1167    /// # Panics
1168    ///
1169    /// This method panics if called from within an async runtime.
1170    #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
1171    #[cfg(feature = "blocking")]
1172    pub fn server_info_blocking(&self) -> Result<ServerInfo> {
1173        let client = reqwest::blocking::ClientBuilder::new()
1174            .gzip(true)
1175            .zstd(true)
1176            .use_rustls_tls()
1177            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1178            .build()?;
1179
1180        let response = client
1181            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1182            .send()?
1183            .json::<ServerResponse<ServerInfo>>()
1184            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1185
1186        match response {
1187            ServerResponse::Success(similar) => Ok(similar),
1188            ServerResponse::Error(e) => Err(e.into()),
1189        }
1190    }
1191}
1192
1193/// Find servers using Multicast DNS (also known as Bonjour)
1194///
1195/// # Errors
1196///
1197/// This may fail if there's a networking issue.
1198pub fn discover_servers() -> Result<Vec<MalwareDBServer>> {
1199    const MAX_ITERS: usize = 5;
1200    let mdns = ServiceDaemon::new()?;
1201    let mut servers = HashSet::new();
1202    let receiver = mdns.browse(malwaredb_api::MDNS_NAME)?;
1203
1204    let mut counter = 0;
1205    while let Ok(event) = receiver.recv() {
1206        if let ServiceEvent::ServiceResolved(resolved) = event {
1207            let host = resolved.host.replace(".local.", "");
1208            let ssl = if let Some(ssl) = resolved.txt_properties.get("ssl") {
1209                ssl.val_str() == "true"
1210            } else {
1211                debug!(
1212                    "MalwareDB entry for {host}:{} doesn't specify ssl, assuming not",
1213                    resolved.port
1214                );
1215                false
1216            };
1217
1218            let server = MalwareDBServer {
1219                host,
1220                port: resolved.port,
1221                ssl,
1222                name: resolved.fullname.replace(malwaredb_api::MDNS_NAME, ""),
1223            };
1224
1225            servers.insert(server);
1226        }
1227        counter += 1;
1228        if counter > MAX_ITERS {
1229            break;
1230        }
1231    }
1232
1233    if mdns.shutdown().is_err() {
1234        // Pass
1235    }
1236    Ok(servers.into_iter().collect())
1237}
1238
1239#[cfg(test)]
1240mod tests {
1241    use super::*;
1242
1243    #[test]
1244    fn cart() {
1245        const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
1246        const ORIGINAL_SHA256: &str =
1247            "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
1248
1249        let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
1250
1251        let mut sha256 = Sha256::new();
1252        sha256.update(&decoded);
1253        let sha256 = hex::encode(sha256.finalize());
1254        assert_eq!(sha256, ORIGINAL_SHA256);
1255
1256        let header = header.unwrap();
1257        let entropy = header.get("entropy").unwrap().as_f64().unwrap();
1258        assert!(entropy > 4.0 && entropy < 4.1);
1259
1260        let footer = footer.unwrap();
1261        assert_eq!(footer.get("length").unwrap(), "5093");
1262        assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
1263    }
1264}