use malwaredb_virustotal::VirusTotalClient;
use std::path::PathBuf;
use std::process::ExitCode;
use anyhow::{bail, Result};
use clap::{Parser, Subcommand};
use sha2::{Digest, Sha256};
#[derive(Parser)]
#[command(author, about, version)]
struct Args {
#[arg(long, env = "VT_API_KEY")]
pub key: String,
#[clap(subcommand)]
pub action: Action,
}
#[derive(Parser, Clone)]
struct SubmitFileArg {
pub file: PathBuf,
}
#[derive(Parser, Clone)]
struct FileReportArg {
#[arg(short, long)]
pub file: Option<PathBuf>,
#[arg(long)]
pub hash: Option<String>,
#[arg(short, long)]
pub output: Option<PathBuf>,
}
#[derive(Parser, Clone)]
struct HashArg {
pub hash: String,
}
#[derive(Parser, Clone)]
struct StringArg {
pub search: String,
}
impl HashArg {
pub fn valid(&self) -> bool {
self.hash.len() == 32 || self.hash.len() == 40 || self.hash.len() == 64
}
}
#[derive(Subcommand, Clone)]
enum Action {
Submit(SubmitFileArg),
GetReport(FileReportArg),
Rescan(HashArg),
Download(HashArg),
Search(StringArg),
}
impl Action {
async fn execute(&self, client: &VirusTotalClient) -> Result<()> {
match self {
Action::Submit(arg) => {
let contents = std::fs::read(&arg.file)?;
let response = client
.submit(
contents,
arg.file
.file_name()
.map(|s| s.to_str().unwrap().to_string()),
)
.await?;
println!("Submitted, request id {}", response.id);
}
Action::GetReport(arg) => {
if arg.file.is_none() && arg.hash.is_none() {
bail!("Nothing to do, neither file path nor hash were specified.");
}
if let Some(report_dest) = &arg.output {
if let Some(input_file) = &arg.file {
let contents = std::fs::read(input_file)?;
let mut sha256 = Sha256::new();
sha256.update(contents);
let sha256 = sha256.finalize();
let sha256 = hex::encode(sha256);
let response = client.get_report(&sha256).await?;
let report = serde_json::to_string(&response)?;
std::fs::write(report_dest, report)?;
println!(
"AVs with detection: {} of {} for {input_file:?}",
response.attributes.last_analysis_stats.malicious,
response.attributes.last_analysis_stats.av_count()
);
} else if let Some(input_hash) = &arg.hash {
let response = client.get_report(input_hash).await?;
let report = serde_json::to_string(&response)?;
std::fs::write(report_dest, report)?;
println!(
"AVs with detection: {} of {} for {input_hash}",
response.attributes.last_analysis_stats.malicious,
response.attributes.last_analysis_stats.av_count()
);
}
} else {
if let Some(input_file) = &arg.file {
let contents = std::fs::read(input_file)?;
let mut sha256 = Sha256::new();
sha256.update(contents);
let sha256 = sha256.finalize();
let sha256 = hex::encode(sha256);
let response = client.get_report(&sha256).await?;
println!(
"AVs with detection: {} of {} for {input_file:?}",
response.attributes.last_analysis_stats.malicious,
response.attributes.last_analysis_stats.av_count()
);
}
if let Some(input_hash) = &arg.hash {
let response = client.get_report(input_hash).await?;
println!(
"AVs with detection: {} of {} for {input_hash}",
response.attributes.last_analysis_stats.malicious,
response.attributes.last_analysis_stats.av_count()
);
}
}
}
Action::Rescan(arg) => {
if !arg.valid() {
bail!("Hash {} isn't an MD5, SHA-1, or SHA-256 hash.", arg.hash);
}
let response = client.request_rescan(&arg.hash).await?;
println!("Rescan for {} requested: {}", arg.hash, response.id);
}
Action::Download(arg) => {
if !arg.valid() {
bail!("Hash {} isn't an MD5, SHA-1, or SHA-256 hash.", arg.hash);
}
let response = client.download(&arg.hash).await?;
std::fs::write(&arg.hash, response)?;
}
Action::Search(arg) => {
let response = client.search(&arg.search).await?;
if response.hashes.is_empty() {
println!("Nothing found.");
} else {
for hash in response.hashes {
println!("{hash}");
}
}
}
}
Ok(())
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<ExitCode> {
let args = Args::parse();
let client: VirusTotalClient = args.key.into();
args.action.execute(&client).await?;
Ok(ExitCode::SUCCESS)
}