malwaredb_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![deny(missing_docs)]
6#![deny(clippy::all)]
7#![deny(clippy::pedantic)]
8#![forbid(unsafe_code)]
9
10/// Non-async version of the Malware DB client
11#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
12#[cfg(feature = "blocking")]
13pub mod blocking;
14
15pub use malwaredb_api;
16use malwaredb_api::{
17    digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
18    Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
19    ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes,
20};
21use malwaredb_lzjd::{LZDict, Murmur3HashState};
22use malwaredb_types::exec::pe32::EXE;
23use malwaredb_types::utils::entropy_calc;
24
25use std::collections::HashSet;
26use std::fmt::{Debug, Display, Formatter};
27use std::io::Cursor;
28use std::path::{Path, PathBuf};
29use std::sync::LazyLock;
30
31use anyhow::{bail, ensure, Context, Result};
32use base64::engine::general_purpose;
33use base64::Engine;
34use cart_container::JsonMap;
35use fuzzyhash::FuzzyHash;
36use home::home_dir;
37use mdns_sd::{ServiceDaemon, ServiceEvent};
38use reqwest::Certificate;
39use serde::{Deserialize, Serialize};
40use sha2::{Digest, Sha256, Sha384, Sha512};
41use tlsh_fixed::TlshBuilder;
42use tracing::{debug, error, info, trace, warn};
43use zeroize::{Zeroize, ZeroizeOnDrop};
44
45/// Local directory for the Malware DB client configs
46const MDB_CLIENT_DIR: &str = "malwaredb_client";
47
48/// Error for Anyhow's `context()` function with regard to what is expected just a network error
49pub(crate) const MDB_CLIENT_ERROR_CONTEXT: &str =
50    "Network error connecting to MalwareDB, or failure to decode server response.";
51
52/// Config file name expected by Malware DB client
53const MDB_CLIENT_CONFIG_TOML: &str = "mdb_client.toml";
54
55/// MDB version
56pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
57
58/// MDB version as a semantic version object
59pub static MDB_VERSION_SEMVER: LazyLock<semver::Version> =
60    LazyLock::new(|| semver::Version::parse(MDB_VERSION).unwrap());
61
62/// macOS Keychain functionality
63#[cfg(target_os = "macos")]
64pub(crate) mod macos {
65    use crate::CertificateType;
66
67    use anyhow::Result;
68    use reqwest::Certificate;
69    use security_framework::os::macos::keychain::SecKeychain;
70    use tracing::error;
71
72    /// Application identifier for macOS Keychain
73    const KEYCHAIN_ID: &str = "malwaredb-client";
74
75    /// Entry ID for the Malware DB server URL
76    const KEYCHAIN_URL: &str = "URL";
77
78    /// Entry ID for the user's Malware DB API Key
79    const KEYCHAIN_API_KEY: &str = "API_KEY";
80
81    /// Entry ID for the custom PEM-encoded certificate for the Malware DB server
82    const KEYCHAIN_CERTIFICATE_PEM: &str = "CERT_PEM";
83
84    /// Entry ID for the custom DER-encoded certificate for the Malware DB server
85    const KEYCHAIN_CERTIFICATE_DER: &str = "CERT_DER";
86
87    #[derive(Clone)]
88    pub(crate) struct CertificateData {
89        pub cert_type: CertificateType,
90        pub cert_bytes: Vec<u8>,
91    }
92
93    impl CertificateData {
94        pub(crate) fn as_cert(&self) -> Result<Certificate> {
95            Ok(match self.cert_type {
96                CertificateType::PEM => Certificate::from_pem(&self.cert_bytes)?,
97                CertificateType::DER => Certificate::from_der(&self.cert_bytes)?,
98            })
99        }
100    }
101
102    /// Save the elements to they Keychain
103    pub fn save_credentials(url: &str, key: &str, cert: Option<CertificateData>) -> Result<()> {
104        let keychain = SecKeychain::default()?;
105
106        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_URL, url.as_bytes())?;
107        keychain.add_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY, key.as_bytes())?;
108
109        if let Some(cert) = cert {
110            match cert.cert_type {
111                CertificateType::PEM => keychain.add_generic_password(
112                    KEYCHAIN_ID,
113                    KEYCHAIN_CERTIFICATE_PEM,
114                    &cert.cert_bytes,
115                )?,
116                CertificateType::DER => keychain.add_generic_password(
117                    KEYCHAIN_ID,
118                    KEYCHAIN_CERTIFICATE_DER,
119                    &cert.cert_bytes,
120                )?,
121            }
122        }
123
124        Ok(())
125    }
126
127    /// Return key, url, and optionally the certificate in that order. Errors silently discarded.
128    pub fn retrieve_credentials() -> Result<(String, String, Option<CertificateData>)> {
129        let keychain = SecKeychain::default()?;
130        let (api_key, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_API_KEY)?;
131        let api_key = String::from_utf8(api_key.as_ref().to_vec())?;
132        let (url, _item) = keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_URL)?;
133        let url = String::from_utf8(url.as_ref().to_vec())?;
134
135        if let Ok((cert, _item)) =
136            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_PEM)
137        {
138            let cert = CertificateData {
139                cert_type: CertificateType::PEM,
140                cert_bytes: cert.to_vec(),
141            };
142            return Ok((api_key, url, Some(cert)));
143        }
144
145        if let Ok((cert, _item)) =
146            keychain.find_generic_password(KEYCHAIN_ID, KEYCHAIN_CERTIFICATE_DER)
147        {
148            let cert = CertificateData {
149                cert_type: CertificateType::DER,
150                cert_bytes: cert.to_vec(),
151            };
152            return Ok((api_key, url, Some(cert)));
153        }
154
155        Ok((api_key, url, None))
156    }
157
158    /// Delete Malware DB client information from the Keychain
159    pub fn clear_credentials() {
160        if let Ok(keychain) = SecKeychain::default() {
161            for element in [
162                KEYCHAIN_API_KEY,
163                KEYCHAIN_URL,
164                KEYCHAIN_CERTIFICATE_PEM,
165                KEYCHAIN_CERTIFICATE_DER,
166            ] {
167                if let Ok((_, item)) = keychain.find_generic_password(KEYCHAIN_ID, element) {
168                    item.delete();
169                }
170            }
171        } else {
172            error!("Failed to get access to the Keychain to clear credentials");
173        }
174    }
175}
176
177#[allow(clippy::upper_case_acronyms)]
178#[derive(Copy, Clone, PartialEq, Eq)]
179enum CertificateType {
180    DER,
181    PEM,
182}
183
184/// Asynchronous Malware DB Client Configuration and connection
185#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
186pub struct MdbClient {
187    /// URL of the Malware DB server, including http and port number, ending without a slash
188    pub url: String,
189
190    /// User's API key for Malware DB
191    api_key: String,
192
193    /// Async http client which stores the optional server certificate
194    #[zeroize(skip)]
195    #[serde(skip)]
196    client: reqwest::Client,
197
198    /// Server's certificate
199    #[cfg(target_os = "macos")]
200    #[zeroize(skip)]
201    #[serde(skip)]
202    cert: Option<macos::CertificateData>,
203}
204
205impl MdbClient {
206    /// MDB Client from components, doesn't test connectivity
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if a list of certificates was passed and any were not in the expected
211    /// DER or PEM format or could not be parsed.
212    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
213        let mut url = url;
214        let url = if url.ends_with('/') {
215            url.pop();
216            url
217        } else {
218            url
219        };
220
221        let cert = if let Some(path) = cert_path {
222            Some((path_load_cert(&path)?, path))
223        } else {
224            None
225        };
226
227        let builder = reqwest::ClientBuilder::new()
228            .gzip(true)
229            .zstd(true)
230            .use_rustls_tls()
231            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
232
233        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
234            builder.add_root_certificate(cert.clone()).build()
235        } else {
236            builder.build()
237        }?;
238
239        #[cfg(target_os = "macos")]
240        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
241            Some(macos::CertificateData {
242                cert_type: *cert_type,
243                cert_bytes: std::fs::read(cert_path)?,
244            })
245        } else {
246            None
247        };
248
249        Ok(Self {
250            url,
251            api_key,
252            client,
253
254            #[cfg(target_os = "macos")]
255            cert,
256        })
257    }
258
259    /// Login to a server, optionally save the configuration file, and return a client object
260    ///
261    /// # Errors
262    ///
263    /// Returns an error if the server URL, username, or password were incorrect, or if a network
264    /// issue occurred.
265    pub async fn login(
266        url: String,
267        username: String,
268        password: String,
269        save: bool,
270        cert_path: Option<PathBuf>,
271    ) -> Result<Self> {
272        let mut url = url;
273        let url = if url.ends_with('/') {
274            url.pop();
275            url
276        } else {
277            url
278        };
279
280        let api_request = malwaredb_api::GetAPIKeyRequest {
281            user: username,
282            password,
283        };
284
285        let builder = reqwest::ClientBuilder::new()
286            .gzip(true)
287            .zstd(true)
288            .use_rustls_tls()
289            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
290
291        let cert = if let Some(path) = cert_path {
292            Some((path_load_cert(&path)?, path))
293        } else {
294            None
295        };
296
297        let client = if let Some(((_cert_type, cert), _cert_path)) = &cert {
298            builder.add_root_certificate(cert.clone()).build()
299        } else {
300            builder.build()
301        }?;
302
303        let res = client
304            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
305            .json(&api_request)
306            .send()
307            .await?
308            .json::<ServerResponse<GetAPIKeyResponse>>()
309            .await
310            .context(MDB_CLIENT_ERROR_CONTEXT)?;
311
312        let res = match res {
313            ServerResponse::Success(res) => res,
314            ServerResponse::Error(err) => return Err(err.into()),
315        };
316
317        #[cfg(target_os = "macos")]
318        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
319            Some(macos::CertificateData {
320                cert_type: *cert_type,
321                cert_bytes: std::fs::read(cert_path)?,
322            })
323        } else {
324            None
325        };
326
327        let client = MdbClient {
328            url,
329            api_key: res.key.clone(),
330            client,
331
332            #[cfg(target_os = "macos")]
333            cert,
334        };
335
336        let server_info = client.server_info().await?;
337        if server_info.mdb_version > *MDB_VERSION_SEMVER {
338            warn!(
339                "Server version {:?} is newer than client {:?}, consider updating.",
340                server_info.mdb_version, MDB_VERSION_SEMVER
341            );
342        }
343
344        if save {
345            if let Err(e) = client.save() {
346                error!("Login successful but failed to save config: {e}");
347                bail!("Login successful but failed to save config: {e}");
348            }
349        }
350        Ok(client)
351    }
352
353    /// Reset one's own API key to effectively logout & disable all clients who are using the key
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if there was a network issue or the user wasn't properly logged in.
358    pub async fn reset_key(&self) -> Result<()> {
359        let response = self
360            .client
361            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
362            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
363            .send()
364            .await
365            .context(MDB_CLIENT_ERROR_CONTEXT)?;
366        if !response.status().is_success() {
367            bail!("failed to reset API key, was it correct?");
368        }
369        Ok(())
370    }
371
372    /// Malware DB Client configuration loaded from a specified path
373    ///
374    /// # Errors
375    ///
376    /// Returns an error if the configuration file cannot be read, possibly because it
377    /// doesn't exist or due to a permission error or a parsing error.
378    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
379        let name = path.as_ref().display();
380        let config =
381            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
382        let cfg: MdbClient =
383            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
384        Ok(cfg)
385    }
386
387    /// Malware DB Client configuration from user's home directory
388    ///
389    /// On macOS, it will attempt to load this information in the Keychain, which isn't required.
390    ///
391    /// # Errors
392    ///
393    /// Returns an error if the configuration file cannot be read, possibly because it
394    /// doesn't exist or due to a permission error or a parsing error.
395    pub fn load() -> Result<Self> {
396        #[cfg(target_os = "macos")]
397        {
398            if let Ok((api_key, url, cert)) = macos::retrieve_credentials() {
399                let builder = reqwest::ClientBuilder::new()
400                    .gzip(true)
401                    .zstd(true)
402                    .use_rustls_tls()
403                    .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
404
405                let client = if let Some(cert) = &cert {
406                    builder.add_root_certificate(cert.as_cert()?).build()
407                } else {
408                    builder.build()
409                }?;
410
411                return Ok(Self {
412                    url,
413                    api_key,
414                    client,
415                    cert,
416                });
417            }
418        }
419
420        let path = get_config_path(false)?;
421        if path.exists() {
422            return Self::from_file(path);
423        }
424        bail!("config file not found")
425    }
426
427    /// Save Malware DB Client configuration to the user's home directory.
428    ///
429    /// On macOS, it will attempt to save this information in the Keychain, which isn't required.
430    ///
431    /// # Errors
432    ///
433    /// Returns an error if there was a problem saving the configuration file.
434    pub fn save(&self) -> Result<()> {
435        #[cfg(target_os = "macos")]
436        {
437            if macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
438                return Ok(());
439            }
440        }
441
442        let toml = toml::to_string(self)?;
443        let path = get_config_path(true)?;
444        std::fs::write(&path, toml)
445            .context(format!("failed to write mdb config to {}", path.display()))
446    }
447
448    /// Delete the Malware DB client configuration file
449    ///
450    /// # Errors
451    ///
452    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
453    /// possibly due to a permissions error.
454    pub fn delete(&self) -> Result<()> {
455        #[cfg(target_os = "macos")]
456        macos::clear_credentials();
457
458        let path = get_config_path(false)?;
459        if path.exists() {
460            std::fs::remove_file(&path).context(format!(
461                "failed to delete client config file {}",
462                path.display()
463            ))?;
464        }
465        Ok(())
466    }
467
468    // Actions of the client
469
470    /// Get information about the server, unauthenticated
471    ///
472    /// # Errors
473    ///
474    /// This may return an error if there's a network situation.
475    pub async fn server_info(&self) -> Result<ServerInfo> {
476        let response = self
477            .client
478            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
479            .send()
480            .await?
481            .json::<ServerResponse<ServerInfo>>()
482            .await
483            .context(MDB_CLIENT_ERROR_CONTEXT)?;
484
485        match response {
486            ServerResponse::Success(info) => Ok(info),
487            ServerResponse::Error(e) => Err(e.into()),
488        }
489    }
490
491    /// Get file types supported by the server, unauthenticated
492    ///
493    /// # Errors
494    ///
495    /// This may return an error if there's a network situation.
496    pub async fn supported_types(&self) -> Result<SupportedFileTypes> {
497        let response = self
498            .client
499            .get(format!(
500                "{}{}",
501                self.url,
502                malwaredb_api::SUPPORTED_FILE_TYPES_URL
503            ))
504            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
505            .send()
506            .await?
507            .json::<ServerResponse<SupportedFileTypes>>()
508            .await
509            .context(MDB_CLIENT_ERROR_CONTEXT)?;
510
511        match response {
512            ServerResponse::Success(types) => Ok(types),
513            ServerResponse::Error(e) => Err(e.into()),
514        }
515    }
516
517    /// Get information about the user
518    ///
519    /// # Errors
520    ///
521    /// This may return an error if there's a network situation or if the user is not logged in
522    /// or not properly authorized to connect.
523    pub async fn whoami(&self) -> Result<GetUserInfoResponse> {
524        let response = self
525            .client
526            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
527            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
528            .send()
529            .await?
530            .json::<ServerResponse<GetUserInfoResponse>>()
531            .await
532            .context(MDB_CLIENT_ERROR_CONTEXT)?;
533
534        match response {
535            ServerResponse::Success(info) => Ok(info),
536            ServerResponse::Error(e) => Err(e.into()),
537        }
538    }
539
540    /// Get the sample labels known to the server
541    ///
542    /// # Errors
543    ///
544    /// This may return an error if there's a network situation or if the user is not logged in
545    /// or not properly authorized to connect.
546    pub async fn labels(&self) -> Result<Labels> {
547        let response = self
548            .client
549            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
550            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
551            .send()
552            .await?
553            .json::<ServerResponse<Labels>>()
554            .await
555            .context(MDB_CLIENT_ERROR_CONTEXT)?;
556
557        match response {
558            ServerResponse::Success(labels) => Ok(labels),
559            ServerResponse::Error(e) => Err(e.into()),
560        }
561    }
562
563    /// Get the sources available to the current user
564    ///
565    /// # Errors
566    ///
567    /// This may return an error if there's a network situation or if the user is not logged in
568    /// or not properly authorized to connect.
569    pub async fn sources(&self) -> Result<Sources> {
570        let response = self
571            .client
572            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
573            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
574            .send()
575            .await?
576            .json::<ServerResponse<Sources>>()
577            .await
578            .context(MDB_CLIENT_ERROR_CONTEXT)?;
579
580        match response {
581            ServerResponse::Success(sources) => Ok(sources),
582            ServerResponse::Error(e) => Err(e.into()),
583        }
584    }
585
586    /// Submit one file to `MalwareDB`: provide the contents, file name, and source ID
587    ///
588    /// # Errors
589    ///
590    /// This may return an error if there's a network situation or if the user is not logged in
591    /// or not properly authorized to connect.
592    pub async fn submit(
593        &self,
594        contents: impl AsRef<[u8]>,
595        file_name: impl AsRef<str>,
596        source_id: u32,
597    ) -> Result<bool> {
598        let mut hasher = Sha256::new();
599        hasher.update(&contents);
600        let result = hasher.finalize();
601
602        let encoded = general_purpose::STANDARD.encode(contents);
603
604        let payload = malwaredb_api::NewSampleB64 {
605            file_name: file_name.as_ref().to_string(),
606            source_id,
607            file_contents_b64: encoded,
608            sha256: hex::encode(result),
609        };
610
611        match self
612            .client
613            .post(format!(
614                "{}{}",
615                self.url,
616                malwaredb_api::UPLOAD_SAMPLE_JSON_URL
617            ))
618            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
619            .json(&payload)
620            .send()
621            .await
622        {
623            Ok(res) => {
624                if !res.status().is_success() {
625                    info!("Code {} sending {}", res.status(), payload.file_name);
626                }
627                Ok(res.status().is_success())
628            }
629            Err(e) => {
630                let status: String = e
631                    .status()
632                    .map(|s| s.as_str().to_string())
633                    .unwrap_or_default();
634                error!("Error{status} sending {}: {e}", payload.file_name);
635                bail!(e.to_string())
636            }
637        }
638    }
639
640    /// Submit one file to `MalwareDB` as a Cbor object: provide the contents, file name, and source ID
641    /// Experimental! May be removed at any point.
642    ///
643    /// # Errors
644    ///
645    /// This may return an error if there's a network situation or if the user is not logged in
646    /// or not properly authorized to connect.
647    pub async fn submit_as_cbor(
648        &self,
649        contents: impl AsRef<[u8]>,
650        file_name: impl AsRef<str>,
651        source_id: u32,
652    ) -> Result<bool> {
653        let mut hasher = Sha256::new();
654        hasher.update(&contents);
655        let result = hasher.finalize();
656
657        let payload = malwaredb_api::NewSampleBytes {
658            file_name: file_name.as_ref().to_string(),
659            source_id,
660            file_contents: contents.as_ref().to_vec(),
661            sha256: hex::encode(result),
662        };
663
664        let mut bytes = Vec::with_capacity(payload.file_contents.len());
665        ciborium::ser::into_writer(&payload, &mut bytes)?;
666
667        match self
668            .client
669            .post(format!(
670                "{}{}",
671                self.url,
672                malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
673            ))
674            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
675            .header("content-type", "application/cbor")
676            .body(bytes)
677            .send()
678            .await
679        {
680            Ok(res) => {
681                if !res.status().is_success() {
682                    info!("Code {} sending {}", res.status(), payload.file_name);
683                }
684                Ok(res.status().is_success())
685            }
686            Err(e) => {
687                let status: String = e
688                    .status()
689                    .map(|s| s.as_str().to_string())
690                    .unwrap_or_default();
691                error!("Error{status} sending {}: {e}", payload.file_name);
692                bail!(e.to_string())
693            }
694        }
695    }
696
697    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
698    ///
699    /// # Errors
700    ///
701    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
702    pub async fn partial_search(
703        &self,
704        partial_hash: Option<(PartialHashSearchType, String)>,
705        name: Option<String>,
706        response: PartialHashSearchType,
707        limit: u32,
708    ) -> Result<SearchResponse> {
709        let query = SearchRequest {
710            search: SearchType::Search(SearchRequestParameters {
711                partial_hash,
712                file_name: name,
713                response,
714                limit,
715                labels: None,
716                file_type: None,
717                magic: None,
718            }),
719        };
720
721        self.do_search_request(&query).await
722    }
723
724    /// Search for a file based on partial hash and/or partial file name, labels, file type; returns a list of hashes
725    ///
726    /// # Errors
727    ///
728    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
729    #[allow(clippy::too_many_arguments)]
730    pub async fn partial_search_labels_type(
731        &self,
732        partial_hash: Option<(PartialHashSearchType, String)>,
733        name: Option<String>,
734        response: PartialHashSearchType,
735        labels: Option<Vec<String>>,
736        file_type: Option<String>,
737        magic: Option<String>,
738        limit: u32,
739    ) -> Result<SearchResponse> {
740        let query = SearchRequest {
741            search: SearchType::Search(SearchRequestParameters {
742                partial_hash,
743                file_name: name,
744                response,
745                limit,
746                file_type,
747                magic,
748                labels,
749            }),
750        };
751
752        self.do_search_request(&query).await
753    }
754
755    /// Return the next page from the search result
756    ///
757    /// # Errors
758    ///
759    /// Returns an error if there is a network problem, or pagination not available
760    pub async fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
761        if let Some(uuid) = response.pagination {
762            let request = SearchRequest {
763                search: SearchType::Continuation(uuid),
764            };
765            return self.do_search_request(&request).await;
766        }
767
768        bail!("Pagination not available")
769    }
770
771    async fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
772        ensure!(
773            query.is_valid(),
774            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
775        );
776
777        let response = self
778            .client
779            .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
780            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
781            .json(query)
782            .send()
783            .await?
784            .json::<ServerResponse<SearchResponse>>()
785            .await
786            .context(MDB_CLIENT_ERROR_CONTEXT)?;
787
788        match response {
789            ServerResponse::Success(search) => Ok(search),
790            ServerResponse::Error(e) => Err(e.into()),
791        }
792    }
793
794    /// Retrieve sample by hash, optionally in the `CaRT` format
795    ///
796    /// # Errors
797    ///
798    /// This may return an error if there's a network situation or if the user is not logged in
799    /// or not properly authorized to connect.
800    pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
801        let api_endpoint = if cart {
802            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
803        } else {
804            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
805        };
806
807        let res = self
808            .client
809            .get(format!("{}{api_endpoint}", self.url))
810            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
811            .send()
812            .await?;
813
814        if !res.status().is_success() {
815            bail!("Received code {}", res.status());
816        }
817
818        let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
819        let body = res.bytes().await?;
820        let bytes = body.to_vec();
821
822        // TODO: Make this required in v0.3
823        if let Some(digest) = content_digest {
824            let hash = HashType::from_content_digest_header(digest.to_str()?)?;
825            if hash.verify(&bytes) {
826                trace!("Hash verified for sample {hash}");
827            } else {
828                error!("Hash mismatch for sample {hash}");
829            }
830        } else {
831            warn!("No content digest header received for sample {hash}");
832        }
833
834        Ok(bytes)
835    }
836
837    /// Fetch a report for a sample
838    ///
839    /// # Errors
840    ///
841    /// This may return an error if there's a network situation or if the user is not logged in
842    /// or not properly authorized to connect.
843    pub async fn report(&self, hash: &str) -> Result<Report> {
844        let response = self
845            .client
846            .get(format!(
847                "{}{}/{hash}",
848                self.url,
849                malwaredb_api::SAMPLE_REPORT_URL
850            ))
851            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
852            .send()
853            .await?
854            .json::<ServerResponse<Report>>()
855            .await
856            .context(MDB_CLIENT_ERROR_CONTEXT)?;
857
858        match response {
859            ServerResponse::Success(report) => Ok(report),
860            ServerResponse::Error(e) => Err(e.into()),
861        }
862    }
863
864    /// Find similar samples in `MalwareDB` based on the contents of a given file.
865    /// This does not submit the sample to `MalwareDB`.
866    ///
867    /// # Errors
868    ///
869    /// This may return an error if there's a network situation or if the user is not logged in
870    /// or not properly authorized to connect.
871    pub async fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
872        let mut hashes = vec![];
873        let ssdeep_hash = FuzzyHash::new(contents);
874
875        let build_hasher = Murmur3HashState::default();
876        let lzjd_str =
877            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
878        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
879        hashes.push((
880            malwaredb_api::SimilarityHashType::SSDeep,
881            ssdeep_hash.to_string(),
882        ));
883
884        let mut builder = TlshBuilder::new(
885            tlsh_fixed::BucketKind::Bucket256,
886            tlsh_fixed::ChecksumKind::ThreeByte,
887            tlsh_fixed::Version::Version4,
888        );
889
890        builder.update(contents);
891        if let Ok(hasher) = builder.build() {
892            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
893        }
894
895        if let Ok(exe) = EXE::from(contents) {
896            if let Some(imports) = exe.imports {
897                hashes.push((
898                    malwaredb_api::SimilarityHashType::ImportHash,
899                    hex::encode(imports.hash()),
900                ));
901                hashes.push((
902                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
903                    imports.fuzzy_hash(),
904                ));
905            }
906        }
907
908        let request = malwaredb_api::SimilarSamplesRequest { hashes };
909
910        let response = self
911            .client
912            .post(format!(
913                "{}{}",
914                self.url,
915                malwaredb_api::SIMILAR_SAMPLES_URL
916            ))
917            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
918            .json(&request)
919            .send()
920            .await?
921            .json::<ServerResponse<SimilarSamplesResponse>>()
922            .await
923            .context(MDB_CLIENT_ERROR_CONTEXT)?;
924
925        match response {
926            ServerResponse::Success(similar) => Ok(similar),
927            ServerResponse::Error(e) => Err(e.into()),
928        }
929    }
930}
931
932impl Debug for MdbClient {
933    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
934        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
935    }
936}
937
938/// Convenience function for encoding bytes into a `CaRT` file using the default key. This also
939/// adds SHA-384 and SHA-512 hashes plus the entropy of the original file.
940/// See <https://github.com/CybercentreCanada/cart> for more information.
941///
942/// # Errors
943///
944/// There should not be any errors, but the underlying library can't guarantee that. However, any data
945/// passed to it is correct.
946pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
947    let mut input_buffer = Cursor::new(data);
948    let mut output_buffer = Cursor::new(vec![]);
949    let mut output_metadata = JsonMap::new();
950
951    let mut sha384 = Sha384::new();
952    sha384.update(data);
953    let sha384 = hex::encode(sha384.finalize());
954
955    let mut sha512 = Sha512::new();
956    sha512.update(data);
957    let sha512 = hex::encode(sha512.finalize());
958
959    output_metadata.insert("sha384".into(), sha384.into());
960    output_metadata.insert("sha512".into(), sha512.into());
961    output_metadata.insert("entropy".into(), entropy_calc(data).into());
962    cart_container::pack_stream(
963        &mut input_buffer,
964        &mut output_buffer,
965        Some(output_metadata),
966        None,
967        cart_container::digesters::default_digesters(),
968        None,
969    )?;
970
971    Ok(output_buffer.into_inner())
972}
973
974/// Convenience function for decoding a `CaRT` file using the default key, returning the bytes plus the
975/// optional header and footer metadata, if present.
976/// See <https://github.com/CybercentreCanada/cart> for more information.
977///
978/// # Errors
979///
980/// Returns an error if the file cannot be parsed or if this `CaRT` file didn't use the default key.
981/// <https://github.com/CybercentreCanada/cart-rs/blob/7ad548143bb85b64f364804e90cfada6c31cf902/cart_container/src/cipher.rs#L14-L17>
982pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
983    let mut input_buffer = Cursor::new(data);
984    let mut output_buffer = Cursor::new(vec![]);
985    let (header, footer) =
986        cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
987    Ok((output_buffer.into_inner(), header, footer))
988}
989
990/// Load a certificate from a path
991///
992/// # Errors
993///
994/// Returns errors if the file cannot be read or if the file isn't an ASN.1 DER file or
995/// base64-encoded ASN.1 PEM file.
996fn path_load_cert(path: &Path) -> Result<(CertificateType, Certificate)> {
997    if !path.exists() {
998        bail!("Certificate {} does not exist.", path.display());
999    }
1000    let cert = match path
1001        .extension()
1002        .context("can't determine file extension")?
1003        .to_str()
1004        .context("unable to parse file extension")?
1005    {
1006        "pem" => {
1007            let contents = std::fs::read(path)?;
1008            (CertificateType::PEM, Certificate::from_pem(&contents)?)
1009        }
1010        "der" => {
1011            let contents = std::fs::read(path)?;
1012            (CertificateType::DER, Certificate::from_der(&contents)?)
1013        }
1014        ext => {
1015            bail!("Unknown extension {ext:?}")
1016        }
1017    };
1018    Ok(cert)
1019}
1020
1021/// Gets the configuration file in the following order:
1022///
1023/// 1. Current working directory: `./mdb_client.toml`
1024/// 2. Haiku-specific directory if on Haiku
1025/// 3. XDG Free Desktop directory if on Unix
1026/// 4. The user's home directory in `~/.config/malwaredb_client/mdb_client.toml`
1027/// 5. `mdb_client.toml` in the current directory (same as the first, but after checking others)
1028#[inline]
1029pub(crate) fn get_config_path(create: bool) -> Result<PathBuf> {
1030    // If there is a config file in the current working directory, use it
1031    let config = PathBuf::from(MDB_CLIENT_CONFIG_TOML);
1032    if config.exists() {
1033        return Ok(config);
1034    }
1035
1036    #[cfg(target_os = "haiku")]
1037    {
1038        let mut settings = PathBuf::from("/boot/home/config/settings/malwaredb");
1039        if create && !settings.exists() {
1040            std::fs::create_dir_all(&settings)?;
1041        }
1042        settings.push(MDB_CLIENT_CONFIG_TOML);
1043        return Ok(settings);
1044    }
1045
1046    #[cfg(unix)]
1047    {
1048        // Obey the Free Desktop standard, check the variable
1049        if let Some(xdg_home) = std::env::var_os("XDG_CONFIG_HOME") {
1050            let mut xdg_config_home = PathBuf::from(xdg_home);
1051            xdg_config_home.push(MDB_CLIENT_DIR);
1052            if create && !xdg_config_home.exists() {
1053                std::fs::create_dir_all(&xdg_config_home)?;
1054            }
1055            xdg_config_home.push(MDB_CLIENT_CONFIG_TOML);
1056            return Ok(xdg_config_home);
1057        }
1058    }
1059
1060    if let Some(mut home_config) = home_dir() {
1061        home_config.push(".config");
1062        home_config.push(MDB_CLIENT_DIR);
1063        if create && !home_config.exists() {
1064            std::fs::create_dir_all(&home_config)?;
1065        }
1066        home_config.push(MDB_CLIENT_CONFIG_TOML);
1067        return Ok(home_config);
1068    }
1069
1070    Ok(PathBuf::from(MDB_CLIENT_CONFIG_TOML))
1071}
1072
1073/// Malware DB entries found by Multicast DNS (also known as Bonjour or Zeroconf)
1074#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1075pub struct MalwareDBServer {
1076    /// Server IP or domain
1077    pub host: String,
1078
1079    /// Server port
1080    pub port: u16,
1081
1082    /// If the server expects an encrypted connection
1083    pub ssl: bool,
1084
1085    /// Malware DB server name
1086    pub name: String,
1087}
1088
1089impl Display for MalwareDBServer {
1090    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1091        if self.ssl {
1092            write!(f, "https://{}:{}", self.host, self.port)
1093        } else {
1094            write!(f, "http://{}:{}", self.host, self.port)
1095        }
1096    }
1097}
1098
1099impl MalwareDBServer {
1100    /// Retrieve details about the server
1101    ///
1102    /// # Errors
1103    ///
1104    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1105    pub async fn server_info(&self) -> Result<ServerInfo> {
1106        let client = reqwest::ClientBuilder::new()
1107            .gzip(true)
1108            .zstd(true)
1109            .use_rustls_tls()
1110            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1111            .build()?;
1112
1113        let response = client
1114            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1115            .send()
1116            .await?
1117            .json::<ServerResponse<ServerInfo>>()
1118            .await
1119            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1120
1121        match response {
1122            ServerResponse::Success(info) => Ok(info),
1123            ServerResponse::Error(e) => Err(e.into()),
1124        }
1125    }
1126
1127    /// Retrieve details about the server
1128    ///
1129    /// # Errors
1130    ///
1131    /// An error will result if the server becomes unreachable or if a specific CA certificate is required
1132    ///
1133    /// # Panics
1134    ///
1135    /// This method panics if called from within an async runtime.
1136    #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
1137    #[cfg(feature = "blocking")]
1138    pub fn server_info_blocking(&self) -> Result<ServerInfo> {
1139        let client = reqwest::blocking::ClientBuilder::new()
1140            .gzip(true)
1141            .zstd(true)
1142            .use_rustls_tls()
1143            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")))
1144            .build()?;
1145
1146        let response = client
1147            .get(format!("{self}{}", malwaredb_api::SERVER_INFO_URL))
1148            .send()?
1149            .json::<ServerResponse<ServerInfo>>()
1150            .context(MDB_CLIENT_ERROR_CONTEXT)?;
1151
1152        match response {
1153            ServerResponse::Success(similar) => Ok(similar),
1154            ServerResponse::Error(e) => Err(e.into()),
1155        }
1156    }
1157}
1158
1159/// Find servers using Multicast DNS (also known as Bonjour)
1160///
1161/// # Errors
1162///
1163/// This may fail if there's a networking issue.
1164pub fn discover_servers() -> Result<Vec<MalwareDBServer>> {
1165    const MAX_ITERS: usize = 5;
1166    let mdns = ServiceDaemon::new()?;
1167    let mut servers = HashSet::new();
1168    let receiver = mdns.browse(malwaredb_api::MDNS_NAME)?;
1169
1170    let mut counter = 0;
1171    while let Ok(event) = receiver.recv() {
1172        if let ServiceEvent::ServiceResolved(resolved) = event {
1173            let host = resolved.host.replace(".local.", "");
1174            let ssl = if let Some(ssl) = resolved.txt_properties.get("ssl") {
1175                ssl.val_str() == "true"
1176            } else {
1177                debug!(
1178                    "MalwareDB entry for {host}:{} doesn't specify ssl, assuming not",
1179                    resolved.port
1180                );
1181                false
1182            };
1183
1184            let server = MalwareDBServer {
1185                host,
1186                port: resolved.port,
1187                ssl,
1188                name: resolved.fullname.replace(malwaredb_api::MDNS_NAME, ""),
1189            };
1190
1191            servers.insert(server);
1192        }
1193        counter += 1;
1194        if counter > MAX_ITERS {
1195            break;
1196        }
1197    }
1198
1199    Ok(servers.into_iter().collect())
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204    use super::*;
1205
1206    #[test]
1207    fn cart() {
1208        const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
1209        const ORIGINAL_SHA256: &str =
1210            "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
1211
1212        let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
1213
1214        let mut sha256 = Sha256::new();
1215        sha256.update(&decoded);
1216        let sha256 = hex::encode(sha256.finalize());
1217        assert_eq!(sha256, ORIGINAL_SHA256);
1218
1219        let header = header.unwrap();
1220        let entropy = header.get("entropy").unwrap().as_f64().unwrap();
1221        assert!(entropy > 4.0 && entropy < 4.1);
1222
1223        let footer = footer.unwrap();
1224        assert_eq!(footer.get("length").unwrap(), "5093");
1225        assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
1226    }
1227}