1#[cfg(any(test, feature = "admin"))]
10mod admin;
11mod pg;
13
14#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33 digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51#[derive(Debug)]
53pub enum DatabaseType {
54 Postgres(Postgres),
56
57 #[cfg(any(test, feature = "sqlite"))]
59 SQLite(Sqlite),
60}
61
62#[derive(Debug)]
64pub struct DatabaseInformation {
65 pub version: String,
67
68 pub size: String,
70
71 pub num_files: u64,
73
74 pub num_users: u32,
76
77 pub num_groups: u32,
79
80 pub num_sources: u32,
82}
83
84pub struct FileAddedResult {
86 pub file_id: u64,
88
89 pub is_new: bool,
92}
93
94#[derive(Debug)]
96pub struct MDBConfig {
97 pub name: String,
99
100 pub compression: bool,
102
103 pub send_samples_to_vt: bool,
105
106 pub keep_unknown_files: bool,
108
109 pub(crate) default_key: Option<u32>,
111}
112
113#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
115#[cfg(feature = "vt")]
116#[derive(Debug, Clone, Copy)]
117pub struct VtStats {
118 pub clean_records: u32,
120
121 pub hits_records: u32,
123
124 pub files_without_records: u32,
126}
127
128impl DatabaseType {
129 pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
138 #[cfg(any(test, feature = "sqlite"))]
139 if arg.starts_with("file:") {
140 let new_conn_str = arg.trim_start_matches("file:");
141 return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
142 }
143
144 if arg.starts_with("postgres") {
145 let new_conn_str = arg.trim_start_matches("postgres");
146 return Ok(DatabaseType::Postgres(
147 Postgres::new(new_conn_str, server_ca).await?,
148 ));
149 }
150
151 bail!("unknown database type `{arg}`")
152 }
153
154 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
160 #[cfg(feature = "vt")]
161 pub async fn enable_vt_upload(&self) -> Result<()> {
162 match self {
163 DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
164 #[cfg(any(test, feature = "sqlite"))]
165 DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
166 }
167 }
168
169 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
175 #[cfg(feature = "vt")]
176 pub async fn disable_vt_upload(&self) -> Result<()> {
177 match self {
178 DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
179 #[cfg(any(test, feature = "sqlite"))]
180 DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
181 }
182 }
183
184 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
190 #[cfg(feature = "vt")]
191 pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
192 match self {
193 DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
194 #[cfg(any(test, feature = "sqlite"))]
195 DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
196 }
197 }
198
199 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
205 #[cfg(feature = "vt")]
206 pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
207 match self {
208 DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
209 #[cfg(any(test, feature = "sqlite"))]
210 DatabaseType::SQLite(sl) => sl.store_vt_record(results),
211 }
212 }
213
214 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
220 #[cfg(feature = "vt")]
221 pub async fn get_vt_stats(&self) -> Result<VtStats> {
222 match self {
223 DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
224 #[cfg(any(test, feature = "sqlite"))]
225 DatabaseType::SQLite(sl) => sl.get_vt_stats(),
226 }
227 }
228
229 pub async fn get_config(&self) -> Result<MDBConfig> {
235 match self {
236 DatabaseType::Postgres(pg) => pg.get_config().await,
237 #[cfg(any(test, feature = "sqlite"))]
238 DatabaseType::SQLite(sl) => sl.get_config(),
239 }
240 }
241
242 pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
249 match self {
250 DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
251 #[cfg(any(test, feature = "sqlite"))]
252 DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
253 }
254 }
255
256 pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
263 ensure!(!apikey.is_empty(), "API key was empty");
264 match self {
265 DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
266 #[cfg(any(test, feature = "sqlite"))]
267 DatabaseType::SQLite(sl) => sl.get_uid(apikey),
268 }
269 }
270
271 pub async fn db_info(&self) -> Result<DatabaseInformation> {
277 match self {
278 DatabaseType::Postgres(pg) => pg.db_info().await,
279 #[cfg(any(test, feature = "sqlite"))]
280 DatabaseType::SQLite(sl) => sl.db_info(),
281 }
282 }
283
284 pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
291 match self {
292 DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
293 #[cfg(any(test, feature = "sqlite"))]
294 DatabaseType::SQLite(sl) => sl.get_user_info(uid),
295 }
296 }
297
298 pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
305 match self {
306 DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
307 #[cfg(any(test, feature = "sqlite"))]
308 DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
309 }
310 }
311
312 pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
319 match self {
320 DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
321 #[cfg(any(test, feature = "sqlite"))]
322 DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
323 }
324 }
325
326 pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
332 match self {
333 DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
334 #[cfg(any(test, feature = "sqlite"))]
335 DatabaseType::SQLite(sl) => sl.get_known_data_types(),
336 }
337 }
338
339 pub async fn get_labels(&self) -> Result<Labels> {
345 match self {
346 DatabaseType::Postgres(pg) => pg.get_labels().await,
347 #[cfg(any(test, feature = "sqlite"))]
348 DatabaseType::SQLite(sl) => sl.get_labels(),
349 }
350 }
351
352 pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
358 match self {
359 DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
360 #[cfg(any(test, feature = "sqlite"))]
361 DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
362 }
363 }
364
365 pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
372 match self {
373 DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
374 #[cfg(any(test, feature = "sqlite"))]
375 DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
376 }
377 }
378
379 pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
387 match self {
388 DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
389 #[cfg(any(test, feature = "sqlite"))]
390 DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
391 }
392 }
393
394 pub async fn add_file(
401 &self,
402 meta: &FileMetadata,
403 known_type: KnownType<'_>,
404 uid: u32,
405 sid: u32,
406 ftype: u32,
407 parent: Option<u64>,
408 ) -> Result<FileAddedResult> {
409 match self {
410 DatabaseType::Postgres(pg) => {
411 pg.add_file(meta, known_type, uid, sid, ftype, parent).await
412 }
413 #[cfg(any(test, feature = "sqlite"))]
414 DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
415 }
416 }
417
418 pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
424 match self {
425 DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
426 #[cfg(any(test, feature = "sqlite"))]
427 DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
428 }
429 }
430
431 pub async fn cleanup(&self) -> Result<u64> {
437 match self {
438 DatabaseType::Postgres(pg) => pg.cleanup().await,
439 #[cfg(any(test, feature = "sqlite"))]
440 DatabaseType::SQLite(sl) => sl.cleanup(),
441 }
442 }
443
444 pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
452 match self {
453 DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
454 #[cfg(any(test, feature = "sqlite"))]
455 DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
456 }
457 }
458
459 pub async fn get_sample_report(
466 &self,
467 uid: u32,
468 hash: &HashType,
469 ) -> Result<malwaredb_api::Report> {
470 match self {
471 DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
472 #[cfg(any(test, feature = "sqlite"))]
473 DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
474 }
475 }
476
477 pub async fn find_similar_samples(
483 &self,
484 uid: u32,
485 sim: &[(malwaredb_api::SimilarityHashType, String)],
486 ) -> Result<Vec<malwaredb_api::SimilarSample>> {
487 match self {
488 DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
489 #[cfg(any(test, feature = "sqlite"))]
490 DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
491 }
492 }
493
494 pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
498 match self {
499 DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
500 #[cfg(any(test, feature = "sqlite"))]
501 DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
502 }
503 }
504
505 pub(crate) async fn get_file_encryption_key_id(
507 &self,
508 hash: &str,
509 ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
510 match self {
511 DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
512 #[cfg(any(test, feature = "sqlite"))]
513 DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
514 }
515 }
516
517 pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
519 match self {
520 DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
521 #[cfg(any(test, feature = "sqlite"))]
522 DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
523 }
524 }
525
526 #[cfg(any(test, feature = "admin"))]
534 pub async fn set_name(&self, name: &str) -> Result<()> {
535 match self {
536 DatabaseType::Postgres(pg) => pg.set_name(name).await,
537 #[cfg(any(test, feature = "sqlite"))]
538 DatabaseType::SQLite(sl) => sl.set_name(name),
539 }
540 }
541
542 #[cfg(any(test, feature = "admin"))]
548 pub async fn enable_compression(&self) -> Result<()> {
549 match self {
550 DatabaseType::Postgres(pg) => pg.enable_compression().await,
551 #[cfg(any(test, feature = "sqlite"))]
552 DatabaseType::SQLite(sl) => sl.enable_compression(),
553 }
554 }
555
556 #[cfg(any(test, feature = "admin"))]
562 pub async fn disable_compression(&self) -> Result<()> {
563 match self {
564 DatabaseType::Postgres(pg) => pg.disable_compression().await,
565 #[cfg(any(test, feature = "sqlite"))]
566 DatabaseType::SQLite(sl) => sl.disable_compression(),
567 }
568 }
569
570 #[cfg(any(test, feature = "admin"))]
576 pub async fn enable_keep_unknown_files(&self) -> Result<()> {
577 match self {
578 DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
579 #[cfg(any(test, feature = "sqlite"))]
580 DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
581 }
582 }
583
584 #[cfg(any(test, feature = "admin"))]
590 pub async fn disable_keep_unknown_files(&self) -> Result<()> {
591 match self {
592 DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
593 #[cfg(any(test, feature = "sqlite"))]
594 DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
595 }
596 }
597
598 #[cfg(any(test, feature = "admin"))]
604 pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
605 match self {
606 DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
607 #[cfg(any(test, feature = "sqlite"))]
608 DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
609 }
610 }
611
612 #[cfg(any(test, feature = "admin"))]
618 pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
619 match self {
620 DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
621 #[cfg(any(test, feature = "sqlite"))]
622 DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
623 }
624 }
625
626 #[allow(clippy::too_many_arguments)]
632 #[cfg(any(test, feature = "admin"))]
633 pub async fn create_user(
634 &self,
635 uname: &str,
636 fname: &str,
637 lname: &str,
638 email: &str,
639 password: Option<String>,
640 organisation: Option<&String>,
641 readonly: bool,
642 ) -> Result<u32> {
643 match self {
644 DatabaseType::Postgres(pg) => {
645 pg.create_user(uname, fname, lname, email, password, organisation, readonly)
646 .await
647 }
648 #[cfg(any(test, feature = "sqlite"))]
649 DatabaseType::SQLite(sl) => {
650 sl.create_user(uname, fname, lname, email, password, organisation, readonly)
651 }
652 }
653 }
654
655 #[cfg(any(test, feature = "admin"))]
661 pub async fn reset_api_keys(&self) -> Result<u64> {
662 match self {
663 DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
664 #[cfg(any(test, feature = "sqlite"))]
665 DatabaseType::SQLite(sl) => sl.reset_api_keys(),
666 }
667 }
668
669 #[cfg(any(test, feature = "admin"))]
676 pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
677 match self {
678 DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
679 #[cfg(any(test, feature = "sqlite"))]
680 DatabaseType::SQLite(sl) => sl.set_password(uname, password),
681 }
682 }
683
684 #[cfg(any(test, feature = "admin"))]
690 pub async fn list_users(&self) -> Result<Vec<admin::User>> {
691 match self {
692 DatabaseType::Postgres(pg) => pg.list_users().await,
693 #[cfg(any(test, feature = "sqlite"))]
694 DatabaseType::SQLite(sl) => sl.list_users(),
695 }
696 }
697
698 #[cfg(any(test, feature = "admin"))]
705 pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
706 match self {
707 DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
708 #[cfg(any(test, feature = "sqlite"))]
709 DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
710 }
711 }
712
713 #[cfg(any(test, feature = "admin"))]
720 pub async fn edit_group(
721 &self,
722 gid: u32,
723 name: &str,
724 desc: &str,
725 parent: Option<u32>,
726 ) -> Result<()> {
727 match self {
728 DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
729 #[cfg(any(test, feature = "sqlite"))]
730 DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
731 }
732 }
733
734 #[cfg(any(test, feature = "admin"))]
740 pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
741 match self {
742 DatabaseType::Postgres(pg) => pg.list_groups().await,
743 #[cfg(any(test, feature = "sqlite"))]
744 DatabaseType::SQLite(sl) => sl.list_groups(),
745 }
746 }
747
748 #[cfg(any(test, feature = "admin"))]
755 pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
756 match self {
757 DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
758 #[cfg(any(test, feature = "sqlite"))]
759 DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
760 }
761 }
762
763 #[cfg(any(test, feature = "admin"))]
770 pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
771 match self {
772 DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
773 #[cfg(any(test, feature = "sqlite"))]
774 DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
775 }
776 }
777
778 #[cfg(any(test, feature = "admin"))]
785 pub async fn create_group(
786 &self,
787 name: &str,
788 description: &str,
789 parent: Option<u32>,
790 ) -> Result<u32> {
791 match self {
792 DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
793 #[cfg(any(test, feature = "sqlite"))]
794 DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
795 }
796 }
797
798 #[cfg(any(test, feature = "admin"))]
804 pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
805 match self {
806 DatabaseType::Postgres(pg) => pg.list_sources().await,
807 #[cfg(any(test, feature = "sqlite"))]
808 DatabaseType::SQLite(sl) => sl.list_sources(),
809 }
810 }
811
812 #[cfg(any(test, feature = "admin"))]
819 pub async fn create_source(
820 &self,
821 name: &str,
822 description: Option<&str>,
823 url: Option<&str>,
824 date: chrono::DateTime<Local>,
825 releasable: bool,
826 malicious: Option<bool>,
827 ) -> Result<u32> {
828 match self {
829 DatabaseType::Postgres(pg) => {
830 pg.create_source(name, description, url, date, releasable, malicious)
831 .await
832 }
833 #[cfg(any(test, feature = "sqlite"))]
834 DatabaseType::SQLite(sl) => {
835 sl.create_source(name, description, url, date, releasable, malicious)
836 }
837 }
838 }
839
840 #[cfg(any(test, feature = "admin"))]
847 pub async fn edit_user(
848 &self,
849 uid: u32,
850 uname: &str,
851 fname: &str,
852 lname: &str,
853 email: &str,
854 readonly: bool,
855 ) -> Result<()> {
856 match self {
857 DatabaseType::Postgres(pg) => {
858 pg.edit_user(uid, uname, fname, lname, email, readonly)
859 .await
860 }
861 #[cfg(any(test, feature = "sqlite"))]
862 DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
863 }
864 }
865
866 #[cfg(any(test, feature = "admin"))]
873 pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
874 match self {
875 DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
876 #[cfg(any(test, feature = "sqlite"))]
877 DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
878 }
879 }
880
881 #[cfg(any(test, feature = "admin"))]
887 pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
888 match self {
889 DatabaseType::Postgres(pg) => pg.file_types_counts().await,
890 #[cfg(any(test, feature = "sqlite"))]
891 DatabaseType::SQLite(sl) => sl.file_types_counts(),
892 }
893 }
894
895 #[cfg(any(test, feature = "admin"))]
903 pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
904 match self {
905 DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
906 #[cfg(any(test, feature = "sqlite"))]
907 DatabaseType::SQLite(sl) => sl.create_label(name, parent),
908 }
909 }
910
911 #[cfg(any(test, feature = "admin"))]
919 pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
920 match self {
921 DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
922 #[cfg(any(test, feature = "sqlite"))]
923 DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
924 }
925 }
926
927 #[cfg(any(test, feature = "admin"))]
934 pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
935 match self {
936 DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
937 #[cfg(any(test, feature = "sqlite"))]
938 DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
939 }
940 }
941
942 #[cfg(any(test, feature = "admin"))]
950 pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
951 match self {
952 DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
953 #[cfg(any(test, feature = "sqlite"))]
954 DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
955 }
956 }
957}
958
959pub fn hash_password(password: &str) -> Result<String> {
965 let salt = SaltString::generate(&mut OsRng);
966 let argon2 = Argon2::default();
967 Ok(argon2
968 .hash_password(password.as_bytes(), &salt)?
969 .to_string())
970}
971
972#[must_use]
974pub fn random_bytes_api_key() -> String {
975 let key1 = uuid::Uuid::new_v4();
976 let key2 = uuid::Uuid::new_v4();
977 let key1 = key1.to_string().replace('-', "");
978 let key2 = key2.to_string().replace('-', "");
979 format!("{key1}{key2}")
980}
981
982#[cfg(test)]
983mod tests {
984 use super::*;
985 #[cfg(feature = "vt")]
986 use crate::vt::VtUpdater;
987
988 use std::fs;
989 #[cfg(feature = "vt")]
990 use std::sync::Arc;
991 #[cfg(feature = "vt")]
992 use std::time::SystemTime;
993
994 use anyhow::Context;
995 use fuzzyhash::FuzzyHash;
996 use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
997 use malwaredb_lzjd::{LZDict, Murmur3HashState};
998 use tlsh_fixed::TlshBuilder;
999 use uuid::Uuid;
1000
1001 const MALWARE_LABEL: &str = "malware";
1002 const RANSOMWARE_LABEL: &str = "ransomware";
1003
1004 fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1005 let mut hashes = vec![];
1006
1007 hashes.push((
1008 malwaredb_api::SimilarityHashType::SSDeep,
1009 FuzzyHash::new(data).to_string(),
1010 ));
1011
1012 let mut builder = TlshBuilder::new(
1013 tlsh_fixed::BucketKind::Bucket256,
1014 tlsh_fixed::ChecksumKind::ThreeByte,
1015 tlsh_fixed::Version::Version4,
1016 );
1017
1018 builder.update(data);
1019
1020 if let Ok(hasher) = builder.build() {
1021 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1022 }
1023
1024 let build_hasher = Murmur3HashState::default();
1025 let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1026 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1027
1028 malwaredb_api::SimilarSamplesRequest { hashes }
1029 }
1030
1031 async fn pg_config() -> Postgres {
1032 const CONNECTION_STRING: &str =
1035 "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1036
1037 if let Ok(pg_port) = std::env::var("PG_PORT") {
1038 let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1040 Postgres::new(&conn_string, None)
1041 .await
1042 .context(format!(
1043 "failed to connect to postgres with specified port {pg_port}"
1044 ))
1045 .unwrap()
1046 } else {
1047 Postgres::new(CONNECTION_STRING, None).await.unwrap()
1048 }
1049 }
1050
1051 #[tokio::test]
1052 #[ignore = "don't run this in CI"]
1053 async fn pg() {
1054 let psql = pg_config().await;
1055 psql.delete_init().await.unwrap();
1056
1057 let db = DatabaseType::Postgres(psql);
1058 let key = FileEncryption::from(EncryptionOption::Xor);
1059 db.add_file_encryption_key(&key).await.unwrap();
1060 assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1061 everything(&db).await.unwrap();
1062
1063 #[cfg(feature = "vt")]
1064 {
1065 let db_config = db.get_config().await.unwrap();
1066 let state = crate::State {
1067 port: 8080,
1068 directory: None,
1069 max_upload: 10 * 1024 * 1024,
1070 ip: "127.0.0.1".parse().unwrap(),
1071 db_type: Arc::new(db),
1072 db_config,
1073 keys: HashMap::new(),
1074 started: SystemTime::now(),
1075 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1076 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1077 }),
1078 tls_config: None,
1079 mdns: None,
1080 };
1081
1082 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1083
1084 vt.updater().await.unwrap();
1085 println!("PG: Did VT ops!");
1086
1087 let psql = pg_config().await;
1088
1089 let vt_stats = psql
1090 .get_vt_stats()
1091 .await
1092 .context("failed to get Postgres VT Stats")
1093 .unwrap();
1094 println!("{vt_stats:?}");
1095 assert!(
1096 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1097 );
1098 }
1099
1100 let psql = pg_config().await;
1102 psql.delete_init().await.unwrap();
1103 }
1104
1105 #[tokio::test]
1106 async fn sqlite() {
1107 const DB_FILE: &str = "testing_sqlite.db";
1108 if std::path::Path::new(DB_FILE).exists() {
1109 fs::remove_file(DB_FILE)
1110 .context(format!("failed to delete old SQLite file {DB_FILE}"))
1111 .unwrap();
1112 }
1113
1114 let sqlite = Sqlite::new(DB_FILE)
1115 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1116 .unwrap();
1117
1118 let db = DatabaseType::SQLite(sqlite);
1119 let key = FileEncryption::from(EncryptionOption::Xor);
1120 db.add_file_encryption_key(&key).await.unwrap();
1121 assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1122 everything(&db).await.unwrap();
1123
1124 #[cfg(feature = "vt")]
1125 {
1126 let db_config = db.get_config().await.unwrap();
1127 let state = crate::State {
1128 port: 8080,
1129 directory: None,
1130 max_upload: 10 * 1024 * 1024,
1131 ip: "127.0.0.1".parse().unwrap(),
1132 db_type: Arc::new(db),
1133 db_config,
1134 keys: HashMap::new(),
1135 started: SystemTime::now(),
1136 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1137 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1138 }),
1139 tls_config: None,
1140 mdns: None,
1141 };
1142
1143 let sqlite_second = Sqlite::new(DB_FILE)
1144 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1145 .unwrap();
1146
1147 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1148
1149 vt.updater().await.unwrap();
1150 println!("Sqlite: Did VT ops!");
1151 let vt_stats = sqlite_second
1152 .get_vt_stats()
1153 .context("failed to get Sqlite VT Stats")
1154 .unwrap();
1155 println!("{vt_stats:?}");
1156 assert!(
1157 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1158 );
1159 }
1160
1161 fs::remove_file(DB_FILE)
1162 .context(format!("failed to delete SQLite file {DB_FILE}"))
1163 .unwrap();
1164 }
1165
1166 #[allow(clippy::too_many_lines)]
1167 async fn everything(db: &DatabaseType) -> Result<()> {
1168 const ADMIN_UNAME: &str = "admin";
1169 const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1170
1171 db.set_name("Testing Database")
1172 .await
1173 .context("setting instance name failed")?;
1174
1175 assert!(
1176 db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1177 "Authentication without password should have failed."
1178 );
1179
1180 db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1181 .await
1182 .context("failed to set admin password")?;
1183
1184 let admin_api_key = db
1185 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1186 .await
1187 .context("unable to get api key for admin")?;
1188 println!("API key: {admin_api_key}");
1189 assert_eq!(admin_api_key.len(), 64);
1190
1191 assert_eq!(
1192 db.get_uid(&admin_api_key).await?,
1193 0,
1194 "Unable to get UID given the API key"
1195 );
1196
1197 let admin_api_key_again = db
1198 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1199 .await
1200 .context("unable to get api key a second time for admin")?;
1201
1202 assert_eq!(
1203 admin_api_key, admin_api_key_again,
1204 "API keys didn't match the second time."
1205 );
1206
1207 let bad_password = "this_is_totally_not_my_password!!";
1208 eprintln!("Testing API login with incorrect password.");
1209 assert!(
1210 db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1211 "Authenticating as admin with a bad password should have failed."
1212 );
1213
1214 let admin_is_admin = db
1215 .user_is_admin(0)
1216 .await
1217 .context("unable to see if admin (uid 0) is an admin")?;
1218 assert!(admin_is_admin);
1219
1220 let new_user_uname = "testuser";
1221 let new_user_email = "test@example.com";
1222 let new_user_password = "some_awesome_password_++";
1223 let new_id = db
1224 .create_user(
1225 new_user_uname,
1226 new_user_uname,
1227 new_user_uname,
1228 new_user_email,
1229 Some(new_user_password.into()),
1230 None,
1231 false,
1232 )
1233 .await
1234 .context(format!("failed to create user {new_user_uname}"))?;
1235
1236 let passwordless_user_id = db
1237 .create_user(
1238 "passwordless_user",
1239 "passwordless_user",
1240 "passwordless_user",
1241 "passwordless_user@example.com",
1242 None,
1243 None,
1244 false,
1245 )
1246 .await
1247 .context("failed to create passwordless_user")?;
1248
1249 for user in &db.list_users().await.context("failed to list users")? {
1250 if user.id == passwordless_user_id {
1251 assert_eq!(user.uname, "passwordless_user");
1252 }
1253 }
1254
1255 db.edit_user(
1256 passwordless_user_id,
1257 "passwordless_user_2",
1258 "passwordless_user_2",
1259 "passwordless_user_2",
1260 "passwordless_user_2@something.com",
1261 false,
1262 )
1263 .await
1264 .context(format!(
1265 "failed to alter 'passwordless' user, id {passwordless_user_id}"
1266 ))?;
1267
1268 for user in &db.list_users().await.context("failed to list users")? {
1269 if user.id == passwordless_user_id {
1270 assert_eq!(user.uname, "passwordless_user_2");
1271 }
1272 }
1273
1274 assert!(
1275 new_id > 0,
1276 "Weird UID created for user {new_user_uname}: {new_id}"
1277 );
1278
1279 assert!(
1280 db.create_user(
1281 new_user_uname,
1282 new_user_uname,
1283 new_user_uname,
1284 new_user_email,
1285 Some(new_user_password.into()),
1286 None,
1287 false
1288 )
1289 .await
1290 .is_err(),
1291 "Creating a new user with the same user name should fail"
1292 );
1293
1294 let ro_user_name = "ro_user";
1295 let ro_user_password = "ro_user_password";
1296 db.create_user(
1297 ro_user_name,
1298 "ro_user",
1299 "ro_user",
1300 "ro@example.com",
1301 Some(ro_user_password.into()),
1302 None,
1303 true,
1304 )
1305 .await
1306 .context("failed to create read-only user")?;
1307
1308 let ro_user_api_key = db
1309 .authenticate(ro_user_name, ro_user_password)
1310 .await
1311 .context("unable to get api key for read-only user")?;
1312
1313 let new_user_password_change = "some_new_awesomer_password!_++";
1314 db.set_password(new_user_uname, new_user_password_change)
1315 .await
1316 .context("failed to change the password for testuser")?;
1317
1318 let new_user_api_key = db
1319 .authenticate(new_user_uname, new_user_password_change)
1320 .await
1321 .context("unable to get api key for testuser")?;
1322 eprintln!("{new_user_uname} got API key {new_user_api_key}");
1323
1324 assert_eq!(admin_api_key.len(), new_user_api_key.len());
1325
1326 let users = db.list_users().await.context("failed to list users")?;
1327 assert_eq!(
1328 users.len(),
1329 4,
1330 "Four users were created, yet there are {} users",
1331 users.len()
1332 );
1333 eprintln!("DB has {} users:", users.len());
1334 let mut passwordless_user_found = false;
1335 for user in users {
1336 println!("{user}");
1337 if user.uname == "passwordless_user_2" {
1338 assert!(!user.has_api_key);
1339 assert!(!user.has_password);
1340 passwordless_user_found = true;
1341 } else {
1342 assert!(user.has_api_key);
1343 assert!(user.has_password);
1344 }
1345 }
1346 assert!(passwordless_user_found);
1347
1348 let new_group_name = "some_new_group";
1349 let new_group_desc = "some_new_group_description";
1350 let new_group_id = 1;
1351 assert_eq!(
1352 db.create_group(new_group_name, new_group_desc, None)
1353 .await
1354 .context("failed to create group")?,
1355 new_group_id,
1356 "New group didn't have the expected ID, expected {new_group_id}"
1357 );
1358
1359 assert!(
1360 db.create_group(new_group_name, new_group_desc, None)
1361 .await
1362 .is_err(),
1363 "Duplicate group name should have failed"
1364 );
1365
1366 db.add_user_to_group(1, 1)
1367 .await
1368 .context("Unable to add uid 1 to gid 1")?;
1369
1370 let ro_user_uid = db
1371 .get_uid(&ro_user_api_key)
1372 .await
1373 .context("Unable to get UID for read-only user")?;
1374 db.add_user_to_group(ro_user_uid, 1)
1375 .await
1376 .context("Unable to add uid 2 to gid 1")?;
1377
1378 let new_admin_group_name = "admin_subgroup";
1379 let new_admin_group_desc = "admin_subgroup_description";
1380 let new_admin_group_id = 2;
1381 assert!(
1383 db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1384 .await
1385 .context("failed to create admin sub-group")?
1386 >= new_admin_group_id,
1387 "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1388 );
1389
1390 let groups = db.list_groups().await.context("failed to list groups")?;
1391 assert_eq!(
1392 groups.len(),
1393 3,
1394 "Three groups were created, yet there are {} groups",
1395 groups.len()
1396 );
1397 eprintln!("DB has {} groups:", groups.len());
1398 for group in groups {
1399 println!("{group}");
1400 if group.id == new_admin_group_id {
1401 assert_eq!(group.parent, Some("admin".to_string()));
1402 }
1403 if group.id == 1 {
1404 let test_user_str = String::from(new_user_uname);
1405 let mut found = false;
1406 for member in group.members {
1407 if member.uname == test_user_str {
1408 found = true;
1409 break;
1410 }
1411 }
1412 assert!(found, "new user {test_user_str} wasn't in the group");
1413 }
1414 }
1415
1416 let default_source_name = "default_source".to_string();
1417 let default_source_id = db
1418 .create_source(
1419 &default_source_name,
1420 Some("desc_default_source"),
1421 None,
1422 Local::now(),
1423 true,
1424 Some(false),
1425 )
1426 .await
1427 .context("failed to create source `default_source`")?;
1428
1429 db.add_group_to_source(1, default_source_id)
1430 .await
1431 .context("failed to add group 1 to source 1")?;
1432
1433 let another_source_name = "another_source".to_string();
1434 let another_source_id = db
1435 .create_source(
1436 &another_source_name,
1437 Some("yet another file source"),
1438 None,
1439 Local::now(),
1440 true,
1441 Some(false),
1442 )
1443 .await
1444 .context("failed to create source `another_source`")?;
1445
1446 let empty_source_name = "empty_source".to_string();
1447 db.create_source(
1448 &empty_source_name,
1449 Some("empty and unused file source"),
1450 None,
1451 Local::now(),
1452 true,
1453 Some(false),
1454 )
1455 .await
1456 .context("failed to create source `another_source`")?;
1457
1458 db.add_group_to_source(1, another_source_id)
1459 .await
1460 .context("failed to add group 1 to source 1")?;
1461
1462 let sources = db.list_sources().await.context("failed to list sources")?;
1463 eprintln!("DB has {} sources:", sources.len());
1464 for source in sources {
1465 println!("{source}");
1466 assert_eq!(source.files, 0);
1467 if source.id == default_source_id || source.id == another_source_id {
1468 assert_eq!(
1469 source.groups, 1,
1470 "default source {default_source_name} should have 1 group"
1471 );
1472 } else {
1473 assert_eq!(source.groups, 0, "groups should zero (empty)");
1474 }
1475 }
1476
1477 let uid = db
1478 .get_uid(&new_user_api_key)
1479 .await
1480 .context("failed to user uid from apikey")?;
1481 let user_info = db
1482 .get_user_info(uid)
1483 .await
1484 .context("failed to get user's available groups and sources")?;
1485 assert!(user_info.sources.contains(&default_source_name));
1486 assert!(!user_info.is_admin);
1487 println!("UserInfoResponse: {user_info:?}");
1488
1489 assert!(
1490 db.allowed_user_source(1, default_source_id)
1491 .await
1492 .context(format!(
1493 "failed to check that user 1 has access to source {default_source_id}"
1494 ))?,
1495 "User 1 should should have had access to source {default_source_id}"
1496 );
1497
1498 assert!(
1499 !db.allowed_user_source(1, 5)
1500 .await
1501 .context("failed to check that user 1 has access to source 5")?,
1502 "User 1 should should not have had access to source 5"
1503 );
1504
1505 let test_label_id = db
1506 .create_label("TestLabel", None)
1507 .await
1508 .context("failed to create test label")?;
1509 let test_elf_label_id = db
1510 .create_label("TestELF", Some(test_label_id))
1511 .await
1512 .context("failed to create test label")?;
1513
1514 let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1515 let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1516 let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1517
1518 let known_type =
1519 KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1520 assert!(known_type.is_exec(), "ELF should be executable");
1521 eprintln!("ELF type ID: {elf_type}");
1522
1523 let file_addition = db
1524 .add_file(
1525 &test_elf_meta,
1526 known_type.clone(),
1527 1,
1528 default_source_id,
1529 elf_type,
1530 None,
1531 )
1532 .await
1533 .context("failed to insert a test elf")?;
1534 assert!(file_addition.is_new, "File should have been added");
1535 eprintln!("Added ELF to the DB");
1536 db.label_file(file_addition.file_id, test_elf_label_id)
1537 .await
1538 .context("failed to label file")?;
1539
1540 let partial_search = SearchRequest {
1542 search: SearchType::Search(SearchRequestParameters {
1543 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1544 labels: Some(vec![String::from("TestELF")]),
1545 file_type: Some(String::from("ELF")),
1546 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1547 ..Default::default()
1548 }),
1549 };
1550 assert!(partial_search.is_valid());
1551 let partial_search_response = db.partial_search(1, partial_search).await?;
1552 assert_eq!(partial_search_response.hashes.len(), 1);
1553 assert_eq!(
1554 partial_search_response.hashes[0],
1555 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1556 );
1557
1558 let partial_search = SearchRequest {
1560 search: SearchType::Search(SearchRequestParameters {
1561 partial_hash: None,
1562 labels: None,
1563 file_type: None,
1564 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1565 ..Default::default()
1566 }),
1567 };
1568 assert!(partial_search.is_valid());
1569 let partial_search_response = db.partial_search(1, partial_search).await?;
1570 assert_eq!(partial_search_response.hashes.len(), 1);
1571 assert_eq!(
1572 partial_search_response.hashes[0],
1573 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1574 );
1575
1576 let partial_search = SearchRequest {
1578 search: SearchType::Search(SearchRequestParameters {
1579 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1580 file_type: Some(String::from("PE32")),
1581 ..Default::default()
1582 }),
1583 };
1584 assert!(partial_search.is_valid());
1585 let partial_search_response = db.partial_search(1, partial_search).await?;
1586 assert_eq!(partial_search_response.hashes.len(), 0);
1587
1588 let partial_search = SearchRequest {
1589 search: SearchType::Search(SearchRequestParameters {
1590 file_name: Some("ppc64".into()),
1591 ..Default::default()
1592 }),
1593 };
1594 assert!(partial_search.is_valid());
1595 let partial_search_response = db.partial_search(1, partial_search).await?;
1596 assert_eq!(partial_search_response.hashes.len(), 1);
1597
1598 let partial_search = SearchRequest {
1600 search: SearchType::Search(SearchRequestParameters::default()),
1601 };
1602 assert!(!partial_search.is_valid());
1603 let partial_search_response = db.partial_search(1, partial_search).await?;
1604 assert!(partial_search_response.hashes.is_empty());
1605
1606 let partial_search = SearchRequest {
1608 search: SearchType::Continuation(Uuid::default()),
1609 };
1610 assert!(partial_search.is_valid());
1611 let partial_search_response = db.partial_search(1, partial_search).await?;
1612 assert!(partial_search_response.hashes.is_empty());
1613
1614 assert!(db
1616 .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1617 .await
1618 .is_err());
1619
1620 assert!(
1621 db.add_file(
1622 &test_elf_meta,
1623 known_type.clone(),
1624 ro_user_uid,
1625 default_source_id,
1626 elf_type,
1627 None
1628 )
1629 .await
1630 .is_err(),
1631 "Read-only user should not be able to add a file"
1632 );
1633
1634 let mut test_elf_meta_different_name = test_elf_meta.clone();
1635 test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1636
1637 assert!(
1638 !db.add_file(
1639 &test_elf_meta_different_name,
1640 known_type,
1641 1,
1642 another_source_id,
1643 elf_type,
1644 None
1645 )
1646 .await
1647 .context("failed to insert a test elf again for a different source")?
1648 .is_new
1649 );
1650
1651 let sources = db
1652 .list_sources()
1653 .await
1654 .context("failed to re-list sources")?;
1655 eprintln!(
1656 "DB has {} sources, and a file was added twice:",
1657 sources.len()
1658 );
1659 println!("We should have two sources with one file each, yet only one ELF.");
1660 for source in sources {
1661 println!("{source}");
1662 if source.id == default_source_id || source.id == another_source_id {
1663 assert_eq!(source.files, 1);
1664 } else {
1665 assert_eq!(source.files, 0, "groups should zero (empty)");
1666 }
1667 }
1668
1669 assert!(!db
1670 .get_user_sources(1)
1671 .await
1672 .expect("failed to get user 1's sources")
1673 .sources
1674 .is_empty());
1675
1676 let file_types_counts = db
1677 .file_types_counts()
1678 .await
1679 .context("failed to get file types and counts")?;
1680 for (name, count) in file_types_counts {
1681 println!("{name}: {count}");
1682 assert_eq!(name, "ELF");
1683 assert_eq!(count, 1);
1684 }
1685
1686 let mut test_elf_modified = test_elf.clone();
1687 let random_bytes = Uuid::new_v4();
1688 let mut random_bytes = random_bytes.into_bytes().to_vec();
1689 test_elf_modified.append(&mut random_bytes);
1690 let similarity_request = generate_similarity_request(&test_elf_modified);
1691 let similarity_response = db
1692 .find_similar_samples(1, &similarity_request.hashes)
1693 .await
1694 .context("failed to get similarity response")?;
1695 eprintln!("Similarity response: {similarity_response:?}");
1696 let similarity_response = similarity_response.first().unwrap();
1697 assert_eq!(
1698 similarity_response.sha256, test_elf_meta.sha256,
1699 "Similarity response should have had the hash of the original ELF"
1700 );
1701 for (algo, sim) in &similarity_response.algorithms {
1702 match algo {
1703 malwaredb_api::SimilarityHashType::LZJD => {
1704 assert!(*sim > 0.0f32);
1705 }
1706 malwaredb_api::SimilarityHashType::SSDeep => {
1707 assert!(*sim > 80.0f32);
1708 }
1709 malwaredb_api::SimilarityHashType::TLSH => {
1710 assert!(*sim <= 20f32);
1711 }
1712 _ => {}
1713 }
1714 }
1715
1716 let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1.as_str())
1717 .context("failed to get `HashType::SHA1` from string")?;
1718 let response_sha256 = db
1719 .retrieve_sample(1, &test_elf_hashtype)
1720 .await
1721 .context("could not get SHA-256 hash from test sample")
1722 .unwrap();
1723 assert_eq!(response_sha256, test_elf_meta.sha256);
1724
1725 let test_bogus_hash =
1726 HashType::try_from("d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0")
1727 .context("failed to get `HashType` from static string")?;
1728 assert!(
1729 db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1730 "Getting a file with a bogus hash should have failed."
1731 );
1732
1733 let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1734 let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1735 let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1736
1737 let known_type =
1738 KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1739
1740 assert!(
1741 db.add_file(
1742 &test_pdf_meta,
1743 known_type,
1744 1,
1745 default_source_id,
1746 pdf_type,
1747 None
1748 )
1749 .await
1750 .context("failed to insert a test pdf")?
1751 .is_new
1752 );
1753 eprintln!("Added PDF to the DB");
1754
1755 let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1756 let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1757 let rtf_type = db
1758 .get_type_id_for_bytes(&test_rtf)
1759 .await
1760 .context("failed to get file type id for rtf")?;
1761
1762 let known_type =
1763 KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1764
1765 assert!(
1766 db.add_file(
1767 &test_rtf_meta,
1768 known_type,
1769 1,
1770 default_source_id,
1771 rtf_type,
1772 None
1773 )
1774 .await
1775 .context("failed to insert a test rtf")?
1776 .is_new
1777 );
1778 eprintln!("Added RTF to the DB");
1779
1780 let report = db
1781 .get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.as_str())?)
1782 .await
1783 .context("failed to get report for test rtf")?;
1784 assert!(report
1785 .clone()
1786 .filecommand
1787 .unwrap()
1788 .contains("Rich Text Format"));
1789 println!("Report: {report}");
1790
1791 assert!(db
1792 .get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256.as_str())?)
1793 .await
1794 .is_err());
1795
1796 #[cfg(feature = "vt")]
1797 {
1798 assert!(report.vt.is_some());
1799 let files_needing_vt = db
1800 .files_without_vt_records(10)
1801 .await
1802 .context("failed to get files without VT records")?;
1803 assert!(files_needing_vt.len() > 2);
1804 println!(
1805 "{} files needing VT data: {files_needing_vt:?}",
1806 files_needing_vt.len()
1807 );
1808 }
1809
1810 #[cfg(not(feature = "vt"))]
1811 {
1812 assert!(report.vt.is_none());
1813 }
1814
1815 let reset = db
1816 .reset_api_keys()
1817 .await
1818 .context("failed to reset all API keys")?;
1819 eprintln!("Cleared {reset} api keys.");
1820
1821 let db_info = db.db_info().await.context("failed to get database info")?;
1822 eprintln!("DB Info: {db_info:?}");
1823
1824 let data_types = db
1825 .get_known_data_types()
1826 .await
1827 .context("failed to get data types")?;
1828 for data_type in data_types {
1829 println!("{data_type:?}");
1830 }
1831
1832 let sources = db
1833 .list_sources()
1834 .await
1835 .context("failed to list sources second time")?;
1836 eprintln!("DB has {} sources:", sources.len());
1837 for source in sources {
1838 println!("{source}");
1839 }
1840
1841 let file_types_counts = db
1842 .file_types_counts()
1843 .await
1844 .context("failed to get file types and counts")?;
1845 for (name, count) in file_types_counts {
1846 println!("{name}: {count}");
1847 assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1848 }
1849
1850 let fatmacho =
1851 include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
1852 .to_vec();
1853 let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
1854 let fatmacho_type = db
1855 .get_type_id_for_bytes(&fatmacho)
1856 .await
1857 .context("failed to get file type for Fat Mach-O")?;
1858 let known_type = KnownType::new(&fatmacho)
1859 .context("failed to parse Fat Mach-O from type crate's test data")?;
1860
1861 assert!(
1862 db.add_file(
1863 &fatmacho_meta,
1864 known_type,
1865 1,
1866 default_source_id,
1867 fatmacho_type,
1868 None
1869 )
1870 .await
1871 .context("failed to insert a test Fat Mach-O")?
1872 .is_new
1873 );
1874 eprintln!("Added Fat Mach-O to the DB");
1875
1876 let file_types_counts = db
1877 .file_types_counts()
1878 .await
1879 .context("failed to get file types and counts")?;
1880 for (name, count) in &file_types_counts {
1881 println!("{name}: {count}");
1882 }
1883
1884 assert_eq!(
1885 *file_types_counts.get("Mach-O").unwrap(),
1886 4,
1887 "Expected 4 Mach-O files, got {:?}",
1888 file_types_counts.get("Mach-O")
1889 );
1890
1891 let malware_label_id = db
1892 .create_label(MALWARE_LABEL, None)
1893 .await
1894 .context("failed to create first label")?;
1895 let ransomware_label_id = db
1896 .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
1897 .await
1898 .context("failed to create malware sub-label")?;
1899 let labels = db.get_labels().await.context("failed to get labels")?;
1900
1901 assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
1902 for label in labels.0 {
1903 if label.name == RANSOMWARE_LABEL {
1904 assert_eq!(label.id, ransomware_label_id);
1905 assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
1906 }
1907 }
1908
1909 let source_code = include_bytes!("mod.rs");
1911 let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
1912 let known_type =
1913 KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
1914
1915 assert!(matches!(known_type, KnownType::Unknown(_)));
1916
1917 let unknown_type: Vec<FileType> = db
1918 .get_known_data_types()
1919 .await?
1920 .into_iter()
1921 .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
1922 .collect();
1923 let unknown_type_id = unknown_type.first().unwrap().id;
1924 assert!(db.get_type_id_for_bytes(source_code).await.is_err());
1925 db.enable_keep_unknown_files()
1926 .await
1927 .context("failed to enable keeping of unknown files")?;
1928 let source_type = db
1929 .get_type_id_for_bytes(source_code)
1930 .await
1931 .context("failed to type id for source code unknown type example")?;
1932 assert_eq!(source_type, unknown_type_id);
1933 eprintln!("Unknown file type ID: {source_type}");
1934 assert!(
1935 db.add_file(
1936 &source_meta,
1937 known_type,
1938 1,
1939 default_source_id,
1940 unknown_type_id,
1941 None
1942 )
1943 .await
1944 .context("failed to add Rust source code file")?
1945 .is_new
1946 );
1947 eprintln!("Added Rust source code to the DB");
1948
1949 db.reset_own_api_key(0)
1950 .await
1951 .context("failed to clear own API key uid 0")?;
1952
1953 db.deactivate_user(0)
1954 .await
1955 .context("failed to clear password and API key for uid 0")?;
1956
1957 Ok(())
1958 }
1959}