1#![doc = include_str!("../README.md")]
4#![deny(clippy::all)]
5#![deny(clippy::pedantic)]
6#![forbid(unsafe_code)]
7
8pub mod cart;
10
11pub mod types;
13
14use crate::types::{
15 DiscoveredServer, Label, SearchResults, ServerInfo, Source, SupportedFileType, UserInfo,
16};
17use malwaredb_client::blocking::MdbClient;
18
19use std::borrow::Cow;
20use std::path::PathBuf;
21
22use anyhow::{anyhow, Result};
23use pyo3::prelude::*;
24
25pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
27
28pub const VERSION: &str = concat!(
29 "v",
30 env!("CARGO_PKG_VERSION"),
31 "-",
32 env!("VERGEN_GIT_SHA"),
33 " ",
34 env!("VERGEN_BUILD_DATE")
35);
36
37#[pyclass(frozen)]
39pub struct MalwareDBClient {
40 inner: MdbClient,
41}
42
43#[pymethods]
44impl MalwareDBClient {
45 #[new]
51 pub fn new() -> PyResult<Self> {
52 Ok(MalwareDBClient {
53 inner: MdbClient::load()?,
54 })
55 }
56
57 #[staticmethod]
64 pub fn login(
65 url: String,
66 username: String,
67 password: String,
68 save: bool,
69 cert_path: Option<PathBuf>,
70 ) -> PyResult<Self> {
71 Ok(MalwareDBClient {
72 inner: MdbClient::login(url, username, password, save, cert_path)?,
73 })
74 }
75
76 #[staticmethod]
82 pub fn connect(url: String, api_key: String, cert_path: Option<PathBuf>) -> PyResult<Self> {
83 Ok(MalwareDBClient {
84 inner: MdbClient::new(url, api_key, cert_path)?,
85 })
86 }
87
88 #[staticmethod]
94 pub fn discover() -> Result<Vec<DiscoveredServer>> {
95 malwaredb_client::discover_servers().map(|s| s.into_iter().map(Into::into).collect())
96 }
97
98 #[staticmethod]
105 pub fn from_file(path: PathBuf) -> Result<Self> {
106 Ok(MalwareDBClient {
107 inner: MdbClient::from_file(path)?,
108 })
109 }
110
111 #[getter]
113 #[must_use]
114 pub fn url(&self) -> String {
115 self.inner.url.clone()
116 }
117
118 #[pyo3(signature = (hash, cart = false))]
125 pub fn get_file_bytes(&self, hash: &str, cart: bool) -> Result<Cow<'_, [u8]>> {
126 self.inner.retrieve(hash, cart).map(Cow::from)
127 }
128
129 pub fn submit_file(
136 &self,
137 contents: Vec<u8>,
138 file_name: String,
139 source_id: u32,
140 ) -> Result<bool> {
141 self.inner.submit(contents, file_name, source_id)
142 }
143
144 #[allow(clippy::too_many_arguments)]
153 #[pyo3(signature = (hash = None, hash_type = "sha256", file_name = None, labels = None, file_type = None, magic = None, response_hash = "sha256", limit = 100))]
154 pub fn search(
155 &self,
156 hash: Option<String>,
157 hash_type: &str,
158 file_name: Option<String>,
159 labels: Option<Vec<String>>,
160 file_type: Option<String>,
161 magic: Option<String>,
162 response_hash: &str,
163 limit: u32,
164 ) -> Result<SearchResults> {
165 let hash_type = hash_type.try_into().map_err(|e: String| anyhow!(e))?;
166 let response_hash = response_hash.try_into().map_err(|e: String| anyhow!(e))?;
167 self.inner
168 .partial_search_labels_type(
169 hash.map(|h| (hash_type, h)),
170 file_name,
171 response_hash,
172 labels,
173 file_type,
174 magic,
175 limit,
176 )
177 .map(Into::into)
178 }
179
180 pub fn get_sources(&self) -> Result<Vec<Source>> {
186 let sources = self
187 .inner
188 .sources()?
189 .sources
190 .into_iter()
191 .map(Into::into)
192 .collect();
193 Ok(sources)
194 }
195
196 pub fn server_info(&self) -> Result<ServerInfo> {
202 Ok(self.inner.server_info()?.into())
203 }
204
205 pub fn get_supported_file_types(&self) -> Result<Vec<SupportedFileType>> {
211 let supported_types = self
212 .inner
213 .supported_types()?
214 .types
215 .into_iter()
216 .map(Into::into)
217 .collect();
218 Ok(supported_types)
219 }
220
221 pub fn whoami(&self) -> Result<UserInfo> {
228 self.inner.whoami().map(Into::into)
229 }
230
231 pub fn labels(&self) -> Result<Vec<Label>> {
238 self.inner
239 .labels()
240 .map(|labels| labels.0.into_iter().map(Into::into).collect())
241 }
242}
243
244#[cfg(not(feature = "rust_lib"))]
247#[pymodule]
248fn malwaredb(m: &Bound<'_, PyModule>) -> PyResult<()> {
249 if let Err(log_error) = pyo3_log::try_init() {
250 eprintln!("Failed to enable logging: {log_error}");
251 }
252
253 m.add_class::<MalwareDBClient>()?;
254 m.add_class::<Label>()?;
255 m.add_class::<ServerInfo>()?;
256 m.add_class::<Source>()?;
257 m.add_class::<SupportedFileType>()?;
258 m.add_class::<UserInfo>()?;
259 cart::register_cart_module(m)?;
260 m.add("__version__", MDB_VERSION)?;
261 m.add("full_version", VERSION)?;
262 Ok(())
263}