malwaredb_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod option_cert_path_serialization;
8
9use malwaredb_lzjd::{LZDict, Murmur3HashState};
10use malwaredb_types::exec::pe32::EXE;
11use malwaredb_types::utils::entropy_calc;
12
13use std::fmt::{Debug, Formatter};
14use std::io::Cursor;
15use std::path::{Path, PathBuf};
16
17use anyhow::{bail, Context, Result};
18use base64::engine::general_purpose;
19use base64::Engine;
20use cart_container::JsonMap;
21use fuzzyhash::FuzzyHash;
22use home::home_dir;
23use reqwest::Certificate;
24use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256, Sha384, Sha512};
26use tlsh_fixed::TlshBuilder;
27use tracing::{error, warn};
28use zeroize::{Zeroize, ZeroizeOnDrop};
29
30/// Config file name expected by MalwareDB Client
31const DOT_MDB_CLIENT_TOML: &str = ".mdb_client.toml";
32
33/// MDB version
34pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
35
36/// MalwareDB Client Configuration and connection
37#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
38pub struct MdbClient {
39    /// URL of MalwareDB, including http and port number, ending without a slash
40    pub url: String,
41
42    /// User's API key for MalwareDB
43    api_key: String,
44
45    /// Certificate and Path, if needed
46    /// The path is serialized; deserialization loads & parses the certificate file specified.
47    #[zeroize(skip)]
48    #[serde(default, with = "option_cert_path_serialization")]
49    cert: Option<(Certificate, PathBuf)>,
50}
51
52impl MdbClient {
53    /// MDB Client from components, doesn't test connectivity
54    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
55        let mut url = url;
56        let url = if url.ends_with('/') {
57            url.pop();
58            url
59        } else {
60            url
61        };
62
63        let cert = if let Some(path) = cert_path {
64            Some((path_load_cert(&path)?, path))
65        } else {
66            None
67        };
68
69        Ok(Self { url, api_key, cert })
70    }
71
72    /// Generate a client which already knows to send the API key, and asks for gzip responses.
73    #[inline]
74    fn client(&self) -> reqwest::Result<reqwest::Client> {
75        let builder = reqwest::ClientBuilder::new()
76            .gzip(true)
77            .zstd(true)
78            .use_rustls_tls()
79            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
80
81        if let Some(cert) = &self.cert {
82            builder.add_root_certificate(cert.0.clone()).build()
83        } else {
84            builder.build()
85        }
86    }
87
88    /// Login to a server, optionally save the config file, and return a client object
89    pub async fn login(
90        url: String,
91        username: String,
92        password: String,
93        save: bool,
94        cert_path: Option<PathBuf>,
95    ) -> Result<Self> {
96        let mut url = url;
97        let url = if url.ends_with('/') {
98            url.pop();
99            url
100        } else {
101            url
102        };
103
104        let api_request = malwaredb_api::GetAPIKeyRequest {
105            user: username,
106            password,
107        };
108
109        let builder = reqwest::ClientBuilder::new()
110            .gzip(true)
111            .zstd(true)
112            .use_rustls_tls()
113            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
114
115        let cert = if let Some(path) = cert_path {
116            Some((path_load_cert(&path)?, path))
117        } else {
118            None
119        };
120
121        let client = if let Some(cert) = &cert {
122            builder.add_root_certificate(cert.0.clone()).build()
123        } else {
124            builder.build()
125        }?;
126
127        let res = client
128            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
129            .json(&api_request)
130            .send()
131            .await?
132            .json::<malwaredb_api::GetAPIKeyResponse>()
133            .await?;
134
135        if let Some(key) = &res.key {
136            let client = MdbClient {
137                url,
138                api_key: key.clone(),
139                cert,
140            };
141
142            if save {
143                if let Err(e) = client.save() {
144                    error!("Login successful but failed to save config: {e}");
145                    bail!("Login successful but failed to save config: {e}");
146                }
147            }
148            Ok(client)
149        } else {
150            if let Some(msg) = &res.message {
151                error!("Login failed, response: {msg}");
152            }
153            bail!("server error or bad credentials");
154        }
155    }
156
157    /// Reset one's own API key to effectively logout & disable all clients who are using the key
158    pub async fn reset_key(&self) -> Result<()> {
159        let response = self
160            .client()?
161            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
162            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
163            .send()
164            .await
165            .context("server error, or invalid API key")?;
166        if response.status().is_success() {
167            bail!("failed to reset API key, was it correct?");
168        }
169        Ok(())
170    }
171
172    /// MDB Client loaded from a specified path
173    pub fn from_file(path: &PathBuf) -> Result<Self> {
174        let config = std::fs::read_to_string(path)
175            .context(format!("failed to read config file {}", path.display()))?;
176        let cfg: MdbClient = toml::from_str(&config)
177            .context(format!("failed to parse config file {}", path.display()))?;
178        Ok(cfg)
179    }
180
181    /// MDB Client from user's home directory
182    pub fn load() -> Result<Self> {
183        let config = Path::new("mdb_client.toml");
184        if config.exists() {
185            return Self::from_file(&config.to_path_buf());
186        }
187
188        if let Some(mut home_config) = home_dir() {
189            home_config.push(DOT_MDB_CLIENT_TOML);
190            if home_config.exists() {
191                return Self::from_file(&home_config);
192            }
193        }
194        bail!("config file not found")
195    }
196
197    /// Save MDB Client to the user's home directory
198    pub fn save(&self) -> Result<()> {
199        let toml = toml::to_string(self)?;
200        if let Some(mut home_config) = home_dir() {
201            home_config.push(DOT_MDB_CLIENT_TOML);
202            std::fs::write(&home_config, toml).context(format!(
203                "Unable to write config file at {}",
204                &home_config.display()
205            ))?;
206            return Ok(());
207        }
208
209        std::fs::write("mdb_client.toml", toml).context("failed to write mdb config")
210    }
211
212    /// Delete the MalwareDB client config file
213    pub fn delete(&self) -> Result<()> {
214        if let Some(mut home_config) = home_dir() {
215            home_config.push(DOT_MDB_CLIENT_TOML);
216            if home_config.exists() {
217                std::fs::remove_file(home_config)?;
218            }
219        }
220        Ok(())
221    }
222
223    // Actions of the client
224
225    /// Get information about the server, unauthenticated
226    pub async fn server_info(&self) -> Result<malwaredb_api::ServerInfo> {
227        self.client()?
228            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO))
229            .send()
230            .await?
231            .json::<malwaredb_api::ServerInfo>()
232            .await
233            .context("failed to receive or decode server info")
234    }
235
236    /// Get file types supported by the server, unauthenticated
237    pub async fn supported_types(&self) -> Result<malwaredb_api::SupportedFileTypes> {
238        self.client()?
239            .get(format!(
240                "{}{}",
241                self.url,
242                malwaredb_api::SUPPORTED_FILE_TYPES
243            ))
244            .send()
245            .await?
246            .json::<malwaredb_api::SupportedFileTypes>()
247            .await
248            .context("failed to receive or decode server-supported file types")
249    }
250
251    /// Get information about the user
252    pub async fn whoami(&self) -> Result<malwaredb_api::GetUserInfoResponse> {
253        self.client()?
254            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
255            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
256            .send()
257            .await?
258            .json::<malwaredb_api::GetUserInfoResponse>()
259            .await
260            .context("failed to receive or decode user info, or invalid API key")
261    }
262
263    /// Get the sample labels known to the server
264    pub async fn labels(&self) -> Result<malwaredb_api::Labels> {
265        self.client()?
266            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS))
267            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
268            .send()
269            .await?
270            .json::<malwaredb_api::Labels>()
271            .await
272            .context("failed to receive or decode available labels, or invalid API key")
273    }
274
275    /// Get the sources available to the current user
276    pub async fn sources(&self) -> Result<malwaredb_api::Sources> {
277        self.client()?
278            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES))
279            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
280            .send()
281            .await?
282            .json::<malwaredb_api::Sources>()
283            .await
284            .context("failed to receive or decode available labels, or invalid API key")
285    }
286
287    /// Submit one file to MalwareDB: provide the contents, file name, and source ID
288    pub async fn submit(
289        &self,
290        contents: impl AsRef<[u8]>,
291        file_name: &str,
292        source_id: u32,
293    ) -> Result<bool> {
294        let mut hasher = Sha256::new();
295        hasher.update(&contents);
296        let result = hasher.finalize();
297
298        let encoded = general_purpose::STANDARD.encode(contents);
299
300        let payload = malwaredb_api::NewSample {
301            file_name: file_name.to_string(),
302            source_id,
303            file_contents_b64: encoded,
304            sha256: hex::encode(result),
305        };
306
307        match self
308            .client()?
309            .post(format!("{}{}", self.url, malwaredb_api::UPLOAD_SAMPLE))
310            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
311            .json(&payload)
312            .send()
313            .await
314        {
315            Ok(res) => {
316                if !res.status().is_success() {
317                    warn!("Code {} sending {file_name}", res.status());
318                }
319                Ok(res.status().is_success())
320            }
321            Err(e) => {
322                let status: String = e
323                    .status()
324                    .map(|s| s.as_str().to_string())
325                    .unwrap_or_default();
326                error!("Error{status} sending {file_name}: {e}");
327                bail!(e.to_string())
328            }
329        }
330    }
331
332    /// Retrieve sample by hash, optionally in the CaRT format
333    pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
334        let api_endpoint = if cart {
335            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART)
336        } else {
337            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE)
338        };
339
340        let res = self
341            .client()?
342            .get(format!("{}{api_endpoint}", self.url))
343            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
344            .send()
345            .await?;
346
347        if !res.status().is_success() {
348            bail!("Received code {}", res.status());
349        }
350
351        let body = res.bytes().await?;
352        Ok(body.to_vec())
353    }
354
355    /// Fetch a report for a sample
356    pub async fn report(&self, hash: &str) -> Result<malwaredb_api::Report> {
357        self.client()?
358            .get(format!(
359                "{}{}/{hash}",
360                self.url,
361                malwaredb_api::SAMPLE_REPORT
362            ))
363            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
364            .send()
365            .await?
366            .json::<malwaredb_api::Report>()
367            .await
368            .context("failed to receive or decode sample report, or invalid API key")
369    }
370
371    /// Find similar samples in MalwareDB based on the contents of a given file.
372    /// This does not submit the sample to MalwareDB.
373    pub async fn similar(&self, contents: &[u8]) -> Result<malwaredb_api::SimilarSamplesResponse> {
374        let mut hashes = vec![];
375        let ssdeep_hash = FuzzyHash::new(contents);
376
377        let build_hasher = Murmur3HashState::default();
378        let lzjd_str =
379            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
380        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
381        hashes.push((
382            malwaredb_api::SimilarityHashType::SSDeep,
383            ssdeep_hash.to_string(),
384        ));
385
386        let mut builder = TlshBuilder::new(
387            tlsh_fixed::BucketKind::Bucket256,
388            tlsh_fixed::ChecksumKind::ThreeByte,
389            tlsh_fixed::Version::Version4,
390        );
391
392        builder.update(contents);
393        if let Ok(hasher) = builder.build() {
394            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
395        }
396
397        if let Ok(exe) = EXE::from(contents) {
398            if let Some(imports) = exe.imports {
399                hashes.push((
400                    malwaredb_api::SimilarityHashType::ImportHash,
401                    hex::encode(imports.hash()),
402                ));
403                hashes.push((
404                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
405                    imports.fuzzy_hash(),
406                ));
407            }
408        }
409
410        let request = malwaredb_api::SimilarSamplesRequest { hashes };
411
412        self.client()?
413            .post(format!("{}{}", self.url, malwaredb_api::SIMILAR_SAMPLES))
414            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
415            .json(&request)
416            .send()
417            .await?
418            .json::<malwaredb_api::SimilarSamplesResponse>()
419            .await
420            .context("failed to receive or decode similarity response, or invalid API key")
421    }
422}
423
424impl Debug for MdbClient {
425    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
426        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
427    }
428}
429
430/// Convenience function for encoding bytes into a CaRT file using the default key. This also
431/// adds SHA-384 and SHA-512 hashes plus the entropy of the original file.
432/// See https://github.com/CybercentreCanada/cart for more information.
433pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
434    let mut input_buffer = Cursor::new(data);
435    let mut output_buffer = Cursor::new(vec![]);
436    let mut output_metadata = JsonMap::new();
437
438    let mut sha384 = Sha384::new();
439    sha384.update(data);
440    let sha384 = hex::encode(sha384.finalize());
441
442    let mut sha512 = Sha512::new();
443    sha512.update(data);
444    let sha512 = hex::encode(sha512.finalize());
445
446    output_metadata.insert("sha384".into(), sha384.into());
447    output_metadata.insert("sha512".into(), sha512.into());
448    output_metadata.insert("entropy".into(), entropy_calc(data).into());
449    cart_container::pack_stream(
450        &mut input_buffer,
451        &mut output_buffer,
452        Some(output_metadata),
453        None,
454        cart_container::digesters::default_digesters(),
455        None,
456    )?;
457
458    Ok(output_buffer.into_inner())
459}
460
461/// Convenience function for decoding a CaRT file using the default key, returning the bytes plus the
462/// optional header & footer metadata, if present.
463/// See https://github.com/CybercentreCanada/cart for more information.
464pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
465    let mut input_buffer = Cursor::new(data);
466    let mut output_buffer = Cursor::new(vec![]);
467    let (header, footer) =
468        cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
469    Ok((output_buffer.into_inner(), header, footer))
470}
471
472/// Load a certificate from a path
473pub fn path_load_cert(path: &Path) -> Result<Certificate> {
474    if !path.exists() {
475        bail!("Certificate {path:?} does not exist.");
476    }
477    let cert = match path
478        .extension()
479        .expect("can't determine file extension")
480        .to_str()
481        .expect("unable to parse file extension")
482    {
483        "pem" => {
484            let contents = std::fs::read(path)?;
485            Certificate::from_pem(&contents)?
486        }
487        "der" => {
488            let contents = std::fs::read(path)?;
489            Certificate::from_der(&contents)?
490        }
491        ext => {
492            bail!("Unknown extension {ext:?}")
493        }
494    };
495    Ok(cert)
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn cart() {
504        const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
505        const ORIGINAL_SHA256: &str =
506            "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
507
508        let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
509
510        let mut sha256 = Sha256::new();
511        sha256.update(&decoded);
512        let sha256 = hex::encode(sha256.finalize());
513        assert_eq!(sha256, ORIGINAL_SHA256);
514
515        let header = header.unwrap();
516        let entropy = header.get("entropy").unwrap().as_f64().unwrap();
517        assert!(entropy > 4.0 && entropy < 4.1);
518
519        let footer = footer.unwrap();
520        assert_eq!(footer.get("length").unwrap(), "5093");
521        assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
522    }
523}