malwaredb/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![doc = include_str!("../README.md")]
4#![deny(clippy::all)]
5#![deny(clippy::pedantic)]
6#![forbid(unsafe_code)]
7
8/// `CaRT` file I/O
9pub mod cart;
10
11/// Python wrapper types for some Malware DB API types
12pub mod types;
13
14use std::borrow::Cow;
15use std::path::PathBuf;
16
17use crate::types::{Label, ServerInfo, Source, SupportedFileType, UserInfo};
18use malwaredb_client::blocking::MdbClient;
19
20use anyhow::{anyhow, Result};
21use pyo3::prelude::*;
22
23/// MDB version
24pub const MDB_VERSION: &str = env!("CARGO_PKG_VERSION");
25
26pub const VERSION: &str = concat!(
27    "v",
28    env!("CARGO_PKG_VERSION"),
29    "-",
30    env!("VERGEN_GIT_DESCRIBE"),
31    " ",
32    env!("VERGEN_BUILD_DATE")
33);
34
35/// Malware DB client
36#[pyclass(frozen)]
37pub struct MalwareDBClient {
38    inner: MdbClient,
39}
40
41#[pymethods]
42impl MalwareDBClient {
43    /// Load a configuration from a file if it can be found
44    ///
45    /// # Errors
46    ///
47    /// Returns an error if the configuration file can't be found or isn't valid.
48    #[new]
49    pub fn new() -> PyResult<Self> {
50        Ok(MalwareDBClient {
51            inner: MdbClient::load()?,
52        })
53    }
54
55    /// Login with a username and password
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the server URL, username, or password were incorrect, or if a network
60    /// issue occurred.
61    #[staticmethod]
62    pub fn login(
63        url: String,
64        username: String,
65        password: String,
66        save: bool,
67        cert_path: Option<PathBuf>,
68    ) -> PyResult<Self> {
69        Ok(MalwareDBClient {
70            inner: MdbClient::login(url, username, password, save, cert_path)?,
71        })
72    }
73
74    /// Connect if an API key is already known
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if a list of certificates was passed and any were not in the expected
79    /// DER or PEM format or could not be parsed.
80    #[staticmethod]
81    pub fn connect(url: String, api_key: String, cert_path: Option<PathBuf>) -> PyResult<Self> {
82        Ok(MalwareDBClient {
83            inner: MdbClient::new(url, api_key, cert_path)?,
84        })
85    }
86
87    /// Connect using a specific configuration file
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if the configuration file cannot be read, possibly because it
92    /// doesn't exist or due to a permission error or a parsing error.
93    #[staticmethod]
94    pub fn from_file(path: PathBuf) -> Result<Self> {
95        Ok(MalwareDBClient {
96            inner: MdbClient::from_file(path)?,
97        })
98    }
99
100    /// Get the server's URL
101    #[getter]
102    #[must_use]
103    pub fn url(&self) -> String {
104        self.inner.url.clone()
105    }
106
107    /// Get the bytes of a sample from the database, optionally as a `CaRT` file.
108    ///
109    /// # Errors
110    ///
111    /// This may return an error if there's a network situation or if the user is not logged in
112    /// or not properly authorized to connect.
113    #[pyo3(signature = (hash, cart = false))]
114    pub fn get_file_bytes(&self, hash: &str, cart: bool) -> Result<Cow<'_, [u8]>> {
115        self.inner.retrieve(hash, cart).map(Cow::from)
116    }
117
118    /// Submit a file to the database, which requires the file name and source ID. Returns true if stored.
119    ///
120    /// # Errors
121    ///
122    /// This may return an error if there's a network situation or if the user is not logged in
123    /// or not properly authorized to connect.
124    pub fn submit_file(
125        &self,
126        contents: Vec<u8>,
127        file_name: String,
128        source_id: u32,
129    ) -> Result<bool> {
130        self.inner.submit(contents, file_name, source_id)
131    }
132
133    /// Search by partial hash and/or partial file name, returning a list of hashes by specified hash type
134    ///
135    /// # Errors
136    ///
137    /// * Invalid hash types will result in an error
138    /// * This may return an error if there's a network situation or if the user is not logged in or the request isn't valid
139    #[pyo3(signature = (hash = None, hash_type="sha256", file_name = None, limit = 100, response_hash = "sha256"))]
140    pub fn partial_search(
141        &self,
142        hash: Option<String>,
143        hash_type: &str,
144        file_name: Option<String>,
145        limit: u32,
146        response_hash: &str,
147    ) -> Result<Vec<String>> {
148        let hash_type = hash_type.try_into().map_err(|e: String| anyhow!(e))?;
149        let response_hash = response_hash.try_into().map_err(|e: String| anyhow!(e))?;
150        self.inner.partial_search(
151            hash.map(|h| (hash_type, h)),
152            file_name,
153            response_hash,
154            limit,
155        )
156    }
157
158    /// Get sources available to the user
159    ///
160    /// # Errors
161    ///
162    /// This may return an error if there's a network situation or if the user is not logged in
163    /// or not properly authorized to connect.
164    pub fn get_sources(&self) -> Result<Vec<Source>> {
165        let sources = self
166            .inner
167            .sources()?
168            .sources
169            .into_iter()
170            .map(Into::into)
171            .collect();
172        Ok(sources)
173    }
174
175    /// Get information about the server
176    ///
177    /// # Errors
178    ///
179    /// This may return an error if there's a network problem or the server is down.
180    pub fn server_info(&self) -> Result<ServerInfo> {
181        Ok(self.inner.server_info()?.into())
182    }
183
184    /// Get supported file types; Malware DB only accepts file types it knows about
185    ///
186    /// # Errors
187    ///
188    /// This may return an error if there's a network problem or the server is down.
189    pub fn get_supported_file_types(&self) -> Result<Vec<SupportedFileType>> {
190        let supported_types = self
191            .inner
192            .supported_types()?
193            .types
194            .into_iter()
195            .map(Into::into)
196            .collect();
197        Ok(supported_types)
198    }
199
200    /// Get information about the user
201    ///
202    /// # Errors
203    ///
204    /// This may return an error if there's a network problem or if the user is not logged in
205    /// or not properly authorized to connect.
206    pub fn whoami(&self) -> Result<UserInfo> {
207        self.inner.whoami().map(Into::into)
208    }
209
210    /// Get labels
211    ///
212    /// # Errors
213    ///
214    /// This may return an error if there's a network problem or if the user is not logged in
215    /// or not properly authorized to connect.
216    pub fn labels(&self) -> Result<Vec<Label>> {
217        self.inner
218            .labels()
219            .map(|labels| labels.0.into_iter().map(Into::into).collect())
220    }
221}
222
223/// Only used by this crate directly to register the module. If this crate is used as a module,
224/// that other crate must register the Rust types with that new Python module.
225#[cfg(not(feature = "rust_lib"))]
226#[pymodule]
227fn malwaredb(m: &Bound<'_, PyModule>) -> PyResult<()> {
228    if let Err(log_error) = pyo3_log::try_init() {
229        eprintln!("Failed to enable logging: {log_error}");
230    }
231
232    m.add_class::<MalwareDBClient>()?;
233    m.add_class::<Label>()?;
234    m.add_class::<ServerInfo>()?;
235    m.add_class::<Source>()?;
236    m.add_class::<SupportedFileType>()?;
237    m.add_class::<UserInfo>()?;
238    cart::register_cart_module(m)?;
239    m.add("__version__", MDB_VERSION)?;
240    m.add("full_version", VERSION)?;
241    Ok(())
242}