1#[cfg(any(test, feature = "admin"))]
10mod admin;
11mod pg;
13
14#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33 digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51#[derive(Debug)]
53pub enum DatabaseType {
54 Postgres(Postgres),
56
57 #[cfg(any(test, feature = "sqlite"))]
59 SQLite(Sqlite),
60}
61
62#[derive(Debug)]
64pub struct DatabaseInformation {
65 pub version: String,
67
68 pub size: String,
70
71 pub num_files: u64,
73
74 pub num_users: u32,
76
77 pub num_groups: u32,
79
80 pub num_sources: u32,
82}
83
84pub struct FileAddedResult {
86 pub file_id: u64,
88
89 pub is_new: bool,
92}
93
94#[derive(Debug)]
96pub struct MDBConfig {
97 pub name: String,
99
100 pub compression: bool,
102
103 pub send_samples_to_vt: bool,
105
106 pub keep_unknown_files: bool,
108
109 pub(crate) default_key: Option<u32>,
111}
112
113#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
115#[cfg(feature = "vt")]
116#[derive(Debug, Clone, Copy)]
117pub struct VtStats {
118 pub clean_records: u32,
120
121 pub hits_records: u32,
123
124 pub files_without_records: u32,
126}
127
128impl DatabaseType {
129 pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
138 #[cfg(any(test, feature = "sqlite"))]
139 if arg.starts_with("file:") {
140 let new_conn_str = arg.trim_start_matches("file:");
141 return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
142 }
143
144 if arg.starts_with("postgres") {
145 let new_conn_str = arg.trim_start_matches("postgres");
146 return Ok(DatabaseType::Postgres(
147 Postgres::new(new_conn_str, server_ca).await?,
148 ));
149 }
150
151 bail!("unknown database type `{arg}`")
152 }
153
154 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
160 #[cfg(feature = "vt")]
161 pub async fn enable_vt_upload(&self) -> Result<()> {
162 match self {
163 DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
164 #[cfg(any(test, feature = "sqlite"))]
165 DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
166 }
167 }
168
169 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
175 #[cfg(feature = "vt")]
176 pub async fn disable_vt_upload(&self) -> Result<()> {
177 match self {
178 DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
179 #[cfg(any(test, feature = "sqlite"))]
180 DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
181 }
182 }
183
184 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
190 #[cfg(feature = "vt")]
191 pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
192 match self {
193 DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
194 #[cfg(any(test, feature = "sqlite"))]
195 DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
196 }
197 }
198
199 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
205 #[cfg(feature = "vt")]
206 pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
207 match self {
208 DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
209 #[cfg(any(test, feature = "sqlite"))]
210 DatabaseType::SQLite(sl) => sl.store_vt_record(results),
211 }
212 }
213
214 #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
220 #[cfg(feature = "vt")]
221 pub async fn get_vt_stats(&self) -> Result<VtStats> {
222 match self {
223 DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
224 #[cfg(any(test, feature = "sqlite"))]
225 DatabaseType::SQLite(sl) => sl.get_vt_stats(),
226 }
227 }
228
229 pub async fn get_config(&self) -> Result<MDBConfig> {
235 match self {
236 DatabaseType::Postgres(pg) => pg.get_config().await,
237 #[cfg(any(test, feature = "sqlite"))]
238 DatabaseType::SQLite(sl) => sl.get_config(),
239 }
240 }
241
242 pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
249 match self {
250 DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
251 #[cfg(any(test, feature = "sqlite"))]
252 DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
253 }
254 }
255
256 pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
263 ensure!(!apikey.is_empty(), "API key was empty");
264 match self {
265 DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
266 #[cfg(any(test, feature = "sqlite"))]
267 DatabaseType::SQLite(sl) => sl.get_uid(apikey),
268 }
269 }
270
271 pub async fn db_info(&self) -> Result<DatabaseInformation> {
277 match self {
278 DatabaseType::Postgres(pg) => pg.db_info().await,
279 #[cfg(any(test, feature = "sqlite"))]
280 DatabaseType::SQLite(sl) => sl.db_info(),
281 }
282 }
283
284 pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
291 match self {
292 DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
293 #[cfg(any(test, feature = "sqlite"))]
294 DatabaseType::SQLite(sl) => sl.get_user_info(uid),
295 }
296 }
297
298 pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
305 match self {
306 DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
307 #[cfg(any(test, feature = "sqlite"))]
308 DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
309 }
310 }
311
312 pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
319 match self {
320 DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
321 #[cfg(any(test, feature = "sqlite"))]
322 DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
323 }
324 }
325
326 pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
332 match self {
333 DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
334 #[cfg(any(test, feature = "sqlite"))]
335 DatabaseType::SQLite(sl) => sl.get_known_data_types(),
336 }
337 }
338
339 pub async fn get_labels(&self) -> Result<Labels> {
345 match self {
346 DatabaseType::Postgres(pg) => pg.get_labels().await,
347 #[cfg(any(test, feature = "sqlite"))]
348 DatabaseType::SQLite(sl) => sl.get_labels(),
349 }
350 }
351
352 pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
358 match self {
359 DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
360 #[cfg(any(test, feature = "sqlite"))]
361 DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
362 }
363 }
364
365 pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
372 match self {
373 DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
374 #[cfg(any(test, feature = "sqlite"))]
375 DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
376 }
377 }
378
379 pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
387 match self {
388 DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
389 #[cfg(any(test, feature = "sqlite"))]
390 DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
391 }
392 }
393
394 pub async fn add_file(
401 &self,
402 meta: &FileMetadata,
403 known_type: KnownType<'_>,
404 uid: u32,
405 sid: u32,
406 ftype: u32,
407 parent: Option<u64>,
408 ) -> Result<FileAddedResult> {
409 match self {
410 DatabaseType::Postgres(pg) => {
411 pg.add_file(meta, known_type, uid, sid, ftype, parent).await
412 }
413 #[cfg(any(test, feature = "sqlite"))]
414 DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
415 }
416 }
417
418 pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
424 match self {
425 DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
426 #[cfg(any(test, feature = "sqlite"))]
427 DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
428 }
429 }
430
431 pub async fn cleanup(&self) -> Result<u64> {
437 match self {
438 DatabaseType::Postgres(pg) => pg.cleanup().await,
439 #[cfg(any(test, feature = "sqlite"))]
440 DatabaseType::SQLite(sl) => sl.cleanup(),
441 }
442 }
443
444 pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
452 match self {
453 DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
454 #[cfg(any(test, feature = "sqlite"))]
455 DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
456 }
457 }
458
459 pub async fn get_sample_report(
466 &self,
467 uid: u32,
468 hash: &HashType,
469 ) -> Result<malwaredb_api::Report> {
470 match self {
471 DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
472 #[cfg(any(test, feature = "sqlite"))]
473 DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
474 }
475 }
476
477 pub async fn find_similar_samples(
483 &self,
484 uid: u32,
485 sim: &[(malwaredb_api::SimilarityHashType, String)],
486 ) -> Result<Vec<malwaredb_api::SimilarSample>> {
487 match self {
488 DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
489 #[cfg(any(test, feature = "sqlite"))]
490 DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
491 }
492 }
493
494 pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
498 match self {
499 DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
500 #[cfg(any(test, feature = "sqlite"))]
501 DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
502 }
503 }
504
505 pub(crate) async fn get_file_encryption_key_id(
507 &self,
508 hash: &str,
509 ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
510 match self {
511 DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
512 #[cfg(any(test, feature = "sqlite"))]
513 DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
514 }
515 }
516
517 pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
519 match self {
520 DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
521 #[cfg(any(test, feature = "sqlite"))]
522 DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
523 }
524 }
525
526 #[cfg(any(test, feature = "admin"))]
534 pub async fn set_name(&self, name: &str) -> Result<()> {
535 match self {
536 DatabaseType::Postgres(pg) => pg.set_name(name).await,
537 #[cfg(any(test, feature = "sqlite"))]
538 DatabaseType::SQLite(sl) => sl.set_name(name),
539 }
540 }
541
542 #[cfg(any(test, feature = "admin"))]
548 pub async fn enable_compression(&self) -> Result<()> {
549 match self {
550 DatabaseType::Postgres(pg) => pg.enable_compression().await,
551 #[cfg(any(test, feature = "sqlite"))]
552 DatabaseType::SQLite(sl) => sl.enable_compression(),
553 }
554 }
555
556 #[cfg(any(test, feature = "admin"))]
562 pub async fn disable_compression(&self) -> Result<()> {
563 match self {
564 DatabaseType::Postgres(pg) => pg.disable_compression().await,
565 #[cfg(any(test, feature = "sqlite"))]
566 DatabaseType::SQLite(sl) => sl.disable_compression(),
567 }
568 }
569
570 #[cfg(any(test, feature = "admin"))]
576 pub async fn enable_keep_unknown_files(&self) -> Result<()> {
577 match self {
578 DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
579 #[cfg(any(test, feature = "sqlite"))]
580 DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
581 }
582 }
583
584 #[cfg(any(test, feature = "admin"))]
590 pub async fn disable_keep_unknown_files(&self) -> Result<()> {
591 match self {
592 DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
593 #[cfg(any(test, feature = "sqlite"))]
594 DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
595 }
596 }
597
598 #[cfg(any(test, feature = "admin"))]
604 pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
605 match self {
606 DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
607 #[cfg(any(test, feature = "sqlite"))]
608 DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
609 }
610 }
611
612 #[cfg(any(test, feature = "admin"))]
618 pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
619 match self {
620 DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
621 #[cfg(any(test, feature = "sqlite"))]
622 DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
623 }
624 }
625
626 #[allow(clippy::too_many_arguments)]
632 #[cfg(any(test, feature = "admin"))]
633 pub async fn create_user(
634 &self,
635 uname: &str,
636 fname: &str,
637 lname: &str,
638 email: &str,
639 password: Option<String>,
640 organisation: Option<&String>,
641 readonly: bool,
642 ) -> Result<u32> {
643 match self {
644 DatabaseType::Postgres(pg) => {
645 pg.create_user(uname, fname, lname, email, password, organisation, readonly)
646 .await
647 }
648 #[cfg(any(test, feature = "sqlite"))]
649 DatabaseType::SQLite(sl) => {
650 sl.create_user(uname, fname, lname, email, password, organisation, readonly)
651 }
652 }
653 }
654
655 #[cfg(any(test, feature = "admin"))]
661 pub async fn reset_api_keys(&self) -> Result<u64> {
662 match self {
663 DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
664 #[cfg(any(test, feature = "sqlite"))]
665 DatabaseType::SQLite(sl) => sl.reset_api_keys(),
666 }
667 }
668
669 #[cfg(any(test, feature = "admin"))]
676 pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
677 match self {
678 DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
679 #[cfg(any(test, feature = "sqlite"))]
680 DatabaseType::SQLite(sl) => sl.set_password(uname, password),
681 }
682 }
683
684 #[cfg(any(test, feature = "admin"))]
690 pub async fn list_users(&self) -> Result<Vec<admin::User>> {
691 match self {
692 DatabaseType::Postgres(pg) => pg.list_users().await,
693 #[cfg(any(test, feature = "sqlite"))]
694 DatabaseType::SQLite(sl) => sl.list_users(),
695 }
696 }
697
698 #[cfg(any(test, feature = "admin"))]
705 pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
706 match self {
707 DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
708 #[cfg(any(test, feature = "sqlite"))]
709 DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
710 }
711 }
712
713 #[cfg(any(test, feature = "admin"))]
720 pub async fn edit_group(
721 &self,
722 gid: u32,
723 name: &str,
724 desc: &str,
725 parent: Option<u32>,
726 ) -> Result<()> {
727 match self {
728 DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
729 #[cfg(any(test, feature = "sqlite"))]
730 DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
731 }
732 }
733
734 #[cfg(any(test, feature = "admin"))]
740 pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
741 match self {
742 DatabaseType::Postgres(pg) => pg.list_groups().await,
743 #[cfg(any(test, feature = "sqlite"))]
744 DatabaseType::SQLite(sl) => sl.list_groups(),
745 }
746 }
747
748 #[cfg(any(test, feature = "admin"))]
755 pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
756 match self {
757 DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
758 #[cfg(any(test, feature = "sqlite"))]
759 DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
760 }
761 }
762
763 #[cfg(any(test, feature = "admin"))]
770 pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
771 match self {
772 DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
773 #[cfg(any(test, feature = "sqlite"))]
774 DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
775 }
776 }
777
778 #[cfg(any(test, feature = "admin"))]
785 pub async fn create_group(
786 &self,
787 name: &str,
788 description: &str,
789 parent: Option<u32>,
790 ) -> Result<u32> {
791 match self {
792 DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
793 #[cfg(any(test, feature = "sqlite"))]
794 DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
795 }
796 }
797
798 #[cfg(any(test, feature = "admin"))]
804 pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
805 match self {
806 DatabaseType::Postgres(pg) => pg.list_sources().await,
807 #[cfg(any(test, feature = "sqlite"))]
808 DatabaseType::SQLite(sl) => sl.list_sources(),
809 }
810 }
811
812 #[cfg(any(test, feature = "admin"))]
819 pub async fn create_source(
820 &self,
821 name: &str,
822 description: Option<&str>,
823 url: Option<&str>,
824 date: chrono::DateTime<Local>,
825 releasable: bool,
826 malicious: Option<bool>,
827 ) -> Result<u32> {
828 match self {
829 DatabaseType::Postgres(pg) => {
830 pg.create_source(name, description, url, date, releasable, malicious)
831 .await
832 }
833 #[cfg(any(test, feature = "sqlite"))]
834 DatabaseType::SQLite(sl) => {
835 sl.create_source(name, description, url, date, releasable, malicious)
836 }
837 }
838 }
839
840 #[cfg(any(test, feature = "admin"))]
847 pub async fn edit_user(
848 &self,
849 uid: u32,
850 uname: &str,
851 fname: &str,
852 lname: &str,
853 email: &str,
854 readonly: bool,
855 ) -> Result<()> {
856 match self {
857 DatabaseType::Postgres(pg) => {
858 pg.edit_user(uid, uname, fname, lname, email, readonly)
859 .await
860 }
861 #[cfg(any(test, feature = "sqlite"))]
862 DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
863 }
864 }
865
866 #[cfg(any(test, feature = "admin"))]
873 pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
874 match self {
875 DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
876 #[cfg(any(test, feature = "sqlite"))]
877 DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
878 }
879 }
880
881 #[cfg(any(test, feature = "admin"))]
887 pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
888 match self {
889 DatabaseType::Postgres(pg) => pg.file_types_counts().await,
890 #[cfg(any(test, feature = "sqlite"))]
891 DatabaseType::SQLite(sl) => sl.file_types_counts(),
892 }
893 }
894
895 #[cfg(any(test, feature = "admin"))]
903 pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
904 match self {
905 DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
906 #[cfg(any(test, feature = "sqlite"))]
907 DatabaseType::SQLite(sl) => sl.create_label(name, parent),
908 }
909 }
910
911 #[cfg(any(test, feature = "admin"))]
919 pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
920 match self {
921 DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
922 #[cfg(any(test, feature = "sqlite"))]
923 DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
924 }
925 }
926
927 #[cfg(any(test, feature = "admin"))]
934 pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
935 match self {
936 DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
937 #[cfg(any(test, feature = "sqlite"))]
938 DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
939 }
940 }
941
942 #[cfg(any(test, feature = "admin"))]
950 pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
951 match self {
952 DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
953 #[cfg(any(test, feature = "sqlite"))]
954 DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
955 }
956 }
957}
958
959pub fn hash_password(password: &str) -> Result<String> {
965 let salt = SaltString::generate(&mut OsRng);
966 let argon2 = Argon2::default();
967 Ok(argon2
968 .hash_password(password.as_bytes(), &salt)?
969 .to_string())
970}
971
972#[must_use]
974pub fn random_bytes_api_key() -> String {
975 let key1 = uuid::Uuid::new_v4();
976 let key2 = uuid::Uuid::new_v4();
977 let key1 = key1.to_string().replace('-', "");
978 let key2 = key2.to_string().replace('-', "");
979 format!("{key1}{key2}")
980}
981
982#[cfg(test)]
983mod tests {
984 use super::*;
985 #[cfg(feature = "vt")]
986 use crate::vt::VtUpdater;
987
988 use std::fs;
989 #[cfg(feature = "vt")]
990 use std::sync::Arc;
991 #[cfg(feature = "vt")]
992 use std::time::SystemTime;
993
994 use anyhow::Context;
995 use fuzzyhash::FuzzyHash;
996 use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
997 use malwaredb_lzjd::{LZDict, Murmur3HashState};
998 use tlsh_fixed::TlshBuilder;
999 use uuid::Uuid;
1000
1001 const MALWARE_LABEL: &str = "malware";
1002 const RANSOMWARE_LABEL: &str = "ransomware";
1003
1004 fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1005 let mut hashes = vec![];
1006
1007 hashes.push((
1008 malwaredb_api::SimilarityHashType::SSDeep,
1009 FuzzyHash::new(data).to_string(),
1010 ));
1011
1012 let mut builder = TlshBuilder::new(
1013 tlsh_fixed::BucketKind::Bucket256,
1014 tlsh_fixed::ChecksumKind::ThreeByte,
1015 tlsh_fixed::Version::Version4,
1016 );
1017
1018 builder.update(data);
1019
1020 if let Ok(hasher) = builder.build() {
1021 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1022 }
1023
1024 let build_hasher = Murmur3HashState::default();
1025 let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1026 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1027
1028 malwaredb_api::SimilarSamplesRequest { hashes }
1029 }
1030
1031 async fn pg_config() -> Postgres {
1032 const CONNECTION_STRING: &str =
1035 "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1036
1037 if let Ok(pg_port) = std::env::var("PG_PORT") {
1038 let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1040 Postgres::new(&conn_string, None)
1041 .await
1042 .context(format!(
1043 "failed to connect to postgres with specified port {pg_port}"
1044 ))
1045 .unwrap()
1046 } else {
1047 Postgres::new(CONNECTION_STRING, None).await.unwrap()
1048 }
1049 }
1050
1051 #[tokio::test]
1052 #[ignore = "don't run this in CI"]
1053 async fn pg() {
1054 let psql = pg_config().await;
1055 psql.delete_init().await.unwrap();
1056
1057 let db = DatabaseType::Postgres(psql);
1058 let key = FileEncryption::from(EncryptionOption::Xor);
1059 db.add_file_encryption_key(&key).await.unwrap();
1060 assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1061 everything(&db).await.unwrap();
1062
1063 #[cfg(feature = "vt")]
1064 {
1065 let db_config = db.get_config().await.unwrap();
1066 let state = crate::State {
1067 port: 8080,
1068 directory: None,
1069 max_upload: 10 * 1024 * 1024,
1070 ip: "127.0.0.1".parse().unwrap(),
1071 db_type: Arc::new(db),
1072 db_config,
1073 keys: HashMap::new(),
1074 started: SystemTime::now(),
1075 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1076 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1077 }),
1078 cert: None,
1079 key: None,
1080 mdns: false,
1081 };
1082
1083 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1084
1085 vt.updater().await.unwrap();
1086 println!("PG: Did VT ops!");
1087
1088 let psql = pg_config().await;
1089
1090 let vt_stats = psql
1091 .get_vt_stats()
1092 .await
1093 .context("failed to get Postgres VT Stats")
1094 .unwrap();
1095 println!("{vt_stats:?}");
1096 assert!(
1097 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1098 );
1099 }
1100
1101 let psql = pg_config().await;
1103 psql.delete_init().await.unwrap();
1104 }
1105
1106 #[tokio::test]
1107 async fn sqlite() {
1108 const DB_FILE: &str = "testing_sqlite.db";
1109 if std::path::Path::new(DB_FILE).exists() {
1110 fs::remove_file(DB_FILE)
1111 .context(format!("failed to delete old SQLite file {DB_FILE}"))
1112 .unwrap();
1113 }
1114
1115 let sqlite = Sqlite::new(DB_FILE)
1116 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1117 .unwrap();
1118
1119 let db = DatabaseType::SQLite(sqlite);
1120 let key = FileEncryption::from(EncryptionOption::Xor);
1121 db.add_file_encryption_key(&key).await.unwrap();
1122 assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1123 everything(&db).await.unwrap();
1124
1125 #[cfg(feature = "vt")]
1126 {
1127 let db_config = db.get_config().await.unwrap();
1128 let state = crate::State {
1129 port: 8080,
1130 directory: None,
1131 max_upload: 10 * 1024 * 1024,
1132 ip: "127.0.0.1".parse().unwrap(),
1133 db_type: Arc::new(db),
1134 db_config,
1135 keys: HashMap::new(),
1136 started: SystemTime::now(),
1137 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1138 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1139 }),
1140 cert: None,
1141 key: None,
1142 mdns: false,
1143 };
1144
1145 let sqlite_second = Sqlite::new(DB_FILE)
1146 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1147 .unwrap();
1148
1149 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1150
1151 vt.updater().await.unwrap();
1152 println!("Sqlite: Did VT ops!");
1153 let vt_stats = sqlite_second
1154 .get_vt_stats()
1155 .context("failed to get Sqlite VT Stats")
1156 .unwrap();
1157 println!("{vt_stats:?}");
1158 assert!(
1159 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1160 );
1161 }
1162
1163 fs::remove_file(DB_FILE)
1164 .context(format!("failed to delete SQLite file {DB_FILE}"))
1165 .unwrap();
1166 }
1167
1168 #[allow(clippy::too_many_lines)]
1169 async fn everything(db: &DatabaseType) -> Result<()> {
1170 const ADMIN_UNAME: &str = "admin";
1171 const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1172
1173 db.set_name("Testing Database")
1174 .await
1175 .context("setting instance name failed")?;
1176
1177 assert!(
1178 db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1179 "Authentication without password should have failed."
1180 );
1181
1182 db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1183 .await
1184 .context("failed to set admin password")?;
1185
1186 let admin_api_key = db
1187 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1188 .await
1189 .context("unable to get api key for admin")?;
1190 println!("API key: {admin_api_key}");
1191 assert_eq!(admin_api_key.len(), 64);
1192
1193 assert_eq!(
1194 db.get_uid(&admin_api_key).await?,
1195 0,
1196 "Unable to get UID given the API key"
1197 );
1198
1199 let admin_api_key_again = db
1200 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1201 .await
1202 .context("unable to get api key a second time for admin")?;
1203
1204 assert_eq!(
1205 admin_api_key, admin_api_key_again,
1206 "API keys didn't match the second time."
1207 );
1208
1209 let bad_password = "this_is_totally_not_my_password!!";
1210 eprintln!("Testing API login with incorrect password.");
1211 assert!(
1212 db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1213 "Authenticating as admin with a bad password should have failed."
1214 );
1215
1216 let admin_is_admin = db
1217 .user_is_admin(0)
1218 .await
1219 .context("unable to see if admin (uid 0) is an admin")?;
1220 assert!(admin_is_admin);
1221
1222 let new_user_uname = "testuser";
1223 let new_user_email = "test@example.com";
1224 let new_user_password = "some_awesome_password_++";
1225 let new_id = db
1226 .create_user(
1227 new_user_uname,
1228 new_user_uname,
1229 new_user_uname,
1230 new_user_email,
1231 Some(new_user_password.into()),
1232 None,
1233 false,
1234 )
1235 .await
1236 .context(format!("failed to create user {new_user_uname}"))?;
1237
1238 let passwordless_user_id = db
1239 .create_user(
1240 "passwordless_user",
1241 "passwordless_user",
1242 "passwordless_user",
1243 "passwordless_user@example.com",
1244 None,
1245 None,
1246 false,
1247 )
1248 .await
1249 .context("failed to create passwordless_user")?;
1250
1251 for user in &db.list_users().await.context("failed to list users")? {
1252 if user.id == passwordless_user_id {
1253 assert_eq!(user.uname, "passwordless_user");
1254 }
1255 }
1256
1257 db.edit_user(
1258 passwordless_user_id,
1259 "passwordless_user_2",
1260 "passwordless_user_2",
1261 "passwordless_user_2",
1262 "passwordless_user_2@something.com",
1263 false,
1264 )
1265 .await
1266 .context(format!(
1267 "failed to alter 'passwordless' user, id {passwordless_user_id}"
1268 ))?;
1269
1270 for user in &db.list_users().await.context("failed to list users")? {
1271 if user.id == passwordless_user_id {
1272 assert_eq!(user.uname, "passwordless_user_2");
1273 }
1274 }
1275
1276 assert!(
1277 new_id > 0,
1278 "Weird UID created for user {new_user_uname}: {new_id}"
1279 );
1280
1281 assert!(
1282 db.create_user(
1283 new_user_uname,
1284 new_user_uname,
1285 new_user_uname,
1286 new_user_email,
1287 Some(new_user_password.into()),
1288 None,
1289 false
1290 )
1291 .await
1292 .is_err(),
1293 "Creating a new user with the same user name should fail"
1294 );
1295
1296 let ro_user_name = "ro_user";
1297 let ro_user_password = "ro_user_password";
1298 db.create_user(
1299 ro_user_name,
1300 "ro_user",
1301 "ro_user",
1302 "ro@example.com",
1303 Some(ro_user_password.into()),
1304 None,
1305 true,
1306 )
1307 .await
1308 .context("failed to create read-only user")?;
1309
1310 let ro_user_api_key = db
1311 .authenticate(ro_user_name, ro_user_password)
1312 .await
1313 .context("unable to get api key for read-only user")?;
1314
1315 let new_user_password_change = "some_new_awesomer_password!_++";
1316 db.set_password(new_user_uname, new_user_password_change)
1317 .await
1318 .context("failed to change the password for testuser")?;
1319
1320 let new_user_api_key = db
1321 .authenticate(new_user_uname, new_user_password_change)
1322 .await
1323 .context("unable to get api key for testuser")?;
1324 eprintln!("{new_user_uname} got API key {new_user_api_key}");
1325
1326 assert_eq!(admin_api_key.len(), new_user_api_key.len());
1327
1328 let users = db.list_users().await.context("failed to list users")?;
1329 assert_eq!(
1330 users.len(),
1331 4,
1332 "Four users were created, yet there are {} users",
1333 users.len()
1334 );
1335 eprintln!("DB has {} users:", users.len());
1336 let mut passwordless_user_found = false;
1337 for user in users {
1338 println!("{user}");
1339 if user.uname == "passwordless_user_2" {
1340 assert!(!user.has_api_key);
1341 assert!(!user.has_password);
1342 passwordless_user_found = true;
1343 } else {
1344 assert!(user.has_api_key);
1345 assert!(user.has_password);
1346 }
1347 }
1348 assert!(passwordless_user_found);
1349
1350 let new_group_name = "some_new_group";
1351 let new_group_desc = "some_new_group_description";
1352 let new_group_id = 1;
1353 assert_eq!(
1354 db.create_group(new_group_name, new_group_desc, None)
1355 .await
1356 .context("failed to create group")?,
1357 new_group_id,
1358 "New group didn't have the expected ID, expected {new_group_id}"
1359 );
1360
1361 assert!(
1362 db.create_group(new_group_name, new_group_desc, None)
1363 .await
1364 .is_err(),
1365 "Duplicate group name should have failed"
1366 );
1367
1368 db.add_user_to_group(1, 1)
1369 .await
1370 .context("Unable to add uid 1 to gid 1")?;
1371
1372 let ro_user_uid = db
1373 .get_uid(&ro_user_api_key)
1374 .await
1375 .context("Unable to get UID for read-only user")?;
1376 db.add_user_to_group(ro_user_uid, 1)
1377 .await
1378 .context("Unable to add uid 2 to gid 1")?;
1379
1380 let new_admin_group_name = "admin_subgroup";
1381 let new_admin_group_desc = "admin_subgroup_description";
1382 let new_admin_group_id = 2;
1383 assert!(
1385 db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1386 .await
1387 .context("failed to create admin sub-group")?
1388 >= new_admin_group_id,
1389 "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1390 );
1391
1392 let groups = db.list_groups().await.context("failed to list groups")?;
1393 assert_eq!(
1394 groups.len(),
1395 3,
1396 "Three groups were created, yet there are {} groups",
1397 groups.len()
1398 );
1399 eprintln!("DB has {} groups:", groups.len());
1400 for group in groups {
1401 println!("{group}");
1402 if group.id == new_admin_group_id {
1403 assert_eq!(group.parent, Some("admin".to_string()));
1404 }
1405 if group.id == 1 {
1406 let test_user_str = String::from(new_user_uname);
1407 let mut found = false;
1408 for member in group.members {
1409 if member.uname == test_user_str {
1410 found = true;
1411 break;
1412 }
1413 }
1414 assert!(found, "new user {test_user_str} wasn't in the group");
1415 }
1416 }
1417
1418 let default_source_name = "default_source".to_string();
1419 let default_source_id = db
1420 .create_source(
1421 &default_source_name,
1422 Some("desc_default_source"),
1423 None,
1424 Local::now(),
1425 true,
1426 Some(false),
1427 )
1428 .await
1429 .context("failed to create source `default_source`")?;
1430
1431 db.add_group_to_source(1, default_source_id)
1432 .await
1433 .context("failed to add group 1 to source 1")?;
1434
1435 let another_source_name = "another_source".to_string();
1436 let another_source_id = db
1437 .create_source(
1438 &another_source_name,
1439 Some("yet another file source"),
1440 None,
1441 Local::now(),
1442 true,
1443 Some(false),
1444 )
1445 .await
1446 .context("failed to create source `another_source`")?;
1447
1448 let empty_source_name = "empty_source".to_string();
1449 db.create_source(
1450 &empty_source_name,
1451 Some("empty and unused file source"),
1452 None,
1453 Local::now(),
1454 true,
1455 Some(false),
1456 )
1457 .await
1458 .context("failed to create source `another_source`")?;
1459
1460 db.add_group_to_source(1, another_source_id)
1461 .await
1462 .context("failed to add group 1 to source 1")?;
1463
1464 let sources = db.list_sources().await.context("failed to list sources")?;
1465 eprintln!("DB has {} sources:", sources.len());
1466 for source in sources {
1467 println!("{source}");
1468 assert_eq!(source.files, 0);
1469 if source.id == default_source_id || source.id == another_source_id {
1470 assert_eq!(
1471 source.groups, 1,
1472 "default source {default_source_name} should have 1 group"
1473 );
1474 } else {
1475 assert_eq!(source.groups, 0, "groups should zero (empty)");
1476 }
1477 }
1478
1479 let uid = db
1480 .get_uid(&new_user_api_key)
1481 .await
1482 .context("failed to user uid from apikey")?;
1483 let user_info = db
1484 .get_user_info(uid)
1485 .await
1486 .context("failed to get user's available groups and sources")?;
1487 assert!(user_info.sources.contains(&default_source_name));
1488 assert!(!user_info.is_admin);
1489 println!("UserInfoResponse: {user_info:?}");
1490
1491 assert!(
1492 db.allowed_user_source(1, default_source_id)
1493 .await
1494 .context(format!(
1495 "failed to check that user 1 has access to source {default_source_id}"
1496 ))?,
1497 "User 1 should should have had access to source {default_source_id}"
1498 );
1499
1500 assert!(
1501 !db.allowed_user_source(1, 5)
1502 .await
1503 .context("failed to check that user 1 has access to source 5")?,
1504 "User 1 should should not have had access to source 5"
1505 );
1506
1507 let test_label_id = db
1508 .create_label("TestLabel", None)
1509 .await
1510 .context("failed to create test label")?;
1511 let test_elf_label_id = db
1512 .create_label("TestELF", Some(test_label_id))
1513 .await
1514 .context("failed to create test label")?;
1515
1516 let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1517 let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1518 let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1519
1520 let known_type =
1521 KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1522 assert!(known_type.is_exec(), "ELF should be executable");
1523 eprintln!("ELF type ID: {elf_type}");
1524
1525 let file_addition = db
1526 .add_file(
1527 &test_elf_meta,
1528 known_type.clone(),
1529 1,
1530 default_source_id,
1531 elf_type,
1532 None,
1533 )
1534 .await
1535 .context("failed to insert a test elf")?;
1536 assert!(file_addition.is_new, "File should have been added");
1537 eprintln!("Added ELF to the DB");
1538 db.label_file(file_addition.file_id, test_elf_label_id)
1539 .await
1540 .context("failed to label file")?;
1541
1542 let partial_search = SearchRequest {
1544 search: SearchType::Search(SearchRequestParameters {
1545 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1546 labels: Some(vec![String::from("TestELF")]),
1547 file_type: Some(String::from("ELF")),
1548 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1549 ..Default::default()
1550 }),
1551 };
1552 assert!(partial_search.is_valid());
1553 let partial_search_response = db.partial_search(1, partial_search).await?;
1554 assert_eq!(partial_search_response.hashes.len(), 1);
1555 assert_eq!(
1556 partial_search_response.hashes[0],
1557 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1558 );
1559
1560 let partial_search = SearchRequest {
1562 search: SearchType::Search(SearchRequestParameters {
1563 partial_hash: None,
1564 labels: None,
1565 file_type: None,
1566 magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1567 ..Default::default()
1568 }),
1569 };
1570 assert!(partial_search.is_valid());
1571 let partial_search_response = db.partial_search(1, partial_search).await?;
1572 assert_eq!(partial_search_response.hashes.len(), 1);
1573 assert_eq!(
1574 partial_search_response.hashes[0],
1575 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1576 );
1577
1578 let partial_search = SearchRequest {
1580 search: SearchType::Search(SearchRequestParameters {
1581 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1582 file_type: Some(String::from("PE32")),
1583 ..Default::default()
1584 }),
1585 };
1586 assert!(partial_search.is_valid());
1587 let partial_search_response = db.partial_search(1, partial_search).await?;
1588 assert_eq!(partial_search_response.hashes.len(), 0);
1589
1590 let partial_search = SearchRequest {
1591 search: SearchType::Search(SearchRequestParameters {
1592 file_name: Some("ppc64".into()),
1593 ..Default::default()
1594 }),
1595 };
1596 assert!(partial_search.is_valid());
1597 let partial_search_response = db.partial_search(1, partial_search).await?;
1598 assert_eq!(partial_search_response.hashes.len(), 1);
1599
1600 let partial_search = SearchRequest {
1602 search: SearchType::Search(SearchRequestParameters::default()),
1603 };
1604 assert!(!partial_search.is_valid());
1605 let partial_search_response = db.partial_search(1, partial_search).await?;
1606 assert!(partial_search_response.hashes.is_empty());
1607
1608 let partial_search = SearchRequest {
1610 search: SearchType::Continuation(Uuid::default()),
1611 };
1612 assert!(partial_search.is_valid());
1613 let partial_search_response = db.partial_search(1, partial_search).await?;
1614 assert!(partial_search_response.hashes.is_empty());
1615
1616 assert!(db
1618 .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1619 .await
1620 .is_err());
1621
1622 assert!(
1623 db.add_file(
1624 &test_elf_meta,
1625 known_type.clone(),
1626 ro_user_uid,
1627 default_source_id,
1628 elf_type,
1629 None
1630 )
1631 .await
1632 .is_err(),
1633 "Read-only user should not be able to add a file"
1634 );
1635
1636 let mut test_elf_meta_different_name = test_elf_meta.clone();
1637 test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1638
1639 assert!(
1640 !db.add_file(
1641 &test_elf_meta_different_name,
1642 known_type,
1643 1,
1644 another_source_id,
1645 elf_type,
1646 None
1647 )
1648 .await
1649 .context("failed to insert a test elf again for a different source")?
1650 .is_new
1651 );
1652
1653 let sources = db
1654 .list_sources()
1655 .await
1656 .context("failed to re-list sources")?;
1657 eprintln!(
1658 "DB has {} sources, and a file was added twice:",
1659 sources.len()
1660 );
1661 println!("We should have two sources with one file each, yet only one ELF.");
1662 for source in sources {
1663 println!("{source}");
1664 if source.id == default_source_id || source.id == another_source_id {
1665 assert_eq!(source.files, 1);
1666 } else {
1667 assert_eq!(source.files, 0, "groups should zero (empty)");
1668 }
1669 }
1670
1671 assert!(!db
1672 .get_user_sources(1)
1673 .await
1674 .expect("failed to get user 1's sources")
1675 .sources
1676 .is_empty());
1677
1678 let file_types_counts = db
1679 .file_types_counts()
1680 .await
1681 .context("failed to get file types and counts")?;
1682 for (name, count) in file_types_counts {
1683 println!("{name}: {count}");
1684 assert_eq!(name, "ELF");
1685 assert_eq!(count, 1);
1686 }
1687
1688 let mut test_elf_modified = test_elf.clone();
1689 let random_bytes = Uuid::new_v4();
1690 let mut random_bytes = random_bytes.into_bytes().to_vec();
1691 test_elf_modified.append(&mut random_bytes);
1692 let similarity_request = generate_similarity_request(&test_elf_modified);
1693 let similarity_response = db
1694 .find_similar_samples(1, &similarity_request.hashes)
1695 .await
1696 .context("failed to get similarity response")?;
1697 eprintln!("Similarity response: {similarity_response:?}");
1698 let similarity_response = similarity_response.first().unwrap();
1699 assert_eq!(
1700 similarity_response.sha256, test_elf_meta.sha256,
1701 "Similarity response should have had the hash of the original ELF"
1702 );
1703 for (algo, sim) in &similarity_response.algorithms {
1704 match algo {
1705 malwaredb_api::SimilarityHashType::LZJD => {
1706 assert!(*sim > 0.0f32);
1707 }
1708 malwaredb_api::SimilarityHashType::SSDeep => {
1709 assert!(*sim > 80.0f32);
1710 }
1711 malwaredb_api::SimilarityHashType::TLSH => {
1712 assert!(*sim <= 20f32);
1713 }
1714 _ => {}
1715 }
1716 }
1717
1718 let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1)
1719 .context("failed to get `HashType::SHA1` from string")?;
1720 let response_sha256 = db
1721 .retrieve_sample(1, &test_elf_hashtype)
1722 .await
1723 .context("could not get SHA-256 hash from test sample")
1724 .unwrap();
1725 assert_eq!(response_sha256, test_elf_meta.sha256);
1726
1727 let test_bogus_hash = HashType::try_from(String::from(
1728 "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0",
1729 ))
1730 .context("failed to get `HashType` from static string")?;
1731 assert!(
1732 db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1733 "Getting a file with a bogus hash should have failed."
1734 );
1735
1736 let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1737 let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1738 let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1739
1740 let known_type =
1741 KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1742
1743 assert!(
1744 db.add_file(
1745 &test_pdf_meta,
1746 known_type,
1747 1,
1748 default_source_id,
1749 pdf_type,
1750 None
1751 )
1752 .await
1753 .context("failed to insert a test pdf")?
1754 .is_new
1755 );
1756 eprintln!("Added PDF to the DB");
1757
1758 let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1759 let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1760 let rtf_type = db
1761 .get_type_id_for_bytes(&test_rtf)
1762 .await
1763 .context("failed to get file type id for rtf")?;
1764
1765 let known_type =
1766 KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1767
1768 assert!(
1769 db.add_file(
1770 &test_rtf_meta,
1771 known_type,
1772 1,
1773 default_source_id,
1774 rtf_type,
1775 None
1776 )
1777 .await
1778 .context("failed to insert a test rtf")?
1779 .is_new
1780 );
1781 eprintln!("Added RTF to the DB");
1782
1783 let report = db
1784 .get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.clone())?)
1785 .await
1786 .context("failed to get report for test rtf")?;
1787 assert!(report
1788 .clone()
1789 .filecommand
1790 .unwrap()
1791 .contains("Rich Text Format"));
1792 println!("Report: {report}");
1793
1794 assert!(db
1795 .get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256)?)
1796 .await
1797 .is_err());
1798
1799 #[cfg(feature = "vt")]
1800 {
1801 assert!(report.vt.is_some());
1802 let files_needing_vt = db
1803 .files_without_vt_records(10)
1804 .await
1805 .context("failed to get files without VT records")?;
1806 assert!(files_needing_vt.len() > 2);
1807 println!(
1808 "{} files needing VT data: {files_needing_vt:?}",
1809 files_needing_vt.len()
1810 );
1811 }
1812
1813 #[cfg(not(feature = "vt"))]
1814 {
1815 assert!(report.vt.is_none());
1816 }
1817
1818 let reset = db
1819 .reset_api_keys()
1820 .await
1821 .context("failed to reset all API keys")?;
1822 eprintln!("Cleared {reset} api keys.");
1823
1824 let db_info = db.db_info().await.context("failed to get database info")?;
1825 eprintln!("DB Info: {db_info:?}");
1826
1827 let data_types = db
1828 .get_known_data_types()
1829 .await
1830 .context("failed to get data types")?;
1831 for data_type in data_types {
1832 println!("{data_type:?}");
1833 }
1834
1835 let sources = db
1836 .list_sources()
1837 .await
1838 .context("failed to list sources second time")?;
1839 eprintln!("DB has {} sources:", sources.len());
1840 for source in sources {
1841 println!("{source}");
1842 }
1843
1844 let file_types_counts = db
1845 .file_types_counts()
1846 .await
1847 .context("failed to get file types and counts")?;
1848 for (name, count) in file_types_counts {
1849 println!("{name}: {count}");
1850 assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1851 }
1852
1853 let fatmacho =
1854 include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
1855 .to_vec();
1856 let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
1857 let fatmacho_type = db
1858 .get_type_id_for_bytes(&fatmacho)
1859 .await
1860 .context("failed to get file type for Fat Mach-O")?;
1861 let known_type = KnownType::new(&fatmacho)
1862 .context("failed to parse Fat Mach-O from type crate's test data")?;
1863
1864 assert!(
1865 db.add_file(
1866 &fatmacho_meta,
1867 known_type,
1868 1,
1869 default_source_id,
1870 fatmacho_type,
1871 None
1872 )
1873 .await
1874 .context("failed to insert a test Fat Mach-O")?
1875 .is_new
1876 );
1877 eprintln!("Added Fat Mach-O to the DB");
1878
1879 let file_types_counts = db
1880 .file_types_counts()
1881 .await
1882 .context("failed to get file types and counts")?;
1883 for (name, count) in &file_types_counts {
1884 println!("{name}: {count}");
1885 }
1886
1887 assert_eq!(
1888 *file_types_counts.get("Mach-O").unwrap(),
1889 4,
1890 "Expected 4 Mach-O files, got {:?}",
1891 file_types_counts.get("Mach-O")
1892 );
1893
1894 let malware_label_id = db
1895 .create_label(MALWARE_LABEL, None)
1896 .await
1897 .context("failed to create first label")?;
1898 let ransomware_label_id = db
1899 .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
1900 .await
1901 .context("failed to create malware sub-label")?;
1902 let labels = db.get_labels().await.context("failed to get labels")?;
1903
1904 assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
1905 for label in labels.0 {
1906 if label.name == RANSOMWARE_LABEL {
1907 assert_eq!(label.id, ransomware_label_id);
1908 assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
1909 }
1910 }
1911
1912 let source_code = include_bytes!("mod.rs");
1914 let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
1915 let known_type =
1916 KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
1917
1918 assert!(matches!(known_type, KnownType::Unknown(_)));
1919
1920 let unknown_type: Vec<FileType> = db
1921 .get_known_data_types()
1922 .await?
1923 .into_iter()
1924 .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
1925 .collect();
1926 let unknown_type_id = unknown_type.first().unwrap().id;
1927 assert!(db.get_type_id_for_bytes(source_code).await.is_err());
1928 db.enable_keep_unknown_files()
1929 .await
1930 .context("failed to enable keeping of unknown files")?;
1931 let source_type = db
1932 .get_type_id_for_bytes(source_code)
1933 .await
1934 .context("failed to type id for source code unknown type example")?;
1935 assert_eq!(source_type, unknown_type_id);
1936 eprintln!("Unknown file type ID: {source_type}");
1937 assert!(
1938 db.add_file(
1939 &source_meta,
1940 known_type,
1941 1,
1942 default_source_id,
1943 unknown_type_id,
1944 None
1945 )
1946 .await
1947 .context("failed to add Rust source code file")?
1948 .is_new
1949 );
1950 eprintln!("Added Rust source code to the DB");
1951
1952 db.reset_own_api_key(0)
1953 .await
1954 .context("failed to clear own API key uid 0")?;
1955
1956 db.deactivate_user(0)
1957 .await
1958 .context("failed to clear password and API key for uid 0")?;
1959
1960 Ok(())
1961 }
1962}