malwaredb_client/
blocking.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::fmt::{Debug, Formatter};
4use std::path::{Path, PathBuf};
5
6use crate::{get_config_path, MDB_CLIENT_ERROR_CONTEXT};
7use malwaredb_api::{
8    digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
9    Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
10    ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes,
11};
12use malwaredb_types::exec::pe32::EXE;
13
14use anyhow::{bail, ensure, Context, Result};
15use base64::engine::general_purpose;
16use base64::Engine;
17use fuzzyhash::FuzzyHash;
18use malwaredb_lzjd::{LZDict, Murmur3HashState};
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use tlsh_fixed::TlshBuilder;
22use tracing::{error, info, trace, warn};
23use zeroize::{Zeroize, ZeroizeOnDrop};
24
25/// Blocking Malware DB Client Configuration and connection which requires the `blocking` feature
26#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
27#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
28pub struct MdbClient {
29    /// URL of the Malware DB server, including http and port number, ending without a slash
30    pub url: String,
31
32    /// User's API key for Malware DB
33    api_key: String,
34
35    /// Blocking http client which stores the optional server certificate
36    #[zeroize(skip)]
37    #[serde(skip)]
38    client: reqwest::blocking::Client,
39
40    /// Server's certificate
41    #[cfg(target_os = "macos")]
42    #[zeroize(skip)]
43    #[serde(skip)]
44    cert: Option<crate::macos::CertificateData>,
45}
46
47impl MdbClient {
48    /// MDB Client from components, doesn't test connectivity
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if a list of certificates was passed and any were not in the expected
53    /// DER or PEM format or could not be parsed.
54    ///
55    /// # Panics
56    ///
57    /// This method panics if called from within an async runtime.
58    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
59        let mut url = url;
60        let url = if url.ends_with('/') {
61            url.pop();
62            url
63        } else {
64            url
65        };
66
67        let cert = if let Some(path) = cert_path {
68            Some((crate::path_load_cert(&path)?, path))
69        } else {
70            None
71        };
72
73        let builder = reqwest::blocking::ClientBuilder::new()
74            .gzip(true)
75            .zstd(true)
76            .use_rustls_tls()
77            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
78
79        let client = if let Some(((_cert_type, cert), _path)) = &cert {
80            builder.add_root_certificate(cert.clone()).build()
81        } else {
82            builder.build()
83        }?;
84
85        #[cfg(target_os = "macos")]
86        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
87            Some(crate::macos::CertificateData {
88                cert_type: *cert_type,
89                cert_bytes: std::fs::read(cert_path)?,
90            })
91        } else {
92            None
93        };
94
95        Ok(Self {
96            url,
97            api_key,
98            client,
99
100            #[cfg(target_os = "macos")]
101            cert,
102        })
103    }
104
105    /// Login to a server, optionally save the configuration file, and return a client object
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the server URL, username, or password were incorrect, or if a network
110    /// issue occurred.
111    ///
112    /// # Panics
113    ///
114    /// This method panics if called from within an async runtime.
115    pub fn login(
116        url: String,
117        username: String,
118        password: String,
119        save: bool,
120        cert_path: Option<PathBuf>,
121    ) -> Result<Self> {
122        let mut url = url;
123        let url = if url.ends_with('/') {
124            url.pop();
125            url
126        } else {
127            url
128        };
129
130        let api_request = malwaredb_api::GetAPIKeyRequest {
131            user: username,
132            password,
133        };
134
135        let builder = reqwest::blocking::ClientBuilder::new()
136            .gzip(true)
137            .zstd(true)
138            .use_rustls_tls()
139            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
140
141        let cert = if let Some(path) = cert_path {
142            Some((crate::path_load_cert(&path)?, path))
143        } else {
144            None
145        };
146
147        let client = if let Some(((_cert_type, cert), _path)) = &cert {
148            builder.add_root_certificate(cert.clone()).build()
149        } else {
150            builder.build()
151        }?;
152
153        let res = client
154            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
155            .json(&api_request)
156            .send()?
157            .json::<ServerResponse<GetAPIKeyResponse>>()
158            .context(MDB_CLIENT_ERROR_CONTEXT)?;
159
160        let res = match res {
161            ServerResponse::Success(res) => res,
162            ServerResponse::Error(err) => return Err(err.into()),
163        };
164
165        #[cfg(target_os = "macos")]
166        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
167            Some(crate::macos::CertificateData {
168                cert_type: *cert_type,
169                cert_bytes: std::fs::read(cert_path)?,
170            })
171        } else {
172            None
173        };
174
175        let client = MdbClient {
176            url,
177            api_key: res.key.clone(),
178            client,
179
180            #[cfg(target_os = "macos")]
181            cert,
182        };
183
184        let server_info = client.server_info()?;
185        if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
186            warn!(
187                "Server version {:?} is newer than client {:?}, consider updating.",
188                server_info.mdb_version,
189                crate::MDB_VERSION_SEMVER
190            );
191        }
192
193        if save {
194            if let Err(e) = client.save() {
195                error!("Login successful but failed to save config: {e}");
196                bail!("Login successful but failed to save config: {e}");
197            }
198        }
199        Ok(client)
200    }
201
202    /// Reset one's own API key to effectively logout & disable all clients who are using the key
203    ///
204    /// # Errors
205    ///
206    /// Returns an error if there was a network issue or the user wasn't properly logged in.
207    pub fn reset_key(&self) -> Result<()> {
208        let response = self
209            .client
210            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
211            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
212            .send()
213            .context(MDB_CLIENT_ERROR_CONTEXT)?;
214        if !response.status().is_success() {
215            bail!("failed to reset API key, was it correct?");
216        }
217        Ok(())
218    }
219
220    /// Malware DB Client configuration loaded from a specified path
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if the configuration file cannot be read, possibly because it
225    /// doesn't exist or due to a permission error or a parsing error.
226    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
227        let name = path.as_ref().display();
228        let config =
229            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
230        let cfg: MdbClient =
231            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
232        Ok(cfg)
233    }
234
235    /// Malware DB Client configuration from user's home directory
236    ///
237    /// On macOS, it will attempt to load this information in the Keychain, which isn't required.
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if the configuration file cannot be read, possibly because it
242    /// doesn't exist or due to a permission error or a parsing error.
243    ///
244    /// # Panics
245    ///
246    /// This method panics if called from within an async runtime.
247    pub fn load() -> Result<Self> {
248        #[cfg(target_os = "macos")]
249        {
250            if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
251                let builder = reqwest::blocking::ClientBuilder::new()
252                    .gzip(true)
253                    .zstd(true)
254                    .use_rustls_tls()
255                    .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
256
257                let client = if let Some(cert) = &cert {
258                    builder.add_root_certificate(cert.as_cert()?).build()
259                } else {
260                    builder.build()
261                }?;
262
263                return Ok(Self {
264                    url,
265                    api_key,
266                    client,
267                    cert,
268                });
269            }
270        }
271
272        let path = get_config_path(false)?;
273        if path.exists() {
274            return Self::from_file(path);
275        }
276        bail!("config file not found")
277    }
278
279    /// Save Malware DB Client configuration to the user's home directory
280    ///
281    /// On macOS, it will attempt to save this information in the Keychain, which isn't required.
282    ///
283    /// # Errors
284    ///
285    /// Returns an error if there was a problem saving the configuration file.
286    pub fn save(&self) -> Result<()> {
287        #[cfg(target_os = "macos")]
288        {
289            if crate::macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
290                return Ok(());
291            }
292        }
293
294        let toml = toml::to_string(self)?;
295        let path = get_config_path(true)?;
296        std::fs::write(&path, toml)
297            .context(format!("failed to write mdb config to {}", path.display()))
298    }
299
300    /// Delete the Malware DB client configuration file
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
305    /// possibly due to a permissions error.
306    pub fn delete(&self) -> Result<()> {
307        #[cfg(target_os = "macos")]
308        crate::macos::clear_credentials();
309
310        let path = get_config_path(false)?;
311        if path.exists() {
312            std::fs::remove_file(&path).context(format!(
313                "failed to delete client config file {}",
314                path.display()
315            ))?;
316        }
317        Ok(())
318    }
319
320    // Actions of the client
321
322    /// Get information about the server, unauthenticated
323    ///
324    /// # Errors
325    ///
326    /// This may return an error if there's a network situation.
327    pub fn server_info(&self) -> Result<ServerInfo> {
328        let response = self
329            .client
330            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
331            .send()?
332            .json::<ServerResponse<ServerInfo>>()
333            .context(MDB_CLIENT_ERROR_CONTEXT)?;
334
335        match response {
336            ServerResponse::Success(info) => Ok(info),
337            ServerResponse::Error(e) => Err(e.into()),
338        }
339    }
340
341    /// Get file types supported by the server, unauthenticated
342    ///
343    /// # Errors
344    ///
345    /// This may return an error if there's a network situation.
346    pub fn supported_types(&self) -> Result<SupportedFileTypes> {
347        let response = self
348            .client
349            .get(format!(
350                "{}{}",
351                self.url,
352                malwaredb_api::SUPPORTED_FILE_TYPES_URL
353            ))
354            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
355            .send()?
356            .json::<ServerResponse<SupportedFileTypes>>()
357            .context(MDB_CLIENT_ERROR_CONTEXT)?;
358
359        match response {
360            ServerResponse::Success(types) => Ok(types),
361            ServerResponse::Error(e) => Err(e.into()),
362        }
363    }
364
365    /// Get information about the user
366    ///
367    /// # Errors
368    ///
369    /// This may return an error if there's a network situation or if the user is not logged in
370    /// or not properly authorized to connect.
371    pub fn whoami(&self) -> Result<GetUserInfoResponse> {
372        let response = self
373            .client
374            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
375            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
376            .send()?
377            .json::<ServerResponse<GetUserInfoResponse>>()
378            .context(MDB_CLIENT_ERROR_CONTEXT)?;
379
380        match response {
381            ServerResponse::Success(info) => Ok(info),
382            ServerResponse::Error(e) => Err(e.into()),
383        }
384    }
385
386    /// Get the sample labels known to the server
387    ///
388    /// # Errors
389    ///
390    /// This may return an error if there's a network situation or if the user is not logged in
391    /// or not properly authorized to connect.
392    pub fn labels(&self) -> Result<Labels> {
393        let response = self
394            .client
395            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
396            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
397            .send()?
398            .json::<ServerResponse<Labels>>()
399            .context(MDB_CLIENT_ERROR_CONTEXT)?;
400
401        match response {
402            ServerResponse::Success(labels) => Ok(labels),
403            ServerResponse::Error(e) => Err(e.into()),
404        }
405    }
406
407    /// Get the sources available to the current user
408    ///
409    /// # Errors
410    ///
411    /// This may return an error if there's a network situation or if the user is not logged in
412    /// or not properly authorized to connect.
413    pub fn sources(&self) -> Result<Sources> {
414        let response = self
415            .client
416            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
417            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
418            .send()?
419            .json::<ServerResponse<Sources>>()
420            .context(MDB_CLIENT_ERROR_CONTEXT)?;
421
422        match response {
423            ServerResponse::Success(sources) => Ok(sources),
424            ServerResponse::Error(e) => Err(e.into()),
425        }
426    }
427
428    /// Submit one file to Malware DB: provide the contents, file name, and source ID
429    ///
430    /// # Errors
431    ///
432    /// This may return an error if there's a network situation or if the user is not logged in
433    /// or not properly authorized to connect.
434    pub fn submit(
435        &self,
436        contents: impl AsRef<[u8]>,
437        file_name: String,
438        source_id: u32,
439    ) -> Result<bool> {
440        let mut hasher = Sha256::new();
441        hasher.update(&contents);
442        let result = hasher.finalize();
443
444        let encoded = general_purpose::STANDARD.encode(contents);
445
446        let payload = malwaredb_api::NewSampleB64 {
447            file_name,
448            source_id,
449            file_contents_b64: encoded,
450            sha256: hex::encode(result),
451        };
452
453        match self
454            .client
455            .post(format!(
456                "{}{}",
457                self.url,
458                malwaredb_api::UPLOAD_SAMPLE_JSON_URL
459            ))
460            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
461            .json(&payload)
462            .send()
463        {
464            Ok(res) => {
465                if !res.status().is_success() {
466                    info!("Code {} sending {}", res.status(), payload.file_name);
467                }
468                Ok(res.status().is_success())
469            }
470            Err(e) => {
471                let status: String = e
472                    .status()
473                    .map(|s| s.as_str().to_string())
474                    .unwrap_or_default();
475                error!("Error{status} sending {}: {e}", payload.file_name);
476                bail!(e.to_string())
477            }
478        }
479    }
480
481    /// Submit one file to Malware DB: provide the contents, file name, and source ID
482    /// Experimental! May be removed at any point.
483    ///
484    /// # Errors
485    ///
486    /// This may return an error if there's a network situation or if the user is not logged in
487    /// or not properly authorized to connect.
488    pub fn submit_as_cbor(
489        &self,
490        contents: impl AsRef<[u8]>,
491        file_name: String,
492        source_id: u32,
493    ) -> Result<bool> {
494        let mut hasher = Sha256::new();
495        hasher.update(&contents);
496        let result = hasher.finalize();
497
498        let payload = malwaredb_api::NewSampleBytes {
499            file_name,
500            source_id,
501            file_contents: contents.as_ref().to_vec(),
502            sha256: hex::encode(result),
503        };
504
505        let mut bytes = Vec::with_capacity(payload.file_contents.len());
506        ciborium::ser::into_writer(&payload, &mut bytes)?;
507
508        match self
509            .client
510            .post(format!(
511                "{}{}",
512                self.url,
513                malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
514            ))
515            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
516            .header("content-type", "application/cbor")
517            .body(bytes)
518            .send()
519        {
520            Ok(res) => {
521                if !res.status().is_success() {
522                    info!("Code {} sending {}", res.status(), payload.file_name);
523                }
524                Ok(res.status().is_success())
525            }
526            Err(e) => {
527                let status: String = e
528                    .status()
529                    .map(|s| s.as_str().to_string())
530                    .unwrap_or_default();
531                error!("Error{status} sending {}: {e}", payload.file_name);
532                bail!(e.to_string())
533            }
534        }
535    }
536
537    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
538    ///
539    /// # Errors
540    ///
541    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
542    pub fn partial_search(
543        &self,
544        partial_hash: Option<(PartialHashSearchType, String)>,
545        name: Option<String>,
546        response: PartialHashSearchType,
547        limit: u32,
548    ) -> Result<SearchResponse> {
549        let query = SearchRequest {
550            search: SearchType::Search(SearchRequestParameters {
551                partial_hash,
552                file_name: name,
553                response,
554                limit,
555                labels: None,
556                file_type: None,
557                magic: None,
558            }),
559        };
560
561        self.do_search_request(&query)
562    }
563
564    /// Search for a file based on partial hash and/or partial file name, labels, file type; returns a list of hashes
565    ///
566    /// # Errors
567    ///
568    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
569    #[allow(clippy::too_many_arguments)]
570    pub fn partial_search_labels_type(
571        &self,
572        partial_hash: Option<(PartialHashSearchType, String)>,
573        name: Option<String>,
574        response: PartialHashSearchType,
575        labels: Option<Vec<String>>,
576        file_type: Option<String>,
577        magic: Option<String>,
578        limit: u32,
579    ) -> Result<SearchResponse> {
580        let query = SearchRequest {
581            search: SearchType::Search(SearchRequestParameters {
582                partial_hash,
583                file_name: name,
584                response,
585                limit,
586                file_type,
587                magic,
588                labels,
589            }),
590        };
591
592        self.do_search_request(&query)
593    }
594
595    /// Return the next page from the search result
596    ///
597    /// # Errors
598    ///
599    /// Returns an error if there is a network problem, or pagination not available
600    pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
601        if let Some(uuid) = response.pagination {
602            let request = SearchRequest {
603                search: SearchType::Continuation(uuid),
604            };
605            return self.do_search_request(&request);
606        }
607
608        bail!("Pagination not available")
609    }
610
611    fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
612        ensure!(
613            query.is_valid(),
614            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
615        );
616
617        let response = self
618            .client
619            .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
620            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
621            .json(query)
622            .send()?
623            .json::<ServerResponse<SearchResponse>>()
624            .context(MDB_CLIENT_ERROR_CONTEXT)?;
625
626        match response {
627            ServerResponse::Success(search) => Ok(search),
628            ServerResponse::Error(e) => Err(e.into()),
629        }
630    }
631
632    /// Retrieve sample by hash, optionally in the `CaRT` format
633    ///
634    /// # Errors
635    ///
636    /// This may return an error if there's a network situation or if the user is not logged in
637    /// or not properly authorized to connect.
638    pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
639        let api_endpoint = if cart {
640            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
641        } else {
642            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
643        };
644
645        let res = self
646            .client
647            .get(format!("{}{api_endpoint}", self.url))
648            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
649            .send()?;
650
651        if !res.status().is_success() {
652            bail!("Received code {}", res.status());
653        }
654
655        let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
656        let body = res.bytes()?;
657        let bytes = body.to_vec();
658
659        // TODO: Make this required in v0.3
660        if let Some(digest) = content_digest {
661            let hash = HashType::from_content_digest_header(digest.to_str()?)?;
662            if hash.verify(&bytes) {
663                trace!("Hash verified for sample {hash}");
664            } else {
665                error!("Hash mismatch for sample {hash}");
666            }
667        } else {
668            warn!("No content digest header received for sample {hash}");
669        }
670
671        Ok(bytes)
672    }
673
674    /// Fetch a report for a sample
675    ///
676    /// # Errors
677    ///
678    /// This may return an error if there's a network situation or if the user is not logged in
679    /// or not properly authorized to connect.
680    pub fn report(&self, hash: &str) -> Result<Report> {
681        let response = self
682            .client
683            .get(format!(
684                "{}{}/{hash}",
685                self.url,
686                malwaredb_api::SAMPLE_REPORT_URL
687            ))
688            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
689            .send()?
690            .json::<ServerResponse<Report>>()
691            .context(MDB_CLIENT_ERROR_CONTEXT)?;
692
693        match response {
694            ServerResponse::Success(report) => Ok(report),
695            ServerResponse::Error(e) => Err(e.into()),
696        }
697    }
698
699    /// Find similar samples in `MalwareDB` based on the contents of a given file.
700    /// This does not submit the sample to `MalwareDB`.
701    ///
702    /// # Errors
703    ///
704    /// This may return an error if there's a network situation or if the user is not logged in
705    /// or not properly authorized to connect.
706    pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
707        let mut hashes = vec![];
708        let ssdeep_hash = FuzzyHash::new(contents);
709
710        let build_hasher = Murmur3HashState::default();
711        let lzjd_str =
712            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
713        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
714        hashes.push((
715            malwaredb_api::SimilarityHashType::SSDeep,
716            ssdeep_hash.to_string(),
717        ));
718
719        let mut builder = TlshBuilder::new(
720            tlsh_fixed::BucketKind::Bucket256,
721            tlsh_fixed::ChecksumKind::ThreeByte,
722            tlsh_fixed::Version::Version4,
723        );
724
725        builder.update(contents);
726        if let Ok(hasher) = builder.build() {
727            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
728        }
729
730        if let Ok(exe) = EXE::from(contents) {
731            if let Some(imports) = exe.imports {
732                hashes.push((
733                    malwaredb_api::SimilarityHashType::ImportHash,
734                    hex::encode(imports.hash()),
735                ));
736                hashes.push((
737                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
738                    imports.fuzzy_hash(),
739                ));
740            }
741        }
742
743        let request = malwaredb_api::SimilarSamplesRequest { hashes };
744
745        let response = self
746            .client
747            .post(format!(
748                "{}{}",
749                self.url,
750                malwaredb_api::SIMILAR_SAMPLES_URL
751            ))
752            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
753            .json(&request)
754            .send()?
755            .json::<ServerResponse<SimilarSamplesResponse>>()
756            .context(MDB_CLIENT_ERROR_CONTEXT)?;
757
758        match response {
759            ServerResponse::Success(similar) => Ok(similar),
760            ServerResponse::Error(e) => Err(e.into()),
761        }
762    }
763}
764
765impl Debug for MdbClient {
766    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
767        use crate::MDB_VERSION;
768
769        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
770    }
771}
772
773/// Wrapper around search results for iterating over resulting hashes with the blocking client
774///
775/// ```rust,no_run
776/// use malwaredb_client::blocking::{MdbClient, IterableHashSearchResult};
777/// use malwaredb_client::malwaredb_api::{PartialHashSearchType, SearchType};
778///
779/// let client = MdbClient::load().expect("Failed to load client or parse config file");
780///
781/// // Get the first 100 files where the file name contains "foo", returning hashes as SHA-256
782/// let search_result = client.partial_search(None, Some("foo".into()), PartialHashSearchType::SHA256, 100).unwrap();
783/// for hash in IterableHashSearchResult::from(search_result, &client) {
784///     println!("{hash}");
785/// }
786/// ```
787#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
788pub struct IterableHashSearchResult<'a> {
789    /// Server search result
790    pub response: SearchResponse,
791
792    /// Blocking client
793    client: &'a MdbClient,
794}
795
796impl<'a> IterableHashSearchResult<'a> {
797    /// Iterate over the hashes from a search result and blocking client
798    #[must_use]
799    pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
800        Self { response, client }
801    }
802}
803
804impl Iterator for IterableHashSearchResult<'_> {
805    type Item = String;
806
807    fn next(&mut self) -> Option<Self::Item> {
808        if let Some(hash) = self.response.hashes.pop() {
809            Some(hash)
810        } else if let Some(uuid) = self.response.pagination {
811            let request = SearchRequest {
812                search: SearchType::Continuation(uuid),
813            };
814
815            self.response = match self.client.do_search_request(&request) {
816                Ok(response) => response,
817                Err(e) => {
818                    warn!("Failed to continue search: {e}");
819                    return None;
820                }
821            };
822
823            self.response.hashes.pop()
824        } else {
825            None
826        }
827    }
828}
829
830/// Wrapper around search results for iterating over resulting binaries with the blocking client
831#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
832pub struct IterableSampleSearchResult<'a> {
833    /// Server search result
834    pub response: SearchResponse,
835
836    /// Blocking client
837    client: &'a MdbClient,
838}
839
840impl<'a> IterableSampleSearchResult<'a> {
841    /// Iterate over the hashes from a search result and blocking client, returning the binary
842    #[must_use]
843    pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
844        Self { response, client }
845    }
846}
847
848impl Iterator for IterableSampleSearchResult<'_> {
849    type Item = Vec<u8>;
850
851    fn next(&mut self) -> Option<Self::Item> {
852        if let Some(hash) = self.response.hashes.pop() {
853            let binary = match self.client.retrieve(&hash, false) {
854                Ok(binary) => binary,
855                Err(e) => {
856                    error!("Failed to download {hash}: {e}");
857                    return None;
858                }
859            };
860            Some(binary)
861        } else if let Some(uuid) = self.response.pagination {
862            let request = SearchRequest {
863                search: SearchType::Continuation(uuid),
864            };
865
866            self.response = match self.client.do_search_request(&request) {
867                Ok(response) => response,
868                Err(e) => {
869                    warn!("Failed to continue search: {e}");
870                    return None;
871                }
872            };
873
874            if let Some(hash) = self.response.hashes.pop() {
875                let binary = match self.client.retrieve(&hash, false) {
876                    Ok(binary) => binary,
877                    Err(e) => {
878                        error!("Failed to download {hash}: {e}");
879                        return None;
880                    }
881                };
882                Some(binary)
883            } else {
884                None
885            }
886        } else {
887            None
888        }
889    }
890}