malwaredb_server/db/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Postgres is the database used by MalwareDB.
4//! However, SQLite will be used for unit testing or for small instances of MalwareDB. This option
5//! can be allowed by using the `sqlite` feature flag. When using SQLite, MalwareDB will calculate
6//! the distances for the similarity hashes.
7
8/// Malware DB Administrative functions
9#[cfg(any(test, feature = "admin"))]
10mod admin;
11/// Postgres functions
12mod pg;
13
14/// `SQLite` functionality
15#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18/// Custom `SQLite` functions
19#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22/// File Metadata convenience data structure
23pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33    digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48/// The maximum amount of partial hash and/or partial file name search results to prevent performance issues
49pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51/// Database connection handle
52#[derive(Debug)]
53pub enum DatabaseType {
54    /// Postgres database
55    Postgres(Postgres),
56
57    /// `SQLite` database
58    #[cfg(any(test, feature = "sqlite"))]
59    SQLite(Sqlite),
60}
61
62/// Version information and basic stats for the database
63#[derive(Debug)]
64pub struct DatabaseInformation {
65    /// Version string of the database
66    pub version: String,
67
68    /// Human-readable database size
69    pub size: String,
70
71    /// Number of file samples in Malware DB
72    pub num_files: u64,
73
74    /// Number of user accounts
75    pub num_users: u32,
76
77    /// Number of user groups
78    pub num_groups: u32,
79
80    /// Number of sample sources
81    pub num_sources: u32,
82}
83
84/// Data returned when adding a new sample
85pub struct FileAddedResult {
86    /// File ID
87    pub file_id: u64,
88
89    /// Whether the file was added as a new entry.
90    /// This is false if the sample was already known to Malware DB.
91    pub is_new: bool,
92}
93
94/// Malware DB configuration which is stored in the database
95#[derive(Debug)]
96pub struct MDBConfig {
97    /// The name of this instance of Malware DB
98    pub name: String,
99
100    /// Whether samples are stored compressed
101    pub compression: bool,
102
103    /// Whether Malware DB can send samples to Virus Total
104    pub send_samples_to_vt: bool,
105
106    /// If Malware DB should keep unknown files
107    pub keep_unknown_files: bool,
108
109    /// If samples are to be encrypted, which key?
110    pub(crate) default_key: Option<u32>,
111}
112
113/// VT record information for files in Malware DB
114#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
115#[cfg(feature = "vt")]
116#[derive(Debug, Clone, Copy)]
117pub struct VtStats {
118    /// Files marked as clean
119    pub clean_records: u32,
120
121    /// Files marked as malicious
122    pub hits_records: u32,
123
124    /// Files without VT records
125    pub files_without_records: u32,
126}
127
128impl DatabaseType {
129    /// Get a database connection from a configuration string
130    ///
131    /// # Errors
132    ///
133    /// * If there's a connectivity issue to Postgres, an error will result
134    /// * If the `SQLite` file cannot be created or opened, an error will result
135    /// * If `SQLite` is the type but Malware DB wasn't compiled with the sqlite feature, an error will result
136    /// * If the format or database type isn't known, an error will result
137    pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
138        #[cfg(any(test, feature = "sqlite"))]
139        if arg.starts_with("file:") {
140            let new_conn_str = arg.trim_start_matches("file:");
141            return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
142        }
143
144        if arg.starts_with("postgres") {
145            let new_conn_str = arg.trim_start_matches("postgres");
146            return Ok(DatabaseType::Postgres(
147                Postgres::new(new_conn_str, server_ca).await?,
148            ));
149        }
150
151        bail!("unknown database type `{arg}`")
152    }
153
154    /// Set the flag allowing uploads to Virus Total.
155    ///
156    /// # Errors
157    ///
158    /// If there's a connectivity issue to Postgres, an error will result
159    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
160    #[cfg(feature = "vt")]
161    pub async fn enable_vt_upload(&self) -> Result<()> {
162        match self {
163            DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
164            #[cfg(any(test, feature = "sqlite"))]
165            DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
166        }
167    }
168
169    /// Set the flag preventing uploads to Virus Total.
170    ///
171    /// # Errors
172    ///
173    /// If there's a connectivity issue to Postgres, an error will result
174    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
175    #[cfg(feature = "vt")]
176    pub async fn disable_vt_upload(&self) -> Result<()> {
177        match self {
178            DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
179            #[cfg(any(test, feature = "sqlite"))]
180            DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
181        }
182    }
183
184    /// Get the SHA-256 hashes of the files which don't have VT records
185    ///
186    /// # Errors
187    ///
188    /// If there's a connectivity issue to Postgres, an error will result
189    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
190    #[cfg(feature = "vt")]
191    pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
192        match self {
193            DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
194            #[cfg(any(test, feature = "sqlite"))]
195            DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
196        }
197    }
198
199    /// Store the VT results: AV hits and detailed report, or lack of any AV hits
200    ///
201    /// # Errors
202    ///
203    /// If there's a connectivity issue to Postgres, an error will result
204    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
205    #[cfg(feature = "vt")]
206    pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
207        match self {
208            DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
209            #[cfg(any(test, feature = "sqlite"))]
210            DatabaseType::SQLite(sl) => sl.store_vt_record(results),
211        }
212    }
213
214    /// Quick statistics regarding the data contained for VT information for our samples
215    ///
216    /// # Errors
217    ///
218    /// If there's a connectivity issue to Postgres, an error will result
219    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
220    #[cfg(feature = "vt")]
221    pub async fn get_vt_stats(&self) -> Result<VtStats> {
222        match self {
223            DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
224            #[cfg(any(test, feature = "sqlite"))]
225            DatabaseType::SQLite(sl) => sl.get_vt_stats(),
226        }
227    }
228
229    /// Get the configuration which is stored in the database
230    ///
231    /// # Errors
232    ///
233    /// If there's a connectivity issue to Postgres, an error will result
234    pub async fn get_config(&self) -> Result<MDBConfig> {
235        match self {
236            DatabaseType::Postgres(pg) => pg.get_config().await,
237            #[cfg(any(test, feature = "sqlite"))]
238            DatabaseType::SQLite(sl) => sl.get_config(),
239        }
240    }
241
242    /// Check user credentials, return the API key. Generate if it doesn't exist.
243    ///
244    /// # Errors
245    ///
246    /// * If there's a connectivity issue to Postgres, an error will result
247    /// * If the username and/or password aren't correct, an error will result
248    pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
249        match self {
250            DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
251            #[cfg(any(test, feature = "sqlite"))]
252            DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
253        }
254    }
255
256    /// Get the user's ID from their API key
257    ///
258    /// # Errors
259    ///
260    /// * If there's a connectivity issue to Postgres, an error will result
261    /// * If the api key isn't valid, an error will result
262    pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
263        ensure!(!apikey.is_empty(), "API key was empty");
264        match self {
265            DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
266            #[cfg(any(test, feature = "sqlite"))]
267            DatabaseType::SQLite(sl) => sl.get_uid(apikey),
268        }
269    }
270
271    /// Retrieve information about the database
272    ///
273    /// # Errors
274    ///
275    /// * If there's a connectivity issue to Postgres, an error will result
276    pub async fn db_info(&self) -> Result<DatabaseInformation> {
277        match self {
278            DatabaseType::Postgres(pg) => pg.db_info().await,
279            #[cfg(any(test, feature = "sqlite"))]
280            DatabaseType::SQLite(sl) => sl.db_info(),
281        }
282    }
283
284    /// Retrieve the names of the groups and sources the user is part of and has access to
285    ///
286    /// # Errors
287    ///
288    /// * If there's a connectivity issue to Postgres, an error will result
289    /// * If the user ID isn't valid, an error will result
290    pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
291        match self {
292            DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
293            #[cfg(any(test, feature = "sqlite"))]
294            DatabaseType::SQLite(sl) => sl.get_user_info(uid),
295        }
296    }
297
298    /// Retrieve the source information available to the specified user
299    ///
300    /// # Errors
301    ///
302    /// * If there's a connectivity issue to Postgres, an error will result
303    /// * If the user ID isn't valid, an error will result
304    pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
305        match self {
306            DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
307            #[cfg(any(test, feature = "sqlite"))]
308            DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
309        }
310    }
311
312    /// Let the user clear their own API key to log out from all systems
313    ///
314    /// # Errors
315    ///
316    /// * If there's a connectivity issue to Postgres, an error will result
317    /// * If the user ID isn't valid, an error will result
318    pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
319        match self {
320            DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
321            #[cfg(any(test, feature = "sqlite"))]
322            DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
323        }
324    }
325
326    /// Retrieve the supported data type information
327    ///
328    /// # Errors
329    ///
330    /// If there's a connectivity issue to Postgres, an error will result
331    pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
332        match self {
333            DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
334            #[cfg(any(test, feature = "sqlite"))]
335            DatabaseType::SQLite(sl) => sl.get_known_data_types(),
336        }
337    }
338
339    /// Get all labels from Malware DB
340    ///
341    /// # Errors
342    ///
343    /// If there's a connectivity issue to Postgres, an error will result
344    pub async fn get_labels(&self) -> Result<Labels> {
345        match self {
346            DatabaseType::Postgres(pg) => pg.get_labels().await,
347            #[cfg(any(test, feature = "sqlite"))]
348            DatabaseType::SQLite(sl) => sl.get_labels(),
349        }
350    }
351
352    /// Get the corresponding type ID for a buffer representing a file
353    ///
354    /// # Errors
355    ///
356    /// If there's a connectivity issue to Postgres, an error will result
357    pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
358        match self {
359            DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
360            #[cfg(any(test, feature = "sqlite"))]
361            DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
362        }
363    }
364
365    /// Check that a user has been granted access data for the specific source
366    ///
367    /// # Errors
368    ///
369    /// * If there's a connectivity issue to Postgres, an error will result
370    /// * If the user or source ID(s) aren't valid, an error will result
371    pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
372        match self {
373            DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
374            #[cfg(any(test, feature = "sqlite"))]
375            DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
376        }
377    }
378
379    /// Check to see if the user is an administrator. The user must be a member of the
380    /// admin group (group ID 0), or a one group below (a group with the parent group id of 0).
381    ///
382    /// # Errors
383    ///
384    /// * If there's a connectivity issue to Postgres, an error will result
385    /// * If the user ID isn't valid, an error will result
386    pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
387        match self {
388            DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
389            #[cfg(any(test, feature = "sqlite"))]
390            DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
391        }
392    }
393
394    /// Add a file's metadata to the database, returning true if this is a new entry
395    ///
396    /// # Errors
397    ///
398    /// * If there's a connectivity issue to Postgres, an error will result
399    /// * If the source doesn't exist or the user isn't part of a member group, an error will result
400    pub async fn add_file(
401        &self,
402        meta: &FileMetadata,
403        known_type: KnownType<'_>,
404        uid: u32,
405        sid: u32,
406        ftype: u32,
407        parent: Option<u64>,
408    ) -> Result<FileAddedResult> {
409        match self {
410            DatabaseType::Postgres(pg) => {
411                pg.add_file(meta, known_type, uid, sid, ftype, parent).await
412            }
413            #[cfg(any(test, feature = "sqlite"))]
414            DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
415        }
416    }
417
418    /// Search for allowed samples based on partial search and/or file name
419    ///
420    /// # Errors
421    ///
422    /// * If there's a connectivity issue to Postgres, an error will result
423    pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
424        match self {
425            DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
426            #[cfg(any(test, feature = "sqlite"))]
427            DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
428        }
429    }
430
431    /// Delete old pagination searches
432    ///
433    /// # Errors
434    ///
435    /// An error would occur if the Postgres server couldn't be reached.
436    pub async fn cleanup(&self) -> Result<u64> {
437        match self {
438            DatabaseType::Postgres(pg) => pg.cleanup().await,
439            #[cfg(any(test, feature = "sqlite"))]
440            DatabaseType::SQLite(sl) => sl.cleanup(),
441        }
442    }
443
444    /// Retrieve the SHA-256 hash of the sample while checking that the user is permitted
445    /// to access to it
446    ///
447    /// # Errors
448    ///
449    /// * If there's a connectivity issue to Postgres, an error will result
450    /// * If the file doesn't exist or the user isn't allowed access, an error will result
451    pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
452        match self {
453            DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
454            #[cfg(any(test, feature = "sqlite"))]
455            DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
456        }
457    }
458
459    /// Retrieve a report for a given sample, if allowed.
460    ///
461    /// # Errors
462    ///
463    /// * If there's a connectivity issue to Postgres, an error will result
464    /// * If the file doesn't exist or the user isn't allowed access, an error will result
465    pub async fn get_sample_report(
466        &self,
467        uid: u32,
468        hash: &HashType,
469    ) -> Result<malwaredb_api::Report> {
470        match self {
471            DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
472            #[cfg(any(test, feature = "sqlite"))]
473            DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
474        }
475    }
476
477    /// Given a collection of similarity hashes, find samples which are similar.
478    ///
479    /// # Errors
480    ///
481    /// If there's a connectivity issue to Postgres, an error will result
482    pub async fn find_similar_samples(
483        &self,
484        uid: u32,
485        sim: &[(malwaredb_api::SimilarityHashType, String)],
486    ) -> Result<Vec<malwaredb_api::SimilarSample>> {
487        match self {
488            DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
489            #[cfg(any(test, feature = "sqlite"))]
490            DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
491        }
492    }
493
494    // Private functions
495
496    /// Get the file encryption keys
497    pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
498        match self {
499            DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
500            #[cfg(any(test, feature = "sqlite"))]
501            DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
502        }
503    }
504
505    /// Get the key and AES nonce for a specific file identified by SHA-256 hash
506    pub(crate) async fn get_file_encryption_key_id(
507        &self,
508        hash: &str,
509    ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
510        match self {
511            DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
512            #[cfg(any(test, feature = "sqlite"))]
513            DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
514        }
515    }
516
517    /// Set the AES nonce for a file specified by SHA-256 hash, a `None` value removes it.
518    pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
519        match self {
520            DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
521            #[cfg(any(test, feature = "sqlite"))]
522            DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
523        }
524    }
525
526    // Administrative functions
527
528    /// Set the instance name
529    ///
530    /// # Errors
531    ///
532    /// If there's a connectivity issue to Postgres, an error will result
533    #[cfg(any(test, feature = "admin"))]
534    pub async fn set_name(&self, name: &str) -> Result<()> {
535        match self {
536            DatabaseType::Postgres(pg) => pg.set_name(name).await,
537            #[cfg(any(test, feature = "sqlite"))]
538            DatabaseType::SQLite(sl) => sl.set_name(name),
539        }
540    }
541
542    /// Set the compression flag
543    ///
544    /// # Errors
545    ///
546    /// If there's a connectivity issue to Postgres, an error will result
547    #[cfg(any(test, feature = "admin"))]
548    pub async fn enable_compression(&self) -> Result<()> {
549        match self {
550            DatabaseType::Postgres(pg) => pg.enable_compression().await,
551            #[cfg(any(test, feature = "sqlite"))]
552            DatabaseType::SQLite(sl) => sl.enable_compression(),
553        }
554    }
555
556    /// Unset the compression flag, does not go and decompress files already compressed!
557    ///
558    /// # Errors
559    ///
560    /// If there's a connectivity issue to Postgres, an error will result
561    #[cfg(any(test, feature = "admin"))]
562    pub async fn disable_compression(&self) -> Result<()> {
563        match self {
564            DatabaseType::Postgres(pg) => pg.disable_compression().await,
565            #[cfg(any(test, feature = "sqlite"))]
566            DatabaseType::SQLite(sl) => sl.disable_compression(),
567        }
568    }
569
570    /// Set the keep unknown files flag
571    ///
572    /// # Errors
573    ///
574    /// If there's a connectivity issue to Postgres, an error will result
575    #[cfg(any(test, feature = "admin"))]
576    pub async fn enable_keep_unknown_files(&self) -> Result<()> {
577        match self {
578            DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
579            #[cfg(any(test, feature = "sqlite"))]
580            DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
581        }
582    }
583
584    /// Unset the keep unknown files flag, does not go and remove unknown files!
585    ///
586    /// # Errors
587    ///
588    /// If there's a connectivity issue to Postgres, an error will result
589    #[cfg(any(test, feature = "admin"))]
590    pub async fn disable_keep_unknown_files(&self) -> Result<()> {
591        match self {
592            DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
593            #[cfg(any(test, feature = "sqlite"))]
594            DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
595        }
596    }
597
598    /// Add an encryption key to the database, set it as the default, and return the key ID
599    ///
600    /// # Errors
601    ///
602    /// * If there is a connectivity with Postgres, an error will result
603    #[cfg(any(test, feature = "admin"))]
604    pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
605        match self {
606            DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
607            #[cfg(any(test, feature = "sqlite"))]
608            DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
609        }
610    }
611
612    /// Get the key ID and algorithm names
613    ///
614    /// # Errors
615    ///
616    /// If there is a connectivity with Postgres, an error will result
617    #[cfg(any(test, feature = "admin"))]
618    pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
619        match self {
620            DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
621            #[cfg(any(test, feature = "sqlite"))]
622            DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
623        }
624    }
625
626    /// Create a user account, return the user ID.
627    ///
628    /// # Errors
629    ///
630    /// If there's a connectivity issue to Postgres, an error will result
631    #[allow(clippy::too_many_arguments)]
632    #[cfg(any(test, feature = "admin"))]
633    pub async fn create_user(
634        &self,
635        uname: &str,
636        fname: &str,
637        lname: &str,
638        email: &str,
639        password: Option<String>,
640        organisation: Option<&String>,
641        readonly: bool,
642    ) -> Result<u32> {
643        match self {
644            DatabaseType::Postgres(pg) => {
645                pg.create_user(uname, fname, lname, email, password, organisation, readonly)
646                    .await
647            }
648            #[cfg(any(test, feature = "sqlite"))]
649            DatabaseType::SQLite(sl) => {
650                sl.create_user(uname, fname, lname, email, password, organisation, readonly)
651            }
652        }
653    }
654
655    /// Clear all API keys, either in case of suspected activity, or part of policy
656    ///
657    /// # Errors
658    ///
659    /// If there's a connectivity issue to Postgres, an error will result
660    #[cfg(any(test, feature = "admin"))]
661    pub async fn reset_api_keys(&self) -> Result<u64> {
662        match self {
663            DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
664            #[cfg(any(test, feature = "sqlite"))]
665            DatabaseType::SQLite(sl) => sl.reset_api_keys(),
666        }
667    }
668
669    /// Set a user's password
670    ///
671    /// # Errors
672    ///
673    /// * If there's a connectivity issue to Postgres, an error will result
674    /// * If the user doesn't exist, an error will result
675    #[cfg(any(test, feature = "admin"))]
676    pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
677        match self {
678            DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
679            #[cfg(any(test, feature = "sqlite"))]
680            DatabaseType::SQLite(sl) => sl.set_password(uname, password),
681        }
682    }
683
684    /// Get the complete list of users
685    ///
686    /// # Errors
687    ///
688    /// If there's a connectivity issue to Postgres, an error will result
689    #[cfg(any(test, feature = "admin"))]
690    pub async fn list_users(&self) -> Result<Vec<admin::User>> {
691        match self {
692            DatabaseType::Postgres(pg) => pg.list_users().await,
693            #[cfg(any(test, feature = "sqlite"))]
694            DatabaseType::SQLite(sl) => sl.list_users(),
695        }
696    }
697
698    /// Get the ID of a group from name
699    ///
700    /// # Errors
701    ///
702    /// * If there's a connectivity issue to Postgres, an error will result
703    /// * If the group name isn't valid, an error will result
704    #[cfg(any(test, feature = "admin"))]
705    pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
706        match self {
707            DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
708            #[cfg(any(test, feature = "sqlite"))]
709            DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
710        }
711    }
712
713    /// Update the record for a group
714    ///
715    /// # Errors
716    ///
717    /// * If there's a connectivity issue to Postgres, an error will result
718    /// * If the group doesn't exist, an error will result
719    #[cfg(any(test, feature = "admin"))]
720    pub async fn edit_group(
721        &self,
722        gid: u32,
723        name: &str,
724        desc: &str,
725        parent: Option<u32>,
726    ) -> Result<()> {
727        match self {
728            DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
729            #[cfg(any(test, feature = "sqlite"))]
730            DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
731        }
732    }
733
734    /// Get the complete list of groups
735    ///
736    /// # Errors
737    ///
738    /// If there's a connectivity issue to Postgres, an error will result
739    #[cfg(any(test, feature = "admin"))]
740    pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
741        match self {
742            DatabaseType::Postgres(pg) => pg.list_groups().await,
743            #[cfg(any(test, feature = "sqlite"))]
744            DatabaseType::SQLite(sl) => sl.list_groups(),
745        }
746    }
747
748    /// Grant a user membership to a group, both by id.
749    ///
750    /// # Errors
751    ///
752    /// * If there's a connectivity issue to Postgres, an error will result
753    /// * If the user or group doesn't exist, an error will result
754    #[cfg(any(test, feature = "admin"))]
755    pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
756        match self {
757            DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
758            #[cfg(any(test, feature = "sqlite"))]
759            DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
760        }
761    }
762
763    /// Grand a group access to a source, both by id.
764    ///
765    /// # Errors
766    ///
767    /// * If there's a connectivity issue to Postgres, an error will result
768    /// * If the group or source doesn't exist, an error will result
769    #[cfg(any(test, feature = "admin"))]
770    pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
771        match self {
772            DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
773            #[cfg(any(test, feature = "sqlite"))]
774            DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
775        }
776    }
777
778    /// Create a new group, returning the group ID
779    ///
780    /// # Errors
781    ///
782    /// * If there's a connectivity issue to Postgres, an error will result
783    /// * If the group name is already taken, an error will result
784    #[cfg(any(test, feature = "admin"))]
785    pub async fn create_group(
786        &self,
787        name: &str,
788        description: &str,
789        parent: Option<u32>,
790    ) -> Result<u32> {
791        match self {
792            DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
793            #[cfg(any(test, feature = "sqlite"))]
794            DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
795        }
796    }
797
798    /// Get the complete list of sources
799    ///
800    /// # Errors
801    ///
802    /// If there's a connectivity issue to Postgres, an error will result
803    #[cfg(any(test, feature = "admin"))]
804    pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
805        match self {
806            DatabaseType::Postgres(pg) => pg.list_sources().await,
807            #[cfg(any(test, feature = "sqlite"))]
808            DatabaseType::SQLite(sl) => sl.list_sources(),
809        }
810    }
811
812    /// Create a source, returning the source ID
813    ///
814    /// # Errors
815    ///
816    /// * If there's a connectivity issue to Postgres, an error will result
817    /// * If the source already exists, an error will result
818    #[cfg(any(test, feature = "admin"))]
819    pub async fn create_source(
820        &self,
821        name: &str,
822        description: Option<&str>,
823        url: Option<&str>,
824        date: chrono::DateTime<Local>,
825        releasable: bool,
826        malicious: Option<bool>,
827    ) -> Result<u32> {
828        match self {
829            DatabaseType::Postgres(pg) => {
830                pg.create_source(name, description, url, date, releasable, malicious)
831                    .await
832            }
833            #[cfg(any(test, feature = "sqlite"))]
834            DatabaseType::SQLite(sl) => {
835                sl.create_source(name, description, url, date, releasable, malicious)
836            }
837        }
838    }
839
840    /// Edit a user account setting the specified field values, primarily used by the Admin gui
841    ///
842    /// # Errors
843    ///
844    /// * If there's a connectivity issue to Postgres, an error will result
845    /// * If the user isn't valid, an error will result
846    #[cfg(any(test, feature = "admin"))]
847    pub async fn edit_user(
848        &self,
849        uid: u32,
850        uname: &str,
851        fname: &str,
852        lname: &str,
853        email: &str,
854        readonly: bool,
855    ) -> Result<()> {
856        match self {
857            DatabaseType::Postgres(pg) => {
858                pg.edit_user(uid, uname, fname, lname, email, readonly)
859                    .await
860            }
861            #[cfg(any(test, feature = "sqlite"))]
862            DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
863        }
864    }
865
866    /// Deactivate, but don't delete, a user's account
867    ///
868    /// # Errors
869    ///
870    /// * If there's a connectivity issue to Postgres, an error will result
871    /// * If the user ID isn't valid, an error will result
872    #[cfg(any(test, feature = "admin"))]
873    pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
874        match self {
875            DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
876            #[cfg(any(test, feature = "sqlite"))]
877            DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
878        }
879    }
880
881    /// File types and number of files per type
882    ///
883    /// # Errors
884    ///
885    /// * If there's a connectivity issue to Postgres, an error will result
886    #[cfg(any(test, feature = "admin"))]
887    pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
888        match self {
889            DatabaseType::Postgres(pg) => pg.file_types_counts().await,
890            #[cfg(any(test, feature = "sqlite"))]
891            DatabaseType::SQLite(sl) => sl.file_types_counts(),
892        }
893    }
894
895    /// Create a new label, returning the label ID
896    ///
897    /// # Errors
898    ///
899    /// * If there's a connectivity issue to Postgres, an error will result
900    /// * If the parent label doesn't exist, an error will result
901    /// * If the label name already exists, an error will result
902    #[cfg(any(test, feature = "admin"))]
903    pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
904        match self {
905            DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
906            #[cfg(any(test, feature = "sqlite"))]
907            DatabaseType::SQLite(sl) => sl.create_label(name, parent),
908        }
909    }
910
911    /// Edit a label name or parent
912    ///
913    /// # Errors
914    ///
915    /// * If there's a connectivity issue to Postgres, an error will result
916    /// * If the parent label doesn't exist, an error will result
917    /// * If the label name already exists, an error will result
918    #[cfg(any(test, feature = "admin"))]
919    pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
920        match self {
921            DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
922            #[cfg(any(test, feature = "sqlite"))]
923            DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
924        }
925    }
926
927    /// Return the ID for a given label name
928    ///
929    /// # Errors
930    ///
931    /// * If there's a connectivity issue to Postgres, an error will result
932    /// * If the label ID is invalid, an error will result
933    #[cfg(any(test, feature = "admin"))]
934    pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
935        match self {
936            DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
937            #[cfg(any(test, feature = "sqlite"))]
938            DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
939        }
940    }
941
942    /// Associate an existing label by its IDs with a file.
943    ///
944    /// # Errors
945    ///
946    /// * Incorrect IDs will result in an error
947    /// * If the file already has the given label associated
948    /// * If there is a network or connection issue with Postgres
949    #[cfg(any(test, feature = "admin"))]
950    pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
951        match self {
952            DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
953            #[cfg(any(test, feature = "sqlite"))]
954            DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
955        }
956    }
957}
958
959/// Hash a password string with [Argon2]
960///
961/// # Errors
962///
963/// An error may result if Argon has an error
964pub fn hash_password(password: &str) -> Result<String> {
965    let salt = SaltString::generate(&mut OsRng);
966    let argon2 = Argon2::default();
967    Ok(argon2
968        .hash_password(password.as_bytes(), &salt)?
969        .to_string())
970}
971
972/// Generate a new, random API key
973#[must_use]
974pub fn random_bytes_api_key() -> String {
975    let key1 = uuid::Uuid::new_v4();
976    let key2 = uuid::Uuid::new_v4();
977    let key1 = key1.to_string().replace('-', "");
978    let key2 = key2.to_string().replace('-', "");
979    format!("{key1}{key2}")
980}
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985    #[cfg(feature = "vt")]
986    use crate::vt::VtUpdater;
987
988    use std::fs;
989    #[cfg(feature = "vt")]
990    use std::sync::Arc;
991    #[cfg(feature = "vt")]
992    use std::time::SystemTime;
993
994    use anyhow::Context;
995    use fuzzyhash::FuzzyHash;
996    use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
997    use malwaredb_lzjd::{LZDict, Murmur3HashState};
998    use tlsh_fixed::TlshBuilder;
999    use uuid::Uuid;
1000
1001    const MALWARE_LABEL: &str = "malware";
1002    const RANSOMWARE_LABEL: &str = "ransomware";
1003
1004    fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1005        let mut hashes = vec![];
1006
1007        hashes.push((
1008            malwaredb_api::SimilarityHashType::SSDeep,
1009            FuzzyHash::new(data).to_string(),
1010        ));
1011
1012        let mut builder = TlshBuilder::new(
1013            tlsh_fixed::BucketKind::Bucket256,
1014            tlsh_fixed::ChecksumKind::ThreeByte,
1015            tlsh_fixed::Version::Version4,
1016        );
1017
1018        builder.update(data);
1019
1020        if let Ok(hasher) = builder.build() {
1021            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1022        }
1023
1024        let build_hasher = Murmur3HashState::default();
1025        let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1026        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1027
1028        malwaredb_api::SimilarSamplesRequest { hashes }
1029    }
1030
1031    async fn pg_config() -> Postgres {
1032        // create user malwaredbtesting with password 'malwaredbtesting';
1033        // create database malwaredbtesting owner malwaredbtesting;
1034        const CONNECTION_STRING: &str =
1035            "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1036
1037        if let Ok(pg_port) = std::env::var("PG_PORT") {
1038            // Get the port number to run in Github CI
1039            let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1040            Postgres::new(&conn_string, None)
1041                .await
1042                .context(format!(
1043                    "failed to connect to postgres with specified port {pg_port}"
1044                ))
1045                .unwrap()
1046        } else {
1047            Postgres::new(CONNECTION_STRING, None).await.unwrap()
1048        }
1049    }
1050
1051    #[tokio::test]
1052    #[ignore = "don't run this in CI"]
1053    async fn pg() {
1054        let psql = pg_config().await;
1055        psql.delete_init().await.unwrap();
1056
1057        let db = DatabaseType::Postgres(psql);
1058        let key = FileEncryption::from(EncryptionOption::Xor);
1059        db.add_file_encryption_key(&key).await.unwrap();
1060        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1061        everything(&db).await.unwrap();
1062
1063        #[cfg(feature = "vt")]
1064        {
1065            let db_config = db.get_config().await.unwrap();
1066            let state = crate::State {
1067                port: 8080,
1068                directory: None,
1069                max_upload: 10 * 1024 * 1024,
1070                ip: "127.0.0.1".parse().unwrap(),
1071                db_type: Arc::new(db),
1072                db_config,
1073                keys: HashMap::new(),
1074                started: SystemTime::now(),
1075                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1076                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1077                }),
1078                cert: None,
1079                key: None,
1080                mdns: false,
1081            };
1082
1083            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1084
1085            vt.updater().await.unwrap();
1086            println!("PG: Did VT ops!");
1087
1088            let psql = pg_config().await;
1089
1090            let vt_stats = psql
1091                .get_vt_stats()
1092                .await
1093                .context("failed to get Postgres VT Stats")
1094                .unwrap();
1095            println!("{vt_stats:?}");
1096            assert!(
1097                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1098            );
1099        }
1100
1101        // Re-create the Postgres object so we can do some clean-up
1102        let psql = pg_config().await;
1103        psql.delete_init().await.unwrap();
1104    }
1105
1106    #[tokio::test]
1107    async fn sqlite() {
1108        const DB_FILE: &str = "testing_sqlite.db";
1109        if std::path::Path::new(DB_FILE).exists() {
1110            fs::remove_file(DB_FILE)
1111                .context(format!("failed to delete old SQLite file {DB_FILE}"))
1112                .unwrap();
1113        }
1114
1115        let sqlite = Sqlite::new(DB_FILE)
1116            .context(format!("failed to create SQLite instance for {DB_FILE}"))
1117            .unwrap();
1118
1119        let db = DatabaseType::SQLite(sqlite);
1120        let key = FileEncryption::from(EncryptionOption::Xor);
1121        db.add_file_encryption_key(&key).await.unwrap();
1122        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1123        everything(&db).await.unwrap();
1124
1125        #[cfg(feature = "vt")]
1126        {
1127            let db_config = db.get_config().await.unwrap();
1128            let state = crate::State {
1129                port: 8080,
1130                directory: None,
1131                max_upload: 10 * 1024 * 1024,
1132                ip: "127.0.0.1".parse().unwrap(),
1133                db_type: Arc::new(db),
1134                db_config,
1135                keys: HashMap::new(),
1136                started: SystemTime::now(),
1137                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1138                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1139                }),
1140                cert: None,
1141                key: None,
1142                mdns: false,
1143            };
1144
1145            let sqlite_second = Sqlite::new(DB_FILE)
1146                .context(format!("failed to create SQLite instance for {DB_FILE}"))
1147                .unwrap();
1148
1149            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1150
1151            vt.updater().await.unwrap();
1152            println!("Sqlite: Did VT ops!");
1153            let vt_stats = sqlite_second
1154                .get_vt_stats()
1155                .context("failed to get Sqlite VT Stats")
1156                .unwrap();
1157            println!("{vt_stats:?}");
1158            assert!(
1159                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1160            );
1161        }
1162
1163        fs::remove_file(DB_FILE)
1164            .context(format!("failed to delete SQLite file {DB_FILE}"))
1165            .unwrap();
1166    }
1167
1168    #[allow(clippy::too_many_lines)]
1169    async fn everything(db: &DatabaseType) -> Result<()> {
1170        const ADMIN_UNAME: &str = "admin";
1171        const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1172
1173        db.set_name("Testing Database")
1174            .await
1175            .context("setting instance name failed")?;
1176
1177        assert!(
1178            db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1179            "Authentication without password should have failed."
1180        );
1181
1182        db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1183            .await
1184            .context("failed to set admin password")?;
1185
1186        let admin_api_key = db
1187            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1188            .await
1189            .context("unable to get api key for admin")?;
1190        println!("API key: {admin_api_key}");
1191        assert_eq!(admin_api_key.len(), 64);
1192
1193        assert_eq!(
1194            db.get_uid(&admin_api_key).await?,
1195            0,
1196            "Unable to get UID given the API key"
1197        );
1198
1199        let admin_api_key_again = db
1200            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1201            .await
1202            .context("unable to get api key a second time for admin")?;
1203
1204        assert_eq!(
1205            admin_api_key, admin_api_key_again,
1206            "API keys didn't match the second time."
1207        );
1208
1209        let bad_password = "this_is_totally_not_my_password!!";
1210        eprintln!("Testing API login with incorrect password.");
1211        assert!(
1212            db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1213            "Authenticating as admin with a bad password should have failed."
1214        );
1215
1216        let admin_is_admin = db
1217            .user_is_admin(0)
1218            .await
1219            .context("unable to see if admin (uid 0) is an admin")?;
1220        assert!(admin_is_admin);
1221
1222        let new_user_uname = "testuser";
1223        let new_user_email = "test@example.com";
1224        let new_user_password = "some_awesome_password_++";
1225        let new_id = db
1226            .create_user(
1227                new_user_uname,
1228                new_user_uname,
1229                new_user_uname,
1230                new_user_email,
1231                Some(new_user_password.into()),
1232                None,
1233                false,
1234            )
1235            .await
1236            .context(format!("failed to create user {new_user_uname}"))?;
1237
1238        let passwordless_user_id = db
1239            .create_user(
1240                "passwordless_user",
1241                "passwordless_user",
1242                "passwordless_user",
1243                "passwordless_user@example.com",
1244                None,
1245                None,
1246                false,
1247            )
1248            .await
1249            .context("failed to create passwordless_user")?;
1250
1251        for user in &db.list_users().await.context("failed to list users")? {
1252            if user.id == passwordless_user_id {
1253                assert_eq!(user.uname, "passwordless_user");
1254            }
1255        }
1256
1257        db.edit_user(
1258            passwordless_user_id,
1259            "passwordless_user_2",
1260            "passwordless_user_2",
1261            "passwordless_user_2",
1262            "passwordless_user_2@something.com",
1263            false,
1264        )
1265        .await
1266        .context(format!(
1267            "failed to alter 'passwordless' user, id {passwordless_user_id}"
1268        ))?;
1269
1270        for user in &db.list_users().await.context("failed to list users")? {
1271            if user.id == passwordless_user_id {
1272                assert_eq!(user.uname, "passwordless_user_2");
1273            }
1274        }
1275
1276        assert!(
1277            new_id > 0,
1278            "Weird UID created for user {new_user_uname}: {new_id}"
1279        );
1280
1281        assert!(
1282            db.create_user(
1283                new_user_uname,
1284                new_user_uname,
1285                new_user_uname,
1286                new_user_email,
1287                Some(new_user_password.into()),
1288                None,
1289                false
1290            )
1291            .await
1292            .is_err(),
1293            "Creating a new user with the same user name should fail"
1294        );
1295
1296        let ro_user_name = "ro_user";
1297        let ro_user_password = "ro_user_password";
1298        db.create_user(
1299            ro_user_name,
1300            "ro_user",
1301            "ro_user",
1302            "ro@example.com",
1303            Some(ro_user_password.into()),
1304            None,
1305            true,
1306        )
1307        .await
1308        .context("failed to create read-only user")?;
1309
1310        let ro_user_api_key = db
1311            .authenticate(ro_user_name, ro_user_password)
1312            .await
1313            .context("unable to get api key for read-only user")?;
1314
1315        let new_user_password_change = "some_new_awesomer_password!_++";
1316        db.set_password(new_user_uname, new_user_password_change)
1317            .await
1318            .context("failed to change the password for testuser")?;
1319
1320        let new_user_api_key = db
1321            .authenticate(new_user_uname, new_user_password_change)
1322            .await
1323            .context("unable to get api key for testuser")?;
1324        eprintln!("{new_user_uname} got API key {new_user_api_key}");
1325
1326        assert_eq!(admin_api_key.len(), new_user_api_key.len());
1327
1328        let users = db.list_users().await.context("failed to list users")?;
1329        assert_eq!(
1330            users.len(),
1331            4,
1332            "Four users were created, yet there are {} users",
1333            users.len()
1334        );
1335        eprintln!("DB has {} users:", users.len());
1336        let mut passwordless_user_found = false;
1337        for user in users {
1338            println!("{user}");
1339            if user.uname == "passwordless_user_2" {
1340                assert!(!user.has_api_key);
1341                assert!(!user.has_password);
1342                passwordless_user_found = true;
1343            } else {
1344                assert!(user.has_api_key);
1345                assert!(user.has_password);
1346            }
1347        }
1348        assert!(passwordless_user_found);
1349
1350        let new_group_name = "some_new_group";
1351        let new_group_desc = "some_new_group_description";
1352        let new_group_id = 1;
1353        assert_eq!(
1354            db.create_group(new_group_name, new_group_desc, None)
1355                .await
1356                .context("failed to create group")?,
1357            new_group_id,
1358            "New group didn't have the expected ID, expected {new_group_id}"
1359        );
1360
1361        assert!(
1362            db.create_group(new_group_name, new_group_desc, None)
1363                .await
1364                .is_err(),
1365            "Duplicate group name should have failed"
1366        );
1367
1368        db.add_user_to_group(1, 1)
1369            .await
1370            .context("Unable to add uid 1 to gid 1")?;
1371
1372        let ro_user_uid = db
1373            .get_uid(&ro_user_api_key)
1374            .await
1375            .context("Unable to get UID for read-only user")?;
1376        db.add_user_to_group(ro_user_uid, 1)
1377            .await
1378            .context("Unable to add uid 2 to gid 1")?;
1379
1380        let new_admin_group_name = "admin_subgroup";
1381        let new_admin_group_desc = "admin_subgroup_description";
1382        let new_admin_group_id = 2;
1383        // TODO: Figure out why SQLite makes the group_id = 2, but with Postgres it's 3.
1384        assert!(
1385            db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1386                .await
1387                .context("failed to create admin sub-group")?
1388                >= new_admin_group_id,
1389            "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1390        );
1391
1392        let groups = db.list_groups().await.context("failed to list groups")?;
1393        assert_eq!(
1394            groups.len(),
1395            3,
1396            "Three groups were created, yet there are {} groups",
1397            groups.len()
1398        );
1399        eprintln!("DB has {} groups:", groups.len());
1400        for group in groups {
1401            println!("{group}");
1402            if group.id == new_admin_group_id {
1403                assert_eq!(group.parent, Some("admin".to_string()));
1404            }
1405            if group.id == 1 {
1406                let test_user_str = String::from(new_user_uname);
1407                let mut found = false;
1408                for member in group.members {
1409                    if member.uname == test_user_str {
1410                        found = true;
1411                        break;
1412                    }
1413                }
1414                assert!(found, "new user {test_user_str} wasn't in the group");
1415            }
1416        }
1417
1418        let default_source_name = "default_source".to_string();
1419        let default_source_id = db
1420            .create_source(
1421                &default_source_name,
1422                Some("desc_default_source"),
1423                None,
1424                Local::now(),
1425                true,
1426                Some(false),
1427            )
1428            .await
1429            .context("failed to create source `default_source`")?;
1430
1431        db.add_group_to_source(1, default_source_id)
1432            .await
1433            .context("failed to add group 1 to source 1")?;
1434
1435        let another_source_name = "another_source".to_string();
1436        let another_source_id = db
1437            .create_source(
1438                &another_source_name,
1439                Some("yet another file source"),
1440                None,
1441                Local::now(),
1442                true,
1443                Some(false),
1444            )
1445            .await
1446            .context("failed to create source `another_source`")?;
1447
1448        let empty_source_name = "empty_source".to_string();
1449        db.create_source(
1450            &empty_source_name,
1451            Some("empty and unused file source"),
1452            None,
1453            Local::now(),
1454            true,
1455            Some(false),
1456        )
1457        .await
1458        .context("failed to create source `another_source`")?;
1459
1460        db.add_group_to_source(1, another_source_id)
1461            .await
1462            .context("failed to add group 1 to source 1")?;
1463
1464        let sources = db.list_sources().await.context("failed to list sources")?;
1465        eprintln!("DB has {} sources:", sources.len());
1466        for source in sources {
1467            println!("{source}");
1468            assert_eq!(source.files, 0);
1469            if source.id == default_source_id || source.id == another_source_id {
1470                assert_eq!(
1471                    source.groups, 1,
1472                    "default source {default_source_name} should have 1 group"
1473                );
1474            } else {
1475                assert_eq!(source.groups, 0, "groups should zero (empty)");
1476            }
1477        }
1478
1479        let uid = db
1480            .get_uid(&new_user_api_key)
1481            .await
1482            .context("failed to user uid from apikey")?;
1483        let user_info = db
1484            .get_user_info(uid)
1485            .await
1486            .context("failed to get user's available groups and sources")?;
1487        assert!(user_info.sources.contains(&default_source_name));
1488        assert!(!user_info.is_admin);
1489        println!("UserInfoResponse: {user_info:?}");
1490
1491        assert!(
1492            db.allowed_user_source(1, default_source_id)
1493                .await
1494                .context(format!(
1495                    "failed to check that user 1 has access to source {default_source_id}"
1496                ))?,
1497            "User 1 should should have had access to source {default_source_id}"
1498        );
1499
1500        assert!(
1501            !db.allowed_user_source(1, 5)
1502                .await
1503                .context("failed to check that user 1 has access to source 5")?,
1504            "User 1 should should not have had access to source 5"
1505        );
1506
1507        let test_label_id = db
1508            .create_label("TestLabel", None)
1509            .await
1510            .context("failed to create test label")?;
1511        let test_elf_label_id = db
1512            .create_label("TestELF", Some(test_label_id))
1513            .await
1514            .context("failed to create test label")?;
1515
1516        let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1517        let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1518        let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1519
1520        let known_type =
1521            KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1522        assert!(known_type.is_exec(), "ELF should be executable");
1523        eprintln!("ELF type ID: {elf_type}");
1524
1525        let file_addition = db
1526            .add_file(
1527                &test_elf_meta,
1528                known_type.clone(),
1529                1,
1530                default_source_id,
1531                elf_type,
1532                None,
1533            )
1534            .await
1535            .context("failed to insert a test elf")?;
1536        assert!(file_addition.is_new, "File should have been added");
1537        eprintln!("Added ELF to the DB");
1538        db.label_file(file_addition.file_id, test_elf_label_id)
1539            .await
1540            .context("failed to label file")?;
1541
1542        // Search with several fields
1543        let partial_search = SearchRequest {
1544            search: SearchType::Search(SearchRequestParameters {
1545                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1546                labels: Some(vec![String::from("TestELF")]),
1547                file_type: Some(String::from("ELF")),
1548                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1549                ..Default::default()
1550            }),
1551        };
1552        assert!(partial_search.is_valid());
1553        let partial_search_response = db.partial_search(1, partial_search).await?;
1554        assert_eq!(partial_search_response.hashes.len(), 1);
1555        assert_eq!(
1556            partial_search_response.hashes[0],
1557            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1558        );
1559
1560        // Ensure we get a partial result just with the magic
1561        let partial_search = SearchRequest {
1562            search: SearchType::Search(SearchRequestParameters {
1563                partial_hash: None,
1564                labels: None,
1565                file_type: None,
1566                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1567                ..Default::default()
1568            }),
1569        };
1570        assert!(partial_search.is_valid());
1571        let partial_search_response = db.partial_search(1, partial_search).await?;
1572        assert_eq!(partial_search_response.hashes.len(), 1);
1573        assert_eq!(
1574            partial_search_response.hashes[0],
1575            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1576        );
1577
1578        // Should be valid yet return nothing since it's the wrong file type
1579        let partial_search = SearchRequest {
1580            search: SearchType::Search(SearchRequestParameters {
1581                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1582                file_type: Some(String::from("PE32")),
1583                ..Default::default()
1584            }),
1585        };
1586        assert!(partial_search.is_valid());
1587        let partial_search_response = db.partial_search(1, partial_search).await?;
1588        assert_eq!(partial_search_response.hashes.len(), 0);
1589
1590        let partial_search = SearchRequest {
1591            search: SearchType::Search(SearchRequestParameters {
1592                file_name: Some("ppc64".into()),
1593                ..Default::default()
1594            }),
1595        };
1596        assert!(partial_search.is_valid());
1597        let partial_search_response = db.partial_search(1, partial_search).await?;
1598        assert_eq!(partial_search_response.hashes.len(), 1);
1599
1600        // Invalid search request should return empty results
1601        let partial_search = SearchRequest {
1602            search: SearchType::Search(SearchRequestParameters::default()),
1603        };
1604        assert!(!partial_search.is_valid());
1605        let partial_search_response = db.partial_search(1, partial_search).await?;
1606        assert!(partial_search_response.hashes.is_empty());
1607
1608        // Invalid search pagination should return empty results
1609        let partial_search = SearchRequest {
1610            search: SearchType::Continuation(Uuid::default()),
1611        };
1612        assert!(partial_search.is_valid());
1613        let partial_search_response = db.partial_search(1, partial_search).await?;
1614        assert!(partial_search_response.hashes.is_empty());
1615
1616        // Ensure a type representing an unknown type isn't found
1617        assert!(db
1618            .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1619            .await
1620            .is_err());
1621
1622        assert!(
1623            db.add_file(
1624                &test_elf_meta,
1625                known_type.clone(),
1626                ro_user_uid,
1627                default_source_id,
1628                elf_type,
1629                None
1630            )
1631            .await
1632            .is_err(),
1633            "Read-only user should not be able to add a file"
1634        );
1635
1636        let mut test_elf_meta_different_name = test_elf_meta.clone();
1637        test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1638
1639        assert!(
1640            !db.add_file(
1641                &test_elf_meta_different_name,
1642                known_type,
1643                1,
1644                another_source_id,
1645                elf_type,
1646                None
1647            )
1648            .await
1649            .context("failed to insert a test elf again for a different source")?
1650            .is_new
1651        );
1652
1653        let sources = db
1654            .list_sources()
1655            .await
1656            .context("failed to re-list sources")?;
1657        eprintln!(
1658            "DB has {} sources, and a file was added twice:",
1659            sources.len()
1660        );
1661        println!("We should have two sources with one file each, yet only one ELF.");
1662        for source in sources {
1663            println!("{source}");
1664            if source.id == default_source_id || source.id == another_source_id {
1665                assert_eq!(source.files, 1);
1666            } else {
1667                assert_eq!(source.files, 0, "groups should zero (empty)");
1668            }
1669        }
1670
1671        assert!(!db
1672            .get_user_sources(1)
1673            .await
1674            .expect("failed to get user 1's sources")
1675            .sources
1676            .is_empty());
1677
1678        let file_types_counts = db
1679            .file_types_counts()
1680            .await
1681            .context("failed to get file types and counts")?;
1682        for (name, count) in file_types_counts {
1683            println!("{name}: {count}");
1684            assert_eq!(name, "ELF");
1685            assert_eq!(count, 1);
1686        }
1687
1688        let mut test_elf_modified = test_elf.clone();
1689        let random_bytes = Uuid::new_v4();
1690        let mut random_bytes = random_bytes.into_bytes().to_vec();
1691        test_elf_modified.append(&mut random_bytes);
1692        let similarity_request = generate_similarity_request(&test_elf_modified);
1693        let similarity_response = db
1694            .find_similar_samples(1, &similarity_request.hashes)
1695            .await
1696            .context("failed to get similarity response")?;
1697        eprintln!("Similarity response: {similarity_response:?}");
1698        let similarity_response = similarity_response.first().unwrap();
1699        assert_eq!(
1700            similarity_response.sha256, test_elf_meta.sha256,
1701            "Similarity response should have had the hash of the original ELF"
1702        );
1703        for (algo, sim) in &similarity_response.algorithms {
1704            match algo {
1705                malwaredb_api::SimilarityHashType::LZJD => {
1706                    assert!(*sim > 0.0f32);
1707                }
1708                malwaredb_api::SimilarityHashType::SSDeep => {
1709                    assert!(*sim > 80.0f32);
1710                }
1711                malwaredb_api::SimilarityHashType::TLSH => {
1712                    assert!(*sim <= 20f32);
1713                }
1714                _ => {}
1715            }
1716        }
1717
1718        let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1)
1719            .context("failed to get `HashType::SHA1` from string")?;
1720        let response_sha256 = db
1721            .retrieve_sample(1, &test_elf_hashtype)
1722            .await
1723            .context("could not get SHA-256 hash from test sample")
1724            .unwrap();
1725        assert_eq!(response_sha256, test_elf_meta.sha256);
1726
1727        let test_bogus_hash = HashType::try_from(String::from(
1728            "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0",
1729        ))
1730        .context("failed to get `HashType` from static string")?;
1731        assert!(
1732            db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1733            "Getting a file with a bogus hash should have failed."
1734        );
1735
1736        let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1737        let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1738        let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1739
1740        let known_type =
1741            KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1742
1743        assert!(
1744            db.add_file(
1745                &test_pdf_meta,
1746                known_type,
1747                1,
1748                default_source_id,
1749                pdf_type,
1750                None
1751            )
1752            .await
1753            .context("failed to insert a test pdf")?
1754            .is_new
1755        );
1756        eprintln!("Added PDF to the DB");
1757
1758        let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1759        let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1760        let rtf_type = db
1761            .get_type_id_for_bytes(&test_rtf)
1762            .await
1763            .context("failed to get file type id for rtf")?;
1764
1765        let known_type =
1766            KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1767
1768        assert!(
1769            db.add_file(
1770                &test_rtf_meta,
1771                known_type,
1772                1,
1773                default_source_id,
1774                rtf_type,
1775                None
1776            )
1777            .await
1778            .context("failed to insert a test rtf")?
1779            .is_new
1780        );
1781        eprintln!("Added RTF to the DB");
1782
1783        let report = db
1784            .get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.clone())?)
1785            .await
1786            .context("failed to get report for test rtf")?;
1787        assert!(report
1788            .clone()
1789            .filecommand
1790            .unwrap()
1791            .contains("Rich Text Format"));
1792        println!("Report: {report}");
1793
1794        assert!(db
1795            .get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256)?)
1796            .await
1797            .is_err());
1798
1799        #[cfg(feature = "vt")]
1800        {
1801            assert!(report.vt.is_some());
1802            let files_needing_vt = db
1803                .files_without_vt_records(10)
1804                .await
1805                .context("failed to get files without VT records")?;
1806            assert!(files_needing_vt.len() > 2);
1807            println!(
1808                "{} files needing VT data: {files_needing_vt:?}",
1809                files_needing_vt.len()
1810            );
1811        }
1812
1813        #[cfg(not(feature = "vt"))]
1814        {
1815            assert!(report.vt.is_none());
1816        }
1817
1818        let reset = db
1819            .reset_api_keys()
1820            .await
1821            .context("failed to reset all API keys")?;
1822        eprintln!("Cleared {reset} api keys.");
1823
1824        let db_info = db.db_info().await.context("failed to get database info")?;
1825        eprintln!("DB Info: {db_info:?}");
1826
1827        let data_types = db
1828            .get_known_data_types()
1829            .await
1830            .context("failed to get data types")?;
1831        for data_type in data_types {
1832            println!("{data_type:?}");
1833        }
1834
1835        let sources = db
1836            .list_sources()
1837            .await
1838            .context("failed to list sources second time")?;
1839        eprintln!("DB has {} sources:", sources.len());
1840        for source in sources {
1841            println!("{source}");
1842        }
1843
1844        let file_types_counts = db
1845            .file_types_counts()
1846            .await
1847            .context("failed to get file types and counts")?;
1848        for (name, count) in file_types_counts {
1849            println!("{name}: {count}");
1850            assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1851        }
1852
1853        let fatmacho =
1854            include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
1855                .to_vec();
1856        let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
1857        let fatmacho_type = db
1858            .get_type_id_for_bytes(&fatmacho)
1859            .await
1860            .context("failed to get file type for Fat Mach-O")?;
1861        let known_type = KnownType::new(&fatmacho)
1862            .context("failed to parse Fat Mach-O from type crate's test data")?;
1863
1864        assert!(
1865            db.add_file(
1866                &fatmacho_meta,
1867                known_type,
1868                1,
1869                default_source_id,
1870                fatmacho_type,
1871                None
1872            )
1873            .await
1874            .context("failed to insert a test Fat Mach-O")?
1875            .is_new
1876        );
1877        eprintln!("Added Fat Mach-O to the DB");
1878
1879        let file_types_counts = db
1880            .file_types_counts()
1881            .await
1882            .context("failed to get file types and counts")?;
1883        for (name, count) in &file_types_counts {
1884            println!("{name}: {count}");
1885        }
1886
1887        assert_eq!(
1888            *file_types_counts.get("Mach-O").unwrap(),
1889            4,
1890            "Expected 4 Mach-O files, got {:?}",
1891            file_types_counts.get("Mach-O")
1892        );
1893
1894        let malware_label_id = db
1895            .create_label(MALWARE_LABEL, None)
1896            .await
1897            .context("failed to create first label")?;
1898        let ransomware_label_id = db
1899            .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
1900            .await
1901            .context("failed to create malware sub-label")?;
1902        let labels = db.get_labels().await.context("failed to get labels")?;
1903
1904        assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
1905        for label in labels.0 {
1906            if label.name == RANSOMWARE_LABEL {
1907                assert_eq!(label.id, ransomware_label_id);
1908                assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
1909            }
1910        }
1911
1912        // Use this file as a stand-un for an unknown file type
1913        let source_code = include_bytes!("mod.rs");
1914        let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
1915        let known_type =
1916            KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
1917
1918        assert!(matches!(known_type, KnownType::Unknown(_)));
1919
1920        let unknown_type: Vec<FileType> = db
1921            .get_known_data_types()
1922            .await?
1923            .into_iter()
1924            .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
1925            .collect();
1926        let unknown_type_id = unknown_type.first().unwrap().id;
1927        assert!(db.get_type_id_for_bytes(source_code).await.is_err());
1928        db.enable_keep_unknown_files()
1929            .await
1930            .context("failed to enable keeping of unknown files")?;
1931        let source_type = db
1932            .get_type_id_for_bytes(source_code)
1933            .await
1934            .context("failed to type id for source code unknown type example")?;
1935        assert_eq!(source_type, unknown_type_id);
1936        eprintln!("Unknown file type ID: {source_type}");
1937        assert!(
1938            db.add_file(
1939                &source_meta,
1940                known_type,
1941                1,
1942                default_source_id,
1943                unknown_type_id,
1944                None
1945            )
1946            .await
1947            .context("failed to add Rust source code file")?
1948            .is_new
1949        );
1950        eprintln!("Added Rust source code to the DB");
1951
1952        db.reset_own_api_key(0)
1953            .await
1954            .context("failed to clear own API key uid 0")?;
1955
1956        db.deactivate_user(0)
1957            .await
1958            .context("failed to clear password and API key for uid 0")?;
1959
1960        Ok(())
1961    }
1962}