#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub mod errors;
pub mod filereport;
pub mod filerescan;
pub mod filesearch;
use crate::filereport::{FileReportData, FileReportRequestResponse};
use crate::filerescan::{FileRescanRequestData, FileRescanRequestResponse};
use crate::filesearch::FileSearchResponse;
use std::borrow::Cow;
use std::fmt::{Debug, Display, Formatter};
use std::io::Error;
use std::path::Path;
use std::str::FromStr;
use std::string::FromUtf8Error;
use bytes::Bytes;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::multipart::{Form, Part};
use serde::{Deserialize, Serialize, Serializer};
use tokio::fs::File;
use zeroize::{Zeroize, ZeroizeOnDrop};
const THIRTY_TWO_MEGABYTES: u64 = 32 * 1024 * 1024;
#[derive(Clone, Debug, Eq, Serialize, Deserialize)]
pub struct VirusTotalError {
pub message: String,
pub code: String,
}
impl PartialEq for VirusTotalError {
fn eq(&self, other: &VirusTotalError) -> bool {
self.code.to_lowercase() == other.code.to_lowercase()
}
}
impl Display for VirusTotalError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for VirusTotalError {}
impl From<reqwest::Error> for VirusTotalError {
fn from(err: reqwest::Error) -> Self {
let url = if let Some(url) = err.url() {
format!(" loading {url}")
} else {
"".into()
};
Self {
message: "Http error".into(),
code: format!("Error{url} {err}"),
}
}
}
impl From<serde_json::Error> for VirusTotalError {
fn from(err: serde_json::Error) -> Self {
Self {
message: "Json error".into(),
code: format!("Json error at line {}: {err}", err.line()),
}
}
}
impl From<FromUtf8Error> for VirusTotalError {
fn from(err: FromUtf8Error) -> Self {
Self {
message: "UTF-8 decoding error".into(),
code: err.to_string(),
}
}
}
impl From<std::io::Error> for VirusTotalError {
fn from(value: Error) -> Self {
Self {
message: "IO error".into(),
code: value.to_string(),
}
}
}
#[derive(Clone, Deserialize, Zeroize, ZeroizeOnDrop)]
#[cfg_attr(feature = "clap", derive(clap::Args))]
pub struct VirusTotalClient {
#[cfg_attr(feature = "clap", arg(long, env = "VT_API_KEY"))]
key: String,
}
impl Debug for VirusTotalClient {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "VirusTotal Client v{}", env!("CARGO_PKG_VERSION"))
}
}
impl Serialize for VirusTotalClient {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
#[cfg(feature = "unsafe-serialization")]
return serializer.serialize_str(&self.key);
#[cfg(not(feature = "unsafe-serialization"))]
serializer.serialize_str("your-api-key-here")
}
}
impl VirusTotalClient {
const API_KEY: &'static str = "x-apikey";
pub const KEY_LEN: usize = 64;
pub fn new(key: String) -> Self {
Self { key }
}
#[inline]
fn client(&self) -> reqwest::Result<reqwest::Client> {
let mut headers = HeaderMap::new();
headers.insert(
VirusTotalClient::API_KEY,
HeaderValue::from_str(&self.key).unwrap(),
);
reqwest::ClientBuilder::new()
.gzip(true)
.default_headers(headers)
.build()
}
#[inline]
pub async fn get_report_raw(&self, file_hash: &str) -> Result<Bytes, VirusTotalError> {
let client = self.client()?;
let bytes = client
.get(format!(
"https://www.virustotal.com/api/v3/files/{file_hash}"
))
.send()
.await?
.bytes()
.await?;
Ok(bytes)
}
pub async fn get_report(&self, file_hash: &str) -> Result<FileReportData, VirusTotalError> {
let body = self.get_report_raw(file_hash).await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileReportRequestResponse = serde_json::from_str(&json_response)?;
match report {
FileReportRequestResponse::Data(data) => Ok(data),
FileReportRequestResponse::Error(error) => Err(error),
}
}
#[inline]
pub async fn request_rescan_raw(&self, file_hash: &str) -> Result<Bytes, VirusTotalError> {
let client = self.client()?;
let bytes = client
.post(format!(
"https://www.virustotal.com/api/v3/files/{file_hash}/analyse"
))
.header("content-length", "0")
.send()
.await?
.bytes()
.await?;
Ok(bytes)
}
pub async fn request_rescan(
&self,
file_hash: &str,
) -> Result<FileRescanRequestData, VirusTotalError> {
let body = self.request_rescan_raw(file_hash).await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
match report {
FileRescanRequestResponse::Data(data) => Ok(data),
FileRescanRequestResponse::Error(error) => Err(error),
}
}
#[inline]
pub async fn submit_file_path_raw<P>(&self, path: P) -> Result<Bytes, VirusTotalError>
where
P: AsRef<Path>,
{
let client = self.client()?;
let file = File::open(&path).await?;
let size = file.metadata().await?.len();
let url = if size >= THIRTY_TWO_MEGABYTES {
self.get_upload_url().await?
} else {
"https://www.virustotal.com/api/v3/files".to_string()
};
let form = Form::new().file("file", path).await?;
let bytes = client
.post(url)
.header("accept", "application/json")
.multipart(form)
.send()
.await?
.bytes()
.await?;
Ok(bytes)
}
pub async fn submit_file_path<P>(
&self,
path: P,
) -> Result<FileRescanRequestData, VirusTotalError>
where
P: AsRef<Path>,
{
let body = self.submit_file_path_raw(path).await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
match report {
FileRescanRequestResponse::Data(data) => Ok(data),
FileRescanRequestResponse::Error(error) => Err(error),
}
}
#[inline]
pub async fn submit_bytes_raw<N>(
&self,
data: Vec<u8>,
name: N,
) -> Result<Bytes, VirusTotalError>
where
N: Into<Cow<'static, str>>,
{
let client = self.client()?;
let url = if data.len() as u64 >= THIRTY_TWO_MEGABYTES {
self.get_upload_url().await?
} else {
"https://www.virustotal.com/api/v3/files".to_string()
};
let form = Form::new().part(
"file",
Part::bytes(data)
.file_name(name)
.mime_str("application/octet-stream")?,
);
let bytes = client
.post(url)
.header("accept", "application/json")
.multipart(form)
.send()
.await?
.bytes()
.await?;
Ok(bytes)
}
pub async fn submit_bytes<N>(
&self,
data: Vec<u8>,
name: N,
) -> Result<FileRescanRequestData, VirusTotalError>
where
N: Into<Cow<'static, str>>,
{
let body = self.submit_bytes_raw(data, name).await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let report: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
match report {
FileRescanRequestResponse::Data(data) => Ok(data),
FileRescanRequestResponse::Error(error) => Err(error),
}
}
#[inline]
pub async fn get_upload_url(&self) -> Result<String, VirusTotalError> {
let response = self.other("files/upload_url").await?;
let response = String::from_utf8(response.to_vec())?;
let response = serde_json::from_str::<serde_json::Value>(&response)?;
let url = response["data"].as_str().ok_or(VirusTotalError {
message: "No URL returned".to_string(),
code: "NoURLReturned".to_string(),
})?;
Ok(url.to_string())
}
pub async fn download(&self, file_hash: &str) -> Result<Vec<u8>, VirusTotalError> {
let client = self.client()?;
let response = client
.get(format!(
"https://www.virustotal.com/api/v3/files/{file_hash}/download"
))
.send()
.await?;
if !response.status().is_success() {
let body = response.bytes().await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let error: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
return if let FileRescanRequestResponse::Error(error) = error {
Err(error)
} else {
Err(VirusTotalError {
message: json_response,
code: "VTError".into(),
})
};
}
let body = response.bytes().await?;
Ok(body.to_vec())
}
#[inline]
pub async fn search_raw<Q>(&self, query: Q) -> Result<Bytes, VirusTotalError>
where
Q: Display,
{
let url = format!(
"https://www.virustotal.com/vtapi/v2/file/search?apikey={}&query={query}",
self.key.as_str()
);
let body = self.client()?.get(url).send().await?.bytes().await?;
Ok(body)
}
pub async fn search<Q>(&self, query: Q) -> Result<FileSearchResponse, VirusTotalError>
where
Q: Display,
{
let body = self.search_raw(&query).await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let response: FileSearchResponse = serde_json::from_str(&json_response)?;
let response = FileSearchResponse {
response_code: response.response_code,
offset: response.offset,
hashes: response.hashes,
query: query.to_string(),
verbose_msg: response.verbose_msg,
};
Ok(response)
}
pub async fn search_offset(
&self,
prior: &FileSearchResponse,
) -> Result<FileSearchResponse, VirusTotalError> {
if prior.offset.is_none() {
return Err(VirusTotalError {
message: "Cannot continue a search without an offset code".to_string(),
code: "NonPaginatedResults".to_string(),
});
}
let url = format!(
"https://www.virustotal.com/vtapi/v2/file/search?apikey={}&query={}&offset={}",
self.key.as_str(),
prior.query,
prior.offset.as_ref().unwrap()
);
let body = self.client()?.get(url).send().await?.bytes().await?;
let json_response = String::from_utf8(body.to_ascii_lowercase())?;
let response: FileSearchResponse = serde_json::from_str(&json_response)?;
let response = FileSearchResponse {
response_code: response.response_code,
offset: response.offset,
hashes: response.hashes,
query: prior.query.clone(),
verbose_msg: response.verbose_msg,
};
Ok(response)
}
pub async fn other(&self, url: &str) -> reqwest::Result<Bytes> {
let client = self.client()?;
client
.get(format!("https://www.virustotal.com/api/v3/{url}"))
.send()
.await?
.bytes()
.await
}
}
impl FromStr for VirusTotalClient {
type Err = &'static str;
fn from_str(key: &str) -> Result<Self, Self::Err> {
if key.len() != VirusTotalClient::KEY_LEN {
Err("Invalid API key length")
} else {
Ok(Self {
key: key.to_string(),
})
}
}
}
impl From<String> for VirusTotalClient {
fn from(value: String) -> Self {
VirusTotalClient::new(value)
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
#[ignore]
async fn api() {
if let Ok(api_key) = std::env::var("VT_API_KEY") {
const HASH: &str = "fff40032c3dc062147c530e3a0a5c7e6acda4d1f1369fbc994cddd3c19a2de88";
let client = VirusTotalClient::new(api_key);
let report = client
.get_report(HASH)
.await
.expect("failed to get or parse VT scan report");
assert!(report.attributes.last_analysis_results.len() > 10);
let rescan = client
.request_rescan(HASH)
.await
.expect("failed to get or parse VT rescan response");
assert_eq!(rescan.rescan_type, "analysis");
const ELF: &[u8] = include_bytes!("../testdata/elf_haiku_x86");
client
.submit_bytes(Vec::from(ELF), "elf_haiku_x86".to_string())
.await
.unwrap();
match client.get_report("AABBCCDD").await {
Ok(_) => {
unreachable!("No way this should work");
}
Err(err) => {
assert_eq!(err, *crate::errors::NOT_FOUND_ERROR);
}
}
let response = client
.download("abc91ba39ea3220d23458f8049ed900c16ce1023")
.await;
match response {
Ok(_) => {
unreachable!("This shouldn't work, unless you have VT Premium")
}
Err(e) => {
assert_eq!(e, *crate::errors::FORBIDDEN_ERROR);
}
}
} else {
panic!("`VT_API_KEY` not set!")
}
}
}