malwaredb_server/db/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Postgres is the database used by MalwareDB.
4//! However, SQLite will be used for unit testing. This option can be allowed by using the `sqlite`
5//! feature flag. When using SQLite, any similarity functions must be calculated by MalwareDB.
6
7/// Malware DB Administrative functions
8#[cfg(any(test, feature = "admin"))]
9mod admin;
10/// Postgres functions
11mod pg;
12
13/// `SQLite` functions
14#[cfg(any(test, feature = "sqlite"))]
15mod sqlite;
16
17/// File Metadata convenience data structure
18pub mod types;
19
20#[cfg(any(test, feature = "admin"))]
21use crate::crypto::EncryptionOption;
22use crate::crypto::FileEncryption;
23use crate::db::pg::Postgres;
24#[cfg(any(test, feature = "sqlite"))]
25use crate::db::sqlite::Sqlite;
26use crate::db::types::{FileMetadata, FileType};
27use malwaredb_api::{digest::HashType, GetUserInfoResponse, Labels, SearchRequest, Sources};
28use malwaredb_types::KnownType;
29
30use std::collections::HashMap;
31use std::path::PathBuf;
32
33use anyhow::{bail, ensure, Result};
34use argon2::password_hash::{rand_core::OsRng, SaltString};
35use argon2::{Argon2, PasswordHasher};
36#[cfg(any(test, feature = "admin"))]
37use chrono::Local;
38#[cfg(feature = "vt")]
39use malwaredb_virustotal::filereport::ScanResultAttributes;
40
41/// The maximum amount of partial hash and/or partial file name search results to prevent performance issues
42pub const PARTIAL_SEARCH_LIMIT: u32 = 100;
43
44/// Database handle for `SQLite` or Postgres
45#[derive(Debug)]
46pub enum DatabaseType {
47    /// Postgres database
48    Postgres(Postgres),
49
50    /// `SQLite` database
51    #[cfg(any(test, feature = "sqlite"))]
52    SQLite(Sqlite),
53}
54
55/// Version information for the database, and some basic stats
56#[derive(Debug)]
57pub struct DatabaseInformation {
58    /// Version string of the database
59    pub version: String,
60
61    /// Human-readable database size
62    pub size: String,
63
64    /// Number of file samples in Malware DB
65    pub num_files: u64,
66
67    /// Number of user accounts
68    pub num_users: u32,
69
70    /// Number of user groups
71    pub num_groups: u32,
72
73    /// Number of sample sources
74    pub num_sources: u32,
75}
76
77/// Malware DB configuration which is stored in the database
78#[derive(Debug)]
79pub struct MDBConfig {
80    /// The name of this instance of Malware DB
81    pub name: String,
82
83    /// Whether samples are stored compressed
84    pub compression: bool,
85
86    /// Whether Malware DB can send samples to Virus Total
87    pub send_samples_to_vt: bool,
88
89    /// If Malware DB should keep unknown files
90    pub keep_unknown_files: bool,
91
92    /// If samples are to be encrypted, which key?
93    #[allow(dead_code)]
94    pub(crate) default_key: Option<u32>,
95}
96
97/// VT record information for files in Malware DB
98#[cfg(feature = "vt")]
99#[derive(Debug, Clone, Copy)]
100pub struct VtStats {
101    /// Files marked as clean
102    pub clean_records: u32,
103
104    /// Files marked as malicious
105    pub hits_records: u32,
106
107    /// Files without VT records
108    pub files_without_records: u32,
109}
110
111impl DatabaseType {
112    /// Get a database connection from a configuration string
113    ///
114    /// # Errors
115    ///
116    /// * If there's a connectivity issue to Postgres, an error will result
117    /// * If the `SQLite` file cannot be created or opened, an error will result
118    /// * If `SQLite` is the type but Malware DB wasn't compiled with the sqlite feature, an error will result
119    /// * If the format or database type isn't known, an error will result
120    pub async fn from_string(arg: &str, server_ca: Option<PathBuf>) -> Result<Self> {
121        #[cfg(any(test, feature = "sqlite"))]
122        if arg.starts_with("file:") {
123            let new_conn_str = arg.trim_start_matches("file:");
124            return Ok(DatabaseType::SQLite(Sqlite::new(new_conn_str)?));
125        }
126
127        if arg.starts_with("postgres") {
128            let new_conn_str = arg.trim_start_matches("postgres");
129            return Ok(DatabaseType::Postgres(
130                Postgres::new(new_conn_str, server_ca).await?,
131            ));
132        }
133
134        bail!("unknown database type `{arg}`")
135    }
136
137    /// Set the flag allowing uploads to Virus Total.
138    ///
139    /// # Errors
140    ///
141    /// If there's a connectivity issue to Postgres, an error will result
142    #[cfg(feature = "vt")]
143    pub async fn enable_vt_upload(&self) -> Result<()> {
144        match self {
145            DatabaseType::Postgres(pg) => pg.enable_vt_upload().await,
146            #[cfg(any(test, feature = "sqlite"))]
147            DatabaseType::SQLite(sl) => sl.enable_vt_upload(),
148        }
149    }
150
151    /// Set the flag preventing uploads to Virus Total.
152    ///
153    /// # Errors
154    ///
155    /// If there's a connectivity issue to Postgres, an error will result
156    #[cfg(feature = "vt")]
157    pub async fn disable_vt_upload(&self) -> Result<()> {
158        match self {
159            DatabaseType::Postgres(pg) => pg.disable_vt_upload().await,
160            #[cfg(any(test, feature = "sqlite"))]
161            DatabaseType::SQLite(sl) => sl.disable_vt_upload(),
162        }
163    }
164
165    /// Get the SHA-256 hashes of the files which don't have VT records
166    ///
167    /// # Errors
168    ///
169    /// If there's a connectivity issue to Postgres, an error will result
170    #[cfg(feature = "vt")]
171    pub async fn files_without_vt_records(&self, limit: u32) -> Result<Vec<String>> {
172        match self {
173            DatabaseType::Postgres(pg) => pg.files_without_vt_records(limit).await,
174            #[cfg(any(test, feature = "sqlite"))]
175            DatabaseType::SQLite(sl) => sl.files_without_vt_records(limit),
176        }
177    }
178
179    /// Store the VT results: AV hits and detailed report, or lack of any AV hits
180    ///
181    /// # Errors
182    ///
183    /// If there's a connectivity issue to Postgres, an error will result
184    #[cfg(feature = "vt")]
185    pub async fn store_vt_record(&self, results: &ScanResultAttributes) -> Result<()> {
186        match self {
187            DatabaseType::Postgres(pg) => pg.store_vt_record(results).await,
188            #[cfg(any(test, feature = "sqlite"))]
189            DatabaseType::SQLite(sl) => sl.store_vt_record(results),
190        }
191    }
192
193    /// Quick statistics regarding the data contained for VT information for our samples
194    ///
195    /// # Errors
196    ///
197    /// If there's a connectivity issue to Postgres, an error will result
198    #[cfg(feature = "vt")]
199    pub async fn get_vt_stats(&self) -> Result<VtStats> {
200        match self {
201            DatabaseType::Postgres(pg) => pg.get_vt_stats().await,
202            #[cfg(any(test, feature = "sqlite"))]
203            DatabaseType::SQLite(sl) => sl.get_vt_stats(),
204        }
205    }
206
207    /// Get the configuration which is stored in the database
208    ///
209    /// # Errors
210    ///
211    /// If there's a connectivity issue to Postgres, an error will result
212    pub async fn get_config(&self) -> Result<MDBConfig> {
213        match self {
214            DatabaseType::Postgres(pg) => pg.get_config().await,
215            #[cfg(any(test, feature = "sqlite"))]
216            DatabaseType::SQLite(sl) => sl.get_config(),
217        }
218    }
219
220    /// Check user credentials, return API key. Generate if it doesn't exist.
221    ///
222    /// # Errors
223    ///
224    /// * If there's a connectivity issue to Postgres, an error will result
225    /// * If the user name and/or password aren't correct, an error will result
226    pub async fn authenticate(&self, uname: &str, password: &str) -> Result<String> {
227        match self {
228            DatabaseType::Postgres(pg) => pg.authenticate(uname, password).await,
229            #[cfg(any(test, feature = "sqlite"))]
230            DatabaseType::SQLite(sl) => sl.authenticate(uname, password),
231        }
232    }
233
234    /// Get the user's ID from their API key
235    ///
236    /// # Errors
237    ///
238    /// * If there's a connectivity issue to Postgres, an error will result
239    /// * If the api key isn't valid, an error will result
240    pub async fn get_uid(&self, apikey: &str) -> Result<u32> {
241        ensure!(!apikey.is_empty(), "API key was empty");
242        match self {
243            DatabaseType::Postgres(pg) => pg.get_uid(apikey).await,
244            #[cfg(any(test, feature = "sqlite"))]
245            DatabaseType::SQLite(sl) => sl.get_uid(apikey),
246        }
247    }
248
249    /// Retrieve information about the database
250    ///
251    /// # Errors
252    ///
253    /// * If there's a connectivity issue to Postgres, an error will result
254    pub async fn db_info(&self) -> Result<DatabaseInformation> {
255        match self {
256            DatabaseType::Postgres(pg) => pg.db_info().await,
257            #[cfg(any(test, feature = "sqlite"))]
258            DatabaseType::SQLite(sl) => sl.db_info(),
259        }
260    }
261
262    /// Retrieve the names of the groups and sources the user is part of and has access to
263    ///
264    /// # Errors
265    ///
266    /// * If there's a connectivity issue to Postgres, an error will result
267    /// * If the user ID isn't valid, an error will result
268    pub async fn get_user_info(&self, uid: u32) -> Result<GetUserInfoResponse> {
269        match self {
270            DatabaseType::Postgres(pg) => pg.get_user_info(uid).await,
271            #[cfg(any(test, feature = "sqlite"))]
272            DatabaseType::SQLite(sl) => sl.get_user_info(uid),
273        }
274    }
275
276    /// Retrieve the source information available to the specified user
277    ///
278    /// # Errors
279    ///
280    /// * If there's a connectivity issue to Postgres, an error will result
281    /// * If the user ID isn't valid, an error will result
282    pub async fn get_user_sources(&self, uid: u32) -> Result<Sources> {
283        match self {
284            DatabaseType::Postgres(pg) => pg.get_user_sources(uid).await,
285            #[cfg(any(test, feature = "sqlite"))]
286            DatabaseType::SQLite(sl) => sl.get_user_sources(uid),
287        }
288    }
289
290    /// Let the user clear their own API key to log out from all systems
291    ///
292    /// # Errors
293    ///
294    /// * If there's a connectivity issue to Postgres, an error will result
295    /// * If the user ID isn't valid, an error will result
296    pub async fn reset_own_api_key(&self, uid: u32) -> Result<()> {
297        match self {
298            DatabaseType::Postgres(pg) => pg.reset_own_api_key(uid).await,
299            #[cfg(any(test, feature = "sqlite"))]
300            DatabaseType::SQLite(sl) => sl.reset_own_api_key(uid),
301        }
302    }
303
304    /// Retrieve the supported data type information
305    ///
306    /// # Errors
307    ///
308    /// If there's a connectivity issue to Postgres, an error will result
309    pub async fn get_known_data_types(&self) -> Result<Vec<FileType>> {
310        match self {
311            DatabaseType::Postgres(pg) => pg.get_known_data_types().await,
312            #[cfg(any(test, feature = "sqlite"))]
313            DatabaseType::SQLite(sl) => sl.get_known_data_types(),
314        }
315    }
316
317    /// Get all labels from Malware DB
318    ///
319    /// # Errors
320    ///
321    /// If there's a connectivity issue to Postgres, an error will result
322    pub async fn get_labels(&self) -> Result<Labels> {
323        match self {
324            DatabaseType::Postgres(pg) => pg.get_labels().await,
325            #[cfg(any(test, feature = "sqlite"))]
326            DatabaseType::SQLite(sl) => sl.get_labels(),
327        }
328    }
329
330    /// Get the corresponding type ID for a buffer representing a file
331    ///
332    /// # Errors
333    ///
334    /// If there's a connectivity issue to Postgres, an error will result
335    pub async fn get_type_id_for_bytes(&self, data: &[u8]) -> Result<u32> {
336        match self {
337            DatabaseType::Postgres(pg) => pg.get_type_id_for_bytes(data).await,
338            #[cfg(any(test, feature = "sqlite"))]
339            DatabaseType::SQLite(sl) => sl.get_type_id_for_bytes(data),
340        }
341    }
342
343    /// Check that a user has been granted access data for the specific source
344    ///
345    /// # Errors
346    ///
347    /// * If there's a connectivity issue to Postgres, an error will result
348    /// * If the user or source ID(s) aren't valid, an error will result
349    pub async fn allowed_user_source(&self, uid: u32, sid: u32) -> Result<bool> {
350        match self {
351            DatabaseType::Postgres(pg) => pg.allowed_user_source(uid, sid).await,
352            #[cfg(any(test, feature = "sqlite"))]
353            DatabaseType::SQLite(sl) => sl.allowed_user_source(uid, sid),
354        }
355    }
356
357    /// Check to see if the user is an administrator. The user must be a member of the
358    /// admin group (group ID 0), or a one group below (a group with the parent group id of 0).
359    ///
360    /// # Errors
361    ///
362    /// * If there's a connectivity issue to Postgres, an error will result
363    /// * If the user ID isn't valid, an error will result
364    pub async fn user_is_admin(&self, uid: u32) -> Result<bool> {
365        match self {
366            DatabaseType::Postgres(pg) => pg.user_is_admin(uid).await,
367            #[cfg(any(test, feature = "sqlite"))]
368            DatabaseType::SQLite(sl) => sl.user_is_admin(uid),
369        }
370    }
371
372    /// Add a file's metadata to the database, returning true if this is a new entry
373    ///
374    /// # Errors
375    ///
376    /// * If there's a connectivity issue to Postgres, an error will result
377    /// * If the source doesn't exist or the user isn't part of a member group, an error will result
378    pub async fn add_file(
379        &self,
380        meta: &FileMetadata,
381        known_type: KnownType<'_>,
382        uid: u32,
383        sid: u32,
384        ftype: u32,
385        parent: Option<u64>,
386    ) -> Result<bool> {
387        match self {
388            DatabaseType::Postgres(pg) => {
389                pg.add_file(meta, known_type, uid, sid, ftype, parent).await
390            }
391            #[cfg(any(test, feature = "sqlite"))]
392            DatabaseType::SQLite(sl) => sl.add_file(meta, &known_type, uid, sid, ftype, parent),
393        }
394    }
395
396    /// Search for allowed samples based on partial search and/or file name
397    ///
398    /// # Errors
399    ///
400    /// * If there's a connectivity issue to Postgres, an error will result
401    pub async fn partial_search(&self, uid: u32, search: &SearchRequest) -> Result<Vec<String>> {
402        match self {
403            DatabaseType::Postgres(pg) => pg.partial_search(uid, search).await,
404            #[cfg(any(test, feature = "sqlite"))]
405            DatabaseType::SQLite(sl) => sl.partial_search(uid, search),
406        }
407    }
408
409    /// Retrieve the SHA-256 hash of the sample while checking that the user is permitted
410    /// to access to it
411    ///
412    /// # Errors
413    ///
414    /// * If there's a connectivity issue to Postgres, an error will result
415    /// * If the file doesn't exist or the user isn't allowed access, an error will result
416    pub async fn retrieve_sample(&self, uid: u32, hash: &HashType) -> Result<String> {
417        match self {
418            DatabaseType::Postgres(pg) => pg.retrieve_sample(uid, hash).await,
419            #[cfg(any(test, feature = "sqlite"))]
420            DatabaseType::SQLite(sl) => sl.retrieve_sample(uid, hash),
421        }
422    }
423
424    /// Retrieve a report for a given sample, if allowed.
425    ///
426    /// # Errors
427    ///
428    /// * If there's a connectivity issue to Postgres, an error will result
429    /// * If the file doesn't exist or the user isn't allowed access, an error will result
430    pub async fn get_sample_report(
431        &self,
432        uid: u32,
433        hash: &HashType,
434    ) -> Result<malwaredb_api::Report> {
435        match self {
436            DatabaseType::Postgres(pg) => pg.get_sample_report(uid, hash).await,
437            #[cfg(any(test, feature = "sqlite"))]
438            DatabaseType::SQLite(sl) => sl.get_sample_report(uid, hash),
439        }
440    }
441
442    /// Given a collection of similarity hashes, find samples which are similar.
443    ///
444    /// # Errors
445    ///
446    /// If there's a connectivity issue to Postgres, an error will result
447    pub async fn find_similar_samples(
448        &self,
449        uid: u32,
450        sim: &[(malwaredb_api::SimilarityHashType, String)],
451    ) -> Result<Vec<malwaredb_api::SimilarSample>> {
452        match self {
453            DatabaseType::Postgres(pg) => pg.find_similar_samples(uid, sim).await,
454            #[cfg(any(test, feature = "sqlite"))]
455            DatabaseType::SQLite(sl) => sl.find_similar_samples(uid, sim),
456        }
457    }
458
459    // Private functions
460
461    /// Get the file encryption keys
462    pub(crate) async fn get_encryption_keys(&self) -> Result<HashMap<u32, FileEncryption>> {
463        match self {
464            DatabaseType::Postgres(pg) => pg.get_encryption_keys().await,
465            #[cfg(any(test, feature = "sqlite"))]
466            DatabaseType::SQLite(sl) => sl.get_encryption_keys(),
467        }
468    }
469
470    /// Get the key and AES nonce for a specific file identified by SHA-256 hash
471    pub(crate) async fn get_file_encryption_key_id(
472        &self,
473        hash: &str,
474    ) -> Result<(Option<u32>, Option<Vec<u8>>)> {
475        match self {
476            DatabaseType::Postgres(pg) => pg.get_file_encryption_key_id(hash).await,
477            #[cfg(any(test, feature = "sqlite"))]
478            DatabaseType::SQLite(sl) => sl.get_file_encryption_key_id(hash),
479        }
480    }
481
482    /// Set the AES nonce for a file specified by SHA-256 hash, a `None` value removes it.
483    pub(crate) async fn set_file_nonce(&self, hash: &str, nonce: Option<&[u8]>) -> Result<()> {
484        match self {
485            DatabaseType::Postgres(pg) => pg.set_file_nonce(hash, nonce).await,
486            #[cfg(any(test, feature = "sqlite"))]
487            DatabaseType::SQLite(sl) => sl.set_file_nonce(hash, nonce),
488        }
489    }
490
491    // Administrative functions
492
493    /// Set the instance name
494    ///
495    /// # Errors
496    ///
497    /// If there's a connectivity issue to Postgres, an error will result
498    #[cfg(any(test, feature = "admin"))]
499    pub async fn set_name(&self, name: &str) -> Result<()> {
500        match self {
501            DatabaseType::Postgres(pg) => pg.set_name(name).await,
502            #[cfg(any(test, feature = "sqlite"))]
503            DatabaseType::SQLite(sl) => sl.set_name(name),
504        }
505    }
506
507    /// Set the compression flag
508    ///
509    /// # Errors
510    ///
511    /// If there's a connectivity issue to Postgres, an error will result
512    #[cfg(any(test, feature = "admin"))]
513    pub async fn enable_compression(&self) -> Result<()> {
514        match self {
515            DatabaseType::Postgres(pg) => pg.enable_compression().await,
516            #[cfg(any(test, feature = "sqlite"))]
517            DatabaseType::SQLite(sl) => sl.enable_compression(),
518        }
519    }
520
521    /// Unset the compression flag, does not go and decompress files already compressed!
522    ///
523    /// # Errors
524    ///
525    /// If there's a connectivity issue to Postgres, an error will result
526    #[cfg(any(test, feature = "admin"))]
527    pub async fn disable_compression(&self) -> Result<()> {
528        match self {
529            DatabaseType::Postgres(pg) => pg.disable_compression().await,
530            #[cfg(any(test, feature = "sqlite"))]
531            DatabaseType::SQLite(sl) => sl.disable_compression(),
532        }
533    }
534
535    /// Set the keep unknown files flag
536    ///
537    /// # Errors
538    ///
539    /// If there's a connectivity issue to Postgres, an error will result
540    #[cfg(any(test, feature = "admin"))]
541    pub async fn enable_keep_unknown_files(&self) -> Result<()> {
542        match self {
543            DatabaseType::Postgres(pg) => pg.enable_keep_unknown_files().await,
544            #[cfg(any(test, feature = "sqlite"))]
545            DatabaseType::SQLite(sl) => sl.enable_keep_unknown_files(),
546        }
547    }
548
549    /// Unset the keep unknown files flag, does not go and remove unknown files!
550    ///
551    /// # Errors
552    ///
553    /// If there's a connectivity issue to Postgres, an error will result
554    #[cfg(any(test, feature = "admin"))]
555    pub async fn disable_keep_unknown_files(&self) -> Result<()> {
556        match self {
557            DatabaseType::Postgres(pg) => pg.disable_keep_unknown_files().await,
558            #[cfg(any(test, feature = "sqlite"))]
559            DatabaseType::SQLite(sl) => sl.disable_keep_unknown_files(),
560        }
561    }
562
563    /// Add an encryption key to the database, set it as the default, and return the key ID
564    ///
565    /// # Errors
566    ///
567    /// * If there is a connectivity with Postgres, an error will result
568    #[cfg(any(test, feature = "admin"))]
569    pub async fn add_file_encryption_key(&self, key: &FileEncryption) -> Result<u32> {
570        match self {
571            DatabaseType::Postgres(pg) => pg.add_file_encryption_key(key).await,
572            #[cfg(any(test, feature = "sqlite"))]
573            DatabaseType::SQLite(sl) => sl.add_file_encryption_key(key),
574        }
575    }
576
577    /// Get the key ID and algorithm names
578    ///
579    /// # Errors
580    ///
581    /// If there is a connectivity with Postgres, an error will result
582    #[cfg(any(test, feature = "admin"))]
583    pub async fn get_encryption_key_names_ids(&self) -> Result<Vec<(u32, EncryptionOption)>> {
584        match self {
585            DatabaseType::Postgres(pg) => pg.get_encryption_key_names_ids().await,
586            #[cfg(any(test, feature = "sqlite"))]
587            DatabaseType::SQLite(sl) => sl.get_encryption_key_names_ids(),
588        }
589    }
590
591    /// Create a user account, return the user ID.
592    ///
593    /// # Errors
594    ///
595    /// If there's a connectivity issue to Postgres, an error will result
596    #[allow(clippy::too_many_arguments)]
597    #[cfg(any(test, feature = "admin"))]
598    pub async fn create_user(
599        &self,
600        uname: &str,
601        fname: &str,
602        lname: &str,
603        email: &str,
604        password: Option<String>,
605        organisation: Option<&String>,
606        readonly: bool,
607    ) -> Result<u32> {
608        match self {
609            DatabaseType::Postgres(pg) => {
610                pg.create_user(uname, fname, lname, email, password, organisation, readonly)
611                    .await
612            }
613            #[cfg(any(test, feature = "sqlite"))]
614            DatabaseType::SQLite(sl) => {
615                sl.create_user(uname, fname, lname, email, password, organisation, readonly)
616            }
617        }
618    }
619
620    /// Clear all API keys, either in case of suspected activity, or part of policy
621    ///
622    /// # Errors
623    ///
624    /// If there's a connectivity issue to Postgres, an error will result
625    #[cfg(any(test, feature = "admin"))]
626    pub async fn reset_api_keys(&self) -> Result<u64> {
627        match self {
628            DatabaseType::Postgres(pg) => pg.reset_api_keys().await,
629            #[cfg(any(test, feature = "sqlite"))]
630            DatabaseType::SQLite(sl) => sl.reset_api_keys(),
631        }
632    }
633
634    /// Set a user's password
635    ///
636    /// # Errors
637    ///
638    /// * If there's a connectivity issue to Postgres, an error will result
639    /// * If the user doesn't exist, an error will result
640    #[cfg(any(test, feature = "admin"))]
641    pub async fn set_password(&self, uname: &str, password: &str) -> Result<()> {
642        match self {
643            DatabaseType::Postgres(pg) => pg.set_password(uname, password).await,
644            #[cfg(any(test, feature = "sqlite"))]
645            DatabaseType::SQLite(sl) => sl.set_password(uname, password),
646        }
647    }
648
649    /// Get the complete list of users
650    ///
651    /// # Errors
652    ///
653    /// If there's a connectivity issue to Postgres, an error will result
654    #[cfg(any(test, feature = "admin"))]
655    pub async fn list_users(&self) -> Result<Vec<admin::User>> {
656        match self {
657            DatabaseType::Postgres(pg) => pg.list_users().await,
658            #[cfg(any(test, feature = "sqlite"))]
659            DatabaseType::SQLite(sl) => sl.list_users(),
660        }
661    }
662
663    /// Get the ID of a group from name
664    ///
665    /// # Errors
666    ///
667    /// * If there's a connectivity issue to Postgres, an error will result
668    /// * If the group name isn't valid, an error will result
669    #[cfg(any(test, feature = "admin"))]
670    pub async fn group_id_from_name(&self, name: &str) -> Result<i32> {
671        match self {
672            DatabaseType::Postgres(pg) => pg.group_id_from_name(name).await,
673            #[cfg(any(test, feature = "sqlite"))]
674            DatabaseType::SQLite(sl) => sl.group_id_from_name(name),
675        }
676    }
677
678    /// Update the record for a group
679    ///
680    /// # Errors
681    ///
682    /// * If there's a connectivity issue to Postgres, an error will result
683    /// * If the group doesn't exist, an error will result
684    #[cfg(any(test, feature = "admin"))]
685    pub async fn edit_group(
686        &self,
687        gid: u32,
688        name: &str,
689        desc: &str,
690        parent: Option<u32>,
691    ) -> Result<()> {
692        match self {
693            DatabaseType::Postgres(pg) => pg.edit_group(gid, name, desc, parent).await,
694            #[cfg(any(test, feature = "sqlite"))]
695            DatabaseType::SQLite(sl) => sl.edit_group(gid, name, desc, parent),
696        }
697    }
698
699    /// Get the complete list of groups
700    ///
701    /// # Errors
702    ///
703    /// If there's a connectivity issue to Postgres, an error will result
704    #[cfg(any(test, feature = "admin"))]
705    pub async fn list_groups(&self) -> Result<Vec<admin::Group>> {
706        match self {
707            DatabaseType::Postgres(pg) => pg.list_groups().await,
708            #[cfg(any(test, feature = "sqlite"))]
709            DatabaseType::SQLite(sl) => sl.list_groups(),
710        }
711    }
712
713    /// Grant a user membership to a group, both by id.
714    ///
715    /// # Errors
716    ///
717    /// * If there's a connectivity issue to Postgres, an error will result
718    /// * If the user or group doesn't exist, an error will result
719    #[cfg(any(test, feature = "admin"))]
720    pub async fn add_user_to_group(&self, uid: u32, gid: u32) -> Result<()> {
721        match self {
722            DatabaseType::Postgres(pg) => pg.add_user_to_group(uid, gid).await,
723            #[cfg(any(test, feature = "sqlite"))]
724            DatabaseType::SQLite(sl) => sl.add_user_to_group(uid, gid),
725        }
726    }
727
728    /// Grand a group access to a source, both by id.
729    ///
730    /// # Errors
731    ///
732    /// * If there's a connectivity issue to Postgres, an error will result
733    /// * If the group or source doesn't exist, an error will result
734    #[cfg(any(test, feature = "admin"))]
735    pub async fn add_group_to_source(&self, gid: u32, sid: u32) -> Result<()> {
736        match self {
737            DatabaseType::Postgres(pg) => pg.add_group_to_source(gid, sid).await,
738            #[cfg(any(test, feature = "sqlite"))]
739            DatabaseType::SQLite(sl) => sl.add_group_to_source(gid, sid),
740        }
741    }
742
743    /// Create a new group, returning the group ID
744    ///
745    /// # Errors
746    ///
747    /// * If there's a connectivity issue to Postgres, an error will result
748    /// * If the group name is already taken, an error will result
749    #[cfg(any(test, feature = "admin"))]
750    pub async fn create_group(
751        &self,
752        name: &str,
753        description: &str,
754        parent: Option<u32>,
755    ) -> Result<u32> {
756        match self {
757            DatabaseType::Postgres(pg) => pg.create_group(name, description, parent).await,
758            #[cfg(any(test, feature = "sqlite"))]
759            DatabaseType::SQLite(sl) => sl.create_group(name, description, parent),
760        }
761    }
762
763    /// Get the complete list of sources
764    ///
765    /// # Errors
766    ///
767    /// If there's a connectivity issue to Postgres, an error will result
768    #[cfg(any(test, feature = "admin"))]
769    pub async fn list_sources(&self) -> Result<Vec<admin::Source>> {
770        match self {
771            DatabaseType::Postgres(pg) => pg.list_sources().await,
772            #[cfg(any(test, feature = "sqlite"))]
773            DatabaseType::SQLite(sl) => sl.list_sources(),
774        }
775    }
776
777    /// Create a source, returning the source ID
778    ///
779    /// # Errors
780    ///
781    /// * If there's a connectivity issue to Postgres, an error will result
782    /// * If the source already exists, an error will result
783    #[cfg(any(test, feature = "admin"))]
784    pub async fn create_source(
785        &self,
786        name: &str,
787        description: Option<&str>,
788        url: Option<&str>,
789        date: chrono::DateTime<Local>,
790        releasable: bool,
791        malicious: Option<bool>,
792    ) -> Result<u32> {
793        match self {
794            DatabaseType::Postgres(pg) => {
795                pg.create_source(name, description, url, date, releasable, malicious)
796                    .await
797            }
798            #[cfg(any(test, feature = "sqlite"))]
799            DatabaseType::SQLite(sl) => {
800                sl.create_source(name, description, url, date, releasable, malicious)
801            }
802        }
803    }
804
805    /// Edit a user account setting the specified field values, primarily used by the Admin gui
806    ///
807    /// # Errors
808    ///
809    /// * If there's a connectivity issue to Postgres, an error will result
810    /// * If the user isn't valid, an error will result
811    #[cfg(any(test, feature = "admin"))]
812    pub async fn edit_user(
813        &self,
814        uid: u32,
815        uname: &str,
816        fname: &str,
817        lname: &str,
818        email: &str,
819        readonly: bool,
820    ) -> Result<()> {
821        match self {
822            DatabaseType::Postgres(pg) => {
823                pg.edit_user(uid, uname, fname, lname, email, readonly)
824                    .await
825            }
826            #[cfg(any(test, feature = "sqlite"))]
827            DatabaseType::SQLite(sl) => sl.edit_user(uid, uname, fname, lname, email, readonly),
828        }
829    }
830
831    /// Deactivate, but don't delete, a user's account
832    ///
833    /// # Errors
834    ///
835    /// * If there's a connectivity issue to Postgres, an error will result
836    /// * If the user ID isn't valid, an error will result
837    #[cfg(any(test, feature = "admin"))]
838    pub async fn deactivate_user(&self, uid: u32) -> Result<()> {
839        match self {
840            DatabaseType::Postgres(pg) => pg.deactivate_user(uid).await,
841            #[cfg(any(test, feature = "sqlite"))]
842            DatabaseType::SQLite(sl) => sl.deactivate_user(uid),
843        }
844    }
845
846    /// File types and number of files per type
847    ///
848    /// # Errors
849    ///
850    /// * If there's a connectivity issue to Postgres, an error will result
851    #[cfg(any(test, feature = "admin"))]
852    pub async fn file_types_counts(&self) -> Result<HashMap<String, u32>> {
853        match self {
854            DatabaseType::Postgres(pg) => pg.file_types_counts().await,
855            #[cfg(any(test, feature = "sqlite"))]
856            DatabaseType::SQLite(sl) => sl.file_types_counts(),
857        }
858    }
859
860    /// Create a new label
861    ///
862    /// # Errors
863    ///
864    /// * If there's a connectivity issue to Postgres, an error will result
865    /// * If the parent label doesn't exist, an error will result
866    /// * If the label name already exists, an error will result
867    #[cfg(any(test, feature = "admin"))]
868    pub async fn create_label(&self, name: &str, parent: Option<u64>) -> Result<u64> {
869        match self {
870            DatabaseType::Postgres(pg) => pg.create_label(name, parent).await,
871            #[cfg(any(test, feature = "sqlite"))]
872            DatabaseType::SQLite(sl) => sl.create_label(name, parent),
873        }
874    }
875
876    /// Edit a label name or parent
877    ///
878    /// # Errors
879    ///
880    /// * If there's a connectivity issue to Postgres, an error will result
881    /// * If the parent label doesn't exist, an error will result
882    /// * If the label name already exists, an error will result
883    #[cfg(any(test, feature = "admin"))]
884    pub async fn edit_label(&self, id: u64, name: &str, parent: Option<u64>) -> Result<()> {
885        match self {
886            DatabaseType::Postgres(pg) => pg.edit_label(id, name, parent).await,
887            #[cfg(any(test, feature = "sqlite"))]
888            DatabaseType::SQLite(sl) => sl.edit_label(id, name, parent),
889        }
890    }
891
892    /// Return the ID for a given label name
893    ///
894    /// # Errors
895    ///
896    /// * If there's a connectivity issue to Postgres, an error will result
897    /// * If the label ID is invalid, an error will result
898    #[cfg(any(test, feature = "admin"))]
899    pub async fn label_id_from_name(&self, name: &str) -> Result<u64> {
900        match self {
901            DatabaseType::Postgres(pg) => pg.label_id_from_name(name).await,
902            #[cfg(any(test, feature = "sqlite"))]
903            DatabaseType::SQLite(sl) => sl.label_id_from_name(name),
904        }
905    }
906}
907
908/// Hash a password string with [Argon2]
909///
910/// # Errors
911///
912/// An error may result if Argon has an error
913pub fn hash_password(password: &str) -> Result<String> {
914    let salt = SaltString::generate(&mut OsRng);
915    let argon2 = Argon2::default();
916    Ok(argon2
917        .hash_password(password.as_bytes(), &salt)?
918        .to_string())
919}
920
921/// Generate a new, random API key
922#[must_use]
923pub fn random_bytes_api_key() -> String {
924    let key1 = uuid::Uuid::new_v4();
925    let key2 = uuid::Uuid::new_v4();
926    let key1 = key1.to_string().replace('-', "");
927    let key2 = key2.to_string().replace('-', "");
928    format!("{key1}{key2}")
929}
930
931#[cfg(test)]
932mod tests {
933    use super::*;
934    #[cfg(feature = "vt")]
935    use crate::vt::VtUpdater;
936
937    use std::fs;
938    #[cfg(feature = "vt")]
939    use std::time::SystemTime;
940
941    use anyhow::Context;
942    use fuzzyhash::FuzzyHash;
943    use malwaredb_api::PartialHashSearchType;
944    use malwaredb_lzjd::{LZDict, Murmur3HashState};
945    use tlsh_fixed::TlshBuilder;
946    use uuid::Uuid;
947
948    const MALWARE_LABEL: &str = "malware";
949    const RANSOMWARE_LABEL: &str = "ransomware";
950
951    fn generate_similarity_request(data: &[u8]) -> malwaredb_api::SimilarSamplesRequest {
952        let mut hashes = vec![];
953
954        hashes.push((
955            malwaredb_api::SimilarityHashType::SSDeep,
956            FuzzyHash::new(data).to_string(),
957        ));
958
959        let mut builder = TlshBuilder::new(
960            tlsh_fixed::BucketKind::Bucket256,
961            tlsh_fixed::ChecksumKind::ThreeByte,
962            tlsh_fixed::Version::Version4,
963        );
964
965        builder.update(data);
966
967        if let Ok(hasher) = builder.build() {
968            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
969        }
970
971        let build_hasher = Murmur3HashState::default();
972        let lzjd_str = LZDict::from_bytes_stream(data.iter().copied(), &build_hasher).to_string();
973        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
974
975        malwaredb_api::SimilarSamplesRequest { hashes }
976    }
977
978    async fn pg_config() -> Postgres {
979        // create user malwaredbtesting with password 'malwaredbtesting';
980        // create database malwaredbtesting owner malwaredbtesting;
981        const CONNECTION_STRING: &str =
982            "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
983
984        if let Ok(pg_port) = std::env::var("PG_PORT") {
985            // Get the port number to run in Github CI
986            let conn_string = format!("{CONNECTION_STRING} port={pg_port}");
987            Postgres::new(&conn_string, None)
988                .await
989                .context(format!(
990                    "failed to connect to postgres with specified port {pg_port}"
991                ))
992                .unwrap()
993        } else {
994            Postgres::new(CONNECTION_STRING, None).await.unwrap()
995        }
996    }
997
998    #[tokio::test]
999    #[ignore = "don't run this in CI"]
1000    async fn pg() {
1001        let psql = pg_config().await;
1002        psql.delete_init().await.unwrap();
1003
1004        let db = DatabaseType::Postgres(psql);
1005        let key = FileEncryption::from(EncryptionOption::Xor);
1006        db.add_file_encryption_key(&key).await.unwrap();
1007        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1008        everything(&db).await.unwrap();
1009
1010        #[cfg(feature = "vt")]
1011        {
1012            let db_config = db.get_config().await.unwrap();
1013            let state = crate::State {
1014                port: 8080,
1015                directory: None,
1016                max_upload: 10 * 1024 * 1024,
1017                ip: "127.0.0.1".parse().unwrap(),
1018                db_type: db,
1019                db_config,
1020                keys: HashMap::new(),
1021                started: SystemTime::now(),
1022                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1023                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1024                }),
1025                cert: None,
1026                key: None,
1027            };
1028
1029            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1030
1031            vt.updater().await.unwrap();
1032            println!("PG: Did VT ops!");
1033
1034            let psql = pg_config().await;
1035
1036            let vt_stats = psql
1037                .get_vt_stats()
1038                .await
1039                .context("failed to get Postgres VT Stats")
1040                .unwrap();
1041            println!("{vt_stats:?}");
1042            assert!(
1043                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1044            );
1045        }
1046
1047        // Re-create the Postgres object so we can do some clean-up
1048        let psql = pg_config().await;
1049        psql.delete_init().await.unwrap();
1050    }
1051
1052    #[tokio::test]
1053    async fn sqlite() {
1054        const DB_FILE: &str = "testing_sqlite.db";
1055        if std::path::Path::new(DB_FILE).exists() {
1056            fs::remove_file(DB_FILE)
1057                .context(format!("failed to delete old SQLite file {DB_FILE}"))
1058                .unwrap();
1059        }
1060
1061        let sqlite = Sqlite::new(DB_FILE)
1062            .context(format!("failed to create SQLite instance for {DB_FILE}"))
1063            .unwrap();
1064
1065        let db = DatabaseType::SQLite(sqlite);
1066        let key = FileEncryption::from(EncryptionOption::Xor);
1067        db.add_file_encryption_key(&key).await.unwrap();
1068        assert_eq!(db.get_encryption_keys().await.unwrap().len(), 1);
1069        everything(&db).await.unwrap();
1070
1071        #[cfg(feature = "vt")]
1072        {
1073            let db_config = db.get_config().await.unwrap();
1074            let state = crate::State {
1075                port: 8080,
1076                directory: None,
1077                max_upload: 10 * 1024 * 1024,
1078                ip: "127.0.0.1".parse().unwrap(),
1079                db_type: db,
1080                db_config,
1081                keys: HashMap::new(),
1082                started: SystemTime::now(),
1083                vt_client: std::env::var("VT_API_KEY").map_or(None, |e| {
1084                    Some(malwaredb_virustotal::VirusTotalClient::new(e))
1085                }),
1086                cert: None,
1087                key: None,
1088            };
1089
1090            let sqlite_second = Sqlite::new(DB_FILE)
1091                .context(format!("failed to create SQLite instance for {DB_FILE}"))
1092                .unwrap();
1093
1094            let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
1095
1096            vt.updater().await.unwrap();
1097            println!("Sqlite: Did VT ops!");
1098            let vt_stats = sqlite_second
1099                .get_vt_stats()
1100                .context("failed to get Sqlite VT Stats")
1101                .unwrap();
1102            println!("{vt_stats:?}");
1103            assert!(
1104                vt_stats.files_without_records + vt_stats.clean_records + vt_stats.hits_records > 2
1105            );
1106        }
1107
1108        fs::remove_file(DB_FILE)
1109            .context(format!("failed to delete SQLite file {DB_FILE}"))
1110            .unwrap();
1111    }
1112
1113    #[allow(clippy::too_many_lines)]
1114    async fn everything(db: &DatabaseType) -> Result<()> {
1115        const ADMIN_UNAME: &str = "admin";
1116        const ADMIN_PASSWORD: &str = "super_secure_password_dont_tell_anyone!";
1117
1118        db.set_name("Testing Database")
1119            .await
1120            .context("setting instance name failed")?;
1121
1122        assert!(
1123            db.authenticate(ADMIN_UNAME, ADMIN_PASSWORD).await.is_err(),
1124            "Authentication without password should have failed."
1125        );
1126
1127        db.set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1128            .await
1129            .context("failed to set admin password")?;
1130
1131        let admin_api_key = db
1132            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1133            .await
1134            .context("unable to get api key for admin")?;
1135        println!("API key: {admin_api_key}");
1136        assert_eq!(admin_api_key.len(), 64);
1137
1138        assert_eq!(
1139            db.get_uid(&admin_api_key).await?,
1140            0,
1141            "Unable to get UID given the API key"
1142        );
1143
1144        let admin_api_key_again = db
1145            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1146            .await
1147            .context("unable to get api key a second time for admin")?;
1148
1149        assert_eq!(
1150            admin_api_key, admin_api_key_again,
1151            "API keys didn't match the second time."
1152        );
1153
1154        let bad_password = "this_is_totally_not_my_password!!";
1155        eprintln!("Testing API login with incorrect password.");
1156        assert!(
1157            db.authenticate(ADMIN_UNAME, bad_password).await.is_err(),
1158            "Authenticating as admin with a bad password should have failed."
1159        );
1160
1161        let admin_is_admin = db
1162            .user_is_admin(0)
1163            .await
1164            .context("unable to see if admin (uid 0) is an admin")?;
1165        assert!(admin_is_admin);
1166
1167        let new_user_uname = "testuser";
1168        let new_user_email = "test@example.com";
1169        let new_user_password = "some_awesome_password_++";
1170        let new_id = db
1171            .create_user(
1172                new_user_uname,
1173                new_user_uname,
1174                new_user_uname,
1175                new_user_email,
1176                Some(new_user_password.into()),
1177                None,
1178                false,
1179            )
1180            .await
1181            .context(format!("failed to create user {new_user_uname}"))?;
1182
1183        let passwordless_user_id = db
1184            .create_user(
1185                "passwordless_user",
1186                "passwordless_user",
1187                "passwordless_user",
1188                "passwordless_user@example.com",
1189                None,
1190                None,
1191                false,
1192            )
1193            .await
1194            .context("failed to create passwordless_user")?;
1195
1196        for user in &db.list_users().await.context("failed to list users")? {
1197            if user.id == passwordless_user_id {
1198                assert_eq!(user.uname, "passwordless_user");
1199            }
1200        }
1201
1202        db.edit_user(
1203            passwordless_user_id,
1204            "passwordless_user_2",
1205            "passwordless_user_2",
1206            "passwordless_user_2",
1207            "passwordless_user_2@something.com",
1208            false,
1209        )
1210        .await
1211        .context(format!(
1212            "failed to alter 'passwordless' user, id {passwordless_user_id}"
1213        ))?;
1214
1215        for user in &db.list_users().await.context("failed to list users")? {
1216            if user.id == passwordless_user_id {
1217                assert_eq!(user.uname, "passwordless_user_2");
1218            }
1219        }
1220
1221        assert!(
1222            new_id > 0,
1223            "Weird UID created for user {new_user_uname}: {new_id}"
1224        );
1225
1226        assert!(
1227            db.create_user(
1228                new_user_uname,
1229                new_user_uname,
1230                new_user_uname,
1231                new_user_email,
1232                Some(new_user_password.into()),
1233                None,
1234                false
1235            )
1236            .await
1237            .is_err(),
1238            "Creating a new user with the same user name should fail"
1239        );
1240
1241        let ro_user_name = "ro_user";
1242        let ro_user_password = "ro_user_password";
1243        db.create_user(
1244            ro_user_name,
1245            "ro_user",
1246            "ro_user",
1247            "ro@example.com",
1248            Some(ro_user_password.into()),
1249            None,
1250            true,
1251        )
1252        .await
1253        .context("failed to create read-only user")?;
1254
1255        let ro_user_api_key = db
1256            .authenticate(ro_user_name, ro_user_password)
1257            .await
1258            .context("unable to get api key for read-only user")?;
1259
1260        let new_user_password_change = "some_new_awesomer_password!_++";
1261        db.set_password(new_user_uname, new_user_password_change)
1262            .await
1263            .context("failed to change the password for testuser")?;
1264
1265        let new_user_api_key = db
1266            .authenticate(new_user_uname, new_user_password_change)
1267            .await
1268            .context("unable to get api key for testuser")?;
1269        eprintln!("{new_user_uname} got API key {new_user_api_key}");
1270
1271        assert_eq!(admin_api_key.len(), new_user_api_key.len());
1272
1273        let users = db.list_users().await.context("failed to list users")?;
1274        assert_eq!(
1275            users.len(),
1276            4,
1277            "Four users were created, yet there are {} users",
1278            users.len()
1279        );
1280        eprintln!("DB has {} users:", users.len());
1281        let mut passwordless_user_found = false;
1282        for user in users {
1283            println!("{user}");
1284            if user.uname == "passwordless_user_2" {
1285                assert!(!user.has_api_key);
1286                assert!(!user.has_password);
1287                passwordless_user_found = true;
1288            } else {
1289                assert!(user.has_api_key);
1290                assert!(user.has_password);
1291            }
1292        }
1293        assert!(passwordless_user_found);
1294
1295        let new_group_name = "some_new_group";
1296        let new_group_desc = "some_new_group_description";
1297        let new_group_id = 1;
1298        assert_eq!(
1299            db.create_group(new_group_name, new_group_desc, None)
1300                .await
1301                .context("failed to create group")?,
1302            new_group_id,
1303            "New group didn't have the expected ID, expected {new_group_id}"
1304        );
1305
1306        assert!(
1307            db.create_group(new_group_name, new_group_desc, None)
1308                .await
1309                .is_err(),
1310            "Duplicate group name should have failed"
1311        );
1312
1313        db.add_user_to_group(1, 1)
1314            .await
1315            .context("Unable to add uid 1 to gid 1")?;
1316
1317        let ro_user_uid = db
1318            .get_uid(&ro_user_api_key)
1319            .await
1320            .context("Unable to get UID for read-only user")?;
1321        db.add_user_to_group(ro_user_uid, 1)
1322            .await
1323            .context("Unable to add uid 2 to gid 1")?;
1324
1325        let new_admin_group_name = "admin_subgroup";
1326        let new_admin_group_desc = "admin_subgroup_description";
1327        let new_admin_group_id = 2;
1328        // TODO: Figure out why SQLite makes the group_id = 2, but with Postgres it's 3.
1329        assert!(
1330            db.create_group(new_admin_group_name, new_admin_group_desc, Some(0))
1331                .await
1332                .context("failed to create admin sub-group")?
1333                >= new_admin_group_id,
1334            "New group didn't have the expected ID, expected >= {new_admin_group_id}"
1335        );
1336
1337        let groups = db.list_groups().await.context("failed to list groups")?;
1338        assert_eq!(
1339            groups.len(),
1340            3,
1341            "Three groups were created, yet there are {} groups",
1342            groups.len()
1343        );
1344        eprintln!("DB has {} groups:", groups.len());
1345        for group in groups {
1346            println!("{group}");
1347            if group.id == new_admin_group_id {
1348                assert_eq!(group.parent, Some("admin".to_string()));
1349            }
1350            if group.id == 1 {
1351                let test_user_str = String::from(new_user_uname);
1352                let mut found = false;
1353                for member in group.members {
1354                    if member.uname == test_user_str {
1355                        found = true;
1356                        break;
1357                    }
1358                }
1359                assert!(found, "new user {test_user_str} wasn't in the group");
1360            }
1361        }
1362
1363        let default_source_name = "default_source".to_string();
1364        let default_source_id = db
1365            .create_source(
1366                &default_source_name,
1367                Some("desc_default_source"),
1368                None,
1369                Local::now(),
1370                true,
1371                Some(false),
1372            )
1373            .await
1374            .context("failed to create source `default_source`")?;
1375
1376        db.add_group_to_source(1, default_source_id)
1377            .await
1378            .context("failed to add group 1 to source 1")?;
1379
1380        let another_source_name = "another_source".to_string();
1381        let another_source_id = db
1382            .create_source(
1383                &another_source_name,
1384                Some("yet another file source"),
1385                None,
1386                Local::now(),
1387                true,
1388                Some(false),
1389            )
1390            .await
1391            .context("failed to create source `another_source`")?;
1392
1393        let empty_source_name = "empty_source".to_string();
1394        db.create_source(
1395            &empty_source_name,
1396            Some("empty and unused file source"),
1397            None,
1398            Local::now(),
1399            true,
1400            Some(false),
1401        )
1402        .await
1403        .context("failed to create source `another_source`")?;
1404
1405        db.add_group_to_source(1, another_source_id)
1406            .await
1407            .context("failed to add group 1 to source 1")?;
1408
1409        let sources = db.list_sources().await.context("failed to list sources")?;
1410        eprintln!("DB has {} sources:", sources.len());
1411        for source in sources {
1412            println!("{source}");
1413            assert_eq!(source.files, 0);
1414            if source.id == default_source_id || source.id == another_source_id {
1415                assert_eq!(
1416                    source.groups, 1,
1417                    "default source {default_source_name} should have 1 group"
1418                );
1419            } else {
1420                assert_eq!(source.groups, 0, "groups should zero (empty)");
1421            }
1422        }
1423
1424        let uid = db
1425            .get_uid(&new_user_api_key)
1426            .await
1427            .context("failed to user uid from apikey")?;
1428        let user_info = db
1429            .get_user_info(uid)
1430            .await
1431            .context("failed to get user's available groups and sources")?;
1432        assert!(user_info.sources.contains(&default_source_name));
1433        assert!(!user_info.is_admin);
1434        println!("UserInfoResponse: {user_info:?}");
1435
1436        assert!(
1437            db.allowed_user_source(1, default_source_id)
1438                .await
1439                .context(format!(
1440                    "failed to check that user 1 has access to source {default_source_id}"
1441                ))?,
1442            "User 1 should should have had access to source {default_source_id}"
1443        );
1444
1445        assert!(
1446            !db.allowed_user_source(1, 5)
1447                .await
1448                .context("failed to check that user 1 has access to source 5")?,
1449            "User 1 should should not have had access to source 5"
1450        );
1451
1452        let test_elf = include_bytes!("../../../types/testdata/elf/elf_linux_ppc64le").to_vec();
1453        let test_elf_meta = FileMetadata::new(&test_elf, Some("elf_linux_ppc64le"));
1454        let elf_type = db.get_type_id_for_bytes(&test_elf).await.unwrap();
1455
1456        let known_type =
1457            KnownType::new(&test_elf).context("failed to parse elf from test crate's test data")?;
1458        assert!(known_type.is_exec(), "ELF should be executable");
1459        eprintln!("ELF type ID: {elf_type}");
1460
1461        assert!(db
1462            .add_file(
1463                &test_elf_meta,
1464                known_type.clone(),
1465                1,
1466                default_source_id,
1467                elf_type,
1468                None
1469            )
1470            .await
1471            .context("failed to insert a test elf")?);
1472        eprintln!("Added ELF to the DB");
1473
1474        let partial_search = SearchRequest {
1475            partial_hash: Some((PartialHashSearchType::SHA1, "fe7d0186".into())),
1476            ..Default::default()
1477        };
1478        let partial_search_response = db.partial_search(1, &partial_search).await?;
1479        assert_eq!(partial_search_response.len(), 1);
1480        assert_eq!(
1481            partial_search_response[0],
1482            "897541f9f3c673b3ecc7004ff52c70c0b0440e804c7c3eb4854d72d94c317868"
1483        );
1484
1485        let partial_search = SearchRequest {
1486            file_name: Some("ppc64".into()),
1487            ..Default::default()
1488        };
1489        let partial_search_response = db.partial_search(1, &partial_search).await?;
1490        assert_eq!(partial_search_response.len(), 1);
1491
1492        let partial_search = SearchRequest::default();
1493        let partial_search_response = db.partial_search(1, &partial_search).await?;
1494        assert!(partial_search_response.is_empty());
1495
1496        assert!(
1497            db.add_file(
1498                &test_elf_meta,
1499                known_type.clone(),
1500                ro_user_uid,
1501                default_source_id,
1502                elf_type,
1503                None
1504            )
1505            .await
1506            .is_err(),
1507            "Read-only user should not be able to add a file"
1508        );
1509
1510        let mut test_elf_meta_different_name = test_elf_meta.clone();
1511        test_elf_meta_different_name.name = Some("completely_different_name.bin".into());
1512
1513        assert!(!db
1514            .add_file(
1515                &test_elf_meta_different_name,
1516                known_type,
1517                1,
1518                another_source_id,
1519                elf_type,
1520                None
1521            )
1522            .await
1523            .context("failed to insert a test elf again for a different source")?);
1524
1525        let sources = db
1526            .list_sources()
1527            .await
1528            .context("failed to re-list sources")?;
1529        eprintln!(
1530            "DB has {} sources, and a file was added twice:",
1531            sources.len()
1532        );
1533        println!("We should have two sources with one file each, yet only one ELF.");
1534        for source in sources {
1535            println!("{source}");
1536            if source.id == default_source_id || source.id == another_source_id {
1537                assert_eq!(source.files, 1);
1538            } else {
1539                assert_eq!(source.files, 0, "groups should zero (empty)");
1540            }
1541        }
1542
1543        assert!(!db
1544            .get_user_sources(1)
1545            .await
1546            .expect("failed to get user 1's sources")
1547            .sources
1548            .is_empty());
1549
1550        let file_types_counts = db
1551            .file_types_counts()
1552            .await
1553            .context("failed to get file types and counts")?;
1554        for (name, count) in file_types_counts {
1555            println!("{name}: {count}");
1556            assert_eq!(name, "ELF");
1557            assert_eq!(count, 1);
1558        }
1559
1560        let mut test_elf_modified = test_elf.clone();
1561        let random_bytes = Uuid::new_v4();
1562        let mut random_bytes = random_bytes.into_bytes().to_vec();
1563        test_elf_modified.append(&mut random_bytes);
1564        let similarity_request = generate_similarity_request(&test_elf_modified);
1565        let similarity_response = db
1566            .find_similar_samples(1, &similarity_request.hashes)
1567            .await
1568            .context("failed to get similarity response")?;
1569        eprintln!("Similarity response: {similarity_response:?}");
1570        let similarity_response = similarity_response.first().unwrap();
1571        assert_eq!(
1572            similarity_response.sha256, test_elf_meta.sha256,
1573            "Similarity response should have had the hash of the original ELF"
1574        );
1575        for (algo, sim) in &similarity_response.algorithms {
1576            match algo {
1577                malwaredb_api::SimilarityHashType::LZJD => {
1578                    assert!(*sim > 0.0f32);
1579                }
1580                malwaredb_api::SimilarityHashType::SSDeep => {
1581                    assert!(*sim > 80.0f32);
1582                }
1583                malwaredb_api::SimilarityHashType::TLSH => {
1584                    assert!(*sim <= 20f32);
1585                }
1586                _ => {}
1587            }
1588        }
1589
1590        let test_elf_hashtype = HashType::try_from(test_elf_meta.sha1)
1591            .context("failed to get `HashType::SHA1` from string")?;
1592        let response_sha256 = db
1593            .retrieve_sample(1, &test_elf_hashtype)
1594            .await
1595            .context("could not get SHA-256 hash from test sample")
1596            .unwrap();
1597        assert_eq!(response_sha256, test_elf_meta.sha256);
1598
1599        let test_bogus_hash = HashType::try_from(String::from(
1600            "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0",
1601        ))
1602        .context("failed to get `HashType` from static string")?;
1603        assert!(
1604            db.retrieve_sample(1, &test_bogus_hash).await.is_err(),
1605            "Getting a file with a bogus hash should have failed."
1606        );
1607
1608        let test_pdf = include_bytes!("../../../types/testdata/pdf/test.pdf").to_vec();
1609        let test_pdf_meta = FileMetadata::new(&test_pdf, Some("test.pdf"));
1610        let pdf_type = db.get_type_id_for_bytes(&test_pdf).await.unwrap();
1611
1612        let known_type =
1613            KnownType::new(&test_pdf).context("failed to parse pdf from test crate's test data")?;
1614
1615        assert!(db
1616            .add_file(
1617                &test_pdf_meta,
1618                known_type,
1619                1,
1620                default_source_id,
1621                pdf_type,
1622                None
1623            )
1624            .await
1625            .context("failed to insert a test pdf")?);
1626        eprintln!("Added PDF to the DB");
1627
1628        let test_rtf = include_bytes!("../../../types/testdata/rtf/hello.rtf").to_vec();
1629        let test_rtf_meta = FileMetadata::new(&test_rtf, Some("test.rtf"));
1630        let rtf_type = db
1631            .get_type_id_for_bytes(&test_rtf)
1632            .await
1633            .context("failed to get file type id for rtf")?;
1634
1635        let known_type =
1636            KnownType::new(&test_rtf).context("failed to parse pdf from test crate's test data")?;
1637
1638        assert!(db
1639            .add_file(
1640                &test_rtf_meta,
1641                known_type,
1642                1,
1643                default_source_id,
1644                rtf_type,
1645                None
1646            )
1647            .await
1648            .context("failed to insert a test rtf")?);
1649        eprintln!("Added RTF to the DB");
1650
1651        let report = db
1652            .get_sample_report(1, &HashType::try_from(test_rtf_meta.sha256.clone())?)
1653            .await
1654            .context("failed to get report for test rtf")?;
1655        assert!(report
1656            .clone()
1657            .filecommand
1658            .unwrap()
1659            .contains("Rich Text Format"));
1660        println!("Report: {report}");
1661
1662        assert!(db
1663            .get_sample_report(999, &HashType::try_from(test_rtf_meta.sha256)?)
1664            .await
1665            .is_err());
1666
1667        #[cfg(feature = "vt")]
1668        {
1669            assert!(report.vt.is_some());
1670            let files_needing_vt = db
1671                .files_without_vt_records(10)
1672                .await
1673                .context("failed to get files without VT records")?;
1674            assert!(files_needing_vt.len() > 2);
1675            println!(
1676                "{} files needing VT data: {files_needing_vt:?}",
1677                files_needing_vt.len()
1678            );
1679        }
1680
1681        #[cfg(not(feature = "vt"))]
1682        {
1683            assert!(report.vt.is_none());
1684        }
1685
1686        let reset = db
1687            .reset_api_keys()
1688            .await
1689            .context("failed to reset all API keys")?;
1690        eprintln!("Cleared {reset} api keys.");
1691
1692        let db_info = db.db_info().await.context("failed to get database info")?;
1693        eprintln!("DB Info: {db_info:?}");
1694
1695        let data_types = db
1696            .get_known_data_types()
1697            .await
1698            .context("failed to get data types")?;
1699        for data_type in data_types {
1700            println!("{data_type:?}");
1701        }
1702
1703        let sources = db
1704            .list_sources()
1705            .await
1706            .context("failed to list sources second time")?;
1707        eprintln!("DB has {} sources:", sources.len());
1708        for source in sources {
1709            println!("{source}");
1710        }
1711
1712        let file_types_counts = db
1713            .file_types_counts()
1714            .await
1715            .context("failed to get file types and counts")?;
1716        for (name, count) in file_types_counts {
1717            println!("{name}: {count}");
1718            assert_ne!(name, "Mach-O", "No Mach-O files have been inserted yet!");
1719        }
1720
1721        let fatmacho =
1722            include_bytes!("../../../types/testdata/macho/macho_fat_arm64_ppc_ppc64_x86_64")
1723                .to_vec();
1724        let fatmacho_meta = FileMetadata::new(&fatmacho, Some("macho_fat_arm64_ppc_ppc64_x86_64"));
1725        let fatmacho_type = db
1726            .get_type_id_for_bytes(&fatmacho)
1727            .await
1728            .context("failed to get file type for Fat Mach-O")?;
1729        let known_type = KnownType::new(&fatmacho)
1730            .context("failed to parse Fat Mach-O from type crate's test data")?;
1731
1732        assert!(db
1733            .add_file(
1734                &fatmacho_meta,
1735                known_type,
1736                1,
1737                default_source_id,
1738                fatmacho_type,
1739                None
1740            )
1741            .await
1742            .context("failed to insert a test Fat Mach-O")?);
1743        eprintln!("Added Fat Mach-O to the DB");
1744
1745        let file_types_counts = db
1746            .file_types_counts()
1747            .await
1748            .context("failed to get file types and counts")?;
1749        for (name, count) in &file_types_counts {
1750            println!("{name}: {count}");
1751        }
1752
1753        assert_eq!(
1754            *file_types_counts.get("Mach-O").unwrap(),
1755            4,
1756            "Expected 4 Mach-O files, got {:?}",
1757            file_types_counts.get("Mach-O")
1758        );
1759
1760        let malware_label_id = db
1761            .create_label(MALWARE_LABEL, None)
1762            .await
1763            .context("failed to create first label")?;
1764        let ransomware_label_id = db
1765            .create_label(RANSOMWARE_LABEL, Some(malware_label_id))
1766            .await
1767            .context("failed to create malware sub-label")?;
1768        let labels = db.get_labels().await.context("failed to get labels")?;
1769
1770        assert_eq!(labels.len(), 2);
1771        for label in labels.0 {
1772            if label.name == RANSOMWARE_LABEL {
1773                assert_eq!(label.id, ransomware_label_id);
1774                assert_eq!(label.parent.unwrap(), MALWARE_LABEL);
1775            }
1776        }
1777
1778        // Use this file as a stand-un for an unknown file type
1779        let source_code = include_bytes!("mod.rs");
1780        let source_meta = FileMetadata::new(source_code, Some("mod.rs"));
1781        let known_type =
1782            KnownType::new(source_code).context("failed to source code to get `Unknown` type")?;
1783
1784        assert!(matches!(known_type, KnownType::Unknown(_)));
1785
1786        let unknown_type: Vec<FileType> = db
1787            .get_known_data_types()
1788            .await?
1789            .into_iter()
1790            .filter(|t| t.name.eq_ignore_ascii_case("unknown"))
1791            .collect();
1792        let unknown_type_id = unknown_type.first().unwrap().id;
1793        assert!(db.get_type_id_for_bytes(source_code).await.is_err());
1794        db.enable_keep_unknown_files()
1795            .await
1796            .context("failed to enable keeping of unknown files")?;
1797        let source_type = db
1798            .get_type_id_for_bytes(source_code)
1799            .await
1800            .context("failed to type id for source code unknown type example")?;
1801        assert_eq!(source_type, unknown_type_id);
1802        eprintln!("Unknown file type ID: {source_type}");
1803        assert!(db
1804            .add_file(
1805                &source_meta,
1806                known_type,
1807                1,
1808                default_source_id,
1809                unknown_type_id,
1810                None
1811            )
1812            .await
1813            .context("failed to add Rust source code file")?);
1814        eprintln!("Added Rust source code to the DB");
1815
1816        db.reset_own_api_key(0)
1817            .await
1818            .context("failed to clear own API key uid 0")?;
1819
1820        db.deactivate_user(0)
1821            .await
1822            .context("failed to clear password and API key for uid 0")?;
1823
1824        Ok(())
1825    }
1826}