1#[cfg(any(test, feature = "admin"))]
10mod admin;
11mod pg;
13
14#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33 digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51#[derive(Copy, Clone)]
53pub enum Migration {
54 Check,
56
57 #[cfg(any(test, feature = "admin"))]
59 Migrate,
60}
61
62#[derive(Debug)]
64pub enum DatabaseType {
65 Postgres(Postgres),
67
68 #[cfg(any(test, feature = "sqlite"))]
70 SQLite(Sqlite),
71}
72
73#[derive(Debug)]
75pub struct DatabaseInformation {
76 pub version: String,
78
79 pub size: String,
81
82 pub num_files: u64,
84
85 pub num_users: u32,
87
88 pub num_groups: u32,
90
91 pub num_sources: u32,
93}
94
95pub struct FileAddedResult {
97 pub file_id: u64,
99
100 pub is_new: bool,
103}
104
105#[derive(Debug)]
107pub struct MDBConfig {
108 pub name: String,
110
111 pub compression: bool,
113
114 pub send_samples_to_vt: bool,
116
117 pub keep_unknown_files: bool,
119
120 pub(crate) default_key: Option<u32>,
122}
123
124#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
126#[cfg(feature = "vt")]
127#[derive(Debug, Clone, Copy)]
128pub struct VtStats {
129 pub clean_records: u32,
131
132 pub hits_records: u32,
134
135 pub files_without_records: u32,
137}
138
139impl DatabaseType {
140 pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
149 #[cfg(any(test, feature = "sqlite"))]
150 if arg.starts_with("file:") {
151 let new_conn_str = arg.trim_start_matches("file:");
152 let db = DatabaseType::SQLite(Sqlite::new(new_conn_str)?);
153 db.migrate(Migration::Check).await?;
154 return Ok(db);
155 }
156
157 if arg.starts_with("postgres") {
158 let new_conn_str = arg.trim_start_matches("postgres");
159 let db = DatabaseType::Postgres(Postgres::new(new_conn_str, server_ca).await?);
160 db.migrate(Migration::Check).await?;
161 return Ok(db);
162 }
163
164 bail!("unknown database type `{arg}`")
165 }
166
167 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
173 #[cfg(feature = "vt")]
174 pub async fn enable_vt_upload(&self) -> Result<()> {
175 match self {
176 DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
177 #[cfg(any(test, feature = "sqlite"))]
178 DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
179 }
180 }
181
182 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
188 #[cfg(feature = "vt")]
189 pub async fn disable_vt_upload(&self) -> Result<()> {
190 match self {
191 DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
192 #[cfg(any(test, feature = "sqlite"))]
193 DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
194 }
195 }
196
197 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
203 #[cfg(feature = "vt")]
204 pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
205 match self {
206 DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
207 #[cfg(any(test, feature = "sqlite"))]
208 DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
209 }
210 }
211
212 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
218 #[cfg(feature = "vt")]
219 pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
220 match self {
221 DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
222 #[cfg(any(test, feature = "sqlite"))]
223 DatabaseType::SQLite(sl) => sl.store_vt_record(results),
224 }
225 }
226
227 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
233 #[cfg(feature = "vt")]
234 pub async fn get_vt_stats(&self) -> Result<VtStats> {
235 match self {
236 DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
237 #[cfg(any(test, feature = "sqlite"))]
238 DatabaseType::SQLite(sl) => sl.get_vt_stats(),
239 }
240 }
241
242 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
248 #[cfg(feature = "yara")]
249 pub async fn add_yara_search(
250 &self,
251 uid: u32,
252 yara_string: &str,
253 yara_bytes: &[u8],
254 ) -> Result<uuid::Uuid> {
255 match self {
256 DatabaseType::Postgres(pg) => pg.add_yara_search(uid, yara_string, yara_bytes).await,
257 #[cfg(any(test, feature = "sqlite"))]
258 DatabaseType::SQLite(sl) => sl.add_yara_search(uid, yara_string, yara_bytes),
259 }
260 }
261
262 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
268 #[cfg(feature = "yara")]
269 pub async fn get_unfinished_yara_tasks(&self) -> Result<Vec<crate::yara::YaraTask>> {
270 match self {
271 DatabaseType::Postgres(pg) => pg.get_unfinished_yara_tasks().await,
272 #[cfg(any(test, feature = "sqlite"))]
273 DatabaseType::SQLite(sl) => sl.get_unfinished_yara_tasks(),
274 }
275 }
276
277 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
283 #[cfg(feature = "yara")]
284 pub async fn add_yara_match(
285 &self,
286 id: uuid::Uuid,
287 rule_name: &str,
288 file_sha256: &str,
289 ) -> Result<()> {
290 match self {
291 DatabaseType::Postgres(pg) => pg.add_yara_match(id, rule_name, file_sha256).await,
292 #[cfg(any(test, feature = "sqlite"))]
293 DatabaseType::SQLite(sl) => sl.add_yara_match(id, rule_name, file_sha256),
294 }
295 }
296
297 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
303 #[cfg(feature = "yara")]
304 pub async fn mark_yara_task_as_finished(&self, id: uuid::Uuid) -> Result<()> {
305 match self {
306 DatabaseType::Postgres(pg) => pg.mark_yara_task_as_finished(id).await,
307 #[cfg(any(test, feature = "sqlite"))]
308 DatabaseType::SQLite(sl) => sl.mark_yara_task_as_finished(id),
309 }
310 }
311
312 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
318 #[cfg(feature = "yara")]
319 pub async fn yara_add_next_file_id(&self, id: uuid::Uuid, file_id: u64) -> Result<()> {
320 match self {
321 DatabaseType::Postgres(pg) => pg.yara_add_next_file_id(id, file_id).await,
322 #[cfg(any(test, feature = "sqlite"))]
323 DatabaseType::SQLite(sl) => sl.yara_add_next_file_id(id, file_id),
324 }
325 }
326
327 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
333 #[cfg(feature = "yara")]
334 pub async fn get_yara_results(
335 &self,
336 id: uuid::Uuid,
337 user_id: u32,
338 ) -> Result<malwaredb_api::YaraSearchResponse> {
339 match self {
340 DatabaseType::Postgres(pg) => pg.get_yara_results(id, user_id).await,
341 #[cfg(any(test, feature = "sqlite"))]
342 DatabaseType::SQLite(sl) => sl.get_yara_results(id, user_id),
343 }
344 }
345
346 pub async fn get_config(&self) -> Result<MDBConfig> {
352 match self {
353 DatabaseType::Postgres(pg) => pg.get_config().await,
354 #[cfg(any(test, feature = "sqlite"))]
355 DatabaseType::SQLite(sl) => sl.get_config(),
356 }
357 }
358
359 pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
366 match self {
367 DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
368 #[cfg(any(test, feature = "sqlite"))]
369 DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
370 }
371 }
372
373 pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
380 ensure!(!apikey.is_empty(), "API key was empty");
381 match self {
382 DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
383 #[cfg(any(test, feature = "sqlite"))]
384 DatabaseType::SQLite(sl) => sl.get_uid(apikey),
385 }
386 }
387
388 pub async fn db_info(&self) -> Result<DatabaseInformation> {
394 match self {
395 DatabaseType::Postgres(pg) => pg.db_info().await,
396 #[cfg(any(test, feature = "sqlite"))]
397 DatabaseType::SQLite(sl) => sl.db_info(),
398 }
399 }
400
401 pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
408 match self {
409 DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
410 #[cfg(any(test, feature = "sqlite"))]
411 DatabaseType::SQLite(sl) => sl.get_user_info(uid),
412 }
413 }
414
415 pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
422 match self {
423 DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
424 #[cfg(any(test, feature = "sqlite"))]
425 DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
426 }
427 }
428
429 pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
436 match self {
437 DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
438 #[cfg(any(test, feature = "sqlite"))]
439 DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
440 }
441 }
442
443 pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
449 match self {
450 DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
451 #[cfg(any(test, feature = "sqlite"))]
452 DatabaseType::SQLite(sl) => sl.get_known_data_types(),
453 }
454 }
455
456 pub async fn get_labels(&self) -> Result<Labels> {
462 match self {
463 DatabaseType::Postgres(pg) => pg.get_labels().await,
464 #[cfg(any(test, feature = "sqlite"))]
465 DatabaseType::SQLite(sl) => sl.get_labels(),
466 }
467 }
468
469 pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
475 match self {
476 DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
477 #[cfg(any(test, feature = "sqlite"))]
478 DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
479 }
480 }
481
482 pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
489 match self {
490 DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
491 #[cfg(any(test, feature = "sqlite"))]
492 DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
493 }
494 }
495
496 pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
504 match self {
505 DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
506 #[cfg(any(test, feature = "sqlite"))]
507 DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
508 }
509 }
510
511 pub async fn add_file(
518 &self,
519 meta: &FileMetadata,
520 known_type: KnownType<'_>,
521 uid: u32,
522 sid: u32,
523 ftype: u32,
524 parent: Option<u64>,
525 ) -> Result<FileAddedResult> {
526 match self {
527 DatabaseType::Postgres(pg) => {
528 pg.add_file(meta, known_type, uid, sid, ftype, parent).await
529 }
530 #[cfg(any(test, feature = "sqlite"))]
531 DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
532 }
533 }
534
535 pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
541 match self {
542 DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
543 #[cfg(any(test, feature = "sqlite"))]
544 DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
545 }
546 }
547
548 pub async fn cleanup(&self) -> Result<u64> {
554 match self {
555 DatabaseType::Postgres(pg) => pg.cleanup().await,
556 #[cfg(any(test, feature = "sqlite"))]
557 DatabaseType::SQLite(sl) => sl.cleanup(),
558 }
559 }
560
561 pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
569 match self {
570 DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
571 #[cfg(any(test, feature = "sqlite"))]
572 DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
573 }
574 }
575
576 pub async fn get_sample_report(
583 &self,
584 uid: u32,
585 hash: &HashType,
586 ) -> Result<malwaredb_api::Report> {
587 match self {
588 DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
589 #[cfg(any(test, feature = "sqlite"))]
590 DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
591 }
592 }
593
594 pub async fn find_similar_samples(
600 &self,
601 uid: u32,
602 sim: &[(malwaredb_api::SimilarityHashType, String)],
603 ) -> Result<Vec<malwaredb_api::SimilarSample>> {
604 match self {
605 DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
606 #[cfg(any(test, feature = "sqlite"))]
607 DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
608 }
609 }
610
611 pub async fn user_allowed_files_by_sha256(
618 &self,
619 uid: u32,
620 next: Option<u64>,
621 ) -> Result<(Vec<String>, u64)> {
622 match self {
623 DatabaseType::Postgres(pg) => pg.user_allowed_files_by_sha256(uid, next).await,
624 #[cfg(any(test, feature = "sqlite"))]
625 DatabaseType::SQLite(sl) => sl.user_allowed_files_by_sha256(uid, next),
626 }
627 }
628
629 pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
633 match self {
634 DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
635 #[cfg(any(test, feature = "sqlite"))]
636 DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
637 }
638 }
639
640 pub(crate) async fn get_file_encryption_key_id(
642 &self,
643 hash: &str,
644 ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
645 match self {
646 DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
647 #[cfg(any(test, feature = "sqlite"))]
648 DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
649 }
650 }
651
652 pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
654 match self {
655 DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
656 #[cfg(any(test, feature = "sqlite"))]
657 DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
658 }
659 }
660
661 pub async fn migrate(&self, action: Migration) -> Result<()> {
668 match self {
669 DatabaseType::Postgres(pg) => pg.migrate(action).await,
670 #[cfg(any(test, feature = "sqlite"))]
671 DatabaseType::SQLite(sl) => sl.migrate(action),
672 }
673 }
674
675 #[cfg(any(test, feature = "admin"))]
683 pub async fn set_name(&self, name: &str) -> Result<()> {
684 match self {
685 DatabaseType::Postgres(pg) => pg.set_name(name).await,
686 #[cfg(any(test, feature = "sqlite"))]
687 DatabaseType::SQLite(sl) => sl.set_name(name),
688 }
689 }
690
691 #[cfg(any(test, feature = "admin"))]
697 pub async fn enable_compression(&self) -> Result<()> {
698 match self {
699 DatabaseType::Postgres(pg) => pg.enable_compression().await,
700 #[cfg(any(test, feature = "sqlite"))]
701 DatabaseType::SQLite(sl) => sl.enable_compression(),
702 }
703 }
704
705 #[cfg(any(test, feature = "admin"))]
711 pub async fn disable_compression(&self) -> Result<()> {
712 match self {
713 DatabaseType::Postgres(pg) => pg.disable_compression().await,
714 #[cfg(any(test, feature = "sqlite"))]
715 DatabaseType::SQLite(sl) => sl.disable_compression(),
716 }
717 }
718
719 #[cfg(any(test, feature = "admin"))]
725 pub async fn enable_keep_unknown_files(&self) -> Result<()> {
726 match self {
727 DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
728 #[cfg(any(test, feature = "sqlite"))]
729 DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
730 }
731 }
732
733 #[cfg(any(test, feature = "admin"))]
739 pub async fn disable_keep_unknown_files(&self) -> Result<()> {
740 match self {
741 DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
742 #[cfg(any(test, feature = "sqlite"))]
743 DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
744 }
745 }
746
747 #[cfg(any(test, feature = "admin"))]
753 pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
754 match self {
755 DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
756 #[cfg(any(test, feature = "sqlite"))]
757 DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
758 }
759 }
760
761 #[cfg(any(test, feature = "admin"))]
767 pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
768 match self {
769 DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
770 #[cfg(any(test, feature = "sqlite"))]
771 DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
772 }
773 }
774
775 #[allow(clippy::too_many_arguments)]
781 #[cfg(any(test, feature = "admin"))]
782 pub async fn create_user(
783 &self,
784 uname: &str,
785 fname: &str,
786 lname: &str,
787 email: &str,
788 password: Option<String>,
789 organisation: Option<&String>,
790 readonly: bool,
791 ) -> Result<u32> {
792 match self {
793 DatabaseType::Postgres(pg) => {
794 pg.create_user(uname, fname, lname, email, password, organisation, readonly)
795 .await
796 }
797 #[cfg(any(test, feature = "sqlite"))]
798 DatabaseType::SQLite(sl) => {
799 sl.create_user(uname, fname, lname, email, password, organisation, readonly)
800 }
801 }
802 }
803
804 #[cfg(any(test, feature = "admin"))]
810 pub async fn reset_api_keys(&self) -> Result<u64> {
811 match self {
812 DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
813 #[cfg(any(test, feature = "sqlite"))]
814 DatabaseType::SQLite(sl) => sl.reset_api_keys(),
815 }
816 }
817
818 #[cfg(any(test, feature = "admin"))]
825 pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
826 match self {
827 DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
828 #[cfg(any(test, feature = "sqlite"))]
829 DatabaseType::SQLite(sl) => sl.set_password(uname, password),
830 }
831 }
832
833 #[cfg(any(test, feature = "admin"))]
839 pub async fn list_users(&self) -> Result<Vec<admin::User>> {
840 match self {
841 DatabaseType::Postgres(pg) => pg.list_users().await,
842 #[cfg(any(test, feature = "sqlite"))]
843 DatabaseType::SQLite(sl) => sl.list_users(),
844 }
845 }
846
847 #[cfg(any(test, feature = "admin"))]
854 pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
855 match self {
856 DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
857 #[cfg(any(test, feature = "sqlite"))]
858 DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
859 }
860 }
861
862 #[cfg(any(test, feature = "admin"))]
869 pub async fn edit_group(
870 &self,
871 gid: u32,
872 name: &str,
873 desc: &str,
874 parent: Option<u32>,
875 ) -> Result<()> {
876 match self {
877 DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
878 #[cfg(any(test, feature = "sqlite"))]
879 DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
880 }
881 }
882
883 #[cfg(any(test, feature = "admin"))]
889 pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
890 match self {
891 DatabaseType::Postgres(pg) => pg.list_groups().await,
892 #[cfg(any(test, feature = "sqlite"))]
893 DatabaseType::SQLite(sl) => sl.list_groups(),
894 }
895 }
896
897 #[cfg(any(test, feature = "admin"))]
904 pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
905 match self {
906 DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
907 #[cfg(any(test, feature = "sqlite"))]
908 DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
909 }
910 }
911
912 #[cfg(any(test, feature = "admin"))]
919 pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
920 match self {
921 DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
922 #[cfg(any(test, feature = "sqlite"))]
923 DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
924 }
925 }
926
927 #[cfg(any(test, feature = "admin"))]
934 pub async fn create_group(
935 &self,
936 name: &str,
937 description: &str,
938 parent: Option<u32>,
939 ) -> Result<u32> {
940 match self {
941 DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
942 #[cfg(any(test, feature = "sqlite"))]
943 DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
944 }
945 }
946
947 #[cfg(any(test, feature = "admin"))]
953 pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
954 match self {
955 DatabaseType::Postgres(pg) => pg.list_sources().await,
956 #[cfg(any(test, feature = "sqlite"))]
957 DatabaseType::SQLite(sl) => sl.list_sources(),
958 }
959 }
960
961 #[cfg(any(test, feature = "admin"))]
968 pub async fn create_source(
969 &self,
970 name: &str,
971 description: Option<&str>,
972 url: Option<&str>,
973 date: chrono::DateTime<Local>,
974 releasable: bool,
975 malicious: Option<bool>,
976 ) -> Result<u32> {
977 match self {
978 DatabaseType::Postgres(pg) => {
979 pg.create_source(name, description, url, date, releasable, malicious)
980 .await
981 }
982 #[cfg(any(test, feature = "sqlite"))]
983 DatabaseType::SQLite(sl) => {
984 sl.create_source(name, description, url, date, releasable, malicious)
985 }
986 }
987 }
988
989 #[cfg(any(test, feature = "admin"))]
996 pub async fn edit_user(
997 &self,
998 uid: u32,
999 uname: &str,
1000 fname: &str,
1001 lname: &str,
1002 email: &str,
1003 readonly: bool,
1004 ) -> Result<()> {
1005 match self {
1006 DatabaseType::Postgres(pg) => {
1007 pg.edit_user(uid, uname, fname, lname, email, readonly)
1008 .await
1009 }
1010 #[cfg(any(test, feature = "sqlite"))]
1011 DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
1012 }
1013 }
1014
1015 #[cfg(any(test, feature = "admin"))]
1022 pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
1023 match self {
1024 DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
1025 #[cfg(any(test, feature = "sqlite"))]
1026 DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
1027 }
1028 }
1029
1030 #[cfg(any(test, feature = "admin"))]
1036 pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
1037 match self {
1038 DatabaseType::Postgres(pg) => pg.file_types_counts().await,
1039 #[cfg(any(test, feature = "sqlite"))]
1040 DatabaseType::SQLite(sl) => sl.file_types_counts(),
1041 }
1042 }
1043
1044 #[cfg(any(test, feature = "admin"))]
1052 pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
1053 match self {
1054 DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
1055 #[cfg(any(test, feature = "sqlite"))]
1056 DatabaseType::SQLite(sl) => sl.create_label(name, parent),
1057 }
1058 }
1059
1060 #[cfg(any(test, feature = "admin"))]
1068 pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
1069 match self {
1070 DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
1071 #[cfg(any(test, feature = "sqlite"))]
1072 DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
1073 }
1074 }
1075
1076 #[cfg(any(test, feature = "admin"))]
1083 pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
1084 match self {
1085 DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
1086 #[cfg(any(test, feature = "sqlite"))]
1087 DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
1088 }
1089 }
1090
1091 #[cfg(any(test, feature = "admin"))]
1099 pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
1100 match self {
1101 DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
1102 #[cfg(any(test, feature = "sqlite"))]
1103 DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
1104 }
1105 }
1106}
1107
1108pub fn hash_password(password: &str) -> Result<String> {
1114 let salt = SaltString::generate(&mut OsRng);
1115 let argon2 = Argon2::default();
1116 Ok(argon2
1117 .hash_password(password.as_bytes(), &salt)?
1118 .to_string())
1119}
1120
1121#[must_use]
1123pub fn random_bytes_api_key() -> String {
1124 let key1 = uuid::Uuid::new_v4();
1125 let key2 = uuid::Uuid::new_v4();
1126 let key1 = key1.to_string().replace('-', "");
1127 let key2 = key2.to_string().replace('-', "");
1128 format!("{key1}{key2}")
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133 use super::*;
1134 #[cfg(feature = "vt")]
1135 use crate::vt::VtUpdater;
1136
1137 use std::fs;
1138 #[cfg(feature = "vt")]
1139 use std::sync::Arc;
1140 #[cfg(feature = "vt")]
1141 use std::time::SystemTime;
1142
1143 use anyhow::Context;
1144 use fuzzyhash::FuzzyHash;
1145 use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
1146 use malwaredb_lzjd::{LZDict, Murmur3HashState};
1147 use tlsh_fixed::TlshBuilder;
1148 use uuid::Uuid;
1149
1150 const MALWARE_LABEL: &str = "malware";
1151 const RANSOMWARE_LABEL: &str = "ransomware";
1152
1153 fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1154 let mut hashes = vec![];
1155
1156 hashes.push((
1157 malwaredb_api::SimilarityHashType::SSDeep,
1158 FuzzyHash::new(data).to_string(),
1159 ));
1160
1161 let mut builder = TlshBuilder::new(
1162 tlsh_fixed::BucketKind::Bucket256,
1163 tlsh_fixed::ChecksumKind::ThreeByte,
1164 tlsh_fixed::Version::Version4,
1165 );
1166
1167 builder.update(data);
1168
1169 if let Ok(hasher) = builder.build() {
1170 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1171 }
1172
1173 let build_hasher = Murmur3HashState::default();
1174 let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1175 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1176
1177 malwaredb_api::SimilarSamplesRequest { hashes }
1178 }
1179
1180 async fn pg_config() -> Postgres {
1181 const CONNECTION_STRING: &str =
1184 "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1185
1186 if let Ok(pg_port) = std::env::var("PG_PORT") {
1187 let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1189 Postgres::new(&conn_string, None)
1190 .await
1191 .context(format!(
1192 "failed to connect to postgres with specified port {pg_port}"
1193 ))
1194 .unwrap()
1195 } else {
1196 Postgres::new(CONNECTION_STRING, None).await.unwrap()
1197 }
1198 }
1199
1200 #[tokio::test]
1201 #[ignore = "don't run this in CI"]
1202 async fn pg() {
1203 let psql = pg_config().await;
1204 psql.delete().await.unwrap();
1205
1206 let psql = pg_config().await;
1207 let db = DatabaseType::Postgres(psql);
1208 everything(&db).await.unwrap();
1209
1210 #[cfg(feature = "vt")]
1211 {
1212 let db_config = db.get_config().await.unwrap();
1213 let state = crate::State {
1214 port: 8080,
1215 directory: None,
1216 max_upload: 10 * 1024 * 1024,
1217 ip: "127.0.0.1".parse().unwrap(),
1218 db_type: Arc::new(db),
1219 db_config,
1220 keys: HashMap::new(),
1221 started: SystemTime::now(),
1222 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1223 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1224 }),
1225 tls_config: None,
1226 mdns: None,
1227 };
1228
1229 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1230
1231 vt.updater().await.unwrap();
1232 println!("PG: Did VT ops!");
1233
1234 let psql = pg_config().await;
1235
1236 let vt_stats = psql
1237 .get_vt_stats()
1238 .await
1239 .context("failed to get Postgres VT Stats")
1240 .unwrap();
1241 println!("{vt_stats:?}");
1242 assert!(
1243 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1244 );
1245 }
1246
1247 let psql = pg_config().await;
1249 psql.delete().await.unwrap();
1250 }
1251
1252 #[tokio::test]
1253 async fn sqlite() {
1254 const DB_FILE: &str = "testing_sqlite.db";
1255 if std::path::Path::new(DB_FILE).exists() {
1256 fs::remove_file(DB_FILE)
1257 .context(format!("failed to delete old SQLite file {DB_FILE}"))
1258 .unwrap();
1259 }
1260
1261 let sqlite = Sqlite::new(DB_FILE)
1262 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1263 .unwrap();
1264
1265 let db = DatabaseType::SQLite(sqlite);
1266 everything(&db).await.unwrap();
1267
1268 #[cfg(feature = "vt")]
1269 {
1270 let db_config = db.get_config().await.unwrap();
1271 let state = crate::State {
1272 port: 8080,
1273 directory: None,
1274 max_upload: 10 * 1024 * 1024,
1275 ip: "127.0.0.1".parse().unwrap(),
1276 db_type: Arc::new(db),
1277 db_config,
1278 keys: HashMap::new(),
1279 started: SystemTime::now(),
1280 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1281 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1282 }),
1283 tls_config: None,
1284 mdns: None,
1285 };
1286
1287 let sqlite_second = Sqlite::new(DB_FILE)
1288 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1289 .unwrap();
1290
1291 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1292
1293 vt.updater().await.unwrap();
1294 println!("Sqlite: Did VT ops!");
1295 let vt_stats = sqlite_second
1296 .get_vt_stats()
1297 .context("failed to get Sqlite VT Stats")
1298 .unwrap();
1299 println!("{vt_stats:?}");
1300 assert!(
1301 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1302 );
1303 }
1304
1305 fs::remove_file(DB_FILE)
1306 .context(format!("failed to delete SQLite file {DB_FILE}"))
1307 .unwrap();
1308 }
1309
1310 #[allow(clippy::too_many_lines)]
1311 async fn everything(db: &DatabaseType) -> Result<()> {
1312 const ADMIN_UNAME: &str = "admin";
1313 const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1314
1315 db.set_name("Testing Database")
1316 .await
1317 .context("setting instance name failed")?;
1318
1319 assert!(
1320 db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1321 "Authentication without password should have failed."
1322 );
1323
1324 db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1325 .await
1326 .context("failed to set admin password")?;
1327
1328 let admin_api_key = db
1329 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1330 .await
1331 .context("unable to get api key for admin")?;
1332 println!("API key: {admin_api_key}");
1333 assert_eq!(admin_api_key.len(), 64);
1334
1335 assert_eq!(
1336 db.get_uid(&admin_api_key).await?,
1337 0,
1338 "Unable to get UID given the API key"
1339 );
1340
1341 let admin_api_key_again = db
1342 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1343 .await
1344 .context("unable to get api key a second time for admin")?;
1345
1346 assert_eq!(
1347 admin_api_key, admin_api_key_again,
1348 "API keys didn't match the second time."
1349 );
1350
1351 let bad_password = "this_is_totally_not_my_password!!";
1352 eprintln!("Testing API login with incorrect password.");
1353 assert!(
1354 db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1355 "Authenticating as admin with a bad password should have failed."
1356 );
1357
1358 let admin_is_admin = db
1359 .user_is_admin(0)
1360 .await
1361 .context("unable to see if admin (uid 0) is an admin")?;
1362 assert!(admin_is_admin);
1363
1364 let new_user_uname = "testuser";
1365 let new_user_email = "test@example.com";
1366 let new_user_password = "some_awesome_password_++";
1367 let new_id = db
1368 .create_user(
1369 new_user_uname,
1370 new_user_uname,
1371 new_user_uname,
1372 new_user_email,
1373 Some(new_user_password.into()),
1374 None,
1375 false,
1376 )
1377 .await
1378 .context(format!("failed to create user {new_user_uname}"))?;
1379
1380 let passwordless_user_id = db
1381 .create_user(
1382 "passwordless_user",
1383 "passwordless_user",
1384 "passwordless_user",
1385 "passwordless_user@example.com",
1386 None,
1387 None,
1388 false,
1389 )
1390 .await
1391 .context("failed to create passwordless_user")?;
1392
1393 for user in &db.list_users().await.context("failed to list users")? {
1394 if user.id == passwordless_user_id {
1395 assert_eq!(user.uname, "passwordless_user");
1396 }
1397 }
1398
1399 db.edit_user(
1400 passwordless_user_id,
1401 "passwordless_user_2",
1402 "passwordless_user_2",
1403 "passwordless_user_2",
1404 "passwordless_user_2@something.com",
1405 false,
1406 )
1407 .await
1408 .context(format!(
1409 "failed to alter 'passwordless' user, id {passwordless_user_id}"
1410 ))?;
1411
1412 for user in &db.list_users().await.context("failed to list users")? {
1413 if user.id == passwordless_user_id {
1414 assert_eq!(user.uname, "passwordless_user_2");
1415 }
1416 }
1417
1418 assert!(
1419 new_id > 0,
1420 "Weird UID created for user {new_user_uname}: {new_id}"
1421 );
1422
1423 assert!(
1424 db.create_user(
1425 new_user_uname,
1426 new_user_uname,
1427 new_user_uname,
1428 new_user_email,
1429 Some(new_user_password.into()),
1430 None,
1431 false
1432 )
1433 .await
1434 .is_err(),
1435 "Creating a new user with the same user name should fail"
1436 );
1437
1438 let ro_user_name = "ro_user";
1439 let ro_user_password = "ro_user_password";
1440 db.create_user(
1441 ro_user_name,
1442 "ro_user",
1443 "ro_user",
1444 "ro@example.com",
1445 Some(ro_user_password.into()),
1446 None,
1447 true,
1448 )
1449 .await
1450 .context("failed to create read-only user")?;
1451
1452 let ro_user_api_key = db
1453 .authenticate(ro_user_name, ro_user_password)
1454 .await
1455 .context("unable to get api key for read-only user")?;
1456
1457 let new_user_password_change = "some_new_awesomer_password!_++";
1458 db.set_password(new_user_uname, new_user_password_change)
1459 .await
1460 .context("failed to change the password for testuser")?;
1461
1462 let new_user_api_key = db
1463 .authenticate(new_user_uname, new_user_password_change)
1464 .await
1465 .context("unable to get api key for testuser")?;
1466 eprintln!("{new_user_uname} got API key {new_user_api_key}");
1467
1468 assert_eq!(admin_api_key.len(), new_user_api_key.len());
1469
1470 let users = db.list_users().await.context("failed to list users")?;
1471 assert_eq!(
1472 users.len(),
1473 4,
1474 "Four users were created, yet there are {} users",
1475 users.len()
1476 );
1477 eprintln!("DB has {} users:", users.len());
1478 let mut passwordless_user_found = false;
1479 for user in users {
1480 println!("{user}");
1481 if user.uname == "passwordless_user_2" {
1482 assert!(!user.has_api_key);
1483 assert!(!user.has_password);
1484 passwordless_user_found = true;
1485 } else {
1486 assert!(user.has_api_key);
1487 assert!(user.has_password);
1488 }
1489 }
1490 assert!(passwordless_user_found);
1491
1492 let new_group_name = "some_new_group";
1493 let new_group_desc = "some_new_group_description";
1494 let new_group_id = 1;
1495 assert_eq!(
1496 db.create_group(new_group_name, new_group_desc, None)
1497 .await
1498 .context("failed to create group")?,
1499 new_group_id,
1500 "New group didn't have the expected ID, expected {new_group_id}"
1501 );
1502
1503 assert!(
1504 db.create_group(new_group_name, new_group_desc, None)
1505 .await
1506 .is_err(),
1507 "Duplicate group name should have failed"
1508 );
1509
1510 db.add_user_to_group(1, 1)
1511 .await
1512 .context("Unable to add uid 1 to gid 1")?;
1513
1514 let ro_user_uid = db
1515 .get_uid(&ro_user_api_key)
1516 .await
1517 .context("Unable to get UID for read-only user")?;
1518 db.add_user_to_group(ro_user_uid, 1)
1519 .await
1520 .context("Unable to add uid 2 to gid 1")?;
1521
1522 let new_admin_group_name = "admin_subgroup";
1523 let new_admin_group_desc = "admin_subgroup_description";
1524 let new_admin_group_id = 2;
1525 assert!(
1527 db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1528 .await
1529 .context("failed to create admin sub-group")?
1530 >= new_admin_group_id,
1531 "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1532 );
1533
1534 let groups = db.list_groups().await.context("failed to list groups")?;
1535 assert_eq!(
1536 groups.len(),
1537 3,
1538 "Three groups were created, yet there are {} groups",
1539 groups.len()
1540 );
1541 eprintln!("DB has {} groups:", groups.len());
1542 for group in groups {
1543 println!("{group}");
1544 if group.id == new_admin_group_id {
1545 assert_eq!(group.parent, Some("admin".to_string()));
1546 }
1547 if group.id == 1 {
1548 let test_user_str = String::from(new_user_uname);
1549 let mut found = false;
1550 for member in group.members {
1551 if member.uname == test_user_str {
1552 found = true;
1553 break;
1554 }
1555 }
1556 assert!(found, "new user {test_user_str} wasn't in the group");
1557 }
1558 }
1559
1560 let default_source_name = "default_source".to_string();
1561 let default_source_id = db
1562 .create_source(
1563 &default_source_name,
1564 Some("desc_default_source"),
1565 None,
1566 Local::now(),
1567 true,
1568 Some(false),
1569 )
1570 .await
1571 .context("failed to create source `default_source`")?;
1572
1573 db.add_group_to_source(1, default_source_id)
1574 .await
1575 .context("failed to add group 1 to source 1")?;
1576
1577 let another_source_name = "another_source".to_string();
1578 let another_source_id = db
1579 .create_source(
1580 &another_source_name,
1581 Some("yet another file source"),
1582 None,
1583 Local::now(),
1584 true,
1585 Some(false),
1586 )
1587 .await
1588 .context("failed to create source `another_source`")?;
1589
1590 let empty_source_name = "empty_source".to_string();
1591 db.create_source(
1592 &empty_source_name,
1593 Some("empty and unused file source"),
1594 None,
1595 Local::now(),
1596 true,
1597 Some(false),
1598 )
1599 .await
1600 .context("failed to create source `another_source`")?;
1601
1602 db.add_group_to_source(1, another_source_id)
1603 .await
1604 .context("failed to add group 1 to source 1")?;
1605
1606 let sources = db.list_sources().await.context("failed to list sources")?;
1607 eprintln!("DB has {} sources:", sources.len());
1608 for source in sources {
1609 println!("{source}");
1610 assert_eq!(source.files, 0);
1611 if source.id == default_source_id || source.id == another_source_id {
1612 assert_eq!(
1613 source.groups, 1,
1614 "default source {default_source_name} should have 1 group"
1615 );
1616 } else {
1617 assert_eq!(source.groups, 0, "groups should zero (empty)");
1618 }
1619 }
1620
1621 let uid = db
1622 .get_uid(&new_user_api_key)
1623 .await
1624 .context("failed to user uid from apikey")?;
1625 let user_info = db
1626 .get_user_info(uid)
1627 .await
1628 .context("failed to get user's available groups and sources")?;
1629 assert!(user_info.sources.contains(&default_source_name));
1630 assert!(!user_info.is_admin);
1631 println!("UserInfoResponse: {user_info:?}");
1632
1633 assert!(
1634 db.allowed_user_source(1, default_source_id)
1635 .await
1636 .context(format!(
1637 "failed to check that user 1 has access to source {default_source_id}"
1638 ))?,
1639 "User 1 should should have had access to source {default_source_id}"
1640 );
1641
1642 assert!(
1643 !db.allowed_user_source(1, 5)
1644 .await
1645 .context("failed to check that user 1 has access to source 5")?,
1646 "User 1 should should not have had access to source 5"
1647 );
1648
1649 let test_label_id = db
1650 .create_label("TestLabel", None)
1651 .await
1652 .context("failed to create test label")?;
1653 let test_elf_label_id = db
1654 .create_label("TestELF", Some(test_label_id))
1655 .await
1656 .context("failed to create test label")?;
1657
1658 let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1659 let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1660 let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1661
1662 let known_type =
1663 KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1664 assert!(known_type.is_exec(), "ELF should be executable");
1665 eprintln!("ELF type ID: {elf_type}");
1666
1667 let file_addition = db
1668 .add_file(
1669 &test_elf_meta,
1670 known_type.clone(),
1671 1,
1672 default_source_id,
1673 elf_type,
1674 None,
1675 )
1676 .await
1677 .context("failed to insert a test elf")?;
1678 assert!(file_addition.is_new, "File should have been added");
1679 eprintln!("Added ELF to the DB");
1680 db.label_file(file_addition.file_id, test_elf_label_id)
1681 .await
1682 .context("failed to label file")?;
1683
1684 let partial_search = SearchRequest {
1686 search: SearchType::Search(SearchRequestParameters {
1687 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1688 labels: Some(vec![String::from("TestELF")]),
1689 file_type: Some(String::from("ELF")),
1690 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1691 ..Default::default()
1692 }),
1693 };
1694 assert!(partial_search.is_valid());
1695 let partial_search_response = db.partial_search(1, partial_search).await?;
1696 assert_eq!(partial_search_response.hashes.len(), 1);
1697 assert_eq!(
1698 partial_search_response.hashes[0],
1699 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1700 );
1701
1702 let partial_search = SearchRequest {
1704 search: SearchType::Search(SearchRequestParameters {
1705 partial_hash: None,
1706 labels: None,
1707 file_type: None,
1708 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1709 ..Default::default()
1710 }),
1711 };
1712 assert!(partial_search.is_valid());
1713 let partial_search_response = db.partial_search(1, partial_search).await?;
1714 assert_eq!(partial_search_response.hashes.len(), 1);
1715 assert_eq!(
1716 partial_search_response.hashes[0],
1717 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1718 );
1719
1720 let partial_search = SearchRequest {
1722 search: SearchType::Search(SearchRequestParameters {
1723 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1724 file_type: Some(String::from("PE32")),
1725 ..Default::default()
1726 }),
1727 };
1728 assert!(partial_search.is_valid());
1729 let partial_search_response = db.partial_search(1, partial_search).await?;
1730 assert_eq!(partial_search_response.hashes.len(), 0);
1731
1732 let partial_search = SearchRequest {
1733 search: SearchType::Search(SearchRequestParameters {
1734 file_name: Some("ppc64".into()),
1735 ..Default::default()
1736 }),
1737 };
1738 assert!(partial_search.is_valid());
1739 let partial_search_response = db.partial_search(1, partial_search).await?;
1740 assert_eq!(partial_search_response.hashes.len(), 1);
1741
1742 let partial_search = SearchRequest {
1744 search: SearchType::Search(SearchRequestParameters::default()),
1745 };
1746 assert!(!partial_search.is_valid());
1747 let partial_search_response = db.partial_search(1, partial_search).await?;
1748 assert!(partial_search_response.hashes.is_empty());
1749
1750 let partial_search = SearchRequest {
1752 search: SearchType::Continuation(Uuid::default()),
1753 };
1754 assert!(partial_search.is_valid());
1755 let partial_search_response = db.partial_search(1, partial_search).await?;
1756 assert!(partial_search_response.hashes.is_empty());
1757
1758 assert!(db
1760 .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1761 .await
1762 .is_err());
1763
1764 assert!(
1765 db.add_file(
1766 &test_elf_meta,
1767 known_type.clone(),
1768 ro_user_uid,
1769 default_source_id,
1770 elf_type,
1771 None
1772 )
1773 .await
1774 .is_err(),
1775 "Read-only user should not be able to add a file"
1776 );
1777
1778 let mut test_elf_meta_different_name = test_elf_meta.clone();
1779 test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1780
1781 assert!(
1782 !db.add_file(
1783 &test_elf_meta_different_name,
1784 known_type,
1785 1,
1786 another_source_id,
1787 elf_type,
1788 None
1789 )
1790 .await
1791 .context("failed to insert a test elf again for a different source")?
1792 .is_new
1793 );
1794
1795 let sources = db
1796 .list_sources()
1797 .await
1798 .context("failed to re-list sources")?;
1799 eprintln!(
1800 "DB has {} sources, and a file was added twice:",
1801 sources.len()
1802 );
1803 println!("We should have two sources with one file each, yet only one ELF.");
1804 for source in sources {
1805 println!("{source}");
1806 if source.id == default_source_id || source.id == another_source_id {
1807 assert_eq!(source.files, 1);
1808 } else {
1809 assert_eq!(source.files, 0, "groups should zero (empty)");
1810 }
1811 }
1812
1813 assert!(!db
1814 .get_user_sources(1)
1815 .await
1816 .expect("failed to get user 1's sources")
1817 .sources
1818 .is_empty());
1819
1820 let file_types_counts = db
1821 .file_types_counts()
1822 .await
1823 .context("failed to get file types and counts")?;
1824 for (name, count) in file_types_counts {
1825 println!("{name}: {count}");
1826 assert_eq!(name, "ELF");
1827 assert_eq!(count, 1);
1828 }
1829
1830 let mut test_elf_modified = test_elf.clone();
1831 let random_bytes = Uuid::new_v4();
1832 let mut random_bytes = random_bytes.into_bytes().to_vec();
1833 test_elf_modified.append(&mut random_bytes);
1834 let similarity_request = generate_similarity_request(&test_elf_modified);
1835 let similarity_response = db
1836 .find_similar_samples(1, &similarity_request.hashes)
1837 .await
1838 .context("failed to get similarity response")?;
1839 eprintln!("Similarity response: {similarity_response:?}");
1840 let similarity_response = similarity_response.first().unwrap();
1841 assert_eq!(
1842 similarity_response.sha256,
1843 hex::encode(&test_elf_meta.sha256),
1844 "Similarity response should have had the hash of the original ELF"
1845 );
1846 for (algo, sim) in &similarity_response.algorithms {
1847 match algo {
1848 malwaredb_api::SimilarityHashType::LZJD => {
1849 assert!(*sim > 0.0f32);
1850 }
1851 malwaredb_api::SimilarityHashType::SSDeep => {
1852 assert!(*sim > 80.0f32);
1853 }
1854 malwaredb_api::SimilarityHashType::TLSH => {
1855 assert!(*sim <= 20f32);
1856 }
1857 _ => {}
1858 }
1859 }
1860
1861 let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1.as_slice())
1862 .context("failed to get `HashType::SHA1` from string")?;
1863 let response_sha256 = db
1864 .retrieve_sample(1, &test_elf_hashtype)
1865 .await
1866 .context("could not get SHA-256 hash from test sample")
1867 .unwrap();
1868 assert_eq!(response_sha256, hex::encode(&test_elf_meta.sha256));
1869
1870 let test_bogus_hash =
1871 HashType::try_from("d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0")
1872 .context("failed to get `HashType` from static string")?;
1873 assert!(
1874 db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1875 "Getting a file with a bogus hash should have failed."
1876 );
1877
1878 let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1879 let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1880 let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1881
1882 let known_type =
1883 KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1884
1885 assert!(
1886 db.add_file(
1887 &test_pdf_meta,
1888 known_type,
1889 1,
1890 default_source_id,
1891 pdf_type,
1892 None
1893 )
1894 .await
1895 .context("failed to insert a test pdf")?
1896 .is_new
1897 );
1898 eprintln!("Added PDF to the DB");
1899
1900 let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1901 let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1902 let rtf_type = db
1903 .get_type_id_for_bytes(&test_rtf)
1904 .await
1905 .context("failed to get file type id for rtf")?;
1906
1907 let known_type =
1908 KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1909
1910 assert!(
1911 db.add_file(
1912 &test_rtf_meta,
1913 known_type,
1914 1,
1915 default_source_id,
1916 rtf_type,
1917 None
1918 )
1919 .await
1920 .context("failed to insert a test rtf")?
1921 .is_new
1922 );
1923 eprintln!("Added RTF to the DB");
1924
1925 let report = db
1926 .get_sample_report(
1927 1,
1928 &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap(),
1929 )
1930 .await
1931 .context("failed to get report for test rtf")?;
1932 assert!(report
1933 .clone()
1934 .filecommand
1935 .unwrap()
1936 .contains("Rich Text Format"));
1937 println!("Report: {report}");
1938
1939 assert!(db
1940 .get_sample_report(
1941 999,
1942 &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap()
1943 )
1944 .await
1945 .is_err());
1946
1947 #[cfg(feature = "vt")]
1948 {
1949 assert!(report.vt.is_some());
1950 let files_needing_vt = db
1951 .files_without_vt_records(10)
1952 .await
1953 .context("failed to get files without VT records")?;
1954 assert!(files_needing_vt.len() > 2);
1955 println!(
1956 "{} files needing VT data: {files_needing_vt:?}",
1957 files_needing_vt.len()
1958 );
1959 }
1960
1961 #[cfg(not(feature = "vt"))]
1962 {
1963 assert!(report.vt.is_none());
1964 }
1965
1966 let reset = db
1967 .reset_api_keys()
1968 .await
1969 .context("failed to reset all API keys")?;
1970 eprintln!("Cleared {reset} api keys.");
1971
1972 let db_info = db.db_info().await.context("failed to get database info")?;
1973 eprintln!("DB Info: {db_info:?}");
1974
1975 let data_types = db
1976 .get_known_data_types()
1977 .await
1978 .context("failed to get data types")?;
1979 for data_type in data_types {
1980 println!("{data_type:?}");
1981 }
1982
1983 let sources = db
1984 .list_sources()
1985 .await
1986 .context("failed to list sources second time")?;
1987 eprintln!("DB has {} sources:", sources.len());
1988 for source in sources {
1989 println!("{source}");
1990 }
1991
1992 let file_types_counts = db
1993 .file_types_counts()
1994 .await
1995 .context("failed to get file types and counts")?;
1996 for (name, count) in file_types_counts {
1997 println!("{name}: {count}");
1998 assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1999 }
2000
2001 let fatmacho =
2002 include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
2003 .to_vec();
2004 let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
2005 let fatmacho_type = db
2006 .get_type_id_for_bytes(&fatmacho)
2007 .await
2008 .context("failed to get file type for Fat Mach-O")?;
2009 let known_type = KnownType::new(&fatmacho)
2010 .context("failed to parse Fat Mach-O from type crate's test data")?;
2011
2012 assert!(
2013 db.add_file(
2014 &fatmacho_meta,
2015 known_type,
2016 1,
2017 default_source_id,
2018 fatmacho_type,
2019 None
2020 )
2021 .await
2022 .context("failed to insert a test Fat Mach-O")?
2023 .is_new
2024 );
2025 eprintln!("Added Fat Mach-O to the DB");
2026
2027 let file_types_counts = db
2028 .file_types_counts()
2029 .await
2030 .context("failed to get file types and counts")?;
2031 for (name, count) in &file_types_counts {
2032 println!("{name}: {count}");
2033 }
2034
2035 assert_eq!(
2036 *file_types_counts.get("Mach-O").unwrap(),
2037 4,
2038 "Expected 4 Mach-O files, got {:?}",
2039 file_types_counts.get("Mach-O")
2040 );
2041
2042 let allowed_files = db
2043 .user_allowed_files_by_sha256(1, None)
2044 .await
2045 .context("failed to get allowed files")?;
2046 assert_eq!(allowed_files.0.len(), 8);
2047
2048 let allowed_files = db
2049 .user_allowed_files_by_sha256(1, Some(allowed_files.1))
2050 .await
2051 .context("failed to get allowed files")?;
2052 assert!(allowed_files.0.is_empty());
2053
2054 let malware_label_id = db
2055 .create_label(MALWARE_LABEL, None)
2056 .await
2057 .context("failed to create first label")?;
2058 let ransomware_label_id = db
2059 .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
2060 .await
2061 .context("failed to create malware sub-label")?;
2062 let labels = db.get_labels().await.context("failed to get labels")?;
2063
2064 assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
2065 for label in labels.0 {
2066 if label.name == RANSOMWARE_LABEL {
2067 assert_eq!(label.id, ransomware_label_id);
2068 assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
2069 }
2070 }
2071
2072 let source_code = include_bytes!("mod.rs");
2074 let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
2075 let known_type =
2076 KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
2077
2078 assert!(matches!(known_type, KnownType::Unknown(_)));
2079
2080 let unknown_type: Vec<FileType> = db
2081 .get_known_data_types()
2082 .await?
2083 .into_iter()
2084 .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
2085 .collect();
2086 let unknown_type_id = unknown_type.first().unwrap().id;
2087 assert!(db.get_type_id_for_bytes(source_code).await.is_err());
2088 db.enable_keep_unknown_files()
2089 .await
2090 .context("failed to enable keeping of unknown files")?;
2091 let source_type = db
2092 .get_type_id_for_bytes(source_code)
2093 .await
2094 .context("failed to type id for source code unknown type example")?;
2095 assert_eq!(source_type, unknown_type_id);
2096 eprintln!("Unknown file type ID: {source_type}");
2097 assert!(
2098 db.add_file(
2099 &source_meta,
2100 known_type,
2101 1,
2102 default_source_id,
2103 unknown_type_id,
2104 None
2105 )
2106 .await
2107 .context("failed to add Rust source code file")?
2108 .is_new
2109 );
2110 eprintln!("Added Rust source code to the DB");
2111
2112 #[cfg(feature = "yara")]
2113 assert!(db.get_unfinished_yara_tasks().await?.is_empty());
2114
2115 db.reset_own_api_key(0)
2116 .await
2117 .context("failed to clear own API key uid 0")?;
2118
2119 db.deactivate_user(0)
2120 .await
2121 .context("failed to clear password and API key for uid 0")?;
2122
2123 Ok(())
2124 }
2125}