Skip to main content

malwaredb_server/db/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Postgres is the database used by MalwareDB.
4//! However, SQLite will be used for unit testing or for small instances of MalwareDB. This option
5//! can be allowed by using the `sqlite` feature flag. When using SQLite, MalwareDB will calculate
6//! the distances for the similarity hashes.
7
8/// Malware DB Administrative functions
9#[cfg(any(test, feature = "admin"))]
10mod admin;
11/// Postgres functions
12mod pg;
13
14/// `SQLite` functionality
15#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18/// Custom `SQLite` functions
19#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22/// File Metadata convenience data structure
23pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33    digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48/// The maximum amount of partial hash and/or partial file name search results to prevent performance issues
49pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51/// Migration action
52#[derive(Copy, Clone)]
53pub enum Migration {
54    /// At run: check if a migration is needed
55    Check,
56
57    /// Admin feature: do the migration
58    #[cfg(any(test, feature = "admin"))]
59    Migrate,
60}
61
62/// Database connection handle
63#[derive(Debug)]
64pub enum DatabaseType {
65    /// Postgres database
66    Postgres(Postgres),
67
68    /// `SQLite` database
69    #[cfg(any(test, feature = "sqlite"))]
70    SQLite(Sqlite),
71}
72
73/// Version information and basic stats for the database
74#[derive(Debug)]
75pub struct DatabaseInformation {
76    /// Version string of the database
77    pub version: String,
78
79    /// Human-readable database size
80    pub size: String,
81
82    /// Number of file samples in Malware DB
83    pub num_files: u64,
84
85    /// Number of user accounts
86    pub num_users: u32,
87
88    /// Number of user groups
89    pub num_groups: u32,
90
91    /// Number of sample sources
92    pub num_sources: u32,
93}
94
95/// Data returned when adding a new sample
96pub struct FileAddedResult {
97    /// File ID
98    pub file_id: u64,
99
100    /// Whether the file was added as a new entry.
101    /// This is false if the sample was already known to Malware DB.
102    pub is_new: bool,
103}
104
105/// Malware DB configuration which is stored in the database
106#[derive(Debug)]
107pub struct MDBConfig {
108    /// The name of this instance of Malware DB
109    pub name: String,
110
111    /// Whether samples are stored compressed
112    pub compression: bool,
113
114    /// Whether Malware DB can send samples to Virus Total
115    pub send_samples_to_vt: bool,
116
117    /// If Malware DB should keep unknown files
118    pub keep_unknown_files: bool,
119
120    /// If samples are to be encrypted, which key?
121    pub(crate) default_key: Option<u32>,
122}
123
124/// VT record information for files in Malware DB
125#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
126#[cfg(feature = "vt")]
127#[derive(Debug, Clone, Copy)]
128pub struct VtStats {
129    /// Files marked as clean
130    pub clean_records: u32,
131
132    /// Files marked as malicious
133    pub hits_records: u32,
134
135    /// Files without VT records
136    pub files_without_records: u32,
137}
138
139impl DatabaseType {
140    /// Get a database connection from a configuration string
141    ///
142    /// # Errors
143    ///
144    /// * If there's a connectivity issue to Postgres, an error will result
145    /// * If the `SQLite` file cannot be created or opened, an error will result
146    /// * If `SQLite` is the type but Malware DB wasn't compiled with the sqlite feature, an error will result
147    /// * If the format or database type isn't known, an error will result
148    pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
149        let db = Self::init_from_string(arg, server_ca).await?;
150        db.migrate_check(Migration::Check).await?;
151        Ok(db)
152    }
153
154    /// Get a database connection from a configuration string and perform a migration, if needed
155    ///
156    /// # Errors
157    ///
158    /// * If there's a connectivity issue to Postgres, an error will result
159    /// * If the `SQLite` file cannot be created or opened, an error will result
160    /// * If `SQLite` is the type but Malware DB wasn't compiled with the sqlite feature, an error will result
161    /// * If the format or database type isn't known, an error will result
162    #[cfg(feature = "admin")]
163    pub async fn migrate(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
164        let db = Self::init_from_string(arg, server_ca).await?;
165        db.migrate_check(Migration::Migrate).await?;
166        Ok(db)
167    }
168
169    async fn init_from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
170        #[cfg(any(test, feature = "sqlite"))]
171        if arg.starts_with("file:") {
172            let new_conn_str = arg.trim_start_matches("file:");
173            let db = DatabaseType::SQLite(Sqlite::new(new_conn_str)?);
174            return Ok(db);
175        }
176
177        if arg.starts_with("postgres") {
178            let new_conn_str = arg.trim_start_matches("postgres");
179            let db = DatabaseType::Postgres(Postgres::new(new_conn_str, server_ca).await?);
180
181            return Ok(db);
182        }
183
184        bail!("unknown database type `{arg}`")
185    }
186
187    /// Set the flag allowing uploads to Virus Total.
188    ///
189    /// # Errors
190    ///
191    /// If there's a connectivity issue to Postgres, an error will result
192    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
193    #[cfg(feature = "vt")]
194    pub async fn enable_vt_upload(&self) -> Result<()> {
195        match self {
196            DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
197            #[cfg(any(test, feature = "sqlite"))]
198            DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
199        }
200    }
201
202    /// Set the flag preventing uploads to Virus Total.
203    ///
204    /// # Errors
205    ///
206    /// If there's a connectivity issue to Postgres, an error will result
207    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
208    #[cfg(feature = "vt")]
209    pub async fn disable_vt_upload(&self) -> Result<()> {
210        match self {
211            DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
212            #[cfg(any(test, feature = "sqlite"))]
213            DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
214        }
215    }
216
217    /// Get the SHA-256 hashes of the files which don't have VT records
218    ///
219    /// # Errors
220    ///
221    /// If there's a connectivity issue to Postgres, an error will result
222    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
223    #[cfg(feature = "vt")]
224    pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
225        match self {
226            DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
227            #[cfg(any(test, feature = "sqlite"))]
228            DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
229        }
230    }
231
232    /// Store the VT results: AV hits and detailed report, or lack of any AV hits
233    ///
234    /// # Errors
235    ///
236    /// If there's a connectivity issue to Postgres, an error will result
237    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
238    #[cfg(feature = "vt")]
239    pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
240        match self {
241            DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
242            #[cfg(any(test, feature = "sqlite"))]
243            DatabaseType::SQLite(sl) => sl.store_vt_record(results),
244        }
245    }
246
247    /// Quick statistics regarding the data contained for VT information for our samples
248    ///
249    /// # Errors
250    ///
251    /// If there's a connectivity issue to Postgres, an error will result
252    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
253    #[cfg(feature = "vt")]
254    pub async fn get_vt_stats(&self) -> Result<VtStats> {
255        match self {
256            DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
257            #[cfg(any(test, feature = "sqlite"))]
258            DatabaseType::SQLite(sl) => sl.get_vt_stats(),
259        }
260    }
261
262    /// Add the Yara search to the database
263    ///
264    /// # Errors
265    ///
266    /// If there's a connectivity issue to Postgres, an error will result
267    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
268    #[cfg(feature = "yara")]
269    pub async fn add_yara_search(
270        &self,
271        uid: u32,
272        yara_string: &str,
273        yara_bytes: &[u8],
274    ) -> Result<uuid::Uuid> {
275        match self {
276            DatabaseType::Postgres(pg) => pg.add_yara_search(uid, yara_string, yara_bytes).await,
277            #[cfg(any(test, feature = "sqlite"))]
278            DatabaseType::SQLite(sl) => sl.add_yara_search(uid, yara_string, yara_bytes),
279        }
280    }
281
282    /// Get unfinished Yara tasks for processing.
283    ///
284    /// # Errors
285    ///
286    /// If there's a connectivity issue to Postgres, an error will result
287    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
288    #[cfg(feature = "yara")]
289    pub async fn get_unfinished_yara_tasks(&self) -> Result<Vec<crate::yara::YaraTask>> {
290        match self {
291            DatabaseType::Postgres(pg) => pg.get_unfinished_yara_tasks().await,
292            #[cfg(any(test, feature = "sqlite"))]
293            DatabaseType::SQLite(sl) => sl.get_unfinished_yara_tasks(),
294        }
295    }
296
297    /// Add a Yara match to the database
298    ///
299    /// # Errors
300    ///
301    /// If there's a connectivity issue to Postgres, an error will result
302    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
303    #[cfg(feature = "yara")]
304    pub async fn add_yara_match(
305        &self,
306        id: uuid::Uuid,
307        rule_name: &str,
308        file_sha256: &str,
309    ) -> Result<()> {
310        match self {
311            DatabaseType::Postgres(pg) => pg.add_yara_match(id, rule_name, file_sha256).await,
312            #[cfg(any(test, feature = "sqlite"))]
313            DatabaseType::SQLite(sl) => sl.add_yara_match(id, rule_name, file_sha256),
314        }
315    }
316
317    /// Indicate that the Yara search task has finished
318    ///
319    /// # Errors
320    ///
321    /// If there's a connectivity issue to Postgres, an error will result
322    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
323    #[cfg(feature = "yara")]
324    pub async fn mark_yara_task_as_finished(&self, id: uuid::Uuid) -> Result<()> {
325        match self {
326            DatabaseType::Postgres(pg) => pg.mark_yara_task_as_finished(id).await,
327            #[cfg(any(test, feature = "sqlite"))]
328            DatabaseType::SQLite(sl) => sl.mark_yara_task_as_finished(id),
329        }
330    }
331
332    /// Add the last file ID for the next iteration
333    ///
334    /// # Errors
335    ///
336    /// If there's a connectivity issue to Postgres, an error will result
337    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
338    #[cfg(feature = "yara")]
339    pub async fn yara_add_next_file_id(&self, id: uuid::Uuid, file_id: u64) -> Result<()> {
340        match self {
341            DatabaseType::Postgres(pg) => pg.yara_add_next_file_id(id, file_id).await,
342            #[cfg(any(test, feature = "sqlite"))]
343            DatabaseType::SQLite(sl) => sl.yara_add_next_file_id(id, file_id),
344        }
345    }
346
347    /// Get the Yara search results
348    ///
349    /// # Errors
350    ///
351    /// If there's a connectivity issue to Postgres, an error will result
352    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
353    #[cfg(feature = "yara")]
354    pub async fn get_yara_results(
355        &self,
356        id: uuid::Uuid,
357        user_id: u32,
358    ) -> Result<malwaredb_api::YaraSearchResponse> {
359        match self {
360            DatabaseType::Postgres(pg) => pg.get_yara_results(id, user_id).await,
361            #[cfg(any(test, feature = "sqlite"))]
362            DatabaseType::SQLite(sl) => sl.get_yara_results(id, user_id),
363        }
364    }
365
366    /// Get the configuration which is stored in the database
367    ///
368    /// # Errors
369    ///
370    /// If there's a connectivity issue to Postgres, an error will result
371    pub async fn get_config(&self) -> Result<MDBConfig> {
372        match self {
373            DatabaseType::Postgres(pg) => pg.get_config().await,
374            #[cfg(any(test, feature = "sqlite"))]
375            DatabaseType::SQLite(sl) => sl.get_config(),
376        }
377    }
378
379    /// Check user credentials, return the API key. Generate if it doesn't exist.
380    ///
381    /// # Errors
382    ///
383    /// * If there's a connectivity issue to Postgres, an error will result
384    /// * If the username and/or password aren't correct, an error will result
385    pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
386        match self {
387            DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
388            #[cfg(any(test, feature = "sqlite"))]
389            DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
390        }
391    }
392
393    /// Get the user's ID from their API key
394    ///
395    /// # Errors
396    ///
397    /// * If there's a connectivity issue to Postgres, an error will result
398    /// * If the api key isn't valid, an error will result
399    pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
400        ensure!(!apikey.is_empty(), "API key was empty");
401        match self {
402            DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
403            #[cfg(any(test, feature = "sqlite"))]
404            DatabaseType::SQLite(sl) => sl.get_uid(apikey),
405        }
406    }
407
408    /// Retrieve information about the database
409    ///
410    /// # Errors
411    ///
412    /// * If there's a connectivity issue to Postgres, an error will result
413    pub async fn db_info(&self) -> Result<DatabaseInformation> {
414        match self {
415            DatabaseType::Postgres(pg) => pg.db_info().await,
416            #[cfg(any(test, feature = "sqlite"))]
417            DatabaseType::SQLite(sl) => sl.db_info(),
418        }
419    }
420
421    /// Retrieve the names of the groups and sources the user is part of and has access to
422    ///
423    /// # Errors
424    ///
425    /// * If there's a connectivity issue to Postgres, an error will result
426    /// * If the user ID isn't valid, an error will result
427    pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
428        match self {
429            DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
430            #[cfg(any(test, feature = "sqlite"))]
431            DatabaseType::SQLite(sl) => sl.get_user_info(uid),
432        }
433    }
434
435    /// Retrieve the source information available to the specified user
436    ///
437    /// # Errors
438    ///
439    /// * If there's a connectivity issue to Postgres, an error will result
440    /// * If the user ID isn't valid, an error will result
441    pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
442        match self {
443            DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
444            #[cfg(any(test, feature = "sqlite"))]
445            DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
446        }
447    }
448
449    /// Let the user clear their own API key to log out from all systems
450    ///
451    /// # Errors
452    ///
453    /// * If there's a connectivity issue to Postgres, an error will result
454    /// * If the user ID isn't valid, an error will result
455    pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
456        match self {
457            DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
458            #[cfg(any(test, feature = "sqlite"))]
459            DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
460        }
461    }
462
463    /// Retrieve the supported data type information
464    ///
465    /// # Errors
466    ///
467    /// If there's a connectivity issue to Postgres, an error will result
468    pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
469        match self {
470            DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
471            #[cfg(any(test, feature = "sqlite"))]
472            DatabaseType::SQLite(sl) => sl.get_known_data_types(),
473        }
474    }
475
476    /// Get all labels from Malware DB
477    ///
478    /// # Errors
479    ///
480    /// If there's a connectivity issue to Postgres, an error will result
481    pub async fn get_labels(&self) -> Result<Labels> {
482        match self {
483            DatabaseType::Postgres(pg) => pg.get_labels().await,
484            #[cfg(any(test, feature = "sqlite"))]
485            DatabaseType::SQLite(sl) => sl.get_labels(),
486        }
487    }
488
489    /// Get the corresponding type ID for a buffer representing a file
490    ///
491    /// # Errors
492    ///
493    /// If there's a connectivity issue to Postgres, an error will result
494    pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
495        match self {
496            DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
497            #[cfg(any(test, feature = "sqlite"))]
498            DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
499        }
500    }
501
502    /// Check that a user has been granted access data for the specific source
503    ///
504    /// # Errors
505    ///
506    /// * If there's a connectivity issue to Postgres, an error will result
507    /// * If the user or source ID(s) aren't valid, an error will result
508    pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
509        match self {
510            DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
511            #[cfg(any(test, feature = "sqlite"))]
512            DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
513        }
514    }
515
516    /// Check to see if the user is an administrator. The user must be a member of the
517    /// admin group (group ID 0), or a one group below (a group with the parent group id of 0).
518    ///
519    /// # Errors
520    ///
521    /// * If there's a connectivity issue to Postgres, an error will result
522    /// * If the user ID isn't valid, an error will result
523    pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
524        match self {
525            DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
526            #[cfg(any(test, feature = "sqlite"))]
527            DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
528        }
529    }
530
531    /// Add a file's metadata to the database, returning true if this is a new entry
532    ///
533    /// # Errors
534    ///
535    /// * If there's a connectivity issue to Postgres, an error will result
536    /// * If the source doesn't exist or the user isn't part of a member group, an error will result
537    pub async fn add_file(
538        &self,
539        meta: &FileMetadata,
540        known_type: KnownType<'_>,
541        uid: u32,
542        sid: u32,
543        ftype: u32,
544        parent: Option<u64>,
545    ) -> Result<FileAddedResult> {
546        match self {
547            DatabaseType::Postgres(pg) => {
548                pg.add_file(meta, known_type, uid, sid, ftype, parent).await
549            }
550            #[cfg(any(test, feature = "sqlite"))]
551            DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
552        }
553    }
554
555    /// Search for allowed samples based on partial search and/or file name
556    ///
557    /// # Errors
558    ///
559    /// * If there's a connectivity issue to Postgres, an error will result
560    pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
561        match self {
562            DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
563            #[cfg(any(test, feature = "sqlite"))]
564            DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
565        }
566    }
567
568    /// Delete old pagination searches
569    ///
570    /// # Errors
571    ///
572    /// An error would occur if the Postgres server couldn't be reached.
573    pub async fn cleanup(&self) -> Result<u64> {
574        match self {
575            DatabaseType::Postgres(pg) => pg.cleanup().await,
576            #[cfg(any(test, feature = "sqlite"))]
577            DatabaseType::SQLite(sl) => sl.cleanup(),
578        }
579    }
580
581    /// Retrieve the SHA-256 hash of the sample while checking that the user is permitted
582    /// to access to it
583    ///
584    /// # Errors
585    ///
586    /// * If there's a connectivity issue to Postgres, an error will result
587    /// * If the file doesn't exist or the user isn't allowed access, an error will result
588    pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
589        match self {
590            DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
591            #[cfg(any(test, feature = "sqlite"))]
592            DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
593        }
594    }
595
596    /// Retrieve a report for a given sample, if allowed.
597    ///
598    /// # Errors
599    ///
600    /// * If there's a connectivity issue to Postgres, an error will result
601    /// * If the file doesn't exist or the user isn't allowed access, an error will result
602    pub async fn get_sample_report(
603        &self,
604        uid: u32,
605        hash: &HashType,
606    ) -> Result<malwaredb_api::Report> {
607        match self {
608            DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
609            #[cfg(any(test, feature = "sqlite"))]
610            DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
611        }
612    }
613
614    /// Given a collection of similarity hashes, find samples which are similar.
615    ///
616    /// # Errors
617    ///
618    /// If there's a connectivity issue to Postgres, an error will result
619    pub async fn find_similar_samples(
620        &self,
621        uid: u32,
622        sim: &[(malwaredb_api::SimilarityHashType, String)],
623    ) -> Result<Vec<malwaredb_api::SimilarSample>> {
624        match self {
625            DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
626            #[cfg(any(test, feature = "sqlite"))]
627            DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
628        }
629    }
630
631    /// For a given user ID, return the file hashes the person is allowed to know about and the last
632    /// file ID, which can be provided as `next` to get the next batch of hashes.
633    ///
634    /// # Errors
635    ///
636    /// If there's a connectivity issue to Postgres, an error will result
637    pub async fn user_allowed_files_by_sha256(
638        &self,
639        uid: u32,
640        next: Option<u64>,
641    ) -> Result<(Vec<String>, u64)> {
642        match self {
643            DatabaseType::Postgres(pg) => pg.user_allowed_files_by_sha256(uid, next).await,
644            #[cfg(any(test, feature = "sqlite"))]
645            DatabaseType::SQLite(sl) => sl.user_allowed_files_by_sha256(uid, next),
646        }
647    }
648
649    // Private functions
650
651    /// Get the file encryption keys
652    pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
653        match self {
654            DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
655            #[cfg(any(test, feature = "sqlite"))]
656            DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
657        }
658    }
659
660    /// Get the key and AES nonce for a specific file identified by SHA-256 hash
661    pub(crate) async fn get_file_encryption_key_id(
662        &self,
663        hash: &str,
664    ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
665        match self {
666            DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
667            #[cfg(any(test, feature = "sqlite"))]
668            DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
669        }
670    }
671
672    /// Set the AES nonce for a file specified by SHA-256 hash, a `None` value removes it.
673    pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
674        match self {
675            DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
676            #[cfg(any(test, feature = "sqlite"))]
677            DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
678        }
679    }
680
681    /// Checks if a migration is needed, erroring if action is to check and the schema has changed.
682    ///
683    /// # Errors
684    ///
685    /// If there's a connectivity issue to Postgres, an error will result; if a migration is needed
686    /// and not an administrative action, an error results.
687    pub async fn migrate_check(&self, action: Migration) -> Result<()> {
688        match self {
689            DatabaseType::Postgres(pg) => pg.migrate(action).await,
690            #[cfg(any(test, feature = "sqlite"))]
691            DatabaseType::SQLite(sl) => sl.migrate(action),
692        }
693    }
694
695    // Administrative functions
696
697    /// Set the instance name
698    ///
699    /// # Errors
700    ///
701    /// If there's a connectivity issue to Postgres, an error will result
702    #[cfg(any(test, feature = "admin"))]
703    pub async fn set_name(&self, name: &str) -> Result<()> {
704        match self {
705            DatabaseType::Postgres(pg) => pg.set_name(name).await,
706            #[cfg(any(test, feature = "sqlite"))]
707            DatabaseType::SQLite(sl) => sl.set_name(name),
708        }
709    }
710
711    /// Set the compression flag
712    ///
713    /// # Errors
714    ///
715    /// If there's a connectivity issue to Postgres, an error will result
716    #[cfg(any(test, feature = "admin"))]
717    pub async fn enable_compression(&self) -> Result<()> {
718        match self {
719            DatabaseType::Postgres(pg) => pg.enable_compression().await,
720            #[cfg(any(test, feature = "sqlite"))]
721            DatabaseType::SQLite(sl) => sl.enable_compression(),
722        }
723    }
724
725    /// Unset the compression flag, does not go and decompress files already compressed!
726    ///
727    /// # Errors
728    ///
729    /// If there's a connectivity issue to Postgres, an error will result
730    #[cfg(any(test, feature = "admin"))]
731    pub async fn disable_compression(&self) -> Result<()> {
732        match self {
733            DatabaseType::Postgres(pg) => pg.disable_compression().await,
734            #[cfg(any(test, feature = "sqlite"))]
735            DatabaseType::SQLite(sl) => sl.disable_compression(),
736        }
737    }
738
739    /// Set the keep unknown files flag
740    ///
741    /// # Errors
742    ///
743    /// If there's a connectivity issue to Postgres, an error will result
744    #[cfg(any(test, feature = "admin"))]
745    pub async fn enable_keep_unknown_files(&self) -> Result<()> {
746        match self {
747            DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
748            #[cfg(any(test, feature = "sqlite"))]
749            DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
750        }
751    }
752
753    /// Unset the keep unknown files flag, does not go and remove unknown files!
754    ///
755    /// # Errors
756    ///
757    /// If there's a connectivity issue to Postgres, an error will result
758    #[cfg(any(test, feature = "admin"))]
759    pub async fn disable_keep_unknown_files(&self) -> Result<()> {
760        match self {
761            DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
762            #[cfg(any(test, feature = "sqlite"))]
763            DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
764        }
765    }
766
767    /// Add an encryption key to the database, set it as the default, and return the key ID
768    ///
769    /// # Errors
770    ///
771    /// * If there is a connectivity with Postgres, an error will result
772    #[cfg(any(test, feature = "admin"))]
773    pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
774        match self {
775            DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
776            #[cfg(any(test, feature = "sqlite"))]
777            DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
778        }
779    }
780
781    /// Get the key ID and algorithm names
782    ///
783    /// # Errors
784    ///
785    /// If there is a connectivity with Postgres, an error will result
786    #[cfg(any(test, feature = "admin"))]
787    pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
788        match self {
789            DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
790            #[cfg(any(test, feature = "sqlite"))]
791            DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
792        }
793    }
794
795    /// Create a user account, return the user ID.
796    ///
797    /// # Errors
798    ///
799    /// If there's a connectivity issue to Postgres, an error will result
800    #[allow(clippy::too_many_arguments)]
801    #[cfg(any(test, feature = "admin"))]
802    pub async fn create_user(
803        &self,
804        uname: &str,
805        fname: &str,
806        lname: &str,
807        email: &str,
808        password: Option<String>,
809        organisation: Option<&String>,
810        readonly: bool,
811    ) -> Result<u32> {
812        match self {
813            DatabaseType::Postgres(pg) => {
814                pg.create_user(uname, fname, lname, email, password, organisation, readonly)
815                    .await
816            }
817            #[cfg(any(test, feature = "sqlite"))]
818            DatabaseType::SQLite(sl) => {
819                sl.create_user(uname, fname, lname, email, password, organisation, readonly)
820            }
821        }
822    }
823
824    /// Clear all API keys, either in case of suspected activity, or part of policy
825    ///
826    /// # Errors
827    ///
828    /// If there's a connectivity issue to Postgres, an error will result
829    #[cfg(any(test, feature = "admin"))]
830    pub async fn reset_api_keys(&self) -> Result<u64> {
831        match self {
832            DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
833            #[cfg(any(test, feature = "sqlite"))]
834            DatabaseType::SQLite(sl) => sl.reset_api_keys(),
835        }
836    }
837
838    /// Set a user's password
839    ///
840    /// # Errors
841    ///
842    /// * If there's a connectivity issue to Postgres, an error will result
843    /// * If the user doesn't exist, an error will result
844    #[cfg(any(test, feature = "admin"))]
845    pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
846        match self {
847            DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
848            #[cfg(any(test, feature = "sqlite"))]
849            DatabaseType::SQLite(sl) => sl.set_password(uname, password),
850        }
851    }
852
853    /// Get the complete list of users
854    ///
855    /// # Errors
856    ///
857    /// If there's a connectivity issue to Postgres, an error will result
858    #[cfg(any(test, feature = "admin"))]
859    pub async fn list_users(&self) -> Result<Vec<admin::User>> {
860        match self {
861            DatabaseType::Postgres(pg) => pg.list_users().await,
862            #[cfg(any(test, feature = "sqlite"))]
863            DatabaseType::SQLite(sl) => sl.list_users(),
864        }
865    }
866
867    /// Get the ID of a group from name
868    ///
869    /// # Errors
870    ///
871    /// * If there's a connectivity issue to Postgres, an error will result
872    /// * If the group name isn't valid, an error will result
873    #[cfg(any(test, feature = "admin"))]
874    pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
875        match self {
876            DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
877            #[cfg(any(test, feature = "sqlite"))]
878            DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
879        }
880    }
881
882    /// Update the record for a group
883    ///
884    /// # Errors
885    ///
886    /// * If there's a connectivity issue to Postgres, an error will result
887    /// * If the group doesn't exist, an error will result
888    #[cfg(any(test, feature = "admin"))]
889    pub async fn edit_group(
890        &self,
891        gid: u32,
892        name: &str,
893        desc: &str,
894        parent: Option<u32>,
895    ) -> Result<()> {
896        match self {
897            DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
898            #[cfg(any(test, feature = "sqlite"))]
899            DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
900        }
901    }
902
903    /// Get the complete list of groups
904    ///
905    /// # Errors
906    ///
907    /// If there's a connectivity issue to Postgres, an error will result
908    #[cfg(any(test, feature = "admin"))]
909    pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
910        match self {
911            DatabaseType::Postgres(pg) => pg.list_groups().await,
912            #[cfg(any(test, feature = "sqlite"))]
913            DatabaseType::SQLite(sl) => sl.list_groups(),
914        }
915    }
916
917    /// Grant a user membership to a group, both by id.
918    ///
919    /// # Errors
920    ///
921    /// * If there's a connectivity issue to Postgres, an error will result
922    /// * If the user or group doesn't exist, an error will result
923    #[cfg(any(test, feature = "admin"))]
924    pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
925        match self {
926            DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
927            #[cfg(any(test, feature = "sqlite"))]
928            DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
929        }
930    }
931
932    /// Grand a group access to a source, both by id.
933    ///
934    /// # Errors
935    ///
936    /// * If there's a connectivity issue to Postgres, an error will result
937    /// * If the group or source doesn't exist, an error will result
938    #[cfg(any(test, feature = "admin"))]
939    pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
940        match self {
941            DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
942            #[cfg(any(test, feature = "sqlite"))]
943            DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
944        }
945    }
946
947    /// Create a new group, returning the group ID
948    ///
949    /// # Errors
950    ///
951    /// * If there's a connectivity issue to Postgres, an error will result
952    /// * If the group name is already taken, an error will result
953    #[cfg(any(test, feature = "admin"))]
954    pub async fn create_group(
955        &self,
956        name: &str,
957        description: &str,
958        parent: Option<u32>,
959    ) -> Result<u32> {
960        match self {
961            DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
962            #[cfg(any(test, feature = "sqlite"))]
963            DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
964        }
965    }
966
967    /// Get the complete list of sources
968    ///
969    /// # Errors
970    ///
971    /// If there's a connectivity issue to Postgres, an error will result
972    #[cfg(any(test, feature = "admin"))]
973    pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
974        match self {
975            DatabaseType::Postgres(pg) => pg.list_sources().await,
976            #[cfg(any(test, feature = "sqlite"))]
977            DatabaseType::SQLite(sl) => sl.list_sources(),
978        }
979    }
980
981    /// Create a source, returning the source ID
982    ///
983    /// # Errors
984    ///
985    /// * If there's a connectivity issue to Postgres, an error will result
986    /// * If the source already exists, an error will result
987    #[cfg(any(test, feature = "admin"))]
988    pub async fn create_source(
989        &self,
990        name: &str,
991        description: Option<&str>,
992        url: Option<&str>,
993        date: chrono::DateTime<Local>,
994        releasable: bool,
995        malicious: Option<bool>,
996    ) -> Result<u32> {
997        match self {
998            DatabaseType::Postgres(pg) => {
999                pg.create_source(name, description, url, date, releasable, malicious)
1000                    .await
1001            }
1002            #[cfg(any(test, feature = "sqlite"))]
1003            DatabaseType::SQLite(sl) => {
1004                sl.create_source(name, description, url, date, releasable, malicious)
1005            }
1006        }
1007    }
1008
1009    /// Edit a user account setting the specified field values, primarily used by the Admin gui
1010    ///
1011    /// # Errors
1012    ///
1013    /// * If there's a connectivity issue to Postgres, an error will result
1014    /// * If the user isn't valid, an error will result
1015    #[cfg(any(test, feature = "admin"))]
1016    pub async fn edit_user(
1017        &self,
1018        uid: u32,
1019        uname: &str,
1020        fname: &str,
1021        lname: &str,
1022        email: &str,
1023        readonly: bool,
1024    ) -> Result<()> {
1025        match self {
1026            DatabaseType::Postgres(pg) => {
1027                pg.edit_user(uid, uname, fname, lname, email, readonly)
1028                    .await
1029            }
1030            #[cfg(any(test, feature = "sqlite"))]
1031            DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
1032        }
1033    }
1034
1035    /// Deactivate, but don't delete, a user's account
1036    ///
1037    /// # Errors
1038    ///
1039    /// * If there's a connectivity issue to Postgres, an error will result
1040    /// * If the user ID isn't valid, an error will result
1041    #[cfg(any(test, feature = "admin"))]
1042    pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
1043        match self {
1044            DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
1045            #[cfg(any(test, feature = "sqlite"))]
1046            DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
1047        }
1048    }
1049
1050    /// File types and number of files per type
1051    ///
1052    /// # Errors
1053    ///
1054    /// * If there's a connectivity issue to Postgres, an error will result
1055    #[cfg(any(test, feature = "admin"))]
1056    pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
1057        match self {
1058            DatabaseType::Postgres(pg) => pg.file_types_counts().await,
1059            #[cfg(any(test, feature = "sqlite"))]
1060            DatabaseType::SQLite(sl) => sl.file_types_counts(),
1061        }
1062    }
1063
1064    /// Create a new label, returning the label ID
1065    ///
1066    /// # Errors
1067    ///
1068    /// * If there's a connectivity issue to Postgres, an error will result
1069    /// * If the parent label doesn't exist, an error will result
1070    /// * If the label name already exists, an error will result
1071    #[cfg(any(test, feature = "admin"))]
1072    pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
1073        match self {
1074            DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
1075            #[cfg(any(test, feature = "sqlite"))]
1076            DatabaseType::SQLite(sl) => sl.create_label(name, parent),
1077        }
1078    }
1079
1080    /// Edit a label name or parent
1081    ///
1082    /// # Errors
1083    ///
1084    /// * If there's a connectivity issue to Postgres, an error will result
1085    /// * If the parent label doesn't exist, an error will result
1086    /// * If the label name already exists, an error will result
1087    #[cfg(any(test, feature = "admin"))]
1088    pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
1089        match self {
1090            DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
1091            #[cfg(any(test, feature = "sqlite"))]
1092            DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
1093        }
1094    }
1095
1096    /// Return the ID for a given label name
1097    ///
1098    /// # Errors
1099    ///
1100    /// * If there's a connectivity issue to Postgres, an error will result
1101    /// * If the label ID is invalid, an error will result
1102    #[cfg(any(test, feature = "admin"))]
1103    pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
1104        match self {
1105            DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
1106            #[cfg(any(test, feature = "sqlite"))]
1107            DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
1108        }
1109    }
1110
1111    /// Associate an existing label by its IDs with a file.
1112    ///
1113    /// # Errors
1114    ///
1115    /// * Incorrect IDs will result in an error
1116    /// * If the file already has the given label associated
1117    /// * If there is a network or connection issue with Postgres
1118    #[cfg(any(test, feature = "admin"))]
1119    pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
1120        match self {
1121            DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
1122            #[cfg(any(test, feature = "sqlite"))]
1123            DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
1124        }
1125    }
1126}
1127
1128/// Hash a password string with [Argon2]
1129///
1130/// # Errors
1131///
1132/// An error may result if Argon has an error
1133pub fn hash_password(password: &str) -> Result<String> {
1134    let salt = SaltString::generate(&mut OsRng);
1135    let argon2 = Argon2::default();
1136    Ok(argon2
1137        .hash_password(password.as_bytes(), &salt)?
1138        .to_string())
1139}
1140
1141/// Generate a new, random API key
1142#[must_use]
1143pub fn random_bytes_api_key() -> String {
1144    let key1 = uuid::Uuid::new_v4();
1145    let key2 = uuid::Uuid::new_v4();
1146    let key1 = key1.to_string().replace('-', "");
1147    let key2 = key2.to_string().replace('-', "");
1148    format!("{key1}{key2}")
1149}
1150
1151#[cfg(test)]
1152mod tests {
1153    use super::*;
1154    #[cfg(feature = "vt")]
1155    use crate::vt::VtUpdater;
1156
1157    use std::fs;
1158    #[cfg(feature = "vt")]
1159    use std::sync::Arc;
1160    #[cfg(feature = "vt")]
1161    use std::time::SystemTime;
1162
1163    use anyhow::Context;
1164    use fuzzyhash::FuzzyHash;
1165    use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
1166    use malwaredb_lzjd::{LZDict, Murmur3HashState};
1167    use tlsh_fixed::TlshBuilder;
1168    use uuid::Uuid;
1169
1170    const MALWARE_LABEL: &str = "malware";
1171    const RANSOMWARE_LABEL: &str = "ransomware";
1172
1173    fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1174        let mut hashes = vec![];
1175
1176        hashes.push((
1177            malwaredb_api::SimilarityHashType::SSDeep,
1178            FuzzyHash::new(data).to_string(),
1179        ));
1180
1181        let mut builder = TlshBuilder::new(
1182            tlsh_fixed::BucketKind::Bucket256,
1183            tlsh_fixed::ChecksumKind::ThreeByte,
1184            tlsh_fixed::Version::Version4,
1185        );
1186
1187        builder.update(data);
1188
1189        if let Ok(hasher) = builder.build() {
1190            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1191        }
1192
1193        let build_hasher = Murmur3HashState::default();
1194        let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1195        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1196
1197        malwaredb_api::SimilarSamplesRequest { hashes }
1198    }
1199
1200    async fn pg_config() -> Postgres {
1201        // create user malwaredbtesting with password 'malwaredbtesting';
1202        // create database malwaredbtesting owner malwaredbtesting;
1203        const CONNECTION_STRING: &str =
1204            "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1205
1206        if let Ok(pg_port) = std::env::var("PG_PORT") {
1207            // Get the port number to run in Github CI
1208            let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1209            Postgres::new(&conn_string, None)
1210                .await
1211                .context(format!(
1212                    "failed to connect to postgres with specified port {pg_port}"
1213                ))
1214                .unwrap()
1215        } else {
1216            Postgres::new(CONNECTION_STRING, None).await.unwrap()
1217        }
1218    }
1219
1220    #[tokio::test]
1221    #[ignore = "don't run this in CI"]
1222    async fn pg() {
1223        let psql = pg_config().await;
1224        psql.delete().await.unwrap();
1225
1226        let psql = pg_config().await;
1227        let db = DatabaseType::Postgres(psql);
1228        everything(&db).await.unwrap();
1229
1230        #[cfg(feature = "vt")]
1231        {
1232            let db_config = db.get_config().await.unwrap();
1233            let state = crate::State {
1234                port: 8080,
1235                directory: None,
1236                max_upload: 10 * 1024 * 1024,
1237                ip: "127.0.0.1".parse().unwrap(),
1238                db_type: Arc::new(db),
1239                db_config,
1240                keys: HashMap::new(),
1241                started: SystemTime::now(),
1242                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1243                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1244                }),
1245                tls_config: None,
1246                mdns: None,
1247            };
1248
1249            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1250
1251            vt.updater().await.unwrap();
1252            println!("PG: Did VT ops!");
1253
1254            let psql = pg_config().await;
1255
1256            let vt_stats = psql
1257                .get_vt_stats()
1258                .await
1259                .context("failed to get Postgres VT Stats")
1260                .unwrap();
1261            println!("{vt_stats:?}");
1262            assert!(
1263                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1264            );
1265        }
1266
1267        // Re-create the Postgres object so we can do some clean-up
1268        let psql = pg_config().await;
1269        psql.delete().await.unwrap();
1270    }
1271
1272    #[tokio::test]
1273    async fn sqlite() {
1274        const DB_FILE: &str = "testing_sqlite.db";
1275        if std::path::Path::new(DB_FILE).exists() {
1276            fs::remove_file(DB_FILE)
1277                .context(format!("failed to delete old SQLite file {DB_FILE}"))
1278                .unwrap();
1279        }
1280
1281        let sqlite = Sqlite::new(DB_FILE)
1282            .context(format!("failed to create SQLite instance for {DB_FILE}"))
1283            .unwrap();
1284
1285        let db = DatabaseType::SQLite(sqlite);
1286        everything(&db).await.unwrap();
1287
1288        #[cfg(feature = "vt")]
1289        {
1290            let db_config = db.get_config().await.unwrap();
1291            let state = crate::State {
1292                port: 8080,
1293                directory: None,
1294                max_upload: 10 * 1024 * 1024,
1295                ip: "127.0.0.1".parse().unwrap(),
1296                db_type: Arc::new(db),
1297                db_config,
1298                keys: HashMap::new(),
1299                started: SystemTime::now(),
1300                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1301                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1302                }),
1303                tls_config: None,
1304                mdns: None,
1305            };
1306
1307            let sqlite_second = Sqlite::new(DB_FILE)
1308                .context(format!("failed to create SQLite instance for {DB_FILE}"))
1309                .unwrap();
1310
1311            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1312
1313            vt.updater().await.unwrap();
1314            println!("Sqlite: Did VT ops!");
1315            let vt_stats = sqlite_second
1316                .get_vt_stats()
1317                .context("failed to get Sqlite VT Stats")
1318                .unwrap();
1319            println!("{vt_stats:?}");
1320            assert!(
1321                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1322            );
1323        }
1324
1325        fs::remove_file(DB_FILE)
1326            .context(format!("failed to delete SQLite file {DB_FILE}"))
1327            .unwrap();
1328    }
1329
1330    #[allow(clippy::too_many_lines)]
1331    async fn everything(db: &DatabaseType) -> Result<()> {
1332        const ADMIN_UNAME: &str = "admin";
1333        const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1334
1335        db.set_name("Testing Database")
1336            .await
1337            .context("setting instance name failed")?;
1338
1339        assert!(
1340            db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1341            "Authentication without password should have failed."
1342        );
1343
1344        db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1345            .await
1346            .context("failed to set admin password")?;
1347
1348        let admin_api_key = db
1349            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1350            .await
1351            .context("unable to get api key for admin")?;
1352        println!("API key: {admin_api_key}");
1353        assert_eq!(admin_api_key.len(), 64);
1354
1355        assert_eq!(
1356            db.get_uid(&admin_api_key).await?,
1357            0,
1358            "Unable to get UID given the API key"
1359        );
1360
1361        let admin_api_key_again = db
1362            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1363            .await
1364            .context("unable to get api key a second time for admin")?;
1365
1366        assert_eq!(
1367            admin_api_key, admin_api_key_again,
1368            "API keys didn't match the second time."
1369        );
1370
1371        let bad_password = "this_is_totally_not_my_password!!";
1372        eprintln!("Testing API login with incorrect password.");
1373        assert!(
1374            db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1375            "Authenticating as admin with a bad password should have failed."
1376        );
1377
1378        let admin_is_admin = db
1379            .user_is_admin(0)
1380            .await
1381            .context("unable to see if admin (uid 0) is an admin")?;
1382        assert!(admin_is_admin);
1383
1384        let new_user_uname = "testuser";
1385        let new_user_email = "test@example.com";
1386        let new_user_password = "some_awesome_password_++";
1387        let new_id = db
1388            .create_user(
1389                new_user_uname,
1390                new_user_uname,
1391                new_user_uname,
1392                new_user_email,
1393                Some(new_user_password.into()),
1394                None,
1395                false,
1396            )
1397            .await
1398            .context(format!("failed to create user {new_user_uname}"))?;
1399
1400        let passwordless_user_id = db
1401            .create_user(
1402                "passwordless_user",
1403                "passwordless_user",
1404                "passwordless_user",
1405                "passwordless_user@example.com",
1406                None,
1407                None,
1408                false,
1409            )
1410            .await
1411            .context("failed to create passwordless_user")?;
1412
1413        for user in &db.list_users().await.context("failed to list users")? {
1414            if user.id == passwordless_user_id {
1415                assert_eq!(user.uname, "passwordless_user");
1416            }
1417        }
1418
1419        db.edit_user(
1420            passwordless_user_id,
1421            "passwordless_user_2",
1422            "passwordless_user_2",
1423            "passwordless_user_2",
1424            "passwordless_user_2@something.com",
1425            false,
1426        )
1427        .await
1428        .context(format!(
1429            "failed to alter 'passwordless' user, id {passwordless_user_id}"
1430        ))?;
1431
1432        for user in &db.list_users().await.context("failed to list users")? {
1433            if user.id == passwordless_user_id {
1434                assert_eq!(user.uname, "passwordless_user_2");
1435            }
1436        }
1437
1438        assert!(
1439            new_id > 0,
1440            "Weird UID created for user {new_user_uname}: {new_id}"
1441        );
1442
1443        assert!(
1444            db.create_user(
1445                new_user_uname,
1446                new_user_uname,
1447                new_user_uname,
1448                new_user_email,
1449                Some(new_user_password.into()),
1450                None,
1451                false
1452            )
1453            .await
1454            .is_err(),
1455            "Creating a new user with the same user name should fail"
1456        );
1457
1458        let ro_user_name = "ro_user";
1459        let ro_user_password = "ro_user_password";
1460        db.create_user(
1461            ro_user_name,
1462            "ro_user",
1463            "ro_user",
1464            "ro@example.com",
1465            Some(ro_user_password.into()),
1466            None,
1467            true,
1468        )
1469        .await
1470        .context("failed to create read-only user")?;
1471
1472        let ro_user_api_key = db
1473            .authenticate(ro_user_name, ro_user_password)
1474            .await
1475            .context("unable to get api key for read-only user")?;
1476
1477        let new_user_password_change = "some_new_awesomer_password!_++";
1478        db.set_password(new_user_uname, new_user_password_change)
1479            .await
1480            .context("failed to change the password for testuser")?;
1481
1482        let new_user_api_key = db
1483            .authenticate(new_user_uname, new_user_password_change)
1484            .await
1485            .context("unable to get api key for testuser")?;
1486        eprintln!("{new_user_uname} got API key {new_user_api_key}");
1487
1488        assert_eq!(admin_api_key.len(), new_user_api_key.len());
1489
1490        let users = db.list_users().await.context("failed to list users")?;
1491        assert_eq!(
1492            users.len(),
1493            4,
1494            "Four users were created, yet there are {} users",
1495            users.len()
1496        );
1497        eprintln!("DB has {} users:", users.len());
1498        let mut passwordless_user_found = false;
1499        for user in users {
1500            println!("{user}");
1501            if user.uname == "passwordless_user_2" {
1502                assert!(!user.has_api_key);
1503                assert!(!user.has_password);
1504                passwordless_user_found = true;
1505            } else {
1506                assert!(user.has_api_key);
1507                assert!(user.has_password);
1508            }
1509        }
1510        assert!(passwordless_user_found);
1511
1512        let new_group_name = "some_new_group";
1513        let new_group_desc = "some_new_group_description";
1514        let new_group_id = 1;
1515        assert_eq!(
1516            db.create_group(new_group_name, new_group_desc, None)
1517                .await
1518                .context("failed to create group")?,
1519            new_group_id,
1520            "New group didn't have the expected ID, expected {new_group_id}"
1521        );
1522
1523        assert!(
1524            db.create_group(new_group_name, new_group_desc, None)
1525                .await
1526                .is_err(),
1527            "Duplicate group name should have failed"
1528        );
1529
1530        db.add_user_to_group(1, 1)
1531            .await
1532            .context("Unable to add uid 1 to gid 1")?;
1533
1534        let ro_user_uid = db
1535            .get_uid(&ro_user_api_key)
1536            .await
1537            .context("Unable to get UID for read-only user")?;
1538        db.add_user_to_group(ro_user_uid, 1)
1539            .await
1540            .context("Unable to add uid 2 to gid 1")?;
1541
1542        let new_admin_group_name = "admin_subgroup";
1543        let new_admin_group_desc = "admin_subgroup_description";
1544        let new_admin_group_id = 2;
1545        // TODO: Figure out why SQLite makes the group_id = 2, but with Postgres it's 3.
1546        assert!(
1547            db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1548                .await
1549                .context("failed to create admin sub-group")?
1550                >= new_admin_group_id,
1551            "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1552        );
1553
1554        let groups = db.list_groups().await.context("failed to list groups")?;
1555        assert_eq!(
1556            groups.len(),
1557            3,
1558            "Three groups were created, yet there are {} groups",
1559            groups.len()
1560        );
1561        eprintln!("DB has {} groups:", groups.len());
1562        for group in groups {
1563            println!("{group}");
1564            if group.id == new_admin_group_id {
1565                assert_eq!(group.parent, Some("admin".to_string()));
1566            }
1567            if group.id == 1 {
1568                let test_user_str = String::from(new_user_uname);
1569                let mut found = false;
1570                for member in group.members {
1571                    if member.uname == test_user_str {
1572                        found = true;
1573                        break;
1574                    }
1575                }
1576                assert!(found, "new user {test_user_str} wasn't in the group");
1577            }
1578        }
1579
1580        let default_source_name = "default_source".to_string();
1581        let default_source_id = db
1582            .create_source(
1583                &default_source_name,
1584                Some("desc_default_source"),
1585                None,
1586                Local::now(),
1587                true,
1588                Some(false),
1589            )
1590            .await
1591            .context("failed to create source `default_source`")?;
1592
1593        db.add_group_to_source(1, default_source_id)
1594            .await
1595            .context("failed to add group 1 to source 1")?;
1596
1597        let another_source_name = "another_source".to_string();
1598        let another_source_id = db
1599            .create_source(
1600                &another_source_name,
1601                Some("yet another file source"),
1602                None,
1603                Local::now(),
1604                true,
1605                Some(false),
1606            )
1607            .await
1608            .context("failed to create source `another_source`")?;
1609
1610        let empty_source_name = "empty_source".to_string();
1611        db.create_source(
1612            &empty_source_name,
1613            Some("empty and unused file source"),
1614            None,
1615            Local::now(),
1616            true,
1617            Some(false),
1618        )
1619        .await
1620        .context("failed to create source `another_source`")?;
1621
1622        db.add_group_to_source(1, another_source_id)
1623            .await
1624            .context("failed to add group 1 to source 1")?;
1625
1626        let sources = db.list_sources().await.context("failed to list sources")?;
1627        eprintln!("DB has {} sources:", sources.len());
1628        for source in sources {
1629            println!("{source}");
1630            assert_eq!(source.files, 0);
1631            if source.id == default_source_id || source.id == another_source_id {
1632                assert_eq!(
1633                    source.groups, 1,
1634                    "default source {default_source_name} should have 1 group"
1635                );
1636            } else {
1637                assert_eq!(source.groups, 0, "groups should zero (empty)");
1638            }
1639        }
1640
1641        let uid = db
1642            .get_uid(&new_user_api_key)
1643            .await
1644            .context("failed to user uid from apikey")?;
1645        let user_info = db
1646            .get_user_info(uid)
1647            .await
1648            .context("failed to get user's available groups and sources")?;
1649        assert!(user_info.sources.contains(&default_source_name));
1650        assert!(!user_info.is_admin);
1651        println!("UserInfoResponse: {user_info:?}");
1652
1653        assert!(
1654            db.allowed_user_source(1, default_source_id)
1655                .await
1656                .context(format!(
1657                    "failed to check that user 1 has access to source {default_source_id}"
1658                ))?,
1659            "User 1 should should have had access to source {default_source_id}"
1660        );
1661
1662        assert!(
1663            !db.allowed_user_source(1, 5)
1664                .await
1665                .context("failed to check that user 1 has access to source 5")?,
1666            "User 1 should should not have had access to source 5"
1667        );
1668
1669        let test_label_id = db
1670            .create_label("TestLabel", None)
1671            .await
1672            .context("failed to create test label")?;
1673        let test_elf_label_id = db
1674            .create_label("TestELF", Some(test_label_id))
1675            .await
1676            .context("failed to create test label")?;
1677
1678        let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1679        let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1680        let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1681
1682        let known_type =
1683            KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1684        assert!(known_type.is_exec(), "ELF should be executable");
1685        eprintln!("ELF type ID: {elf_type}");
1686
1687        let file_addition = db
1688            .add_file(
1689                &test_elf_meta,
1690                known_type.clone(),
1691                1,
1692                default_source_id,
1693                elf_type,
1694                None,
1695            )
1696            .await
1697            .context("failed to insert a test elf")?;
1698        assert!(file_addition.is_new, "File should have been added");
1699        eprintln!("Added ELF to the DB");
1700        db.label_file(file_addition.file_id, test_elf_label_id)
1701            .await
1702            .context("failed to label file")?;
1703
1704        // Search with several fields
1705        let partial_search = SearchRequest {
1706            search: SearchType::Search(SearchRequestParameters {
1707                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1708                labels: Some(vec![String::from("TestELF")]),
1709                file_type: Some(String::from("ELF")),
1710                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1711                ..Default::default()
1712            }),
1713        };
1714        assert!(partial_search.is_valid());
1715        let partial_search_response = db.partial_search(1, partial_search).await?;
1716        assert_eq!(partial_search_response.hashes.len(), 1);
1717        assert_eq!(
1718            partial_search_response.hashes[0],
1719            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1720        );
1721
1722        // Ensure we get a partial result just with the magic
1723        let partial_search = SearchRequest {
1724            search: SearchType::Search(SearchRequestParameters {
1725                partial_hash: None,
1726                labels: None,
1727                file_type: None,
1728                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1729                ..Default::default()
1730            }),
1731        };
1732        assert!(partial_search.is_valid());
1733        let partial_search_response = db.partial_search(1, partial_search).await?;
1734        assert_eq!(partial_search_response.hashes.len(), 1);
1735        assert_eq!(
1736            partial_search_response.hashes[0],
1737            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1738        );
1739
1740        // Should be valid yet return nothing since it's the wrong file type
1741        let partial_search = SearchRequest {
1742            search: SearchType::Search(SearchRequestParameters {
1743                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1744                file_type: Some(String::from("PE32")),
1745                ..Default::default()
1746            }),
1747        };
1748        assert!(partial_search.is_valid());
1749        let partial_search_response = db.partial_search(1, partial_search).await?;
1750        assert_eq!(partial_search_response.hashes.len(), 0);
1751
1752        let partial_search = SearchRequest {
1753            search: SearchType::Search(SearchRequestParameters {
1754                file_name: Some("ppc64".into()),
1755                ..Default::default()
1756            }),
1757        };
1758        assert!(partial_search.is_valid());
1759        let partial_search_response = db.partial_search(1, partial_search).await?;
1760        assert_eq!(partial_search_response.hashes.len(), 1);
1761
1762        // Invalid search request should return empty results
1763        let partial_search = SearchRequest {
1764            search: SearchType::Search(SearchRequestParameters::default()),
1765        };
1766        assert!(!partial_search.is_valid());
1767        let partial_search_response = db.partial_search(1, partial_search).await?;
1768        assert!(partial_search_response.hashes.is_empty());
1769
1770        // Invalid search pagination should return empty results
1771        let partial_search = SearchRequest {
1772            search: SearchType::Continuation(Uuid::default()),
1773        };
1774        assert!(partial_search.is_valid());
1775        let partial_search_response = db.partial_search(1, partial_search).await?;
1776        assert!(partial_search_response.hashes.is_empty());
1777
1778        // Ensure a type representing an unknown type isn't found
1779        assert!(db
1780            .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1781            .await
1782            .is_err());
1783
1784        assert!(
1785            db.add_file(
1786                &test_elf_meta,
1787                known_type.clone(),
1788                ro_user_uid,
1789                default_source_id,
1790                elf_type,
1791                None
1792            )
1793            .await
1794            .is_err(),
1795            "Read-only user should not be able to add a file"
1796        );
1797
1798        let mut test_elf_meta_different_name = test_elf_meta.clone();
1799        test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1800
1801        assert!(
1802            !db.add_file(
1803                &test_elf_meta_different_name,
1804                known_type,
1805                1,
1806                another_source_id,
1807                elf_type,
1808                None
1809            )
1810            .await
1811            .context("failed to insert a test elf again for a different source")?
1812            .is_new
1813        );
1814
1815        let sources = db
1816            .list_sources()
1817            .await
1818            .context("failed to re-list sources")?;
1819        eprintln!(
1820            "DB has {} sources, and a file was added twice:",
1821            sources.len()
1822        );
1823        println!("We should have two sources with one file each, yet only one ELF.");
1824        for source in sources {
1825            println!("{source}");
1826            if source.id == default_source_id || source.id == another_source_id {
1827                assert_eq!(source.files, 1);
1828            } else {
1829                assert_eq!(source.files, 0, "groups should zero (empty)");
1830            }
1831        }
1832
1833        assert!(!db
1834            .get_user_sources(1)
1835            .await
1836            .expect("failed to get user 1's sources")
1837            .sources
1838            .is_empty());
1839
1840        let file_types_counts = db
1841            .file_types_counts()
1842            .await
1843            .context("failed to get file types and counts")?;
1844        for (name, count) in file_types_counts {
1845            println!("{name}: {count}");
1846            assert_eq!(name, "ELF");
1847            assert_eq!(count, 1);
1848        }
1849
1850        let mut test_elf_modified = test_elf.clone();
1851        let random_bytes = Uuid::new_v4();
1852        let mut random_bytes = random_bytes.into_bytes().to_vec();
1853        test_elf_modified.append(&mut random_bytes);
1854        let similarity_request = generate_similarity_request(&test_elf_modified);
1855        let similarity_response = db
1856            .find_similar_samples(1, &similarity_request.hashes)
1857            .await
1858            .context("failed to get similarity response")?;
1859        eprintln!("Similarity response: {similarity_response:?}");
1860        let similarity_response = similarity_response.first().unwrap();
1861        assert_eq!(
1862            similarity_response.sha256,
1863            hex::encode(&test_elf_meta.sha256),
1864            "Similarity response should have had the hash of the original ELF"
1865        );
1866        for (algo, sim) in &similarity_response.algorithms {
1867            match algo {
1868                malwaredb_api::SimilarityHashType::LZJD => {
1869                    assert!(*sim > 0.0f32);
1870                }
1871                malwaredb_api::SimilarityHashType::SSDeep => {
1872                    assert!(*sim > 80.0f32);
1873                }
1874                malwaredb_api::SimilarityHashType::TLSH => {
1875                    assert!(*sim <= 20f32);
1876                }
1877                _ => {}
1878            }
1879        }
1880
1881        let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1.as_slice())
1882            .context("failed to get `HashType::SHA1` from string")?;
1883        let response_sha256 = db
1884            .retrieve_sample(1, &test_elf_hashtype)
1885            .await
1886            .context("could not get SHA-256 hash from test sample")
1887            .unwrap();
1888        assert_eq!(response_sha256, hex::encode(&test_elf_meta.sha256));
1889
1890        let test_bogus_hash =
1891            HashType::try_from("d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0")
1892                .context("failed to get `HashType` from static string")?;
1893        assert!(
1894            db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1895            "Getting a file with a bogus hash should have failed."
1896        );
1897
1898        let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1899        let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1900        let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1901
1902        let known_type =
1903            KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1904
1905        assert!(
1906            db.add_file(
1907                &test_pdf_meta,
1908                known_type,
1909                1,
1910                default_source_id,
1911                pdf_type,
1912                None
1913            )
1914            .await
1915            .context("failed to insert a test pdf")?
1916            .is_new
1917        );
1918        eprintln!("Added PDF to the DB");
1919
1920        let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1921        let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1922        let rtf_type = db
1923            .get_type_id_for_bytes(&test_rtf)
1924            .await
1925            .context("failed to get file type id for rtf")?;
1926
1927        let known_type =
1928            KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1929
1930        assert!(
1931            db.add_file(
1932                &test_rtf_meta,
1933                known_type,
1934                1,
1935                default_source_id,
1936                rtf_type,
1937                None
1938            )
1939            .await
1940            .context("failed to insert a test rtf")?
1941            .is_new
1942        );
1943        eprintln!("Added RTF to the DB");
1944
1945        let report = db
1946            .get_sample_report(
1947                1,
1948                &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap(),
1949            )
1950            .await
1951            .context("failed to get report for test rtf")?;
1952        assert!(report
1953            .clone()
1954            .filecommand
1955            .unwrap()
1956            .contains("Rich Text Format"));
1957        println!("Report: {report}");
1958
1959        assert!(db
1960            .get_sample_report(
1961                999,
1962                &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap()
1963            )
1964            .await
1965            .is_err());
1966
1967        #[cfg(feature = "vt")]
1968        {
1969            assert!(report.vt.is_some());
1970            let files_needing_vt = db
1971                .files_without_vt_records(10)
1972                .await
1973                .context("failed to get files without VT records")?;
1974            assert!(files_needing_vt.len() > 2);
1975            println!(
1976                "{} files needing VT data: {files_needing_vt:?}",
1977                files_needing_vt.len()
1978            );
1979        }
1980
1981        #[cfg(not(feature = "vt"))]
1982        {
1983            assert!(report.vt.is_none());
1984        }
1985
1986        let reset = db
1987            .reset_api_keys()
1988            .await
1989            .context("failed to reset all API keys")?;
1990        eprintln!("Cleared {reset} api keys.");
1991
1992        let db_info = db.db_info().await.context("failed to get database info")?;
1993        eprintln!("DB Info: {db_info:?}");
1994
1995        let data_types = db
1996            .get_known_data_types()
1997            .await
1998            .context("failed to get data types")?;
1999        for data_type in data_types {
2000            println!("{data_type:?}");
2001        }
2002
2003        let sources = db
2004            .list_sources()
2005            .await
2006            .context("failed to list sources second time")?;
2007        eprintln!("DB has {} sources:", sources.len());
2008        for source in sources {
2009            println!("{source}");
2010        }
2011
2012        let file_types_counts = db
2013            .file_types_counts()
2014            .await
2015            .context("failed to get file types and counts")?;
2016        for (name, count) in file_types_counts {
2017            println!("{name}: {count}");
2018            assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
2019        }
2020
2021        let fatmacho =
2022            include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
2023                .to_vec();
2024        let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
2025        let fatmacho_type = db
2026            .get_type_id_for_bytes(&fatmacho)
2027            .await
2028            .context("failed to get file type for Fat Mach-O")?;
2029        let known_type = KnownType::new(&fatmacho)
2030            .context("failed to parse Fat Mach-O from type crate's test data")?;
2031
2032        assert!(
2033            db.add_file(
2034                &fatmacho_meta,
2035                known_type,
2036                1,
2037                default_source_id,
2038                fatmacho_type,
2039                None
2040            )
2041            .await
2042            .context("failed to insert a test Fat Mach-O")?
2043            .is_new
2044        );
2045        eprintln!("Added Fat Mach-O to the DB");
2046
2047        let file_types_counts = db
2048            .file_types_counts()
2049            .await
2050            .context("failed to get file types and counts")?;
2051        for (name, count) in &file_types_counts {
2052            println!("{name}: {count}");
2053        }
2054
2055        assert_eq!(
2056            *file_types_counts.get("Mach-O").unwrap(),
2057            4,
2058            "Expected 4 Mach-O files, got {:?}",
2059            file_types_counts.get("Mach-O")
2060        );
2061
2062        let allowed_files = db
2063            .user_allowed_files_by_sha256(1, None)
2064            .await
2065            .context("failed to get allowed files")?;
2066        assert_eq!(allowed_files.0.len(), 8);
2067
2068        let allowed_files = db
2069            .user_allowed_files_by_sha256(1, Some(allowed_files.1))
2070            .await
2071            .context("failed to get allowed files")?;
2072        assert!(allowed_files.0.is_empty());
2073
2074        let malware_label_id = db
2075            .create_label(MALWARE_LABEL, None)
2076            .await
2077            .context("failed to create first label")?;
2078        let ransomware_label_id = db
2079            .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
2080            .await
2081            .context("failed to create malware sub-label")?;
2082        let labels = db.get_labels().await.context("failed to get labels")?;
2083
2084        assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
2085        for label in labels.0 {
2086            if label.name == RANSOMWARE_LABEL {
2087                assert_eq!(label.id, ransomware_label_id);
2088                assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
2089            }
2090        }
2091
2092        // Use this file as a stand-un for an unknown file type
2093        let source_code = include_bytes!("mod.rs");
2094        let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
2095        let known_type =
2096            KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
2097
2098        assert!(matches!(known_type, KnownType::Unknown(_)));
2099
2100        let unknown_type: Vec<FileType> = db
2101            .get_known_data_types()
2102            .await?
2103            .into_iter()
2104            .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
2105            .collect();
2106        let unknown_type_id = unknown_type.first().unwrap().id;
2107        assert!(db.get_type_id_for_bytes(source_code).await.is_err());
2108        db.enable_keep_unknown_files()
2109            .await
2110            .context("failed to enable keeping of unknown files")?;
2111        let source_type = db
2112            .get_type_id_for_bytes(source_code)
2113            .await
2114            .context("failed to type id for source code unknown type example")?;
2115        assert_eq!(source_type, unknown_type_id);
2116        eprintln!("Unknown file type ID: {source_type}");
2117        assert!(
2118            db.add_file(
2119                &source_meta,
2120                known_type,
2121                1,
2122                default_source_id,
2123                unknown_type_id,
2124                None
2125            )
2126            .await
2127            .context("failed to add Rust source code file")?
2128            .is_new
2129        );
2130        eprintln!("Added Rust source code to the DB");
2131
2132        #[cfg(feature = "yara")]
2133        assert!(db.get_unfinished_yara_tasks().await?.is_empty());
2134
2135        db.reset_own_api_key(0)
2136            .await
2137            .context("failed to clear own API key uid 0")?;
2138
2139        db.deactivate_user(0)
2140            .await
2141            .context("failed to clear password and API key for uid 0")?;
2142
2143        Ok(())
2144    }
2145}