malwaredb_server/db/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Postgres is the database used by MalwareDB.
4//! However, SQLite will be used for unit testing. This option can be allowed by using the `sqlite`
5//! feature flag. When using SQLite, any similarity functions must be calculated by MalwareDB.
6
7/// Malware DB Administrative functions
8#[cfg(any(test, feature = "admin"))]
9mod admin;
10/// Postgres functions
11mod pg;
12
13/// `SQLite` functions
14#[cfg(any(test, feature = "sqlite"))]
15mod sqlite;
16
17/// File Metadata convenience data structure
18pub mod types;
19
20#[cfg(any(test, feature = "admin"))]
21use crate::crypto::EncryptionOption;
22use crate::crypto::FileEncryption;
23use crate::db::pg::Postgres;
24#[cfg(any(test, feature = "sqlite"))]
25use crate::db::sqlite::Sqlite;
26use crate::db::types::{FileMetadata, FileType};
27use malwaredb_api::{digest::HashType, GetUserInfoResponse, Labels, SearchRequest, Sources};
28use malwaredb_types::KnownType;
29
30use std::collections::HashMap;
31use std::path::PathBuf;
32
33use anyhow::{bail, ensure, Result};
34use argon2::password_hash::{rand_core::OsRng, SaltString};
35use argon2::{Argon2, PasswordHasher};
36#[cfg(any(test, feature = "admin"))]
37use chrono::Local;
38#[cfg(feature = "vt")]
39use malwaredb_virustotal::filereport::ScanResultAttributes;
40
41/// The maximum amount of partial hash and/or partial file name search results to prevent performance issues
42pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
43
44/// Database handle for `SQLite` or Postgres
45#[derive(Debug)]
46pub enum DatabaseType {
47    /// Postgres database
48    Postgres(Postgres),
49
50    /// `SQLite` database
51    #[cfg(any(test, feature = "sqlite"))]
52    SQLite(Sqlite),
53}
54
55/// Version information for the database, and some basic stats
56#[derive(Debug)]
57pub struct DatabaseInformation {
58    /// Version string of the database
59    pub version: String,
60
61    /// Human-readable database size
62    pub size: String,
63
64    /// Number of file samples in Malware DB
65    pub num_files: u64,
66
67    /// Number of user accounts
68    pub num_users: u32,
69
70    /// Number of user groups
71    pub num_groups: u32,
72
73    /// Number of sample sources
74    pub num_sources: u32,
75}
76
77/// Malware DB configuration which is stored in the database
78#[derive(Debug)]
79pub struct MDBConfig {
80    /// The name of this instance of Malware DB
81    pub name: String,
82
83    /// Whether samples are stored compressed
84    pub compression: bool,
85
86    /// Whether Malware DB can send samples to Virus Total
87    pub send_samples_to_vt: bool,
88
89    /// If Malware DB should keep unknown files
90    pub keep_unknown_files: bool,
91
92    /// If samples are to be encrypted, which key?
93    #[allow(dead_code)]
94    pub(crate) default_key: Option<u32>,
95}
96
97/// VT record information for files in Malware DB
98#[cfg(feature = "vt")]
99#[derive(Debug, Clone, Copy)]
100pub struct VtStats {
101    /// Files marked as clean
102    pub clean_records: u32,
103
104    /// Files marked as malicious
105    pub hits_records: u32,
106
107    /// Files without VT records
108    pub files_without_records: u32,
109}
110
111impl DatabaseType {
112    /// Get a database connection from a configuration string
113    ///
114    /// # Errors
115    ///
116    /// * If there's a connectivity issue to Postgres, an error will result
117    /// * If the `SQLite` file cannot be created or opened, an error will result
118    /// * If `SQLite` is the type but Malware DB wasn't compiled with the sqlite feature, an error will result
119    /// * If the format or database type isn't known, an error will result
120    pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
121        #[cfg(any(test, feature = "sqlite"))]
122        if arg.starts_with("file:") {
123            let new_conn_str = arg.trim_start_matches("file:");
124            return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
125        }
126
127        if arg.starts_with("postgres") {
128            let new_conn_str = arg.trim_start_matches("postgres");
129            return Ok(DatabaseType::Postgres(
130                Postgres::new(new_conn_str, server_ca).await?,
131            ));
132        }
133
134        bail!("unknown database type `{arg}`")
135    }
136
137    /// Set the flag allowing uploads to Virus Total.
138    ///
139    /// # Errors
140    ///
141    /// If there's a connectivity issue to Postgres, an error will result
142    #[cfg(feature = "vt")]
143    pub async fn enable_vt_upload(&self) -> Result<()> {
144        match self {
145            DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
146            #[cfg(any(test, feature = "sqlite"))]
147            DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
148        }
149    }
150
151    /// Set the flag preventing uploads to Virus Total.
152    ///
153    /// # Errors
154    ///
155    /// If there's a connectivity issue to Postgres, an error will result
156    #[cfg(feature = "vt")]
157    pub async fn disable_vt_upload(&self) -> Result<()> {
158        match self {
159            DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
160            #[cfg(any(test, feature = "sqlite"))]
161            DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
162        }
163    }
164
165    /// Get the SHA-256 hashes of the files which don't have VT records
166    ///
167    /// # Errors
168    ///
169    /// If there's a connectivity issue to Postgres, an error will result
170    #[cfg(feature = "vt")]
171    pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
172        match self {
173            DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
174            #[cfg(any(test, feature = "sqlite"))]
175            DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
176        }
177    }
178
179    /// Store the VT results: AV hits and detailed report, or lack of any AV hits
180    ///
181    /// # Errors
182    ///
183    /// If there's a connectivity issue to Postgres, an error will result
184    #[cfg(feature = "vt")]
185    pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
186        match self {
187            DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
188            #[cfg(any(test, feature = "sqlite"))]
189            DatabaseType::SQLite(sl) => sl.store_vt_record(results),
190        }
191    }
192
193    /// Quick statistics regarding the data contained for VT information for our samples
194    ///
195    /// # Errors
196    ///
197    /// If there's a connectivity issue to Postgres, an error will result
198    #[cfg(feature = "vt")]
199    pub async fn get_vt_stats(&self) -> Result<VtStats> {
200        match self {
201            DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
202            #[cfg(any(test, feature = "sqlite"))]
203            DatabaseType::SQLite(sl) => sl.get_vt_stats(),
204        }
205    }
206
207    /// Get the configuration which is stored in the database
208    ///
209    /// # Errors
210    ///
211    /// If there's a connectivity issue to Postgres, an error will result
212    pub async fn get_config(&self) -> Result<MDBConfig> {
213        match self {
214            DatabaseType::Postgres(pg) => pg.get_config().await,
215            #[cfg(any(test, feature = "sqlite"))]
216            DatabaseType::SQLite(sl) => sl.get_config(),
217        }
218    }
219
220    /// Check user credentials, return API key. Generate if it doesn't exist.
221    ///
222    /// # Errors
223    ///
224    /// * If there's a connectivity issue to Postgres, an error will result
225    /// * If the user name and/or password aren't correct, an error will result
226    pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
227        match self {
228            DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
229            #[cfg(any(test, feature = "sqlite"))]
230            DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
231        }
232    }
233
234    /// Get the user's ID from their API key
235    ///
236    /// # Errors
237    ///
238    /// * If there's a connectivity issue to Postgres, an error will result
239    /// * If the api key isn't valid, an error will result
240    pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
241        ensure!(!apikey.is_empty(), "API key was empty");
242        match self {
243            DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
244            #[cfg(any(test, feature = "sqlite"))]
245            DatabaseType::SQLite(sl) => sl.get_uid(apikey),
246        }
247    }
248
249    /// Retrieve information about the database
250    ///
251    /// # Errors
252    ///
253    /// * If there's a connectivity issue to Postgres, an error will result
254    pub async fn db_info(&self) -> Result<DatabaseInformation> {
255        match self {
256            DatabaseType::Postgres(pg) => pg.db_info().await,
257            #[cfg(any(test, feature = "sqlite"))]
258            DatabaseType::SQLite(sl) => sl.db_info(),
259        }
260    }
261
262    /// Retrieve the names of the groups and sources the user is part of and has access to
263    ///
264    /// # Errors
265    ///
266    /// * If there's a connectivity issue to Postgres, an error will result
267    /// * If the user ID isn't valid, an error will result
268    pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
269        match self {
270            DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
271            #[cfg(any(test, feature = "sqlite"))]
272            DatabaseType::SQLite(sl) => sl.get_user_info(uid),
273        }
274    }
275
276    /// Retrieve the source information available to the specified user
277    ///
278    /// # Errors
279    ///
280    /// * If there's a connectivity issue to Postgres, an error will result
281    /// * If the user ID isn't valid, an error will result
282    pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
283        match self {
284            DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
285            #[cfg(any(test, feature = "sqlite"))]
286            DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
287        }
288    }
289
290    /// Let the user clear their own API key to log out from all systems
291    ///
292    /// # Errors
293    ///
294    /// * If there's a connectivity issue to Postgres, an error will result
295    /// * If the user ID isn't valid, an error will result
296    pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
297        match self {
298            DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
299            #[cfg(any(test, feature = "sqlite"))]
300            DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
301        }
302    }
303
304    /// Retrieve the supported data type information
305    ///
306    /// # Errors
307    ///
308    /// If there's a connectivity issue to Postgres, an error will result
309    pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
310        match self {
311            DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
312            #[cfg(any(test, feature = "sqlite"))]
313            DatabaseType::SQLite(sl) => sl.get_known_data_types(),
314        }
315    }
316
317    /// Get all labels from Malware DB
318    ///
319    /// # Errors
320    ///
321    /// If there's a connectivity issue to Postgres, an error will result
322    pub async fn get_labels(&self) -> Result<Labels> {
323        match self {
324            DatabaseType::Postgres(pg) => pg.get_labels().await,
325            #[cfg(any(test, feature = "sqlite"))]
326            DatabaseType::SQLite(sl) => sl.get_labels(),
327        }
328    }
329
330    /// Get the corresponding type ID for a buffer representing a file
331    ///
332    /// # Errors
333    ///
334    /// If there's a connectivity issue to Postgres, an error will result
335    pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
336        match self {
337            DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
338            #[cfg(any(test, feature = "sqlite"))]
339            DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
340        }
341    }
342
343    /// Check that a user has been granted access data for the specific source
344    ///
345    /// # Errors
346    ///
347    /// * If there's a connectivity issue to Postgres, an error will result
348    /// * If the user or source ID(s) aren't valid, an error will result
349    pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
350        match self {
351            DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
352            #[cfg(any(test, feature = "sqlite"))]
353            DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
354        }
355    }
356
357    /// Check to see if the user is an administrator. The user must be a member of the
358    /// admin group (group ID 0), or a one group below (a group with the parent group id of 0).
359    ///
360    /// # Errors
361    ///
362    /// * If there's a connectivity issue to Postgres, an error will result
363    /// * If the user ID isn't valid, an error will result
364    pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
365        match self {
366            DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
367            #[cfg(any(test, feature = "sqlite"))]
368            DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
369        }
370    }
371
372    /// Add a file's metadata to the database, returning true if this is a new entry
373    ///
374    /// # Errors
375    ///
376    /// * If there's a connectivity issue to Postgres, an error will result
377    /// * If the source doesn't exist or the user isn't part of a member group, an error will result
378    pub async fn add_file(
379        &self,
380        meta: &FileMetadata,
381        known_type: KnownType<'_>,
382        uid: u32,
383        sid: u32,
384        ftype: u32,
385        parent: Option<u64>,
386    ) -> Result<bool> {
387        match self {
388            DatabaseType::Postgres(pg) => {
389                pg.add_file(meta, known_type, uid, sid, ftype, parent).await
390            }
391            #[cfg(any(test, feature = "sqlite"))]
392            DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
393        }
394    }
395
396    /// Search for allowed samples based on partial search and/or file name
397    ///
398    /// # Errors
399    ///
400    /// * If there's a connectivity issue to Postgres, an error will result
401    pub async fn partial_search(&self, uid: u32, search: &SearchRequest) -> Result<Vec<String>> {
402        match self {
403            DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
404            #[cfg(any(test, feature = "sqlite"))]
405            DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
406        }
407    }
408
409    /// Retrieve the SHA-256 hash of the sample while checking that the user is permitted
410    /// to access to it
411    ///
412    /// # Errors
413    ///
414    /// * If there's a connectivity issue to Postgres, an error will result
415    /// * If the file doesn't exist or the user isn't allowed access, an error will result
416    pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
417        match self {
418            DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
419            #[cfg(any(test, feature = "sqlite"))]
420            DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
421        }
422    }
423
424    /// Retrieve a report for a given sample, if allowed.
425    ///
426    /// # Errors
427    ///
428    /// * If there's a connectivity issue to Postgres, an error will result
429    /// * If the file doesn't exist or the user isn't allowed access, an error will result
430    pub async fn get_sample_report(
431        &self,
432        uid: u32,
433        hash: &HashType,
434    ) -> Result<malwaredb_api::Report> {
435        match self {
436            DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
437            #[cfg(any(test, feature = "sqlite"))]
438            DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
439        }
440    }
441
442    /// Given a collection of similarity hashes, find samples which are similar.
443    ///
444    /// # Errors
445    ///
446    /// If there's a connectivity issue to Postgres, an error will result
447    pub async fn find_similar_samples(
448        &self,
449        uid: u32,
450        sim: &[(malwaredb_api::SimilarityHashType, String)],
451    ) -> Result<Vec<malwaredb_api::SimilarSample>> {
452        match self {
453            DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
454            #[cfg(any(test, feature = "sqlite"))]
455            DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
456        }
457    }
458
459    // Private functions
460
461    /// Get the file encryption keys
462    pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
463        match self {
464            DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
465            #[cfg(any(test, feature = "sqlite"))]
466            DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
467        }
468    }
469
470    /// Get the key and AES nonce for a specific file identified by SHA-256 hash
471    pub(crate) async fn get_file_encryption_key_id(
472        &self,
473        hash: &str,
474    ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
475        match self {
476            DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
477            #[cfg(any(test, feature = "sqlite"))]
478            DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
479        }
480    }
481
482    /// Set the AES nonce for a file specified by SHA-256 hash, a `None` value removes it.
483    pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
484        match self {
485            DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
486            #[cfg(any(test, feature = "sqlite"))]
487            DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
488        }
489    }
490
491    // Administrative functions
492
493    /// Set the compression flag
494    ///
495    /// # Errors
496    ///
497    /// If there's a connectivity issue to Postgres, an error will result
498    #[cfg(any(test, feature = "admin"))]
499    pub async fn enable_compression(&self) -> Result<()> {
500        match self {
501            DatabaseType::Postgres(pg) => pg.enable_compression().await,
502            #[cfg(any(test, feature = "sqlite"))]
503            DatabaseType::SQLite(sl) => sl.enable_compression(),
504        }
505    }
506
507    /// Unset the compression flag, does not go and decompress files already compressed!
508    ///
509    /// # Errors
510    ///
511    /// If there's a connectivity issue to Postgres, an error will result
512    #[cfg(any(test, feature = "admin"))]
513    pub async fn disable_compression(&self) -> Result<()> {
514        match self {
515            DatabaseType::Postgres(pg) => pg.disable_compression().await,
516            #[cfg(any(test, feature = "sqlite"))]
517            DatabaseType::SQLite(sl) => sl.disable_compression(),
518        }
519    }
520
521    /// Set the keep unknown files flag
522    ///
523    /// # Errors
524    ///
525    /// If there's a connectivity issue to Postgres, an error will result
526    #[cfg(any(test, feature = "admin"))]
527    pub async fn enable_keep_unknown_files(&self) -> Result<()> {
528        match self {
529            DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
530            #[cfg(any(test, feature = "sqlite"))]
531            DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
532        }
533    }
534
535    /// Unset the keep unknown files flag, does not go and remove unknown files!
536    ///
537    /// # Errors
538    ///
539    /// If there's a connectivity issue to Postgres, an error will result
540    #[cfg(any(test, feature = "admin"))]
541    pub async fn disable_keep_unknown_files(&self) -> Result<()> {
542        match self {
543            DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
544            #[cfg(any(test, feature = "sqlite"))]
545            DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
546        }
547    }
548
549    /// Add an encryption key to the database, set it as the default, and return the key ID
550    ///
551    /// # Errors
552    ///
553    /// * If there is a connectivity with Postgres, an error will result
554    #[cfg(any(test, feature = "admin"))]
555    pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
556        match self {
557            DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
558            #[cfg(any(test, feature = "sqlite"))]
559            DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
560        }
561    }
562
563    /// Get the key ID and algorithm names
564    ///
565    /// # Errors
566    ///
567    /// If there is a connectivity with Postgres, an error will result
568    #[cfg(any(test, feature = "admin"))]
569    pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
570        match self {
571            DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
572            #[cfg(any(test, feature = "sqlite"))]
573            DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
574        }
575    }
576
577    /// Create a user account, return the user ID.
578    ///
579    /// # Errors
580    ///
581    /// If there's a connectivity issue to Postgres, an error will result
582    #[allow(clippy::too_many_arguments)]
583    #[cfg(any(test, feature = "admin"))]
584    pub async fn create_user(
585        &self,
586        uname: &str,
587        fname: &str,
588        lname: &str,
589        email: &str,
590        password: Option<String>,
591        organisation: Option<&String>,
592        readonly: bool,
593    ) -> Result<u32> {
594        match self {
595            DatabaseType::Postgres(pg) => {
596                pg.create_user(uname, fname, lname, email, password, organisation, readonly)
597                    .await
598            }
599            #[cfg(any(test, feature = "sqlite"))]
600            DatabaseType::SQLite(sl) => {
601                sl.create_user(uname, fname, lname, email, password, organisation, readonly)
602            }
603        }
604    }
605
606    /// Clear all API keys, either in case of suspected activity, or part of policy
607    ///
608    /// # Errors
609    ///
610    /// If there's a connectivity issue to Postgres, an error will result
611    #[cfg(any(test, feature = "admin"))]
612    pub async fn reset_api_keys(&self) -> Result<u64> {
613        match self {
614            DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
615            #[cfg(any(test, feature = "sqlite"))]
616            DatabaseType::SQLite(sl) => sl.reset_api_keys(),
617        }
618    }
619
620    /// Set a user's password
621    ///
622    /// # Errors
623    ///
624    /// * If there's a connectivity issue to Postgres, an error will result
625    /// * If the user doesn't exist, an error will result
626    #[cfg(any(test, feature = "admin"))]
627    pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
628        match self {
629            DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
630            #[cfg(any(test, feature = "sqlite"))]
631            DatabaseType::SQLite(sl) => sl.set_password(uname, password),
632        }
633    }
634
635    /// Get the complete list of users
636    ///
637    /// # Errors
638    ///
639    /// If there's a connectivity issue to Postgres, an error will result
640    #[cfg(any(test, feature = "admin"))]
641    pub async fn list_users(&self) -> Result<Vec<admin::User>> {
642        match self {
643            DatabaseType::Postgres(pg) => pg.list_users().await,
644            #[cfg(any(test, feature = "sqlite"))]
645            DatabaseType::SQLite(sl) => sl.list_users(),
646        }
647    }
648
649    /// Get the ID of a group from name
650    ///
651    /// # Errors
652    ///
653    /// * If there's a connectivity issue to Postgres, an error will result
654    /// * If the group name isn't valid, an error will result
655    #[cfg(any(test, feature = "admin"))]
656    pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
657        match self {
658            DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
659            #[cfg(any(test, feature = "sqlite"))]
660            DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
661        }
662    }
663
664    /// Update the record for a group
665    ///
666    /// # Errors
667    ///
668    /// * If there's a connectivity issue to Postgres, an error will result
669    /// * If the group doesn't exist, an error will result
670    #[cfg(any(test, feature = "admin"))]
671    pub async fn edit_group(
672        &self,
673        gid: u32,
674        name: &str,
675        desc: &str,
676        parent: Option<u32>,
677    ) -> Result<()> {
678        match self {
679            DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
680            #[cfg(any(test, feature = "sqlite"))]
681            DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
682        }
683    }
684
685    /// Get the complete list of groups
686    ///
687    /// # Errors
688    ///
689    /// If there's a connectivity issue to Postgres, an error will result
690    #[cfg(any(test, feature = "admin"))]
691    pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
692        match self {
693            DatabaseType::Postgres(pg) => pg.list_groups().await,
694            #[cfg(any(test, feature = "sqlite"))]
695            DatabaseType::SQLite(sl) => sl.list_groups(),
696        }
697    }
698
699    /// Grant a user membership to a group, both by id.
700    ///
701    /// # Errors
702    ///
703    /// * If there's a connectivity issue to Postgres, an error will result
704    /// * If the user or group doesn't exist, an error will result
705    #[cfg(any(test, feature = "admin"))]
706    pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
707        match self {
708            DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
709            #[cfg(any(test, feature = "sqlite"))]
710            DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
711        }
712    }
713
714    /// Grand a group access to a source, both by id.
715    ///
716    /// # Errors
717    ///
718    /// * If there's a connectivity issue to Postgres, an error will result
719    /// * If the group or source doesn't exist, an error will result
720    #[cfg(any(test, feature = "admin"))]
721    pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
722        match self {
723            DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
724            #[cfg(any(test, feature = "sqlite"))]
725            DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
726        }
727    }
728
729    /// Create a new group, returning the group ID
730    ///
731    /// # Errors
732    ///
733    /// * If there's a connectivity issue to Postgres, an error will result
734    /// * If the group name is already taken, an error will result
735    #[cfg(any(test, feature = "admin"))]
736    pub async fn create_group(
737        &self,
738        name: &str,
739        description: &str,
740        parent: Option<u32>,
741    ) -> Result<u32> {
742        match self {
743            DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
744            #[cfg(any(test, feature = "sqlite"))]
745            DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
746        }
747    }
748
749    /// Get the complete list of sources
750    ///
751    /// # Errors
752    ///
753    /// If there's a connectivity issue to Postgres, an error will result
754    #[cfg(any(test, feature = "admin"))]
755    pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
756        match self {
757            DatabaseType::Postgres(pg) => pg.list_sources().await,
758            #[cfg(any(test, feature = "sqlite"))]
759            DatabaseType::SQLite(sl) => sl.list_sources(),
760        }
761    }
762
763    /// Create a source, returning the source ID
764    ///
765    /// # Errors
766    ///
767    /// * If there's a connectivity issue to Postgres, an error will result
768    /// * If the source already exists, an error will result
769    #[cfg(any(test, feature = "admin"))]
770    pub async fn create_source(
771        &self,
772        name: &str,
773        description: Option<&str>,
774        url: Option<&str>,
775        date: chrono::DateTime<Local>,
776        releasable: bool,
777        malicious: Option<bool>,
778    ) -> Result<u32> {
779        match self {
780            DatabaseType::Postgres(pg) => {
781                pg.create_source(name, description, url, date, releasable, malicious)
782                    .await
783            }
784            #[cfg(any(test, feature = "sqlite"))]
785            DatabaseType::SQLite(sl) => {
786                sl.create_source(name, description, url, date, releasable, malicious)
787            }
788        }
789    }
790
791    /// Edit a user account setting the specified field values, primarily used by the Admin gui
792    ///
793    /// # Errors
794    ///
795    /// * If there's a connectivity issue to Postgres, an error will result
796    /// * If the user isn't valid, an error will result
797    #[cfg(any(test, feature = "admin"))]
798    pub async fn edit_user(
799        &self,
800        uid: u32,
801        uname: &str,
802        fname: &str,
803        lname: &str,
804        email: &str,
805        readonly: bool,
806    ) -> Result<()> {
807        match self {
808            DatabaseType::Postgres(pg) => {
809                pg.edit_user(uid, uname, fname, lname, email, readonly)
810                    .await
811            }
812            #[cfg(any(test, feature = "sqlite"))]
813            DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
814        }
815    }
816
817    /// Deactivate, but don't delete, a user's account
818    ///
819    /// # Errors
820    ///
821    /// * If there's a connectivity issue to Postgres, an error will result
822    /// * If the user ID isn't valid, an error will result
823    #[cfg(any(test, feature = "admin"))]
824    pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
825        match self {
826            DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
827            #[cfg(any(test, feature = "sqlite"))]
828            DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
829        }
830    }
831
832    /// File types and number of files per type
833    ///
834    /// # Errors
835    ///
836    /// * If there's a connectivity issue to Postgres, an error will result
837    #[cfg(any(test, feature = "admin"))]
838    pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
839        match self {
840            DatabaseType::Postgres(pg) => pg.file_types_counts().await,
841            #[cfg(any(test, feature = "sqlite"))]
842            DatabaseType::SQLite(sl) => sl.file_types_counts(),
843        }
844    }
845
846    /// Create a new label
847    ///
848    /// # Errors
849    ///
850    /// * If there's a connectivity issue to Postgres, an error will result
851    /// * If the parent label doesn't exist, an error will result
852    /// * If the label name already exists, an error will result
853    #[cfg(any(test, feature = "admin"))]
854    pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
855        match self {
856            DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
857            #[cfg(any(test, feature = "sqlite"))]
858            DatabaseType::SQLite(sl) => sl.create_label(name, parent),
859        }
860    }
861
862    /// Edit a label name or parent
863    ///
864    /// # Errors
865    ///
866    /// * If there's a connectivity issue to Postgres, an error will result
867    /// * If the parent label doesn't exist, an error will result
868    /// * If the label name already exists, an error will result
869    #[cfg(any(test, feature = "admin"))]
870    pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
871        match self {
872            DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
873            #[cfg(any(test, feature = "sqlite"))]
874            DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
875        }
876    }
877
878    /// Return the ID for a given label name
879    ///
880    /// # Errors
881    ///
882    /// * If there's a connectivity issue to Postgres, an error will result
883    /// * If the label ID is invalid, an error will result
884    #[cfg(any(test, feature = "admin"))]
885    pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
886        match self {
887            DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
888            #[cfg(any(test, feature = "sqlite"))]
889            DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
890        }
891    }
892}
893
894/// Hash a password string with [Argon2]
895///
896/// # Errors
897///
898/// An error may result if Argon has an error
899pub fn hash_password(password: &str) -> Result<String> {
900    let salt = SaltString::generate(&mut OsRng);
901    let argon2 = Argon2::default();
902    Ok(argon2
903        .hash_password(password.as_bytes(), &salt)?
904        .to_string())
905}
906
907/// Generate a new, random API key
908#[must_use]
909pub fn random_bytes_api_key() -> String {
910    let key1 = uuid::Uuid::new_v4();
911    let key2 = uuid::Uuid::new_v4();
912    let key1 = key1.to_string().replace('-', "");
913    let key2 = key2.to_string().replace('-', "");
914    format!("{key1}{key2}")
915}
916
917#[cfg(test)]
918mod tests {
919    use super::*;
920    #[cfg(feature = "vt")]
921    use crate::vt::VtUpdater;
922
923    use std::fs;
924    #[cfg(feature = "vt")]
925    use std::time::SystemTime;
926
927    use anyhow::Context;
928    use fuzzyhash::FuzzyHash;
929    use malwaredb_api::PartialHashSearchType;
930    use malwaredb_lzjd::{LZDict, Murmur3HashState};
931    use tlsh_fixed::TlshBuilder;
932    use uuid::Uuid;
933
934    const MALWARE_LABEL: &str = "malware";
935    const RANSOMWARE_LABEL: &str = "ransomware";
936
937    fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
938        let mut hashes = vec![];
939
940        hashes.push((
941            malwaredb_api::SimilarityHashType::SSDeep,
942            FuzzyHash::new(data).to_string(),
943        ));
944
945        let mut builder = TlshBuilder::new(
946            tlsh_fixed::BucketKind::Bucket256,
947            tlsh_fixed::ChecksumKind::ThreeByte,
948            tlsh_fixed::Version::Version4,
949        );
950
951        builder.update(data);
952
953        if let Ok(hasher) = builder.build() {
954            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
955        }
956
957        let build_hasher = Murmur3HashState::default();
958        let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
959        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
960
961        malwaredb_api::SimilarSamplesRequest { hashes }
962    }
963
964    async fn pg_config() -> Postgres {
965        // create user malwaredbtesting with password 'malwaredbtesting';
966        // create database malwaredbtesting owner malwaredbtesting;
967        const CONNECTION_STRING: &str =
968            "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
969
970        if let Ok(pg_port) = std::env::var("PG_PORT") {
971            // Get the port number to run in Github CI
972            let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
973            Postgres::new(&conn_string, None)
974                .await
975                .context(format!(
976                    "failed to connect to postgres with specified port {pg_port}"
977                ))
978                .unwrap()
979        } else {
980            Postgres::new(CONNECTION_STRING, None).await.unwrap()
981        }
982    }
983
984    #[tokio::test]
985    #[ignore = "don't run this in CI"]
986    async fn pg() {
987        let psql = pg_config().await;
988        psql.delete_init().await.unwrap();
989
990        let db = DatabaseType::Postgres(psql);
991        let key = FileEncryption::from(EncryptionOption::Xor);
992        db.add_file_encryption_key(&key).await.unwrap();
993        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
994        everything(&db).await.unwrap();
995
996        #[cfg(feature = "vt")]
997        {
998            let db_config = db.get_config().await.unwrap();
999            let state = crate::State {
1000                port: 8080,
1001                directory: None,
1002                max_upload: 10 * 1024 * 1024,
1003                ip: "127.0.0.1".parse().unwrap(),
1004                db_type: db,
1005                db_config,
1006                keys: HashMap::new(),
1007                started: SystemTime::now(),
1008                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1009                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1010                }),
1011                cert: None,
1012                key: None,
1013            };
1014
1015            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1016
1017            vt.updater().await.unwrap();
1018            println!("PG: Did VT ops!");
1019
1020            let psql = pg_config().await;
1021
1022            let vt_stats = psql
1023                .get_vt_stats()
1024                .await
1025                .context("failed to get Postgres VT Stats")
1026                .unwrap();
1027            println!("{vt_stats:?}");
1028            assert!(
1029                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1030            );
1031        }
1032
1033        // Re-create the Postgres object so we can do some clean-up
1034        let psql = pg_config().await;
1035        psql.delete_init().await.unwrap();
1036    }
1037
1038    #[tokio::test]
1039    async fn sqlite() {
1040        const DB_FILE: &str = "testing_sqlite.db";
1041        if std::path::Path::new(DB_FILE).exists() {
1042            fs::remove_file(DB_FILE)
1043                .context(format!("failed to delete old SQLite file {DB_FILE}"))
1044                .unwrap();
1045        }
1046
1047        let sqlite = Sqlite::new(DB_FILE)
1048            .context(format!("failed to create SQLite instance for {DB_FILE}"))
1049            .unwrap();
1050
1051        let db = DatabaseType::SQLite(sqlite);
1052        let key = FileEncryption::from(EncryptionOption::Xor);
1053        db.add_file_encryption_key(&key).await.unwrap();
1054        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1055        everything(&db).await.unwrap();
1056
1057        #[cfg(feature = "vt")]
1058        {
1059            let db_config = db.get_config().await.unwrap();
1060            let state = crate::State {
1061                port: 8080,
1062                directory: None,
1063                max_upload: 10 * 1024 * 1024,
1064                ip: "127.0.0.1".parse().unwrap(),
1065                db_type: db,
1066                db_config,
1067                keys: HashMap::new(),
1068                started: SystemTime::now(),
1069                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1070                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1071                }),
1072                cert: None,
1073                key: None,
1074            };
1075
1076            let sqlite_second = Sqlite::new(DB_FILE)
1077                .context(format!("failed to create SQLite instance for {DB_FILE}"))
1078                .unwrap();
1079
1080            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1081
1082            vt.updater().await.unwrap();
1083            println!("Sqlite: Did VT ops!");
1084            let vt_stats = sqlite_second
1085                .get_vt_stats()
1086                .context("failed to get Sqlite VT Stats")
1087                .unwrap();
1088            println!("{vt_stats:?}");
1089            assert!(
1090                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1091            );
1092        }
1093
1094        fs::remove_file(DB_FILE)
1095            .context(format!("failed to delete SQLite file {DB_FILE}"))
1096            .unwrap();
1097    }
1098
1099    #[allow(clippy::too_many_lines)]
1100    async fn everything(db: &DatabaseType) -> Result<()> {
1101        const ADMIN_UNAME: &str = "admin";
1102        const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1103
1104        assert!(
1105            db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1106            "Authentication without password should have failed."
1107        );
1108
1109        db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1110            .await
1111            .context("failed to set admin password")?;
1112
1113        let admin_api_key = db
1114            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1115            .await
1116            .context("unable to get api key for admin")?;
1117        println!("API key: {admin_api_key}");
1118        assert_eq!(admin_api_key.len(), 64);
1119
1120        assert_eq!(
1121            db.get_uid(&admin_api_key).await?,
1122            0,
1123            "Unable to get UID given the API key"
1124        );
1125
1126        let admin_api_key_again = db
1127            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1128            .await
1129            .context("unable to get api key a second time for admin")?;
1130
1131        assert_eq!(
1132            admin_api_key, admin_api_key_again,
1133            "API keys didn't match the second time."
1134        );
1135
1136        let bad_password = "this_is_totally_not_my_password!!";
1137        eprintln!("Testing API login with incorrect password.");
1138        assert!(
1139            db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1140            "Authenticating as admin with a bad password should have failed."
1141        );
1142
1143        let admin_is_admin = db
1144            .user_is_admin(0)
1145            .await
1146            .context("unable to see if admin (uid 0) is an admin")?;
1147        assert!(admin_is_admin);
1148
1149        let new_user_uname = "testuser";
1150        let new_user_email = "test@example.com";
1151        let new_user_password = "some_awesome_password_++";
1152        let new_id = db
1153            .create_user(
1154                new_user_uname,
1155                new_user_uname,
1156                new_user_uname,
1157                new_user_email,
1158                Some(new_user_password.into()),
1159                None,
1160                false,
1161            )
1162            .await
1163            .context(format!("failed to create user {new_user_uname}"))?;
1164
1165        let passwordless_user_id = db
1166            .create_user(
1167                "passwordless_user",
1168                "passwordless_user",
1169                "passwordless_user",
1170                "passwordless_user@example.com",
1171                None,
1172                None,
1173                false,
1174            )
1175            .await
1176            .context("failed to create passwordless_user")?;
1177
1178        for user in &db.list_users().await.context("failed to list users")? {
1179            if user.id == passwordless_user_id {
1180                assert_eq!(user.uname, "passwordless_user");
1181            }
1182        }
1183
1184        db.edit_user(
1185            passwordless_user_id,
1186            "passwordless_user_2",
1187            "passwordless_user_2",
1188            "passwordless_user_2",
1189            "passwordless_user_2@something.com",
1190            false,
1191        )
1192        .await
1193        .context(format!(
1194            "failed to alter 'passwordless' user, id {passwordless_user_id}"
1195        ))?;
1196
1197        for user in &db.list_users().await.context("failed to list users")? {
1198            if user.id == passwordless_user_id {
1199                assert_eq!(user.uname, "passwordless_user_2");
1200            }
1201        }
1202
1203        assert!(
1204            new_id > 0,
1205            "Weird UID created for user {new_user_uname}: {new_id}"
1206        );
1207
1208        assert!(
1209            db.create_user(
1210                new_user_uname,
1211                new_user_uname,
1212                new_user_uname,
1213                new_user_email,
1214                Some(new_user_password.into()),
1215                None,
1216                false
1217            )
1218            .await
1219            .is_err(),
1220            "Creating a new user with the same user name should fail"
1221        );
1222
1223        let ro_user_name = "ro_user";
1224        let ro_user_password = "ro_user_password";
1225        db.create_user(
1226            ro_user_name,
1227            "ro_user",
1228            "ro_user",
1229            "ro@example.com",
1230            Some(ro_user_password.into()),
1231            None,
1232            true,
1233        )
1234        .await
1235        .context("failed to create read-only user")?;
1236
1237        let ro_user_api_key = db
1238            .authenticate(ro_user_name, ro_user_password)
1239            .await
1240            .context("unable to get api key for read-only user")?;
1241
1242        let new_user_password_change = "some_new_awesomer_password!_++";
1243        db.set_password(new_user_uname, new_user_password_change)
1244            .await
1245            .context("failed to change the password for testuser")?;
1246
1247        let new_user_api_key = db
1248            .authenticate(new_user_uname, new_user_password_change)
1249            .await
1250            .context("unable to get api key for testuser")?;
1251        eprintln!("{new_user_uname} got API key {new_user_api_key}");
1252
1253        assert_eq!(admin_api_key.len(), new_user_api_key.len());
1254
1255        let users = db.list_users().await.context("failed to list users")?;
1256        assert_eq!(
1257            users.len(),
1258            4,
1259            "Four users were created, yet there are {} users",
1260            users.len()
1261        );
1262        eprintln!("DB has {} users:", users.len());
1263        let mut passwordless_user_found = false;
1264        for user in users {
1265            println!("{user}");
1266            if user.uname == "passwordless_user_2" {
1267                assert!(!user.has_api_key);
1268                assert!(!user.has_password);
1269                passwordless_user_found = true;
1270            } else {
1271                assert!(user.has_api_key);
1272                assert!(user.has_password);
1273            }
1274        }
1275        assert!(passwordless_user_found);
1276
1277        let new_group_name = "some_new_group";
1278        let new_group_desc = "some_new_group_description";
1279        let new_group_id = 1;
1280        assert_eq!(
1281            db.create_group(new_group_name, new_group_desc, None)
1282                .await
1283                .context("failed to create group")?,
1284            new_group_id,
1285            "New group didn't have the expected ID, expected {new_group_id}"
1286        );
1287
1288        assert!(
1289            db.create_group(new_group_name, new_group_desc, None)
1290                .await
1291                .is_err(),
1292            "Duplicate group name should have failed"
1293        );
1294
1295        db.add_user_to_group(1, 1)
1296            .await
1297            .context("Unable to add uid 1 to gid 1")?;
1298
1299        let ro_user_uid = db
1300            .get_uid(&ro_user_api_key)
1301            .await
1302            .context("Unable to get UID for read-only user")?;
1303        db.add_user_to_group(ro_user_uid, 1)
1304            .await
1305            .context("Unable to add uid 2 to gid 1")?;
1306
1307        let new_admin_group_name = "admin_subgroup";
1308        let new_admin_group_desc = "admin_subgroup_description";
1309        let new_admin_group_id = 2;
1310        // TODO: Figure out why SQLite makes the group_id = 2, but with Postgres it's 3.
1311        assert!(
1312            db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1313                .await
1314                .context("failed to create admin sub-group")?
1315                >= new_admin_group_id,
1316            "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1317        );
1318
1319        let groups = db.list_groups().await.context("failed to list groups")?;
1320        assert_eq!(
1321            groups.len(),
1322            3,
1323            "Three groups were created, yet there are {} groups",
1324            groups.len()
1325        );
1326        eprintln!("DB has {} groups:", groups.len());
1327        for group in groups {
1328            println!("{group}");
1329            if group.id == new_admin_group_id {
1330                assert_eq!(group.parent, Some("admin".to_string()));
1331            }
1332            if group.id == 1 {
1333                let test_user_str = String::from(new_user_uname);
1334                let mut found = false;
1335                for member in group.members {
1336                    if member.uname == test_user_str {
1337                        found = true;
1338                        break;
1339                    }
1340                }
1341                assert!(found, "new user {test_user_str} wasn't in the group");
1342            }
1343        }
1344
1345        let default_source_name = "default_source".to_string();
1346        let default_source_id = db
1347            .create_source(
1348                &default_source_name,
1349                Some("desc_default_source"),
1350                None,
1351                Local::now(),
1352                true,
1353                Some(false),
1354            )
1355            .await
1356            .context("failed to create source `default_source`")?;
1357
1358        db.add_group_to_source(1, default_source_id)
1359            .await
1360            .context("failed to add group 1 to source 1")?;
1361
1362        let another_source_name = "another_source".to_string();
1363        let another_source_id = db
1364            .create_source(
1365                &another_source_name,
1366                Some("yet another file source"),
1367                None,
1368                Local::now(),
1369                true,
1370                Some(false),
1371            )
1372            .await
1373            .context("failed to create source `another_source`")?;
1374
1375        let empty_source_name = "empty_source".to_string();
1376        db.create_source(
1377            &empty_source_name,
1378            Some("empty and unused file source"),
1379            None,
1380            Local::now(),
1381            true,
1382            Some(false),
1383        )
1384        .await
1385        .context("failed to create source `another_source`")?;
1386
1387        db.add_group_to_source(1, another_source_id)
1388            .await
1389            .context("failed to add group 1 to source 1")?;
1390
1391        let sources = db.list_sources().await.context("failed to list sources")?;
1392        eprintln!("DB has {} sources:", sources.len());
1393        for source in sources {
1394            println!("{source}");
1395            assert_eq!(source.files, 0);
1396            if source.id == default_source_id || source.id == another_source_id {
1397                assert_eq!(
1398                    source.groups, 1,
1399                    "default source {default_source_name} should have 1 group"
1400                );
1401            } else {
1402                assert_eq!(source.groups, 0, "groups should zero (empty)");
1403            }
1404        }
1405
1406        let uid = db
1407            .get_uid(&new_user_api_key)
1408            .await
1409            .context("failed to user uid from apikey")?;
1410        let user_info = db
1411            .get_user_info(uid)
1412            .await
1413            .context("failed to get user's available groups and sources")?;
1414        assert!(user_info.sources.contains(&default_source_name));
1415        assert!(!user_info.is_admin);
1416        println!("UserInfoResponse: {user_info:?}");
1417
1418        assert!(
1419            db.allowed_user_source(1, default_source_id)
1420                .await
1421                .context(format!(
1422                    "failed to check that user 1 has access to source {default_source_id}"
1423                ))?,
1424            "User 1 should should have had access to source {default_source_id}"
1425        );
1426
1427        assert!(
1428            !db.allowed_user_source(1, 5)
1429                .await
1430                .context("failed to check that user 1 has access to source 5")?,
1431            "User 1 should should not have had access to source 5"
1432        );
1433
1434        let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1435        let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1436        let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1437
1438        let known_type =
1439            KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1440        assert!(known_type.is_exec(), "ELF should be executable");
1441        eprintln!("ELF type ID: {elf_type}");
1442
1443        assert!(db
1444            .add_file(
1445                &test_elf_meta,
1446                known_type.clone(),
1447                1,
1448                default_source_id,
1449                elf_type,
1450                None
1451            )
1452            .await
1453            .context("failed to insert a test elf")?);
1454        eprintln!("Added ELF to the DB");
1455
1456        let partial_search = SearchRequest {
1457            partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1458            ..Default::default()
1459        };
1460        let partial_search_response = db.partial_search(1, &partial_search).await?;
1461        assert_eq!(partial_search_response.len(), 1);
1462        assert_eq!(
1463            partial_search_response[0],
1464            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1465        );
1466
1467        let partial_search = SearchRequest {
1468            file_name: Some("ppc64".into()),
1469            ..Default::default()
1470        };
1471        let partial_search_response = db.partial_search(1, &partial_search).await?;
1472        assert_eq!(partial_search_response.len(), 1);
1473
1474        let partial_search = SearchRequest::default();
1475        let partial_search_response = db.partial_search(1, &partial_search).await?;
1476        assert!(partial_search_response.is_empty());
1477
1478        assert!(
1479            db.add_file(
1480                &test_elf_meta,
1481                known_type.clone(),
1482                ro_user_uid,
1483                default_source_id,
1484                elf_type,
1485                None
1486            )
1487            .await
1488            .is_err(),
1489            "Read-only user should not be able to add a file"
1490        );
1491
1492        let mut test_elf_meta_different_name = test_elf_meta.clone();
1493        test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1494
1495        assert!(!db
1496            .add_file(
1497                &test_elf_meta_different_name,
1498                known_type,
1499                1,
1500                another_source_id,
1501                elf_type,
1502                None
1503            )
1504            .await
1505            .context("failed to insert a test elf again for a different source")?);
1506
1507        let sources = db
1508            .list_sources()
1509            .await
1510            .context("failed to re-list sources")?;
1511        eprintln!(
1512            "DB has {} sources, and a file was added twice:",
1513            sources.len()
1514        );
1515        println!("We should have two sources with one file each, yet only one ELF.");
1516        for source in sources {
1517            println!("{source}");
1518            if source.id == default_source_id || source.id == another_source_id {
1519                assert_eq!(source.files, 1);
1520            } else {
1521                assert_eq!(source.files, 0, "groups should zero (empty)");
1522            }
1523        }
1524
1525        assert!(!db
1526            .get_user_sources(1)
1527            .await
1528            .expect("failed to get user 1's sources")
1529            .sources
1530            .is_empty());
1531
1532        let file_types_counts = db
1533            .file_types_counts()
1534            .await
1535            .context("failed to get file types and counts")?;
1536        for (name, count) in file_types_counts {
1537            println!("{name}: {count}");
1538            assert_eq!(name, "ELF");
1539            assert_eq!(count, 1);
1540        }
1541
1542        let mut test_elf_modified = test_elf.clone();
1543        let random_bytes = Uuid::new_v4();
1544        let mut random_bytes = random_bytes.into_bytes().to_vec();
1545        test_elf_modified.append(&mut random_bytes);
1546        let similarity_request = generate_similarity_request(&test_elf_modified);
1547        let similarity_response = db
1548            .find_similar_samples(1, &similarity_request.hashes)
1549            .await
1550            .context("failed to get similarity response")?;
1551        eprintln!("Similarity response: {similarity_response:?}");
1552        let similarity_response = similarity_response.first().unwrap();
1553        assert_eq!(
1554            similarity_response.sha256, test_elf_meta.sha256,
1555            "Similarity response should have had the hash of the original ELF"
1556        );
1557        for (algo, sim) in &similarity_response.algorithms {
1558            match algo {
1559                malwaredb_api::SimilarityHashType::LZJD => {
1560                    assert!(*sim > 0.0f32);
1561                }
1562                malwaredb_api::SimilarityHashType::SSDeep => {
1563                    assert!(*sim > 80.0f32);
1564                }
1565                malwaredb_api::SimilarityHashType::TLSH => {
1566                    assert!(*sim <= 20f32);
1567                }
1568                _ => {}
1569            }
1570        }
1571
1572        let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1)
1573            .context("failed to get `HashType::SHA1` from string")?;
1574        let response_sha256 = db
1575            .retrieve_sample(1, &test_elf_hashtype)
1576            .await
1577            .context("could not get SHA-256 hash from test sample")
1578            .unwrap();
1579        assert_eq!(response_sha256, test_elf_meta.sha256);
1580
1581        let test_bogus_hash = HashType::try_from(String::from(
1582            "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0",
1583        ))
1584        .context("failed to get `HashType` from static string")?;
1585        assert!(
1586            db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1587            "Getting a file with a bogus hash should have failed."
1588        );
1589
1590        let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1591        let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1592        let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1593
1594        let known_type =
1595            KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1596
1597        assert!(db
1598            .add_file(
1599                &test_pdf_meta,
1600                known_type,
1601                1,
1602                default_source_id,
1603                pdf_type,
1604                None
1605            )
1606            .await
1607            .context("failed to insert a test pdf")?);
1608        eprintln!("Added PDF to the DB");
1609
1610        let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1611        let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1612        let rtf_type = db
1613            .get_type_id_for_bytes(&test_rtf)
1614            .await
1615            .context("failed to get file type id for rtf")?;
1616
1617        let known_type =
1618            KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1619
1620        assert!(db
1621            .add_file(
1622                &test_rtf_meta,
1623                known_type,
1624                1,
1625                default_source_id,
1626                rtf_type,
1627                None
1628            )
1629            .await
1630            .context("failed to insert a test rtf")?);
1631        eprintln!("Added RTF to the DB");
1632
1633        let report = db
1634            .get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.clone())?)
1635            .await
1636            .context("failed to get report for test rtf")?;
1637        assert!(report
1638            .clone()
1639            .filecommand
1640            .unwrap()
1641            .contains("Rich Text Format"));
1642        println!("Report: {report}");
1643
1644        assert!(db
1645            .get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256)?)
1646            .await
1647            .is_err());
1648
1649        #[cfg(feature = "vt")]
1650        {
1651            assert!(report.vt.is_some());
1652            let files_needing_vt = db
1653                .files_without_vt_records(10)
1654                .await
1655                .context("failed to get files without VT records")?;
1656            assert!(files_needing_vt.len() > 2);
1657            println!(
1658                "{} files needing VT data: {files_needing_vt:?}",
1659                files_needing_vt.len()
1660            );
1661        }
1662
1663        #[cfg(not(feature = "vt"))]
1664        {
1665            assert!(report.vt.is_none());
1666        }
1667
1668        let reset = db
1669            .reset_api_keys()
1670            .await
1671            .context("failed to reset all API keys")?;
1672        eprintln!("Cleared {reset} api keys.");
1673
1674        let db_info = db.db_info().await.context("failed to get database info")?;
1675        eprintln!("DB Info: {db_info:?}");
1676
1677        let data_types = db
1678            .get_known_data_types()
1679            .await
1680            .context("failed to get data types")?;
1681        for data_type in data_types {
1682            println!("{data_type:?}");
1683        }
1684
1685        let sources = db
1686            .list_sources()
1687            .await
1688            .context("failed to list sources second time")?;
1689        eprintln!("DB has {} sources:", sources.len());
1690        for source in sources {
1691            println!("{source}");
1692        }
1693
1694        let file_types_counts = db
1695            .file_types_counts()
1696            .await
1697            .context("failed to get file types and counts")?;
1698        for (name, count) in file_types_counts {
1699            println!("{name}: {count}");
1700            assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1701        }
1702
1703        let fatmacho =
1704            include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
1705                .to_vec();
1706        let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
1707        let fatmacho_type = db
1708            .get_type_id_for_bytes(&fatmacho)
1709            .await
1710            .context("failed to get file type for Fat Mach-O")?;
1711        let known_type = KnownType::new(&fatmacho)
1712            .context("failed to parse Fat Mach-O from type crate's test data")?;
1713
1714        assert!(db
1715            .add_file(
1716                &fatmacho_meta,
1717                known_type,
1718                1,
1719                default_source_id,
1720                fatmacho_type,
1721                None
1722            )
1723            .await
1724            .context("failed to insert a test Fat Mach-O")?);
1725        eprintln!("Added Fat Mach-O to the DB");
1726
1727        let file_types_counts = db
1728            .file_types_counts()
1729            .await
1730            .context("failed to get file types and counts")?;
1731        for (name, count) in &file_types_counts {
1732            println!("{name}: {count}");
1733        }
1734
1735        assert_eq!(
1736            *file_types_counts.get("Mach-O").unwrap(),
1737            4,
1738            "Expected 4 Mach-O files, got {:?}",
1739            file_types_counts.get("Mach-O")
1740        );
1741
1742        let malware_label_id = db
1743            .create_label(MALWARE_LABEL, None)
1744            .await
1745            .context("failed to create first label")?;
1746        let ransomware_label_id = db
1747            .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
1748            .await
1749            .context("failed to create malware sub-label")?;
1750        let labels = db.get_labels().await.context("failed to get labels")?;
1751
1752        assert_eq!(labels.len(), 2);
1753        for label in labels.0 {
1754            if label.name == RANSOMWARE_LABEL {
1755                assert_eq!(label.id, ransomware_label_id);
1756                assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
1757            }
1758        }
1759
1760        // Use this file as a stand-un for an unknown file type
1761        let source_code = include_bytes!("mod.rs");
1762        let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
1763        let known_type =
1764            KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
1765
1766        assert!(matches!(known_type, KnownType::Unknown(_)));
1767
1768        let unknown_type: Vec<FileType> = db
1769            .get_known_data_types()
1770            .await?
1771            .into_iter()
1772            .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
1773            .collect();
1774        let unknown_type_id = unknown_type.first().unwrap().id;
1775        assert!(db.get_type_id_for_bytes(source_code).await.is_err());
1776        db.enable_keep_unknown_files()
1777            .await
1778            .context("failed to enable keeping of unknown files")?;
1779        let source_type = db
1780            .get_type_id_for_bytes(source_code)
1781            .await
1782            .context("failed to type id for source code unknown type example")?;
1783        assert_eq!(source_type, unknown_type_id);
1784        eprintln!("Unknown file type ID: {source_type}");
1785        assert!(db
1786            .add_file(
1787                &source_meta,
1788                known_type,
1789                1,
1790                default_source_id,
1791                unknown_type_id,
1792                None
1793            )
1794            .await
1795            .context("failed to add Rust source code file")?);
1796        eprintln!("Added Rust source code to the DB");
1797
1798        db.reset_own_api_key(0)
1799            .await
1800            .context("failed to clear own API key uid 0")?;
1801
1802        db.deactivate_user(0)
1803            .await
1804            .context("failed to clear password and API key for uid 0")?;
1805
1806        Ok(())
1807    }
1808}