Skip to main content

malwaredb_client/
blocking.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::fmt::{Debug, Formatter};
4use std::path::{Path, PathBuf};
5
6use crate::{get_config_path, MDB_CLIENT_ERROR_CONTEXT};
7use malwaredb_api::{
8    digest::HashType, GetAPIKeyResponse, GetUserInfoResponse, Labels, PartialHashSearchType,
9    Report, SearchRequest, SearchRequestParameters, SearchResponse, SearchType, ServerInfo,
10    ServerResponse, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
11    YaraSearchRequestResponse, YaraSearchResponse,
12};
13use malwaredb_types::exec::pe32::EXE;
14
15use anyhow::{bail, ensure, Context, Result};
16use base64::engine::general_purpose;
17use base64::Engine;
18use fuzzyhash::FuzzyHash;
19use malwaredb_lzjd::{LZDict, Murmur3HashState};
20use serde::{Deserialize, Serialize};
21use sha2::{Digest, Sha256};
22use tlsh_fixed::TlshBuilder;
23use tracing::{error, info, trace, warn};
24use uuid::Uuid;
25use zeroize::{Zeroize, ZeroizeOnDrop};
26
27/// Blocking Malware DB Client Configuration and connection which requires the `blocking` feature
28#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
29#[derive(Deserialize, Serialize, Zeroize, ZeroizeOnDrop)]
30pub struct MdbClient {
31    /// URL of the Malware DB server, including http and port number, ending without a slash
32    pub url: String,
33
34    /// User's API key for Malware DB
35    api_key: String,
36
37    /// Blocking http client which stores the optional server certificate
38    #[zeroize(skip)]
39    #[serde(skip)]
40    client: reqwest::blocking::Client,
41
42    /// Server's certificate
43    #[cfg(target_os = "macos")]
44    #[zeroize(skip)]
45    #[serde(skip)]
46    cert: Option<crate::macos::CertificateData>,
47}
48
49impl MdbClient {
50    /// MDB Client from components, doesn't test connectivity
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if a list of certificates was passed and any were not in the expected
55    /// DER or PEM format or could not be parsed.
56    ///
57    /// # Panics
58    ///
59    /// This method panics if called from within an async runtime.
60    pub fn new(url: String, api_key: String, cert_path: Option<PathBuf>) -> Result<Self> {
61        let mut url = url;
62        let url = if url.ends_with('/') {
63            url.pop();
64            url
65        } else {
66            url
67        };
68
69        let cert = if let Some(path) = cert_path {
70            Some((crate::path_load_cert(&path)?, path))
71        } else {
72            None
73        };
74
75        let builder = reqwest::blocking::ClientBuilder::new()
76            .gzip(true)
77            .zstd(true)
78            .use_rustls_tls()
79            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
80
81        let client = if let Some(((_cert_type, cert), _path)) = &cert {
82            builder.add_root_certificate(cert.clone()).build()
83        } else {
84            builder.build()
85        }?;
86
87        #[cfg(target_os = "macos")]
88        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
89            Some(crate::macos::CertificateData {
90                cert_type: *cert_type,
91                cert_bytes: std::fs::read(cert_path)?,
92            })
93        } else {
94            None
95        };
96
97        Ok(Self {
98            url,
99            api_key,
100            client,
101
102            #[cfg(target_os = "macos")]
103            cert,
104        })
105    }
106
107    /// Login to a server, optionally save the configuration file, and return a client object
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if the server URL, username, or password were incorrect, or if a network
112    /// issue occurred.
113    ///
114    /// # Panics
115    ///
116    /// This method panics if called from within an async runtime.
117    pub fn login(
118        url: String,
119        username: String,
120        password: String,
121        save: bool,
122        cert_path: Option<PathBuf>,
123    ) -> Result<Self> {
124        let mut url = url;
125        let url = if url.ends_with('/') {
126            url.pop();
127            url
128        } else {
129            url
130        };
131
132        let api_request = malwaredb_api::GetAPIKeyRequest {
133            user: username,
134            password,
135        };
136
137        let builder = reqwest::blocking::ClientBuilder::new()
138            .gzip(true)
139            .zstd(true)
140            .use_rustls_tls()
141            .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
142
143        let cert = if let Some(path) = cert_path {
144            Some((crate::path_load_cert(&path)?, path))
145        } else {
146            None
147        };
148
149        let client = if let Some(((_cert_type, cert), _path)) = &cert {
150            builder.add_root_certificate(cert.clone()).build()
151        } else {
152            builder.build()
153        }?;
154
155        let res = client
156            .post(format!("{url}{}", malwaredb_api::USER_LOGIN_URL))
157            .json(&api_request)
158            .send()?
159            .json::<ServerResponse<GetAPIKeyResponse>>()
160            .context(MDB_CLIENT_ERROR_CONTEXT)?;
161
162        let res = match res {
163            ServerResponse::Success(res) => res,
164            ServerResponse::Error(err) => return Err(err.into()),
165        };
166
167        #[cfg(target_os = "macos")]
168        let cert = if let Some(((cert_type, _cert), cert_path)) = &cert {
169            Some(crate::macos::CertificateData {
170                cert_type: *cert_type,
171                cert_bytes: std::fs::read(cert_path)?,
172            })
173        } else {
174            None
175        };
176
177        let client = MdbClient {
178            url,
179            api_key: res.key.clone(),
180            client,
181
182            #[cfg(target_os = "macos")]
183            cert,
184        };
185
186        let server_info = client.server_info()?;
187        if server_info.mdb_version > *crate::MDB_VERSION_SEMVER {
188            warn!(
189                "Server version {:?} is newer than client {:?}, consider updating.",
190                server_info.mdb_version,
191                crate::MDB_VERSION_SEMVER
192            );
193        }
194
195        if save {
196            if let Err(e) = client.save() {
197                error!("Login successful but failed to save config: {e}");
198                bail!("Login successful but failed to save config: {e}");
199            }
200        }
201        Ok(client)
202    }
203
204    /// Reset one's own API key to effectively logout & disable all clients who are using the key
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if there was a network issue or the user wasn't properly logged in.
209    pub fn reset_key(&self) -> Result<()> {
210        let response = self
211            .client
212            .get(format!("{}{}", self.url, malwaredb_api::USER_LOGOUT_URL))
213            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
214            .send()
215            .context(MDB_CLIENT_ERROR_CONTEXT)?;
216        if !response.status().is_success() {
217            bail!("failed to reset API key, was it correct?");
218        }
219        Ok(())
220    }
221
222    /// Malware DB Client configuration loaded from a specified path
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if the configuration file cannot be read, possibly because it
227    /// doesn't exist or due to a permission error or a parsing error.
228    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
229        let name = path.as_ref().display();
230        let config =
231            std::fs::read_to_string(&path).context(format!("failed to read config file {name}"))?;
232        let cfg: MdbClient =
233            toml::from_str(&config).context(format!("failed to parse config file {name}"))?;
234        Ok(cfg)
235    }
236
237    /// Malware DB Client configuration from user's home directory
238    ///
239    /// On macOS, it will attempt to load this information in the Keychain, which isn't required.
240    ///
241    /// # Errors
242    ///
243    /// Returns an error if the configuration file cannot be read, possibly because it
244    /// doesn't exist or due to a permission error or a parsing error.
245    ///
246    /// # Panics
247    ///
248    /// This method panics if called from within an async runtime.
249    pub fn load() -> Result<Self> {
250        #[cfg(target_os = "macos")]
251        {
252            if let Ok((api_key, url, cert)) = crate::macos::retrieve_credentials() {
253                let builder = reqwest::blocking::ClientBuilder::new()
254                    .gzip(true)
255                    .zstd(true)
256                    .use_rustls_tls()
257                    .user_agent(concat!("mdb_client/", env!("CARGO_PKG_VERSION")));
258
259                let client = if let Some(cert) = &cert {
260                    builder.add_root_certificate(cert.as_cert()?).build()
261                } else {
262                    builder.build()
263                }?;
264
265                return Ok(Self {
266                    url,
267                    api_key,
268                    client,
269                    cert,
270                });
271            }
272        }
273
274        let path = get_config_path(false)?;
275        if path.exists() {
276            return Self::from_file(path);
277        }
278        bail!("config file not found")
279    }
280
281    /// Save Malware DB Client configuration to the user's home directory
282    ///
283    /// On macOS, it will attempt to save this information in the Keychain, which isn't required.
284    ///
285    /// # Errors
286    ///
287    /// Returns an error if there was a problem saving the configuration file.
288    pub fn save(&self) -> Result<()> {
289        #[cfg(target_os = "macos")]
290        {
291            if crate::macos::save_credentials(&self.url, &self.api_key, self.cert.clone()).is_ok() {
292                return Ok(());
293            }
294        }
295
296        let toml = toml::to_string(self)?;
297        let path = get_config_path(true)?;
298        std::fs::write(&path, toml)
299            .context(format!("failed to write mdb config to {}", path.display()))
300    }
301
302    /// Delete the Malware DB client configuration file
303    ///
304    /// # Errors
305    ///
306    /// Returns an error if there isn't a configuration file to delete, or if it cannot be deleted,
307    /// possibly due to a permissions error.
308    pub fn delete(&self) -> Result<()> {
309        #[cfg(target_os = "macos")]
310        crate::macos::clear_credentials();
311
312        let path = get_config_path(false)?;
313        if path.exists() {
314            std::fs::remove_file(&path).context(format!(
315                "failed to delete client config file {}",
316                path.display()
317            ))?;
318        }
319        Ok(())
320    }
321
322    // Actions of the client
323
324    /// Get information about the server, unauthenticated
325    ///
326    /// # Errors
327    ///
328    /// This may return an error if there's a network situation.
329    pub fn server_info(&self) -> Result<ServerInfo> {
330        let response = self
331            .client
332            .get(format!("{}{}", self.url, malwaredb_api::SERVER_INFO_URL))
333            .send()?
334            .json::<ServerResponse<ServerInfo>>()
335            .context(MDB_CLIENT_ERROR_CONTEXT)?;
336
337        match response {
338            ServerResponse::Success(info) => Ok(info),
339            ServerResponse::Error(e) => Err(e.into()),
340        }
341    }
342
343    /// Get file types supported by the server, unauthenticated
344    ///
345    /// # Errors
346    ///
347    /// This may return an error if there's a network situation.
348    pub fn supported_types(&self) -> Result<SupportedFileTypes> {
349        let response = self
350            .client
351            .get(format!(
352                "{}{}",
353                self.url,
354                malwaredb_api::SUPPORTED_FILE_TYPES_URL
355            ))
356            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
357            .send()?
358            .json::<ServerResponse<SupportedFileTypes>>()
359            .context(MDB_CLIENT_ERROR_CONTEXT)?;
360
361        match response {
362            ServerResponse::Success(types) => Ok(types),
363            ServerResponse::Error(e) => Err(e.into()),
364        }
365    }
366
367    /// Get information about the user
368    ///
369    /// # Errors
370    ///
371    /// This may return an error if there's a network situation or if the user is not logged in
372    /// or not properly authorized to connect.
373    pub fn whoami(&self) -> Result<GetUserInfoResponse> {
374        let response = self
375            .client
376            .get(format!("{}{}", self.url, malwaredb_api::USER_INFO_URL))
377            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
378            .send()?
379            .json::<ServerResponse<GetUserInfoResponse>>()
380            .context(MDB_CLIENT_ERROR_CONTEXT)?;
381
382        match response {
383            ServerResponse::Success(info) => Ok(info),
384            ServerResponse::Error(e) => Err(e.into()),
385        }
386    }
387
388    /// Get the sample labels known to the server
389    ///
390    /// # Errors
391    ///
392    /// This may return an error if there's a network situation or if the user is not logged in
393    /// or not properly authorized to connect.
394    pub fn labels(&self) -> Result<Labels> {
395        let response = self
396            .client
397            .get(format!("{}{}", self.url, malwaredb_api::LIST_LABELS_URL))
398            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
399            .send()?
400            .json::<ServerResponse<Labels>>()
401            .context(MDB_CLIENT_ERROR_CONTEXT)?;
402
403        match response {
404            ServerResponse::Success(labels) => Ok(labels),
405            ServerResponse::Error(e) => Err(e.into()),
406        }
407    }
408
409    /// Get the sources available to the current user
410    ///
411    /// # Errors
412    ///
413    /// This may return an error if there's a network situation or if the user is not logged in
414    /// or not properly authorized to connect.
415    pub fn sources(&self) -> Result<Sources> {
416        let response = self
417            .client
418            .get(format!("{}{}", self.url, malwaredb_api::LIST_SOURCES_URL))
419            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
420            .send()?
421            .json::<ServerResponse<Sources>>()
422            .context(MDB_CLIENT_ERROR_CONTEXT)?;
423
424        match response {
425            ServerResponse::Success(sources) => Ok(sources),
426            ServerResponse::Error(e) => Err(e.into()),
427        }
428    }
429
430    /// Submit one file to Malware DB: provide the contents, file name, and source ID
431    ///
432    /// # Errors
433    ///
434    /// This may return an error if there's a network situation or if the user is not logged in
435    /// or not properly authorized to connect.
436    pub fn submit(
437        &self,
438        contents: impl AsRef<[u8]>,
439        file_name: String,
440        source_id: u32,
441    ) -> Result<bool> {
442        let mut hasher = Sha256::new();
443        hasher.update(&contents);
444        let result = hasher.finalize();
445
446        let encoded = general_purpose::STANDARD.encode(contents);
447
448        let payload = malwaredb_api::NewSampleB64 {
449            file_name,
450            source_id,
451            file_contents_b64: encoded,
452            sha256: hex::encode(result),
453        };
454
455        match self
456            .client
457            .post(format!(
458                "{}{}",
459                self.url,
460                malwaredb_api::UPLOAD_SAMPLE_JSON_URL
461            ))
462            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
463            .json(&payload)
464            .send()
465        {
466            Ok(res) => {
467                if !res.status().is_success() {
468                    info!("Code {} sending {}", res.status(), payload.file_name);
469                }
470                Ok(res.status().is_success())
471            }
472            Err(e) => {
473                let status: String = e
474                    .status()
475                    .map(|s| s.as_str().to_string())
476                    .unwrap_or_default();
477                error!("Error{status} sending {}: {e}", payload.file_name);
478                bail!(e.to_string())
479            }
480        }
481    }
482
483    /// Submit one file to Malware DB: provide the contents, file name, and source ID
484    /// Experimental! May be removed at any point.
485    ///
486    /// # Errors
487    ///
488    /// This may return an error if there's a network situation or if the user is not logged in
489    /// or not properly authorized to connect.
490    pub fn submit_as_cbor(
491        &self,
492        contents: impl AsRef<[u8]>,
493        file_name: String,
494        source_id: u32,
495    ) -> Result<bool> {
496        let mut hasher = Sha256::new();
497        hasher.update(&contents);
498        let result = hasher.finalize();
499
500        let payload = malwaredb_api::NewSampleBytes {
501            file_name,
502            source_id,
503            file_contents: contents.as_ref().to_vec(),
504            sha256: hex::encode(result),
505        };
506
507        let mut bytes = Vec::with_capacity(payload.file_contents.len());
508        ciborium::ser::into_writer(&payload, &mut bytes)?;
509
510        match self
511            .client
512            .post(format!(
513                "{}{}",
514                self.url,
515                malwaredb_api::UPLOAD_SAMPLE_CBOR_URL
516            ))
517            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
518            .header("content-type", "application/cbor")
519            .body(bytes)
520            .send()
521        {
522            Ok(res) => {
523                if !res.status().is_success() {
524                    info!("Code {} sending {}", res.status(), payload.file_name);
525                }
526                Ok(res.status().is_success())
527            }
528            Err(e) => {
529                let status: String = e
530                    .status()
531                    .map(|s| s.as_str().to_string())
532                    .unwrap_or_default();
533                error!("Error{status} sending {}: {e}", payload.file_name);
534                bail!(e.to_string())
535            }
536        }
537    }
538
539    /// Search for a file based on partial hash and/or partial file name, returns a list of hashes
540    ///
541    /// # Errors
542    ///
543    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
544    pub fn partial_search(
545        &self,
546        partial_hash: Option<(PartialHashSearchType, String)>,
547        name: Option<String>,
548        response: PartialHashSearchType,
549        limit: u32,
550    ) -> Result<SearchResponse> {
551        let query = SearchRequest {
552            search: SearchType::Search(SearchRequestParameters {
553                partial_hash,
554                file_name: name,
555                response,
556                limit,
557                labels: None,
558                file_type: None,
559                magic: None,
560            }),
561        };
562
563        self.do_search_request(&query)
564    }
565
566    /// Search for a file based on partial hash and/or partial file name, labels, file type; returns a list of hashes
567    ///
568    /// # Errors
569    ///
570    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
571    #[allow(clippy::too_many_arguments)]
572    pub fn partial_search_labels_type(
573        &self,
574        partial_hash: Option<(PartialHashSearchType, String)>,
575        name: Option<String>,
576        response: PartialHashSearchType,
577        labels: Option<Vec<String>>,
578        file_type: Option<String>,
579        magic: Option<String>,
580        limit: u32,
581    ) -> Result<SearchResponse> {
582        let query = SearchRequest {
583            search: SearchType::Search(SearchRequestParameters {
584                partial_hash,
585                file_name: name,
586                response,
587                limit,
588                file_type,
589                magic,
590                labels,
591            }),
592        };
593
594        self.do_search_request(&query)
595    }
596
597    /// Return the next page from the search result
598    ///
599    /// # Errors
600    ///
601    /// Returns an error if there is a network problem, or pagination not available
602    pub fn next_page_search(&self, response: &SearchResponse) -> Result<SearchResponse> {
603        if let Some(uuid) = response.pagination {
604            let request = SearchRequest {
605                search: SearchType::Continuation(uuid),
606            };
607            return self.do_search_request(&request);
608        }
609
610        bail!("Pagination not available")
611    }
612
613    fn do_search_request(&self, query: &SearchRequest) -> Result<SearchResponse> {
614        ensure!(
615            query.is_valid(),
616            "Query isn't valid: hash isn't hexidecimal or both the hashes and file name are empty"
617        );
618
619        let response = self
620            .client
621            .post(format!("{}{}", self.url, malwaredb_api::SEARCH_URL))
622            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
623            .json(query)
624            .send()?
625            .json::<ServerResponse<SearchResponse>>()
626            .context(MDB_CLIENT_ERROR_CONTEXT)?;
627
628        match response {
629            ServerResponse::Success(search) => Ok(search),
630            ServerResponse::Error(e) => Err(e.into()),
631        }
632    }
633
634    /// Retrieve sample by hash, optionally in the `CaRT` format
635    ///
636    /// # Errors
637    ///
638    /// This may return an error if there's a network situation or if the user is not logged in
639    /// or not properly authorized to connect.
640    pub fn retrieve(&self, hash: &str, cart: bool) -> Result<Vec<u8>> {
641        let api_endpoint = if cart {
642            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_CART_URL)
643        } else {
644            format!("{}{hash}", malwaredb_api::DOWNLOAD_SAMPLE_URL)
645        };
646
647        let res = self
648            .client
649            .get(format!("{}{api_endpoint}", self.url))
650            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
651            .send()?;
652
653        if !res.status().is_success() {
654            bail!("Received code {}", res.status());
655        }
656
657        let content_digest = res.headers().get("content-digest").map(ToOwned::to_owned);
658        let body = res.bytes()?;
659        let bytes = body.to_vec();
660
661        // TODO: Make this required in v0.3
662        if let Some(digest) = content_digest {
663            let hash = HashType::from_content_digest_header(digest.to_str()?)?;
664            if hash.verify(&bytes) {
665                trace!("Hash verified for sample {hash}");
666            } else {
667                error!("Hash mismatch for sample {hash}");
668            }
669        } else {
670            warn!("No content digest header received for sample {hash}");
671        }
672
673        Ok(bytes)
674    }
675
676    /// Fetch a report for a sample
677    ///
678    /// # Errors
679    ///
680    /// This may return an error if there's a network situation or if the user is not logged in
681    /// or not properly authorized to connect.
682    pub fn report(&self, hash: &str) -> Result<Report> {
683        let response = self
684            .client
685            .get(format!(
686                "{}{}/{hash}",
687                self.url,
688                malwaredb_api::SAMPLE_REPORT_URL
689            ))
690            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
691            .send()?
692            .json::<ServerResponse<Report>>()
693            .context(MDB_CLIENT_ERROR_CONTEXT)?;
694
695        match response {
696            ServerResponse::Success(report) => Ok(report),
697            ServerResponse::Error(e) => Err(e.into()),
698        }
699    }
700
701    /// Find similar samples in `MalwareDB` based on the contents of a given file.
702    /// This does not submit the sample to `MalwareDB`.
703    ///
704    /// # Errors
705    ///
706    /// This may return an error if there's a network situation or if the user is not logged in
707    /// or not properly authorized to connect.
708    pub fn similar(&self, contents: &[u8]) -> Result<SimilarSamplesResponse> {
709        let mut hashes = vec![];
710        let ssdeep_hash = FuzzyHash::new(contents);
711
712        let build_hasher = Murmur3HashState::default();
713        let lzjd_str =
714            LZDict::from_bytes_stream(contents.iter().copied(), &build_hasher).to_string();
715        hashes.push((malwaredb_api::SimilarityHashType::LZJD, lzjd_str));
716        hashes.push((
717            malwaredb_api::SimilarityHashType::SSDeep,
718            ssdeep_hash.to_string(),
719        ));
720
721        let mut builder = TlshBuilder::new(
722            tlsh_fixed::BucketKind::Bucket256,
723            tlsh_fixed::ChecksumKind::ThreeByte,
724            tlsh_fixed::Version::Version4,
725        );
726
727        builder.update(contents);
728        if let Ok(hasher) = builder.build() {
729            hashes.push((malwaredb_api::SimilarityHashType::TLSH, hasher.hash()));
730        }
731
732        if let Ok(exe) = EXE::from(contents) {
733            if let Some(imports) = exe.imports {
734                hashes.push((
735                    malwaredb_api::SimilarityHashType::ImportHash,
736                    hex::encode(imports.hash()),
737                ));
738                hashes.push((
739                    malwaredb_api::SimilarityHashType::FuzzyImportHash,
740                    imports.fuzzy_hash(),
741                ));
742            }
743        }
744
745        let request = malwaredb_api::SimilarSamplesRequest { hashes };
746
747        let response = self
748            .client
749            .post(format!(
750                "{}{}",
751                self.url,
752                malwaredb_api::SIMILAR_SAMPLES_URL
753            ))
754            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
755            .json(&request)
756            .send()?
757            .json::<ServerResponse<SimilarSamplesResponse>>()
758            .context(MDB_CLIENT_ERROR_CONTEXT)?;
759
760        match response {
761            ServerResponse::Success(similar) => Ok(similar),
762            ServerResponse::Error(e) => Err(e.into()),
763        }
764    }
765
766    /// Submit a Yara rule and return the UUID of the search for later retrieval.
767    ///
768    /// # Errors
769    ///
770    /// Network or authentication errors
771    pub fn yara_search(&self, yara: &str) -> Result<YaraSearchRequestResponse> {
772        let yara = YaraSearchRequest {
773            rules: vec![yara.to_string()],
774            response: PartialHashSearchType::SHA256,
775        };
776
777        let response = self
778            .client
779            .post(format!("{}{}", self.url, malwaredb_api::YARA_SEARCH_URL))
780            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
781            .json(&yara)
782            .send()?
783            .json::<ServerResponse<YaraSearchRequestResponse>>()?;
784
785        match response {
786            ServerResponse::Success(similar) => Ok(similar),
787            ServerResponse::Error(e) => Err(e.into()),
788        }
789    }
790
791    /// Get the result from a Yara search
792    ///
793    /// # Errors
794    ///
795    /// Network or authentication errors
796    pub fn yara_result(&self, uuid: Uuid) -> Result<YaraSearchResponse> {
797        let response = self
798            .client
799            .get(format!(
800                "{}{}/{uuid}",
801                self.url,
802                malwaredb_api::YARA_SEARCH_URL
803            ))
804            .header(malwaredb_api::MDB_API_HEADER, &self.api_key)
805            .send()?
806            .json::<ServerResponse<YaraSearchResponse>>()?;
807
808        match response {
809            ServerResponse::Success(sources) => Ok(sources),
810            ServerResponse::Error(e) => Err(e.into()),
811        }
812    }
813}
814
815impl Debug for MdbClient {
816    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
817        use crate::MDB_VERSION;
818
819        writeln!(f, "MDB Client v{MDB_VERSION}: {}", self.url)
820    }
821}
822
823/// Wrapper around search results for iterating over resulting hashes with the blocking client
824///
825/// ```rust,no_run
826/// use malwaredb_client::blocking::{MdbClient, IterableHashSearchResult};
827/// use malwaredb_client::malwaredb_api::{PartialHashSearchType, SearchType};
828///
829/// let client = MdbClient::load().expect("Failed to load client or parse config file");
830///
831/// // Get the first 100 files where the file name contains "foo", returning hashes as SHA-256
832/// let search_result = client.partial_search(None, Some("foo".into()), PartialHashSearchType::SHA256, 100).unwrap();
833/// for hash in IterableHashSearchResult::from(search_result, &client) {
834///     println!("{hash}");
835/// }
836/// ```
837#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
838pub struct IterableHashSearchResult<'a> {
839    /// Server search result
840    pub response: SearchResponse,
841
842    /// Blocking client
843    client: &'a MdbClient,
844}
845
846impl<'a> IterableHashSearchResult<'a> {
847    /// Iterate over the hashes from a search result and blocking client
848    #[must_use]
849    pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
850        Self { response, client }
851    }
852}
853
854impl Iterator for IterableHashSearchResult<'_> {
855    type Item = String;
856
857    fn next(&mut self) -> Option<Self::Item> {
858        if let Some(hash) = self.response.hashes.pop() {
859            Some(hash)
860        } else if let Some(uuid) = self.response.pagination {
861            let request = SearchRequest {
862                search: SearchType::Continuation(uuid),
863            };
864
865            self.response = match self.client.do_search_request(&request) {
866                Ok(response) => response,
867                Err(e) => {
868                    warn!("Failed to continue search: {e}");
869                    return None;
870                }
871            };
872
873            self.response.hashes.pop()
874        } else {
875            None
876        }
877    }
878}
879
880/// Wrapper around search results for iterating over resulting binaries with the blocking client
881#[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
882pub struct IterableSampleSearchResult<'a> {
883    /// Server search result
884    pub response: SearchResponse,
885
886    /// Blocking client
887    client: &'a MdbClient,
888}
889
890impl<'a> IterableSampleSearchResult<'a> {
891    /// Iterate over the hashes from a search result and blocking client, returning the binary
892    #[must_use]
893    pub fn from(response: SearchResponse, client: &'a MdbClient) -> Self {
894        Self { response, client }
895    }
896}
897
898impl Iterator for IterableSampleSearchResult<'_> {
899    type Item = Vec<u8>;
900
901    fn next(&mut self) -> Option<Self::Item> {
902        if let Some(hash) = self.response.hashes.pop() {
903            let binary = match self.client.retrieve(&hash, false) {
904                Ok(binary) => binary,
905                Err(e) => {
906                    error!("Failed to download {hash}: {e}");
907                    return None;
908                }
909            };
910            Some(binary)
911        } else if let Some(uuid) = self.response.pagination {
912            let request = SearchRequest {
913                search: SearchType::Continuation(uuid),
914            };
915
916            self.response = match self.client.do_search_request(&request) {
917                Ok(response) => response,
918                Err(e) => {
919                    warn!("Failed to continue search: {e}");
920                    return None;
921                }
922            };
923
924            if let Some(hash) = self.response.hashes.pop() {
925                let binary = match self.client.retrieve(&hash, false) {
926                    Ok(binary) => binary,
927                    Err(e) => {
928                        error!("Failed to download {hash}: {e}");
929                        return None;
930                    }
931                };
932                Some(binary)
933            } else {
934                None
935            }
936        } else {
937            None
938        }
939    }
940}