malwaredb_server/db/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Postgres is the database used by MalwareDB.
4//! However, SQLite will be used for unit testing or for small instances of MalwareDB. This option
5//! can be allowed by using the `sqlite` feature flag. When using SQLite, MalwareDB will calculate
6//! the distances for the similarity hashes.
7
8/// Malware DB Administrative functions
9#[cfg(any(test, feature = "admin"))]
10mod admin;
11/// Postgres functions
12mod pg;
13
14/// `SQLite` functionality
15#[cfg(any(test, feature = "sqlite"))]
16mod sqlite;
17
18/// Custom `SQLite` functions
19#[cfg(any(test, feature = "sqlite"))]
20mod sqlite_functions;
21
22/// File Metadata convenience data structure
23pub mod types;
24
25#[cfg(any(test, feature = "admin"))]
26use crate::crypto::EncryptionOption;
27use crate::crypto::FileEncryption;
28use crate::db::pg::Postgres;
29#[cfg(any(test, feature = "sqlite"))]
30use crate::db::sqlite::Sqlite;
31use crate::db::types::{FileMetadata, FileType};
32use malwaredb_api::{
33    digest::HashType, GetUserInfoResponse, Labels, SearchRequest, SearchResponse, Sources,
34};
35use malwaredb_types::KnownType;
36
37use std::collections::HashMap;
38use std::path::PathBuf;
39
40use anyhow::{bail, ensure, Result};
41use argon2::password_hash::{rand_core::OsRng, SaltString};
42use argon2::{Argon2, PasswordHasher};
43#[cfg(any(test, feature = "admin"))]
44use chrono::Local;
45#[cfg(feature = "vt")]
46use malwaredb_virustotal::filereport::ScanResultAttributes;
47
48/// The maximum amount of partial hash and/or partial file name search results to prevent performance issues
49pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
50
51/// Database connection handle
52#[derive(Debug)]
53pub enum DatabaseType {
54    /// Postgres database
55    Postgres(Postgres),
56
57    /// `SQLite` database
58    #[cfg(any(test, feature = "sqlite"))]
59    SQLite(Sqlite),
60}
61
62/// Version information and basic stats for the database
63#[derive(Debug)]
64pub struct DatabaseInformation {
65    /// Version string of the database
66    pub version: String,
67
68    /// Human-readable database size
69    pub size: String,
70
71    /// Number of file samples in Malware DB
72    pub num_files: u64,
73
74    /// Number of user accounts
75    pub num_users: u32,
76
77    /// Number of user groups
78    pub num_groups: u32,
79
80    /// Number of sample sources
81    pub num_sources: u32,
82}
83
84/// Data returned when adding a new sample
85pub struct FileAddedResult {
86    /// File ID
87    pub file_id: u64,
88
89    /// Whether the file was added as a new entry.
90    /// This is false if the sample was already known to Malware DB.
91    pub is_new: bool,
92}
93
94/// Malware DB configuration which is stored in the database
95#[derive(Debug)]
96pub struct MDBConfig {
97    /// The name of this instance of Malware DB
98    pub name: String,
99
100    /// Whether samples are stored compressed
101    pub compression: bool,
102
103    /// Whether Malware DB can send samples to Virus Total
104    pub send_samples_to_vt: bool,
105
106    /// If Malware DB should keep unknown files
107    pub keep_unknown_files: bool,
108
109    /// If samples are to be encrypted, which key?
110    pub(crate) default_key: Option<u32>,
111}
112
113/// VT record information for files in Malware DB
114#[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
115#[cfg(feature = "vt")]
116#[derive(Debug, Clone, Copy)]
117pub struct VtStats {
118    /// Files marked as clean
119    pub clean_records: u32,
120
121    /// Files marked as malicious
122    pub hits_records: u32,
123
124    /// Files without VT records
125    pub files_without_records: u32,
126}
127
128impl DatabaseType {
129    /// Get a database connection from a configuration string
130    ///
131    /// # Errors
132    ///
133    /// * If there's a connectivity issue to Postgres, an error will result
134    /// * If the `SQLite` file cannot be created or opened, an error will result
135    /// * If `SQLite` is the type but Malware DB wasn't compiled with the sqlite feature, an error will result
136    /// * If the format or database type isn't known, an error will result
137    pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
138        #[cfg(any(test, feature = "sqlite"))]
139        if arg.starts_with("file:") {
140            let new_conn_str = arg.trim_start_matches("file:");
141            return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
142        }
143
144        if arg.starts_with("postgres") {
145            let new_conn_str = arg.trim_start_matches("postgres");
146            return Ok(DatabaseType::Postgres(
147                Postgres::new(new_conn_str, server_ca).await?,
148            ));
149        }
150
151        bail!("unknown database type `{arg}`")
152    }
153
154    /// Set the flag allowing uploads to Virus Total.
155    ///
156    /// # Errors
157    ///
158    /// If there's a connectivity issue to Postgres, an error will result
159    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
160    #[cfg(feature = "vt")]
161    pub async fn enable_vt_upload(&self) -> Result<()> {
162        match self {
163            DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
164            #[cfg(any(test, feature = "sqlite"))]
165            DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
166        }
167    }
168
169    /// Set the flag preventing uploads to Virus Total.
170    ///
171    /// # Errors
172    ///
173    /// If there's a connectivity issue to Postgres, an error will result
174    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
175    #[cfg(feature = "vt")]
176    pub async fn disable_vt_upload(&self) -> Result<()> {
177        match self {
178            DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
179            #[cfg(any(test, feature = "sqlite"))]
180            DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
181        }
182    }
183
184    /// Get the SHA-256 hashes of the files which don't have VT records
185    ///
186    /// # Errors
187    ///
188    /// If there's a connectivity issue to Postgres, an error will result
189    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
190    #[cfg(feature = "vt")]
191    pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
192        match self {
193            DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
194            #[cfg(any(test, feature = "sqlite"))]
195            DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
196        }
197    }
198
199    /// Store the VT results: AV hits and detailed report, or lack of any AV hits
200    ///
201    /// # Errors
202    ///
203    /// If there's a connectivity issue to Postgres, an error will result
204    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
205    #[cfg(feature = "vt")]
206    pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
207        match self {
208            DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
209            #[cfg(any(test, feature = "sqlite"))]
210            DatabaseType::SQLite(sl) => sl.store_vt_record(results),
211        }
212    }
213
214    /// Quick statistics regarding the data contained for VT information for our samples
215    ///
216    /// # Errors
217    ///
218    /// If there's a connectivity issue to Postgres, an error will result
219    #[cfg_attr(docsrs, doc(cfg(feature = "vt")))]
220    #[cfg(feature = "vt")]
221    pub async fn get_vt_stats(&self) -> Result<VtStats> {
222        match self {
223            DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
224            #[cfg(any(test, feature = "sqlite"))]
225            DatabaseType::SQLite(sl) => sl.get_vt_stats(),
226        }
227    }
228
229    /// Get the configuration which is stored in the database
230    ///
231    /// # Errors
232    ///
233    /// If there's a connectivity issue to Postgres, an error will result
234    pub async fn get_config(&self) -> Result<MDBConfig> {
235        match self {
236            DatabaseType::Postgres(pg) => pg.get_config().await,
237            #[cfg(any(test, feature = "sqlite"))]
238            DatabaseType::SQLite(sl) => sl.get_config(),
239        }
240    }
241
242    /// Check user credentials, return the API key. Generate if it doesn't exist.
243    ///
244    /// # Errors
245    ///
246    /// * If there's a connectivity issue to Postgres, an error will result
247    /// * If the username and/or password aren't correct, an error will result
248    pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
249        match self {
250            DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
251            #[cfg(any(test, feature = "sqlite"))]
252            DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
253        }
254    }
255
256    /// Get the user's ID from their API key
257    ///
258    /// # Errors
259    ///
260    /// * If there's a connectivity issue to Postgres, an error will result
261    /// * If the api key isn't valid, an error will result
262    pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
263        ensure!(!apikey.is_empty(), "API key was empty");
264        match self {
265            DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
266            #[cfg(any(test, feature = "sqlite"))]
267            DatabaseType::SQLite(sl) => sl.get_uid(apikey),
268        }
269    }
270
271    /// Retrieve information about the database
272    ///
273    /// # Errors
274    ///
275    /// * If there's a connectivity issue to Postgres, an error will result
276    pub async fn db_info(&self) -> Result<DatabaseInformation> {
277        match self {
278            DatabaseType::Postgres(pg) => pg.db_info().await,
279            #[cfg(any(test, feature = "sqlite"))]
280            DatabaseType::SQLite(sl) => sl.db_info(),
281        }
282    }
283
284    /// Retrieve the names of the groups and sources the user is part of and has access to
285    ///
286    /// # Errors
287    ///
288    /// * If there's a connectivity issue to Postgres, an error will result
289    /// * If the user ID isn't valid, an error will result
290    pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
291        match self {
292            DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
293            #[cfg(any(test, feature = "sqlite"))]
294            DatabaseType::SQLite(sl) => sl.get_user_info(uid),
295        }
296    }
297
298    /// Retrieve the source information available to the specified user
299    ///
300    /// # Errors
301    ///
302    /// * If there's a connectivity issue to Postgres, an error will result
303    /// * If the user ID isn't valid, an error will result
304    pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
305        match self {
306            DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
307            #[cfg(any(test, feature = "sqlite"))]
308            DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
309        }
310    }
311
312    /// Let the user clear their own API key to log out from all systems
313    ///
314    /// # Errors
315    ///
316    /// * If there's a connectivity issue to Postgres, an error will result
317    /// * If the user ID isn't valid, an error will result
318    pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
319        match self {
320            DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
321            #[cfg(any(test, feature = "sqlite"))]
322            DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
323        }
324    }
325
326    /// Retrieve the supported data type information
327    ///
328    /// # Errors
329    ///
330    /// If there's a connectivity issue to Postgres, an error will result
331    pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
332        match self {
333            DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
334            #[cfg(any(test, feature = "sqlite"))]
335            DatabaseType::SQLite(sl) => sl.get_known_data_types(),
336        }
337    }
338
339    /// Get all labels from Malware DB
340    ///
341    /// # Errors
342    ///
343    /// If there's a connectivity issue to Postgres, an error will result
344    pub async fn get_labels(&self) -> Result<Labels> {
345        match self {
346            DatabaseType::Postgres(pg) => pg.get_labels().await,
347            #[cfg(any(test, feature = "sqlite"))]
348            DatabaseType::SQLite(sl) => sl.get_labels(),
349        }
350    }
351
352    /// Get the corresponding type ID for a buffer representing a file
353    ///
354    /// # Errors
355    ///
356    /// If there's a connectivity issue to Postgres, an error will result
357    pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
358        match self {
359            DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
360            #[cfg(any(test, feature = "sqlite"))]
361            DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
362        }
363    }
364
365    /// Check that a user has been granted access data for the specific source
366    ///
367    /// # Errors
368    ///
369    /// * If there's a connectivity issue to Postgres, an error will result
370    /// * If the user or source ID(s) aren't valid, an error will result
371    pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
372        match self {
373            DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
374            #[cfg(any(test, feature = "sqlite"))]
375            DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
376        }
377    }
378
379    /// Check to see if the user is an administrator. The user must be a member of the
380    /// admin group (group ID 0), or a one group below (a group with the parent group id of 0).
381    ///
382    /// # Errors
383    ///
384    /// * If there's a connectivity issue to Postgres, an error will result
385    /// * If the user ID isn't valid, an error will result
386    pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
387        match self {
388            DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
389            #[cfg(any(test, feature = "sqlite"))]
390            DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
391        }
392    }
393
394    /// Add a file's metadata to the database, returning true if this is a new entry
395    ///
396    /// # Errors
397    ///
398    /// * If there's a connectivity issue to Postgres, an error will result
399    /// * If the source doesn't exist or the user isn't part of a member group, an error will result
400    pub async fn add_file(
401        &self,
402        meta: &FileMetadata,
403        known_type: KnownType<'_>,
404        uid: u32,
405        sid: u32,
406        ftype: u32,
407        parent: Option<u64>,
408    ) -> Result<FileAddedResult> {
409        match self {
410            DatabaseType::Postgres(pg) => {
411                pg.add_file(meta, known_type, uid, sid, ftype, parent).await
412            }
413            #[cfg(any(test, feature = "sqlite"))]
414            DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
415        }
416    }
417
418    /// Search for allowed samples based on partial search and/or file name
419    ///
420    /// # Errors
421    ///
422    /// * If there's a connectivity issue to Postgres, an error will result
423    pub async fn partial_search(&self, uid: u32, search: SearchRequest) -> Result<SearchResponse> {
424        match self {
425            DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
426            #[cfg(any(test, feature = "sqlite"))]
427            DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
428        }
429    }
430
431    /// Delete old pagination searches
432    ///
433    /// # Errors
434    ///
435    /// An error would occur if the Postgres server couldn't be reached.
436    pub async fn cleanup(&self) -> Result<u64> {
437        match self {
438            DatabaseType::Postgres(pg) => pg.cleanup().await,
439            #[cfg(any(test, feature = "sqlite"))]
440            DatabaseType::SQLite(sl) => sl.cleanup(),
441        }
442    }
443
444    /// Retrieve the SHA-256 hash of the sample while checking that the user is permitted
445    /// to access to it
446    ///
447    /// # Errors
448    ///
449    /// * If there's a connectivity issue to Postgres, an error will result
450    /// * If the file doesn't exist or the user isn't allowed access, an error will result
451    pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
452        match self {
453            DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
454            #[cfg(any(test, feature = "sqlite"))]
455            DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
456        }
457    }
458
459    /// Retrieve a report for a given sample, if allowed.
460    ///
461    /// # Errors
462    ///
463    /// * If there's a connectivity issue to Postgres, an error will result
464    /// * If the file doesn't exist or the user isn't allowed access, an error will result
465    pub async fn get_sample_report(
466        &self,
467        uid: u32,
468        hash: &HashType,
469    ) -> Result<malwaredb_api::Report> {
470        match self {
471            DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
472            #[cfg(any(test, feature = "sqlite"))]
473            DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
474        }
475    }
476
477    /// Given a collection of similarity hashes, find samples which are similar.
478    ///
479    /// # Errors
480    ///
481    /// If there's a connectivity issue to Postgres, an error will result
482    pub async fn find_similar_samples(
483        &self,
484        uid: u32,
485        sim: &[(malwaredb_api::SimilarityHashType, String)],
486    ) -> Result<Vec<malwaredb_api::SimilarSample>> {
487        match self {
488            DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
489            #[cfg(any(test, feature = "sqlite"))]
490            DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
491        }
492    }
493
494    // Private functions
495
496    /// Get the file encryption keys
497    pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
498        match self {
499            DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
500            #[cfg(any(test, feature = "sqlite"))]
501            DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
502        }
503    }
504
505    /// Get the key and AES nonce for a specific file identified by SHA-256 hash
506    pub(crate) async fn get_file_encryption_key_id(
507        &self,
508        hash: &str,
509    ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
510        match self {
511            DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
512            #[cfg(any(test, feature = "sqlite"))]
513            DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
514        }
515    }
516
517    /// Set the AES nonce for a file specified by SHA-256 hash, a `None` value removes it.
518    pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
519        match self {
520            DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
521            #[cfg(any(test, feature = "sqlite"))]
522            DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
523        }
524    }
525
526    // Administrative functions
527
528    /// Set the instance name
529    ///
530    /// # Errors
531    ///
532    /// If there's a connectivity issue to Postgres, an error will result
533    #[cfg(any(test, feature = "admin"))]
534    pub async fn set_name(&self, name: &str) -> Result<()> {
535        match self {
536            DatabaseType::Postgres(pg) => pg.set_name(name).await,
537            #[cfg(any(test, feature = "sqlite"))]
538            DatabaseType::SQLite(sl) => sl.set_name(name),
539        }
540    }
541
542    /// Set the compression flag
543    ///
544    /// # Errors
545    ///
546    /// If there's a connectivity issue to Postgres, an error will result
547    #[cfg(any(test, feature = "admin"))]
548    pub async fn enable_compression(&self) -> Result<()> {
549        match self {
550            DatabaseType::Postgres(pg) => pg.enable_compression().await,
551            #[cfg(any(test, feature = "sqlite"))]
552            DatabaseType::SQLite(sl) => sl.enable_compression(),
553        }
554    }
555
556    /// Unset the compression flag, does not go and decompress files already compressed!
557    ///
558    /// # Errors
559    ///
560    /// If there's a connectivity issue to Postgres, an error will result
561    #[cfg(any(test, feature = "admin"))]
562    pub async fn disable_compression(&self) -> Result<()> {
563        match self {
564            DatabaseType::Postgres(pg) => pg.disable_compression().await,
565            #[cfg(any(test, feature = "sqlite"))]
566            DatabaseType::SQLite(sl) => sl.disable_compression(),
567        }
568    }
569
570    /// Set the keep unknown files flag
571    ///
572    /// # Errors
573    ///
574    /// If there's a connectivity issue to Postgres, an error will result
575    #[cfg(any(test, feature = "admin"))]
576    pub async fn enable_keep_unknown_files(&self) -> Result<()> {
577        match self {
578            DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
579            #[cfg(any(test, feature = "sqlite"))]
580            DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
581        }
582    }
583
584    /// Unset the keep unknown files flag, does not go and remove unknown files!
585    ///
586    /// # Errors
587    ///
588    /// If there's a connectivity issue to Postgres, an error will result
589    #[cfg(any(test, feature = "admin"))]
590    pub async fn disable_keep_unknown_files(&self) -> Result<()> {
591        match self {
592            DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
593            #[cfg(any(test, feature = "sqlite"))]
594            DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
595        }
596    }
597
598    /// Add an encryption key to the database, set it as the default, and return the key ID
599    ///
600    /// # Errors
601    ///
602    /// * If there is a connectivity with Postgres, an error will result
603    #[cfg(any(test, feature = "admin"))]
604    pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
605        match self {
606            DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
607            #[cfg(any(test, feature = "sqlite"))]
608            DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
609        }
610    }
611
612    /// Get the key ID and algorithm names
613    ///
614    /// # Errors
615    ///
616    /// If there is a connectivity with Postgres, an error will result
617    #[cfg(any(test, feature = "admin"))]
618    pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
619        match self {
620            DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
621            #[cfg(any(test, feature = "sqlite"))]
622            DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
623        }
624    }
625
626    /// Create a user account, return the user ID.
627    ///
628    /// # Errors
629    ///
630    /// If there's a connectivity issue to Postgres, an error will result
631    #[allow(clippy::too_many_arguments)]
632    #[cfg(any(test, feature = "admin"))]
633    pub async fn create_user(
634        &self,
635        uname: &str,
636        fname: &str,
637        lname: &str,
638        email: &str,
639        password: Option<String>,
640        organisation: Option<&String>,
641        readonly: bool,
642    ) -> Result<u32> {
643        match self {
644            DatabaseType::Postgres(pg) => {
645                pg.create_user(uname, fname, lname, email, password, organisation, readonly)
646                    .await
647            }
648            #[cfg(any(test, feature = "sqlite"))]
649            DatabaseType::SQLite(sl) => {
650                sl.create_user(uname, fname, lname, email, password, organisation, readonly)
651            }
652        }
653    }
654
655    /// Clear all API keys, either in case of suspected activity, or part of policy
656    ///
657    /// # Errors
658    ///
659    /// If there's a connectivity issue to Postgres, an error will result
660    #[cfg(any(test, feature = "admin"))]
661    pub async fn reset_api_keys(&self) -> Result<u64> {
662        match self {
663            DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
664            #[cfg(any(test, feature = "sqlite"))]
665            DatabaseType::SQLite(sl) => sl.reset_api_keys(),
666        }
667    }
668
669    /// Set a user's password
670    ///
671    /// # Errors
672    ///
673    /// * If there's a connectivity issue to Postgres, an error will result
674    /// * If the user doesn't exist, an error will result
675    #[cfg(any(test, feature = "admin"))]
676    pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
677        match self {
678            DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
679            #[cfg(any(test, feature = "sqlite"))]
680            DatabaseType::SQLite(sl) => sl.set_password(uname, password),
681        }
682    }
683
684    /// Get the complete list of users
685    ///
686    /// # Errors
687    ///
688    /// If there's a connectivity issue to Postgres, an error will result
689    #[cfg(any(test, feature = "admin"))]
690    pub async fn list_users(&self) -> Result<Vec<admin::User>> {
691        match self {
692            DatabaseType::Postgres(pg) => pg.list_users().await,
693            #[cfg(any(test, feature = "sqlite"))]
694            DatabaseType::SQLite(sl) => sl.list_users(),
695        }
696    }
697
698    /// Get the ID of a group from name
699    ///
700    /// # Errors
701    ///
702    /// * If there's a connectivity issue to Postgres, an error will result
703    /// * If the group name isn't valid, an error will result
704    #[cfg(any(test, feature = "admin"))]
705    pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
706        match self {
707            DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
708            #[cfg(any(test, feature = "sqlite"))]
709            DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
710        }
711    }
712
713    /// Update the record for a group
714    ///
715    /// # Errors
716    ///
717    /// * If there's a connectivity issue to Postgres, an error will result
718    /// * If the group doesn't exist, an error will result
719    #[cfg(any(test, feature = "admin"))]
720    pub async fn edit_group(
721        &self,
722        gid: u32,
723        name: &str,
724        desc: &str,
725        parent: Option<u32>,
726    ) -> Result<()> {
727        match self {
728            DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
729            #[cfg(any(test, feature = "sqlite"))]
730            DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
731        }
732    }
733
734    /// Get the complete list of groups
735    ///
736    /// # Errors
737    ///
738    /// If there's a connectivity issue to Postgres, an error will result
739    #[cfg(any(test, feature = "admin"))]
740    pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
741        match self {
742            DatabaseType::Postgres(pg) => pg.list_groups().await,
743            #[cfg(any(test, feature = "sqlite"))]
744            DatabaseType::SQLite(sl) => sl.list_groups(),
745        }
746    }
747
748    /// Grant a user membership to a group, both by id.
749    ///
750    /// # Errors
751    ///
752    /// * If there's a connectivity issue to Postgres, an error will result
753    /// * If the user or group doesn't exist, an error will result
754    #[cfg(any(test, feature = "admin"))]
755    pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
756        match self {
757            DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
758            #[cfg(any(test, feature = "sqlite"))]
759            DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
760        }
761    }
762
763    /// Grand a group access to a source, both by id.
764    ///
765    /// # Errors
766    ///
767    /// * If there's a connectivity issue to Postgres, an error will result
768    /// * If the group or source doesn't exist, an error will result
769    #[cfg(any(test, feature = "admin"))]
770    pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
771        match self {
772            DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
773            #[cfg(any(test, feature = "sqlite"))]
774            DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
775        }
776    }
777
778    /// Create a new group, returning the group ID
779    ///
780    /// # Errors
781    ///
782    /// * If there's a connectivity issue to Postgres, an error will result
783    /// * If the group name is already taken, an error will result
784    #[cfg(any(test, feature = "admin"))]
785    pub async fn create_group(
786        &self,
787        name: &str,
788        description: &str,
789        parent: Option<u32>,
790    ) -> Result<u32> {
791        match self {
792            DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
793            #[cfg(any(test, feature = "sqlite"))]
794            DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
795        }
796    }
797
798    /// Get the complete list of sources
799    ///
800    /// # Errors
801    ///
802    /// If there's a connectivity issue to Postgres, an error will result
803    #[cfg(any(test, feature = "admin"))]
804    pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
805        match self {
806            DatabaseType::Postgres(pg) => pg.list_sources().await,
807            #[cfg(any(test, feature = "sqlite"))]
808            DatabaseType::SQLite(sl) => sl.list_sources(),
809        }
810    }
811
812    /// Create a source, returning the source ID
813    ///
814    /// # Errors
815    ///
816    /// * If there's a connectivity issue to Postgres, an error will result
817    /// * If the source already exists, an error will result
818    #[cfg(any(test, feature = "admin"))]
819    pub async fn create_source(
820        &self,
821        name: &str,
822        description: Option<&str>,
823        url: Option<&str>,
824        date: chrono::DateTime<Local>,
825        releasable: bool,
826        malicious: Option<bool>,
827    ) -> Result<u32> {
828        match self {
829            DatabaseType::Postgres(pg) => {
830                pg.create_source(name, description, url, date, releasable, malicious)
831                    .await
832            }
833            #[cfg(any(test, feature = "sqlite"))]
834            DatabaseType::SQLite(sl) => {
835                sl.create_source(name, description, url, date, releasable, malicious)
836            }
837        }
838    }
839
840    /// Edit a user account setting the specified field values, primarily used by the Admin gui
841    ///
842    /// # Errors
843    ///
844    /// * If there's a connectivity issue to Postgres, an error will result
845    /// * If the user isn't valid, an error will result
846    #[cfg(any(test, feature = "admin"))]
847    pub async fn edit_user(
848        &self,
849        uid: u32,
850        uname: &str,
851        fname: &str,
852        lname: &str,
853        email: &str,
854        readonly: bool,
855    ) -> Result<()> {
856        match self {
857            DatabaseType::Postgres(pg) => {
858                pg.edit_user(uid, uname, fname, lname, email, readonly)
859                    .await
860            }
861            #[cfg(any(test, feature = "sqlite"))]
862            DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
863        }
864    }
865
866    /// Deactivate, but don't delete, a user's account
867    ///
868    /// # Errors
869    ///
870    /// * If there's a connectivity issue to Postgres, an error will result
871    /// * If the user ID isn't valid, an error will result
872    #[cfg(any(test, feature = "admin"))]
873    pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
874        match self {
875            DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
876            #[cfg(any(test, feature = "sqlite"))]
877            DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
878        }
879    }
880
881    /// File types and number of files per type
882    ///
883    /// # Errors
884    ///
885    /// * If there's a connectivity issue to Postgres, an error will result
886    #[cfg(any(test, feature = "admin"))]
887    pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
888        match self {
889            DatabaseType::Postgres(pg) => pg.file_types_counts().await,
890            #[cfg(any(test, feature = "sqlite"))]
891            DatabaseType::SQLite(sl) => sl.file_types_counts(),
892        }
893    }
894
895    /// Create a new label, returning the label ID
896    ///
897    /// # Errors
898    ///
899    /// * If there's a connectivity issue to Postgres, an error will result
900    /// * If the parent label doesn't exist, an error will result
901    /// * If the label name already exists, an error will result
902    #[cfg(any(test, feature = "admin"))]
903    pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
904        match self {
905            DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
906            #[cfg(any(test, feature = "sqlite"))]
907            DatabaseType::SQLite(sl) => sl.create_label(name, parent),
908        }
909    }
910
911    /// Edit a label name or parent
912    ///
913    /// # Errors
914    ///
915    /// * If there's a connectivity issue to Postgres, an error will result
916    /// * If the parent label doesn't exist, an error will result
917    /// * If the label name already exists, an error will result
918    #[cfg(any(test, feature = "admin"))]
919    pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
920        match self {
921            DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
922            #[cfg(any(test, feature = "sqlite"))]
923            DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
924        }
925    }
926
927    /// Return the ID for a given label name
928    ///
929    /// # Errors
930    ///
931    /// * If there's a connectivity issue to Postgres, an error will result
932    /// * If the label ID is invalid, an error will result
933    #[cfg(any(test, feature = "admin"))]
934    pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
935        match self {
936            DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
937            #[cfg(any(test, feature = "sqlite"))]
938            DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
939        }
940    }
941
942    /// Associate an existing label by its IDs with a file.
943    ///
944    /// # Errors
945    ///
946    /// * Incorrect IDs will result in an error
947    /// * If the file already has the given label associated
948    /// * If there is a network or connection issue with Postgres
949    #[cfg(any(test, feature = "admin"))]
950    pub async fn label_file(&self, file_id: u64, label_id: u64) -> Result<()> {
951        match self {
952            DatabaseType::Postgres(pg) => pg.label_file(file_id, label_id).await,
953            #[cfg(any(test, feature = "sqlite"))]
954            DatabaseType::SQLite(sl) => sl.label_file(file_id, label_id),
955        }
956    }
957}
958
959/// Hash a password string with [Argon2]
960///
961/// # Errors
962///
963/// An error may result if Argon has an error
964pub fn hash_password(password: &str) -> Result<String> {
965    let salt = SaltString::generate(&mut OsRng);
966    let argon2 = Argon2::default();
967    Ok(argon2
968        .hash_password(password.as_bytes(), &salt)?
969        .to_string())
970}
971
972/// Generate a new, random API key
973#[must_use]
974pub fn random_bytes_api_key() -> String {
975    let key1 = uuid::Uuid::new_v4();
976    let key2 = uuid::Uuid::new_v4();
977    let key1 = key1.to_string().replace('-', "");
978    let key2 = key2.to_string().replace('-', "");
979    format!("{key1}{key2}")
980}
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985    #[cfg(feature = "vt")]
986    use crate::vt::VtUpdater;
987
988    use std::fs;
989    #[cfg(feature = "vt")]
990    use std::sync::Arc;
991    #[cfg(feature = "vt")]
992    use std::time::SystemTime;
993
994    use anyhow::Context;
995    use fuzzyhash::FuzzyHash;
996    use malwaredb_api::{PartialHashSearchType, SearchRequestParameters, SearchType};
997    use malwaredb_lzjd::{LZDict, Murmur3HashState};
998    use tlsh_fixed::TlshBuilder;
999    use uuid::Uuid;
1000
1001    const MALWARE_LABEL: &str = "malware";
1002    const RANSOMWARE_LABEL: &str = "ransomware";
1003
1004    fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
1005        let mut hashes = vec![];
1006
1007        hashes.push((
1008            malwaredb_api::SimilarityHashType::SSDeep,
1009            FuzzyHash::new(data).to_string(),
1010        ));
1011
1012        let mut builder = TlshBuilder::new(
1013            tlsh_fixed::BucketKind::Bucket256,
1014            tlsh_fixed::ChecksumKind::ThreeByte,
1015            tlsh_fixed::Version::Version4,
1016        );
1017
1018        builder.update(data);
1019
1020        if let Ok(hasher) = builder.build() {
1021            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
1022        }
1023
1024        let build_hasher = Murmur3HashState::default();
1025        let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
1026        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
1027
1028        malwaredb_api::SimilarSamplesRequest { hashes }
1029    }
1030
1031    async fn pg_config() -> Postgres {
1032        // create user malwaredbtesting with password 'malwaredbtesting';
1033        // create database malwaredbtesting owner malwaredbtesting;
1034        const CONNECTION_STRING: &str =
1035            "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
1036
1037        if let Ok(pg_port) = std::env::var("PG_PORT") {
1038            // Get the port number to run in Github CI
1039            let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
1040            Postgres::new(&conn_string, None)
1041                .await
1042                .context(format!(
1043                    "failed to connect to postgres with specified port {pg_port}"
1044                ))
1045                .unwrap()
1046        } else {
1047            Postgres::new(CONNECTION_STRING, None).await.unwrap()
1048        }
1049    }
1050
1051    #[tokio::test]
1052    #[ignore = "don't run this in CI"]
1053    async fn pg() {
1054        let psql = pg_config().await;
1055        psql.delete_init().await.unwrap();
1056
1057        let db = DatabaseType::Postgres(psql);
1058        let key = FileEncryption::from(EncryptionOption::Xor);
1059        db.add_file_encryption_key(&key).await.unwrap();
1060        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1061        everything(&db).await.unwrap();
1062
1063        #[cfg(feature = "vt")]
1064        {
1065            let db_config = db.get_config().await.unwrap();
1066            let state = crate::State {
1067                port: 8080,
1068                directory: None,
1069                max_upload: 10 * 1024 * 1024,
1070                ip: "127.0.0.1".parse().unwrap(),
1071                db_type: Arc::new(db),
1072                db_config,
1073                keys: HashMap::new(),
1074                started: SystemTime::now(),
1075                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1076                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1077                }),
1078                tls_config: None,
1079                mdns: None,
1080            };
1081
1082            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1083
1084            vt.updater().await.unwrap();
1085            println!("PG: Did VT ops!");
1086
1087            let psql = pg_config().await;
1088
1089            let vt_stats = psql
1090                .get_vt_stats()
1091                .await
1092                .context("failed to get Postgres VT Stats")
1093                .unwrap();
1094            println!("{vt_stats:?}");
1095            assert!(
1096                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1097            );
1098        }
1099
1100        // Re-create the Postgres object so we can do some clean-up
1101        let psql = pg_config().await;
1102        psql.delete_init().await.unwrap();
1103    }
1104
1105    #[tokio::test]
1106    async fn sqlite() {
1107        const DB_FILE: &str = "testing_sqlite.db";
1108        if std::path::Path::new(DB_FILE).exists() {
1109            fs::remove_file(DB_FILE)
1110                .context(format!("failed to delete old SQLite file {DB_FILE}"))
1111                .unwrap();
1112        }
1113
1114        let sqlite = Sqlite::new(DB_FILE)
1115            .context(format!("failed to create SQLite instance for {DB_FILE}"))
1116            .unwrap();
1117
1118        let db = DatabaseType::SQLite(sqlite);
1119        let key = FileEncryption::from(EncryptionOption::Xor);
1120        db.add_file_encryption_key(&key).await.unwrap();
1121        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1122        everything(&db).await.unwrap();
1123
1124        #[cfg(feature = "vt")]
1125        {
1126            let db_config = db.get_config().await.unwrap();
1127            let state = crate::State {
1128                port: 8080,
1129                directory: None,
1130                max_upload: 10 * 1024 * 1024,
1131                ip: "127.0.0.1".parse().unwrap(),
1132                db_type: Arc::new(db),
1133                db_config,
1134                keys: HashMap::new(),
1135                started: SystemTime::now(),
1136                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1137                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1138                }),
1139                tls_config: None,
1140                mdns: None,
1141            };
1142
1143            let sqlite_second = Sqlite::new(DB_FILE)
1144                .context(format!("failed to create SQLite instance for {DB_FILE}"))
1145                .unwrap();
1146
1147            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1148
1149            vt.updater().await.unwrap();
1150            println!("Sqlite: Did VT ops!");
1151            let vt_stats = sqlite_second
1152                .get_vt_stats()
1153                .context("failed to get Sqlite VT Stats")
1154                .unwrap();
1155            println!("{vt_stats:?}");
1156            assert!(
1157                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1158            );
1159        }
1160
1161        fs::remove_file(DB_FILE)
1162            .context(format!("failed to delete SQLite file {DB_FILE}"))
1163            .unwrap();
1164    }
1165
1166    #[allow(clippy::too_many_lines)]
1167    async fn everything(db: &DatabaseType) -> Result<()> {
1168        const ADMIN_UNAME: &str = "admin";
1169        const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1170
1171        db.set_name("Testing Database")
1172            .await
1173            .context("setting instance name failed")?;
1174
1175        assert!(
1176            db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1177            "Authentication without password should have failed."
1178        );
1179
1180        db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1181            .await
1182            .context("failed to set admin password")?;
1183
1184        let admin_api_key = db
1185            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1186            .await
1187            .context("unable to get api key for admin")?;
1188        println!("API key: {admin_api_key}");
1189        assert_eq!(admin_api_key.len(), 64);
1190
1191        assert_eq!(
1192            db.get_uid(&admin_api_key).await?,
1193            0,
1194            "Unable to get UID given the API key"
1195        );
1196
1197        let admin_api_key_again = db
1198            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1199            .await
1200            .context("unable to get api key a second time for admin")?;
1201
1202        assert_eq!(
1203            admin_api_key, admin_api_key_again,
1204            "API keys didn't match the second time."
1205        );
1206
1207        let bad_password = "this_is_totally_not_my_password!!";
1208        eprintln!("Testing API login with incorrect password.");
1209        assert!(
1210            db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1211            "Authenticating as admin with a bad password should have failed."
1212        );
1213
1214        let admin_is_admin = db
1215            .user_is_admin(0)
1216            .await
1217            .context("unable to see if admin (uid 0) is an admin")?;
1218        assert!(admin_is_admin);
1219
1220        let new_user_uname = "testuser";
1221        let new_user_email = "test@example.com";
1222        let new_user_password = "some_awesome_password_++";
1223        let new_id = db
1224            .create_user(
1225                new_user_uname,
1226                new_user_uname,
1227                new_user_uname,
1228                new_user_email,
1229                Some(new_user_password.into()),
1230                None,
1231                false,
1232            )
1233            .await
1234            .context(format!("failed to create user {new_user_uname}"))?;
1235
1236        let passwordless_user_id = db
1237            .create_user(
1238                "passwordless_user",
1239                "passwordless_user",
1240                "passwordless_user",
1241                "passwordless_user@example.com",
1242                None,
1243                None,
1244                false,
1245            )
1246            .await
1247            .context("failed to create passwordless_user")?;
1248
1249        for user in &db.list_users().await.context("failed to list users")? {
1250            if user.id == passwordless_user_id {
1251                assert_eq!(user.uname, "passwordless_user");
1252            }
1253        }
1254
1255        db.edit_user(
1256            passwordless_user_id,
1257            "passwordless_user_2",
1258            "passwordless_user_2",
1259            "passwordless_user_2",
1260            "passwordless_user_2@something.com",
1261            false,
1262        )
1263        .await
1264        .context(format!(
1265            "failed to alter 'passwordless' user, id {passwordless_user_id}"
1266        ))?;
1267
1268        for user in &db.list_users().await.context("failed to list users")? {
1269            if user.id == passwordless_user_id {
1270                assert_eq!(user.uname, "passwordless_user_2");
1271            }
1272        }
1273
1274        assert!(
1275            new_id > 0,
1276            "Weird UID created for user {new_user_uname}: {new_id}"
1277        );
1278
1279        assert!(
1280            db.create_user(
1281                new_user_uname,
1282                new_user_uname,
1283                new_user_uname,
1284                new_user_email,
1285                Some(new_user_password.into()),
1286                None,
1287                false
1288            )
1289            .await
1290            .is_err(),
1291            "Creating a new user with the same user name should fail"
1292        );
1293
1294        let ro_user_name = "ro_user";
1295        let ro_user_password = "ro_user_password";
1296        db.create_user(
1297            ro_user_name,
1298            "ro_user",
1299            "ro_user",
1300            "ro@example.com",
1301            Some(ro_user_password.into()),
1302            None,
1303            true,
1304        )
1305        .await
1306        .context("failed to create read-only user")?;
1307
1308        let ro_user_api_key = db
1309            .authenticate(ro_user_name, ro_user_password)
1310            .await
1311            .context("unable to get api key for read-only user")?;
1312
1313        let new_user_password_change = "some_new_awesomer_password!_++";
1314        db.set_password(new_user_uname, new_user_password_change)
1315            .await
1316            .context("failed to change the password for testuser")?;
1317
1318        let new_user_api_key = db
1319            .authenticate(new_user_uname, new_user_password_change)
1320            .await
1321            .context("unable to get api key for testuser")?;
1322        eprintln!("{new_user_uname} got API key {new_user_api_key}");
1323
1324        assert_eq!(admin_api_key.len(), new_user_api_key.len());
1325
1326        let users = db.list_users().await.context("failed to list users")?;
1327        assert_eq!(
1328            users.len(),
1329            4,
1330            "Four users were created, yet there are {} users",
1331            users.len()
1332        );
1333        eprintln!("DB has {} users:", users.len());
1334        let mut passwordless_user_found = false;
1335        for user in users {
1336            println!("{user}");
1337            if user.uname == "passwordless_user_2" {
1338                assert!(!user.has_api_key);
1339                assert!(!user.has_password);
1340                passwordless_user_found = true;
1341            } else {
1342                assert!(user.has_api_key);
1343                assert!(user.has_password);
1344            }
1345        }
1346        assert!(passwordless_user_found);
1347
1348        let new_group_name = "some_new_group";
1349        let new_group_desc = "some_new_group_description";
1350        let new_group_id = 1;
1351        assert_eq!(
1352            db.create_group(new_group_name, new_group_desc, None)
1353                .await
1354                .context("failed to create group")?,
1355            new_group_id,
1356            "New group didn't have the expected ID, expected {new_group_id}"
1357        );
1358
1359        assert!(
1360            db.create_group(new_group_name, new_group_desc, None)
1361                .await
1362                .is_err(),
1363            "Duplicate group name should have failed"
1364        );
1365
1366        db.add_user_to_group(1, 1)
1367            .await
1368            .context("Unable to add uid 1 to gid 1")?;
1369
1370        let ro_user_uid = db
1371            .get_uid(&ro_user_api_key)
1372            .await
1373            .context("Unable to get UID for read-only user")?;
1374        db.add_user_to_group(ro_user_uid, 1)
1375            .await
1376            .context("Unable to add uid 2 to gid 1")?;
1377
1378        let new_admin_group_name = "admin_subgroup";
1379        let new_admin_group_desc = "admin_subgroup_description";
1380        let new_admin_group_id = 2;
1381        // TODO: Figure out why SQLite makes the group_id = 2, but with Postgres it's 3.
1382        assert!(
1383            db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1384                .await
1385                .context("failed to create admin sub-group")?
1386                >= new_admin_group_id,
1387            "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1388        );
1389
1390        let groups = db.list_groups().await.context("failed to list groups")?;
1391        assert_eq!(
1392            groups.len(),
1393            3,
1394            "Three groups were created, yet there are {} groups",
1395            groups.len()
1396        );
1397        eprintln!("DB has {} groups:", groups.len());
1398        for group in groups {
1399            println!("{group}");
1400            if group.id == new_admin_group_id {
1401                assert_eq!(group.parent, Some("admin".to_string()));
1402            }
1403            if group.id == 1 {
1404                let test_user_str = String::from(new_user_uname);
1405                let mut found = false;
1406                for member in group.members {
1407                    if member.uname == test_user_str {
1408                        found = true;
1409                        break;
1410                    }
1411                }
1412                assert!(found, "new user {test_user_str} wasn't in the group");
1413            }
1414        }
1415
1416        let default_source_name = "default_source".to_string();
1417        let default_source_id = db
1418            .create_source(
1419                &default_source_name,
1420                Some("desc_default_source"),
1421                None,
1422                Local::now(),
1423                true,
1424                Some(false),
1425            )
1426            .await
1427            .context("failed to create source `default_source`")?;
1428
1429        db.add_group_to_source(1, default_source_id)
1430            .await
1431            .context("failed to add group 1 to source 1")?;
1432
1433        let another_source_name = "another_source".to_string();
1434        let another_source_id = db
1435            .create_source(
1436                &another_source_name,
1437                Some("yet another file source"),
1438                None,
1439                Local::now(),
1440                true,
1441                Some(false),
1442            )
1443            .await
1444            .context("failed to create source `another_source`")?;
1445
1446        let empty_source_name = "empty_source".to_string();
1447        db.create_source(
1448            &empty_source_name,
1449            Some("empty and unused file source"),
1450            None,
1451            Local::now(),
1452            true,
1453            Some(false),
1454        )
1455        .await
1456        .context("failed to create source `another_source`")?;
1457
1458        db.add_group_to_source(1, another_source_id)
1459            .await
1460            .context("failed to add group 1 to source 1")?;
1461
1462        let sources = db.list_sources().await.context("failed to list sources")?;
1463        eprintln!("DB has {} sources:", sources.len());
1464        for source in sources {
1465            println!("{source}");
1466            assert_eq!(source.files, 0);
1467            if source.id == default_source_id || source.id == another_source_id {
1468                assert_eq!(
1469                    source.groups, 1,
1470                    "default source {default_source_name} should have 1 group"
1471                );
1472            } else {
1473                assert_eq!(source.groups, 0, "groups should zero (empty)");
1474            }
1475        }
1476
1477        let uid = db
1478            .get_uid(&new_user_api_key)
1479            .await
1480            .context("failed to user uid from apikey")?;
1481        let user_info = db
1482            .get_user_info(uid)
1483            .await
1484            .context("failed to get user's available groups and sources")?;
1485        assert!(user_info.sources.contains(&default_source_name));
1486        assert!(!user_info.is_admin);
1487        println!("UserInfoResponse: {user_info:?}");
1488
1489        assert!(
1490            db.allowed_user_source(1, default_source_id)
1491                .await
1492                .context(format!(
1493                    "failed to check that user 1 has access to source {default_source_id}"
1494                ))?,
1495            "User 1 should should have had access to source {default_source_id}"
1496        );
1497
1498        assert!(
1499            !db.allowed_user_source(1, 5)
1500                .await
1501                .context("failed to check that user 1 has access to source 5")?,
1502            "User 1 should should not have had access to source 5"
1503        );
1504
1505        let test_label_id = db
1506            .create_label("TestLabel", None)
1507            .await
1508            .context("failed to create test label")?;
1509        let test_elf_label_id = db
1510            .create_label("TestELF", Some(test_label_id))
1511            .await
1512            .context("failed to create test label")?;
1513
1514        let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1515        let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1516        let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1517
1518        let known_type =
1519            KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1520        assert!(known_type.is_exec(), "ELF should be executable");
1521        eprintln!("ELF type ID: {elf_type}");
1522
1523        let file_addition = db
1524            .add_file(
1525                &test_elf_meta,
1526                known_type.clone(),
1527                1,
1528                default_source_id,
1529                elf_type,
1530                None,
1531            )
1532            .await
1533            .context("failed to insert a test elf")?;
1534        assert!(file_addition.is_new, "File should have been added");
1535        eprintln!("Added ELF to the DB");
1536        db.label_file(file_addition.file_id, test_elf_label_id)
1537            .await
1538            .context("failed to label file")?;
1539
1540        // Search with several fields
1541        let partial_search = SearchRequest {
1542            search: SearchType::Search(SearchRequestParameters {
1543                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1544                labels: Some(vec![String::from("TestELF")]),
1545                file_type: Some(String::from("ELF")),
1546                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1547                ..Default::default()
1548            }),
1549        };
1550        assert!(partial_search.is_valid());
1551        let partial_search_response = db.partial_search(1, partial_search).await?;
1552        assert_eq!(partial_search_response.hashes.len(), 1);
1553        assert_eq!(
1554            partial_search_response.hashes[0],
1555            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1556        );
1557
1558        // Ensure we get a partial result just with the magic
1559        let partial_search = SearchRequest {
1560            search: SearchType::Search(SearchRequestParameters {
1561                partial_hash: None,
1562                labels: None,
1563                file_type: None,
1564                magic: Some(String::from("OpenPOWER ELF V2 ABI")),
1565                ..Default::default()
1566            }),
1567        };
1568        assert!(partial_search.is_valid());
1569        let partial_search_response = db.partial_search(1, partial_search).await?;
1570        assert_eq!(partial_search_response.hashes.len(), 1);
1571        assert_eq!(
1572            partial_search_response.hashes[0],
1573            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1574        );
1575
1576        // Should be valid yet return nothing since it's the wrong file type
1577        let partial_search = SearchRequest {
1578            search: SearchType::Search(SearchRequestParameters {
1579                partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1580                file_type: Some(String::from("PE32")),
1581                ..Default::default()
1582            }),
1583        };
1584        assert!(partial_search.is_valid());
1585        let partial_search_response = db.partial_search(1, partial_search).await?;
1586        assert_eq!(partial_search_response.hashes.len(), 0);
1587
1588        let partial_search = SearchRequest {
1589            search: SearchType::Search(SearchRequestParameters {
1590                file_name: Some("ppc64".into()),
1591                ..Default::default()
1592            }),
1593        };
1594        assert!(partial_search.is_valid());
1595        let partial_search_response = db.partial_search(1, partial_search).await?;
1596        assert_eq!(partial_search_response.hashes.len(), 1);
1597
1598        // Invalid search request should return empty results
1599        let partial_search = SearchRequest {
1600            search: SearchType::Search(SearchRequestParameters::default()),
1601        };
1602        assert!(!partial_search.is_valid());
1603        let partial_search_response = db.partial_search(1, partial_search).await?;
1604        assert!(partial_search_response.hashes.is_empty());
1605
1606        // Invalid search pagination should return empty results
1607        let partial_search = SearchRequest {
1608            search: SearchType::Continuation(Uuid::default()),
1609        };
1610        assert!(partial_search.is_valid());
1611        let partial_search_response = db.partial_search(1, partial_search).await?;
1612        assert!(partial_search_response.hashes.is_empty());
1613
1614        // Ensure a type representing an unknown type isn't found
1615        assert!(db
1616            .get_type_id_for_bytes(include_bytes!("../../../../MDB_Logo.ico"))
1617            .await
1618            .is_err());
1619
1620        assert!(
1621            db.add_file(
1622                &test_elf_meta,
1623                known_type.clone(),
1624                ro_user_uid,
1625                default_source_id,
1626                elf_type,
1627                None
1628            )
1629            .await
1630            .is_err(),
1631            "Read-only user should not be able to add a file"
1632        );
1633
1634        let mut test_elf_meta_different_name = test_elf_meta.clone();
1635        test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1636
1637        assert!(
1638            !db.add_file(
1639                &test_elf_meta_different_name,
1640                known_type,
1641                1,
1642                another_source_id,
1643                elf_type,
1644                None
1645            )
1646            .await
1647            .context("failed to insert a test elf again for a different source")?
1648            .is_new
1649        );
1650
1651        let sources = db
1652            .list_sources()
1653            .await
1654            .context("failed to re-list sources")?;
1655        eprintln!(
1656            "DB has {} sources, and a file was added twice:",
1657            sources.len()
1658        );
1659        println!("We should have two sources with one file each, yet only one ELF.");
1660        for source in sources {
1661            println!("{source}");
1662            if source.id == default_source_id || source.id == another_source_id {
1663                assert_eq!(source.files, 1);
1664            } else {
1665                assert_eq!(source.files, 0, "groups should zero (empty)");
1666            }
1667        }
1668
1669        assert!(!db
1670            .get_user_sources(1)
1671            .await
1672            .expect("failed to get user 1's sources")
1673            .sources
1674            .is_empty());
1675
1676        let file_types_counts = db
1677            .file_types_counts()
1678            .await
1679            .context("failed to get file types and counts")?;
1680        for (name, count) in file_types_counts {
1681            println!("{name}: {count}");
1682            assert_eq!(name, "ELF");
1683            assert_eq!(count, 1);
1684        }
1685
1686        let mut test_elf_modified = test_elf.clone();
1687        let random_bytes = Uuid::new_v4();
1688        let mut random_bytes = random_bytes.into_bytes().to_vec();
1689        test_elf_modified.append(&mut random_bytes);
1690        let similarity_request = generate_similarity_request(&test_elf_modified);
1691        let similarity_response = db
1692            .find_similar_samples(1, &similarity_request.hashes)
1693            .await
1694            .context("failed to get similarity response")?;
1695        eprintln!("Similarity response: {similarity_response:?}");
1696        let similarity_response = similarity_response.first().unwrap();
1697        assert_eq!(
1698            similarity_response.sha256, test_elf_meta.sha256,
1699            "Similarity response should have had the hash of the original ELF"
1700        );
1701        for (algo, sim) in &similarity_response.algorithms {
1702            match algo {
1703                malwaredb_api::SimilarityHashType::LZJD => {
1704                    assert!(*sim > 0.0f32);
1705                }
1706                malwaredb_api::SimilarityHashType::SSDeep => {
1707                    assert!(*sim > 80.0f32);
1708                }
1709                malwaredb_api::SimilarityHashType::TLSH => {
1710                    assert!(*sim <= 20f32);
1711                }
1712                _ => {}
1713            }
1714        }
1715
1716        let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1.as_str())
1717            .context("failed to get `HashType::SHA1` from string")?;
1718        let response_sha256 = db
1719            .retrieve_sample(1, &test_elf_hashtype)
1720            .await
1721            .context("could not get SHA-256 hash from test sample")
1722            .unwrap();
1723        assert_eq!(response_sha256, test_elf_meta.sha256);
1724
1725        let test_bogus_hash =
1726            HashType::try_from("d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0")
1727                .context("failed to get `HashType` from static string")?;
1728        assert!(
1729            db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1730            "Getting a file with a bogus hash should have failed."
1731        );
1732
1733        let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1734        let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1735        let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1736
1737        let known_type =
1738            KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1739
1740        assert!(
1741            db.add_file(
1742                &test_pdf_meta,
1743                known_type,
1744                1,
1745                default_source_id,
1746                pdf_type,
1747                None
1748            )
1749            .await
1750            .context("failed to insert a test pdf")?
1751            .is_new
1752        );
1753        eprintln!("Added PDF to the DB");
1754
1755        let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1756        let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1757        let rtf_type = db
1758            .get_type_id_for_bytes(&test_rtf)
1759            .await
1760            .context("failed to get file type id for rtf")?;
1761
1762        let known_type =
1763            KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1764
1765        assert!(
1766            db.add_file(
1767                &test_rtf_meta,
1768                known_type,
1769                1,
1770                default_source_id,
1771                rtf_type,
1772                None
1773            )
1774            .await
1775            .context("failed to insert a test rtf")?
1776            .is_new
1777        );
1778        eprintln!("Added RTF to the DB");
1779
1780        let report = db
1781            .get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.as_str())?)
1782            .await
1783            .context("failed to get report for test rtf")?;
1784        assert!(report
1785            .clone()
1786            .filecommand
1787            .unwrap()
1788            .contains("Rich Text Format"));
1789        println!("Report: {report}");
1790
1791        assert!(db
1792            .get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256.as_str())?)
1793            .await
1794            .is_err());
1795
1796        #[cfg(feature = "vt")]
1797        {
1798            assert!(report.vt.is_some());
1799            let files_needing_vt = db
1800                .files_without_vt_records(10)
1801                .await
1802                .context("failed to get files without VT records")?;
1803            assert!(files_needing_vt.len() > 2);
1804            println!(
1805                "{} files needing VT data: {files_needing_vt:?}",
1806                files_needing_vt.len()
1807            );
1808        }
1809
1810        #[cfg(not(feature = "vt"))]
1811        {
1812            assert!(report.vt.is_none());
1813        }
1814
1815        let reset = db
1816            .reset_api_keys()
1817            .await
1818            .context("failed to reset all API keys")?;
1819        eprintln!("Cleared {reset} api keys.");
1820
1821        let db_info = db.db_info().await.context("failed to get database info")?;
1822        eprintln!("DB Info: {db_info:?}");
1823
1824        let data_types = db
1825            .get_known_data_types()
1826            .await
1827            .context("failed to get data types")?;
1828        for data_type in data_types {
1829            println!("{data_type:?}");
1830        }
1831
1832        let sources = db
1833            .list_sources()
1834            .await
1835            .context("failed to list sources second time")?;
1836        eprintln!("DB has {} sources:", sources.len());
1837        for source in sources {
1838            println!("{source}");
1839        }
1840
1841        let file_types_counts = db
1842            .file_types_counts()
1843            .await
1844            .context("failed to get file types and counts")?;
1845        for (name, count) in file_types_counts {
1846            println!("{name}: {count}");
1847            assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1848        }
1849
1850        let fatmacho =
1851            include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
1852                .to_vec();
1853        let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
1854        let fatmacho_type = db
1855            .get_type_id_for_bytes(&fatmacho)
1856            .await
1857            .context("failed to get file type for Fat Mach-O")?;
1858        let known_type = KnownType::new(&fatmacho)
1859            .context("failed to parse Fat Mach-O from type crate's test data")?;
1860
1861        assert!(
1862            db.add_file(
1863                &fatmacho_meta,
1864                known_type,
1865                1,
1866                default_source_id,
1867                fatmacho_type,
1868                None
1869            )
1870            .await
1871            .context("failed to insert a test Fat Mach-O")?
1872            .is_new
1873        );
1874        eprintln!("Added Fat Mach-O to the DB");
1875
1876        let file_types_counts = db
1877            .file_types_counts()
1878            .await
1879            .context("failed to get file types and counts")?;
1880        for (name, count) in &file_types_counts {
1881            println!("{name}: {count}");
1882        }
1883
1884        assert_eq!(
1885            *file_types_counts.get("Mach-O").unwrap(),
1886            4,
1887            "Expected 4 Mach-O files, got {:?}",
1888            file_types_counts.get("Mach-O")
1889        );
1890
1891        let malware_label_id = db
1892            .create_label(MALWARE_LABEL, None)
1893            .await
1894            .context("failed to create first label")?;
1895        let ransomware_label_id = db
1896            .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
1897            .await
1898            .context("failed to create malware sub-label")?;
1899        let labels = db.get_labels().await.context("failed to get labels")?;
1900
1901        assert_eq!(labels.len(), 4, "Expected 4 labels, got {labels}");
1902        for label in labels.0 {
1903            if label.name == RANSOMWARE_LABEL {
1904                assert_eq!(label.id, ransomware_label_id);
1905                assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
1906            }
1907        }
1908
1909        // Use this file as a stand-un for an unknown file type
1910        let source_code = include_bytes!("mod.rs");
1911        let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
1912        let known_type =
1913            KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
1914
1915        assert!(matches!(known_type, KnownType::Unknown(_)));
1916
1917        let unknown_type: Vec<FileType> = db
1918            .get_known_data_types()
1919            .await?
1920            .into_iter()
1921            .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
1922            .collect();
1923        let unknown_type_id = unknown_type.first().unwrap().id;
1924        assert!(db.get_type_id_for_bytes(source_code).await.is_err());
1925        db.enable_keep_unknown_files()
1926            .await
1927            .context("failed to enable keeping of unknown files")?;
1928        let source_type = db
1929            .get_type_id_for_bytes(source_code)
1930            .await
1931            .context("failed to type id for source code unknown type example")?;
1932        assert_eq!(source_type, unknown_type_id);
1933        eprintln!("Unknown file type ID: {source_type}");
1934        assert!(
1935            db.add_file(
1936                &source_meta,
1937                known_type,
1938                1,
1939                default_source_id,
1940                unknown_type_id,
1941                None
1942            )
1943            .await
1944            .context("failed to add Rust source code file")?
1945            .is_new
1946        );
1947        eprintln!("Added Rust source code to the DB");
1948
1949        db.reset_own_api_key(0)
1950            .await
1951            .context("failed to clear own API key uid 0")?;
1952
1953        db.deactivate_user(0)
1954            .await
1955            .context("failed to clear password and API key for uid 0")?;
1956
1957        Ok(())
1958    }
1959}