malwaredb_virustotal/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![deny(missing_docs)]
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7#![forbid(unsafe_code)]
8
9/// Data types common to a few data types
10pub mod common;
11/// Logic for parsing the domain report data from Virus Total
12pub mod domainreport;
13/// Pre-defined error types for Virus Total allowing for error comparison.
14/// <https://virustotal.readme.io/reference/errors>
15pub mod errors;
16/// Logic for parsing the file report data from Virus Total
17pub mod filereport;
18/// Logic for searching for files based on types, submission, and attributes
19pub mod filesearch;
20/// Logic for parsing the IP report data from Virus Total
21pub mod ipreport;
22
23use crate::common::{RecordType, ReportRequestResponse, ReportResponseHeader, RescanRequestData};
24use crate::domainreport::DomainAttributes;
25use crate::errors::VirusTotalError;
26use crate::filereport::ScanResultAttributes;
27use crate::filesearch::FileSearchResponse;
28use crate::ipreport::IPAttributes;
29
30use std::borrow::Cow;
31use std::fmt::{Debug, Display, Formatter};
32use std::path::Path;
33use std::str::FromStr;
34
35use bytes::Bytes;
36use reqwest::header::{HeaderMap, HeaderValue};
37use reqwest::multipart::{Form, Part};
38use serde::{Deserialize, Serialize, Serializer};
39use zeroize::{Zeroize, ZeroizeOnDrop};
40
41const THIRTY_TWO_MEGABYTES: u64 = 32 * 1024 * 1024;
42
43/// Virus Total client object
44#[derive(Clone, Deserialize, Zeroize, ZeroizeOnDrop)]
45#[cfg_attr(feature = "clap", derive(clap::Args))]
46pub struct VirusTotalClient {
47    /// The API key used to interact with Virus Total
48    #[cfg_attr(feature = "clap", arg(long, env = "VT_API_KEY"))]
49    #[serde(alias = "vt_api_key")]
50    key: String,
51}
52
53impl Debug for VirusTotalClient {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        write!(f, "VirusTotal Client v{}", env!("CARGO_PKG_VERSION"))
56    }
57}
58
59impl Serialize for VirusTotalClient {
60    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
61    where
62        S: Serializer,
63    {
64        #[cfg(feature = "unsafe-serialization")]
65        return serializer.serialize_str(&self.key);
66
67        #[cfg(not(feature = "unsafe-serialization"))]
68        serializer.serialize_str("your-api-key-here")
69    }
70}
71
72impl VirusTotalClient {
73    /// Header used to send the API key to Virus Total
74    const API_KEY: &'static str = "x-apikey";
75
76    /// Length of the API key
77    pub const KEY_LEN: usize = 64;
78
79    /// New Virus Total client given an API key which is assumed to be valid.
80    #[must_use]
81    pub fn new(key: String) -> Self {
82        Self { key }
83    }
84
85    /// Generate a client which already knows to send the API key, and asks for gzip responses.
86    #[inline]
87    fn client(&self) -> Result<reqwest::Client, VirusTotalError> {
88        let mut headers = HeaderMap::new();
89        headers.insert(
90            VirusTotalClient::API_KEY,
91            HeaderValue::from_str(&self.key).unwrap(),
92        );
93
94        reqwest::ClientBuilder::new()
95            .gzip(true)
96            .default_headers(headers)
97            .build()
98            .map_err(|e| {
99                #[cfg(feature = "tracing")]
100                tracing::error!("Error creating VirusTotal client: {e}");
101                e.into()
102            })
103    }
104
105    /// Get the unparsed report from Virus Total for a known type.
106    ///
107    /// File: report given an MD5, SHA-1, or SHA-256 hash
108    /// Domain: a fully qualified domain name
109    /// IP address: an ip address
110    ///
111    /// # Errors
112    ///
113    /// Will return an error if there is a networking problem.
114    #[inline]
115    pub async fn get_report_raw(
116        &self,
117        record_type: RecordType,
118        resource: &str,
119    ) -> Result<Bytes, VirusTotalError> {
120        self.other(&format!("{record_type}/{resource}")).await
121    }
122
123    /// Get a parsed file report from Virus Total for an MD5, SHA-1, or SHA-256 hash, which is assumed to be valid.
124    ///
125    /// # Errors
126    ///
127    /// Will return an error if there is a networking problem or if the response wasn't expected.
128    pub async fn get_file_report(
129        &self,
130        file_hash: &str,
131    ) -> Result<ReportResponseHeader<ScanResultAttributes>, VirusTotalError> {
132        let body = self.get_report_raw(RecordType::File, file_hash).await?;
133        let json_response = String::from_utf8(body.to_ascii_lowercase())
134            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
135        let report: ReportRequestResponse<ReportResponseHeader<ScanResultAttributes>> =
136            VirusTotalError::parse_json(&json_response)?;
137
138        match report {
139            ReportRequestResponse::Data(data) => Ok(data),
140            ReportRequestResponse::Error(error) => Err(error),
141        }
142    }
143
144    /// Request Virus Total rescan a file for an MD5, SHA-1, or SHA-256 hash, and receive the unparsed response
145    ///
146    /// # Errors
147    ///
148    /// Will return an error if there is a networking problem.
149    #[inline]
150    pub async fn request_file_rescan_raw(&self, file_hash: &str) -> Result<Bytes, VirusTotalError> {
151        self.request_rescan_raw(RecordType::File, file_hash).await
152    }
153
154    /// Request Virus Total rescan a file for an MD5, SHA-1, or SHA-256 hash, which is assumed to be valid.
155    ///
156    /// ```rust,no_run
157    /// use malwaredb_virustotal::VirusTotalClient;
158    ///
159    /// // Use of `.unwrap()` for demonstration, don't actually do this.
160    /// let client = VirusTotalClient::new(std::env::var("VT_API_KEY").unwrap());
161    /// # tokio_test::block_on(async {
162    /// let response = client.request_file_rescan("abc91ba39ea3220d23458f8049ed900c16ce1023").await.unwrap();
163    /// assert_eq!(response.rescan_type, "analysis");
164    /// # })
165    /// ```
166    ///
167    /// # Errors
168    ///
169    /// Will return an error if there is a networking problem or if the response wasn't expected.
170    pub async fn request_file_rescan(
171        &self,
172        file_hash: &str,
173    ) -> Result<RescanRequestData, VirusTotalError> {
174        let body = self.request_file_rescan_raw(file_hash).await?;
175        let json_response = String::from_utf8(body.to_ascii_lowercase())
176            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
177        let report: ReportRequestResponse<RescanRequestData> =
178            VirusTotalError::parse_json(&json_response)?;
179
180        match report {
181            ReportRequestResponse::Data(data) => Ok(data),
182            ReportRequestResponse::Error(error) => Err(error),
183        }
184    }
185
186    /// Submit a file by path to Virus Total and receive the unparsed response.
187    ///
188    /// # Errors
189    ///
190    /// Will return an error if there is a networking problem.
191    #[inline]
192    pub async fn submit_file_path_raw<P>(&self, path: P) -> Result<Bytes, VirusTotalError>
193    where
194        P: AsRef<Path>,
195    {
196        let client = self.client()?;
197
198        #[cfg(feature = "tokio")]
199        let file = tokio::fs::File::open(&path).await.map_err(|e| {
200            #[cfg(feature = "tracing")]
201            tracing::error!("Error opening file for VirusTotal submission: {e}");
202            VirusTotalError::IOError(e.to_string())
203        })?;
204
205        #[cfg(not(feature = "tokio"))]
206        let file = std::fs::File::open(&path).map_err(|e| {
207            #[cfg(feature = "tracing")]
208            tracing::error!("Error opening file for VirusTotal submission: {e}");
209            VirusTotalError::IOError(e.to_string())
210        })?;
211
212        #[cfg(feature = "tokio")]
213        let size = file
214            .metadata()
215            .await
216            .map_err(|e| {
217                #[cfg(feature = "tracing")]
218                tracing::error!("Error getting file size: {e}");
219                VirusTotalError::IOError(e.to_string())
220            })?
221            .len();
222
223        #[cfg(not(feature = "tokio"))]
224        let size = file
225            .metadata()
226            .map_err(|e| {
227                #[cfg(feature = "tracing")]
228                tracing::error!("Error getting file size: {e}");
229                VirusTotalError::IOError(e.to_string())
230            })?
231            .len();
232
233        let url = if size >= THIRTY_TWO_MEGABYTES {
234            self.get_upload_url().await?
235        } else {
236            "https://www.virustotal.com/api/v3/files".to_string()
237        };
238
239        let form = Form::new()
240            .file("file", path)
241            .await
242            .map_err(|e| VirusTotalError::IOError(e.to_string()))?;
243
244        client
245            .post(url)
246            .header("accept", "application/json")
247            .multipart(form)
248            .send()
249            .await
250            .map_err(|e| {
251                #[cfg(feature = "tracing")]
252                tracing::error!("Error submitting VirusTotal file: {e}");
253                e
254            })?
255            .bytes()
256            .await
257            .map_err(|e| {
258                #[cfg(feature = "tracing")]
259                tracing::error!("Error parsing VirusTotal file submission response: {e}");
260                e.into()
261            })
262    }
263
264    /// Submit a file by path to Virus Total and receive a parsed response.
265    ///
266    /// # Errors
267    ///
268    /// Will return an error if there is a networking problem or if the response wasn't expected.
269    pub async fn submit_file_path<P: AsRef<Path>>(
270        &self,
271        path: P,
272    ) -> Result<RescanRequestData, VirusTotalError> {
273        let body = self.submit_file_path_raw(path).await?;
274        let json_response = String::from_utf8(body.to_ascii_lowercase())
275            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
276        let report: ReportRequestResponse<RescanRequestData> =
277            VirusTotalError::parse_json(&json_response)?;
278
279        match report {
280            ReportRequestResponse::Data(data) => Ok(data),
281            ReportRequestResponse::Error(error) => Err(error),
282        }
283    }
284
285    /// Submit bytes to Virus Total and receive the unparsed response.
286    ///
287    /// # Errors
288    ///
289    /// Will return an error if there is a networking problem.
290    #[inline]
291    pub async fn submit_bytes_raw<N: Into<Cow<'static, str>>>(
292        &self,
293        data: Vec<u8>,
294        name: N,
295    ) -> Result<Bytes, VirusTotalError> {
296        let client = self.client()?;
297
298        // It's unfortunate that we had to take ownership of the bytes. This is because `Path::new()`
299        // is private in `reqwest`. There is no other way to get the size.
300        let url = if data.len() as u64 >= THIRTY_TWO_MEGABYTES {
301            self.get_upload_url().await?
302        } else {
303            "https://www.virustotal.com/api/v3/files".to_string()
304        };
305
306        let form = Form::new().part(
307            "file",
308            Part::bytes(data)
309                .file_name(name)
310                .mime_str("application/octet-stream")?,
311        );
312
313        client
314            .post(url)
315            .header("accept", "application/json")
316            .multipart(form)
317            .send()
318            .await
319            .map_err(|e| {
320                #[cfg(feature = "tracing")]
321                tracing::error!("Error submitting VirusTotal bytes: {e}");
322                e
323            })?
324            .bytes()
325            .await
326            .map_err(|e| {
327                #[cfg(feature = "tracing")]
328                tracing::error!("Error parsing VirusTotal bytes submission response: {e}");
329                e.into()
330            })
331    }
332
333    /// Submit bytes to Virus Total and receive a parsed response.
334    ///
335    /// # Errors
336    ///
337    /// Will return an error if there is a networking problem or if the response wasn't expected.
338    pub async fn submit_bytes<N: Into<Cow<'static, str>>>(
339        &self,
340        data: Vec<u8>,
341        name: N,
342    ) -> Result<RescanRequestData, VirusTotalError> {
343        let body = self.submit_bytes_raw(data, name).await?;
344        let json_response = String::from_utf8(body.to_ascii_lowercase())
345            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
346        let report: ReportRequestResponse<RescanRequestData> =
347            VirusTotalError::parse_json(&json_response)?;
348
349        match report {
350            ReportRequestResponse::Data(data) => Ok(data),
351            ReportRequestResponse::Error(error) => Err(error),
352        }
353    }
354
355    /// Get a special one-time URL endpoint for submitting files larger than 32 MB
356    ///
357    /// # Errors
358    ///
359    /// Will return an error if there is a networking problem or if the response wasn't expected.
360    #[inline]
361    pub async fn get_upload_url(&self) -> Result<String, VirusTotalError> {
362        let response = self.other("files/upload_url").await?;
363        let response = String::from_utf8(response.to_vec())
364            .map_err(|_e| VirusTotalError::UTF8Error(response.to_vec()))?;
365        let response = serde_json::from_str::<serde_json::Value>(&response)
366            .map_err(|_e| VirusTotalError::JsonError(response))?;
367        let url = response["data"]
368            .as_str()
369            .ok_or(VirusTotalError::NoURLReturned)?;
370        Ok(url.to_string())
371    }
372
373    /// Download a file from Virus Total, requires Virus Total Premium!
374    ///
375    /// ```rust,no_run
376    /// use malwaredb_virustotal::VirusTotalClient;
377    ///
378    /// // Use of `.unwrap()` for demonstration, don't actually do this.
379    /// let client = VirusTotalClient::new(std::env::var("VT_API_KEY").unwrap());
380    /// # tokio_test::block_on(async {
381    /// let file_contents = client.download("abc91ba39ea3220d23458f8049ed900c16ce1023").await.unwrap();
382    /// # })
383    /// ```
384    ///
385    /// # Errors
386    ///
387    /// Will return an error if there is a networking problem.
388    pub async fn download(&self, file_hash: &str) -> Result<Vec<u8>, VirusTotalError> {
389        let client = self.client()?;
390        let response = client
391            .get(format!(
392                "https://www.virustotal.com/api/v3/files/{file_hash}/download"
393            ))
394            .send()
395            .await?;
396
397        if !response.status().is_success() {
398            let body = response.bytes().await?;
399            let json_response = String::from_utf8(body.to_ascii_lowercase())
400                .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
401
402            let error: ReportRequestResponse<RescanRequestData> =
403                VirusTotalError::parse_json(&json_response)?;
404            return if let ReportRequestResponse::Error(error) = error {
405                Err(error)
406            } else {
407                // Should never happen, since we're only here if some error occurred.
408                Err(VirusTotalError::UnknownError)
409            };
410        }
411
412        Ok(response
413            .bytes()
414            .await
415            .map_err(|e| {
416                #[cfg(feature = "tracing")]
417                tracing::error!("Error parsing VirusTotal file response: {e}");
418                e
419            })?
420            .to_vec())
421    }
422
423    /// Search Virus Total for files matching some search parameters, receive unparsed response.
424    /// Requires VT Premium!
425    ///
426    /// # Errors
427    ///
428    /// Will return an error if there is a networking problem.
429    #[inline]
430    pub async fn search_raw<Q: Display>(&self, query: Q) -> Result<Bytes, VirusTotalError> {
431        let url = format!(
432            "https://www.virustotal.com/vtapi/v2/file/search?apikey={}&query={query}",
433            self.key.as_str()
434        );
435
436        self.client()?
437            .get(url)
438            .send()
439            .await?
440            .bytes()
441            .await
442            .map_err(|e| {
443                #[cfg(feature = "tracing")]
444                tracing::error!("Error parsing VirusTotal search result: {e}");
445                e.into()
446            })
447    }
448
449    /// Search Virus Total for files matching some search parameters. Requires VT Premium!
450    /// For more information see <https://virustotal.readme.io/v2.0/reference/file-search>.
451    /// Note: This uses the V2 API.
452    /// Example:
453    ///
454    /// ```rust,no_run
455    /// use malwaredb_virustotal::{VirusTotalClient, filesearch::flags};
456    ///
457    /// // Use of `.unwrap()` for demonstration, don't actually do this.
458    /// let client = VirusTotalClient::new(std::env::var("VT_API_KEY").unwrap());
459    /// // Find PDFs, which are benign, have a fill-able form, and Javascript, first seen yesterday
460    /// # tokio_test::block_on(async {
461    /// #[cfg(not(feature = "chrono"))]
462    /// let result = client.search(flags::FileType::Pdf + flags::BENIGN + flags::Tag::PdfForm + flags::Tag::PdfJs).await.unwrap();
463    /// #[cfg(feature = "chrono")]
464    /// let result = client.search(flags::FileType::Pdf + flags::BENIGN + flags::Tag::PdfForm + flags::Tag::PdfJs + flags::FirstSubmission::days(1)).await.unwrap();
465    /// # })
466    /// ```
467    ///
468    /// # Errors
469    ///
470    /// Will return an error if there is a networking problem or if the response wasn't expected.
471    pub async fn search<Q: Display>(
472        &self,
473        query: Q,
474    ) -> Result<FileSearchResponse, VirusTotalError> {
475        let body = self.search_raw(&query).await?;
476        let json_response = String::from_utf8(body.to_ascii_lowercase())
477            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
478        let response: FileSearchResponse = VirusTotalError::parse_json(&json_response)?;
479
480        let response = FileSearchResponse {
481            response_code: response.response_code,
482            offset: response.offset,
483            hashes: response.hashes,
484            query: query.to_string(),
485            verbose_msg: response.verbose_msg,
486        };
487        Ok(response)
488    }
489
490    /// Search Virus Total for files matching some search parameters. Requires VT Premium!
491    /// Use this to continue from a prior search for the next 300 results. Requires parsed response
492    /// via [`Self::search()`]
493    ///
494    /// # Errors
495    ///
496    /// Will return an error if there is a networking problem or if the response wasn't expected.
497    pub async fn search_offset(
498        &self,
499        prior: &FileSearchResponse,
500    ) -> Result<FileSearchResponse, VirusTotalError> {
501        if let Some(offset) = prior.offset.as_ref() {
502            let url = format!(
503                "https://www.virustotal.com/vtapi/v2/file/search?apikey={}&query={}&offset={}",
504                self.key.as_str(),
505                prior.query,
506                offset
507            );
508
509            let body = self.client()?.get(url).send().await?.bytes().await?;
510            let json_response = String::from_utf8(body.to_ascii_lowercase())
511                .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
512            let response: FileSearchResponse = VirusTotalError::parse_json(&json_response)?;
513
514            let response = FileSearchResponse {
515                response_code: response.response_code,
516                offset: response.offset,
517                hashes: response.hashes,
518                query: prior.query.clone(),
519                verbose_msg: response.verbose_msg,
520            };
521            Ok(response)
522        } else {
523            Err(VirusTotalError::NonPaginatedResults)
524        }
525    }
526
527    /// Get a Virus Total report for a domain, returning the parsed response
528    ///
529    /// # Errors
530    ///
531    /// Will return an error if there is a networking problem or if the response wasn't expected.
532    pub async fn get_domain_report(
533        &self,
534        domain: &str,
535    ) -> Result<ReportResponseHeader<DomainAttributes>, VirusTotalError> {
536        let body = self.get_report_raw(RecordType::Domain, domain).await?;
537        let json_response = String::from_utf8(body.to_ascii_lowercase())
538            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
539        let report: ReportRequestResponse<ReportResponseHeader<DomainAttributes>> =
540            VirusTotalError::parse_json(&json_response)?;
541
542        match report {
543            ReportRequestResponse::Data(data) => Ok(data),
544            ReportRequestResponse::Error(error) => Err(error),
545        }
546    }
547
548    /// Request rescan of a domain and receive parsed response
549    ///
550    /// # Errors
551    ///
552    /// Will return an error if there is a networking problem or if the response wasn't expected.
553    pub async fn request_domain_rescan(
554        &self,
555        domain: &str,
556    ) -> Result<RescanRequestData, VirusTotalError> {
557        let body = self.request_domain_rescan_raw(domain).await?;
558        let json_response = String::from_utf8(body.to_ascii_lowercase())
559            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
560        let report: ReportRequestResponse<RescanRequestData> =
561            VirusTotalError::parse_json(&json_response)?;
562
563        match report {
564            ReportRequestResponse::Data(data) => Ok(data),
565            ReportRequestResponse::Error(error) => Err(error),
566        }
567    }
568
569    /// Request rescan of a domain and receive the unparsed response
570    ///
571    /// # Errors
572    ///
573    /// Will return an error if there is a networking problem.
574    #[inline]
575    pub async fn request_domain_rescan_raw(&self, domain: &str) -> Result<Bytes, VirusTotalError> {
576        self.request_rescan_raw(RecordType::Domain, domain).await
577    }
578
579    /// Request Virus Total rescan of a file or domain, internally used
580    ///
581    /// # Errors
582    ///
583    /// Will return an error if there is a networking problem.
584    #[inline]
585    async fn request_rescan_raw(
586        &self,
587        rescan_type: RecordType,
588        identifier: &str,
589    ) -> Result<Bytes, VirusTotalError> {
590        self.client()?
591            .post(format!(
592                "https://www.virustotal.com/api/v3/{rescan_type}/{identifier}/analyse"
593            ))
594            .header("content-length", "0")
595            .send()
596            .await
597            .map_err(|e| {
598                #[cfg(feature = "tracing")]
599                tracing::error!("Error requesting VirusTotal rescan: {e}");
600                e
601            })?
602            .bytes()
603            .await
604            .map_err(|e| {
605                #[cfg(feature = "tracing")]
606                tracing::error!("Error parsing VirusTotal rescan response: {e}");
607                e.into()
608            })
609    }
610
611    /// Get a Virus Total report for an IP address, returning the parsed response
612    ///
613    /// # Errors
614    ///
615    /// Will return an error if there is a networking problem or if the response wasn't expected.
616    pub async fn get_ip_report(
617        &self,
618        ip: &str,
619    ) -> Result<ReportResponseHeader<IPAttributes>, VirusTotalError> {
620        let body = self.get_report_raw(RecordType::IPAddress, ip).await?;
621        let json_response = String::from_utf8(body.to_ascii_lowercase())
622            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
623        let report: ReportRequestResponse<ReportResponseHeader<IPAttributes>> =
624            VirusTotalError::parse_json(&json_response)?;
625
626        match report {
627            ReportRequestResponse::Data(data) => Ok(data),
628            ReportRequestResponse::Error(error) => Err(error),
629        }
630    }
631
632    /// Request rescan of an IP address and receive the unparsed response
633    ///
634    /// # Errors
635    ///
636    /// Will return an error if there is a networking problem.
637    #[inline]
638    pub async fn request_ip_rescan_raw(&self, ip: &str) -> Result<Bytes, VirusTotalError> {
639        self.request_rescan_raw(RecordType::IPAddress, ip).await
640    }
641
642    /// Request rescan of an IP address and receive parsed response
643    ///
644    /// # Errors
645    ///
646    /// Will return an error if there is a networking problem or if the response wasn't expected.
647    pub async fn request_ip_rescan(&self, ip: &str) -> Result<RescanRequestData, VirusTotalError> {
648        let body = self.request_ip_rescan_raw(ip).await?;
649        let json_response = String::from_utf8(body.to_ascii_lowercase())
650            .map_err(|_e| VirusTotalError::UTF8Error(body.to_vec()))?;
651        let report: ReportRequestResponse<RescanRequestData> =
652            VirusTotalError::parse_json(&json_response)?;
653
654        match report {
655            ReportRequestResponse::Data(data) => Ok(data),
656            ReportRequestResponse::Error(error) => Err(error),
657        }
658    }
659
660    /// Since this crate doesn't support every Virus Total feature, this function can receive a
661    /// URL fragment and return the response.
662    ///
663    /// # Errors
664    ///
665    /// Will return an error if there is a networking problem.
666    #[inline]
667    pub async fn other(&self, url: &str) -> Result<Bytes, VirusTotalError> {
668        let client = self.client()?;
669        client
670            .get(format!("https://www.virustotal.com/api/v3/{url}"))
671            .send()
672            .await
673            .map_err(|e| {
674                #[cfg(feature = "tracing")]
675                tracing::error!("Error requesting VirusTotal other: {e}");
676                VirusTotalError::NetworkError(e.to_string())
677            })?
678            .bytes()
679            .await
680            .map_err(|e| {
681                #[cfg(feature = "tracing")]
682                tracing::error!("Error parsing VirusTotal other response: {e}");
683                VirusTotalError::NetworkError(e.to_string())
684            })
685    }
686}
687
688/// Get a Virus Total client from a key, checking that the key is the expected length.
689impl FromStr for VirusTotalClient {
690    type Err = &'static str;
691
692    fn from_str(key: &str) -> Result<Self, Self::Err> {
693        if key.len() == VirusTotalClient::KEY_LEN {
694            Ok(Self {
695                key: key.to_string(),
696            })
697        } else {
698            Err("Invalid API key length")
699        }
700    }
701}
702
703impl From<String> for VirusTotalClient {
704    fn from(value: String) -> Self {
705        VirusTotalClient::new(value)
706    }
707}
708
709#[cfg(test)]
710mod test {
711    use super::*;
712    use sha2::{Digest, Sha256};
713
714    const ELF: &[u8] = include_bytes!("../testdata/elf_haiku_x86");
715
716    #[tokio::test]
717    #[ignore = "don't test with VT API key in CI"]
718    async fn api() {
719        if let Ok(api_key) = std::env::var("VT_API_KEY") {
720            const HASH: &str = "fff40032c3dc062147c530e3a0a5c7e6acda4d1f1369fbc994cddd3c19a2de88";
721
722            let client = VirusTotalClient::new(api_key);
723
724            let report = client
725                .get_file_report(HASH)
726                .await
727                .expect("failed to get or parse VT scan report");
728            assert!(report.attributes.last_analysis_results.len() > 10);
729
730            let rescan = client
731                .request_file_rescan(HASH)
732                .await
733                .expect("failed to get or parse VT rescan response");
734            assert_eq!(rescan.rescan_type, "analysis");
735
736            client
737                .submit_bytes(Vec::from(ELF), "elf_haiku_x86".to_string())
738                .await
739                .unwrap();
740
741            match client.get_file_report("AABBCCDD").await {
742                Ok(_) => {
743                    unreachable!("No way this should work");
744                }
745                Err(err) => {
746                    assert_eq!(err, VirusTotalError::NotFoundError);
747                }
748            }
749
750            let response = client
751                .download("abc91ba39ea3220d23458f8049ed900c16ce1023")
752                .await;
753            match response {
754                Ok(bytes) => {
755                    let mut sha256 = Sha256::new();
756                    sha256.update(&bytes);
757                    let sha256 = sha256.finalize();
758                    let sha256 = hex::encode(sha256);
759                    assert_eq!(
760                        sha256,
761                        "de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740"
762                    );
763                }
764                Err(e) => {
765                    assert_eq!(e, VirusTotalError::ForbiddenError);
766                }
767            }
768
769            let response = client.get_domain_report("haiku-os.org").await;
770            match response {
771                Ok(report) => {
772                    println!("{:?}", report.attributes.extra);
773                    assert!(report.attributes.extra.is_empty());
774                }
775                Err(e) => {
776                    panic!("Domain report error: {e}");
777                }
778            }
779
780            let response = client.request_domain_rescan("haiku-os.org").await;
781            match response {
782                Ok(report) => {
783                    assert!(!report.links.is_empty());
784                }
785                Err(e) => {
786                    panic!("Domain rescan error: {e}");
787                }
788            }
789
790            let response = client
791                .get_ip_report("23.53.35.49" /* phobos.apple.com */)
792                .await;
793            match response {
794                Ok(report) => {
795                    println!("{:?}", report.attributes.extra);
796                    assert!(report.attributes.extra.is_empty());
797                }
798                Err(e) => {
799                    panic!("IP address report error: {e}");
800                }
801            }
802        } else {
803            panic!("`VT_API_KEY` not set!")
804        }
805    }
806}