malwaredb_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8#![forbid(unsafe_code)]
9
10/// Non-async version of the Malware DB client
11#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
12#[cfg(feature = "blocking")]
13pub mod blocking;
14
15pub use malwaredb_api;
16use malwaredb_lzjd::{LZDict, Murmur3HashState};
17use malwaredb_types::exec::pe32::EXE;
18use malwaredb_types::utils::entropy_calc;
19use std::collections::HashSet;
20
21use anyhow::{bail, ensure, Context, Result};
22use base64::engine::general_purpose;
23use base64::Engine;
24use cart_container::JsonMap;
25use fuzzyhash::FuzzyHash;
26use home::home_dir;
27use malwaredb_api::{
28    GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType, Report, SearchRequest,
29    SearchRequestParameters, SearchResponse, SearchType, ServerInfo, ServerResponse,
30    SimilarSamplesResponse, Sources, SupportedFileTypes,
31};
32use mdns_sd::{ServiceDaemon, ServiceEvent};
33use reqwest::Certificate;
34use serde::{Deserialize, Serialize};
35use sha2::{Digest, Sha256, Sha384, Sha512};
36use std::fmt::{Debug, Display, Formatter};
37use std::io::Cursor;
38use std::path::{Path, PathBuf};
39use std::sync::LazyLock;
40use tlsh_fixed::TlshBuilder;
41use tracing::{debug, error, info, warn};
42use zeroize::{Zeroize, ZeroizeOnDrop};
43
44/// Local directory for the Malware DB client configs
45const MDB_CLIENT_DIR: &str = "malwaredb_client";
46
47/// Error for Anyhow's `context()` function with regard to what is expected just a network error
48pub(crate) const MDB_CLIENT_ERROR_CONTEXT: &str =
49    "Network error connecting to MalwareDB, or failure to decode server response.";
50
51/// Config file name expected by Malware DB client
52const MDB_CLIENT_CONFIG_TOML: &str = "mdb_client.toml";
53
54/// MDB version
55pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
56
57/// MDB version as a semantic version object
58pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
59    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
60
61/// macOS Keychain functionality
62#[cfg(target_os = "macos")]
63pub(crate) mod macos {
64    use crate::CertificateType;
65
66    use anyhow::Result;
67    use reqwest::Certificate;
68    use security_framework::os::macos::keychain::SecKeychain;
69    use tracing::error;
70
71    /// Application identifier for macOS Keychain
72    const KEYCHAIN_ID: &str = "malwaredb-client";
73
74    /// Entry ID for the Malware DB server URL
75    const KEYCHAIN_URL: &str = "URL";
76
77    /// Entry ID for the user's Malware DB API Key
78    const KEYCHAIN_API_KEY: &str = "API_KEY";
79
80    /// Entry ID for the custom PEM-encoded certificate for the Malware DB server
81    const KEYCHAIN_CERTIFICATE_PEM: &str = "CERT_PEM";
82
83    /// Entry ID for the custom DER-encoded certificate for the Malware DB server
84    const KEYCHAIN_CERTIFICATE_DER: &str = "CERT_DER";
85
86    pub(crate) struct CertificateData {
87        pub cert_type: CertificateType,
88        pub cert_bytes: Vec<u8>,
89    }
90
91    impl CertificateData {
92        pub(crate) fn as_cert(&self) -> Result<Certificate> {
93            Ok(match self.cert_type {
94                CertificateType::PEM => Certificate::from_pem(&self.cert_bytes)?,
95                CertificateType::DER => Certificate::from_der(&self.cert_bytes)?,
96            })
97        }
98    }
99
100    /// Save the elements to they Keychain
101    pub fn save_credentials(url: &str, key: &str, cert: Option<CertificateData>) -> Result<()> {
102        let keychain = SecKeychain::default()?;
103
104        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_URL, url.as_bytes())?;
105        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY, key.as_bytes())?;
106
107        if let Some(cert) = cert {
108            match cert.cert_type {
109                CertificateType::PEM => keychain.add_generic_password(
110                    KEYCHAIN_ID,
111                    KEYCHAIN_CERTIFICATE_PEM,
112                    &cert.cert_bytes,
113                )?,
114                CertificateType::DER => keychain.add_generic_password(
115                    KEYCHAIN_ID,
116                    KEYCHAIN_CERTIFICATE_DER,
117                    &cert.cert_bytes,
118                )?,
119            }
120        }
121
122        Ok(())
123    }
124
125    /// Return key, url, and optionally the certificate in that order. Errors silently discarded.
126    pub fn retrieve_credentials() -> Result<(String, String, Option<CertificateData>)> {
127        let keychain = SecKeychain::default()?;
128        let (api_key, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY)?;
129        let api_key = String::from_utf8(api_key.as_ref().to_vec())?;
130        let (url, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_URL)?;
131        let url = String::from_utf8(url.as_ref().to_vec())?;
132
133        if let Ok((cert, _item)) =
134            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_PEM)
135        {
136            let cert = CertificateData {
137                cert_type: CertificateType::PEM,
138                cert_bytes: cert.to_vec(),
139            };
140            return Ok((api_key, url, Some(cert)));
141        }
142
143        if let Ok((cert, _item)) =
144            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_DER)
145        {
146            let cert = CertificateData {
147                cert_type: CertificateType::DER,
148                cert_bytes: cert.to_vec(),
149            };
150            return Ok((api_key, url, Some(cert)));
151        }
152
153        Ok((api_key, url, None))
154    }
155
156    /// Delete Malware DB client information from the Keychain
157    pub fn clear_credentials() {
158        if let Ok(keychain) = SecKeychain::default() {
159            for element in [
160                KEYCHAIN_API_KEY,
161                KEYCHAIN_URL,
162                KEYCHAIN_CERTIFICATE_PEM,
163                KEYCHAIN_CERTIFICATE_DER,
164            ] {
165                if let Ok((_, item)) = keychain.find_generic_password(KEYCHAIN_ID, element) {
166                    item.delete();
167                }
168            }
169        } else {
170            error!("Failed to get access to the Keychain to clear credentials");
171        }
172    }
173}
174
175#[allow(clippy::upper_case_acronyms)]
176#[derive(Copy, Clone, PartialEq, Eq)]
177enum CertificateType {
178    DER,
179    PEM,
180}
181
182/// Asynchronous Malware DB Client Configuration and connection
183#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
184pub struct MdbClient {
185    /// URL of the Malware DB server, including http and port number, ending without a slash
186    pub url: String,
187
188    /// User's API key for Malware DB
189    api_key: String,
190
191    /// Async http client which stores the optional server certificate
192    #[zeroize(skip)]
193    #[serde(skip)]
194    client: reqwest::Client,
195
196    /// Server's certificate
197    #[cfg(target_os = "macos")]
198    #[zeroize(skip)]
199    #[serde(skip)]
200    cert: Option<macos::CertificateData>,
201}
202
203impl MdbClient {
204    /// MDB Client from components, doesn't test connectivity
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if a list of certificates was passed and any were not in the expected
209    /// DER or PEM format or could not be parsed.
210    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
211        let mut url = url;
212        let url = if url.ends_with('/') {
213            url.pop();
214            url
215        } else {
216            url
217        };
218
219        let cert = if let Some(path) = cert_path {
220            Some((path_load_cert(&path)?, path))
221        } else {
222            None
223        };
224
225        let builder = reqwest::ClientBuilder::new()
226            .gzip(true)
227            .zstd(true)
228            .use_rustls_tls()
229            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
230
231        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
232            builder.add_root_certificate(cert.clone()).build()
233        } else {
234            builder.build()
235        }?;
236
237        #[cfg(target_os = "macos")]
238        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
239            Some(macos::CertificateData {
240                cert_type: *cert_type,
241                cert_bytes: std::fs::read(cert_path)?,
242            })
243        } else {
244            None
245        };
246
247        Ok(Self {
248            url,
249            api_key,
250            client,
251
252            #[cfg(target_os = "macos")]
253            cert,
254        })
255    }
256
257    /// Login to a server, optionally save the configuration file, and return a client object
258    ///
259    /// # Errors
260    ///
261    /// Returns an error if the server URL, username, or password were incorrect, or if a network
262    /// issue occurred.
263    pub async fn login(
264        url: String,
265        username: String,
266        password: String,
267        save: bool,
268        cert_path: Option<PathBuf>,
269    ) -> Result<Self> {
270        let mut url = url;
271        let url = if url.ends_with('/') {
272            url.pop();
273            url
274        } else {
275            url
276        };
277
278        let api_request = malwaredb_api::GetAPIKeyRequest {
279            user: username,
280            password,
281        };
282
283        let builder = reqwest::ClientBuilder::new()
284            .gzip(true)
285            .zstd(true)
286            .use_rustls_tls()
287            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
288
289        let cert = if let Some(path) = cert_path {
290            Some((path_load_cert(&path)?, path))
291        } else {
292            None
293        };
294
295        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
296            builder.add_root_certificate(cert.clone()).build()
297        } else {
298            builder.build()
299        }?;
300
301        let res = client
302            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
303            .json(&api_request)
304            .send()
305            .await?
306            .json::<ServerResponse<GetAPIKeyResponse>>()
307            .await
308            .context(MDB_CLIENT_ERROR_CONTEXT)?;
309
310        let res = match res {
311            ServerResponse::Success(res) => res,
312            ServerResponse::Error(err) => return Err(err.into()),
313        };
314
315        #[cfg(target_os = "macos")]
316        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
317            Some(macos::CertificateData {
318                cert_type: *cert_type,
319                cert_bytes: std::fs::read(cert_path)?,
320            })
321        } else {
322            None
323        };
324
325        let client = MdbClient {
326            url,
327            api_key: res.key.clone(),
328            client,
329
330            #[cfg(target_os = "macos")]
331            cert,
332        };
333
334        let server_info = client.server_info().await?;
335        if server_info.mdb_version > *MDB_VERSION_SEMVER {
336            warn!(
337                "Server version {:?} is newer than client {:?}, consider updating.",
338                server_info.mdb_version, MDB_VERSION_SEMVER
339            );
340        }
341
342        if save {
343            if let Err(e) = client.save() {
344                error!("Login successful but failed to save config: {e}");
345                bail!("Login successful but failed to save config: {e}");
346            }
347        }
348        Ok(client)
349    }
350
351    /// Reset one's own API key to effectively logout & disable all clients who are using the key
352    ///
353    /// # Errors
354    ///
355    /// Returns an error if there was a network issue or the user wasn't properly logged in.
356    pub async fn reset_key(&self) -> Result<()> {
357        let response = self
358            .client
359            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
360            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
361            .send()
362            .await
363            .context(MDB_CLIENT_ERROR_CONTEXT)?;
364        if !response.status().is_success() {
365            bail!("failed to reset API key, was it correct?");
366        }
367        Ok(())
368    }
369
370    /// Malware DB Client configuration loaded from a specified path
371    ///
372    /// # Errors
373    ///
374    /// Returns an error if the configuration file cannot be read, possibly because it
375    /// doesn't exist or due to a permission error or a parsing error.
376    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
377        let name = path.as_ref().display();
378        let config =
379            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
380        let cfg: MdbClient =
381            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
382        Ok(cfg)
383    }
384
385    /// Malware DB Client configuration from user's home directory
386    ///
387    /// On macOS, it will attempt to load this information in the Keychain, which isn't required.
388    ///
389    /// # Errors
390    ///
391    /// Returns an error if the configuration file cannot be read, possibly because it
392    /// doesn't exist or due to a permission error or a parsing error.
393    pub fn load() -> Result<Self> {
394        #[cfg(target_os = "macos")]
395        {
396            if let Ok((api_key, url, cert)) = macos::retrieve_credentials() {
397                let builder = reqwest::ClientBuilder::new()
398                    .gzip(true)
399                    .zstd(true)
400                    .use_rustls_tls()
401                    .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
402
403                let client = if let Some(cert) = &cert {
404                    builder.add_root_certificate(cert.as_cert()?).build()
405                } else {
406                    builder.build()
407                }?;
408
409                return Ok(Self {
410                    url,
411                    api_key,
412                    client,
413                    cert,
414                });
415            }
416        }
417
418        let path = get_config_path(false)?;
419        if path.exists() {
420            return Self::from_file(path);
421        }
422        bail!("config file not found")
423    }
424
425    /// Save Malware DB Client configuration to the user's home directory.
426    ///
427    /// On macOS, it will attempt to save this information in the Keychain, which isn't required.
428    ///
429    /// # Errors
430    ///
431    /// Returns an error if there was a problem saving the configuration file.
432    pub fn save(&self) -> Result<()> {
433        #[cfg(target_os = "macos")]
434        {
435            if macos::save_credentials(&self.url, &self.api_key, None).is_ok() {
436                return Ok(());
437            }
438        }
439
440        let toml = toml::to_string(self)?;
441        let path = get_config_path(true)?;
442        std::fs::write(&path, toml)
443            .context(format!("failed to write mdb config to {}", path.display()))
444    }
445
446    /// Delete the Malware DB client configuration file
447    ///
448    /// # Errors
449    ///
450    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
451    /// possibly due to a permissions error.
452    pub fn delete(&self) -> Result<()> {
453        #[cfg(target_os = "macos")]
454        macos::clear_credentials();
455
456        let path = get_config_path(false)?;
457        if path.exists() {
458            std::fs::remove_file(&path).context(format!(
459                "failed to delete client config file {}",
460                path.display()
461            ))?;
462        }
463        Ok(())
464    }
465
466    // Actions of the client
467
468    /// Get information about the server, unauthenticated
469    ///
470    /// # Errors
471    ///
472    /// This may return an error if there's a network situation.
473    pub async fn server_info(&self) -> Result<ServerInfo> {
474        let response = self
475            .client
476            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
477            .send()
478            .await?
479            .json::<ServerResponse<ServerInfo>>()
480            .await
481            .context(MDB_CLIENT_ERROR_CONTEXT)?;
482
483        match response {
484            ServerResponse::Success(info) => Ok(info),
485            ServerResponse::Error(e) => Err(e.into()),
486        }
487    }
488
489    /// Get file types supported by the server, unauthenticated
490    ///
491    /// # Errors
492    ///
493    /// This may return an error if there's a network situation.
494    pub async fn supported_types(&self) -> Result<SupportedFileTypes> {
495        let response = self
496            .client
497            .get(format!(
498                "{}{}",
499                self.url,
500                malwaredb_api::SUPPORTED_FILE_TYPES_URL
501            ))
502            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
503            .send()
504            .await?
505            .json::<ServerResponse<SupportedFileTypes>>()
506            .await
507            .context(MDB_CLIENT_ERROR_CONTEXT)?;
508
509        match response {
510            ServerResponse::Success(types) => Ok(types),
511            ServerResponse::Error(e) => Err(e.into()),
512        }
513    }
514
515    /// Get information about the user
516    ///
517    /// # Errors
518    ///
519    /// This may return an error if there's a network situation or if the user is not logged in
520    /// or not properly authorized to connect.
521    pub async fn whoami(&self) -> Result<GetUserInfoResponse> {
522        let response = self
523            .client
524            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
525            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
526            .send()
527            .await?
528            .json::<ServerResponse<GetUserInfoResponse>>()
529            .await
530            .context(MDB_CLIENT_ERROR_CONTEXT)?;
531
532        match response {
533            ServerResponse::Success(info) => Ok(info),
534            ServerResponse::Error(e) => Err(e.into()),
535        }
536    }
537
538    /// Get the sample labels known to the server
539    ///
540    /// # Errors
541    ///
542    /// This may return an error if there's a network situation or if the user is not logged in
543    /// or not properly authorized to connect.
544    pub async fn labels(&self) -> Result<Labels> {
545        let response = self
546            .client
547            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
548            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
549            .send()
550            .await?
551            .json::<ServerResponse<Labels>>()
552            .await
553            .context(MDB_CLIENT_ERROR_CONTEXT)?;
554
555        match response {
556            ServerResponse::Success(labels) => Ok(labels),
557            ServerResponse::Error(e) => Err(e.into()),
558        }
559    }
560
561    /// Get the sources available to the current user
562    ///
563    /// # Errors
564    ///
565    /// This may return an error if there's a network situation or if the user is not logged in
566    /// or not properly authorized to connect.
567    pub async fn sources(&self) -> Result<Sources> {
568        let response = self
569            .client
570            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
571            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
572            .send()
573            .await?
574            .json::<ServerResponse<Sources>>()
575            .await
576            .context(MDB_CLIENT_ERROR_CONTEXT)?;
577
578        match response {
579            ServerResponse::Success(sources) => Ok(sources),
580            ServerResponse::Error(e) => Err(e.into()),
581        }
582    }
583
584    /// Submit one file to `MalwareDB`: provide the contents, file name, and source ID
585    ///
586    /// # Errors
587    ///
588    /// This may return an error if there's a network situation or if the user is not logged in
589    /// or not properly authorized to connect.
590    pub async fn submit(
591        &self,
592        contents: impl AsRef<[u8]>,
593        file_name: impl AsRef<str>,
594        source_id: u32,
595    ) -> Result<bool> {
596        let mut hasher = Sha256::new();
597        hasher.update(&contents);
598        let result = hasher.finalize();
599
600        let encoded = general_purpose::STANDARD.encode(contents);
601
602        let payload = malwaredb_api::NewSampleB64 {
603            file_name: file_name.as_ref().to_string(),
604            source_id,
605            file_contents_b64: encoded,
606            sha256: hex::encode(result),
607        };
608
609        match self
610            .client
611            .post(format!(
612                "{}{}",
613                self.url,
614                malwaredb_api::UPLOAD_SAMPLE_JSON_URL
615            ))
616            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
617            .json(&payload)
618            .send()
619            .await
620        {
621            Ok(res) => {
622                if !res.status().is_success() {
623                    info!("Code {} sending {}", res.status(), payload.file_name);
624                }
625                Ok(res.status().is_success())
626            }
627            Err(e) => {
628                let status: String = e
629                    .status()
630                    .map(|s| s.as_str().to_string())
631                    .unwrap_or_default();
632                error!("Error{status} sending {}: {e}", payload.file_name);
633                bail!(e.to_string())
634            }
635        }
636    }
637
638    /// Submit one file to `MalwareDB` as a Cbor object: provide the contents, file name, and source ID
639    /// Experimental! May be removed at any point.
640    ///
641    /// # Errors
642    ///
643    /// This may return an error if there's a network situation or if the user is not logged in
644    /// or not properly authorized to connect.
645    pub async fn submit_as_cbor(
646        &self,
647        contents: impl AsRef<[u8]>,
648        file_name: impl AsRef<str>,
649        source_id: u32,
650    ) -> Result<bool> {
651        let mut hasher = Sha256::new();
652        hasher.update(&contents);
653        let result = hasher.finalize();
654
655        let payload = malwaredb_api::NewSampleBytes {
656            file_name: file_name.as_ref().to_string(),
657            source_id,
658            file_contents: contents.as_ref().to_vec(),
659            sha256: hex::encode(result),
660        };
661
662        let mut bytes = Vec::with_capacity(payload.file_contents.len());
663        ciborium::ser::into_writer(&payload, &mut bytes)?;
664
665        match self
666            .client
667            .post(format!(
668                "{}{}",
669                self.url,
670                malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
671            ))
672            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
673            .header("content-type", "application/cbor")
674            .body(bytes)
675            .send()
676            .await
677        {
678            Ok(res) => {
679                if !res.status().is_success() {
680                    info!("Code {} sending {}", res.status(), payload.file_name);
681                }
682                Ok(res.status().is_success())
683            }
684            Err(e) => {
685                let status: String = e
686                    .status()
687                    .map(|s| s.as_str().to_string())
688                    .unwrap_or_default();
689                error!("Error{status} sending {}: {e}", payload.file_name);
690                bail!(e.to_string())
691            }
692        }
693    }
694
695    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
696    ///
697    /// # Errors
698    ///
699    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
700    pub async fn partial_search(
701        &self,
702        partial_hash: Option<(PartialHashSearchType, String)>,
703        name: Option<String>,
704        response: PartialHashSearchType,
705        limit: u32,
706    ) -> Result<SearchResponse> {
707        let query = SearchRequest {
708            search: SearchType::Search(SearchRequestParameters {
709                partial_hash,
710                file_name: name,
711                response,
712                limit,
713                labels: None,
714                file_type: None,
715                magic: None,
716            }),
717        };
718
719        self.do_search_request(&query).await
720    }
721
722    /// Search for a file based on partial hash and/or partial file name, labels, file type; returns a list of hashes
723    ///
724    /// # Errors
725    ///
726    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
727    #[allow(clippy::too_many_arguments)]
728    pub async fn partial_search_labels_type(
729        &self,
730        partial_hash: Option<(PartialHashSearchType, String)>,
731        name: Option<String>,
732        response: PartialHashSearchType,
733        labels: Option<Vec<String>>,
734        file_type: Option<String>,
735        magic: Option<String>,
736        limit: u32,
737    ) -> Result<SearchResponse> {
738        let query = SearchRequest {
739            search: SearchType::Search(SearchRequestParameters {
740                partial_hash,
741                file_name: name,
742                response,
743                limit,
744                file_type,
745                magic,
746                labels,
747            }),
748        };
749
750        self.do_search_request(&query).await
751    }
752
753    /// Return the next page from the search result
754    ///
755    /// # Errors
756    ///
757    /// Returns an error if there is a network problem, or pagination not available
758    pub async fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
759        if let Some(uuid) = response.pagination {
760            let request = SearchRequest {
761                search: SearchType::Continuation(uuid),
762            };
763            return self.do_search_request(&request).await;
764        }
765
766        bail!("Pagination not available")
767    }
768
769    async fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
770        ensure!(
771            query.is_valid(),
772            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
773        );
774
775        let response = self
776            .client
777            .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
778            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
779            .json(query)
780            .send()
781            .await?
782            .json::<ServerResponse<SearchResponse>>()
783            .await
784            .context(MDB_CLIENT_ERROR_CONTEXT)?;
785
786        match response {
787            ServerResponse::Success(search) => Ok(search),
788            ServerResponse::Error(e) => Err(e.into()),
789        }
790    }
791
792    /// Retrieve sample by hash, optionally in the `CaRT` format
793    ///
794    /// # Errors
795    ///
796    /// This may return an error if there's a network situation or if the user is not logged in
797    /// or not properly authorized to connect.
798    pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
799        let api_endpoint = if cart {
800            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
801        } else {
802            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
803        };
804
805        let res = self
806            .client
807            .get(format!("{}{api_endpoint}", self.url))
808            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
809            .send()
810            .await?;
811
812        if !res.status().is_success() {
813            bail!("Received code {}", res.status());
814        }
815
816        let body = res.bytes().await?;
817        Ok(body.to_vec())
818    }
819
820    /// Fetch a report for a sample
821    ///
822    /// # Errors
823    ///
824    /// This may return an error if there's a network situation or if the user is not logged in
825    /// or not properly authorized to connect.
826    pub async fn report(&self, hash: &str) -> Result<Report> {
827        let response = self
828            .client
829            .get(format!(
830                "{}{}/{hash}",
831                self.url,
832                malwaredb_api::SAMPLE_REPORT_URL
833            ))
834            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
835            .send()
836            .await?
837            .json::<ServerResponse<Report>>()
838            .await
839            .context(MDB_CLIENT_ERROR_CONTEXT)?;
840
841        match response {
842            ServerResponse::Success(report) => Ok(report),
843            ServerResponse::Error(e) => Err(e.into()),
844        }
845    }
846
847    /// Find similar samples in `MalwareDB` based on the contents of a given file.
848    /// This does not submit the sample to `MalwareDB`.
849    ///
850    /// # Errors
851    ///
852    /// This may return an error if there's a network situation or if the user is not logged in
853    /// or not properly authorized to connect.
854    pub async fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
855        let mut hashes = vec![];
856        let ssdeep_hash = FuzzyHash::new(contents);
857
858        let build_hasher = Murmur3HashState::default();
859        let lzjd_str =
860            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
861        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
862        hashes.push((
863            malwaredb_api::SimilarityHashType::SSDeep,
864            ssdeep_hash.to_string(),
865        ));
866
867        let mut builder = TlshBuilder::new(
868            tlsh_fixed::BucketKind::Bucket256,
869            tlsh_fixed::ChecksumKind::ThreeByte,
870            tlsh_fixed::Version::Version4,
871        );
872
873        builder.update(contents);
874        if let Ok(hasher) = builder.build() {
875            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
876        }
877
878        if let Ok(exe) = EXE::from(contents) {
879            if let Some(imports) = exe.imports {
880                hashes.push((
881                    malwaredb_api::SimilarityHashType::ImportHash,
882                    hex::encode(imports.hash()),
883                ));
884                hashes.push((
885                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
886                    imports.fuzzy_hash(),
887                ));
888            }
889        }
890
891        let request = malwaredb_api::SimilarSamplesRequest { hashes };
892
893        let response = self
894            .client
895            .post(format!(
896                "{}{}",
897                self.url,
898                malwaredb_api::SIMILAR_SAMPLES_URL
899            ))
900            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
901            .json(&request)
902            .send()
903            .await?
904            .json::<ServerResponse<SimilarSamplesResponse>>()
905            .await
906            .context(MDB_CLIENT_ERROR_CONTEXT)?;
907
908        match response {
909            ServerResponse::Success(similar) => Ok(similar),
910            ServerResponse::Error(e) => Err(e.into()),
911        }
912    }
913}
914
915impl Debug for MdbClient {
916    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
917        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
918    }
919}
920
921/// Convenience function for encoding bytes into a `CaRT` file using the default key. This also
922/// adds SHA-384 and SHA-512 hashes plus the entropy of the original file.
923/// See <https://github.com/CybercentreCanada/cart> for more information.
924///
925/// # Errors
926///
927/// There should not be any errors, but the underlying library can't guarantee that. However, any data
928/// passed to it is correct.
929pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
930    let mut input_buffer = Cursor::new(data);
931    let mut output_buffer = Cursor::new(vec![]);
932    let mut output_metadata = JsonMap::new();
933
934    let mut sha384 = Sha384::new();
935    sha384.update(data);
936    let sha384 = hex::encode(sha384.finalize());
937
938    let mut sha512 = Sha512::new();
939    sha512.update(data);
940    let sha512 = hex::encode(sha512.finalize());
941
942    output_metadata.insert("sha384".into(), sha384.into());
943    output_metadata.insert("sha512".into(), sha512.into());
944    output_metadata.insert("entropy".into(), entropy_calc(data).into());
945    cart_container::pack_stream(
946        &mut input_buffer,
947        &mut output_buffer,
948        Some(output_metadata),
949        None,
950        cart_container::digesters::default_digesters(),
951        None,
952    )?;
953
954    Ok(output_buffer.into_inner())
955}
956
957/// Convenience function for decoding a `CaRT` file using the default key, returning the bytes plus the
958/// optional header and footer metadata, if present.
959/// See <https://github.com/CybercentreCanada/cart> for more information.
960///
961/// # Errors
962///
963/// Returns an error if the file cannot be parsed or if this `CaRT` file didn't use the default key.
964/// <https://github.com/CybercentreCanada/cart-rs/blob/7ad548143bb85b64f364804e90cfada6c31cf902/cart_container/src/cipher.rs#L14-L17>
965pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
966    let mut input_buffer = Cursor::new(data);
967    let mut output_buffer = Cursor::new(vec![]);
968    let (header, footer) =
969        cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
970    Ok((output_buffer.into_inner(), header, footer))
971}
972
973/// Load a certificate from a path
974///
975/// # Errors
976///
977/// Returns errors if the file cannot be read or if the file isn't an ASN.1 DER file or
978/// base64-encoded ASN.1 PEM file.
979fn path_load_cert(path: &Path) -> Result<(CertificateType, Certificate)> {
980    if !path.exists() {
981        bail!("Certificate {} does not exist.", path.display());
982    }
983    let cert = match path
984        .extension()
985        .context("can't determine file extension")?
986        .to_str()
987        .context("unable to parse file extension")?
988    {
989        "pem" => {
990            let contents = std::fs::read(path)?;
991            (CertificateType::PEM, Certificate::from_pem(&contents)?)
992        }
993        "der" => {
994            let contents = std::fs::read(path)?;
995            (CertificateType::DER, Certificate::from_der(&contents)?)
996        }
997        ext => {
998            bail!("Unknown extension {ext:?}")
999        }
1000    };
1001    Ok(cert)
1002}
1003
1004/// Gets the configuration file in the following order:
1005///
1006/// 1. Current working directory: `./mdb_client.toml`
1007/// 2. Haiku-specific directory if on Haiku
1008/// 3. XDG Free Desktop directory if on Unix
1009/// 4. The user's home directory in `~/.config/malwaredb_client/mdb_client.toml`
1010/// 5. `mdb_client.toml` in the current directory (same as the first, but after checking others)
1011#[inline]
1012pub(crate) fn get_config_path(create: bool) -> Result<PathBuf> {
1013    // If there is a config file in the current working directory, use it
1014    let config = PathBuf::from(MDB_CLIENT_CONFIG_TOML);
1015    if config.exists() {
1016        return Ok(config);
1017    }
1018
1019    #[cfg(target_os = "haiku")]
1020    {
1021        let mut settings = PathBuf::from("/boot/home/config/settings/malwaredb");
1022        if create && !settings.exists() {
1023            std::fs::create_dir_all(&settings)?;
1024        }
1025        settings.push(MDB_CLIENT_CONFIG_TOML);
1026        return Ok(settings);
1027    }
1028
1029    #[cfg(unix)]
1030    {
1031        // Obey the Free Desktop standard, check the variable
1032        if let Some(xdg_home) = std::env::var_os("XDG_CONFIG_HOME") {
1033            let mut xdg_config_home = PathBuf::from(xdg_home);
1034            xdg_config_home.push(MDB_CLIENT_DIR);
1035            if create && !xdg_config_home.exists() {
1036                std::fs::create_dir_all(&xdg_config_home)?;
1037            }
1038            xdg_config_home.push(MDB_CLIENT_CONFIG_TOML);
1039            return Ok(xdg_config_home);
1040        }
1041    }
1042
1043    if let Some(mut home_config) = home_dir() {
1044        home_config.push(".config");
1045        home_config.push(MDB_CLIENT_DIR);
1046        if create && !home_config.exists() {
1047            std::fs::create_dir_all(&home_config)?;
1048        }
1049        home_config.push(MDB_CLIENT_CONFIG_TOML);
1050        return Ok(home_config);
1051    }
1052
1053    Ok(PathBuf::from(MDB_CLIENT_CONFIG_TOML))
1054}
1055
1056/// Malware DB entries found by Multicast DNS (also known as Bonjour or Zeroconf)
1057#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1058pub struct MalwareDBServer {
1059    /// Server IP or domain
1060    pub host: String,
1061
1062    /// Server port
1063    pub port: u16,
1064
1065    /// If the server expects an encrypted connection
1066    pub ssl: bool,
1067
1068    /// Malware DB server name
1069    pub name: String,
1070}
1071
1072impl Display for MalwareDBServer {
1073    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1074        if self.ssl {
1075            write!(f, "https://{}:{}", self.host, self.port)
1076        } else {
1077            write!(f, "http://{}:{}", self.host, self.port)
1078        }
1079    }
1080}
1081
1082impl MalwareDBServer {
1083    /// Retrieve details about the server
1084    ///
1085    /// # Errors
1086    ///
1087    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1088    pub async fn server_info(&self) -> Result<ServerInfo> {
1089        let client = reqwest::ClientBuilder::new()
1090            .gzip(true)
1091            .zstd(true)
1092            .use_rustls_tls()
1093            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1094            .build()?;
1095
1096        let response = client
1097            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1098            .send()
1099            .await?
1100            .json::<ServerResponse<ServerInfo>>()
1101            .await
1102            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1103
1104        match response {
1105            ServerResponse::Success(info) => Ok(info),
1106            ServerResponse::Error(e) => Err(e.into()),
1107        }
1108    }
1109
1110    /// Retrieve details about the server
1111    ///
1112    /// # Errors
1113    ///
1114    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1115    ///
1116    /// # Panics
1117    ///
1118    /// This method panics if called from within an async runtime.
1119    #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
1120    #[cfg(feature = "blocking")]
1121    pub fn server_info_blocking(&self) -> Result<ServerInfo> {
1122        let client = reqwest::blocking::ClientBuilder::new()
1123            .gzip(true)
1124            .zstd(true)
1125            .use_rustls_tls()
1126            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1127            .build()?;
1128
1129        let response = client
1130            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1131            .send()?
1132            .json::<ServerResponse<ServerInfo>>()
1133            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1134
1135        match response {
1136            ServerResponse::Success(similar) => Ok(similar),
1137            ServerResponse::Error(e) => Err(e.into()),
1138        }
1139    }
1140}
1141
1142/// Find servers using Multicast DNS (also known as Bonjour)
1143///
1144/// # Errors
1145///
1146/// This may fail if there's a networking issue.
1147pub fn discover_servers() -> Result<Vec<MalwareDBServer>> {
1148    const MAX_ITERS: usize = 5;
1149    let mdns = ServiceDaemon::new()?;
1150    let mut servers = HashSet::new();
1151    let receiver = mdns.browse(malwaredb_api::MDNS_NAME)?;
1152
1153    let mut counter = 0;
1154    while let Ok(event) = receiver.recv() {
1155        if let ServiceEvent::ServiceResolved(resolved) = event {
1156            let host = resolved.host.replace(".local.", "");
1157            let ssl = if let Some(ssl) = resolved.txt_properties.get("ssl") {
1158                ssl.val_str() == "true"
1159            } else {
1160                debug!(
1161                    "MalwareDB entry for {host}:{} doesn't specify ssl, assuming not",
1162                    resolved.port
1163                );
1164                false
1165            };
1166
1167            let server = MalwareDBServer {
1168                host,
1169                port: resolved.port,
1170                ssl,
1171                name: resolved.fullname.replace(malwaredb_api::MDNS_NAME, ""),
1172            };
1173
1174            servers.insert(server);
1175        }
1176        counter += 1;
1177        if counter > MAX_ITERS {
1178            break;
1179        }
1180    }
1181
1182    Ok(servers.into_iter().collect())
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187    use super::*;
1188
1189    #[test]
1190    fn cart() {
1191        const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
1192        const ORIGINAL_SHA256: &str =
1193            "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
1194
1195        let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
1196
1197        let mut sha256 = Sha256::new();
1198        sha256.update(&decoded);
1199        let sha256 = hex::encode(sha256.finalize());
1200        assert_eq!(sha256, ORIGINAL_SHA256);
1201
1202        let header = header.unwrap();
1203        let entropy = header.get("entropy").unwrap().as_f64().unwrap();
1204        assert!(entropy > 4.0 && entropy < 4.1);
1205
1206        let footer = footer.unwrap();
1207        assert_eq!(footer.get("length").unwrap(), "5093");
1208        assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
1209    }
1210}