malwaredb-server 0.3.4

Server data storage logic for MalwareDB.
Documentation
// SPDX-License-Identifier: Apache-2.0

use crate::State;
use std::fmt::{Display, Formatter};

use std::process::ExitCode;
use std::time::Duration;

use anyhow::{Context, Result};
use malwaredb_virustotal::errors::VirusTotalError;
use malwaredb_virustotal::filereport::ScanResultAttributes;
use tracing::{debug, info};

/// Logic for getting records from the database and querying Virus Total for AV hits
pub struct VtUpdater {
    /// Database handle and configuration for Malware DB
    state: State,

    /// If we're allowed to upload samples to Virus Total
    pub send_samples: bool,
}

impl VtUpdater {
    /// Fetch updates for all samples
    ///
    /// # Errors
    ///
    /// Errors may occur if there's a connection issue to Virus Total or to Postgres for storing the data
    pub async fn updater(&self) -> Result<ExitCode> {
        // TODO: Figure out how to paginate over results
        // Don't paginate now since we don't have a mechanism to know when a file isn't in
        // VT, so we'll have an infinite loop.
        let hashes = self
            .state
            .db_type
            .files_without_vt_records(1000)
            .await
            .context("Failed to retrieve hashes for querying VT")?;

        let vt_client = self.state.vt_client.as_ref().context("Missing VT key")?;

        for hash in hashes {
            match vt_client.get_file_report(&hash).await {
                Ok(result) => {
                    self.state
                        .db_type
                        .store_vt_record(&result.attributes)
                        .await
                        .context("Failed to store VT data")?;
                }
                Err(error) => {
                    if self.send_samples
                        && self.state.directory.is_some()
                        && error == VirusTotalError::NotFoundError
                    {
                        if let Ok(bytes) = self.state.retrieve_bytes(&hash).await {
                            match vt_client.submit_bytes(bytes, hash.clone()).await {
                                Ok(_) => {
                                    info!("Sample {hash} uploaded to VT successfully.");
                                }
                                Err(e) => debug!("Error uploading unknown sample to VT: {e}"),
                            }
                        }
                    } else {
                        debug!("Error getting report for {hash}: {error}");
                    }
                }
            }

            tokio::time::sleep(Duration::from_secs(2)).await; // Don't overload VT
        }

        Ok(ExitCode::SUCCESS)
    }

    /// Add a serialized VT report to the database
    ///
    /// # Errors
    ///
    /// There will be an error response if there's a Postgres connection issue.
    pub async fn loader(&self, report: &ScanResultAttributes) -> Result<()> {
        self.state
            .db_type
            .store_vt_record(report)
            .await
            .context("Failed to store VT data")
    }
}

/// Virus Total error if the API key is missing
#[derive(Debug, Copy, Clone, Default)]
pub struct VtKeyMissingError;

impl Display for VtKeyMissingError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "No VT API Key")
    }
}

impl std::error::Error for VtKeyMissingError {}

/// Get a `VtUpdater` object if we have a VT API key
impl TryFrom<State> for VtUpdater {
    type Error = VtKeyMissingError;

    fn try_from(state: State) -> std::result::Result<Self, Self::Error> {
        if state.vt_client.is_none() {
            return Err(VtKeyMissingError);
        }
        let send_samples = state.db_config.send_samples_to_vt;

        Ok(VtUpdater {
            state,
            send_samples,
        })
    }
}