Skip to main content

malwaredb_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8#![forbid(unsafe_code)]
9
10/// Non-async version of the Malware DB client
11#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
12#[cfg(feature = "blocking")]
13pub mod blocking;
14
15pub use malwaredb_api;
16use malwaredb_api::{
17    digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
18    Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
19    ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
20    YaraSearchRequestResponse, YaraSearchResponse,
21};
22use malwaredb_lzjd::{LZDict, Murmur3HashState};
23use malwaredb_types::exec::pe32::EXE;
24use malwaredb_types::utils::entropy_calc;
25
26use std::collections::HashSet;
27use std::fmt::{Debug, Display, Formatter};
28use std::io::Cursor;
29use std::path::{Path, PathBuf};
30use std::sync::LazyLock;
31
32use anyhow::{bail, ensure, Context, Result};
33use base64::engine::general_purpose;
34use base64::Engine;
35use cart_container::JsonMap;
36use fuzzyhash::FuzzyHash;
37use home::home_dir;
38use mdns_sd::{ServiceDaemon, ServiceEvent};
39use reqwest::Certificate;
40use serde::{Deserialize, Serialize};
41use sha2::{Digest, Sha256, Sha384, Sha512};
42use tlsh_fixed::TlshBuilder;
43use tracing::{debug, error, info, trace, warn};
44use uuid::Uuid;
45use zeroize::{Zeroize, ZeroizeOnDrop};
46
47/// Local directory for the Malware DB client configs
48const MDB_CLIENT_DIR: &str = "malwaredb_client";
49
50/// Error for Anyhow's `context()` function with regard to what is expected just a network error
51pub(crate) const MDB_CLIENT_ERROR_CONTEXT: &str =
52    "Network error connecting to MalwareDB, or failure to decode server response.";
53
54/// Config file name expected by Malware DB client
55const MDB_CLIENT_CONFIG_TOML: &str = "mdb_client.toml";
56
57/// MDB version
58pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
59
60/// MDB version as a semantic version object
61pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
62    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
63
64/// macOS Keychain functionality
65#[cfg(target_os = "macos")]
66pub(crate) mod macos {
67    use crate::CertificateType;
68
69    use anyhow::Result;
70    use reqwest::Certificate;
71    use security_framework::os::macos::keychain::SecKeychain;
72    use tracing::error;
73
74    /// Application identifier for macOS Keychain
75    const KEYCHAIN_ID: &str = "malwaredb-client";
76
77    /// Entry ID for the Malware DB server URL
78    const KEYCHAIN_URL: &str = "URL";
79
80    /// Entry ID for the user's Malware DB API Key
81    const KEYCHAIN_API_KEY: &str = "API_KEY";
82
83    /// Entry ID for the custom PEM-encoded certificate for the Malware DB server
84    const KEYCHAIN_CERTIFICATE_PEM: &str = "CERT_PEM";
85
86    /// Entry ID for the custom DER-encoded certificate for the Malware DB server
87    const KEYCHAIN_CERTIFICATE_DER: &str = "CERT_DER";
88
89    #[derive(Clone)]
90    pub(crate) struct CertificateData {
91        pub cert_type: CertificateType,
92        pub cert_bytes: Vec<u8>,
93    }
94
95    impl CertificateData {
96        pub(crate) fn as_cert(&self) -> Result<Certificate> {
97            Ok(match self.cert_type {
98                CertificateType::PEM => Certificate::from_pem(&self.cert_bytes)?,
99                CertificateType::DER => Certificate::from_der(&self.cert_bytes)?,
100            })
101        }
102    }
103
104    /// Save the elements to they Keychain
105    pub fn save_credentials(url: &str, key: &str, cert: Option<CertificateData>) -> Result<()> {
106        let keychain = SecKeychain::default()?;
107
108        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_URL, url.as_bytes())?;
109        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY, key.as_bytes())?;
110
111        if let Some(cert) = cert {
112            match cert.cert_type {
113                CertificateType::PEM => keychain.add_generic_password(
114                    KEYCHAIN_ID,
115                    KEYCHAIN_CERTIFICATE_PEM,
116                    &cert.cert_bytes,
117                )?,
118                CertificateType::DER => keychain.add_generic_password(
119                    KEYCHAIN_ID,
120                    KEYCHAIN_CERTIFICATE_DER,
121                    &cert.cert_bytes,
122                )?,
123            }
124        }
125
126        Ok(())
127    }
128
129    /// Return key, url, and optionally the certificate in that order. Errors silently discarded.
130    pub fn retrieve_credentials() -> Result<(String, String, Option<CertificateData>)> {
131        let keychain = SecKeychain::default()?;
132        let (api_key, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY)?;
133        let api_key = String::from_utf8(api_key.as_ref().to_vec())?;
134        let (url, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_URL)?;
135        let url = String::from_utf8(url.as_ref().to_vec())?;
136
137        if let Ok((cert, _item)) =
138            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_PEM)
139        {
140            let cert = CertificateData {
141                cert_type: CertificateType::PEM,
142                cert_bytes: cert.to_vec(),
143            };
144            return Ok((api_key, url, Some(cert)));
145        }
146
147        if let Ok((cert, _item)) =
148            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_DER)
149        {
150            let cert = CertificateData {
151                cert_type: CertificateType::DER,
152                cert_bytes: cert.to_vec(),
153            };
154            return Ok((api_key, url, Some(cert)));
155        }
156
157        Ok((api_key, url, None))
158    }
159
160    /// Delete Malware DB client information from the Keychain
161    pub fn clear_credentials() {
162        if let Ok(keychain) = SecKeychain::default() {
163            for element in [
164                KEYCHAIN_API_KEY,
165                KEYCHAIN_URL,
166                KEYCHAIN_CERTIFICATE_PEM,
167                KEYCHAIN_CERTIFICATE_DER,
168            ] {
169                if let Ok((_, item)) = keychain.find_generic_password(KEYCHAIN_ID, element) {
170                    item.delete();
171                }
172            }
173        } else {
174            error!("Failed to get access to the Keychain to clear credentials");
175        }
176    }
177}
178
179#[allow(clippy::upper_case_acronyms)]
180#[derive(Copy, Clone, PartialEq, Eq)]
181enum CertificateType {
182    DER,
183    PEM,
184}
185
186/// Asynchronous Malware DB Client Configuration and connection
187#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
188pub struct MdbClient {
189    /// URL of the Malware DB server, including http and port number, ending without a slash
190    pub url: String,
191
192    /// User's API key for Malware DB
193    api_key: String,
194
195    /// Async http client which stores the optional server certificate
196    #[zeroize(skip)]
197    #[serde(skip)]
198    client: reqwest::Client,
199
200    /// Server's certificate
201    #[cfg(target_os = "macos")]
202    #[zeroize(skip)]
203    #[serde(skip)]
204    cert: Option<macos::CertificateData>,
205}
206
207impl MdbClient {
208    /// MDB Client from components, doesn't test connectivity
209    ///
210    /// # Errors
211    ///
212    /// Returns an error if a list of certificates was passed and any were not in the expected
213    /// DER or PEM format or could not be parsed.
214    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
215        let mut url = url;
216        let url = if url.ends_with('/') {
217            url.pop();
218            url
219        } else {
220            url
221        };
222
223        let cert = if let Some(path) = cert_path {
224            Some((path_load_cert(&path)?, path))
225        } else {
226            None
227        };
228
229        let builder = reqwest::ClientBuilder::new()
230            .gzip(true)
231            .zstd(true)
232            .use_rustls_tls()
233            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
234
235        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
236            builder.add_root_certificate(cert.clone()).build()
237        } else {
238            builder.build()
239        }?;
240
241        #[cfg(target_os = "macos")]
242        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
243            Some(macos::CertificateData {
244                cert_type: *cert_type,
245                cert_bytes: std::fs::read(cert_path)?,
246            })
247        } else {
248            None
249        };
250
251        Ok(Self {
252            url,
253            api_key,
254            client,
255
256            #[cfg(target_os = "macos")]
257            cert,
258        })
259    }
260
261    /// Login to a server, optionally save the configuration file, and return a client object
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if the server URL, username, or password were incorrect, or if a network
266    /// issue occurred.
267    pub async fn login(
268        url: String,
269        username: String,
270        password: String,
271        save: bool,
272        cert_path: Option<PathBuf>,
273    ) -> Result<Self> {
274        let mut url = url;
275        let url = if url.ends_with('/') {
276            url.pop();
277            url
278        } else {
279            url
280        };
281
282        let api_request = malwaredb_api::GetAPIKeyRequest {
283            user: username,
284            password,
285        };
286
287        let builder = reqwest::ClientBuilder::new()
288            .gzip(true)
289            .zstd(true)
290            .use_rustls_tls()
291            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
292
293        let cert = if let Some(path) = cert_path {
294            Some((path_load_cert(&path)?, path))
295        } else {
296            None
297        };
298
299        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
300            builder.add_root_certificate(cert.clone()).build()
301        } else {
302            builder.build()
303        }?;
304
305        let res = client
306            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
307            .json(&api_request)
308            .send()
309            .await?
310            .json::<ServerResponse<GetAPIKeyResponse>>()
311            .await
312            .context(MDB_CLIENT_ERROR_CONTEXT)?;
313
314        let res = match res {
315            ServerResponse::Success(res) => res,
316            ServerResponse::Error(err) => return Err(err.into()),
317        };
318
319        #[cfg(target_os = "macos")]
320        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
321            Some(macos::CertificateData {
322                cert_type: *cert_type,
323                cert_bytes: std::fs::read(cert_path)?,
324            })
325        } else {
326            None
327        };
328
329        let client = MdbClient {
330            url,
331            api_key: res.key.clone(),
332            client,
333
334            #[cfg(target_os = "macos")]
335            cert,
336        };
337
338        let server_info = client.server_info().await?;
339        if server_info.mdb_version > *MDB_VERSION_SEMVER {
340            warn!(
341                "Server version {:?} is newer than client {:?}, consider updating.",
342                server_info.mdb_version, MDB_VERSION_SEMVER
343            );
344        }
345
346        if save {
347            if let Err(e) = client.save() {
348                error!("Login successful but failed to save config: {e}");
349                bail!("Login successful but failed to save config: {e}");
350            }
351        }
352        Ok(client)
353    }
354
355    /// Reset one's own API key to effectively logout & disable all clients who are using the key
356    ///
357    /// # Errors
358    ///
359    /// Returns an error if there was a network issue or the user wasn't properly logged in.
360    pub async fn reset_key(&self) -> Result<()> {
361        let response = self
362            .client
363            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
364            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
365            .send()
366            .await
367            .context(MDB_CLIENT_ERROR_CONTEXT)?;
368        if !response.status().is_success() {
369            bail!("failed to reset API key, was it correct?");
370        }
371        Ok(())
372    }
373
374    /// Malware DB Client configuration loaded from a specified path
375    ///
376    /// # Errors
377    ///
378    /// Returns an error if the configuration file cannot be read, possibly because it
379    /// doesn't exist or due to a permission error or a parsing error.
380    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
381        let name = path.as_ref().display();
382        let config =
383            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
384        let cfg: MdbClient =
385            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
386        Ok(cfg)
387    }
388
389    /// Malware DB Client configuration from user's home directory
390    ///
391    /// On macOS, it will attempt to load this information in the Keychain, which isn't required.
392    ///
393    /// # Errors
394    ///
395    /// Returns an error if the configuration file cannot be read, possibly because it
396    /// doesn't exist or due to a permission error or a parsing error.
397    pub fn load() -> Result<Self> {
398        #[cfg(target_os = "macos")]
399        {
400            if let Ok((api_key, url, cert)) = macos::retrieve_credentials() {
401                let builder = reqwest::ClientBuilder::new()
402                    .gzip(true)
403                    .zstd(true)
404                    .use_rustls_tls()
405                    .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
406
407                let client = if let Some(cert) = &cert {
408                    builder.add_root_certificate(cert.as_cert()?).build()
409                } else {
410                    builder.build()
411                }?;
412
413                return Ok(Self {
414                    url,
415                    api_key,
416                    client,
417                    cert,
418                });
419            }
420        }
421
422        let path = get_config_path(false)?;
423        if path.exists() {
424            return Self::from_file(path);
425        }
426        bail!("config file not found")
427    }
428
429    /// Save Malware DB Client configuration to the user's home directory.
430    ///
431    /// On macOS, it will attempt to save this information in the Keychain, which isn't required.
432    ///
433    /// # Errors
434    ///
435    /// Returns an error if there was a problem saving the configuration file.
436    pub fn save(&self) -> Result<()> {
437        #[cfg(target_os = "macos")]
438        {
439            if macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
440                return Ok(());
441            }
442        }
443
444        let toml = toml::to_string(self)?;
445        let path = get_config_path(true)?;
446        std::fs::write(&path, toml)
447            .context(format!("failed to write mdb config to {}", path.display()))
448    }
449
450    /// Delete the Malware DB client configuration file
451    ///
452    /// # Errors
453    ///
454    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
455    /// possibly due to a permissions error.
456    pub fn delete(&self) -> Result<()> {
457        #[cfg(target_os = "macos")]
458        macos::clear_credentials();
459
460        let path = get_config_path(false)?;
461        if path.exists() {
462            std::fs::remove_file(&path).context(format!(
463                "failed to delete client config file {}",
464                path.display()
465            ))?;
466        }
467        Ok(())
468    }
469
470    // Actions of the client
471
472    /// Get information about the server, unauthenticated
473    ///
474    /// # Errors
475    ///
476    /// This may return an error if there's a network situation.
477    pub async fn server_info(&self) -> Result<ServerInfo> {
478        let response = self
479            .client
480            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
481            .send()
482            .await?
483            .json::<ServerResponse<ServerInfo>>()
484            .await
485            .context(MDB_CLIENT_ERROR_CONTEXT)?;
486
487        match response {
488            ServerResponse::Success(info) => Ok(info),
489            ServerResponse::Error(e) => Err(e.into()),
490        }
491    }
492
493    /// Get file types supported by the server, unauthenticated
494    ///
495    /// # Errors
496    ///
497    /// This may return an error if there's a network situation.
498    pub async fn supported_types(&self) -> Result<SupportedFileTypes> {
499        let response = self
500            .client
501            .get(format!(
502                "{}{}",
503                self.url,
504                malwaredb_api::SUPPORTED_FILE_TYPES_URL
505            ))
506            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
507            .send()
508            .await?
509            .json::<ServerResponse<SupportedFileTypes>>()
510            .await
511            .context(MDB_CLIENT_ERROR_CONTEXT)?;
512
513        match response {
514            ServerResponse::Success(types) => Ok(types),
515            ServerResponse::Error(e) => Err(e.into()),
516        }
517    }
518
519    /// Get information about the user
520    ///
521    /// # Errors
522    ///
523    /// This may return an error if there's a network situation or if the user is not logged in
524    /// or not properly authorized to connect.
525    pub async fn whoami(&self) -> Result<GetUserInfoResponse> {
526        let response = self
527            .client
528            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
529            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
530            .send()
531            .await?
532            .json::<ServerResponse<GetUserInfoResponse>>()
533            .await
534            .context(MDB_CLIENT_ERROR_CONTEXT)?;
535
536        match response {
537            ServerResponse::Success(info) => Ok(info),
538            ServerResponse::Error(e) => Err(e.into()),
539        }
540    }
541
542    /// Get the sample labels known to the server
543    ///
544    /// # Errors
545    ///
546    /// This may return an error if there's a network situation or if the user is not logged in
547    /// or not properly authorized to connect.
548    pub async fn labels(&self) -> Result<Labels> {
549        let response = self
550            .client
551            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
552            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
553            .send()
554            .await?
555            .json::<ServerResponse<Labels>>()
556            .await
557            .context(MDB_CLIENT_ERROR_CONTEXT)?;
558
559        match response {
560            ServerResponse::Success(labels) => Ok(labels),
561            ServerResponse::Error(e) => Err(e.into()),
562        }
563    }
564
565    /// Get the sources available to the current user
566    ///
567    /// # Errors
568    ///
569    /// This may return an error if there's a network situation or if the user is not logged in
570    /// or not properly authorized to connect.
571    pub async fn sources(&self) -> Result<Sources> {
572        let response = self
573            .client
574            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
575            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
576            .send()
577            .await?
578            .json::<ServerResponse<Sources>>()
579            .await
580            .context(MDB_CLIENT_ERROR_CONTEXT)?;
581
582        match response {
583            ServerResponse::Success(sources) => Ok(sources),
584            ServerResponse::Error(e) => Err(e.into()),
585        }
586    }
587
588    /// Submit one file to `MalwareDB`: provide the contents, file name, and source ID
589    ///
590    /// # Errors
591    ///
592    /// This may return an error if there's a network situation or if the user is not logged in
593    /// or not properly authorized to connect.
594    pub async fn submit(
595        &self,
596        contents: impl AsRef<[u8]>,
597        file_name: impl AsRef<str>,
598        source_id: u32,
599    ) -> Result<bool> {
600        let mut hasher = Sha256::new();
601        hasher.update(&contents);
602        let result = hasher.finalize();
603
604        let encoded = general_purpose::STANDARD.encode(contents);
605
606        let payload = malwaredb_api::NewSampleB64 {
607            file_name: file_name.as_ref().to_string(),
608            source_id,
609            file_contents_b64: encoded,
610            sha256: hex::encode(result),
611        };
612
613        match self
614            .client
615            .post(format!(
616                "{}{}",
617                self.url,
618                malwaredb_api::UPLOAD_SAMPLE_JSON_URL
619            ))
620            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
621            .json(&payload)
622            .send()
623            .await
624        {
625            Ok(res) => {
626                if !res.status().is_success() {
627                    info!("Code {} sending {}", res.status(), payload.file_name);
628                }
629                Ok(res.status().is_success())
630            }
631            Err(e) => {
632                let status: String = e
633                    .status()
634                    .map(|s| s.as_str().to_string())
635                    .unwrap_or_default();
636                error!("Error{status} sending {}: {e}", payload.file_name);
637                bail!(e.to_string())
638            }
639        }
640    }
641
642    /// Submit one file to `MalwareDB` as a Cbor object: provide the contents, file name, and source ID
643    /// Experimental! May be removed at any point.
644    ///
645    /// # Errors
646    ///
647    /// This may return an error if there's a network situation or if the user is not logged in
648    /// or not properly authorized to connect.
649    pub async fn submit_as_cbor(
650        &self,
651        contents: impl AsRef<[u8]>,
652        file_name: impl AsRef<str>,
653        source_id: u32,
654    ) -> Result<bool> {
655        let mut hasher = Sha256::new();
656        hasher.update(&contents);
657        let result = hasher.finalize();
658
659        let payload = malwaredb_api::NewSampleBytes {
660            file_name: file_name.as_ref().to_string(),
661            source_id,
662            file_contents: contents.as_ref().to_vec(),
663            sha256: hex::encode(result),
664        };
665
666        let mut bytes = Vec::with_capacity(payload.file_contents.len());
667        ciborium::ser::into_writer(&payload, &mut bytes)?;
668
669        match self
670            .client
671            .post(format!(
672                "{}{}",
673                self.url,
674                malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
675            ))
676            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
677            .header("content-type", "application/cbor")
678            .body(bytes)
679            .send()
680            .await
681        {
682            Ok(res) => {
683                if !res.status().is_success() {
684                    info!("Code {} sending {}", res.status(), payload.file_name);
685                }
686                Ok(res.status().is_success())
687            }
688            Err(e) => {
689                let status: String = e
690                    .status()
691                    .map(|s| s.as_str().to_string())
692                    .unwrap_or_default();
693                error!("Error{status} sending {}: {e}", payload.file_name);
694                bail!(e.to_string())
695            }
696        }
697    }
698
699    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
700    ///
701    /// # Errors
702    ///
703    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
704    pub async fn partial_search(
705        &self,
706        partial_hash: Option<(PartialHashSearchType, String)>,
707        name: Option<String>,
708        response: PartialHashSearchType,
709        limit: u32,
710    ) -> Result<SearchResponse> {
711        let query = SearchRequest {
712            search: SearchType::Search(SearchRequestParameters {
713                partial_hash,
714                file_name: name,
715                response,
716                limit,
717                labels: None,
718                file_type: None,
719                magic: None,
720            }),
721        };
722
723        self.do_search_request(&query).await
724    }
725
726    /// Search for a file based on partial hash and/or partial file name, labels, file type; returns a list of hashes
727    ///
728    /// # Errors
729    ///
730    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
731    #[allow(clippy::too_many_arguments)]
732    pub async fn partial_search_labels_type(
733        &self,
734        partial_hash: Option<(PartialHashSearchType, String)>,
735        name: Option<String>,
736        response: PartialHashSearchType,
737        labels: Option<Vec<String>>,
738        file_type: Option<String>,
739        magic: Option<String>,
740        limit: u32,
741    ) -> Result<SearchResponse> {
742        let query = SearchRequest {
743            search: SearchType::Search(SearchRequestParameters {
744                partial_hash,
745                file_name: name,
746                response,
747                limit,
748                file_type,
749                magic,
750                labels,
751            }),
752        };
753
754        self.do_search_request(&query).await
755    }
756
757    /// Return the next page from the search result
758    ///
759    /// # Errors
760    ///
761    /// Returns an error if there is a network problem, or pagination not available
762    pub async fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
763        if let Some(uuid) = response.pagination {
764            let request = SearchRequest {
765                search: SearchType::Continuation(uuid),
766            };
767            return self.do_search_request(&request).await;
768        }
769
770        bail!("Pagination not available")
771    }
772
773    async fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
774        ensure!(
775            query.is_valid(),
776            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
777        );
778
779        let response = self
780            .client
781            .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
782            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
783            .json(query)
784            .send()
785            .await?
786            .json::<ServerResponse<SearchResponse>>()
787            .await
788            .context(MDB_CLIENT_ERROR_CONTEXT)?;
789
790        match response {
791            ServerResponse::Success(search) => Ok(search),
792            ServerResponse::Error(e) => Err(e.into()),
793        }
794    }
795
796    /// Retrieve sample by hash, optionally in the `CaRT` format
797    ///
798    /// # Errors
799    ///
800    /// This may return an error if there's a network situation or if the user is not logged in
801    /// or not properly authorized to connect.
802    pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
803        let api_endpoint = if cart {
804            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
805        } else {
806            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
807        };
808
809        let res = self
810            .client
811            .get(format!("{}{api_endpoint}", self.url))
812            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
813            .send()
814            .await?;
815
816        if !res.status().is_success() {
817            bail!("Received code {}", res.status());
818        }
819
820        let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
821        let body = res.bytes().await?;
822        let bytes = body.to_vec();
823
824        // TODO: Make this required in v0.3
825        if let Some(digest) = content_digest {
826            let hash = HashType::from_content_digest_header(digest.to_str()?)?;
827            if hash.verify(&bytes) {
828                trace!("Hash verified for sample {hash}");
829            } else {
830                error!("Hash mismatch for sample {hash}");
831            }
832        } else {
833            warn!("No content digest header received for sample {hash}");
834        }
835
836        Ok(bytes)
837    }
838
839    /// Fetch a report for a sample
840    ///
841    /// # Errors
842    ///
843    /// This may return an error if there's a network situation or if the user is not logged in
844    /// or not properly authorized to connect.
845    pub async fn report(&self, hash: &str) -> Result<Report> {
846        let response = self
847            .client
848            .get(format!(
849                "{}{}/{hash}",
850                self.url,
851                malwaredb_api::SAMPLE_REPORT_URL
852            ))
853            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
854            .send()
855            .await?
856            .json::<ServerResponse<Report>>()
857            .await
858            .context(MDB_CLIENT_ERROR_CONTEXT)?;
859
860        match response {
861            ServerResponse::Success(report) => Ok(report),
862            ServerResponse::Error(e) => Err(e.into()),
863        }
864    }
865
866    /// Find similar samples in `MalwareDB` based on the contents of a given file.
867    /// This does not submit the sample to `MalwareDB`.
868    ///
869    /// # Errors
870    ///
871    /// This may return an error if there's a network situation or if the user is not logged in
872    /// or not properly authorized to connect.
873    pub async fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
874        let mut hashes = vec![];
875        let ssdeep_hash = FuzzyHash::new(contents);
876
877        let build_hasher = Murmur3HashState::default();
878        let lzjd_str =
879            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
880        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
881        hashes.push((
882            malwaredb_api::SimilarityHashType::SSDeep,
883            ssdeep_hash.to_string(),
884        ));
885
886        let mut builder = TlshBuilder::new(
887            tlsh_fixed::BucketKind::Bucket256,
888            tlsh_fixed::ChecksumKind::ThreeByte,
889            tlsh_fixed::Version::Version4,
890        );
891
892        builder.update(contents);
893        if let Ok(hasher) = builder.build() {
894            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
895        }
896
897        if let Ok(exe) = EXE::from(contents) {
898            if let Some(imports) = exe.imports {
899                hashes.push((
900                    malwaredb_api::SimilarityHashType::ImportHash,
901                    hex::encode(imports.hash()),
902                ));
903                hashes.push((
904                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
905                    imports.fuzzy_hash(),
906                ));
907            }
908        }
909
910        let request = malwaredb_api::SimilarSamplesRequest { hashes };
911
912        let response = self
913            .client
914            .post(format!(
915                "{}{}",
916                self.url,
917                malwaredb_api::SIMILAR_SAMPLES_URL
918            ))
919            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
920            .json(&request)
921            .send()
922            .await?
923            .json::<ServerResponse<SimilarSamplesResponse>>()
924            .await
925            .context(MDB_CLIENT_ERROR_CONTEXT)?;
926
927        match response {
928            ServerResponse::Success(similar) => Ok(similar),
929            ServerResponse::Error(e) => Err(e.into()),
930        }
931    }
932
933    /// Submit a Yara rule and return the UUID of the search for later retrieval.
934    ///
935    /// # Errors
936    ///
937    /// Network or authentication errors
938    pub async fn yara_search(&self, yara: &str) -> Result<YaraSearchRequestResponse> {
939        let yara = YaraSearchRequest {
940            rules: vec![yara.to_string()],
941            response: PartialHashSearchType::SHA256,
942        };
943
944        let response = self
945            .client
946            .post(format!("{}{}", self.url, malwaredb_api::YARA_SEARCH_URL))
947            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
948            .json(&yara)
949            .send()
950            .await?
951            .json::<ServerResponse<YaraSearchRequestResponse>>()
952            .await?;
953
954        match response {
955            ServerResponse::Success(sources) => Ok(sources),
956            ServerResponse::Error(e) => Err(e.into()),
957        }
958    }
959
960    /// Get the result from a Yara search
961    ///
962    /// # Errors
963    ///
964    /// Network or authentication errors
965    pub async fn yara_result(&self, uuid: Uuid) -> Result<YaraSearchResponse> {
966        let response = self
967            .client
968            .get(format!(
969                "{}{}/{uuid}",
970                self.url,
971                malwaredb_api::YARA_SEARCH_URL
972            ))
973            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
974            .send()
975            .await?
976            .json::<ServerResponse<YaraSearchResponse>>()
977            .await?;
978
979        match response {
980            ServerResponse::Success(sources) => Ok(sources),
981            ServerResponse::Error(e) => Err(e.into()),
982        }
983    }
984}
985
986impl Debug for MdbClient {
987    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
988        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
989    }
990}
991
992/// Convenience function for encoding bytes into a `CaRT` file using the default key. This also
993/// adds SHA-384 and SHA-512 hashes plus the entropy of the original file.
994/// See <https://github.com/CybercentreCanada/cart> for more information.
995///
996/// # Errors
997///
998/// There should not be any errors, but the underlying library can't guarantee that. However, any data
999/// passed to it is correct.
1000pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
1001    let mut input_buffer = Cursor::new(data);
1002    let mut output_buffer = Cursor::new(vec![]);
1003    let mut output_metadata = JsonMap::new();
1004
1005    let mut sha384 = Sha384::new();
1006    sha384.update(data);
1007    let sha384 = hex::encode(sha384.finalize());
1008
1009    let mut sha512 = Sha512::new();
1010    sha512.update(data);
1011    let sha512 = hex::encode(sha512.finalize());
1012
1013    output_metadata.insert("sha384".into(), sha384.into());
1014    output_metadata.insert("sha512".into(), sha512.into());
1015    output_metadata.insert("entropy".into(), entropy_calc(data).into());
1016    cart_container::pack_stream(
1017        &mut input_buffer,
1018        &mut output_buffer,
1019        Some(output_metadata),
1020        None,
1021        cart_container::digesters::default_digesters(),
1022        None,
1023    )?;
1024
1025    Ok(output_buffer.into_inner())
1026}
1027
1028/// Convenience function for decoding a `CaRT` file using the default key, returning the bytes plus the
1029/// optional header and footer metadata, if present.
1030/// See <https://github.com/CybercentreCanada/cart> for more information.
1031///
1032/// # Errors
1033///
1034/// Returns an error if the file cannot be parsed or if this `CaRT` file didn't use the default key.
1035/// <https://github.com/CybercentreCanada/cart-rs/blob/7ad548143bb85b64f364804e90cfada6c31cf902/cart_container/src/cipher.rs#L14-L17>
1036pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
1037    let mut input_buffer = Cursor::new(data);
1038    let mut output_buffer = Cursor::new(vec![]);
1039    let (header, footer) =
1040        cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
1041    Ok((output_buffer.into_inner(), header, footer))
1042}
1043
1044/// Load a certificate from a path
1045///
1046/// # Errors
1047///
1048/// Returns errors if the file cannot be read or if the file isn't an ASN.1 DER file or
1049/// base64-encoded ASN.1 PEM file.
1050fn path_load_cert(path: &Path) -> Result<(CertificateType, Certificate)> {
1051    if !path.exists() {
1052        bail!("Certificate {} does not exist.", path.display());
1053    }
1054    let cert = match path
1055        .extension()
1056        .context("can't determine file extension")?
1057        .to_str()
1058        .context("unable to parse file extension")?
1059    {
1060        "pem" => {
1061            let contents = std::fs::read(path)?;
1062            (CertificateType::PEM, Certificate::from_pem(&contents)?)
1063        }
1064        "der" => {
1065            let contents = std::fs::read(path)?;
1066            (CertificateType::DER, Certificate::from_der(&contents)?)
1067        }
1068        ext => {
1069            bail!("Unknown extension {ext:?}")
1070        }
1071    };
1072    Ok(cert)
1073}
1074
1075/// Gets the configuration file in the following order:
1076///
1077/// 1. Current working directory: `./mdb_client.toml`
1078/// 2. Haiku-specific directory if on Haiku
1079/// 3. XDG Free Desktop directory if on Unix
1080/// 4. The user's home directory in `~/.config/malwaredb_client/mdb_client.toml`
1081/// 5. `mdb_client.toml` in the current directory (same as the first, but after checking others)
1082#[inline]
1083pub(crate) fn get_config_path(create: bool) -> Result<PathBuf> {
1084    // If there is a config file in the current working directory, use it
1085    let config = PathBuf::from(MDB_CLIENT_CONFIG_TOML);
1086    if config.exists() {
1087        return Ok(config);
1088    }
1089
1090    #[cfg(target_os = "haiku")]
1091    {
1092        let mut settings = PathBuf::from("/boot/home/config/settings/malwaredb");
1093        if create && !settings.exists() {
1094            std::fs::create_dir_all(&settings)?;
1095        }
1096        settings.push(MDB_CLIENT_CONFIG_TOML);
1097        return Ok(settings);
1098    }
1099
1100    #[cfg(unix)]
1101    {
1102        // Obey the Free Desktop standard, check the variable
1103        if let Some(xdg_home) = std::env::var_os("XDG_CONFIG_HOME") {
1104            let mut xdg_config_home = PathBuf::from(xdg_home);
1105            xdg_config_home.push(MDB_CLIENT_DIR);
1106            if create && !xdg_config_home.exists() {
1107                std::fs::create_dir_all(&xdg_config_home)?;
1108            }
1109            xdg_config_home.push(MDB_CLIENT_CONFIG_TOML);
1110            return Ok(xdg_config_home);
1111        }
1112    }
1113
1114    if let Some(mut home_config) = home_dir() {
1115        home_config.push(".config");
1116        home_config.push(MDB_CLIENT_DIR);
1117        if create && !home_config.exists() {
1118            std::fs::create_dir_all(&home_config)?;
1119        }
1120        home_config.push(MDB_CLIENT_CONFIG_TOML);
1121        return Ok(home_config);
1122    }
1123
1124    Ok(PathBuf::from(MDB_CLIENT_CONFIG_TOML))
1125}
1126
1127/// Malware DB entries found by Multicast DNS (also known as Bonjour or Zeroconf)
1128#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1129pub struct MalwareDBServer {
1130    /// Server IP or domain
1131    pub host: String,
1132
1133    /// Server port
1134    pub port: u16,
1135
1136    /// If the server expects an encrypted connection
1137    pub ssl: bool,
1138
1139    /// Malware DB server name
1140    pub name: String,
1141}
1142
1143impl Display for MalwareDBServer {
1144    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1145        if self.ssl {
1146            write!(f, "https://{}:{}", self.host, self.port)
1147        } else {
1148            write!(f, "http://{}:{}", self.host, self.port)
1149        }
1150    }
1151}
1152
1153impl MalwareDBServer {
1154    /// Retrieve details about the server
1155    ///
1156    /// # Errors
1157    ///
1158    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1159    pub async fn server_info(&self) -> Result<ServerInfo> {
1160        let client = reqwest::ClientBuilder::new()
1161            .gzip(true)
1162            .zstd(true)
1163            .use_rustls_tls()
1164            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1165            .build()?;
1166
1167        let response = client
1168            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1169            .send()
1170            .await?
1171            .json::<ServerResponse<ServerInfo>>()
1172            .await
1173            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1174
1175        match response {
1176            ServerResponse::Success(info) => Ok(info),
1177            ServerResponse::Error(e) => Err(e.into()),
1178        }
1179    }
1180
1181    /// Retrieve details about the server
1182    ///
1183    /// # Errors
1184    ///
1185    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1186    ///
1187    /// # Panics
1188    ///
1189    /// This method panics if called from within an async runtime.
1190    #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
1191    #[cfg(feature = "blocking")]
1192    pub fn server_info_blocking(&self) -> Result<ServerInfo> {
1193        let client = reqwest::blocking::ClientBuilder::new()
1194            .gzip(true)
1195            .zstd(true)
1196            .use_rustls_tls()
1197            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1198            .build()?;
1199
1200        let response = client
1201            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1202            .send()?
1203            .json::<ServerResponse<ServerInfo>>()
1204            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1205
1206        match response {
1207            ServerResponse::Success(similar) => Ok(similar),
1208            ServerResponse::Error(e) => Err(e.into()),
1209        }
1210    }
1211}
1212
1213/// Find servers using Multicast DNS (also known as Bonjour)
1214///
1215/// # Errors
1216///
1217/// This may fail if there's a networking issue.
1218pub fn discover_servers() -> Result<Vec<MalwareDBServer>> {
1219    const MAX_ITERS: usize = 5;
1220    let mdns = ServiceDaemon::new()?;
1221    let mut servers = HashSet::new();
1222    let receiver = mdns.browse(malwaredb_api::MDNS_NAME)?;
1223
1224    let mut counter = 0;
1225    while let Ok(event) = receiver.recv() {
1226        if let ServiceEvent::ServiceResolved(resolved) = event {
1227            let host = resolved.host.replace(".local.", "");
1228            let ssl = if let Some(ssl) = resolved.txt_properties.get("ssl") {
1229                ssl.val_str() == "true"
1230            } else {
1231                debug!(
1232                    "MalwareDB entry for {host}:{} doesn't specify ssl, assuming not",
1233                    resolved.port
1234                );
1235                false
1236            };
1237
1238            let server = MalwareDBServer {
1239                host,
1240                port: resolved.port,
1241                ssl,
1242                name: resolved.fullname.replace(malwaredb_api::MDNS_NAME, ""),
1243            };
1244
1245            servers.insert(server);
1246        }
1247        counter += 1;
1248        if counter > MAX_ITERS {
1249            break;
1250        }
1251    }
1252
1253    Ok(servers.into_iter().collect())
1254}
1255
1256#[cfg(test)]
1257mod tests {
1258    use super::*;
1259
1260    #[test]
1261    fn cart() {
1262        const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
1263        const ORIGINAL_SHA256: &str =
1264            "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
1265
1266        let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
1267
1268        let mut sha256 = Sha256::new();
1269        sha256.update(&decoded);
1270        let sha256 = hex::encode(sha256.finalize());
1271        assert_eq!(sha256, ORIGINAL_SHA256);
1272
1273        let header = header.unwrap();
1274        let entropy = header.get("entropy").unwrap().as_f64().unwrap();
1275        assert!(entropy > 4.0 && entropy < 4.1);
1276
1277        let footer = footer.unwrap();
1278        assert_eq!(footer.get("length").unwrap(), "5093");
1279        assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
1280    }
1281}