malwaredb_server/vt/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::State;
4use std::fmt::{Display, Formatter};
5
6use std::process::ExitCode;
7use std::time::Duration;
8
9use anyhow::{Context, Result};
10use malwaredb_virustotal::errors::VirusTotalError;
11use malwaredb_virustotal::filereport::ScanResultAttributes;
12use tracing::{debug, info};
13
14/// Logic for getting records from the database and querying Virus Total for AV hits
15pub struct VtUpdater {
16    /// Database handle and configuration for Malware DB
17    state: State,
18
19    /// If we're allowed to upload samples to Virus Total
20    pub send_samples: bool,
21}
22
23impl VtUpdater {
24    /// Fetch updates for all samples
25    ///
26    /// # Errors
27    ///
28    /// Errors may occur if there's a connection issue to Virus Total or to Postgres for storing the data
29    pub async fn updater(&self) -> Result<ExitCode> {
30        // TODO: Figure out how to paginate over results
31        // Don't paginate now since we don't have a mechanism to know when a file isn't in
32        // VT, so we'll have an infinite loop.
33        let hashes = self
34            .state
35            .db_type
36            .files_without_vt_records(1000)
37            .await
38            .context("Failed to retrieve hashes for querying VT")?;
39
40        let vt_client = self.state.vt_client.as_ref().context("Missing VT key")?;
41
42        for hash in hashes {
43            match vt_client.get_file_report(&hash).await {
44                Ok(result) => {
45                    self.state
46                        .db_type
47                        .store_vt_record(&result.attributes)
48                        .await
49                        .context("Failed to store VT data")?;
50                }
51                Err(error) => {
52                    if self.send_samples
53                        && self.state.directory.is_some()
54                        && error == VirusTotalError::NotFoundError
55                    {
56                        if let Ok(bytes) = self.state.retrieve_bytes(&hash).await {
57                            match vt_client.submit_bytes(bytes, hash.clone()).await {
58                                Ok(_) => {
59                                    info!("Sample {hash} uploaded to VT successfully.");
60                                }
61                                Err(e) => debug!("Error uploading unknown sample to VT: {e}"),
62                            }
63                        }
64                    } else {
65                        debug!("Error getting report for {hash}: {error}");
66                    }
67                }
68            }
69
70            tokio::time::sleep(Duration::from_secs(2)).await; // Don't overload VT
71        }
72
73        Ok(ExitCode::SUCCESS)
74    }
75
76    /// Add a serialized VT report to the database
77    ///
78    /// # Errors
79    ///
80    /// There will be an error response if there's a Postgres connection issue.
81    pub async fn loader(&self, report: &ScanResultAttributes) -> Result<()> {
82        self.state
83            .db_type
84            .store_vt_record(report)
85            .await
86            .context("Failed to store VT data")
87    }
88}
89
90/// Virus Total error if the API key is missing
91#[derive(Debug, Copy, Clone, Default)]
92pub struct VtKeyMissingError;
93
94impl Display for VtKeyMissingError {
95    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
96        write!(f, "No VT API Key")
97    }
98}
99
100impl std::error::Error for VtKeyMissingError {}
101
102/// Get a `VtUpdater` object if we have a VT API key
103impl TryFrom<State> for VtUpdater {
104    type Error = VtKeyMissingError;
105
106    fn try_from(state: State) -> std::result::Result<Self, Self::Error> {
107        if state.vt_client.is_none() {
108            return Err(VtKeyMissingError);
109        }
110        let send_samples = state.db_config.send_samples_to_vt;
111
112        Ok(VtUpdater {
113            state,
114            send_samples,
115        })
116    }
117}