pub mod filereport;
pub mod filerescan;
use crate::filereport::{FileReportData, FileReportRequestResponse};
use crate::filerescan::{FileRescanRequestData, FileRescanRequestResponse};
use std::fmt::{Display, Formatter};
use std::str::FromStr;
use std::string::FromUtf8Error;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::multipart::Form;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VirusTotalError {
pub message: String,
pub code: String,
}
impl Display for VirusTotalError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for VirusTotalError {}
impl From<reqwest::Error> for VirusTotalError {
fn from(err: reqwest::Error) -> Self {
let url = if let Some(url) = err.url() {
format!(" loading {}", url.as_str())
} else {
"".into()
};
Self {
message: "Http error".into(),
code: format!("Error {url} {}", err),
}
}
}
impl From<serde_json::Error> for VirusTotalError {
fn from(err: serde_json::Error) -> Self {
Self {
message: "Json error".into(),
code: format!("Json error at line {}: {}", err.line(), err),
}
}
}
impl From<FromUtf8Error> for VirusTotalError {
fn from(err: FromUtf8Error) -> Self {
Self {
message: "UTF-8 decoding error error".into(),
code: err.to_string(),
}
}
}
#[derive(Clone)]
pub struct VirusTotalClient {
key: Zeroizing<String>,
}
impl VirusTotalClient {
const API_KEY: &'static str = "x-apikey";
const KEY_LEN: usize = 64;
pub fn new(key: &str) -> Self {
Self {
key: Zeroizing::new(key.to_string()),
}
}
fn header(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
VirusTotalClient::API_KEY,
HeaderValue::from_str(&self.key).unwrap(),
);
headers
}
pub async fn get_report(&self, file_hash: &str) -> Result<FileReportData, VirusTotalError> {
let client = reqwest::Client::new();
let body = client
.get(format!(
"https://www.virustotal.com/api/v3/files/{file_hash}"
))
.headers(self.header())
.send()
.await?
.bytes()
.await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileReportRequestResponse = serde_json::from_str(&json_response)?;
match report {
FileReportRequestResponse::Data(data) => Ok(data),
FileReportRequestResponse::Error(error) => Err(error),
}
}
pub async fn request_rescan(
&self,
file_hash: &str,
) -> Result<FileRescanRequestData, VirusTotalError> {
let client = reqwest::Client::new();
let body = client
.post(format!(
"https://www.virustotal.com/api/v3/files/{file_hash}/analyse"
))
.headers(self.header())
.header("content-length", "0")
.send()
.await?
.bytes()
.await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
match report {
FileRescanRequestResponse::Data(data) => Ok(data),
FileRescanRequestResponse::Error(error) => Err(error),
}
}
pub async fn submit(
&self,
data: Vec<u8>,
name: Option<String>,
) -> Result<FileRescanRequestData, VirusTotalError> {
let client = reqwest::Client::new();
let form = if let Some(file_name) = name {
Form::new().part(
"file",
reqwest::multipart::Part::bytes(data)
.file_name(file_name)
.mime_str("application/octet-stream")?,
)
} else {
Form::new().part(
"file",
reqwest::multipart::Part::bytes(data).mime_str("application/octet-stream")?,
)
};
let body = client
.post("https://www.virustotal.com/api/v3/files")
.headers(self.header())
.header("accept", "application/json")
.header("content-type", "multipart/form-data")
.multipart(form)
.send()
.await?
.bytes()
.await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
match report {
FileRescanRequestResponse::Data(data) => Ok(data),
FileRescanRequestResponse::Error(error) => Err(error),
}
}
}
impl FromStr for VirusTotalClient {
type Err = &'static str;
fn from_str(key: &str) -> Result<Self, Self::Err> {
if key.len() != VirusTotalClient::KEY_LEN {
Err("Invalid API key length")
} else {
Ok(Self {
key: Zeroizing::new(key.to_string()),
})
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
#[ignore]
async fn api() {
if let Ok(api_key) = std::env::var("VT_API_KEY") {
const HASH: &str = "fff40032c3dc062147c530e3a0a5c7e6acda4d1f1369fbc994cddd3c19a2de88";
let client = VirusTotalClient::new(&api_key);
let report = client
.get_report(HASH)
.await
.expect("failed to get or parse VT scan report");
assert!(report.attributes.last_analysis_results.len() > 10);
let rescan = client
.request_rescan(HASH)
.await
.expect("failed to get or parse VT rescan response");
assert_eq!(rescan.rescan_type, "analysis");
const ELF: &[u8] = include_bytes!("../testdata/elf_haiku_x86");
client
.submit(Vec::from(ELF), Some("elf_haiku_x86".to_string()))
.await
.unwrap();
} else {
panic!("`VT_API_KEY` not set!")
}
}
}