1#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![forbid(unsafe_code)]
6
7mod option_cert_path_serialization;
8
9use malwaredb_lzjd::{LZDict, Murmur3HashState};
10use malwaredb_types::exec::pe32::EXE;
11use malwaredb_types::utils::entropy_calc;
12
13use std::fmt::{Debug, Formatter};
14use std::io::Cursor;
15use std::path::{Path, PathBuf};
16
17use anyhow::{bail, Context, Result};
18use base64::engine::general_purpose;
19use base64::Engine;
20use cart_container::JsonMap;
21use fuzzyhash::FuzzyHash;
22use home::home_dir;
23use reqwest::Certificate;
24use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256, Sha384, Sha512};
26use tlsh_fixed::TlshBuilder;
27use tracing::{error, warn};
28use zeroize::{Zeroize, ZeroizeOnDrop};
29
30const DOT_MDB_CLIENT_TOML: &str = ".mdb_client.toml";
32
33pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
35
36#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
38pub struct MdbClient {
39 pub url: String,
41
42 api_key: String,
44
45 #[zeroize(skip)]
48 #[serde(default, with = "option_cert_path_serialization")]
49 cert: Option<(Certificate, PathBuf)>,
50}
51
52impl MdbClient {
53 pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
55 let mut url = url;
56 let url = if url.ends_with('/') {
57 url.pop();
58 url
59 } else {
60 url
61 };
62
63 let cert = if let Some(path) = cert_path {
64 Some((path_load_cert(&path)?, path))
65 } else {
66 None
67 };
68
69 Ok(Self { url, api_key, cert })
70 }
71
72 #[inline]
74 fn client(&self) -> reqwest::Result<reqwest::Client> {
75 let builder = reqwest::ClientBuilder::new()
76 .gzip(true)
77 .zstd(true)
78 .use_rustls_tls()
79 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
80
81 if let Some(cert) = &self.cert {
82 builder.add_root_certificate(cert.0.clone()).build()
83 } else {
84 builder.build()
85 }
86 }
87
88 pub async fn login(
90 url: String,
91 username: String,
92 password: String,
93 save: bool,
94 cert_path: Option<PathBuf>,
95 ) -> Result<Self> {
96 let mut url = url;
97 let url = if url.ends_with('/') {
98 url.pop();
99 url
100 } else {
101 url
102 };
103
104 let api_request = malwaredb_api::GetAPIKeyRequest {
105 user: username,
106 password,
107 };
108
109 let builder = reqwest::ClientBuilder::new()
110 .gzip(true)
111 .zstd(true)
112 .use_rustls_tls()
113 .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
114
115 let cert = if let Some(path) = cert_path {
116 Some((path_load_cert(&path)?, path))
117 } else {
118 None
119 };
120
121 let client = if let Some(cert) = &cert {
122 builder.add_root_certificate(cert.0.clone()).build()
123 } else {
124 builder.build()
125 }?;
126
127 let res = client
128 .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
129 .json(&api_request)
130 .send()
131 .await?
132 .json::<malwaredb_api::GetAPIKeyResponse>()
133 .await?;
134
135 if let Some(key) = &res.key {
136 let client = MdbClient {
137 url,
138 api_key: key.clone(),
139 cert,
140 };
141
142 if save {
143 if let Err(e) = client.save() {
144 error!("Login successful but failed to save config: {e}");
145 bail!("Login successful but failed to save config: {e}");
146 }
147 }
148 Ok(client)
149 } else {
150 if let Some(msg) = &res.message {
151 error!("Login failed, response: {msg}");
152 }
153 bail!("server error or bad credentials");
154 }
155 }
156
157 pub async fn reset_key(&self) -> Result<()> {
159 let response = self
160 .client()?
161 .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
162 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
163 .send()
164 .await
165 .context("server error, or invalid API key")?;
166 if response.status().is_success() {
167 bail!("failed to reset API key, was it correct?");
168 }
169 Ok(())
170 }
171
172 pub fn from_file(path: &PathBuf) -> Result<Self> {
174 let config = std::fs::read_to_string(path)
175 .context(format!("failed to read config file {}", path.display()))?;
176 let cfg: MdbClient = toml::from_str(&config)
177 .context(format!("failed to parse config file {}", path.display()))?;
178 Ok(cfg)
179 }
180
181 pub fn load() -> Result<Self> {
183 let config = Path::new("mdb_client.toml");
184 if config.exists() {
185 return Self::from_file(&config.to_path_buf());
186 }
187
188 if let Some(mut home_config) = home_dir() {
189 home_config.push(DOT_MDB_CLIENT_TOML);
190 if home_config.exists() {
191 return Self::from_file(&home_config);
192 }
193 }
194 bail!("config file not found")
195 }
196
197 pub fn save(&self) -> Result<()> {
199 let toml = toml::to_string(self)?;
200 if let Some(mut home_config) = home_dir() {
201 home_config.push(DOT_MDB_CLIENT_TOML);
202 std::fs::write(&home_config, toml).context(format!(
203 "Unable to write config file at {}",
204 &home_config.display()
205 ))?;
206 return Ok(());
207 }
208
209 std::fs::write("mdb_client.toml", toml).context("failed to write mdb config")
210 }
211
212 pub fn delete(&self) -> Result<()> {
214 if let Some(mut home_config) = home_dir() {
215 home_config.push(DOT_MDB_CLIENT_TOML);
216 if home_config.exists() {
217 std::fs::remove_file(home_config)?;
218 }
219 }
220 Ok(())
221 }
222
223 pub async fn server_info(&self) -> Result<malwaredb_api::ServerInfo> {
227 self.client()?
228 .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO))
229 .send()
230 .await?
231 .json::<malwaredb_api::ServerInfo>()
232 .await
233 .context("failed to receive or decode server info")
234 }
235
236 pub async fn supported_types(&self) -> Result<malwaredb_api::SupportedFileTypes> {
238 self.client()?
239 .get(format!(
240 "{}{}",
241 self.url,
242 malwaredb_api::SUPPORTED_FILE_TYPES
243 ))
244 .send()
245 .await?
246 .json::<malwaredb_api::SupportedFileTypes>()
247 .await
248 .context("failed to receive or decode server-supported file types")
249 }
250
251 pub async fn whoami(&self) -> Result<malwaredb_api::GetUserInfoResponse> {
253 self.client()?
254 .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
255 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
256 .send()
257 .await?
258 .json::<malwaredb_api::GetUserInfoResponse>()
259 .await
260 .context("failed to receive or decode user info, or invalid API key")
261 }
262
263 pub async fn labels(&self) -> Result<malwaredb_api::Labels> {
265 self.client()?
266 .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS))
267 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
268 .send()
269 .await?
270 .json::<malwaredb_api::Labels>()
271 .await
272 .context("failed to receive or decode available labels, or invalid API key")
273 }
274
275 pub async fn sources(&self) -> Result<malwaredb_api::Sources> {
277 self.client()?
278 .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES))
279 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
280 .send()
281 .await?
282 .json::<malwaredb_api::Sources>()
283 .await
284 .context("failed to receive or decode available labels, or invalid API key")
285 }
286
287 pub async fn submit(
289 &self,
290 contents: impl AsRef<[u8]>,
291 file_name: &str,
292 source_id: u32,
293 ) -> Result<bool> {
294 let mut hasher = Sha256::new();
295 hasher.update(&contents);
296 let result = hasher.finalize();
297
298 let encoded = general_purpose::STANDARD.encode(contents);
299
300 let payload = malwaredb_api::NewSample {
301 file_name: file_name.to_string(),
302 source_id,
303 file_contents_b64: encoded,
304 sha256: hex::encode(result),
305 };
306
307 match self
308 .client()?
309 .post(format!("{}{}", self.url, malwaredb_api::UPLOAD_SAMPLE))
310 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
311 .json(&payload)
312 .send()
313 .await
314 {
315 Ok(res) => {
316 if !res.status().is_success() {
317 warn!("Code {} sending {file_name}", res.status());
318 }
319 Ok(res.status().is_success())
320 }
321 Err(e) => {
322 let status: String = e
323 .status()
324 .map(|s| s.as_str().to_string())
325 .unwrap_or_default();
326 error!("Error{status} sending {file_name}: {e}");
327 bail!(e.to_string())
328 }
329 }
330 }
331
332 pub async fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
334 let api_endpoint = if cart {
335 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART)
336 } else {
337 format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE)
338 };
339
340 let res = self
341 .client()?
342 .get(format!("{}{api_endpoint}", self.url))
343 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
344 .send()
345 .await?;
346
347 if !res.status().is_success() {
348 bail!("Received code {}", res.status());
349 }
350
351 let body = res.bytes().await?;
352 Ok(body.to_vec())
353 }
354
355 pub async fn report(&self, hash: &str) -> Result<malwaredb_api::Report> {
357 self.client()?
358 .get(format!(
359 "{}{}/{hash}",
360 self.url,
361 malwaredb_api::SAMPLE_REPORT
362 ))
363 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
364 .send()
365 .await?
366 .json::<malwaredb_api::Report>()
367 .await
368 .context("failed to receive or decode sample report, or invalid API key")
369 }
370
371 pub async fn similar(&self, contents: &[u8]) -> Result<malwaredb_api::SimilarSamplesResponse> {
374 let mut hashes = vec![];
375 let ssdeep_hash = FuzzyHash::new(contents);
376
377 let build_hasher = Murmur3HashState::default();
378 let lzjd_str =
379 LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
380 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
381 hashes.push((
382 malwaredb_api::SimilarityHashType::SSDeep,
383 ssdeep_hash.to_string(),
384 ));
385
386 let mut builder = TlshBuilder::new(
387 tlsh_fixed::BucketKind::Bucket256,
388 tlsh_fixed::ChecksumKind::ThreeByte,
389 tlsh_fixed::Version::Version4,
390 );
391
392 builder.update(contents);
393 if let Ok(hasher) = builder.build() {
394 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
395 }
396
397 if let Ok(exe) = EXE::from(contents) {
398 if let Some(imports) = exe.imports {
399 hashes.push((
400 malwaredb_api::SimilarityHashType::ImportHash,
401 hex::encode(imports.hash()),
402 ));
403 hashes.push((
404 malwaredb_api::SimilarityHashType::FuzzyImportHash,
405 imports.fuzzy_hash(),
406 ));
407 }
408 }
409
410 let request = malwaredb_api::SimilarSamplesRequest { hashes };
411
412 self.client()?
413 .post(format!("{}{}", self.url, malwaredb_api::SIMILAR_SAMPLES))
414 .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
415 .json(&request)
416 .send()
417 .await?
418 .json::<malwaredb_api::SimilarSamplesResponse>()
419 .await
420 .context("failed to receive or decode similarity response, or invalid API key")
421 }
422}
423
424impl Debug for MdbClient {
425 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
426 writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
427 }
428}
429
430pub fn encode_to_cart(data: &[u8]) -> Result<Vec<u8>> {
434 let mut input_buffer = Cursor::new(data);
435 let mut output_buffer = Cursor::new(vec![]);
436 let mut output_metadata = JsonMap::new();
437
438 let mut sha384 = Sha384::new();
439 sha384.update(data);
440 let sha384 = hex::encode(sha384.finalize());
441
442 let mut sha512 = Sha512::new();
443 sha512.update(data);
444 let sha512 = hex::encode(sha512.finalize());
445
446 output_metadata.insert("sha384".into(), sha384.into());
447 output_metadata.insert("sha512".into(), sha512.into());
448 output_metadata.insert("entropy".into(), entropy_calc(data).into());
449 cart_container::pack_stream(
450 &mut input_buffer,
451 &mut output_buffer,
452 Some(output_metadata),
453 None,
454 cart_container::digesters::default_digesters(),
455 None,
456 )?;
457
458 Ok(output_buffer.into_inner())
459}
460
461pub fn decode_from_cart(data: &[u8]) -> Result<(Vec<u8>, Option<JsonMap>, Option<JsonMap>)> {
465 let mut input_buffer = Cursor::new(data);
466 let mut output_buffer = Cursor::new(vec![]);
467 let (header, footer) =
468 cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)?;
469 Ok((output_buffer.into_inner(), header, footer))
470}
471
472pub fn path_load_cert(path: &Path) -> Result<Certificate> {
474 if !path.exists() {
475 bail!("Certificate {path:?} does not exist.");
476 }
477 let cert = match path
478 .extension()
479 .expect("can't determine file extension")
480 .to_str()
481 .expect("unable to parse file extension")
482 {
483 "pem" => {
484 let contents = std::fs::read(path)?;
485 Certificate::from_pem(&contents)?
486 }
487 "der" => {
488 let contents = std::fs::read(path)?;
489 Certificate::from_der(&contents)?
490 }
491 ext => {
492 bail!("Unknown extension {ext:?}")
493 }
494 };
495 Ok(cert)
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn cart() {
504 const BYTES: &[u8] = include_bytes!("../../crates/types/testdata/elf/elf_haiku_x86.cart");
505 const ORIGINAL_SHA256: &str =
506 "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740";
507
508 let (decoded, header, footer) = decode_from_cart(BYTES).unwrap();
509
510 let mut sha256 = Sha256::new();
511 sha256.update(&decoded);
512 let sha256 = hex::encode(sha256.finalize());
513 assert_eq!(sha256, ORIGINAL_SHA256);
514
515 let header = header.unwrap();
516 let entropy = header.get("entropy").unwrap().as_f64().unwrap();
517 assert!(entropy > 4.0 && entropy < 4.1);
518
519 let footer = footer.unwrap();
520 assert_eq!(footer.get("length").unwrap(), "5093");
521 assert_eq!(footer.get("sha256").unwrap(), ORIGINAL_SHA256);
522 }
523}