malwaredb_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7#![forbid(unsafe_code)]
8
9/// Non-async version of the Malware DB client
10#[cfg(feature = "blocking")]
11pub mod blocking;
12pub(crate) mod option_cert_path_serialization;
13
14pub use malwaredb_api;
15use malwaredb_lzjd::{LZDict, Murmur3HashState};
16use malwaredb_types::exec::pe32::EXE;
17use malwaredb_types::utils::entropy_calc;
18
19use std::fmt::{Debug, Formatter};
20use std::io::Cursor;
21use std::path::{Path, PathBuf};
22
23use anyhow::{bail, ensure, Context, Result};
24use base64::engine::general_purpose;
25use base64::Engine;
26use cart_container::JsonMap;
27use fuzzyhash::FuzzyHash;
28use home::home_dir;
29use malwaredb_api::{PartialHashSearchType, SearchRequest};
30use reqwest::Certificate;
31use serde::{Deserialize, Serialize};
32use sha2::{Digest, Sha256, Sha384, Sha512};
33use tlsh_fixed::TlshBuilder;
34use tracing::{error, info};
35use zeroize::{Zeroize, ZeroizeOnDrop};
36
37/// Local directory for the Malware DB client configs
38const MDB_CLIENT_DIR: &str = "malwaredb_client";
39
40/// Config file name expected by Malware DB client
41const MDB_CLIENT_CONFIG_TOML: &str = "mdb_client.toml";
42
43/// MDB version
44pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
45
46/// Asynchronous Malware DB Client Configuration and connection
47#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
48pub struct MdbClient {
49    /// URL of the Malware DB server, including http and port number, ending without a slash
50    pub url: String,
51
52    /// User's API key for Malware DB
53    api_key: String,
54
55    /// Certificate and Path, if needed
56    /// The path is serialized; deserialization loads & parses the certificate file specified.
57    #[zeroize(skip)]
58    #[serde(default, with = "option_cert_path_serialization")]
59    cert: Option<(Certificate, PathBuf)>,
60}
61
62impl MdbClient {
63    /// MDB Client from components, doesn't test connectivity
64    ///
65    /// # Errors
66    ///
67    /// Returns an error if a list of certificates was passed and any were not in the expected
68    /// DER or PEM format or could not be parsed.
69    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
70        let mut url = url;
71        let url = if url.ends_with('/') {
72            url.pop();
73            url
74        } else {
75            url
76        };
77
78        let cert = if let Some(path) = cert_path {
79            Some((path_load_cert(&path)?, path))
80        } else {
81            None
82        };
83
84        Ok(Self { url, api_key, cert })
85    }
86
87    /// Generate a client which already knows to send the API key and asks for gzip or zstd responses.
88    #[inline]
89    fn client(&self) -> reqwest::Result<reqwest::Client> {
90        let builder = reqwest::ClientBuilder::new()
91            .gzip(true)
92            .zstd(true)
93            .use_rustls_tls()
94            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
95
96        if let Some(cert) = &self.cert {
97            builder.add_root_certificate(cert.0.clone()).build()
98        } else {
99            builder.build()
100        }
101    }
102
103    /// Login to a server, optionally save the config file, and return a client object
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the server URL, username, or password were incorrect, or if a network
108    /// issue occurred.
109    pub async fn login(
110        url: String,
111        username: String,
112        password: String,
113        save: bool,
114        cert_path: Option<PathBuf>,
115    ) -> Result<Self> {
116        let mut url = url;
117        let url = if url.ends_with('/') {
118            url.pop();
119            url
120        } else {
121            url
122        };
123
124        let api_request = malwaredb_api::GetAPIKeyRequest {
125            user: username,
126            password,
127        };
128
129        let builder = reqwest::ClientBuilder::new()
130            .gzip(true)
131            .zstd(true)
132            .use_rustls_tls()
133            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
134
135        let cert = if let Some(path) = cert_path {
136            Some((path_load_cert(&path)?, path))
137        } else {
138            None
139        };
140
141        let client = if let Some(cert) = &cert {
142            builder.add_root_certificate(cert.0.clone()).build()
143        } else {
144            builder.build()
145        }?;
146
147        let res = client
148            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
149            .json(&api_request)
150            .send()
151            .await?
152            .json::<malwaredb_api::GetAPIKeyResponse>()
153            .await?;
154
155        if let Some(key) = &res.key {
156            let client = MdbClient {
157                url,
158                api_key: key.clone(),
159                cert,
160            };
161
162            if save {
163                if let Err(e) = client.save() {
164                    error!("Login successful but failed to save config: {e}");
165                    bail!("Login successful but failed to save config: {e}");
166                }
167            }
168            Ok(client)
169        } else {
170            if let Some(msg) = &res.message {
171                error!("Login failed, response: {msg}");
172            }
173            bail!("server error or bad credentials");
174        }
175    }
176
177    /// Reset one's own API key to effectively logout & disable all clients who are using the key
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if there was a network issue or the user wasn't properly logged in.
182    pub async fn reset_key(&self) -> Result<()> {
183        let response = self
184            .client()?
185            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
186            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
187            .send()
188            .await
189            .context("server error, or invalid API key")?;
190        if !response.status().is_success() {
191            bail!("failed to reset API key, was it correct?");
192        }
193        Ok(())
194    }
195
196    /// MDB Client loaded from a specified path
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if the configuration file cannot be read, possibly because it
201    /// doesn't exist or due to a permission error or a parsing error.
202    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
203        let name = path.as_ref().display();
204        let config =
205            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
206        let cfg: MdbClient =
207            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
208        Ok(cfg)
209    }
210
211    /// MDB Client from user's home directory
212    ///
213    /// # Errors
214    ///
215    /// Returns an error if the configuration file cannot be read, possibly because it
216    /// doesn't exist or due to a permission error or a parsing error.
217    pub fn load() -> Result<Self> {
218        let path = get_config_path(false)?;
219        if path.exists() {
220            return Self::from_file(path);
221        }
222        bail!("config file not found")
223    }
224
225    /// Save MDB Client to the user's home directory
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if there was a problem saving the configuration file.
230    pub fn save(&self) -> Result<()> {
231        let toml = toml::to_string(self)?;
232        let path = get_config_path(true)?;
233        std::fs::write(&path, toml)
234            .context(format!("failed to write mdb config to {}", path.display()))
235    }
236
237    /// Delete the `MalwareDB` client config file
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
242    /// possibly due to a permissions error.
243    pub fn delete(&self) -> Result<()> {
244        let path = get_config_path(false)?;
245        if path.exists() {
246            std::fs::remove_file(&path).context(format!(
247                "failed to delete client config file {}",
248                path.display()
249            ))?;
250        }
251        Ok(())
252    }
253
254    // Actions of the client
255
256    /// Get information about the server, unauthenticated
257    ///
258    /// # Errors
259    ///
260    /// This may return an error if there's a network situation.
261    pub async fn server_info(&self) -> Result<malwaredb_api::ServerInfo> {
262        self.client()?
263            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO))
264            .send()
265            .await?
266            .json::<malwaredb_api::ServerInfo>()
267            .await
268            .context("failed to receive or decode server info")
269    }
270
271    /// Get file types supported by the server, unauthenticated
272    ///
273    /// # Errors
274    ///
275    /// This may return an error if there's a network situation.
276    pub async fn supported_types(&self) -> Result<malwaredb_api::SupportedFileTypes> {
277        self.client()?
278            .get(format!(
279                "{}{}",
280                self.url,
281                malwaredb_api::SUPPORTED_FILE_TYPES
282            ))
283            .send()
284            .await?
285            .json::<malwaredb_api::SupportedFileTypes>()
286            .await
287            .context("failed to receive or decode server-supported file types")
288    }
289
290    /// Get information about the user
291    ///
292    /// # Errors
293    ///
294    /// This may return an error if there's a network situation or if the user is not logged in
295    /// or not properly authorized to connect.
296    pub async fn whoami(&self) -> Result<malwaredb_api::GetUserInfoResponse> {
297        self.client()?
298            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
299            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
300            .send()
301            .await?
302            .json::<malwaredb_api::GetUserInfoResponse>()
303            .await
304            .context("failed to receive or decode user info, or invalid API key")
305    }
306
307    /// Get the sample labels known to the server
308    ///
309    /// # Errors
310    ///
311    /// This may return an error if there's a network situation or if the user is not logged in
312    /// or not properly authorized to connect.
313    pub async fn labels(&self) -> Result<malwaredb_api::Labels> {
314        self.client()?
315            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS))
316            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
317            .send()
318            .await?
319            .json::<malwaredb_api::Labels>()
320            .await
321            .context("failed to receive or decode available labels, or invalid API key")
322    }
323
324    /// Get the sources available to the current user
325    ///
326    /// # Errors
327    ///
328    /// This may return an error if there's a network situation or if the user is not logged in
329    /// or not properly authorized to connect.
330    pub async fn sources(&self) -> Result<malwaredb_api::Sources> {
331        self.client()?
332            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES))
333            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
334            .send()
335            .await?
336            .json::<malwaredb_api::Sources>()
337            .await
338            .context("failed to receive or decode available labels, or invalid API key")
339    }
340
341    /// Submit one file to `MalwareDB`: provide the contents, file name, and source ID
342    ///
343    /// # Errors
344    ///
345    /// This may return an error if there's a network situation or if the user is not logged in
346    /// or not properly authorized to connect.
347    pub async fn submit(
348        &self,
349        contents: impl AsRef<[u8]>,
350        file_name: impl AsRef<str>,
351        source_id: u32,
352    ) -> Result<bool> {
353        let mut hasher = Sha256::new();
354        hasher.update(&contents);
355        let result = hasher.finalize();
356
357        let encoded = general_purpose::STANDARD.encode(contents);
358
359        let payload = malwaredb_api::NewSample {
360            file_name: file_name.as_ref().to_string(),
361            source_id,
362            file_contents_b64: encoded,
363            sha256: hex::encode(result),
364        };
365
366        match self
367            .client()?
368            .post(format!("{}{}", self.url, malwaredb_api::UPLOAD_SAMPLE))
369            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
370            .json(&payload)
371            .send()
372            .await
373        {
374            Ok(res) => {
375                if !res.status().is_success() {
376                    info!("Code {} sending {}", res.status(), payload.file_name);
377                }
378                Ok(res.status().is_success())
379            }
380            Err(e) => {
381                let status: String = e
382                    .status()
383                    .map(|s| s.as_str().to_string())
384                    .unwrap_or_default();
385                error!("Error{status} sending {}: {e}", payload.file_name);
386                bail!(e.to_string())
387            }
388        }
389    }
390
391    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
392    ///
393    /// # Errors
394    ///
395    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
396    pub async fn partial_search(
397        &self,
398        partial_hash: Option<(PartialHashSearchType, String)>,
399        name: Option<String>,
400        response: PartialHashSearchType,
401        limit: u32,
402    ) -> Result<Vec<String>> {
403        let query = SearchRequest {
404            partial_hash,
405            file_name: name,
406            response,
407            limit,
408        };
409
410        ensure!(
411            query.is_valid(),
412            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
413        );
414
415        self.client()?
416            .post(format!("{}{}", self.url, malwaredb_api::SEARCH))
417            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
418            .json(&query)
419            .send()
420            .await?
421            .json::<Vec<String>>()
422            .await
423            .context("failed to receive or decode hash list, or invalid API key")
424    }
425
426    /// Retrieve sample by hash, optionally in the `CaRT` format
427    ///
428    /// # Errors
429    ///
430    /// This may return an error if there's a network situation or if the user is not logged in
431    /// or not properly authorized to connect.
432    pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
433        let api_endpoint = if cart {
434            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART)
435        } else {
436            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE)
437        };
438
439        let res = self
440            .client()?
441            .get(format!("{}{api_endpoint}", self.url))
442            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
443            .send()
444            .await?;
445
446        if !res.status().is_success() {
447            bail!("Received code {}", res.status());
448        }
449
450        let body = res.bytes().await?;
451        Ok(body.to_vec())
452    }
453
454    /// Fetch a report for a sample
455    ///
456    /// # Errors
457    ///
458    /// This may return an error if there's a network situation or if the user is not logged in
459    /// or not properly authorized to connect.
460    pub async fn report(&self, hash: &str) -> Result<malwaredb_api::Report> {
461        self.client()?
462            .get(format!(
463                "{}{}/{hash}",
464                self.url,
465                malwaredb_api::SAMPLE_REPORT
466            ))
467            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
468            .send()
469            .await?
470            .json::<malwaredb_api::Report>()
471            .await
472            .context("failed to receive or decode sample report, or invalid API key")
473    }
474
475    /// Find similar samples in `MalwareDB` based on the contents of a given file.
476    /// This does not submit the sample to `MalwareDB`.
477    ///
478    /// # Errors
479    ///
480    /// This may return an error if there's a network situation or if the user is not logged in
481    /// or not properly authorized to connect.
482    pub async fn similar(&self, contents: &[u8]) -> Result<malwaredb_api::SimilarSamplesResponse> {
483        let mut hashes = vec![];
484        let ssdeep_hash = FuzzyHash::new(contents);
485
486        let build_hasher = Murmur3HashState::default();
487        let lzjd_str =
488            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
489        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
490        hashes.push((
491            malwaredb_api::SimilarityHashType::SSDeep,
492            ssdeep_hash.to_string(),
493        ));
494
495        let mut builder = TlshBuilder::new(
496            tlsh_fixed::BucketKind::Bucket256,
497            tlsh_fixed::ChecksumKind::ThreeByte,
498            tlsh_fixed::Version::Version4,
499        );
500
501        builder.update(contents);
502        if let Ok(hasher) = builder.build() {
503            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
504        }
505
506        if let Ok(exe) = EXE::from(contents) {
507            if let Some(imports) = exe.imports {
508                hashes.push((
509                    malwaredb_api::SimilarityHashType::ImportHash,
510                    hex::encode(imports.hash()),
511                ));
512                hashes.push((
513                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
514                    imports.fuzzy_hash(),
515                ));
516            }
517        }
518
519        let request = malwaredb_api::SimilarSamplesRequest { hashes };
520
521        self.client()?
522            .post(format!("{}{}", self.url, malwaredb_api::SIMILAR_SAMPLES))
523            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
524            .json(&request)
525            .send()
526            .await?
527            .json::<malwaredb_api::SimilarSamplesResponse>()
528            .await
529            .context("failed to receive or decode similarity response, or invalid API key")
530    }
531}
532
533impl Debug for MdbClient {
534    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
535        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
536    }
537}
538
539/// Convenience function for encoding bytes into a `CaRT` file using the default key. This also
540/// adds SHA-384 and SHA-512 hashes plus the entropy of the original file.
541/// See <https://github.com/CybercentreCanada/cart> for more information.
542///
543/// # Errors
544///
545/// There should not be any errors, but the underlying library can't guarantee that. However, any data
546/// passed to it is correct.
547pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
548    let mut input_buffer = Cursor::new(data);
549    let mut output_buffer = Cursor::new(vec![]);
550    let mut output_metadata = JsonMap::new();
551
552    let mut sha384 = Sha384::new();
553    sha384.update(data);
554    let sha384 = hex::encode(sha384.finalize());
555
556    let mut sha512 = Sha512::new();
557    sha512.update(data);
558    let sha512 = hex::encode(sha512.finalize());
559
560    output_metadata.insert("sha384".into(), sha384.into());
561    output_metadata.insert("sha512".into(), sha512.into());
562    output_metadata.insert("entropy".into(), entropy_calc(data).into());
563    cart_container::pack_stream(
564        &mut input_buffer,
565        &mut output_buffer,
566        Some(output_metadata),
567        None,
568        cart_container::digesters::default_digesters(),
569        None,
570    )?;
571
572    Ok(output_buffer.into_inner())
573}
574
575/// Convenience function for decoding a `CaRT` file using the default key, returning the bytes plus the
576/// optional header and footer metadata, if present.
577/// See <https://github.com/CybercentreCanada/cart> for more information.
578///
579/// # Errors
580///
581/// Returns an error if the file cannot be parsed or if this `CaRT` file didn't use the default key.
582/// <https://github.com/CybercentreCanada/cart-rs/blob/7ad548143bb85b64f364804e90cfada6c31cf902/cart_container/src/cipher.rs#L14-L17>
583pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
584    let mut input_buffer = Cursor::new(data);
585    let mut output_buffer = Cursor::new(vec![]);
586    let (header, footer) =
587        cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
588    Ok((output_buffer.into_inner(), header, footer))
589}
590
591/// Load a certificate from a path
592///
593/// # Errors
594///
595/// Returns errors if the file cannot be read or if the file isn't an ASN.1 DER file or
596/// base64-encoded ASN.1 PEM file.
597pub fn path_load_cert(path: &Path) -> Result<Certificate> {
598    if !path.exists() {
599        bail!("Certificate {path:?} does not exist.");
600    }
601    let cert = match path
602        .extension()
603        .context("can't determine file extension")?
604        .to_str()
605        .context("unable to parse file extension")?
606    {
607        "pem" => {
608            let contents = std::fs::read(path)?;
609            Certificate::from_pem(&contents)?
610        }
611        "der" => {
612            let contents = std::fs::read(path)?;
613            Certificate::from_der(&contents)?
614        }
615        ext => {
616            bail!("Unknown extension {ext:?}")
617        }
618    };
619    Ok(cert)
620}
621
622/// Gets the configuration file in the following order:
623///
624/// 1. Current working directory: `./mdb_client.toml`
625/// 2. Haiku-specific directory if on Haiku
626/// 3. XDG Free Desktop directory if on Unix
627/// 4. The user's home directory in `~/.config/malwaredb_client/mdb_client.toml`
628/// 5. `mdb_client.toml` in the current directory (same as the first, but after checking others)
629#[inline]
630pub(crate) fn get_config_path(create: bool) -> Result<PathBuf> {
631    // If there is a config file in the current working directory, use it
632    let config = PathBuf::from(MDB_CLIENT_CONFIG_TOML);
633    if config.exists() {
634        return Ok(config);
635    }
636
637    #[cfg(target_os = "haiku")]
638    {
639        let mut settings = PathBuf::from("/boot/home/config/settings/malwaredb");
640        if create && !settings.exists() {
641            std::fs::create_dir_all(&settings)?;
642        }
643        settings.push(MDB_CLIENT_CONFIG_TOML);
644        return Ok(settings);
645    }
646
647    #[cfg(unix)]
648    {
649        // Obey the Free Desktop standard, check the variable
650        if let Some(xdg_home) = std::env::var_os("XDG_CONFIG_HOME") {
651            let mut xdg_config_home = PathBuf::from(xdg_home);
652            xdg_config_home.push(MDB_CLIENT_DIR);
653            if create && !xdg_config_home.exists() {
654                std::fs::create_dir_all(&xdg_config_home)?;
655            }
656            xdg_config_home.push(MDB_CLIENT_CONFIG_TOML);
657            return Ok(xdg_config_home);
658        }
659    }
660
661    if let Some(mut home_config) = home_dir() {
662        home_config.push(".config");
663        home_config.push(MDB_CLIENT_DIR);
664        if create && !home_config.exists() {
665            std::fs::create_dir_all(&home_config)?;
666        }
667        home_config.push(MDB_CLIENT_CONFIG_TOML);
668        return Ok(home_config);
669    }
670
671    Ok(PathBuf::from(MDB_CLIENT_CONFIG_TOML))
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677
678    #[test]
679    fn cart() {
680        const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
681        const ORIGINAL_SHA256: &str =
682            "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
683
684        let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
685
686        let mut sha256 = Sha256::new();
687        sha256.update(&decoded);
688        let sha256 = hex::encode(sha256.finalize());
689        assert_eq!(sha256, ORIGINAL_SHA256);
690
691        let header = header.unwrap();
692        let entropy = header.get("entropy").unwrap().as_f64().unwrap();
693        assert!(entropy > 4.0 && entropy < 4.1);
694
695        let footer = footer.unwrap();
696        assert_eq!(footer.get("length").unwrap(), "5093");
697        assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
698    }
699}