1#[cfg(any(test, feature = "admin"))]
9mod admin;
10mod pg;
12
13#[cfg(any(test, feature = "sqlite"))]
15mod sqlite;
16
17pub mod types;
19
20#[cfg(any(test, feature = "admin"))]
21use crate::crypto::EncryptionOption;
22use crate::crypto::FileEncryption;
23use crate::db::pg::Postgres;
24#[cfg(any(test, feature = "sqlite"))]
25use crate::db::sqlite::Sqlite;
26use crate::db::types::{FileMetadata, FileType};
27use malwaredb_api::{digest::HashType, GetUserInfoResponse, Labels, SearchRequest, Sources};
28use malwaredb_types::KnownType;
29
30use std::collections::HashMap;
31use std::path::PathBuf;
32
33use anyhow::{bail, ensure, Result};
34use argon2::password_hash::{rand_core::OsRng, SaltString};
35use argon2::{Argon2, PasswordHasher};
36#[cfg(any(test, feature = "admin"))]
37use chrono::Local;
38#[cfg(feature = "vt")]
39use malwaredb_virustotal::filereport::ScanResultAttributes;
40
41pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
43
44#[derive(Debug)]
46pub enum DatabaseType {
47 Postgres(Postgres),
49
50 #[cfg(any(test, feature = "sqlite"))]
52 SQLite(Sqlite),
53}
54
55#[derive(Debug)]
57pub struct DatabaseInformation {
58 pub version: String,
60
61 pub size: String,
63
64 pub num_files: u64,
66
67 pub num_users: u32,
69
70 pub num_groups: u32,
72
73 pub num_sources: u32,
75}
76
77#[derive(Debug)]
79pub struct MDBConfig {
80 pub name: String,
82
83 pub compression: bool,
85
86 pub send_samples_to_vt: bool,
88
89 pub keep_unknown_files: bool,
91
92 #[allow(dead_code)]
94 pub(crate) default_key: Option<u32>,
95}
96
97#[cfg(feature = "vt")]
99#[derive(Debug, Clone, Copy)]
100pub struct VtStats {
101 pub clean_records: u32,
103
104 pub hits_records: u32,
106
107 pub files_without_records: u32,
109}
110
111impl DatabaseType {
112 pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
121 #[cfg(any(test, feature = "sqlite"))]
122 if arg.starts_with("file:") {
123 let new_conn_str = arg.trim_start_matches("file:");
124 return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
125 }
126
127 if arg.starts_with("postgres") {
128 let new_conn_str = arg.trim_start_matches("postgres");
129 return Ok(DatabaseType::Postgres(
130 Postgres::new(new_conn_str, server_ca).await?,
131 ));
132 }
133
134 bail!("unknown database type `{arg}`")
135 }
136
137 #[cfg(feature = "vt")]
143 pub async fn enable_vt_upload(&self) -> Result<()> {
144 match self {
145 DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
146 #[cfg(any(test, feature = "sqlite"))]
147 DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
148 }
149 }
150
151 #[cfg(feature = "vt")]
157 pub async fn disable_vt_upload(&self) -> Result<()> {
158 match self {
159 DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
160 #[cfg(any(test, feature = "sqlite"))]
161 DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
162 }
163 }
164
165 #[cfg(feature = "vt")]
171 pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
172 match self {
173 DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
174 #[cfg(any(test, feature = "sqlite"))]
175 DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
176 }
177 }
178
179 #[cfg(feature = "vt")]
185 pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
186 match self {
187 DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
188 #[cfg(any(test, feature = "sqlite"))]
189 DatabaseType::SQLite(sl) => sl.store_vt_record(results),
190 }
191 }
192
193 #[cfg(feature = "vt")]
199 pub async fn get_vt_stats(&self) -> Result<VtStats> {
200 match self {
201 DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
202 #[cfg(any(test, feature = "sqlite"))]
203 DatabaseType::SQLite(sl) => sl.get_vt_stats(),
204 }
205 }
206
207 pub async fn get_config(&self) -> Result<MDBConfig> {
213 match self {
214 DatabaseType::Postgres(pg) => pg.get_config().await,
215 #[cfg(any(test, feature = "sqlite"))]
216 DatabaseType::SQLite(sl) => sl.get_config(),
217 }
218 }
219
220 pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
227 match self {
228 DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
229 #[cfg(any(test, feature = "sqlite"))]
230 DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
231 }
232 }
233
234 pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
241 ensure!(!apikey.is_empty(), "API key was empty");
242 match self {
243 DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
244 #[cfg(any(test, feature = "sqlite"))]
245 DatabaseType::SQLite(sl) => sl.get_uid(apikey),
246 }
247 }
248
249 pub async fn db_info(&self) -> Result<DatabaseInformation> {
255 match self {
256 DatabaseType::Postgres(pg) => pg.db_info().await,
257 #[cfg(any(test, feature = "sqlite"))]
258 DatabaseType::SQLite(sl) => sl.db_info(),
259 }
260 }
261
262 pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
269 match self {
270 DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
271 #[cfg(any(test, feature = "sqlite"))]
272 DatabaseType::SQLite(sl) => sl.get_user_info(uid),
273 }
274 }
275
276 pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
283 match self {
284 DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
285 #[cfg(any(test, feature = "sqlite"))]
286 DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
287 }
288 }
289
290 pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
297 match self {
298 DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
299 #[cfg(any(test, feature = "sqlite"))]
300 DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
301 }
302 }
303
304 pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
310 match self {
311 DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
312 #[cfg(any(test, feature = "sqlite"))]
313 DatabaseType::SQLite(sl) => sl.get_known_data_types(),
314 }
315 }
316
317 pub async fn get_labels(&self) -> Result<Labels> {
323 match self {
324 DatabaseType::Postgres(pg) => pg.get_labels().await,
325 #[cfg(any(test, feature = "sqlite"))]
326 DatabaseType::SQLite(sl) => sl.get_labels(),
327 }
328 }
329
330 pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
336 match self {
337 DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
338 #[cfg(any(test, feature = "sqlite"))]
339 DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
340 }
341 }
342
343 pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
350 match self {
351 DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
352 #[cfg(any(test, feature = "sqlite"))]
353 DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
354 }
355 }
356
357 pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
365 match self {
366 DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
367 #[cfg(any(test, feature = "sqlite"))]
368 DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
369 }
370 }
371
372 pub async fn add_file(
379 &self,
380 meta: &FileMetadata,
381 known_type: KnownType<'_>,
382 uid: u32,
383 sid: u32,
384 ftype: u32,
385 parent: Option<u64>,
386 ) -> Result<bool> {
387 match self {
388 DatabaseType::Postgres(pg) => {
389 pg.add_file(meta, known_type, uid, sid, ftype, parent).await
390 }
391 #[cfg(any(test, feature = "sqlite"))]
392 DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
393 }
394 }
395
396 pub async fn partial_search(&self, uid: u32, search: &SearchRequest) -> Result<Vec<String>> {
402 match self {
403 DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
404 #[cfg(any(test, feature = "sqlite"))]
405 DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
406 }
407 }
408
409 pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
417 match self {
418 DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
419 #[cfg(any(test, feature = "sqlite"))]
420 DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
421 }
422 }
423
424 pub async fn get_sample_report(
431 &self,
432 uid: u32,
433 hash: &HashType,
434 ) -> Result<malwaredb_api::Report> {
435 match self {
436 DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
437 #[cfg(any(test, feature = "sqlite"))]
438 DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
439 }
440 }
441
442 pub async fn find_similar_samples(
448 &self,
449 uid: u32,
450 sim: &[(malwaredb_api::SimilarityHashType, String)],
451 ) -> Result<Vec<malwaredb_api::SimilarSample>> {
452 match self {
453 DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
454 #[cfg(any(test, feature = "sqlite"))]
455 DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
456 }
457 }
458
459 pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
463 match self {
464 DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
465 #[cfg(any(test, feature = "sqlite"))]
466 DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
467 }
468 }
469
470 pub(crate) async fn get_file_encryption_key_id(
472 &self,
473 hash: &str,
474 ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
475 match self {
476 DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
477 #[cfg(any(test, feature = "sqlite"))]
478 DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
479 }
480 }
481
482 pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
484 match self {
485 DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
486 #[cfg(any(test, feature = "sqlite"))]
487 DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
488 }
489 }
490
491 #[cfg(any(test, feature = "admin"))]
499 pub async fn enable_compression(&self) -> Result<()> {
500 match self {
501 DatabaseType::Postgres(pg) => pg.enable_compression().await,
502 #[cfg(any(test, feature = "sqlite"))]
503 DatabaseType::SQLite(sl) => sl.enable_compression(),
504 }
505 }
506
507 #[cfg(any(test, feature = "admin"))]
513 pub async fn disable_compression(&self) -> Result<()> {
514 match self {
515 DatabaseType::Postgres(pg) => pg.disable_compression().await,
516 #[cfg(any(test, feature = "sqlite"))]
517 DatabaseType::SQLite(sl) => sl.disable_compression(),
518 }
519 }
520
521 #[cfg(any(test, feature = "admin"))]
527 pub async fn enable_keep_unknown_files(&self) -> Result<()> {
528 match self {
529 DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
530 #[cfg(any(test, feature = "sqlite"))]
531 DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
532 }
533 }
534
535 #[cfg(any(test, feature = "admin"))]
541 pub async fn disable_keep_unknown_files(&self) -> Result<()> {
542 match self {
543 DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
544 #[cfg(any(test, feature = "sqlite"))]
545 DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
546 }
547 }
548
549 #[cfg(any(test, feature = "admin"))]
555 pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
556 match self {
557 DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
558 #[cfg(any(test, feature = "sqlite"))]
559 DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
560 }
561 }
562
563 #[cfg(any(test, feature = "admin"))]
569 pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
570 match self {
571 DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
572 #[cfg(any(test, feature = "sqlite"))]
573 DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
574 }
575 }
576
577 #[allow(clippy::too_many_arguments)]
583 #[cfg(any(test, feature = "admin"))]
584 pub async fn create_user(
585 &self,
586 uname: &str,
587 fname: &str,
588 lname: &str,
589 email: &str,
590 password: Option<String>,
591 organisation: Option<&String>,
592 readonly: bool,
593 ) -> Result<u32> {
594 match self {
595 DatabaseType::Postgres(pg) => {
596 pg.create_user(uname, fname, lname, email, password, organisation, readonly)
597 .await
598 }
599 #[cfg(any(test, feature = "sqlite"))]
600 DatabaseType::SQLite(sl) => {
601 sl.create_user(uname, fname, lname, email, password, organisation, readonly)
602 }
603 }
604 }
605
606 #[cfg(any(test, feature = "admin"))]
612 pub async fn reset_api_keys(&self) -> Result<u64> {
613 match self {
614 DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
615 #[cfg(any(test, feature = "sqlite"))]
616 DatabaseType::SQLite(sl) => sl.reset_api_keys(),
617 }
618 }
619
620 #[cfg(any(test, feature = "admin"))]
627 pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
628 match self {
629 DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
630 #[cfg(any(test, feature = "sqlite"))]
631 DatabaseType::SQLite(sl) => sl.set_password(uname, password),
632 }
633 }
634
635 #[cfg(any(test, feature = "admin"))]
641 pub async fn list_users(&self) -> Result<Vec<admin::User>> {
642 match self {
643 DatabaseType::Postgres(pg) => pg.list_users().await,
644 #[cfg(any(test, feature = "sqlite"))]
645 DatabaseType::SQLite(sl) => sl.list_users(),
646 }
647 }
648
649 #[cfg(any(test, feature = "admin"))]
656 pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
657 match self {
658 DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
659 #[cfg(any(test, feature = "sqlite"))]
660 DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
661 }
662 }
663
664 #[cfg(any(test, feature = "admin"))]
671 pub async fn edit_group(
672 &self,
673 gid: u32,
674 name: &str,
675 desc: &str,
676 parent: Option<u32>,
677 ) -> Result<()> {
678 match self {
679 DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
680 #[cfg(any(test, feature = "sqlite"))]
681 DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
682 }
683 }
684
685 #[cfg(any(test, feature = "admin"))]
691 pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
692 match self {
693 DatabaseType::Postgres(pg) => pg.list_groups().await,
694 #[cfg(any(test, feature = "sqlite"))]
695 DatabaseType::SQLite(sl) => sl.list_groups(),
696 }
697 }
698
699 #[cfg(any(test, feature = "admin"))]
706 pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
707 match self {
708 DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
709 #[cfg(any(test, feature = "sqlite"))]
710 DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
711 }
712 }
713
714 #[cfg(any(test, feature = "admin"))]
721 pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
722 match self {
723 DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
724 #[cfg(any(test, feature = "sqlite"))]
725 DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
726 }
727 }
728
729 #[cfg(any(test, feature = "admin"))]
736 pub async fn create_group(
737 &self,
738 name: &str,
739 description: &str,
740 parent: Option<u32>,
741 ) -> Result<u32> {
742 match self {
743 DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
744 #[cfg(any(test, feature = "sqlite"))]
745 DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
746 }
747 }
748
749 #[cfg(any(test, feature = "admin"))]
755 pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
756 match self {
757 DatabaseType::Postgres(pg) => pg.list_sources().await,
758 #[cfg(any(test, feature = "sqlite"))]
759 DatabaseType::SQLite(sl) => sl.list_sources(),
760 }
761 }
762
763 #[cfg(any(test, feature = "admin"))]
770 pub async fn create_source(
771 &self,
772 name: &str,
773 description: Option<&str>,
774 url: Option<&str>,
775 date: chrono::DateTime<Local>,
776 releasable: bool,
777 malicious: Option<bool>,
778 ) -> Result<u32> {
779 match self {
780 DatabaseType::Postgres(pg) => {
781 pg.create_source(name, description, url, date, releasable, malicious)
782 .await
783 }
784 #[cfg(any(test, feature = "sqlite"))]
785 DatabaseType::SQLite(sl) => {
786 sl.create_source(name, description, url, date, releasable, malicious)
787 }
788 }
789 }
790
791 #[cfg(any(test, feature = "admin"))]
798 pub async fn edit_user(
799 &self,
800 uid: u32,
801 uname: &str,
802 fname: &str,
803 lname: &str,
804 email: &str,
805 readonly: bool,
806 ) -> Result<()> {
807 match self {
808 DatabaseType::Postgres(pg) => {
809 pg.edit_user(uid, uname, fname, lname, email, readonly)
810 .await
811 }
812 #[cfg(any(test, feature = "sqlite"))]
813 DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
814 }
815 }
816
817 #[cfg(any(test, feature = "admin"))]
824 pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
825 match self {
826 DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
827 #[cfg(any(test, feature = "sqlite"))]
828 DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
829 }
830 }
831
832 #[cfg(any(test, feature = "admin"))]
838 pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
839 match self {
840 DatabaseType::Postgres(pg) => pg.file_types_counts().await,
841 #[cfg(any(test, feature = "sqlite"))]
842 DatabaseType::SQLite(sl) => sl.file_types_counts(),
843 }
844 }
845
846 #[cfg(any(test, feature = "admin"))]
854 pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
855 match self {
856 DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
857 #[cfg(any(test, feature = "sqlite"))]
858 DatabaseType::SQLite(sl) => sl.create_label(name, parent),
859 }
860 }
861
862 #[cfg(any(test, feature = "admin"))]
870 pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
871 match self {
872 DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
873 #[cfg(any(test, feature = "sqlite"))]
874 DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
875 }
876 }
877
878 #[cfg(any(test, feature = "admin"))]
885 pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
886 match self {
887 DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
888 #[cfg(any(test, feature = "sqlite"))]
889 DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
890 }
891 }
892}
893
894pub fn hash_password(password: &str) -> Result<String> {
900 let salt = SaltString::generate(&mut OsRng);
901 let argon2 = Argon2::default();
902 Ok(argon2
903 .hash_password(password.as_bytes(), &salt)?
904 .to_string())
905}
906
907#[must_use]
909pub fn random_bytes_api_key() -> String {
910 let key1 = uuid::Uuid::new_v4();
911 let key2 = uuid::Uuid::new_v4();
912 let key1 = key1.to_string().replace('-', "");
913 let key2 = key2.to_string().replace('-', "");
914 format!("{key1}{key2}")
915}
916
917#[cfg(test)]
918mod tests {
919 use super::*;
920 #[cfg(feature = "vt")]
921 use crate::vt::VtUpdater;
922
923 use std::fs;
924 #[cfg(feature = "vt")]
925 use std::time::SystemTime;
926
927 use anyhow::Context;
928 use fuzzyhash::FuzzyHash;
929 use malwaredb_api::PartialHashSearchType;
930 use malwaredb_lzjd::{LZDict, Murmur3HashState};
931 use tlsh_fixed::TlshBuilder;
932 use uuid::Uuid;
933
934 const MALWARE_LABEL: &str = "malware";
935 const RANSOMWARE_LABEL: &str = "ransomware";
936
937 fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
938 let mut hashes = vec![];
939
940 hashes.push((
941 malwaredb_api::SimilarityHashType::SSDeep,
942 FuzzyHash::new(data).to_string(),
943 ));
944
945 let mut builder = TlshBuilder::new(
946 tlsh_fixed::BucketKind::Bucket256,
947 tlsh_fixed::ChecksumKind::ThreeByte,
948 tlsh_fixed::Version::Version4,
949 );
950
951 builder.update(data);
952
953 if let Ok(hasher) = builder.build() {
954 hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
955 }
956
957 let build_hasher = Murmur3HashState::default();
958 let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
959 hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
960
961 malwaredb_api::SimilarSamplesRequest { hashes }
962 }
963
964 async fn pg_config() -> Postgres {
965 const CONNECTION_STRING: &str =
968 "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
969
970 if let Ok(pg_port) = std::env::var("PG_PORT") {
971 let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
973 Postgres::new(&conn_string, None)
974 .await
975 .context(format!(
976 "failed to connect to postgres with specified port {pg_port}"
977 ))
978 .unwrap()
979 } else {
980 Postgres::new(CONNECTION_STRING, None).await.unwrap()
981 }
982 }
983
984 #[tokio::test]
985 #[ignore = "don't run this in CI"]
986 async fn pg() {
987 let psql = pg_config().await;
988 psql.delete_init().await.unwrap();
989
990 let db = DatabaseType::Postgres(psql);
991 let key = FileEncryption::from(EncryptionOption::Xor);
992 db.add_file_encryption_key(&key).await.unwrap();
993 assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
994 everything(&db).await.unwrap();
995
996 #[cfg(feature = "vt")]
997 {
998 let db_config = db.get_config().await.unwrap();
999 let state = crate::State {
1000 port: 8080,
1001 directory: None,
1002 max_upload: 10 * 1024 * 1024,
1003 ip: "127.0.0.1".parse().unwrap(),
1004 db_type: db,
1005 db_config,
1006 keys: HashMap::new(),
1007 started: SystemTime::now(),
1008 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1009 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1010 }),
1011 cert: None,
1012 key: None,
1013 };
1014
1015 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1016
1017 vt.updater().await.unwrap();
1018 println!("PG: Did VT ops!");
1019
1020 let psql = pg_config().await;
1021
1022 let vt_stats = psql
1023 .get_vt_stats()
1024 .await
1025 .context("failed to get Postgres VT Stats")
1026 .unwrap();
1027 println!("{vt_stats:?}");
1028 assert!(
1029 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1030 );
1031 }
1032
1033 let psql = pg_config().await;
1035 psql.delete_init().await.unwrap();
1036 }
1037
1038 #[tokio::test]
1039 async fn sqlite() {
1040 const DB_FILE: &str = "testing_sqlite.db";
1041 if std::path::Path::new(DB_FILE).exists() {
1042 fs::remove_file(DB_FILE)
1043 .context(format!("failed to delete old SQLite file {DB_FILE}"))
1044 .unwrap();
1045 }
1046
1047 let sqlite = Sqlite::new(DB_FILE)
1048 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1049 .unwrap();
1050
1051 let db = DatabaseType::SQLite(sqlite);
1052 let key = FileEncryption::from(EncryptionOption::Xor);
1053 db.add_file_encryption_key(&key).await.unwrap();
1054 assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1055 everything(&db).await.unwrap();
1056
1057 #[cfg(feature = "vt")]
1058 {
1059 let db_config = db.get_config().await.unwrap();
1060 let state = crate::State {
1061 port: 8080,
1062 directory: None,
1063 max_upload: 10 * 1024 * 1024,
1064 ip: "127.0.0.1".parse().unwrap(),
1065 db_type: db,
1066 db_config,
1067 keys: HashMap::new(),
1068 started: SystemTime::now(),
1069 vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1070 Some(malwaredb_virustotal::VirusTotalClient::new(e))
1071 }),
1072 cert: None,
1073 key: None,
1074 };
1075
1076 let sqlite_second = Sqlite::new(DB_FILE)
1077 .context(format!("failed to create SQLite instance for {DB_FILE}"))
1078 .unwrap();
1079
1080 let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1081
1082 vt.updater().await.unwrap();
1083 println!("Sqlite: Did VT ops!");
1084 let vt_stats = sqlite_second
1085 .get_vt_stats()
1086 .context("failed to get Sqlite VT Stats")
1087 .unwrap();
1088 println!("{vt_stats:?}");
1089 assert!(
1090 vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1091 );
1092 }
1093
1094 fs::remove_file(DB_FILE)
1095 .context(format!("failed to delete SQLite file {DB_FILE}"))
1096 .unwrap();
1097 }
1098
1099 #[allow(clippy::too_many_lines)]
1100 async fn everything(db: &DatabaseType) -> Result<()> {
1101 const ADMIN_UNAME: &str = "admin";
1102 const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1103
1104 assert!(
1105 db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1106 "Authentication without password should have failed."
1107 );
1108
1109 db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1110 .await
1111 .context("failed to set admin password")?;
1112
1113 let admin_api_key = db
1114 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1115 .await
1116 .context("unable to get api key for admin")?;
1117 println!("API key: {admin_api_key}");
1118 assert_eq!(admin_api_key.len(), 64);
1119
1120 assert_eq!(
1121 db.get_uid(&admin_api_key).await?,
1122 0,
1123 "Unable to get UID given the API key"
1124 );
1125
1126 let admin_api_key_again = db
1127 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1128 .await
1129 .context("unable to get api key a second time for admin")?;
1130
1131 assert_eq!(
1132 admin_api_key, admin_api_key_again,
1133 "API keys didn't match the second time."
1134 );
1135
1136 let bad_password = "this_is_totally_not_my_password!!";
1137 eprintln!("Testing API login with incorrect password.");
1138 assert!(
1139 db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1140 "Authenticating as admin with a bad password should have failed."
1141 );
1142
1143 let admin_is_admin = db
1144 .user_is_admin(0)
1145 .await
1146 .context("unable to see if admin (uid 0) is an admin")?;
1147 assert!(admin_is_admin);
1148
1149 let new_user_uname = "testuser";
1150 let new_user_email = "test@example.com";
1151 let new_user_password = "some_awesome_password_++";
1152 let new_id = db
1153 .create_user(
1154 new_user_uname,
1155 new_user_uname,
1156 new_user_uname,
1157 new_user_email,
1158 Some(new_user_password.into()),
1159 None,
1160 false,
1161 )
1162 .await
1163 .context(format!("failed to create user {new_user_uname}"))?;
1164
1165 let passwordless_user_id = db
1166 .create_user(
1167 "passwordless_user",
1168 "passwordless_user",
1169 "passwordless_user",
1170 "passwordless_user@example.com",
1171 None,
1172 None,
1173 false,
1174 )
1175 .await
1176 .context("failed to create passwordless_user")?;
1177
1178 for user in &db.list_users().await.context("failed to list users")? {
1179 if user.id == passwordless_user_id {
1180 assert_eq!(user.uname, "passwordless_user");
1181 }
1182 }
1183
1184 db.edit_user(
1185 passwordless_user_id,
1186 "passwordless_user_2",
1187 "passwordless_user_2",
1188 "passwordless_user_2",
1189 "passwordless_user_2@something.com",
1190 false,
1191 )
1192 .await
1193 .context(format!(
1194 "failed to alter 'passwordless' user, id {passwordless_user_id}"
1195 ))?;
1196
1197 for user in &db.list_users().await.context("failed to list users")? {
1198 if user.id == passwordless_user_id {
1199 assert_eq!(user.uname, "passwordless_user_2");
1200 }
1201 }
1202
1203 assert!(
1204 new_id > 0,
1205 "Weird UID created for user {new_user_uname}: {new_id}"
1206 );
1207
1208 assert!(
1209 db.create_user(
1210 new_user_uname,
1211 new_user_uname,
1212 new_user_uname,
1213 new_user_email,
1214 Some(new_user_password.into()),
1215 None,
1216 false
1217 )
1218 .await
1219 .is_err(),
1220 "Creating a new user with the same user name should fail"
1221 );
1222
1223 let ro_user_name = "ro_user";
1224 let ro_user_password = "ro_user_password";
1225 db.create_user(
1226 ro_user_name,
1227 "ro_user",
1228 "ro_user",
1229 "ro@example.com",
1230 Some(ro_user_password.into()),
1231 None,
1232 true,
1233 )
1234 .await
1235 .context("failed to create read-only user")?;
1236
1237 let ro_user_api_key = db
1238 .authenticate(ro_user_name, ro_user_password)
1239 .await
1240 .context("unable to get api key for read-only user")?;
1241
1242 let new_user_password_change = "some_new_awesomer_password!_++";
1243 db.set_password(new_user_uname, new_user_password_change)
1244 .await
1245 .context("failed to change the password for testuser")?;
1246
1247 let new_user_api_key = db
1248 .authenticate(new_user_uname, new_user_password_change)
1249 .await
1250 .context("unable to get api key for testuser")?;
1251 eprintln!("{new_user_uname} got API key {new_user_api_key}");
1252
1253 assert_eq!(admin_api_key.len(), new_user_api_key.len());
1254
1255 let users = db.list_users().await.context("failed to list users")?;
1256 assert_eq!(
1257 users.len(),
1258 4,
1259 "Four users were created, yet there are {} users",
1260 users.len()
1261 );
1262 eprintln!("DB has {} users:", users.len());
1263 let mut passwordless_user_found = false;
1264 for user in users {
1265 println!("{user}");
1266 if user.uname == "passwordless_user_2" {
1267 assert!(!user.has_api_key);
1268 assert!(!user.has_password);
1269 passwordless_user_found = true;
1270 } else {
1271 assert!(user.has_api_key);
1272 assert!(user.has_password);
1273 }
1274 }
1275 assert!(passwordless_user_found);
1276
1277 let new_group_name = "some_new_group";
1278 let new_group_desc = "some_new_group_description";
1279 let new_group_id = 1;
1280 assert_eq!(
1281 db.create_group(new_group_name, new_group_desc, None)
1282 .await
1283 .context("failed to create group")?,
1284 new_group_id,
1285 "New group didn't have the expected ID, expected {new_group_id}"
1286 );
1287
1288 assert!(
1289 db.create_group(new_group_name, new_group_desc, None)
1290 .await
1291 .is_err(),
1292 "Duplicate group name should have failed"
1293 );
1294
1295 db.add_user_to_group(1, 1)
1296 .await
1297 .context("Unable to add uid 1 to gid 1")?;
1298
1299 let ro_user_uid = db
1300 .get_uid(&ro_user_api_key)
1301 .await
1302 .context("Unable to get UID for read-only user")?;
1303 db.add_user_to_group(ro_user_uid, 1)
1304 .await
1305 .context("Unable to add uid 2 to gid 1")?;
1306
1307 let new_admin_group_name = "admin_subgroup";
1308 let new_admin_group_desc = "admin_subgroup_description";
1309 let new_admin_group_id = 2;
1310 assert!(
1312 db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1313 .await
1314 .context("failed to create admin sub-group")?
1315 >= new_admin_group_id,
1316 "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1317 );
1318
1319 let groups = db.list_groups().await.context("failed to list groups")?;
1320 assert_eq!(
1321 groups.len(),
1322 3,
1323 "Three groups were created, yet there are {} groups",
1324 groups.len()
1325 );
1326 eprintln!("DB has {} groups:", groups.len());
1327 for group in groups {
1328 println!("{group}");
1329 if group.id == new_admin_group_id {
1330 assert_eq!(group.parent, Some("admin".to_string()));
1331 }
1332 if group.id == 1 {
1333 let test_user_str = String::from(new_user_uname);
1334 let mut found = false;
1335 for member in group.members {
1336 if member.uname == test_user_str {
1337 found = true;
1338 break;
1339 }
1340 }
1341 assert!(found, "new user {test_user_str} wasn't in the group");
1342 }
1343 }
1344
1345 let default_source_name = "default_source".to_string();
1346 let default_source_id = db
1347 .create_source(
1348 &default_source_name,
1349 Some("desc_default_source"),
1350 None,
1351 Local::now(),
1352 true,
1353 Some(false),
1354 )
1355 .await
1356 .context("failed to create source `default_source`")?;
1357
1358 db.add_group_to_source(1, default_source_id)
1359 .await
1360 .context("failed to add group 1 to source 1")?;
1361
1362 let another_source_name = "another_source".to_string();
1363 let another_source_id = db
1364 .create_source(
1365 &another_source_name,
1366 Some("yet another file source"),
1367 None,
1368 Local::now(),
1369 true,
1370 Some(false),
1371 )
1372 .await
1373 .context("failed to create source `another_source`")?;
1374
1375 let empty_source_name = "empty_source".to_string();
1376 db.create_source(
1377 &empty_source_name,
1378 Some("empty and unused file source"),
1379 None,
1380 Local::now(),
1381 true,
1382 Some(false),
1383 )
1384 .await
1385 .context("failed to create source `another_source`")?;
1386
1387 db.add_group_to_source(1, another_source_id)
1388 .await
1389 .context("failed to add group 1 to source 1")?;
1390
1391 let sources = db.list_sources().await.context("failed to list sources")?;
1392 eprintln!("DB has {} sources:", sources.len());
1393 for source in sources {
1394 println!("{source}");
1395 assert_eq!(source.files, 0);
1396 if source.id == default_source_id || source.id == another_source_id {
1397 assert_eq!(
1398 source.groups, 1,
1399 "default source {default_source_name} should have 1 group"
1400 );
1401 } else {
1402 assert_eq!(source.groups, 0, "groups should zero (empty)");
1403 }
1404 }
1405
1406 let uid = db
1407 .get_uid(&new_user_api_key)
1408 .await
1409 .context("failed to user uid from apikey")?;
1410 let user_info = db
1411 .get_user_info(uid)
1412 .await
1413 .context("failed to get user's available groups and sources")?;
1414 assert!(user_info.sources.contains(&default_source_name));
1415 assert!(!user_info.is_admin);
1416 println!("UserInfoResponse: {user_info:?}");
1417
1418 assert!(
1419 db.allowed_user_source(1, default_source_id)
1420 .await
1421 .context(format!(
1422 "failed to check that user 1 has access to source {default_source_id}"
1423 ))?,
1424 "User 1 should should have had access to source {default_source_id}"
1425 );
1426
1427 assert!(
1428 !db.allowed_user_source(1, 5)
1429 .await
1430 .context("failed to check that user 1 has access to source 5")?,
1431 "User 1 should should not have had access to source 5"
1432 );
1433
1434 let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1435 let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1436 let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1437
1438 let known_type =
1439 KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1440 assert!(known_type.is_exec(), "ELF should be executable");
1441 eprintln!("ELF type ID: {elf_type}");
1442
1443 assert!(db
1444 .add_file(
1445 &test_elf_meta,
1446 known_type.clone(),
1447 1,
1448 default_source_id,
1449 elf_type,
1450 None
1451 )
1452 .await
1453 .context("failed to insert a test elf")?);
1454 eprintln!("Added ELF to the DB");
1455
1456 let partial_search = SearchRequest {
1457 partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1458 ..Default::default()
1459 };
1460 let partial_search_response = db.partial_search(1, &partial_search).await?;
1461 assert_eq!(partial_search_response.len(), 1);
1462 assert_eq!(
1463 partial_search_response[0],
1464 "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1465 );
1466
1467 let partial_search = SearchRequest {
1468 file_name: Some("ppc64".into()),
1469 ..Default::default()
1470 };
1471 let partial_search_response = db.partial_search(1, &partial_search).await?;
1472 assert_eq!(partial_search_response.len(), 1);
1473
1474 let partial_search = SearchRequest::default();
1475 let partial_search_response = db.partial_search(1, &partial_search).await?;
1476 assert!(partial_search_response.is_empty());
1477
1478 assert!(
1479 db.add_file(
1480 &test_elf_meta,
1481 known_type.clone(),
1482 ro_user_uid,
1483 default_source_id,
1484 elf_type,
1485 None
1486 )
1487 .await
1488 .is_err(),
1489 "Read-only user should not be able to add a file"
1490 );
1491
1492 let mut test_elf_meta_different_name = test_elf_meta.clone();
1493 test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1494
1495 assert!(!db
1496 .add_file(
1497 &test_elf_meta_different_name,
1498 known_type,
1499 1,
1500 another_source_id,
1501 elf_type,
1502 None
1503 )
1504 .await
1505 .context("failed to insert a test elf again for a different source")?);
1506
1507 let sources = db
1508 .list_sources()
1509 .await
1510 .context("failed to re-list sources")?;
1511 eprintln!(
1512 "DB has {} sources, and a file was added twice:",
1513 sources.len()
1514 );
1515 println!("We should have two sources with one file each, yet only one ELF.");
1516 for source in sources {
1517 println!("{source}");
1518 if source.id == default_source_id || source.id == another_source_id {
1519 assert_eq!(source.files, 1);
1520 } else {
1521 assert_eq!(source.files, 0, "groups should zero (empty)");
1522 }
1523 }
1524
1525 assert!(!db
1526 .get_user_sources(1)
1527 .await
1528 .expect("failed to get user 1's sources")
1529 .sources
1530 .is_empty());
1531
1532 let file_types_counts = db
1533 .file_types_counts()
1534 .await
1535 .context("failed to get file types and counts")?;
1536 for (name, count) in file_types_counts {
1537 println!("{name}: {count}");
1538 assert_eq!(name, "ELF");
1539 assert_eq!(count, 1);
1540 }
1541
1542 let mut test_elf_modified = test_elf.clone();
1543 let random_bytes = Uuid::new_v4();
1544 let mut random_bytes = random_bytes.into_bytes().to_vec();
1545 test_elf_modified.append(&mut random_bytes);
1546 let similarity_request = generate_similarity_request(&test_elf_modified);
1547 let similarity_response = db
1548 .find_similar_samples(1, &similarity_request.hashes)
1549 .await
1550 .context("failed to get similarity response")?;
1551 eprintln!("Similarity response: {similarity_response:?}");
1552 let similarity_response = similarity_response.first().unwrap();
1553 assert_eq!(
1554 similarity_response.sha256, test_elf_meta.sha256,
1555 "Similarity response should have had the hash of the original ELF"
1556 );
1557 for (algo, sim) in &similarity_response.algorithms {
1558 match algo {
1559 malwaredb_api::SimilarityHashType::LZJD => {
1560 assert!(*sim > 0.0f32);
1561 }
1562 malwaredb_api::SimilarityHashType::SSDeep => {
1563 assert!(*sim > 80.0f32);
1564 }
1565 malwaredb_api::SimilarityHashType::TLSH => {
1566 assert!(*sim <= 20f32);
1567 }
1568 _ => {}
1569 }
1570 }
1571
1572 let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1)
1573 .context("failed to get `HashType::SHA1` from string")?;
1574 let response_sha256 = db
1575 .retrieve_sample(1, &test_elf_hashtype)
1576 .await
1577 .context("could not get SHA-256 hash from test sample")
1578 .unwrap();
1579 assert_eq!(response_sha256, test_elf_meta.sha256);
1580
1581 let test_bogus_hash = HashType::try_from(String::from(
1582 "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0",
1583 ))
1584 .context("failed to get `HashType` from static string")?;
1585 assert!(
1586 db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1587 "Getting a file with a bogus hash should have failed."
1588 );
1589
1590 let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1591 let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1592 let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1593
1594 let known_type =
1595 KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1596
1597 assert!(db
1598 .add_file(
1599 &test_pdf_meta,
1600 known_type,
1601 1,
1602 default_source_id,
1603 pdf_type,
1604 None
1605 )
1606 .await
1607 .context("failed to insert a test pdf")?);
1608 eprintln!("Added PDF to the DB");
1609
1610 let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1611 let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1612 let rtf_type = db
1613 .get_type_id_for_bytes(&test_rtf)
1614 .await
1615 .context("failed to get file type id for rtf")?;
1616
1617 let known_type =
1618 KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1619
1620 assert!(db
1621 .add_file(
1622 &test_rtf_meta,
1623 known_type,
1624 1,
1625 default_source_id,
1626 rtf_type,
1627 None
1628 )
1629 .await
1630 .context("failed to insert a test rtf")?);
1631 eprintln!("Added RTF to the DB");
1632
1633 let report = db
1634 .get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.clone())?)
1635 .await
1636 .context("failed to get report for test rtf")?;
1637 assert!(report
1638 .clone()
1639 .filecommand
1640 .unwrap()
1641 .contains("Rich Text Format"));
1642 println!("Report: {report}");
1643
1644 assert!(db
1645 .get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256)?)
1646 .await
1647 .is_err());
1648
1649 #[cfg(feature = "vt")]
1650 {
1651 assert!(report.vt.is_some());
1652 let files_needing_vt = db
1653 .files_without_vt_records(10)
1654 .await
1655 .context("failed to get files without VT records")?;
1656 assert!(files_needing_vt.len() > 2);
1657 println!(
1658 "{} files needing VT data: {files_needing_vt:?}",
1659 files_needing_vt.len()
1660 );
1661 }
1662
1663 #[cfg(not(feature = "vt"))]
1664 {
1665 assert!(report.vt.is_none());
1666 }
1667
1668 let reset = db
1669 .reset_api_keys()
1670 .await
1671 .context("failed to reset all API keys")?;
1672 eprintln!("Cleared {reset} api keys.");
1673
1674 let db_info = db.db_info().await.context("failed to get database info")?;
1675 eprintln!("DB Info: {db_info:?}");
1676
1677 let data_types = db
1678 .get_known_data_types()
1679 .await
1680 .context("failed to get data types")?;
1681 for data_type in data_types {
1682 println!("{data_type:?}");
1683 }
1684
1685 let sources = db
1686 .list_sources()
1687 .await
1688 .context("failed to list sources second time")?;
1689 eprintln!("DB has {} sources:", sources.len());
1690 for source in sources {
1691 println!("{source}");
1692 }
1693
1694 let file_types_counts = db
1695 .file_types_counts()
1696 .await
1697 .context("failed to get file types and counts")?;
1698 for (name, count) in file_types_counts {
1699 println!("{name}: {count}");
1700 assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1701 }
1702
1703 let fatmacho =
1704 include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
1705 .to_vec();
1706 let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
1707 let fatmacho_type = db
1708 .get_type_id_for_bytes(&fatmacho)
1709 .await
1710 .context("failed to get file type for Fat Mach-O")?;
1711 let known_type = KnownType::new(&fatmacho)
1712 .context("failed to parse Fat Mach-O from type crate's test data")?;
1713
1714 assert!(db
1715 .add_file(
1716 &fatmacho_meta,
1717 known_type,
1718 1,
1719 default_source_id,
1720 fatmacho_type,
1721 None
1722 )
1723 .await
1724 .context("failed to insert a test Fat Mach-O")?);
1725 eprintln!("Added Fat Mach-O to the DB");
1726
1727 let file_types_counts = db
1728 .file_types_counts()
1729 .await
1730 .context("failed to get file types and counts")?;
1731 for (name, count) in &file_types_counts {
1732 println!("{name}: {count}");
1733 }
1734
1735 assert_eq!(
1736 *file_types_counts.get("Mach-O").unwrap(),
1737 4,
1738 "Expected 4 Mach-O files, got {:?}",
1739 file_types_counts.get("Mach-O")
1740 );
1741
1742 let malware_label_id = db
1743 .create_label(MALWARE_LABEL, None)
1744 .await
1745 .context("failed to create first label")?;
1746 let ransomware_label_id = db
1747 .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
1748 .await
1749 .context("failed to create malware sub-label")?;
1750 let labels = db.get_labels().await.context("failed to get labels")?;
1751
1752 assert_eq!(labels.len(), 2);
1753 for label in labels.0 {
1754 if label.name == RANSOMWARE_LABEL {
1755 assert_eq!(label.id, ransomware_label_id);
1756 assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
1757 }
1758 }
1759
1760 let source_code = include_bytes!("mod.rs");
1762 let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
1763 let known_type =
1764 KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
1765
1766 assert!(matches!(known_type, KnownType::Unknown(_)));
1767
1768 let unknown_type: Vec<FileType> = db
1769 .get_known_data_types()
1770 .await?
1771 .into_iter()
1772 .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
1773 .collect();
1774 let unknown_type_id = unknown_type.first().unwrap().id;
1775 assert!(db.get_type_id_for_bytes(source_code).await.is_err());
1776 db.enable_keep_unknown_files()
1777 .await
1778 .context("failed to enable keeping of unknown files")?;
1779 let source_type = db
1780 .get_type_id_for_bytes(source_code)
1781 .await
1782 .context("failed to type id for source code unknown type example")?;
1783 assert_eq!(source_type, unknown_type_id);
1784 eprintln!("Unknown file type ID: {source_type}");
1785 assert!(db
1786 .add_file(
1787 &source_meta,
1788 known_type,
1789 1,
1790 default_source_id,
1791 unknown_type_id,
1792 None
1793 )
1794 .await
1795 .context("failed to add Rust source code file")?);
1796 eprintln!("Added Rust source code to the DB");
1797
1798 db.reset_own_api_key(0)
1799 .await
1800 .context("failed to clear own API key uid 0")?;
1801
1802 db.deactivate_user(0)
1803 .await
1804 .context("failed to clear password and API key for uid 0")?;
1805
1806 Ok(())
1807 }
1808}