#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub mod errors;
pub mod filereport;
pub mod filerescan;
pub mod filesearch;
use crate::filereport::{FileReportData, FileReportRequestResponse};
use crate::filerescan::{FileRescanRequestData, FileRescanRequestResponse};
use crate::filesearch::FileSearchResponse;
use std::borrow::Cow;
use std::fmt::{Debug, Display, Formatter};
use std::str::FromStr;
use std::string::FromUtf8Error;
use bytes::Bytes;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::multipart::Form;
use serde::{Deserialize, Serialize, Serializer};
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone, Debug, Eq, Serialize, Deserialize)]
pub struct VirusTotalError {
    pub message: String,
    pub code: String,
}
impl PartialEq for VirusTotalError {
    fn eq(&self, other: &VirusTotalError) -> bool {
        self.code.to_lowercase() == other.code.to_lowercase()
    }
}
impl Display for VirusTotalError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.message)
    }
}
impl std::error::Error for VirusTotalError {}
impl From<reqwest::Error> for VirusTotalError {
    fn from(err: reqwest::Error) -> Self {
        let url = if let Some(url) = err.url() {
            format!(" loading {url}")
        } else {
            "".into()
        };
        Self {
            message: "Http error".into(),
            code: format!("Error{url} {err}"),
        }
    }
}
impl From<serde_json::Error> for VirusTotalError {
    fn from(err: serde_json::Error) -> Self {
        Self {
            message: "Json error".into(),
            code: format!("Json error at line {}: {err}", err.line()),
        }
    }
}
impl From<FromUtf8Error> for VirusTotalError {
    fn from(err: FromUtf8Error) -> Self {
        Self {
            message: "UTF-8 decoding error".into(),
            code: err.to_string(),
        }
    }
}
#[derive(Clone, Deserialize, Zeroize, ZeroizeOnDrop)]
#[cfg_attr(feature = "clap", derive(clap::Args))]
pub struct VirusTotalClient {
    #[cfg_attr(feature = "clap", arg(long, env = "VT_API_KEY"))]
    key: String,
}
impl Debug for VirusTotalClient {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "VirusTotal Client v{}", env!("CARGO_PKG_VERSION"))
    }
}
impl Serialize for VirusTotalClient {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        #[cfg(feature = "unsafe-serialization")]
        return serializer.serialize_str(&self.key);
        #[cfg(not(feature = "unsafe-serialization"))]
        serializer.serialize_str("your-api-key-here")
    }
}
impl VirusTotalClient {
    pub const API_KEY: &'static str = "x-apikey";
    pub const KEY_LEN: usize = 64;
    pub fn new(key: String) -> Self {
        Self { key }
    }
    fn header(&self) -> HeaderMap {
        let mut headers = HeaderMap::new();
        headers.insert(
            VirusTotalClient::API_KEY,
            HeaderValue::from_str(&self.key).unwrap(),
        );
        headers
    }
    fn client(&self) -> reqwest::Result<reqwest::Client> {
        reqwest::ClientBuilder::new()
            .gzip(true)
            .default_headers(self.header())
            .build()
    }
    pub async fn get_report_raw(&self, file_hash: &str) -> Result<Bytes, VirusTotalError> {
        let client = self.client()?;
        let bytes = client
            .get(format!(
                "https://www.virustotal.com/api/v3/files/{file_hash}"
            ))
            .send()
            .await?
            .bytes()
            .await?;
        Ok(bytes)
    }
    pub async fn get_report(&self, file_hash: &str) -> Result<FileReportData, VirusTotalError> {
        let body = self.get_report_raw(file_hash).await?;
        let json_response = String::from_utf8(body.to_ascii_lowercase())?;
        let report: FileReportRequestResponse = serde_json::from_str(&json_response)?;
        match report {
            FileReportRequestResponse::Data(data) => Ok(data),
            FileReportRequestResponse::Error(error) => Err(error),
        }
    }
    pub async fn request_rescan_raw(&self, file_hash: &str) -> Result<Bytes, VirusTotalError> {
        let client = self.client()?;
        let bytes = client
            .post(format!(
                "https://www.virustotal.com/api/v3/files/{file_hash}/analyse"
            ))
            .header("content-length", "0")
            .send()
            .await?
            .bytes()
            .await?;
        Ok(bytes)
    }
    pub async fn request_rescan(
        &self,
        file_hash: &str,
    ) -> Result<FileRescanRequestData, VirusTotalError> {
        let body = self.request_rescan_raw(file_hash).await?;
        let json_response = String::from_utf8(body.to_ascii_lowercase())?;
        let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
        match report {
            FileRescanRequestResponse::Data(data) => Ok(data),
            FileRescanRequestResponse::Error(error) => Err(error),
        }
    }
    pub async fn submit_raw<D, N>(&self, data: D, name: Option<N>) -> Result<Bytes, VirusTotalError>
    where
        D: Into<Cow<'static, [u8]>>,
        N: Into<Cow<'static, str>>,
    {
        let client = self.client()?;
        let form = if let Some(file_name) = name {
            Form::new().part(
                "file",
                reqwest::multipart::Part::bytes(data)
                    .file_name(file_name)
                    .mime_str("application/octet-stream")?,
            )
        } else {
            Form::new().part(
                "file",
                reqwest::multipart::Part::bytes(data).mime_str("application/octet-stream")?,
            )
        };
        let bytes = client
            .post("https://www.virustotal.com/api/v3/files")
            .header("accept", "application/json")
            .header("content-type", "multipart/form-data")
            .multipart(form)
            .send()
            .await?
            .bytes()
            .await?;
        Ok(bytes)
    }
    pub async fn submit<D, N>(
        &self,
        data: D,
        name: Option<N>,
    ) -> Result<FileRescanRequestData, VirusTotalError>
    where
        D: Into<Cow<'static, [u8]>>,
        N: Into<Cow<'static, str>>,
    {
        let body = self.submit_raw(data, name).await?;
        let json_response = String::from_utf8(body.to_ascii_lowercase())?;
        let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
        match report {
            FileRescanRequestResponse::Data(data) => Ok(data),
            FileRescanRequestResponse::Error(error) => Err(error),
        }
    }
    pub async fn download(&self, file_hash: &str) -> Result<Vec<u8>, VirusTotalError> {
        let client = self.client()?;
        let response = client
            .get(format!(
                "https://www.virustotal.com/api/v3/files/{file_hash}/download"
            ))
            .send()
            .await?;
        if !response.status().is_success() {
            let body = response.bytes().await?;
            let json_response = String::from_utf8(body.to_ascii_lowercase())?;
            let error: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
            return if let FileRescanRequestResponse::Error(error) = error {
                Err(error)
            } else {
                Err(VirusTotalError {
                    message: json_response,
                    code: "VTError".into(),
                })
            };
        }
        let body = response.bytes().await?;
        Ok(body.to_vec())
    }
    pub async fn search_raw<Q>(&self, query: Q) -> Result<Bytes, VirusTotalError>
    where
        Q: Display,
    {
        let url = format!(
            "https://www.virustotal.com/vtapi/v2/file/search?apikey={}&query={query}",
            self.key.as_str()
        );
        let body = self.client()?.get(url).send().await?.bytes().await?;
        Ok(body)
    }
    pub async fn search<Q>(&self, query: Q) -> Result<FileSearchResponse, VirusTotalError>
    where
        Q: Display,
    {
        let body = self.search_raw(&query).await?;
        let json_response = String::from_utf8(body.to_ascii_lowercase())?;
        let response: FileSearchResponse = serde_json::from_str(&json_response)?;
        let response = FileSearchResponse {
            response_code: response.response_code,
            offset: response.offset,
            hashes: response.hashes,
            query: query.to_string(),
        };
        Ok(response)
    }
    pub async fn search_offset(
        &self,
        prior: &FileSearchResponse,
    ) -> Result<FileSearchResponse, VirusTotalError> {
        let url = format!(
            "https://www.virustotal.com/vtapi/v2/file/search?apikey={}&query={}&offset={}",
            self.key.as_str(),
            prior.query,
            prior.offset
        );
        let body = self.client()?.get(url).send().await?.bytes().await?;
        let json_response = String::from_utf8(body.to_ascii_lowercase())?;
        let response: FileSearchResponse = serde_json::from_str(&json_response)?;
        let response = FileSearchResponse {
            response_code: response.response_code,
            offset: response.offset,
            hashes: response.hashes,
            query: prior.query.clone(),
        };
        Ok(response)
    }
    pub async fn other(&self, url: &str) -> reqwest::Result<Bytes> {
        let client = self.client()?;
        client
            .get(format!("https://www.virustotal.com/api/v3/{url}"))
            .send()
            .await?
            .bytes()
            .await
    }
}
impl FromStr for VirusTotalClient {
    type Err = &'static str;
    fn from_str(key: &str) -> Result<Self, Self::Err> {
        if key.len() != VirusTotalClient::KEY_LEN {
            Err("Invalid API key length")
        } else {
            Ok(Self {
                key: key.to_string(),
            })
        }
    }
}
impl From<String> for VirusTotalClient {
    fn from(value: String) -> Self {
        VirusTotalClient::new(value)
    }
}
#[cfg(test)]
mod test {
    use super::*;
    #[tokio::test]
    #[ignore]
    async fn api() {
        if let Ok(api_key) = std::env::var("VT_API_KEY") {
            const HASH: &str = "fff40032c3dc062147c530e3a0a5c7e6acda4d1f1369fbc994cddd3c19a2de88";
            let client = VirusTotalClient::new(api_key);
            let report = client
                .get_report(HASH)
                .await
                .expect("failed to get or parse VT scan report");
            assert!(report.attributes.last_analysis_results.len() > 10);
            let rescan = client
                .request_rescan(HASH)
                .await
                .expect("failed to get or parse VT rescan response");
            assert_eq!(rescan.rescan_type, "analysis");
            const ELF: &[u8] = include_bytes!("../testdata/elf_haiku_x86");
            client
                .submit(Vec::from(ELF), Some("elf_haiku_x86".to_string()))
                .await
                .unwrap();
            match client.get_report("AABBCCDD").await {
                Ok(_) => {
                    unreachable!("No way this should work");
                }
                Err(err) => {
                    assert_eq!(err, *crate::errors::NOT_FOUND_ERROR);
                }
            }
            let response = client
                .download("abc91ba39ea3220d23458f8049ed900c16ce1023")
                .await;
            match response {
                Ok(_) => {
                    unreachable!("This shouldn't work, unless you have VT Premium")
                }
                Err(e) => {
                    assert_eq!(e, *crate::errors::FORBIDDEN_ERROR);
                }
            }
        } else {
            panic!("`VT_API_KEY` not set!")
        }
    }
}