malwaredb_client/
blocking.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::fmt::{Debug, Formatter};
4use std::path::{Path, PathBuf};
5
6use crate::{get_config_path, MDB_CLIENT_ERROR_CONTEXT};
7use malwaredb_types::exec::pe32::EXE;
8
9use anyhow::{bail, ensure, Context, Result};
10use base64::engine::general_purpose;
11use base64::Engine;
12use fuzzyhash::FuzzyHash;
13use malwaredb_api::{
14    GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType, Report, SearchRequest,
15    SearchRequestParameters, SearchResponse, SearchType, ServerInfo, ServerResponse,
16    SimilarSamplesResponse, Sources, SupportedFileTypes,
17};
18use malwaredb_lzjd::{LZDict, Murmur3HashState};
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use tlsh_fixed::TlshBuilder;
22use tracing::{error, info, warn};
23use zeroize::{Zeroize, ZeroizeOnDrop};
24
25/// Blocking Malware DB Client Configuration and connection which requires the `blocking` feature
26#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
27#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
28pub struct MdbClient {
29    /// URL of the Malware DB server, including http and port number, ending without a slash
30    pub url: String,
31
32    /// User's API key for Malware DB
33    api_key: String,
34
35    /// Blocking http client which stores the optional server certificate
36    #[zeroize(skip)]
37    #[serde(skip)]
38    client: reqwest::blocking::Client,
39
40    /// Server's certificate
41    #[cfg(target_os = "macos")]
42    #[zeroize(skip)]
43    #[serde(skip)]
44    cert: Option<crate::macos::CertificateData>,
45}
46
47impl MdbClient {
48    /// MDB Client from components, doesn't test connectivity
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if a list of certificates was passed and any were not in the expected
53    /// DER or PEM format or could not be parsed.
54    ///
55    /// # Panics
56    ///
57    /// This method panics if called from within an async runtime.
58    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
59        let mut url = url;
60        let url = if url.ends_with('/') {
61            url.pop();
62            url
63        } else {
64            url
65        };
66
67        let cert = if let Some(path) = cert_path {
68            Some((crate::path_load_cert(&path)?, path))
69        } else {
70            None
71        };
72
73        let builder = reqwest::blocking::ClientBuilder::new()
74            .gzip(true)
75            .zstd(true)
76            .use_rustls_tls()
77            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
78
79        let client = if let Some(((_cert_type, cert), _path)) = &cert {
80            builder.add_root_certificate(cert.clone()).build()
81        } else {
82            builder.build()
83        }?;
84
85        #[cfg(target_os = "macos")]
86        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
87            Some(crate::macos::CertificateData {
88                cert_type: *cert_type,
89                cert_bytes: std::fs::read(cert_path)?,
90            })
91        } else {
92            None
93        };
94
95        Ok(Self {
96            url,
97            api_key,
98            client,
99
100            #[cfg(target_os = "macos")]
101            cert,
102        })
103    }
104
105    /// Login to a server, optionally save the configuration file, and return a client object
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the server URL, username, or password were incorrect, or if a network
110    /// issue occurred.
111    ///
112    /// # Panics
113    ///
114    /// This method panics if called from within an async runtime.
115    pub fn login(
116        url: String,
117        username: String,
118        password: String,
119        save: bool,
120        cert_path: Option<PathBuf>,
121    ) -> Result<Self> {
122        let mut url = url;
123        let url = if url.ends_with('/') {
124            url.pop();
125            url
126        } else {
127            url
128        };
129
130        let api_request = malwaredb_api::GetAPIKeyRequest {
131            user: username,
132            password,
133        };
134
135        let builder = reqwest::blocking::ClientBuilder::new()
136            .gzip(true)
137            .zstd(true)
138            .use_rustls_tls()
139            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
140
141        let cert = if let Some(path) = cert_path {
142            Some((crate::path_load_cert(&path)?, path))
143        } else {
144            None
145        };
146
147        let client = if let Some(((_cert_type, cert), _path)) = &cert {
148            builder.add_root_certificate(cert.clone()).build()
149        } else {
150            builder.build()
151        }?;
152
153        let res = client
154            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
155            .json(&api_request)
156            .send()?
157            .json::<ServerResponse<GetAPIKeyResponse>>()
158            .context(MDB_CLIENT_ERROR_CONTEXT)?;
159
160        let res = match res {
161            ServerResponse::Success(res) => res,
162            ServerResponse::Error(err) => return Err(err.into()),
163        };
164
165        #[cfg(target_os = "macos")]
166        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
167            Some(crate::macos::CertificateData {
168                cert_type: *cert_type,
169                cert_bytes: std::fs::read(cert_path)?,
170            })
171        } else {
172            None
173        };
174
175        let client = MdbClient {
176            url,
177            api_key: res.key.clone(),
178            client,
179
180            #[cfg(target_os = "macos")]
181            cert,
182        };
183
184        let server_info = client.server_info()?;
185        if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
186            warn!(
187                "Server version {:?} is newer than client {:?}, consider updating.",
188                server_info.mdb_version,
189                crate::MDB_VERSION_SEMVER
190            );
191        }
192
193        if save {
194            if let Err(e) = client.save() {
195                error!("Login successful but failed to save config: {e}");
196                bail!("Login successful but failed to save config: {e}");
197            }
198        }
199        Ok(client)
200    }
201
202    /// Reset one's own API key to effectively logout & disable all clients who are using the key
203    ///
204    /// # Errors
205    ///
206    /// Returns an error if there was a network issue or the user wasn't properly logged in.
207    pub fn reset_key(&self) -> Result<()> {
208        let response = self
209            .client
210            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
211            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
212            .send()
213            .context(MDB_CLIENT_ERROR_CONTEXT)?;
214        if !response.status().is_success() {
215            bail!("failed to reset API key, was it correct?");
216        }
217        Ok(())
218    }
219
220    /// Malware DB Client configuration loaded from a specified path
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if the configuration file cannot be read, possibly because it
225    /// doesn't exist or due to a permission error or a parsing error.
226    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
227        let name = path.as_ref().display();
228        let config =
229            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
230        let cfg: MdbClient =
231            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
232        Ok(cfg)
233    }
234
235    /// Malware DB Client configuration from user's home directory
236    ///
237    /// On macOS, it will attempt to load this information in the Keychain, which isn't required.
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if the configuration file cannot be read, possibly because it
242    /// doesn't exist or due to a permission error or a parsing error.
243    ///
244    /// # Panics
245    ///
246    /// This method panics if called from within an async runtime.
247    pub fn load() -> Result<Self> {
248        #[cfg(target_os = "macos")]
249        {
250            if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
251                let builder = reqwest::blocking::ClientBuilder::new()
252                    .gzip(true)
253                    .zstd(true)
254                    .use_rustls_tls()
255                    .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
256
257                let client = if let Some(cert) = &cert {
258                    builder.add_root_certificate(cert.as_cert()?).build()
259                } else {
260                    builder.build()
261                }?;
262
263                return Ok(Self {
264                    url,
265                    api_key,
266                    client,
267                    cert,
268                });
269            }
270        }
271
272        let path = get_config_path(false)?;
273        if path.exists() {
274            return Self::from_file(path);
275        }
276        bail!("config file not found")
277    }
278
279    /// Save Malware DB Client configuration to the user's home directory
280    ///
281    /// On macOS, it will attempt to save this information in the Keychain, which isn't required.
282    ///
283    /// # Errors
284    ///
285    /// Returns an error if there was a problem saving the configuration file.
286    pub fn save(&self) -> Result<()> {
287        let toml = toml::to_string(self)?;
288        let path = get_config_path(true)?;
289        std::fs::write(&path, toml)
290            .context(format!("failed to write mdb config to {}", path.display()))
291    }
292
293    /// Delete the Malware DB client configuration file
294    ///
295    /// # Errors
296    ///
297    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
298    /// possibly due to a permissions error.
299    pub fn delete(&self) -> Result<()> {
300        #[cfg(target_os = "macos")]
301        crate::macos::clear_credentials();
302
303        let path = get_config_path(false)?;
304        if path.exists() {
305            std::fs::remove_file(&path).context(format!(
306                "failed to delete client config file {}",
307                path.display()
308            ))?;
309        }
310        Ok(())
311    }
312
313    // Actions of the client
314
315    /// Get information about the server, unauthenticated
316    ///
317    /// # Errors
318    ///
319    /// This may return an error if there's a network situation.
320    pub fn server_info(&self) -> Result<ServerInfo> {
321        let response = self
322            .client
323            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
324            .send()?
325            .json::<ServerResponse<ServerInfo>>()
326            .context(MDB_CLIENT_ERROR_CONTEXT)?;
327
328        match response {
329            ServerResponse::Success(info) => Ok(info),
330            ServerResponse::Error(e) => Err(e.into()),
331        }
332    }
333
334    /// Get file types supported by the server, unauthenticated
335    ///
336    /// # Errors
337    ///
338    /// This may return an error if there's a network situation.
339    pub fn supported_types(&self) -> Result<SupportedFileTypes> {
340        let response = self
341            .client
342            .get(format!(
343                "{}{}",
344                self.url,
345                malwaredb_api::SUPPORTED_FILE_TYPES_URL
346            ))
347            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
348            .send()?
349            .json::<ServerResponse<SupportedFileTypes>>()
350            .context(MDB_CLIENT_ERROR_CONTEXT)?;
351
352        match response {
353            ServerResponse::Success(types) => Ok(types),
354            ServerResponse::Error(e) => Err(e.into()),
355        }
356    }
357
358    /// Get information about the user
359    ///
360    /// # Errors
361    ///
362    /// This may return an error if there's a network situation or if the user is not logged in
363    /// or not properly authorized to connect.
364    pub fn whoami(&self) -> Result<GetUserInfoResponse> {
365        let response = self
366            .client
367            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
368            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
369            .send()?
370            .json::<ServerResponse<GetUserInfoResponse>>()
371            .context(MDB_CLIENT_ERROR_CONTEXT)?;
372
373        match response {
374            ServerResponse::Success(info) => Ok(info),
375            ServerResponse::Error(e) => Err(e.into()),
376        }
377    }
378
379    /// Get the sample labels known to the server
380    ///
381    /// # Errors
382    ///
383    /// This may return an error if there's a network situation or if the user is not logged in
384    /// or not properly authorized to connect.
385    pub fn labels(&self) -> Result<Labels> {
386        let response = self
387            .client
388            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
389            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
390            .send()?
391            .json::<ServerResponse<Labels>>()
392            .context(MDB_CLIENT_ERROR_CONTEXT)?;
393
394        match response {
395            ServerResponse::Success(labels) => Ok(labels),
396            ServerResponse::Error(e) => Err(e.into()),
397        }
398    }
399
400    /// Get the sources available to the current user
401    ///
402    /// # Errors
403    ///
404    /// This may return an error if there's a network situation or if the user is not logged in
405    /// or not properly authorized to connect.
406    pub fn sources(&self) -> Result<Sources> {
407        let response = self
408            .client
409            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
410            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
411            .send()?
412            .json::<ServerResponse<Sources>>()
413            .context(MDB_CLIENT_ERROR_CONTEXT)?;
414
415        match response {
416            ServerResponse::Success(sources) => Ok(sources),
417            ServerResponse::Error(e) => Err(e.into()),
418        }
419    }
420
421    /// Submit one file to Malware DB: provide the contents, file name, and source ID
422    ///
423    /// # Errors
424    ///
425    /// This may return an error if there's a network situation or if the user is not logged in
426    /// or not properly authorized to connect.
427    pub fn submit(
428        &self,
429        contents: impl AsRef<[u8]>,
430        file_name: String,
431        source_id: u32,
432    ) -> Result<bool> {
433        let mut hasher = Sha256::new();
434        hasher.update(&contents);
435        let result = hasher.finalize();
436
437        let encoded = general_purpose::STANDARD.encode(contents);
438
439        let payload = malwaredb_api::NewSampleB64 {
440            file_name,
441            source_id,
442            file_contents_b64: encoded,
443            sha256: hex::encode(result),
444        };
445
446        match self
447            .client
448            .post(format!(
449                "{}{}",
450                self.url,
451                malwaredb_api::UPLOAD_SAMPLE_JSON_URL
452            ))
453            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
454            .json(&payload)
455            .send()
456        {
457            Ok(res) => {
458                if !res.status().is_success() {
459                    info!("Code {} sending {}", res.status(), payload.file_name);
460                }
461                Ok(res.status().is_success())
462            }
463            Err(e) => {
464                let status: String = e
465                    .status()
466                    .map(|s| s.as_str().to_string())
467                    .unwrap_or_default();
468                error!("Error{status} sending {}: {e}", payload.file_name);
469                bail!(e.to_string())
470            }
471        }
472    }
473
474    /// Submit one file to Malware DB: provide the contents, file name, and source ID
475    /// Experimental! May be removed at any point.
476    ///
477    /// # Errors
478    ///
479    /// This may return an error if there's a network situation or if the user is not logged in
480    /// or not properly authorized to connect.
481    pub fn submit_as_cbor(
482        &self,
483        contents: impl AsRef<[u8]>,
484        file_name: String,
485        source_id: u32,
486    ) -> Result<bool> {
487        let mut hasher = Sha256::new();
488        hasher.update(&contents);
489        let result = hasher.finalize();
490
491        let payload = malwaredb_api::NewSampleBytes {
492            file_name,
493            source_id,
494            file_contents: contents.as_ref().to_vec(),
495            sha256: hex::encode(result),
496        };
497
498        let mut bytes = Vec::with_capacity(payload.file_contents.len());
499        ciborium::ser::into_writer(&payload, &mut bytes)?;
500
501        match self
502            .client
503            .post(format!(
504                "{}{}",
505                self.url,
506                malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
507            ))
508            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
509            .header("content-type", "application/cbor")
510            .body(bytes)
511            .send()
512        {
513            Ok(res) => {
514                if !res.status().is_success() {
515                    info!("Code {} sending {}", res.status(), payload.file_name);
516                }
517                Ok(res.status().is_success())
518            }
519            Err(e) => {
520                let status: String = e
521                    .status()
522                    .map(|s| s.as_str().to_string())
523                    .unwrap_or_default();
524                error!("Error{status} sending {}: {e}", payload.file_name);
525                bail!(e.to_string())
526            }
527        }
528    }
529
530    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
531    ///
532    /// # Errors
533    ///
534    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
535    pub fn partial_search(
536        &self,
537        partial_hash: Option<(PartialHashSearchType, String)>,
538        name: Option<String>,
539        response: PartialHashSearchType,
540        limit: u32,
541    ) -> Result<SearchResponse> {
542        let query = SearchRequest {
543            search: SearchType::Search(SearchRequestParameters {
544                partial_hash,
545                file_name: name,
546                response,
547                limit,
548                labels: None,
549                file_type: None,
550                magic: None,
551            }),
552        };
553
554        self.do_search_request(&query)
555    }
556
557    /// Search for a file based on partial hash and/or partial file name, labels, file type; returns a list of hashes
558    ///
559    /// # Errors
560    ///
561    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
562    #[allow(clippy::too_many_arguments)]
563    pub fn partial_search_labels_type(
564        &self,
565        partial_hash: Option<(PartialHashSearchType, String)>,
566        name: Option<String>,
567        response: PartialHashSearchType,
568        labels: Option<Vec<String>>,
569        file_type: Option<String>,
570        magic: Option<String>,
571        limit: u32,
572    ) -> Result<SearchResponse> {
573        let query = SearchRequest {
574            search: SearchType::Search(SearchRequestParameters {
575                partial_hash,
576                file_name: name,
577                response,
578                limit,
579                file_type,
580                magic,
581                labels,
582            }),
583        };
584
585        self.do_search_request(&query)
586    }
587
588    /// Return the next page from the search result
589    ///
590    /// # Errors
591    ///
592    /// Returns an error if there is a network problem, or pagination not available
593    pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
594        if let Some(uuid) = response.pagination {
595            let request = SearchRequest {
596                search: SearchType::Continuation(uuid),
597            };
598            return self.do_search_request(&request);
599        }
600
601        bail!("Pagination not available")
602    }
603
604    fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
605        ensure!(
606            query.is_valid(),
607            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
608        );
609
610        let response = self
611            .client
612            .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
613            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
614            .json(query)
615            .send()?
616            .json::<ServerResponse<SearchResponse>>()
617            .context(MDB_CLIENT_ERROR_CONTEXT)?;
618
619        match response {
620            ServerResponse::Success(search) => Ok(search),
621            ServerResponse::Error(e) => Err(e.into()),
622        }
623    }
624
625    /// Retrieve sample by hash, optionally in the `CaRT` format
626    ///
627    /// # Errors
628    ///
629    /// This may return an error if there's a network situation or if the user is not logged in
630    /// or not properly authorized to connect.
631    pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
632        let api_endpoint = if cart {
633            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
634        } else {
635            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
636        };
637
638        let res = self
639            .client
640            .get(format!("{}{api_endpoint}", self.url))
641            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
642            .send()?;
643
644        if !res.status().is_success() {
645            bail!("Received code {}", res.status());
646        }
647
648        let body = res.bytes()?;
649        Ok(body.to_vec())
650    }
651
652    /// Fetch a report for a sample
653    ///
654    /// # Errors
655    ///
656    /// This may return an error if there's a network situation or if the user is not logged in
657    /// or not properly authorized to connect.
658    pub fn report(&self, hash: &str) -> Result<Report> {
659        let response = self
660            .client
661            .get(format!(
662                "{}{}/{hash}",
663                self.url,
664                malwaredb_api::SAMPLE_REPORT_URL
665            ))
666            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
667            .send()?
668            .json::<ServerResponse<Report>>()
669            .context(MDB_CLIENT_ERROR_CONTEXT)?;
670
671        match response {
672            ServerResponse::Success(report) => Ok(report),
673            ServerResponse::Error(e) => Err(e.into()),
674        }
675    }
676
677    /// Find similar samples in `MalwareDB` based on the contents of a given file.
678    /// This does not submit the sample to `MalwareDB`.
679    ///
680    /// # Errors
681    ///
682    /// This may return an error if there's a network situation or if the user is not logged in
683    /// or not properly authorized to connect.
684    pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
685        let mut hashes = vec![];
686        let ssdeep_hash = FuzzyHash::new(contents);
687
688        let build_hasher = Murmur3HashState::default();
689        let lzjd_str =
690            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
691        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
692        hashes.push((
693            malwaredb_api::SimilarityHashType::SSDeep,
694            ssdeep_hash.to_string(),
695        ));
696
697        let mut builder = TlshBuilder::new(
698            tlsh_fixed::BucketKind::Bucket256,
699            tlsh_fixed::ChecksumKind::ThreeByte,
700            tlsh_fixed::Version::Version4,
701        );
702
703        builder.update(contents);
704        if let Ok(hasher) = builder.build() {
705            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
706        }
707
708        if let Ok(exe) = EXE::from(contents) {
709            if let Some(imports) = exe.imports {
710                hashes.push((
711                    malwaredb_api::SimilarityHashType::ImportHash,
712                    hex::encode(imports.hash()),
713                ));
714                hashes.push((
715                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
716                    imports.fuzzy_hash(),
717                ));
718            }
719        }
720
721        let request = malwaredb_api::SimilarSamplesRequest { hashes };
722
723        let response = self
724            .client
725            .post(format!(
726                "{}{}",
727                self.url,
728                malwaredb_api::SIMILAR_SAMPLES_URL
729            ))
730            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
731            .json(&request)
732            .send()?
733            .json::<ServerResponse<SimilarSamplesResponse>>()
734            .context(MDB_CLIENT_ERROR_CONTEXT)?;
735
736        match response {
737            ServerResponse::Success(similar) => Ok(similar),
738            ServerResponse::Error(e) => Err(e.into()),
739        }
740    }
741}
742
743impl Debug for MdbClient {
744    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
745        use crate::MDB_VERSION;
746
747        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
748    }
749}
750
751/// Wrapper around search results for iterating over resulting hashes
752pub struct IterableHashSearchResult<'a> {
753    /// Server search result
754    pub response: SearchResponse,
755
756    client: &'a MdbClient,
757}
758
759impl<'a> IterableHashSearchResult<'a> {
760    /// Iterate over the hashes from a search result
761    #[must_use]
762    pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
763        Self { response, client }
764    }
765}
766
767impl Iterator for IterableHashSearchResult<'_> {
768    type Item = String;
769
770    fn next(&mut self) -> Option<Self::Item> {
771        if let Some(hash) = self.response.hashes.pop() {
772            Some(hash)
773        } else if let Some(uuid) = self.response.pagination {
774            let request = SearchRequest {
775                search: SearchType::Continuation(uuid),
776            };
777
778            self.response = match self.client.do_search_request(&request) {
779                Ok(response) => response,
780                Err(e) => {
781                    warn!("Failed to continue search: {e}");
782                    return None;
783                }
784            };
785
786            self.response.hashes.pop()
787        } else {
788            None
789        }
790    }
791}
792
793/// Wrapper around search results for iterating over resulting binaries
794pub struct IterableSampleSearchResult<'a> {
795    /// Server search result
796    pub response: SearchResponse,
797
798    client: &'a MdbClient,
799}
800
801impl<'a> IterableSampleSearchResult<'a> {
802    /// Iterate over the hashes from a search result returning the binary
803    #[must_use]
804    pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
805        Self { response, client }
806    }
807}
808
809impl Iterator for IterableSampleSearchResult<'_> {
810    type Item = Vec<u8>;
811
812    fn next(&mut self) -> Option<Self::Item> {
813        if let Some(hash) = self.response.hashes.pop() {
814            let binary = match self.client.retrieve(&hash, false) {
815                Ok(binary) => binary,
816                Err(e) => {
817                    error!("Failed to download {hash}: {e}");
818                    return None;
819                }
820            };
821            Some(binary)
822        } else if let Some(uuid) = self.response.pagination {
823            let request = SearchRequest {
824                search: SearchType::Continuation(uuid),
825            };
826
827            self.response = match self.client.do_search_request(&request) {
828                Ok(response) => response,
829                Err(e) => {
830                    warn!("Failed to continue search: {e}");
831                    return None;
832                }
833            };
834
835            if let Some(hash) = self.response.hashes.pop() {
836                let binary = match self.client.retrieve(&hash, false) {
837                    Ok(binary) => binary,
838                    Err(e) => {
839                        error!("Failed to download {hash}: {e}");
840                        return None;
841                    }
842                };
843                Some(binary)
844            } else {
845                None
846            }
847        } else {
848            None
849        }
850    }
851}