Skip to main content

malwaredb_server/db/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Postgres is the database used by MalwareDB.
4//! However, SQLite will be used for unit testing or for small instances of MalwareDB. This option
5//! can be allowed by using the `sqlite` feature flag. When using SQLite, MalwareDB will calculate
6//! the distances for the similarity hashes.
7
8/// Malware DB Administrative functions
9#[cfg(any(test, feature = "admin"))]
10mod admin;
11/// Postgres functions
12mod pg;
13
14/// `SQLite` functionality
15#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18/// Custom `SQLite` functions
19#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22/// File Metadata convenience data structure
23pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33    digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48/// The maximum amount of partial hash and/or partial file name search results to prevent performance issues
49pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51/// Migration action
52#[derive(Copy, Clone)]
53pub enum Migration {
54    /// At run: check if a migration is needed
55    Check,
56
57    /// Admin feature: do the migration
58    #[cfg(any(test, feature = "admin"))]
59    Migrate,
60}
61
62/// Database connection handle
63#[derive(Debug)]
64pub enum DatabaseType {
65    /// Postgres database
66    Postgres(Postgres),
67
68    /// `SQLite` database
69    #[cfg(any(test, feature = "sqlite"))]
70    SQLite(Sqlite),
71}
72
73/// Version information and basic stats for the database
74#[derive(Debug)]
75pub struct DatabaseInformation {
76    /// Version string of the database
77    pub version: String,
78
79    /// Human-readable database size
80    pub size: String,
81
82    /// Number of file samples in Malware DB
83    pub num_files: u64,
84
85    /// Number of user accounts
86    pub num_users: u32,
87
88    /// Number of user groups
89    pub num_groups: u32,
90
91    /// Number of sample sources
92    pub num_sources: u32,
93}
94
95/// Data returned when adding a new sample
96pub struct FileAddedResult {
97    /// File ID
98    pub file_id: u64,
99
100    /// Whether the file was added as a new entry.
101    /// This is false if the sample was already known to Malware DB.
102    pub is_new: bool,
103}
104
105/// Malware DB configuration which is stored in the database
106#[derive(Debug)]
107pub struct MDBConfig {
108    /// The name of this instance of Malware DB
109    pub name: String,
110
111    /// Whether samples are stored compressed
112    pub compression: bool,
113
114    /// Whether Malware DB can send samples to Virus Total
115    pub send_samples_to_vt: bool,
116
117    /// If Malware DB should keep unknown files
118    pub keep_unknown_files: bool,
119
120    /// If samples are to be encrypted, which key?
121    pub(crate) default_key: Option<u32>,
122}
123
124/// VT record information for files in Malware DB
125#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
126#[cfg(feature = "vt")]
127#[derive(Debug, Clone, Copy)]
128pub struct VtStats {
129    /// Files marked as clean
130    pub clean_records: u32,
131
132    /// Files marked as malicious
133    pub hits_records: u32,
134
135    /// Files without VT records
136    pub files_without_records: u32,
137}
138
139impl DatabaseType {
140    /// Get a database connection from a configuration string
141    ///
142    /// # Errors
143    ///
144    /// * If there's a connectivity issue to Postgres, an error will result
145    /// * If the `SQLite` file cannot be created or opened, an error will result
146    /// * If `SQLite` is the type but Malware DB wasn't compiled with the sqlite feature, an error will result
147    /// * If the format or database type isn't known, an error will result
148    pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
149        #[cfg(any(test, feature = "sqlite"))]
150        if arg.starts_with("file:") {
151            let new_conn_str = arg.trim_start_matches("file:");
152            let db = DatabaseType::SQLite(Sqlite::new(new_conn_str)?);
153            db.migrate(Migration::Check).await?;
154            return Ok(db);
155        }
156
157        if arg.starts_with("postgres") {
158            let new_conn_str = arg.trim_start_matches("postgres");
159            let db = DatabaseType::Postgres(Postgres::new(new_conn_str, server_ca).await?);
160            db.migrate(Migration::Check).await?;
161            return Ok(db);
162        }
163
164        bail!("unknown database type `{arg}`")
165    }
166
167    /// Set the flag allowing uploads to Virus Total.
168    ///
169    /// # Errors
170    ///
171    /// If there's a connectivity issue to Postgres, an error will result
172    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
173    #[cfg(feature = "vt")]
174    pub async fn enable_vt_upload(&self) -> Result<()> {
175        match self {
176            DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
177            #[cfg(any(test, feature = "sqlite"))]
178            DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
179        }
180    }
181
182    /// Set the flag preventing uploads to Virus Total.
183    ///
184    /// # Errors
185    ///
186    /// If there's a connectivity issue to Postgres, an error will result
187    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
188    #[cfg(feature = "vt")]
189    pub async fn disable_vt_upload(&self) -> Result<()> {
190        match self {
191            DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
192            #[cfg(any(test, feature = "sqlite"))]
193            DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
194        }
195    }
196
197    /// Get the SHA-256 hashes of the files which don't have VT records
198    ///
199    /// # Errors
200    ///
201    /// If there's a connectivity issue to Postgres, an error will result
202    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
203    #[cfg(feature = "vt")]
204    pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
205        match self {
206            DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
207            #[cfg(any(test, feature = "sqlite"))]
208            DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
209        }
210    }
211
212    /// Store the VT results: AV hits and detailed report, or lack of any AV hits
213    ///
214    /// # Errors
215    ///
216    /// If there's a connectivity issue to Postgres, an error will result
217    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
218    #[cfg(feature = "vt")]
219    pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
220        match self {
221            DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
222            #[cfg(any(test, feature = "sqlite"))]
223            DatabaseType::SQLite(sl) => sl.store_vt_record(results),
224        }
225    }
226
227    /// Quick statistics regarding the data contained for VT information for our samples
228    ///
229    /// # Errors
230    ///
231    /// If there's a connectivity issue to Postgres, an error will result
232    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
233    #[cfg(feature = "vt")]
234    pub async fn get_vt_stats(&self) -> Result<VtStats> {
235        match self {
236            DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
237            #[cfg(any(test, feature = "sqlite"))]
238            DatabaseType::SQLite(sl) => sl.get_vt_stats(),
239        }
240    }
241
242    /// Add the Yara search to the database
243    ///
244    /// # Errors
245    ///
246    /// If there's a connectivity issue to Postgres, an error will result
247    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
248    #[cfg(feature = "yara")]
249    pub async fn add_yara_search(
250        &self,
251        uid: u32,
252        yara_string: &str,
253        yara_bytes: &[u8],
254    ) -> Result<uuid::Uuid> {
255        match self {
256            DatabaseType::Postgres(pg) => pg.add_yara_search(uid, yara_string, yara_bytes).await,
257            #[cfg(any(test, feature = "sqlite"))]
258            DatabaseType::SQLite(sl) => sl.add_yara_search(uid, yara_string, yara_bytes),
259        }
260    }
261
262    /// Get unfinished Yara tasks for processing.
263    ///
264    /// # Errors
265    ///
266    /// If there's a connectivity issue to Postgres, an error will result
267    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
268    #[cfg(feature = "yara")]
269    pub async fn get_unfinished_yara_tasks(&self) -> Result<Vec<crate::yara::YaraTask>> {
270        match self {
271            DatabaseType::Postgres(pg) => pg.get_unfinished_yara_tasks().await,
272            #[cfg(any(test, feature = "sqlite"))]
273            DatabaseType::SQLite(sl) => sl.get_unfinished_yara_tasks(),
274        }
275    }
276
277    /// Add a Yara match to the database
278    ///
279    /// # Errors
280    ///
281    /// If there's a connectivity issue to Postgres, an error will result
282    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
283    #[cfg(feature = "yara")]
284    pub async fn add_yara_match(
285        &self,
286        id: uuid::Uuid,
287        rule_name: &str,
288        file_sha256: &str,
289    ) -> Result<()> {
290        match self {
291            DatabaseType::Postgres(pg) => pg.add_yara_match(id, rule_name, file_sha256).await,
292            #[cfg(any(test, feature = "sqlite"))]
293            DatabaseType::SQLite(sl) => sl.add_yara_match(id, rule_name, file_sha256),
294        }
295    }
296
297    /// Indicate that the Yara search task has finished
298    ///
299    /// # Errors
300    ///
301    /// If there's a connectivity issue to Postgres, an error will result
302    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
303    #[cfg(feature = "yara")]
304    pub async fn mark_yara_task_as_finished(&self, id: uuid::Uuid) -> Result<()> {
305        match self {
306            DatabaseType::Postgres(pg) => pg.mark_yara_task_as_finished(id).await,
307            #[cfg(any(test, feature = "sqlite"))]
308            DatabaseType::SQLite(sl) => sl.mark_yara_task_as_finished(id),
309        }
310    }
311
312    /// Add the last file ID for the next iteration
313    ///
314    /// # Errors
315    ///
316    /// If there's a connectivity issue to Postgres, an error will result
317    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
318    #[cfg(feature = "yara")]
319    pub async fn yara_add_next_file_id(&self, id: uuid::Uuid, file_id: u64) -> Result<()> {
320        match self {
321            DatabaseType::Postgres(pg) => pg.yara_add_next_file_id(id, file_id).await,
322            #[cfg(any(test, feature = "sqlite"))]
323            DatabaseType::SQLite(sl) => sl.yara_add_next_file_id(id, file_id),
324        }
325    }
326
327    /// Get the Yara search results
328    ///
329    /// # Errors
330    ///
331    /// If there's a connectivity issue to Postgres, an error will result
332    #[cfg_attr(docsrs, doc(cfg(feature = "yara")))]
333    #[cfg(feature = "yara")]
334    pub async fn get_yara_results(
335        &self,
336        id: uuid::Uuid,
337        user_id: u32,
338    ) -> Result<malwaredb_api::YaraSearchResponse> {
339        match self {
340            DatabaseType::Postgres(pg) => pg.get_yara_results(id, user_id).await,
341            #[cfg(any(test, feature = "sqlite"))]
342            DatabaseType::SQLite(sl) => sl.get_yara_results(id, user_id),
343        }
344    }
345
346    /// Get the configuration which is stored in the database
347    ///
348    /// # Errors
349    ///
350    /// If there's a connectivity issue to Postgres, an error will result
351    pub async fn get_config(&self) -> Result<MDBConfig> {
352        match self {
353            DatabaseType::Postgres(pg) => pg.get_config().await,
354            #[cfg(any(test, feature = "sqlite"))]
355            DatabaseType::SQLite(sl) => sl.get_config(),
356        }
357    }
358
359    /// Check user credentials, return the API key. Generate if it doesn't exist.
360    ///
361    /// # Errors
362    ///
363    /// * If there's a connectivity issue to Postgres, an error will result
364    /// * If the username and/or password aren't correct, an error will result
365    pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
366        match self {
367            DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
368            #[cfg(any(test, feature = "sqlite"))]
369            DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
370        }
371    }
372
373    /// Get the user's ID from their API key
374    ///
375    /// # Errors
376    ///
377    /// * If there's a connectivity issue to Postgres, an error will result
378    /// * If the api key isn't valid, an error will result
379    pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
380        ensure!(!apikey.is_empty(), "API key was empty");
381        match self {
382            DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
383            #[cfg(any(test, feature = "sqlite"))]
384            DatabaseType::SQLite(sl) => sl.get_uid(apikey),
385        }
386    }
387
388    /// Retrieve information about the database
389    ///
390    /// # Errors
391    ///
392    /// * If there's a connectivity issue to Postgres, an error will result
393    pub async fn db_info(&self) -> Result<DatabaseInformation> {
394        match self {
395            DatabaseType::Postgres(pg) => pg.db_info().await,
396            #[cfg(any(test, feature = "sqlite"))]
397            DatabaseType::SQLite(sl) => sl.db_info(),
398        }
399    }
400
401    /// Retrieve the names of the groups and sources the user is part of and has access to
402    ///
403    /// # Errors
404    ///
405    /// * If there's a connectivity issue to Postgres, an error will result
406    /// * If the user ID isn't valid, an error will result
407    pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
408        match self {
409            DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
410            #[cfg(any(test, feature = "sqlite"))]
411            DatabaseType::SQLite(sl) => sl.get_user_info(uid),
412        }
413    }
414
415    /// Retrieve the source information available to the specified user
416    ///
417    /// # Errors
418    ///
419    /// * If there's a connectivity issue to Postgres, an error will result
420    /// * If the user ID isn't valid, an error will result
421    pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
422        match self {
423            DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
424            #[cfg(any(test, feature = "sqlite"))]
425            DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
426        }
427    }
428
429    /// Let the user clear their own API key to log out from all systems
430    ///
431    /// # Errors
432    ///
433    /// * If there's a connectivity issue to Postgres, an error will result
434    /// * If the user ID isn't valid, an error will result
435    pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
436        match self {
437            DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
438            #[cfg(any(test, feature = "sqlite"))]
439            DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
440        }
441    }
442
443    /// Retrieve the supported data type information
444    ///
445    /// # Errors
446    ///
447    /// If there's a connectivity issue to Postgres, an error will result
448    pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
449        match self {
450            DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
451            #[cfg(any(test, feature = "sqlite"))]
452            DatabaseType::SQLite(sl) => sl.get_known_data_types(),
453        }
454    }
455
456    /// Get all labels from Malware DB
457    ///
458    /// # Errors
459    ///
460    /// If there's a connectivity issue to Postgres, an error will result
461    pub async fn get_labels(&self) -> Result<Labels> {
462        match self {
463            DatabaseType::Postgres(pg) => pg.get_labels().await,
464            #[cfg(any(test, feature = "sqlite"))]
465            DatabaseType::SQLite(sl) => sl.get_labels(),
466        }
467    }
468
469    /// Get the corresponding type ID for a buffer representing a file
470    ///
471    /// # Errors
472    ///
473    /// If there's a connectivity issue to Postgres, an error will result
474    pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
475        match self {
476            DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
477            #[cfg(any(test, feature = "sqlite"))]
478            DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
479        }
480    }
481
482    /// Check that a user has been granted access data for the specific source
483    ///
484    /// # Errors
485    ///
486    /// * If there's a connectivity issue to Postgres, an error will result
487    /// * If the user or source ID(s) aren't valid, an error will result
488    pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
489        match self {
490            DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
491            #[cfg(any(test, feature = "sqlite"))]
492            DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
493        }
494    }
495
496    /// Check to see if the user is an administrator. The user must be a member of the
497    /// admin group (group ID 0), or a one group below (a group with the parent group id of 0).
498    ///
499    /// # Errors
500    ///
501    /// * If there's a connectivity issue to Postgres, an error will result
502    /// * If the user ID isn't valid, an error will result
503    pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
504        match self {
505            DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
506            #[cfg(any(test, feature = "sqlite"))]
507            DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
508        }
509    }
510
511    /// Add a file's metadata to the database, returning true if this is a new entry
512    ///
513    /// # Errors
514    ///
515    /// * If there's a connectivity issue to Postgres, an error will result
516    /// * If the source doesn't exist or the user isn't part of a member group, an error will result
517    pub async fn add_file(
518        &self,
519        meta: &FileMetadata,
520        known_type: KnownType<'_>,
521        uid: u32,
522        sid: u32,
523        ftype: u32,
524        parent: Option<u64>,
525    ) -> Result<FileAddedResult> {
526        match self {
527            DatabaseType::Postgres(pg) => {
528                pg.add_file(meta, known_type, uid, sid, ftype, parent).await
529            }
530            #[cfg(any(test, feature = "sqlite"))]
531            DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
532        }
533    }
534
535    /// Search for allowed samples based on partial search and/or file name
536    ///
537    /// # Errors
538    ///
539    /// * If there's a connectivity issue to Postgres, an error will result
540    pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
541        match self {
542            DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
543            #[cfg(any(test, feature = "sqlite"))]
544            DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
545        }
546    }
547
548    /// Delete old pagination searches
549    ///
550    /// # Errors
551    ///
552    /// An error would occur if the Postgres server couldn't be reached.
553    pub async fn cleanup(&self) -> Result<u64> {
554        match self {
555            DatabaseType::Postgres(pg) => pg.cleanup().await,
556            #[cfg(any(test, feature = "sqlite"))]
557            DatabaseType::SQLite(sl) => sl.cleanup(),
558        }
559    }
560
561    /// Retrieve the SHA-256 hash of the sample while checking that the user is permitted
562    /// to access to it
563    ///
564    /// # Errors
565    ///
566    /// * If there's a connectivity issue to Postgres, an error will result
567    /// * If the file doesn't exist or the user isn't allowed access, an error will result
568    pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
569        match self {
570            DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
571            #[cfg(any(test, feature = "sqlite"))]
572            DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
573        }
574    }
575
576    /// Retrieve a report for a given sample, if allowed.
577    ///
578    /// # Errors
579    ///
580    /// * If there's a connectivity issue to Postgres, an error will result
581    /// * If the file doesn't exist or the user isn't allowed access, an error will result
582    pub async fn get_sample_report(
583        &self,
584        uid: u32,
585        hash: &HashType,
586    ) -> Result<malwaredb_api::Report> {
587        match self {
588            DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
589            #[cfg(any(test, feature = "sqlite"))]
590            DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
591        }
592    }
593
594    /// Given a collection of similarity hashes, find samples which are similar.
595    ///
596    /// # Errors
597    ///
598    /// If there's a connectivity issue to Postgres, an error will result
599    pub async fn find_similar_samples(
600        &self,
601        uid: u32,
602        sim: &[(malwaredb_api::SimilarityHashType, String)],
603    ) -> Result<Vec<malwaredb_api::SimilarSample>> {
604        match self {
605            DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
606            #[cfg(any(test, feature = "sqlite"))]
607            DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
608        }
609    }
610
611    /// For a given user ID, return the file hashes the person is allowed to know about and the last
612    /// file ID, which can be provided as `next` to get the next batch of hashes.
613    ///
614    /// # Errors
615    ///
616    /// If there's a connectivity issue to Postgres, an error will result
617    pub async fn user_allowed_files_by_sha256(
618        &self,
619        uid: u32,
620        next: Option<u64>,
621    ) -> Result<(Vec<String>, u64)> {
622        match self {
623            DatabaseType::Postgres(pg) => pg.user_allowed_files_by_sha256(uid, next).await,
624            #[cfg(any(test, feature = "sqlite"))]
625            DatabaseType::SQLite(sl) => sl.user_allowed_files_by_sha256(uid, next),
626        }
627    }
628
629    // Private functions
630
631    /// Get the file encryption keys
632    pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
633        match self {
634            DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
635            #[cfg(any(test, feature = "sqlite"))]
636            DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
637        }
638    }
639
640    /// Get the key and AES nonce for a specific file identified by SHA-256 hash
641    pub(crate) async fn get_file_encryption_key_id(
642        &self,
643        hash: &str,
644    ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
645        match self {
646            DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
647            #[cfg(any(test, feature = "sqlite"))]
648            DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
649        }
650    }
651
652    /// Set the AES nonce for a file specified by SHA-256 hash, a `None` value removes it.
653    pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
654        match self {
655            DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
656            #[cfg(any(test, feature = "sqlite"))]
657            DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
658        }
659    }
660
661    /// Checks if a migration is needed, erroring if action is to check and the schema has changed.
662    ///
663    /// # Errors
664    ///
665    /// If there's a connectivity issue to Postgres, an error will result; if a migration is needed
666    /// and not an administrative action, an error results.
667    pub async fn migrate(&self, action: Migration) -> Result<()> {
668        match self {
669            DatabaseType::Postgres(pg) => pg.migrate(action).await,
670            #[cfg(any(test, feature = "sqlite"))]
671            DatabaseType::SQLite(sl) => sl.migrate(action),
672        }
673    }
674
675    // Administrative functions
676
677    /// Set the instance name
678    ///
679    /// # Errors
680    ///
681    /// If there's a connectivity issue to Postgres, an error will result
682    #[cfg(any(test, feature = "admin"))]
683    pub async fn set_name(&self, name: &str) -> Result<()> {
684        match self {
685            DatabaseType::Postgres(pg) => pg.set_name(name).await,
686            #[cfg(any(test, feature = "sqlite"))]
687            DatabaseType::SQLite(sl) => sl.set_name(name),
688        }
689    }
690
691    /// Set the compression flag
692    ///
693    /// # Errors
694    ///
695    /// If there's a connectivity issue to Postgres, an error will result
696    #[cfg(any(test, feature = "admin"))]
697    pub async fn enable_compression(&self) -> Result<()> {
698        match self {
699            DatabaseType::Postgres(pg) => pg.enable_compression().await,
700            #[cfg(any(test, feature = "sqlite"))]
701            DatabaseType::SQLite(sl) => sl.enable_compression(),
702        }
703    }
704
705    /// Unset the compression flag, does not go and decompress files already compressed!
706    ///
707    /// # Errors
708    ///
709    /// If there's a connectivity issue to Postgres, an error will result
710    #[cfg(any(test, feature = "admin"))]
711    pub async fn disable_compression(&self) -> Result<()> {
712        match self {
713            DatabaseType::Postgres(pg) => pg.disable_compression().await,
714            #[cfg(any(test, feature = "sqlite"))]
715            DatabaseType::SQLite(sl) => sl.disable_compression(),
716        }
717    }
718
719    /// Set the keep unknown files flag
720    ///
721    /// # Errors
722    ///
723    /// If there's a connectivity issue to Postgres, an error will result
724    #[cfg(any(test, feature = "admin"))]
725    pub async fn enable_keep_unknown_files(&self) -> Result<()> {
726        match self {
727            DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
728            #[cfg(any(test, feature = "sqlite"))]
729            DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
730        }
731    }
732
733    /// Unset the keep unknown files flag, does not go and remove unknown files!
734    ///
735    /// # Errors
736    ///
737    /// If there's a connectivity issue to Postgres, an error will result
738    #[cfg(any(test, feature = "admin"))]
739    pub async fn disable_keep_unknown_files(&self) -> Result<()> {
740        match self {
741            DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
742            #[cfg(any(test, feature = "sqlite"))]
743            DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
744        }
745    }
746
747    /// Add an encryption key to the database, set it as the default, and return the key ID
748    ///
749    /// # Errors
750    ///
751    /// * If there is a connectivity with Postgres, an error will result
752    #[cfg(any(test, feature = "admin"))]
753    pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
754        match self {
755            DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
756            #[cfg(any(test, feature = "sqlite"))]
757            DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
758        }
759    }
760
761    /// Get the key ID and algorithm names
762    ///
763    /// # Errors
764    ///
765    /// If there is a connectivity with Postgres, an error will result
766    #[cfg(any(test, feature = "admin"))]
767    pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
768        match self {
769            DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
770            #[cfg(any(test, feature = "sqlite"))]
771            DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
772        }
773    }
774
775    /// Create a user account, return the user ID.
776    ///
777    /// # Errors
778    ///
779    /// If there's a connectivity issue to Postgres, an error will result
780    #[allow(clippy::too_many_arguments)]
781    #[cfg(any(test, feature = "admin"))]
782    pub async fn create_user(
783        &self,
784        uname: &str,
785        fname: &str,
786        lname: &str,
787        email: &str,
788        password: Option<String>,
789        organisation: Option<&String>,
790        readonly: bool,
791    ) -> Result<u32> {
792        match self {
793            DatabaseType::Postgres(pg) => {
794                pg.create_user(uname, fname, lname, email, password, organisation, readonly)
795                    .await
796            }
797            #[cfg(any(test, feature = "sqlite"))]
798            DatabaseType::SQLite(sl) => {
799                sl.create_user(uname, fname, lname, email, password, organisation, readonly)
800            }
801        }
802    }
803
804    /// Clear all API keys, either in case of suspected activity, or part of policy
805    ///
806    /// # Errors
807    ///
808    /// If there's a connectivity issue to Postgres, an error will result
809    #[cfg(any(test, feature = "admin"))]
810    pub async fn reset_api_keys(&self) -> Result<u64> {
811        match self {
812            DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
813            #[cfg(any(test, feature = "sqlite"))]
814            DatabaseType::SQLite(sl) => sl.reset_api_keys(),
815        }
816    }
817
818    /// Set a user's password
819    ///
820    /// # Errors
821    ///
822    /// * If there's a connectivity issue to Postgres, an error will result
823    /// * If the user doesn't exist, an error will result
824    #[cfg(any(test, feature = "admin"))]
825    pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
826        match self {
827            DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
828            #[cfg(any(test, feature = "sqlite"))]
829            DatabaseType::SQLite(sl) => sl.set_password(uname, password),
830        }
831    }
832
833    /// Get the complete list of users
834    ///
835    /// # Errors
836    ///
837    /// If there's a connectivity issue to Postgres, an error will result
838    #[cfg(any(test, feature = "admin"))]
839    pub async fn list_users(&self) -> Result<Vec<admin::User>> {
840        match self {
841            DatabaseType::Postgres(pg) => pg.list_users().await,
842            #[cfg(any(test, feature = "sqlite"))]
843            DatabaseType::SQLite(sl) => sl.list_users(),
844        }
845    }
846
847    /// Get the ID of a group from name
848    ///
849    /// # Errors
850    ///
851    /// * If there's a connectivity issue to Postgres, an error will result
852    /// * If the group name isn't valid, an error will result
853    #[cfg(any(test, feature = "admin"))]
854    pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
855        match self {
856            DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
857            #[cfg(any(test, feature = "sqlite"))]
858            DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
859        }
860    }
861
862    /// Update the record for a group
863    ///
864    /// # Errors
865    ///
866    /// * If there's a connectivity issue to Postgres, an error will result
867    /// * If the group doesn't exist, an error will result
868    #[cfg(any(test, feature = "admin"))]
869    pub async fn edit_group(
870        &self,
871        gid: u32,
872        name: &str,
873        desc: &str,
874        parent: Option<u32>,
875    ) -> Result<()> {
876        match self {
877            DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
878            #[cfg(any(test, feature = "sqlite"))]
879            DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
880        }
881    }
882
883    /// Get the complete list of groups
884    ///
885    /// # Errors
886    ///
887    /// If there's a connectivity issue to Postgres, an error will result
888    #[cfg(any(test, feature = "admin"))]
889    pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
890        match self {
891            DatabaseType::Postgres(pg) => pg.list_groups().await,
892            #[cfg(any(test, feature = "sqlite"))]
893            DatabaseType::SQLite(sl) => sl.list_groups(),
894        }
895    }
896
897    /// Grant a user membership to a group, both by id.
898    ///
899    /// # Errors
900    ///
901    /// * If there's a connectivity issue to Postgres, an error will result
902    /// * If the user or group doesn't exist, an error will result
903    #[cfg(any(test, feature = "admin"))]
904    pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
905        match self {
906            DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
907            #[cfg(any(test, feature = "sqlite"))]
908            DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
909        }
910    }
911
912    /// Grand a group access to a source, both by id.
913    ///
914    /// # Errors
915    ///
916    /// * If there's a connectivity issue to Postgres, an error will result
917    /// * If the group or source doesn't exist, an error will result
918    #[cfg(any(test, feature = "admin"))]
919    pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
920        match self {
921            DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
922            #[cfg(any(test, feature = "sqlite"))]
923            DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
924        }
925    }
926
927    /// Create a new group, returning the group ID
928    ///
929    /// # Errors
930    ///
931    /// * If there's a connectivity issue to Postgres, an error will result
932    /// * If the group name is already taken, an error will result
933    #[cfg(any(test, feature = "admin"))]
934    pub async fn create_group(
935        &self,
936        name: &str,
937        description: &str,
938        parent: Option<u32>,
939    ) -> Result<u32> {
940        match self {
941            DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
942            #[cfg(any(test, feature = "sqlite"))]
943            DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
944        }
945    }
946
947    /// Get the complete list of sources
948    ///
949    /// # Errors
950    ///
951    /// If there's a connectivity issue to Postgres, an error will result
952    #[cfg(any(test, feature = "admin"))]
953    pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
954        match self {
955            DatabaseType::Postgres(pg) => pg.list_sources().await,
956            #[cfg(any(test, feature = "sqlite"))]
957            DatabaseType::SQLite(sl) => sl.list_sources(),
958        }
959    }
960
961    /// Create a source, returning the source ID
962    ///
963    /// # Errors
964    ///
965    /// * If there's a connectivity issue to Postgres, an error will result
966    /// * If the source already exists, an error will result
967    #[cfg(any(test, feature = "admin"))]
968    pub async fn create_source(
969        &self,
970        name: &str,
971        description: Option<&str>,
972        url: Option<&str>,
973        date: chrono::DateTime<Local>,
974        releasable: bool,
975        malicious: Option<bool>,
976    ) -> Result<u32> {
977        match self {
978            DatabaseType::Postgres(pg) => {
979                pg.create_source(name, description, url, date, releasable, malicious)
980                    .await
981            }
982            #[cfg(any(test, feature = "sqlite"))]
983            DatabaseType::SQLite(sl) => {
984                sl.create_source(name, description, url, date, releasable, malicious)
985            }
986        }
987    }
988
989    /// Edit a user account setting the specified field values, primarily used by the Admin gui
990    ///
991    /// # Errors
992    ///
993    /// * If there's a connectivity issue to Postgres, an error will result
994    /// * If the user isn't valid, an error will result
995    #[cfg(any(test, feature = "admin"))]
996    pub async fn edit_user(
997        &self,
998        uid: u32,
999        uname: &str,
1000        fname: &str,
1001        lname: &str,
1002        email: &str,
1003        readonly: bool,
1004    ) -> Result<()> {
1005        match self {
1006            DatabaseType::Postgres(pg) => {
1007                pg.edit_user(uid, uname, fname, lname, email, readonly)
1008                    .await
1009            }
1010            #[cfg(any(test, feature = "sqlite"))]
1011            DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
1012        }
1013    }
1014
1015    /// Deactivate, but don't delete, a user's account
1016    ///
1017    /// # Errors
1018    ///
1019    /// * If there's a connectivity issue to Postgres, an error will result
1020    /// * If the user ID isn't valid, an error will result
1021    #[cfg(any(test, feature = "admin"))]
1022    pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
1023        match self {
1024            DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
1025            #[cfg(any(test, feature = "sqlite"))]
1026            DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
1027        }
1028    }
1029
1030    /// File types and number of files per type
1031    ///
1032    /// # Errors
1033    ///
1034    /// * If there's a connectivity issue to Postgres, an error will result
1035    #[cfg(any(test, feature = "admin"))]
1036    pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
1037        match self {
1038            DatabaseType::Postgres(pg) => pg.file_types_counts().await,
1039            #[cfg(any(test, feature = "sqlite"))]
1040            DatabaseType::SQLite(sl) => sl.file_types_counts(),
1041        }
1042    }
1043
1044    /// Create a new label, returning the label ID
1045    ///
1046    /// # Errors
1047    ///
1048    /// * If there's a connectivity issue to Postgres, an error will result
1049    /// * If the parent label doesn't exist, an error will result
1050    /// * If the label name already exists, an error will result
1051    #[cfg(any(test, feature = "admin"))]
1052    pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
1053        match self {
1054            DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
1055            #[cfg(any(test, feature = "sqlite"))]
1056            DatabaseType::SQLite(sl) => sl.create_label(name, parent),
1057        }
1058    }
1059
1060    /// Edit a label name or parent
1061    ///
1062    /// # Errors
1063    ///
1064    /// * If there's a connectivity issue to Postgres, an error will result
1065    /// * If the parent label doesn't exist, an error will result
1066    /// * If the label name already exists, an error will result
1067    #[cfg(any(test, feature = "admin"))]
1068    pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
1069        match self {
1070            DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
1071            #[cfg(any(test, feature = "sqlite"))]
1072            DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
1073        }
1074    }
1075
1076    /// Return the ID for a given label name
1077    ///
1078    /// # Errors
1079    ///
1080    /// * If there's a connectivity issue to Postgres, an error will result
1081    /// * If the label ID is invalid, an error will result
1082    #[cfg(any(test, feature = "admin"))]
1083    pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
1084        match self {
1085            DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
1086            #[cfg(any(test, feature = "sqlite"))]
1087            DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
1088        }
1089    }
1090
1091    /// Associate an existing label by its IDs with a file.
1092    ///
1093    /// # Errors
1094    ///
1095    /// * Incorrect IDs will result in an error
1096    /// * If the file already has the given label associated
1097    /// * If there is a network or connection issue with Postgres
1098    #[cfg(any(test, feature = "admin"))]
1099    pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
1100        match self {
1101            DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
1102            #[cfg(any(test, feature = "sqlite"))]
1103            DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
1104        }
1105    }
1106}
1107
1108/// Hash a password string with [Argon2]
1109///
1110/// # Errors
1111///
1112/// An error may result if Argon has an error
1113pub fn hash_password(password: &str) -> Result<String> {
1114    let salt = SaltString::generate(&mut OsRng);
1115    let argon2 = Argon2::default();
1116    Ok(argon2
1117        .hash_password(password.as_bytes(), &salt)?
1118        .to_string())
1119}
1120
1121/// Generate a new, random API key
1122#[must_use]
1123pub fn random_bytes_api_key() -> String {
1124    let key1 = uuid::Uuid::new_v4();
1125    let key2 = uuid::Uuid::new_v4();
1126    let key1 = key1.to_string().replace('-', "");
1127    let key2 = key2.to_string().replace('-', "");
1128    format!("{key1}{key2}")
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133    use super::*;
1134    #[cfg(feature = "vt")]
1135    use crate::vt::VtUpdater;
1136
1137    use std::fs;
1138    #[cfg(feature = "vt")]
1139    use std::sync::Arc;
1140    #[cfg(feature = "vt")]
1141    use std::time::SystemTime;
1142
1143    use anyhow::Context;
1144    use fuzzyhash::FuzzyHash;
1145    use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
1146    use malwaredb_lzjd::{LZDict, Murmur3HashState};
1147    use tlsh_fixed::TlshBuilder;
1148    use uuid::Uuid;
1149
1150    const MALWARE_LABEL: &str = "malware";
1151    const RANSOMWARE_LABEL: &str = "ransomware";
1152
1153    fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1154        let mut hashes = vec![];
1155
1156        hashes.push((
1157            malwaredb_api::SimilarityHashType::SSDeep,
1158            FuzzyHash::new(data).to_string(),
1159        ));
1160
1161        let mut builder = TlshBuilder::new(
1162            tlsh_fixed::BucketKind::Bucket256,
1163            tlsh_fixed::ChecksumKind::ThreeByte,
1164            tlsh_fixed::Version::Version4,
1165        );
1166
1167        builder.update(data);
1168
1169        if let Ok(hasher) = builder.build() {
1170            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1171        }
1172
1173        let build_hasher = Murmur3HashState::default();
1174        let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1175        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1176
1177        malwaredb_api::SimilarSamplesRequest { hashes }
1178    }
1179
1180    async fn pg_config() -> Postgres {
1181        // create user malwaredbtesting with password 'malwaredbtesting';
1182        // create database malwaredbtesting owner malwaredbtesting;
1183        const CONNECTION_STRING: &str =
1184            "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1185
1186        if let Ok(pg_port) = std::env::var("PG_PORT") {
1187            // Get the port number to run in Github CI
1188            let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1189            Postgres::new(&conn_string, None)
1190                .await
1191                .context(format!(
1192                    "failed to connect to postgres with specified port {pg_port}"
1193                ))
1194                .unwrap()
1195        } else {
1196            Postgres::new(CONNECTION_STRING, None).await.unwrap()
1197        }
1198    }
1199
1200    #[tokio::test]
1201    #[ignore = "don't run this in CI"]
1202    async fn pg() {
1203        let psql = pg_config().await;
1204        psql.delete().await.unwrap();
1205
1206        let psql = pg_config().await;
1207        let db = DatabaseType::Postgres(psql);
1208        everything(&db).await.unwrap();
1209
1210        #[cfg(feature = "vt")]
1211        {
1212            let db_config = db.get_config().await.unwrap();
1213            let state = crate::State {
1214                port: 8080,
1215                directory: None,
1216                max_upload: 10 * 1024 * 1024,
1217                ip: "127.0.0.1".parse().unwrap(),
1218                db_type: Arc::new(db),
1219                db_config,
1220                keys: HashMap::new(),
1221                started: SystemTime::now(),
1222                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1223                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1224                }),
1225                tls_config: None,
1226                mdns: None,
1227            };
1228
1229            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1230
1231            vt.updater().await.unwrap();
1232            println!("PG: Did VT ops!");
1233
1234            let psql = pg_config().await;
1235
1236            let vt_stats = psql
1237                .get_vt_stats()
1238                .await
1239                .context("failed to get Postgres VT Stats")
1240                .unwrap();
1241            println!("{vt_stats:?}");
1242            assert!(
1243                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1244            );
1245        }
1246
1247        // Re-create the Postgres object so we can do some clean-up
1248        let psql = pg_config().await;
1249        psql.delete().await.unwrap();
1250    }
1251
1252    #[tokio::test]
1253    async fn sqlite() {
1254        const DB_FILE: &str = "testing_sqlite.db";
1255        if std::path::Path::new(DB_FILE).exists() {
1256            fs::remove_file(DB_FILE)
1257                .context(format!("failed to delete old SQLite file {DB_FILE}"))
1258                .unwrap();
1259        }
1260
1261        let sqlite = Sqlite::new(DB_FILE)
1262            .context(format!("failed to create SQLite instance for {DB_FILE}"))
1263            .unwrap();
1264
1265        let db = DatabaseType::SQLite(sqlite);
1266        everything(&db).await.unwrap();
1267
1268        #[cfg(feature = "vt")]
1269        {
1270            let db_config = db.get_config().await.unwrap();
1271            let state = crate::State {
1272                port: 8080,
1273                directory: None,
1274                max_upload: 10 * 1024 * 1024,
1275                ip: "127.0.0.1".parse().unwrap(),
1276                db_type: Arc::new(db),
1277                db_config,
1278                keys: HashMap::new(),
1279                started: SystemTime::now(),
1280                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1281                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1282                }),
1283                tls_config: None,
1284                mdns: None,
1285            };
1286
1287            let sqlite_second = Sqlite::new(DB_FILE)
1288                .context(format!("failed to create SQLite instance for {DB_FILE}"))
1289                .unwrap();
1290
1291            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1292
1293            vt.updater().await.unwrap();
1294            println!("Sqlite: Did VT ops!");
1295            let vt_stats = sqlite_second
1296                .get_vt_stats()
1297                .context("failed to get Sqlite VT Stats")
1298                .unwrap();
1299            println!("{vt_stats:?}");
1300            assert!(
1301                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1302            );
1303        }
1304
1305        fs::remove_file(DB_FILE)
1306            .context(format!("failed to delete SQLite file {DB_FILE}"))
1307            .unwrap();
1308    }
1309
1310    #[allow(clippy::too_many_lines)]
1311    async fn everything(db: &DatabaseType) -> Result<()> {
1312        const ADMIN_UNAME: &str = "admin";
1313        const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1314
1315        db.set_name("Testing Database")
1316            .await
1317            .context("setting instance name failed")?;
1318
1319        assert!(
1320            db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1321            "Authentication without password should have failed."
1322        );
1323
1324        db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1325            .await
1326            .context("failed to set admin password")?;
1327
1328        let admin_api_key = db
1329            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1330            .await
1331            .context("unable to get api key for admin")?;
1332        println!("API key: {admin_api_key}");
1333        assert_eq!(admin_api_key.len(), 64);
1334
1335        assert_eq!(
1336            db.get_uid(&admin_api_key).await?,
1337            0,
1338            "Unable to get UID given the API key"
1339        );
1340
1341        let admin_api_key_again = db
1342            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1343            .await
1344            .context("unable to get api key a second time for admin")?;
1345
1346        assert_eq!(
1347            admin_api_key, admin_api_key_again,
1348            "API keys didn't match the second time."
1349        );
1350
1351        let bad_password = "this_is_totally_not_my_password!!";
1352        eprintln!("Testing API login with incorrect password.");
1353        assert!(
1354            db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1355            "Authenticating as admin with a bad password should have failed."
1356        );
1357
1358        let admin_is_admin = db
1359            .user_is_admin(0)
1360            .await
1361            .context("unable to see if admin (uid 0) is an admin")?;
1362        assert!(admin_is_admin);
1363
1364        let new_user_uname = "testuser";
1365        let new_user_email = "test@example.com";
1366        let new_user_password = "some_awesome_password_++";
1367        let new_id = db
1368            .create_user(
1369                new_user_uname,
1370                new_user_uname,
1371                new_user_uname,
1372                new_user_email,
1373                Some(new_user_password.into()),
1374                None,
1375                false,
1376            )
1377            .await
1378            .context(format!("failed to create user {new_user_uname}"))?;
1379
1380        let passwordless_user_id = db
1381            .create_user(
1382                "passwordless_user",
1383                "passwordless_user",
1384                "passwordless_user",
1385                "passwordless_user@example.com",
1386                None,
1387                None,
1388                false,
1389            )
1390            .await
1391            .context("failed to create passwordless_user")?;
1392
1393        for user in &db.list_users().await.context("failed to list users")? {
1394            if user.id == passwordless_user_id {
1395                assert_eq!(user.uname, "passwordless_user");
1396            }
1397        }
1398
1399        db.edit_user(
1400            passwordless_user_id,
1401            "passwordless_user_2",
1402            "passwordless_user_2",
1403            "passwordless_user_2",
1404            "passwordless_user_2@something.com",
1405            false,
1406        )
1407        .await
1408        .context(format!(
1409            "failed to alter 'passwordless' user, id {passwordless_user_id}"
1410        ))?;
1411
1412        for user in &db.list_users().await.context("failed to list users")? {
1413            if user.id == passwordless_user_id {
1414                assert_eq!(user.uname, "passwordless_user_2");
1415            }
1416        }
1417
1418        assert!(
1419            new_id > 0,
1420            "Weird UID created for user {new_user_uname}: {new_id}"
1421        );
1422
1423        assert!(
1424            db.create_user(
1425                new_user_uname,
1426                new_user_uname,
1427                new_user_uname,
1428                new_user_email,
1429                Some(new_user_password.into()),
1430                None,
1431                false
1432            )
1433            .await
1434            .is_err(),
1435            "Creating a new user with the same user name should fail"
1436        );
1437
1438        let ro_user_name = "ro_user";
1439        let ro_user_password = "ro_user_password";
1440        db.create_user(
1441            ro_user_name,
1442            "ro_user",
1443            "ro_user",
1444            "ro@example.com",
1445            Some(ro_user_password.into()),
1446            None,
1447            true,
1448        )
1449        .await
1450        .context("failed to create read-only user")?;
1451
1452        let ro_user_api_key = db
1453            .authenticate(ro_user_name, ro_user_password)
1454            .await
1455            .context("unable to get api key for read-only user")?;
1456
1457        let new_user_password_change = "some_new_awesomer_password!_++";
1458        db.set_password(new_user_uname, new_user_password_change)
1459            .await
1460            .context("failed to change the password for testuser")?;
1461
1462        let new_user_api_key = db
1463            .authenticate(new_user_uname, new_user_password_change)
1464            .await
1465            .context("unable to get api key for testuser")?;
1466        eprintln!("{new_user_uname} got API key {new_user_api_key}");
1467
1468        assert_eq!(admin_api_key.len(), new_user_api_key.len());
1469
1470        let users = db.list_users().await.context("failed to list users")?;
1471        assert_eq!(
1472            users.len(),
1473            4,
1474            "Four users were created, yet there are {} users",
1475            users.len()
1476        );
1477        eprintln!("DB has {} users:", users.len());
1478        let mut passwordless_user_found = false;
1479        for user in users {
1480            println!("{user}");
1481            if user.uname == "passwordless_user_2" {
1482                assert!(!user.has_api_key);
1483                assert!(!user.has_password);
1484                passwordless_user_found = true;
1485            } else {
1486                assert!(user.has_api_key);
1487                assert!(user.has_password);
1488            }
1489        }
1490        assert!(passwordless_user_found);
1491
1492        let new_group_name = "some_new_group";
1493        let new_group_desc = "some_new_group_description";
1494        let new_group_id = 1;
1495        assert_eq!(
1496            db.create_group(new_group_name, new_group_desc, None)
1497                .await
1498                .context("failed to create group")?,
1499            new_group_id,
1500            "New group didn't have the expected ID, expected {new_group_id}"
1501        );
1502
1503        assert!(
1504            db.create_group(new_group_name, new_group_desc, None)
1505                .await
1506                .is_err(),
1507            "Duplicate group name should have failed"
1508        );
1509
1510        db.add_user_to_group(1, 1)
1511            .await
1512            .context("Unable to add uid 1 to gid 1")?;
1513
1514        let ro_user_uid = db
1515            .get_uid(&ro_user_api_key)
1516            .await
1517            .context("Unable to get UID for read-only user")?;
1518        db.add_user_to_group(ro_user_uid, 1)
1519            .await
1520            .context("Unable to add uid 2 to gid 1")?;
1521
1522        let new_admin_group_name = "admin_subgroup";
1523        let new_admin_group_desc = "admin_subgroup_description";
1524        let new_admin_group_id = 2;
1525        // TODO: Figure out why SQLite makes the group_id = 2, but with Postgres it's 3.
1526        assert!(
1527            db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1528                .await
1529                .context("failed to create admin sub-group")?
1530                >= new_admin_group_id,
1531            "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1532        );
1533
1534        let groups = db.list_groups().await.context("failed to list groups")?;
1535        assert_eq!(
1536            groups.len(),
1537            3,
1538            "Three groups were created, yet there are {} groups",
1539            groups.len()
1540        );
1541        eprintln!("DB has {} groups:", groups.len());
1542        for group in groups {
1543            println!("{group}");
1544            if group.id == new_admin_group_id {
1545                assert_eq!(group.parent, Some("admin".to_string()));
1546            }
1547            if group.id == 1 {
1548                let test_user_str = String::from(new_user_uname);
1549                let mut found = false;
1550                for member in group.members {
1551                    if member.uname == test_user_str {
1552                        found = true;
1553                        break;
1554                    }
1555                }
1556                assert!(found, "new user {test_user_str} wasn't in the group");
1557            }
1558        }
1559
1560        let default_source_name = "default_source".to_string();
1561        let default_source_id = db
1562            .create_source(
1563                &default_source_name,
1564                Some("desc_default_source"),
1565                None,
1566                Local::now(),
1567                true,
1568                Some(false),
1569            )
1570            .await
1571            .context("failed to create source `default_source`")?;
1572
1573        db.add_group_to_source(1, default_source_id)
1574            .await
1575            .context("failed to add group 1 to source 1")?;
1576
1577        let another_source_name = "another_source".to_string();
1578        let another_source_id = db
1579            .create_source(
1580                &another_source_name,
1581                Some("yet another file source"),
1582                None,
1583                Local::now(),
1584                true,
1585                Some(false),
1586            )
1587            .await
1588            .context("failed to create source `another_source`")?;
1589
1590        let empty_source_name = "empty_source".to_string();
1591        db.create_source(
1592            &empty_source_name,
1593            Some("empty and unused file source"),
1594            None,
1595            Local::now(),
1596            true,
1597            Some(false),
1598        )
1599        .await
1600        .context("failed to create source `another_source`")?;
1601
1602        db.add_group_to_source(1, another_source_id)
1603            .await
1604            .context("failed to add group 1 to source 1")?;
1605
1606        let sources = db.list_sources().await.context("failed to list sources")?;
1607        eprintln!("DB has {} sources:", sources.len());
1608        for source in sources {
1609            println!("{source}");
1610            assert_eq!(source.files, 0);
1611            if source.id == default_source_id || source.id == another_source_id {
1612                assert_eq!(
1613                    source.groups, 1,
1614                    "default source {default_source_name} should have 1 group"
1615                );
1616            } else {
1617                assert_eq!(source.groups, 0, "groups should zero (empty)");
1618            }
1619        }
1620
1621        let uid = db
1622            .get_uid(&new_user_api_key)
1623            .await
1624            .context("failed to user uid from apikey")?;
1625        let user_info = db
1626            .get_user_info(uid)
1627            .await
1628            .context("failed to get user's available groups and sources")?;
1629        assert!(user_info.sources.contains(&default_source_name));
1630        assert!(!user_info.is_admin);
1631        println!("UserInfoResponse: {user_info:?}");
1632
1633        assert!(
1634            db.allowed_user_source(1, default_source_id)
1635                .await
1636                .context(format!(
1637                    "failed to check that user 1 has access to source {default_source_id}"
1638                ))?,
1639            "User 1 should should have had access to source {default_source_id}"
1640        );
1641
1642        assert!(
1643            !db.allowed_user_source(1, 5)
1644                .await
1645                .context("failed to check that user 1 has access to source 5")?,
1646            "User 1 should should not have had access to source 5"
1647        );
1648
1649        let test_label_id = db
1650            .create_label("TestLabel", None)
1651            .await
1652            .context("failed to create test label")?;
1653        let test_elf_label_id = db
1654            .create_label("TestELF", Some(test_label_id))
1655            .await
1656            .context("failed to create test label")?;
1657
1658        let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1659        let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1660        let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1661
1662        let known_type =
1663            KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1664        assert!(known_type.is_exec(), "ELF should be executable");
1665        eprintln!("ELF type ID: {elf_type}");
1666
1667        let file_addition = db
1668            .add_file(
1669                &test_elf_meta,
1670                known_type.clone(),
1671                1,
1672                default_source_id,
1673                elf_type,
1674                None,
1675            )
1676            .await
1677            .context("failed to insert a test elf")?;
1678        assert!(file_addition.is_new, "File should have been added");
1679        eprintln!("Added ELF to the DB");
1680        db.label_file(file_addition.file_id, test_elf_label_id)
1681            .await
1682            .context("failed to label file")?;
1683
1684        // Search with several fields
1685        let partial_search = SearchRequest {
1686            search: SearchType::Search(SearchRequestParameters {
1687                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1688                labels: Some(vec![String::from("TestELF")]),
1689                file_type: Some(String::from("ELF")),
1690                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1691                ..Default::default()
1692            }),
1693        };
1694        assert!(partial_search.is_valid());
1695        let partial_search_response = db.partial_search(1, partial_search).await?;
1696        assert_eq!(partial_search_response.hashes.len(), 1);
1697        assert_eq!(
1698            partial_search_response.hashes[0],
1699            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1700        );
1701
1702        // Ensure we get a partial result just with the magic
1703        let partial_search = SearchRequest {
1704            search: SearchType::Search(SearchRequestParameters {
1705                partial_hash: None,
1706                labels: None,
1707                file_type: None,
1708                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1709                ..Default::default()
1710            }),
1711        };
1712        assert!(partial_search.is_valid());
1713        let partial_search_response = db.partial_search(1, partial_search).await?;
1714        assert_eq!(partial_search_response.hashes.len(), 1);
1715        assert_eq!(
1716            partial_search_response.hashes[0],
1717            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1718        );
1719
1720        // Should be valid yet return nothing since it's the wrong file type
1721        let partial_search = SearchRequest {
1722            search: SearchType::Search(SearchRequestParameters {
1723                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1724                file_type: Some(String::from("PE32")),
1725                ..Default::default()
1726            }),
1727        };
1728        assert!(partial_search.is_valid());
1729        let partial_search_response = db.partial_search(1, partial_search).await?;
1730        assert_eq!(partial_search_response.hashes.len(), 0);
1731
1732        let partial_search = SearchRequest {
1733            search: SearchType::Search(SearchRequestParameters {
1734                file_name: Some("ppc64".into()),
1735                ..Default::default()
1736            }),
1737        };
1738        assert!(partial_search.is_valid());
1739        let partial_search_response = db.partial_search(1, partial_search).await?;
1740        assert_eq!(partial_search_response.hashes.len(), 1);
1741
1742        // Invalid search request should return empty results
1743        let partial_search = SearchRequest {
1744            search: SearchType::Search(SearchRequestParameters::default()),
1745        };
1746        assert!(!partial_search.is_valid());
1747        let partial_search_response = db.partial_search(1, partial_search).await?;
1748        assert!(partial_search_response.hashes.is_empty());
1749
1750        // Invalid search pagination should return empty results
1751        let partial_search = SearchRequest {
1752            search: SearchType::Continuation(Uuid::default()),
1753        };
1754        assert!(partial_search.is_valid());
1755        let partial_search_response = db.partial_search(1, partial_search).await?;
1756        assert!(partial_search_response.hashes.is_empty());
1757
1758        // Ensure a type representing an unknown type isn't found
1759        assert!(db
1760            .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1761            .await
1762            .is_err());
1763
1764        assert!(
1765            db.add_file(
1766                &test_elf_meta,
1767                known_type.clone(),
1768                ro_user_uid,
1769                default_source_id,
1770                elf_type,
1771                None
1772            )
1773            .await
1774            .is_err(),
1775            "Read-only user should not be able to add a file"
1776        );
1777
1778        let mut test_elf_meta_different_name = test_elf_meta.clone();
1779        test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1780
1781        assert!(
1782            !db.add_file(
1783                &test_elf_meta_different_name,
1784                known_type,
1785                1,
1786                another_source_id,
1787                elf_type,
1788                None
1789            )
1790            .await
1791            .context("failed to insert a test elf again for a different source")?
1792            .is_new
1793        );
1794
1795        let sources = db
1796            .list_sources()
1797            .await
1798            .context("failed to re-list sources")?;
1799        eprintln!(
1800            "DB has {} sources, and a file was added twice:",
1801            sources.len()
1802        );
1803        println!("We should have two sources with one file each, yet only one ELF.");
1804        for source in sources {
1805            println!("{source}");
1806            if source.id == default_source_id || source.id == another_source_id {
1807                assert_eq!(source.files, 1);
1808            } else {
1809                assert_eq!(source.files, 0, "groups should zero (empty)");
1810            }
1811        }
1812
1813        assert!(!db
1814            .get_user_sources(1)
1815            .await
1816            .expect("failed to get user 1's sources")
1817            .sources
1818            .is_empty());
1819
1820        let file_types_counts = db
1821            .file_types_counts()
1822            .await
1823            .context("failed to get file types and counts")?;
1824        for (name, count) in file_types_counts {
1825            println!("{name}: {count}");
1826            assert_eq!(name, "ELF");
1827            assert_eq!(count, 1);
1828        }
1829
1830        let mut test_elf_modified = test_elf.clone();
1831        let random_bytes = Uuid::new_v4();
1832        let mut random_bytes = random_bytes.into_bytes().to_vec();
1833        test_elf_modified.append(&mut random_bytes);
1834        let similarity_request = generate_similarity_request(&test_elf_modified);
1835        let similarity_response = db
1836            .find_similar_samples(1, &similarity_request.hashes)
1837            .await
1838            .context("failed to get similarity response")?;
1839        eprintln!("Similarity response: {similarity_response:?}");
1840        let similarity_response = similarity_response.first().unwrap();
1841        assert_eq!(
1842            similarity_response.sha256,
1843            hex::encode(&test_elf_meta.sha256),
1844            "Similarity response should have had the hash of the original ELF"
1845        );
1846        for (algo, sim) in &similarity_response.algorithms {
1847            match algo {
1848                malwaredb_api::SimilarityHashType::LZJD => {
1849                    assert!(*sim > 0.0f32);
1850                }
1851                malwaredb_api::SimilarityHashType::SSDeep => {
1852                    assert!(*sim > 80.0f32);
1853                }
1854                malwaredb_api::SimilarityHashType::TLSH => {
1855                    assert!(*sim <= 20f32);
1856                }
1857                _ => {}
1858            }
1859        }
1860
1861        let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1.as_slice())
1862            .context("failed to get `HashType::SHA1` from string")?;
1863        let response_sha256 = db
1864            .retrieve_sample(1, &test_elf_hashtype)
1865            .await
1866            .context("could not get SHA-256 hash from test sample")
1867            .unwrap();
1868        assert_eq!(response_sha256, hex::encode(&test_elf_meta.sha256));
1869
1870        let test_bogus_hash =
1871            HashType::try_from("d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0")
1872                .context("failed to get `HashType` from static string")?;
1873        assert!(
1874            db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1875            "Getting a file with a bogus hash should have failed."
1876        );
1877
1878        let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1879        let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1880        let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1881
1882        let known_type =
1883            KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1884
1885        assert!(
1886            db.add_file(
1887                &test_pdf_meta,
1888                known_type,
1889                1,
1890                default_source_id,
1891                pdf_type,
1892                None
1893            )
1894            .await
1895            .context("failed to insert a test pdf")?
1896            .is_new
1897        );
1898        eprintln!("Added PDF to the DB");
1899
1900        let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1901        let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1902        let rtf_type = db
1903            .get_type_id_for_bytes(&test_rtf)
1904            .await
1905            .context("failed to get file type id for rtf")?;
1906
1907        let known_type =
1908            KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1909
1910        assert!(
1911            db.add_file(
1912                &test_rtf_meta,
1913                known_type,
1914                1,
1915                default_source_id,
1916                rtf_type,
1917                None
1918            )
1919            .await
1920            .context("failed to insert a test rtf")?
1921            .is_new
1922        );
1923        eprintln!("Added RTF to the DB");
1924
1925        let report = db
1926            .get_sample_report(
1927                1,
1928                &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap(),
1929            )
1930            .await
1931            .context("failed to get report for test rtf")?;
1932        assert!(report
1933            .clone()
1934            .filecommand
1935            .unwrap()
1936            .contains("Rich Text Format"));
1937        println!("Report: {report}");
1938
1939        assert!(db
1940            .get_sample_report(
1941                999,
1942                &HashType::try_from(test_rtf_meta.sha256.as_slice()).unwrap()
1943            )
1944            .await
1945            .is_err());
1946
1947        #[cfg(feature = "vt")]
1948        {
1949            assert!(report.vt.is_some());
1950            let files_needing_vt = db
1951                .files_without_vt_records(10)
1952                .await
1953                .context("failed to get files without VT records")?;
1954            assert!(files_needing_vt.len() > 2);
1955            println!(
1956                "{} files needing VT data: {files_needing_vt:?}",
1957                files_needing_vt.len()
1958            );
1959        }
1960
1961        #[cfg(not(feature = "vt"))]
1962        {
1963            assert!(report.vt.is_none());
1964        }
1965
1966        let reset = db
1967            .reset_api_keys()
1968            .await
1969            .context("failed to reset all API keys")?;
1970        eprintln!("Cleared {reset} api keys.");
1971
1972        let db_info = db.db_info().await.context("failed to get database info")?;
1973        eprintln!("DB Info: {db_info:?}");
1974
1975        let data_types = db
1976            .get_known_data_types()
1977            .await
1978            .context("failed to get data types")?;
1979        for data_type in data_types {
1980            println!("{data_type:?}");
1981        }
1982
1983        let sources = db
1984            .list_sources()
1985            .await
1986            .context("failed to list sources second time")?;
1987        eprintln!("DB has {} sources:", sources.len());
1988        for source in sources {
1989            println!("{source}");
1990        }
1991
1992        let file_types_counts = db
1993            .file_types_counts()
1994            .await
1995            .context("failed to get file types and counts")?;
1996        for (name, count) in file_types_counts {
1997            println!("{name}: {count}");
1998            assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1999        }
2000
2001        let fatmacho =
2002            include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
2003                .to_vec();
2004        let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
2005        let fatmacho_type = db
2006            .get_type_id_for_bytes(&fatmacho)
2007            .await
2008            .context("failed to get file type for Fat Mach-O")?;
2009        let known_type = KnownType::new(&fatmacho)
2010            .context("failed to parse Fat Mach-O from type crate's test data")?;
2011
2012        assert!(
2013            db.add_file(
2014                &fatmacho_meta,
2015                known_type,
2016                1,
2017                default_source_id,
2018                fatmacho_type,
2019                None
2020            )
2021            .await
2022            .context("failed to insert a test Fat Mach-O")?
2023            .is_new
2024        );
2025        eprintln!("Added Fat Mach-O to the DB");
2026
2027        let file_types_counts = db
2028            .file_types_counts()
2029            .await
2030            .context("failed to get file types and counts")?;
2031        for (name, count) in &file_types_counts {
2032            println!("{name}: {count}");
2033        }
2034
2035        assert_eq!(
2036            *file_types_counts.get("Mach-O").unwrap(),
2037            4,
2038            "Expected 4 Mach-O files, got {:?}",
2039            file_types_counts.get("Mach-O")
2040        );
2041
2042        let allowed_files = db
2043            .user_allowed_files_by_sha256(1, None)
2044            .await
2045            .context("failed to get allowed files")?;
2046        assert_eq!(allowed_files.0.len(), 8);
2047
2048        let allowed_files = db
2049            .user_allowed_files_by_sha256(1, Some(allowed_files.1))
2050            .await
2051            .context("failed to get allowed files")?;
2052        assert!(allowed_files.0.is_empty());
2053
2054        let malware_label_id = db
2055            .create_label(MALWARE_LABEL, None)
2056            .await
2057            .context("failed to create first label")?;
2058        let ransomware_label_id = db
2059            .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
2060            .await
2061            .context("failed to create malware sub-label")?;
2062        let labels = db.get_labels().await.context("failed to get labels")?;
2063
2064        assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
2065        for label in labels.0 {
2066            if label.name == RANSOMWARE_LABEL {
2067                assert_eq!(label.id, ransomware_label_id);
2068                assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
2069            }
2070        }
2071
2072        // Use this file as a stand-un for an unknown file type
2073        let source_code = include_bytes!("mod.rs");
2074        let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
2075        let known_type =
2076            KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
2077
2078        assert!(matches!(known_type, KnownType::Unknown(_)));
2079
2080        let unknown_type: Vec<FileType> = db
2081            .get_known_data_types()
2082            .await?
2083            .into_iter()
2084            .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
2085            .collect();
2086        let unknown_type_id = unknown_type.first().unwrap().id;
2087        assert!(db.get_type_id_for_bytes(source_code).await.is_err());
2088        db.enable_keep_unknown_files()
2089            .await
2090            .context("failed to enable keeping of unknown files")?;
2091        let source_type = db
2092            .get_type_id_for_bytes(source_code)
2093            .await
2094            .context("failed to type id for source code unknown type example")?;
2095        assert_eq!(source_type, unknown_type_id);
2096        eprintln!("Unknown file type ID: {source_type}");
2097        assert!(
2098            db.add_file(
2099                &source_meta,
2100                known_type,
2101                1,
2102                default_source_id,
2103                unknown_type_id,
2104                None
2105            )
2106            .await
2107            .context("failed to add Rust source code file")?
2108            .is_new
2109        );
2110        eprintln!("Added Rust source code to the DB");
2111
2112        #[cfg(feature = "yara")]
2113        assert!(db.get_unfinished_yara_tasks().await?.is_empty());
2114
2115        db.reset_own_api_key(0)
2116            .await
2117            .context("failed to clear own API key uid 0")?;
2118
2119        db.deactivate_user(0)
2120            .await
2121            .context("failed to clear password and API key for uid 0")?;
2122
2123        Ok(())
2124    }
2125}