1#[cfg(any(test, feature = "admin"))]
10mod admin;
11mod pg;
13
14#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33 digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51#[derive(Copy, Clone)]
53pub enum Migration {
54 Check,
56
57 #[cfg(any(test, feature = "admin"))]
59 Migrate,
60}
61
62#[derive(Debug)]
64pub enum DatabaseType {
65 Postgres(Postgres),
67
68 #[cfg(any(test, feature = "sqlite"))]
70 SQLite(Sqlite),
71}
72
73#[derive(Debug)]
75pub struct DatabaseInformation {
76 pub version: String,
78
79 pub size: String,
81
82 pub num_files: u64,
84
85 pub num_users: u32,
87
88 pub num_groups: u32,
90
91 pub num_sources: u32,
93}
94
95pub struct FileAddedResult {
97 pub file_id: u64,
99
100 pub is_new: bool,
103}
104
105#[derive(Debug)]
107pub struct MDBConfig {
108 pub name: String,
110
111 pub compression: bool,
113
114 pub send_samples_to_vt: bool,
116
117 pub keep_unknown_files: bool,
119
120 pub(crate) default_key: Option<u32>,
122}
123
124#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
126#[cfg(feature = "vt")]
127#[derive(Debug, Clone, Copy)]
128pub struct VtStats {
129 pub clean_records: u32,
131
132 pub hits_records: u32,
134
135 pub files_without_records: u32,
137}
138
139impl DatabaseType {
140 pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
149 let db = Self::init_from_string(arg, server_ca).await?;
150 db.migrate_check(Migration::Check).await?;
151 Ok(db)
152 }
153
154 #[cfg(feature = "admin")]
163 pub async fn migrate(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
164 let db = Self::init_from_string(arg, server_ca).await?;
165 db.migrate_check(Migration::Migrate).await?;
166 Ok(db)
167 }
168
169 async fn init_from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
170 #[cfg(any(test, feature = "sqlite"))]
171 if arg.starts_with("file:") {
172 let new_conn_str = arg.trim_start_matches("file:");
173 let db = DatabaseType::SQLite(Sqlite::new(new_conn_str)?);
174 return Ok(db);
175 }
176
177 if arg.starts_with("postgres") {
178 let new_conn_str = arg.trim_start_matches("postgres");
179 let db = DatabaseType::Postgres(Postgres::new(new_conn_str, server_ca).await?);
180
181 return Ok(db);
182 }
183
184 bail!("unknown database type `{arg}`")
185 }
186
187 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
193 #[cfg(feature = "vt")]
194 pub async fn enable_vt_upload(&self) -> Result<()> {
195 match self {
196 DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
197 #[cfg(any(test, feature = "sqlite"))]
198 DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
199 }
200 }
201
202 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
208 #[cfg(feature = "vt")]
209 pub async fn disable_vt_upload(&self) -> Result<()> {
210 match self {
211 DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
212 #[cfg(any(test, feature = "sqlite"))]
213 DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
214 }
215 }
216
217 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
223 #[cfg(feature = "vt")]
224 pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
225 match self {
226 DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
227 #[cfg(any(test, feature = "sqlite"))]
228 DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
229 }
230 }
231
232 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
238 #[cfg(feature = "vt")]
239 pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
240 match self {
241 DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
242 #[cfg(any(test, feature = "sqlite"))]
243 DatabaseType::SQLite(sl) => sl.store_vt_record(results),
244 }
245 }
246
247 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
253 #[cfg(feature = "vt")]
254 pub async fn get_vt_stats(&self) -> Result<VtStats> {
255 match self {
256 DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
257 #[cfg(any(test, feature = "sqlite"))]
258 DatabaseType::SQLite(sl) => sl.get_vt_stats(),
259 }
260 }
261
262 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
268 #[cfg(feature = "yara")]
269 pub async fn add_yara_search(
270 &self,
271 uid: u32,
272 yara_string: &str,
273 yara_bytes: &[u8],
274 ) -> Result<uuid::Uuid> {
275 match self {
276 DatabaseType::Postgres(pg) => pg.add_yara_search(uid, yara_string, yara_bytes).await,
277 #[cfg(any(test, feature = "sqlite"))]
278 DatabaseType::SQLite(sl) => sl.add_yara_search(uid, yara_string, yara_bytes),
279 }
280 }
281
282 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
288 #[cfg(feature = "yara")]
289 pub async fn get_unfinished_yara_tasks(&self) -> Result<Vec<crate::yara::YaraTask>> {
290 match self {
291 DatabaseType::Postgres(pg) => pg.get_unfinished_yara_tasks().await,
292 #[cfg(any(test, feature = "sqlite"))]
293 DatabaseType::SQLite(sl) => sl.get_unfinished_yara_tasks(),
294 }
295 }
296
297 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
303 #[cfg(feature = "yara")]
304 pub async fn add_yara_match(
305 &self,
306 id: uuid::Uuid,
307 rule_name: &str,
308 file_sha256: &str,
309 ) -> Result<()> {
310 match self {
311 DatabaseType::Postgres(pg) => pg.add_yara_match(id, rule_name, file_sha256).await,
312 #[cfg(any(test, feature = "sqlite"))]
313 DatabaseType::SQLite(sl) => sl.add_yara_match(id, rule_name, file_sha256),
314 }
315 }
316
317 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
323 #[cfg(feature = "yara")]
324 pub async fn mark_yara_task_as_finished(&self, id: uuid::Uuid) -> Result<()> {
325 match self {
326 DatabaseType::Postgres(pg) => pg.mark_yara_task_as_finished(id).await,
327 #[cfg(any(test, feature = "sqlite"))]
328 DatabaseType::SQLite(sl) => sl.mark_yara_task_as_finished(id),
329 }
330 }
331
332 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
338 #[cfg(feature = "yara")]
339 pub async fn yara_add_next_file_id(&self, id: uuid::Uuid, file_id: u64) -> Result<()> {
340 match self {
341 DatabaseType::Postgres(pg) => pg.yara_add_next_file_id(id, file_id).await,
342 #[cfg(any(test, feature = "sqlite"))]
343 DatabaseType::SQLite(sl) => sl.yara_add_next_file_id(id, file_id),
344 }
345 }
346
347 #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
353 #[cfg(feature = "yara")]
354 pub async fn get_yara_results(
355 &self,
356 id: uuid::Uuid,
357 user_id: u32,
358 ) -> Result<malwaredb_api::YaraSearchResponse> {
359 match self {
360 DatabaseType::Postgres(pg) => pg.get_yara_results(id, user_id).await,
361 #[cfg(any(test, feature = "sqlite"))]
362 DatabaseType::SQLite(sl) => sl.get_yara_results(id, user_id),
363 }
364 }
365
366 pub async fn get_config(&self) -> Result<MDBConfig> {
372 match self {
373 DatabaseType::Postgres(pg) => pg.get_config().await,
374 #[cfg(any(test, feature = "sqlite"))]
375 DatabaseType::SQLite(sl) => sl.get_config(),
376 }
377 }
378
379 pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
386 match self {
387 DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
388 #[cfg(any(test, feature = "sqlite"))]
389 DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
390 }
391 }
392
393 pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
400 ensure!(!apikey.is_empty(), "API key was empty");
401 match self {
402 DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
403 #[cfg(any(test, feature = "sqlite"))]
404 DatabaseType::SQLite(sl) => sl.get_uid(apikey),
405 }
406 }
407
408 pub async fn db_info(&self) -> Result<DatabaseInformation> {
414 match self {
415 DatabaseType::Postgres(pg) => pg.db_info().await,
416 #[cfg(any(test, feature = "sqlite"))]
417 DatabaseType::SQLite(sl) => sl.db_info(),
418 }
419 }
420
421 pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
428 match self {
429 DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
430 #[cfg(any(test, feature = "sqlite"))]
431 DatabaseType::SQLite(sl) => sl.get_user_info(uid),
432 }
433 }
434
435 pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
442 match self {
443 DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
444 #[cfg(any(test, feature = "sqlite"))]
445 DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
446 }
447 }
448
449 pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
456 match self {
457 DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
458 #[cfg(any(test, feature = "sqlite"))]
459 DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
460 }
461 }
462
463 pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
469 match self {
470 DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
471 #[cfg(any(test, feature = "sqlite"))]
472 DatabaseType::SQLite(sl) => sl.get_known_data_types(),
473 }
474 }
475
476 pub async fn get_labels(&self) -> Result<Labels> {
482 match self {
483 DatabaseType::Postgres(pg) => pg.get_labels().await,
484 #[cfg(any(test, feature = "sqlite"))]
485 DatabaseType::SQLite(sl) => sl.get_labels(),
486 }
487 }
488
489 pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
495 match self {
496 DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
497 #[cfg(any(test, feature = "sqlite"))]
498 DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
499 }
500 }
501
502 pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
509 match self {
510 DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
511 #[cfg(any(test, feature = "sqlite"))]
512 DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
513 }
514 }
515
516 pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
524 match self {
525 DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
526 #[cfg(any(test, feature = "sqlite"))]
527 DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
528 }
529 }
530
531 pub async fn add_file(
538 &self,
539 meta: &FileMetadata,
540 known_type: KnownType<'_>,
541 uid: u32,
542 sid: u32,
543 ftype: u32,
544 parent: Option<u64>,
545 ) -> Result<FileAddedResult> {
546 match self {
547 DatabaseType::Postgres(pg) => {
548 pg.add_file(meta, known_type, uid, sid, ftype, parent).await
549 }
550 #[cfg(any(test, feature = "sqlite"))]
551 DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
552 }
553 }
554
555 pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
561 match self {
562 DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
563 #[cfg(any(test, feature = "sqlite"))]
564 DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
565 }
566 }
567
568 pub async fn cleanup(&self) -> Result<u64> {
574 match self {
575 DatabaseType::Postgres(pg) => pg.cleanup().await,
576 #[cfg(any(test, feature = "sqlite"))]
577 DatabaseType::SQLite(sl) => sl.cleanup(),
578 }
579 }
580
581 pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
589 match self {
590 DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
591 #[cfg(any(test, feature = "sqlite"))]
592 DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
593 }
594 }
595
596 pub async fn get_sample_report(
603 &self,
604 uid: u32,
605 hash: &HashType,
606 ) -> Result<malwaredb_api::Report> {
607 match self {
608 DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
609 #[cfg(any(test, feature = "sqlite"))]
610 DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
611 }
612 }
613
614 pub async fn find_similar_samples(
620 &self,
621 uid: u32,
622 sim: &[(malwaredb_api::SimilarityHashType, String)],
623 ) -> Result<Vec<malwaredb_api::SimilarSample>> {
624 match self {
625 DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
626 #[cfg(any(test, feature = "sqlite"))]
627 DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
628 }
629 }
630
631 pub async fn user_allowed_files_by_sha256(
638 &self,
639 uid: u32,
640 next: Option<u64>,
641 ) -> Result<(Vec<String>, u64)> {
642 match self {
643 DatabaseType::Postgres(pg) => pg.user_allowed_files_by_sha256(uid, next).await,
644 #[cfg(any(test, feature = "sqlite"))]
645 DatabaseType::SQLite(sl) => sl.user_allowed_files_by_sha256(uid, next),
646 }
647 }
648
649 pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
653 match self {
654 DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
655 #[cfg(any(test, feature = "sqlite"))]
656 DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
657 }
658 }
659
660 pub(crate) async fn get_file_encryption_key_id(
662 &self,
663 hash: &str,
664 ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
665 match self {
666 DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
667 #[cfg(any(test, feature = "sqlite"))]
668 DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
669 }
670 }
671
672 pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
674 match self {
675 DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
676 #[cfg(any(test, feature = "sqlite"))]
677 DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
678 }
679 }
680
681 pub async fn migrate_check(&self, action: Migration) -> Result<()> {
688 match self {
689 DatabaseType::Postgres(pg) => pg.migrate(action).await,
690 #[cfg(any(test, feature = "sqlite"))]
691 DatabaseType::SQLite(sl) => sl.migrate(action),
692 }
693 }
694
695 #[cfg(any(test, feature = "admin"))]
703 pub async fn set_name(&self, name: &str) -> Result<()> {
704 match self {
705 DatabaseType::Postgres(pg) => pg.set_name(name).await,
706 #[cfg(any(test, feature = "sqlite"))]
707 DatabaseType::SQLite(sl) => sl.set_name(name),
708 }
709 }
710
711 #[cfg(any(test, feature = "admin"))]
717 pub async fn enable_compression(&self) -> Result<()> {
718 match self {
719 DatabaseType::Postgres(pg) => pg.enable_compression().await,
720 #[cfg(any(test, feature = "sqlite"))]
721 DatabaseType::SQLite(sl) => sl.enable_compression(),
722 }
723 }
724
725 #[cfg(any(test, feature = "admin"))]
731 pub async fn disable_compression(&self) -> Result<()> {
732 match self {
733 DatabaseType::Postgres(pg) => pg.disable_compression().await,
734 #[cfg(any(test, feature = "sqlite"))]
735 DatabaseType::SQLite(sl) => sl.disable_compression(),
736 }
737 }
738
739 #[cfg(any(test, feature = "admin"))]
745 pub async fn enable_keep_unknown_files(&self) -> Result<()> {
746 match self {
747 DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
748 #[cfg(any(test, feature = "sqlite"))]
749 DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
750 }
751 }
752
753 #[cfg(any(test, feature = "admin"))]
759 pub async fn disable_keep_unknown_files(&self) -> Result<()> {
760 match self {
761 DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
762 #[cfg(any(test, feature = "sqlite"))]
763 DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
764 }
765 }
766
767 #[cfg(any(test, feature = "admin"))]
773 pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
774 match self {
775 DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
776 #[cfg(any(test, feature = "sqlite"))]
777 DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
778 }
779 }
780
781 #[cfg(any(test, feature = "admin"))]
787 pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
788 match self {
789 DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
790 #[cfg(any(test, feature = "sqlite"))]
791 DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
792 }
793 }
794
795 #[allow(clippy::too_many_arguments)]
801 #[cfg(any(test, feature = "admin"))]
802 pub async fn create_user(
803 &self,
804 uname: &str,
805 fname: &str,
806 lname: &str,
807 email: &str,
808 password: Option<String>,
809 organisation: Option<&String>,
810 readonly: bool,
811 ) -> Result<u32> {
812 match self {
813 DatabaseType::Postgres(pg) => {
814 pg.create_user(uname, fname, lname, email, password, organisation, readonly)
815 .await
816 }
817 #[cfg(any(test, feature = "sqlite"))]
818 DatabaseType::SQLite(sl) => {
819 sl.create_user(uname, fname, lname, email, password, organisation, readonly)
820 }
821 }
822 }
823
824 #[cfg(any(test, feature = "admin"))]
830 pub async fn reset_api_keys(&self) -> Result<u64> {
831 match self {
832 DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
833 #[cfg(any(test, feature = "sqlite"))]
834 DatabaseType::SQLite(sl) => sl.reset_api_keys(),
835 }
836 }
837
838 #[cfg(any(test, feature = "admin"))]
845 pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
846 match self {
847 DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
848 #[cfg(any(test, feature = "sqlite"))]
849 DatabaseType::SQLite(sl) => sl.set_password(uname, password),
850 }
851 }
852
853 #[cfg(any(test, feature = "admin"))]
859 pub async fn list_users(&self) -> Result<Vec<admin::User>> {
860 match self {
861 DatabaseType::Postgres(pg) => pg.list_users().await,
862 #[cfg(any(test, feature = "sqlite"))]
863 DatabaseType::SQLite(sl) => sl.list_users(),
864 }
865 }
866
867 #[cfg(any(test, feature = "admin"))]
874 pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
875 match self {
876 DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
877 #[cfg(any(test, feature = "sqlite"))]
878 DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
879 }
880 }
881
882 #[cfg(any(test, feature = "admin"))]
889 pub async fn edit_group(
890 &self,
891 gid: u32,
892 name: &str,
893 desc: &str,
894 parent: Option<u32>,
895 ) -> Result<()> {
896 match self {
897 DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
898 #[cfg(any(test, feature = "sqlite"))]
899 DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
900 }
901 }
902
903 #[cfg(any(test, feature = "admin"))]
909 pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
910 match self {
911 DatabaseType::Postgres(pg) => pg.list_groups().await,
912 #[cfg(any(test, feature = "sqlite"))]
913 DatabaseType::SQLite(sl) => sl.list_groups(),
914 }
915 }
916
917 #[cfg(any(test, feature = "admin"))]
924 pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
925 match self {
926 DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
927 #[cfg(any(test, feature = "sqlite"))]
928 DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
929 }
930 }
931
932 #[cfg(any(test, feature = "admin"))]
939 pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
940 match self {
941 DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
942 #[cfg(any(test, feature = "sqlite"))]
943 DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
944 }
945 }
946
947 #[cfg(any(test, feature = "admin"))]
954 pub async fn create_group(
955 &self,
956 name: &str,
957 description: &str,
958 parent: Option<u32>,
959 ) -> Result<u32> {
960 match self {
961 DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
962 #[cfg(any(test, feature = "sqlite"))]
963 DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
964 }
965 }
966
967 #[cfg(any(test, feature = "admin"))]
973 pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
974 match self {
975 DatabaseType::Postgres(pg) => pg.list_sources().await,
976 #[cfg(any(test, feature = "sqlite"))]
977 DatabaseType::SQLite(sl) => sl.list_sources(),
978 }
979 }
980
981 #[cfg(any(test, feature = "admin"))]
988 pub async fn create_source(
989 &self,
990 name: &str,
991 description: Option<&str>,
992 url: Option<&str>,
993 date: chrono::DateTime<Local>,
994 releasable: bool,
995 malicious: Option<bool>,
996 ) -> Result<u32> {
997 match self {
998 DatabaseType::Postgres(pg) => {
999 pg.create_source(name, description, url, date, releasable, malicious)
1000 .await
1001 }
1002 #[cfg(any(test, feature = "sqlite"))]
1003 DatabaseType::SQLite(sl) => {
1004 sl.create_source(name, description, url, date, releasable, malicious)
1005 }
1006 }
1007 }
1008
1009 #[cfg(any(test, feature = "admin"))]
1016 pub async fn edit_user(
1017 &self,
1018 uid: u32,
1019 uname: &str,
1020 fname: &str,
1021 lname: &str,
1022 email: &str,
1023 readonly: bool,
1024 ) -> Result<()> {
1025 match self {
1026 DatabaseType::Postgres(pg) => {
1027 pg.edit_user(uid, uname, fname, lname, email, readonly)
1028 .await
1029 }
1030 #[cfg(any(test, feature = "sqlite"))]
1031 DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
1032 }
1033 }
1034
1035 #[cfg(any(test, feature = "admin"))]
1042 pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
1043 match self {
1044 DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
1045 #[cfg(any(test, feature = "sqlite"))]
1046 DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
1047 }
1048 }
1049
1050 #[cfg(any(test, feature = "admin"))]
1056 pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
1057 match self {
1058 DatabaseType::Postgres(pg) => pg.file_types_counts().await,
1059 #[cfg(any(test, feature = "sqlite"))]
1060 DatabaseType::SQLite(sl) => sl.file_types_counts(),
1061 }
1062 }
1063
1064 #[cfg(any(test, feature = "admin"))]
1072 pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
1073 match self {
1074 DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
1075 #[cfg(any(test, feature = "sqlite"))]
1076 DatabaseType::SQLite(sl) => sl.create_label(name, parent),
1077 }
1078 }
1079
1080 #[cfg(any(test, feature = "admin"))]
1088 pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
1089 match self {
1090 DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
1091 #[cfg(any(test, feature = "sqlite"))]
1092 DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
1093 }
1094 }
1095
1096 #[cfg(any(test, feature = "admin"))]
1103 pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
1104 match self {
1105 DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
1106 #[cfg(any(test, feature = "sqlite"))]
1107 DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
1108 }
1109 }
1110
1111 #[cfg(any(test, feature = "admin"))]
1119 pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
1120 match self {
1121 DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
1122 #[cfg(any(test, feature = "sqlite"))]
1123 DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
1124 }
1125 }
1126}
1127
1128pub fn hash_password(password: &str) -> Result<String> {
1134 let salt = SaltString::generate(&mut OsRng);
1135 let argon2 = Argon2::default();
1136 Ok(argon2
1137 .hash_password(password.as_bytes(), &salt)?
1138 .to_string())
1139}
1140
1141#[must_use]
1143pub fn random_bytes_api_key() -> String {
1144 let key1 = uuid::Uuid::new_v4();
1145 let key2 = uuid::Uuid::new_v4();
1146 let key1 = key1.to_string().replace('-', "");
1147 let key2 = key2.to_string().replace('-', "");
1148 format!("{key1}{key2}")
1149}
1150
1151#[cfg(test)]
1152mod tests {
1153 use super::*;
1154 #[cfg(feature = "vt")]
1155 use crate::vt::VtUpdater;
1156
1157 use std::fs;
1158 #[cfg(feature = "vt")]
1159 use std::sync::Arc;
1160 #[cfg(feature = "vt")]
1161 use std::time::SystemTime;
1162
1163 use anyhow::Context;
1164 use fuzzyhash::FuzzyHash;
1165 use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
1166 use malwaredb_lzjd::{LZDict, Murmur3HashState};
1167 use tlsh_fixed::TlshBuilder;
1168 use uuid::Uuid;
1169
1170 const MALWARE_LABEL: &str = "malware";
1171 const RANSOMWARE_LABEL: &str = "ransomware";
1172
1173 fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1174 let mut hashes = vec![];
1175
1176 hashes.push((
1177 malwaredb_api::SimilarityHashType::SSDeep,
1178 FuzzyHash::new(data).to_string(),
1179 ));
1180
1181 let mut builder = TlshBuilder::new(
1182 tlsh_fixed::BucketKind::Bucket256,
1183 tlsh_fixed::ChecksumKind::ThreeByte,
1184 tlsh_fixed::Version::Version4,
1185 );
1186
1187 builder.update(data);
1188
1189 if let Ok(hasher) = builder.build() {
1190 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1191 }
1192
1193 let build_hasher = Murmur3HashState::default();
1194 let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1195 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1196
1197 malwaredb_api::SimilarSamplesRequest { hashes }
1198 }
1199
1200 async fn pg_config() -> Postgres {
1201 const CONNECTION_STRING: &str =
1204 "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1205
1206 if let Ok(pg_port) = std::env::var("PG_PORT") {
1207 let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1209 Postgres::new(&conn_string, None)
1210 .await
1211 .context(format!(
1212 "failed to connect to postgres with specified port {pg_port}"
1213 ))
1214 .unwrap()
1215 } else {
1216 Postgres::new(CONNECTION_STRING, None).await.unwrap()
1217 }
1218 }
1219
1220 #[tokio::test]
1221 #[ignore = "don't run this in CI"]
1222 async fn pg() {
1223 let psql = pg_config().await;
1224 psql.delete().await.unwrap();
1225
1226 let psql = pg_config().await;
1227 let db = DatabaseType::Postgres(psql);
1228 everything(&db).await.unwrap();
1229
1230 #[cfg(feature = "vt")]
1231 {
1232 let db_config = db.get_config().await.unwrap();
1233 let state = crate::State {
1234 port: 8080,
1235 directory: None,
1236 max_upload: 10 * 1024 * 1024,
1237 ip: "127.0.0.1".parse().unwrap(),
1238 db_type: Arc::new(db),
1239 db_config,
1240 keys: HashMap::new(),
1241 started: SystemTime::now(),
1242 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1243 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1244 }),
1245 tls_config: None,
1246 mdns: None,
1247 };
1248
1249 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1250
1251 vt.updater().await.unwrap();
1252 println!("PG: Did VT ops!");
1253
1254 let psql = pg_config().await;
1255
1256 let vt_stats = psql
1257 .get_vt_stats()
1258 .await
1259 .context("failed to get Postgres VT Stats")
1260 .unwrap();
1261 println!("{vt_stats:?}");
1262 assert!(
1263 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1264 );
1265 }
1266
1267 let psql = pg_config().await;
1269 psql.delete().await.unwrap();
1270 }
1271
1272 #[tokio::test]
1273 async fn sqlite() {
1274 const DB_FILE: &str = "testing_sqlite.db";
1275 if std::path::Path::new(DB_FILE).exists() {
1276 fs::remove_file(DB_FILE)
1277 .context(format!("failed to delete old SQLite file {DB_FILE}"))
1278 .unwrap();
1279 }
1280
1281 let sqlite = Sqlite::new(DB_FILE)
1282 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1283 .unwrap();
1284
1285 let db = DatabaseType::SQLite(sqlite);
1286 everything(&db).await.unwrap();
1287
1288 #[cfg(feature = "vt")]
1289 {
1290 let db_config = db.get_config().await.unwrap();
1291 let state = crate::State {
1292 port: 8080,
1293 directory: None,
1294 max_upload: 10 * 1024 * 1024,
1295 ip: "127.0.0.1".parse().unwrap(),
1296 db_type: Arc::new(db),
1297 db_config,
1298 keys: HashMap::new(),
1299 started: SystemTime::now(),
1300 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1301 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1302 }),
1303 tls_config: None,
1304 mdns: None,
1305 };
1306
1307 let sqlite_second = Sqlite::new(DB_FILE)
1308 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1309 .unwrap();
1310
1311 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1312
1313 vt.updater().await.unwrap();
1314 println!("Sqlite: Did VT ops!");
1315 let vt_stats = sqlite_second
1316 .get_vt_stats()
1317 .context("failed to get Sqlite VT Stats")
1318 .unwrap();
1319 println!("{vt_stats:?}");
1320 assert!(
1321 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1322 );
1323 }
1324
1325 fs::remove_file(DB_FILE)
1326 .context(format!("failed to delete SQLite file {DB_FILE}"))
1327 .unwrap();
1328 }
1329
1330 #[allow(clippy::too_many_lines)]
1331 async fn everything(db: &DatabaseType) -> Result<()> {
1332 const ADMIN_UNAME: &str = "admin";
1333 const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1334
1335 db.set_name("Testing Database")
1336 .await
1337 .context("setting instance name failed")?;
1338
1339 assert!(
1340 db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1341 "Authentication without password should have failed."
1342 );
1343
1344 db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1345 .await
1346 .context("failed to set admin password")?;
1347
1348 let admin_api_key = db
1349 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1350 .await
1351 .context("unable to get api key for admin")?;
1352 println!("API key: {admin_api_key}");
1353 assert_eq!(admin_api_key.len(), 64);
1354
1355 assert_eq!(
1356 db.get_uid(&admin_api_key).await?,
1357 0,
1358 "Unable to get UID given the API key"
1359 );
1360
1361 let admin_api_key_again = db
1362 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1363 .await
1364 .context("unable to get api key a second time for admin")?;
1365
1366 assert_eq!(
1367 admin_api_key, admin_api_key_again,
1368 "API keys didn't match the second time."
1369 );
1370
1371 let bad_password = "this_is_totally_not_my_password!!";
1372 eprintln!("Testing API login with incorrect password.");
1373 assert!(
1374 db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1375 "Authenticating as admin with a bad password should have failed."
1376 );
1377
1378 let admin_is_admin = db
1379 .user_is_admin(0)
1380 .await
1381 .context("unable to see if admin (uid 0) is an admin")?;
1382 assert!(admin_is_admin);
1383
1384 let new_user_uname = "testuser";
1385 let new_user_email = "test@example.com";
1386 let new_user_password = "some_awesome_password_++";
1387 let new_id = db
1388 .create_user(
1389 new_user_uname,
1390 new_user_uname,
1391 new_user_uname,
1392 new_user_email,
1393 Some(new_user_password.into()),
1394 None,
1395 false,
1396 )
1397 .await
1398 .context(format!("failed to create user {new_user_uname}"))?;
1399
1400 let passwordless_user_id = db
1401 .create_user(
1402 "passwordless_user",
1403 "passwordless_user",
1404 "passwordless_user",
1405 "passwordless_user@example.com",
1406 None,
1407 None,
1408 false,
1409 )
1410 .await
1411 .context("failed to create passwordless_user")?;
1412
1413 for user in &db.list_users().await.context("failed to list users")? {
1414 if user.id == passwordless_user_id {
1415 assert_eq!(user.uname, "passwordless_user");
1416 }
1417 }
1418
1419 db.edit_user(
1420 passwordless_user_id,
1421 "passwordless_user_2",
1422 "passwordless_user_2",
1423 "passwordless_user_2",
1424 "passwordless_user_2@something.com",
1425 false,
1426 )
1427 .await
1428 .context(format!(
1429 "failed to alter 'passwordless' user, id {passwordless_user_id}"
1430 ))?;
1431
1432 for user in &db.list_users().await.context("failed to list users")? {
1433 if user.id == passwordless_user_id {
1434 assert_eq!(user.uname, "passwordless_user_2");
1435 }
1436 }
1437
1438 assert!(
1439 new_id > 0,
1440 "Weird UID created for user {new_user_uname}: {new_id}"
1441 );
1442
1443 assert!(
1444 db.create_user(
1445 new_user_uname,
1446 new_user_uname,
1447 new_user_uname,
1448 new_user_email,
1449 Some(new_user_password.into()),
1450 None,
1451 false
1452 )
1453 .await
1454 .is_err(),
1455 "Creating a new user with the same user name should fail"
1456 );
1457
1458 let ro_user_name = "ro_user";
1459 let ro_user_password = "ro_user_password";
1460 db.create_user(
1461 ro_user_name,
1462 "ro_user",
1463 "ro_user",
1464 "ro@example.com",
1465 Some(ro_user_password.into()),
1466 None,
1467 true,
1468 )
1469 .await
1470 .context("failed to create read-only user")?;
1471
1472 let ro_user_api_key = db
1473 .authenticate(ro_user_name, ro_user_password)
1474 .await
1475 .context("unable to get api key for read-only user")?;
1476
1477 let new_user_password_change = "some_new_awesomer_password!_++";
1478 db.set_password(new_user_uname, new_user_password_change)
1479 .await
1480 .context("failed to change the password for testuser")?;
1481
1482 let new_user_api_key = db
1483 .authenticate(new_user_uname, new_user_password_change)
1484 .await
1485 .context("unable to get api key for testuser")?;
1486 eprintln!("{new_user_uname} got API key {new_user_api_key}");
1487
1488 assert_eq!(admin_api_key.len(), new_user_api_key.len());
1489
1490 let users = db.list_users().await.context("failed to list users")?;
1491 assert_eq!(
1492 users.len(),
1493 4,
1494 "Four users were created, yet there are {} users",
1495 users.len()
1496 );
1497 eprintln!("DB has {} users:", users.len());
1498 let mut passwordless_user_found = false;
1499 for user in users {
1500 println!("{user}");
1501 if user.uname == "passwordless_user_2" {
1502 assert!(!user.has_api_key);
1503 assert!(!user.has_password);
1504 passwordless_user_found = true;
1505 } else {
1506 assert!(user.has_api_key);
1507 assert!(user.has_password);
1508 }
1509 }
1510 assert!(passwordless_user_found);
1511
1512 let new_group_name = "some_new_group";
1513 let new_group_desc = "some_new_group_description";
1514 let new_group_id = 1;
1515 assert_eq!(
1516 db.create_group(new_group_name, new_group_desc, None)
1517 .await
1518 .context("failed to create group")?,
1519 new_group_id,
1520 "New group didn't have the expected ID, expected {new_group_id}"
1521 );
1522
1523 assert!(
1524 db.create_group(new_group_name, new_group_desc, None)
1525 .await
1526 .is_err(),
1527 "Duplicate group name should have failed"
1528 );
1529
1530 db.add_user_to_group(1, 1)
1531 .await
1532 .context("Unable to add uid 1 to gid 1")?;
1533
1534 let ro_user_uid = db
1535 .get_uid(&ro_user_api_key)
1536 .await
1537 .context("Unable to get UID for read-only user")?;
1538 db.add_user_to_group(ro_user_uid, 1)
1539 .await
1540 .context("Unable to add uid 2 to gid 1")?;
1541
1542 let new_admin_group_name = "admin_subgroup";
1543 let new_admin_group_desc = "admin_subgroup_description";
1544 let new_admin_group_id = 2;
1545 assert!(
1547 db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1548 .await
1549 .context("failed to create admin sub-group")?
1550 >= new_admin_group_id,
1551 "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1552 );
1553
1554 let groups = db.list_groups().await.context("failed to list groups")?;
1555 assert_eq!(
1556 groups.len(),
1557 3,
1558 "Three groups were created, yet there are {} groups",
1559 groups.len()
1560 );
1561 eprintln!("DB has {} groups:", groups.len());
1562 for group in groups {
1563 println!("{group}");
1564 if group.id == new_admin_group_id {
1565 assert_eq!(group.parent, Some("admin".to_string()));
1566 }
1567 if group.id == 1 {
1568 let test_user_str = String::from(new_user_uname);
1569 let mut found = false;
1570 for member in group.members {
1571 if member.uname == test_user_str {
1572 found = true;
1573 break;
1574 }
1575 }
1576 assert!(found, "new user {test_user_str} wasn't in the group");
1577 }
1578 }
1579
1580 let default_source_name = "default_source".to_string();
1581 let default_source_id = db
1582 .create_source(
1583 &default_source_name,
1584 Some("desc_default_source"),
1585 None,
1586 Local::now(),
1587 true,
1588 Some(false),
1589 )
1590 .await
1591 .context("failed to create source `default_source`")?;
1592
1593 db.add_group_to_source(1, default_source_id)
1594 .await
1595 .context("failed to add group 1 to source 1")?;
1596
1597 let another_source_name = "another_source".to_string();
1598 let another_source_id = db
1599 .create_source(
1600 &another_source_name,
1601 Some("yet another file source"),
1602 None,
1603 Local::now(),
1604 true,
1605 Some(false),
1606 )
1607 .await
1608 .context("failed to create source `another_source`")?;
1609
1610 let empty_source_name = "empty_source".to_string();
1611 db.create_source(
1612 &empty_source_name,
1613 Some("empty and unused file source"),
1614 None,
1615 Local::now(),
1616 true,
1617 Some(false),
1618 )
1619 .await
1620 .context("failed to create source `another_source`")?;
1621
1622 db.add_group_to_source(1, another_source_id)
1623 .await
1624 .context("failed to add group 1 to source 1")?;
1625
1626 let sources = db.list_sources().await.context("failed to list sources")?;
1627 eprintln!("DB has {} sources:", sources.len());
1628 for source in sources {
1629 println!("{source}");
1630 assert_eq!(source.files, 0);
1631 if source.id == default_source_id || source.id == another_source_id {
1632 assert_eq!(
1633 source.groups, 1,
1634 "default source {default_source_name} should have 1 group"
1635 );
1636 } else {
1637 assert_eq!(source.groups, 0, "groups should zero (empty)");
1638 }
1639 }
1640
1641 let uid = db
1642 .get_uid(&new_user_api_key)
1643 .await
1644 .context("failed to user uid from apikey")?;
1645 let user_info = db
1646 .get_user_info(uid)
1647 .await
1648 .context("failed to get user's available groups and sources")?;
1649 assert!(user_info.sources.contains(&default_source_name));
1650 assert!(!user_info.is_admin);
1651 println!("UserInfoResponse: {user_info:?}");
1652
1653 assert!(
1654 db.allowed_user_source(1, default_source_id)
1655 .await
1656 .context(format!(
1657 "failed to check that user 1 has access to source {default_source_id}"
1658 ))?,
1659 "User 1 should should have had access to source {default_source_id}"
1660 );
1661
1662 assert!(
1663 !db.allowed_user_source(1, 5)
1664 .await
1665 .context("failed to check that user 1 has access to source 5")?,
1666 "User 1 should should not have had access to source 5"
1667 );
1668
1669 let test_label_id = db
1670 .create_label("TestLabel", None)
1671 .await
1672 .context("failed to create test label")?;
1673 let test_elf_label_id = db
1674 .create_label("TestELF", Some(test_label_id))
1675 .await
1676 .context("failed to create test label")?;
1677
1678 let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1679 let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1680 let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1681
1682 let known_type =
1683 KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1684 assert!(known_type.is_exec(), "ELF should be executable");
1685 eprintln!("ELF type ID: {elf_type}");
1686
1687 let file_addition = db
1688 .add_file(
1689 &test_elf_meta,
1690 known_type.clone(),
1691 1,
1692 default_source_id,
1693 elf_type,
1694 None,
1695 )
1696 .await
1697 .context("failed to insert a test elf")?;
1698 assert!(file_addition.is_new, "File should have been added");
1699 eprintln!("Added ELF to the DB");
1700 db.label_file(file_addition.file_id, test_elf_label_id)
1701 .await
1702 .context("failed to label file")?;
1703
1704 let partial_search = SearchRequest {
1706 search: SearchType::Search(SearchRequestParameters {
1707 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1708 labels: Some(vec![String::from("TestELF")]),
1709 file_type: Some(String::from("ELF")),
1710 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1711 ..Default::default()
1712 }),
1713 };
1714 assert!(partial_search.is_valid());
1715 let partial_search_response = db.partial_search(1, partial_search).await?;
1716 assert_eq!(partial_search_response.hashes.len(), 1);
1717 assert_eq!(
1718 partial_search_response.hashes[0],
1719 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1720 );
1721
1722 let partial_search = SearchRequest {
1724 search: SearchType::Search(SearchRequestParameters {
1725 partial_hash: None,
1726 labels: None,
1727 file_type: None,
1728 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1729 ..Default::default()
1730 }),
1731 };
1732 assert!(partial_search.is_valid());
1733 let partial_search_response = db.partial_search(1, partial_search).await?;
1734 assert_eq!(partial_search_response.hashes.len(), 1);
1735 assert_eq!(
1736 partial_search_response.hashes[0],
1737 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1738 );
1739
1740 let partial_search = SearchRequest {
1742 search: SearchType::Search(SearchRequestParameters {
1743 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1744 file_type: Some(String::from("PE32")),
1745 ..Default::default()
1746 }),
1747 };
1748 assert!(partial_search.is_valid());
1749 let partial_search_response = db.partial_search(1, partial_search).await?;
1750 assert_eq!(partial_search_response.hashes.len(), 0);
1751
1752 let partial_search = SearchRequest {
1753 search: SearchType::Search(SearchRequestParameters {
1754 file_name: Some("ppc64".into()),
1755 ..Default::default()
1756 }),
1757 };
1758 assert!(partial_search.is_valid());
1759 let partial_search_response = db.partial_search(1, partial_search).await?;
1760 assert_eq!(partial_search_response.hashes.len(), 1);
1761
1762 let partial_search = SearchRequest {
1764 search: SearchType::Search(SearchRequestParameters::default()),
1765 };
1766 assert!(!partial_search.is_valid());
1767 let partial_search_response = db.partial_search(1, partial_search).await?;
1768 assert!(partial_search_response.hashes.is_empty());
1769
1770 let partial_search = SearchRequest {
1772 search: SearchType::Continuation(Uuid::default()),
1773 };
1774 assert!(partial_search.is_valid());
1775 let partial_search_response = db.partial_search(1, partial_search).await?;
1776 assert!(partial_search_response.hashes.is_empty());
1777
1778 assert!(db
1780 .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1781 .await
1782 .is_err());
1783
1784 assert!(
1785 db.add_file(
1786 &test_elf_meta,
1787 known_type.clone(),
1788 ro_user_uid,
1789 default_source_id,
1790 elf_type,
1791 None
1792 )
1793 .await
1794 .is_err(),
1795 "Read-only user should not be able to add a file"
1796 );
1797
1798 let mut test_elf_meta_different_name = test_elf_meta.clone();
1799 test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1800
1801 assert!(
1802 !db.add_file(
1803 &test_elf_meta_different_name,
1804 known_type,
1805 1,
1806 another_source_id,
1807 elf_type,
1808 None
1809 )
1810 .await
1811 .context("failed to insert a test elf again for a different source")?
1812 .is_new
1813 );
1814
1815 let sources = db
1816 .list_sources()
1817 .await
1818 .context("failed to re-list sources")?;
1819 eprintln!(
1820 "DB has {} sources, and a file was added twice:",
1821 sources.len()
1822 );
1823 println!("We should have two sources with one file each, yet only one ELF.");
1824 for source in sources {
1825 println!("{source}");
1826 if source.id == default_source_id || source.id == another_source_id {
1827 assert_eq!(source.files, 1);
1828 } else {
1829 assert_eq!(source.files, 0, "groups should zero (empty)");
1830 }
1831 }
1832
1833 assert!(!db
1834 .get_user_sources(1)
1835 .await
1836 .expect("failed to get user 1's sources")
1837 .sources
1838 .is_empty());
1839
1840 let file_types_counts = db
1841 .file_types_counts()
1842 .await
1843 .context("failed to get file types and counts")?;
1844 for (name, count) in file_types_counts {
1845 println!("{name}: {count}");
1846 assert_eq!(name, "ELF");
1847 assert_eq!(count, 1);
1848 }
1849
1850 let mut test_elf_modified = test_elf.clone();
1851 let random_bytes = Uuid::new_v4();
1852 let mut random_bytes = random_bytes.into_bytes().to_vec();
1853 test_elf_modified.append(&mut random_bytes);
1854 let similarity_request = generate_similarity_request(&test_elf_modified);
1855 let similarity_response = db
1856 .find_similar_samples(1, &similarity_request.hashes)
1857 .await
1858 .context("failed to get similarity response")?;
1859 eprintln!("Similarity response: {similarity_response:?}");
1860 let similarity_response = similarity_response.first().unwrap();
1861 assert_eq!(
1862 similarity_response.sha256,
1863 hex::encode(&test_elf_meta.sha256),
1864 "Similarity response should have had the hash of the original ELF"
1865 );
1866 for (algo, sim) in &similarity_response.algorithms {
1867 match algo {
1868 malwaredb_api::SimilarityHashType::LZJD => {
1869 assert!(*sim > 0.0f32);
1870 }
1871 malwaredb_api::SimilarityHashType::SSDeep => {
1872 assert!(*sim > 80.0f32);
1873 }
1874 malwaredb_api::SimilarityHashType::TLSH => {
1875 assert!(*sim <= 20f32);
1876 }
1877 _ => {}
1878 }
1879 }
1880
1881 let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1.as_slice())
1882 .context("failed to get `HashType::SHA1` from string")?;
1883 let response_sha256 = db
1884 .retrieve_sample(1, &test_elf_hashtype)
1885 .await
1886 .context("could not get SHA-256 hash from test sample")
1887 .unwrap();
1888 assert_eq!(response_sha256, hex::encode(&test_elf_meta.sha256));
1889
1890 let test_bogus_hash =
1891 HashType::try_from("d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0")
1892 .context("failed to get `HashType` from static string")?;
1893 assert!(
1894 db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1895 "Getting a file with a bogus hash should have failed."
1896 );
1897
1898 let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1899 let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1900 let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1901
1902 let known_type =
1903 KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1904
1905 assert!(
1906 db.add_file(
1907 &test_pdf_meta,
1908 known_type,
1909 1,
1910 default_source_id,
1911 pdf_type,
1912 None
1913 )
1914 .await
1915 .context("failed to insert a test pdf")?
1916 .is_new
1917 );
1918 eprintln!("Added PDF to the DB");
1919
1920 let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1921 let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1922 let rtf_type = db
1923 .get_type_id_for_bytes(&test_rtf)
1924 .await
1925 .context("failed to get file type id for rtf")?;
1926
1927 let known_type =
1928 KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1929
1930 assert!(
1931 db.add_file(
1932 &test_rtf_meta,
1933 known_type,
1934 1,
1935 default_source_id,
1936 rtf_type,
1937 None
1938 )
1939 .await
1940 .context("failed to insert a test rtf")?
1941 .is_new
1942 );
1943 eprintln!("Added RTF to the DB");
1944
1945 let report = db
1946 .get_sample_report(
1947 1,
1948 &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap(),
1949 )
1950 .await
1951 .context("failed to get report for test rtf")?;
1952 assert!(report
1953 .clone()
1954 .filecommand
1955 .unwrap()
1956 .contains("Rich Text Format"));
1957 println!("Report: {report}");
1958
1959 assert!(db
1960 .get_sample_report(
1961 999,
1962 &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap()
1963 )
1964 .await
1965 .is_err());
1966
1967 #[cfg(feature = "vt")]
1968 {
1969 assert!(report.vt.is_some());
1970 let files_needing_vt = db
1971 .files_without_vt_records(10)
1972 .await
1973 .context("failed to get files without VT records")?;
1974 assert!(files_needing_vt.len() > 2);
1975 println!(
1976 "{} files needing VT data: {files_needing_vt:?}",
1977 files_needing_vt.len()
1978 );
1979 }
1980
1981 #[cfg(not(feature = "vt"))]
1982 {
1983 assert!(report.vt.is_none());
1984 }
1985
1986 let reset = db
1987 .reset_api_keys()
1988 .await
1989 .context("failed to reset all API keys")?;
1990 eprintln!("Cleared {reset} api keys.");
1991
1992 let db_info = db.db_info().await.context("failed to get database info")?;
1993 eprintln!("DB Info: {db_info:?}");
1994
1995 let data_types = db
1996 .get_known_data_types()
1997 .await
1998 .context("failed to get data types")?;
1999 for data_type in data_types {
2000 println!("{data_type:?}");
2001 }
2002
2003 let sources = db
2004 .list_sources()
2005 .await
2006 .context("failed to list sources second time")?;
2007 eprintln!("DB has {} sources:", sources.len());
2008 for source in sources {
2009 println!("{source}");
2010 }
2011
2012 let file_types_counts = db
2013 .file_types_counts()
2014 .await
2015 .context("failed to get file types and counts")?;
2016 for (name, count) in file_types_counts {
2017 println!("{name}: {count}");
2018 assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
2019 }
2020
2021 let fatmacho =
2022 include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
2023 .to_vec();
2024 let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
2025 let fatmacho_type = db
2026 .get_type_id_for_bytes(&fatmacho)
2027 .await
2028 .context("failed to get file type for Fat Mach-O")?;
2029 let known_type = KnownType::new(&fatmacho)
2030 .context("failed to parse Fat Mach-O from type crate's test data")?;
2031
2032 assert!(
2033 db.add_file(
2034 &fatmacho_meta,
2035 known_type,
2036 1,
2037 default_source_id,
2038 fatmacho_type,
2039 None
2040 )
2041 .await
2042 .context("failed to insert a test Fat Mach-O")?
2043 .is_new
2044 );
2045 eprintln!("Added Fat Mach-O to the DB");
2046
2047 let file_types_counts = db
2048 .file_types_counts()
2049 .await
2050 .context("failed to get file types and counts")?;
2051 for (name, count) in &file_types_counts {
2052 println!("{name}: {count}");
2053 }
2054
2055 assert_eq!(
2056 *file_types_counts.get("Mach-O").unwrap(),
2057 4,
2058 "Expected 4 Mach-O files, got {:?}",
2059 file_types_counts.get("Mach-O")
2060 );
2061
2062 let allowed_files = db
2063 .user_allowed_files_by_sha256(1, None)
2064 .await
2065 .context("failed to get allowed files")?;
2066 assert_eq!(allowed_files.0.len(), 8);
2067
2068 let allowed_files = db
2069 .user_allowed_files_by_sha256(1, Some(allowed_files.1))
2070 .await
2071 .context("failed to get allowed files")?;
2072 assert!(allowed_files.0.is_empty());
2073
2074 let malware_label_id = db
2075 .create_label(MALWARE_LABEL, None)
2076 .await
2077 .context("failed to create first label")?;
2078 let ransomware_label_id = db
2079 .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
2080 .await
2081 .context("failed to create malware sub-label")?;
2082 let labels = db.get_labels().await.context("failed to get labels")?;
2083
2084 assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
2085 for label in labels.0 {
2086 if label.name == RANSOMWARE_LABEL {
2087 assert_eq!(label.id, ransomware_label_id);
2088 assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
2089 }
2090 }
2091
2092 let source_code = include_bytes!("mod.rs");
2094 let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
2095 let known_type =
2096 KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
2097
2098 assert!(matches!(known_type, KnownType::Unknown(_)));
2099
2100 let unknown_type: Vec<FileType> = db
2101 .get_known_data_types()
2102 .await?
2103 .into_iter()
2104 .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
2105 .collect();
2106 let unknown_type_id = unknown_type.first().unwrap().id;
2107 assert!(db.get_type_id_for_bytes(source_code).await.is_err());
2108 db.enable_keep_unknown_files()
2109 .await
2110 .context("failed to enable keeping of unknown files")?;
2111 let source_type = db
2112 .get_type_id_for_bytes(source_code)
2113 .await
2114 .context("failed to type id for source code unknown type example")?;
2115 assert_eq!(source_type, unknown_type_id);
2116 eprintln!("Unknown file type ID: {source_type}");
2117 assert!(
2118 db.add_file(
2119 &source_meta,
2120 known_type,
2121 1,
2122 default_source_id,
2123 unknown_type_id,
2124 None
2125 )
2126 .await
2127 .context("failed to add Rust source code file")?
2128 .is_new
2129 );
2130 eprintln!("Added Rust source code to the DB");
2131
2132 #[cfg(feature = "yara")]
2133 assert!(db.get_unfinished_yara_tasks().await?.is_empty());
2134
2135 db.reset_own_api_key(0)
2136 .await
2137 .context("failed to clear own API key uid 0")?;
2138
2139 db.deactivate_user(0)
2140 .await
2141 .context("failed to clear password and API key for uid 0")?;
2142
2143 Ok(())
2144 }
2145}