malwaredb_server/http/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::State;
4use crate::db::types::FileTypes;
5use malwaredb_api::digest::HashType;
6use malwaredb_api::{
7    GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSample, Report,
8    SearchRequest, ServerInfo, Sources, SupportedFileTypes, MDB_API_HEADER,
9};
10
11use std::fmt::{Display, Formatter};
12use std::io::Cursor;
13use std::iter::once;
14use std::sync::Arc;
15
16use axum::body::Bytes;
17use axum::extract::{DefaultBodyLimit, Path};
18use axum::http::{header, StatusCode};
19use axum::response::IntoResponse;
20use axum::routing::{get, post};
21use axum::{Extension, Json, Router};
22use base64::{engine::general_purpose, Engine as _};
23use constcat::concat;
24use http::{HeaderMap, HeaderName};
25use sha2::{Digest, Sha256};
26use tower_http::compression::CompressionLayer;
27use tower_http::decompression::DecompressionLayer;
28use tower_http::limit::RequestBodyLimitLayer;
29use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
30
31mod receive;
32
33/// Build the web service from the initial server state object
34pub fn app(state: Arc<State>) -> Router {
35    // This is an estimate, so err on the side of usability.
36    const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSample>>() * 2;
37
38    let compression_layer = CompressionLayer::new()
39        .br(true)
40        .deflate(true)
41        .gzip(true)
42        .zstd(true);
43
44    let decompression_layer = DecompressionLayer::new()
45        .br(true)
46        .deflate(true)
47        .gzip(true)
48        .zstd(true);
49
50    let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
51
52    Router::new()
53        .route("/", get(health))
54        .route(malwaredb_api::SERVER_INFO, get(get_mdb_info))
55        .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
56        .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
57        .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
58        .route(
59            malwaredb_api::SUPPORTED_FILE_TYPES,
60            get(get_supported_types),
61        )
62        .route(malwaredb_api::LIST_LABELS, get(get_labels))
63        .route(malwaredb_api::LIST_SOURCES, get(get_sources))
64        .route(malwaredb_api::UPLOAD_SAMPLE, post(get_new_sample))
65        .route(malwaredb_api::SEARCH, post(sample_search))
66        .route(
67            concat!(malwaredb_api::DOWNLOAD_SAMPLE, "/{hash}"),
68            get(download_sample),
69        )
70        .route(
71            concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART, "/{hash}"),
72            get(download_sample_cart),
73        )
74        .route(
75            concat!(malwaredb_api::SAMPLE_REPORT, "/{hash}"),
76            get(sample_report),
77        )
78        .route(malwaredb_api::SIMILAR_SAMPLES, post(find_similar))
79        .layer(DefaultBodyLimit::max(state.max_upload))
80        .layer(compression_layer)
81        .layer(decompression_layer)
82        .layer(size_limit_layer)
83        .layer(SetSensitiveHeadersLayer::new(once(
84            HeaderName::from_static(MDB_API_HEADER),
85        )))
86        .layer(Extension(state))
87}
88
89async fn health() -> StatusCode {
90    StatusCode::OK
91}
92
93async fn get_mdb_info(
94    Extension(state): Extension<Arc<State>>,
95) -> Result<Json<ServerInfo>, HttpError> {
96    let server_info = state.get_info().await?;
97    Ok(Json(server_info))
98}
99
100async fn user_login(
101    Extension(state): Extension<Arc<State>>,
102    Json(payload): Json<GetAPIKeyRequest>,
103) -> Result<Json<GetAPIKeyResponse>, HttpError> {
104    let api_key = state
105        .db_type
106        .authenticate(&payload.user, &payload.password)
107        .await?;
108
109    Ok(Json(GetAPIKeyResponse {
110        key: Some(api_key),
111        message: None,
112    }))
113}
114
115async fn user_logout(
116    Extension(state): Extension<Arc<State>>,
117    headers: HeaderMap,
118) -> Result<StatusCode, HttpError> {
119    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
120        anyhow::Error::msg("Missing API key"),
121        StatusCode::NOT_ACCEPTABLE,
122    ))?;
123    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
124    state.db_type.reset_own_api_key(uid).await?;
125    Ok(StatusCode::OK)
126}
127
128async fn get_user_groups_sources(
129    Extension(state): Extension<Arc<State>>,
130    headers: HeaderMap,
131) -> Result<Json<GetUserInfoResponse>, HttpError> {
132    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
133        anyhow::Error::msg("Missing API key"),
134        StatusCode::NOT_ACCEPTABLE,
135    ))?;
136    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
137    let groups_sources = state.db_type.get_user_info(uid).await?;
138    Ok(Json(groups_sources))
139}
140
141async fn get_supported_types(
142    Extension(state): Extension<Arc<State>>,
143) -> Result<Json<SupportedFileTypes>, HttpError> {
144    let data_types = state.db_type.get_known_data_types().await?;
145    let file_types = FileTypes(data_types);
146    Ok(Json(file_types.into()))
147}
148
149async fn get_labels(
150    Extension(state): Extension<Arc<State>>,
151    headers: HeaderMap,
152) -> Result<Json<Labels>, HttpError> {
153    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
154        anyhow::Error::msg("Missing API key"),
155        StatusCode::NOT_ACCEPTABLE,
156    ))?;
157    let _uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
158    let labels = state.db_type.get_labels().await?;
159    Ok(Json(labels))
160}
161
162async fn get_sources(
163    Extension(state): Extension<Arc<State>>,
164    headers: HeaderMap,
165) -> Result<Json<Sources>, HttpError> {
166    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
167        anyhow::Error::msg("Missing API key"),
168        StatusCode::NOT_ACCEPTABLE,
169    ))?;
170    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
171    let sources = state.db_type.get_user_sources(uid).await?;
172    Ok(Json(sources))
173}
174
175async fn get_new_sample(
176    Extension(state): Extension<Arc<State>>,
177    headers: HeaderMap,
178    Json(payload): Json<NewSample>,
179) -> Result<StatusCode, HttpError> {
180    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
181        anyhow::Error::msg("Missing API key"),
182        StatusCode::NOT_ACCEPTABLE,
183    ))?;
184    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
185
186    let allowed = state
187        .db_type
188        .allowed_user_source(uid, payload.source_id)
189        .await?;
190
191    if !allowed {
192        return Err(HttpError(
193            anyhow::Error::msg("Unauthorized"),
194            StatusCode::UNAUTHORIZED,
195        ));
196    }
197
198    let received_hash = hex::decode(&payload.sha256)?;
199    let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
200
201    let mut hasher = Sha256::new();
202    hasher.update(&bytes);
203    let result = hasher.finalize();
204
205    if result[..] != received_hash[..] {
206        return Err(HttpError(
207            anyhow::Error::msg("Hash mismatch"),
208            StatusCode::NOT_ACCEPTABLE,
209        ));
210    }
211
212    receive::incoming_sample(
213        state.clone(),
214        bytes,
215        uid,
216        payload.source_id,
217        payload.file_name,
218    )
219    .await?;
220
221    Ok(StatusCode::OK)
222}
223
224async fn sample_search(
225    Extension(state): Extension<Arc<State>>,
226    headers: HeaderMap,
227    Json(payload): Json<SearchRequest>,
228) -> Result<Json<Vec<String>>, HttpError> {
229    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
230        anyhow::Error::msg("Missing API key"),
231        StatusCode::NOT_ACCEPTABLE,
232    ))?;
233    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
234
235    let hashes = state.db_type.partial_search(uid, &payload).await?;
236    Ok(Json(hashes))
237}
238
239async fn download_sample(
240    Path(hash): Path<String>,
241    headers: HeaderMap,
242    Extension(state): Extension<Arc<State>>,
243) -> Result<impl IntoResponse, HttpError> {
244    if state.directory.is_none() {
245        return Err(HttpError(
246            anyhow::Error::msg("Server does not store samples"),
247            StatusCode::NOT_ACCEPTABLE,
248        ));
249    }
250
251    let hash = HashType::try_from(hash)?;
252    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
253        anyhow::Error::msg("Missing API key"),
254        StatusCode::NOT_ACCEPTABLE,
255    ))?;
256
257    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
258    let sha256 = state.db_type.retrieve_sample(uid, &hash).await?;
259
260    let contents = state.retrieve_bytes(&sha256).await?;
261
262    let mut bytes = Bytes::from(contents).into_response();
263    let name_header_value = format!("attachment; filename=\"{sha256}\"");
264    bytes.headers_mut().insert(
265        header::CONTENT_DISPOSITION,
266        http::HeaderValue::from_str(&name_header_value)
267            .unwrap_or(http::HeaderValue::from_static("Unknown.bin")),
268    );
269
270    Ok(bytes)
271}
272
273async fn download_sample_cart(
274    Path(hash): Path<String>,
275    headers: HeaderMap,
276    Extension(state): Extension<Arc<State>>,
277) -> Result<impl IntoResponse, HttpError> {
278    if state.directory.is_none() {
279        return Err(HttpError(
280            anyhow::Error::msg("Server does not store samples"),
281            StatusCode::NOT_ACCEPTABLE,
282        ));
283    }
284
285    let hash = HashType::try_from(hash)?;
286    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
287        anyhow::Error::msg("Missing API key"),
288        StatusCode::NOT_ACCEPTABLE,
289    ))?;
290
291    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
292    let sha256 = state.db_type.retrieve_sample(uid, &hash).await?;
293    let report = state.db_type.get_sample_report(uid, &hash).await?;
294
295    let contents = state.retrieve_bytes(&sha256).await?;
296    let contents_cursor = Cursor::new(contents);
297    let mut output_cursor = Cursor::new(vec![]);
298
299    let mut output_metadata = cart_container::JsonMap::new();
300    output_metadata.insert("sha384".into(), report.sha384.into());
301    output_metadata.insert("sha512".into(), report.sha512.into());
302    output_metadata.insert("entropy".into(), report.entropy.into());
303    if let Some(filecmd) = report.filecommand {
304        output_metadata.insert("file".into(), filecmd.into());
305    }
306
307    cart_container::pack_stream(
308        contents_cursor,
309        &mut output_cursor,
310        Some(output_metadata),
311        None,
312        cart_container::digesters::default_digesters(), // MD-5, SHA-1, SHA-256 applied here
313        None,
314    )?;
315
316    let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
317    let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
318    bytes.headers_mut().insert(
319        header::CONTENT_DISPOSITION,
320        http::HeaderValue::from_str(&name_header_value)
321            .unwrap_or(http::HeaderValue::from_static("Unknown.cart")),
322    );
323
324    Ok(bytes)
325}
326
327async fn sample_report(
328    Path(hash): Path<String>,
329    headers: HeaderMap,
330    Extension(state): Extension<Arc<State>>,
331) -> Result<Json<Report>, HttpError> {
332    let hash = HashType::try_from(hash)?;
333    let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
334        anyhow::Error::msg("Missing API key"),
335        StatusCode::NOT_ACCEPTABLE,
336    ))?;
337
338    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
339    let report = state.db_type.get_sample_report(uid, &hash).await?;
340    Ok(Json(report))
341}
342
343async fn find_similar(
344    Extension(state): Extension<Arc<State>>,
345    headers: HeaderMap,
346    Json(payload): Json<malwaredb_api::SimilarSamplesRequest>,
347) -> Result<Json<malwaredb_api::SimilarSamplesResponse>, HttpError> {
348    let key = headers.get(malwaredb_api::MDB_API_HEADER).ok_or(HttpError(
349        anyhow::Error::msg("Missing API key"),
350        StatusCode::NOT_ACCEPTABLE,
351    ))?;
352
353    let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
354
355    let results = state
356        .db_type
357        .find_similar_samples(uid, &payload.hashes)
358        .await?;
359
360    Ok(Json(malwaredb_api::SimilarSamplesResponse {
361        results,
362        message: None,
363    }))
364}
365
366// How to use `anyhow::Error` with `axum`:
367// https://github.com/tokio-rs/axum/blob/c97967252de9741b602f400dc2b25c8a33216039/examples/anyhow-error-response/src/main.rs
368
369/// Anyhow wrapper to support Axum error handling
370pub struct HttpError(pub anyhow::Error, pub StatusCode);
371
372/// Convert Anyhow error into an Axum response object
373impl IntoResponse for HttpError {
374    fn into_response(self) -> axum::response::Response {
375        (self.1, format!("MDB error: {}", self.0)).into_response()
376    }
377}
378
379impl Display for HttpError {
380    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
381        write!(f, "{}", self.0)
382    }
383}
384
385/// Enable the use of the ? operator in Axum handler functions
386impl<E> From<E> for HttpError
387where
388    E: Into<anyhow::Error>,
389{
390    fn from(err: E) -> Self {
391        Self(err.into(), StatusCode::INTERNAL_SERVER_ERROR)
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use malwaredb_client::MdbClient;
398
399    use super::*;
400    use crate::crypto::{EncryptionOption, FileEncryption};
401    use crate::db::DatabaseType;
402    use malwaredb_api::PartialHashSearchType;
403
404    use std::collections::HashMap;
405    use std::sync::Once;
406    use std::time::SystemTime;
407    use std::{env, fs};
408
409    use anyhow::Context;
410    use axum::body::Body;
411    use axum::http::Request;
412    use chrono::Local;
413    use http::header::CONTENT_TYPE;
414    use http_body_util::BodyExt;
415    use rstest::rstest;
416    use tower::ServiceExt;
417    // for `app.oneshot()`
418    use uuid::Uuid;
419
420    const ADMIN_UNAME: &str = "admin";
421    const ADMIN_PASSWORD: &str = "password12345";
422
423    static TRACING: Once = Once::new();
424
425    fn init_tracing() {
426        tracing_subscriber::fmt()
427            .with_max_level(tracing::Level::TRACE)
428            .init();
429    }
430
431    async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
432        // Each test needs a separate file, or else they'll clobber each other.
433        let mut db_file = env::temp_dir();
434        db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
435        if std::path::Path::new(&db_file).exists() {
436            fs::remove_file(&db_file)
437                .context(format!("failed to delete old SQLite file {db_file:?}"))
438                .unwrap();
439        }
440
441        let db_type =
442            DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
443                .await
444                .context(format!("failed to create SQLite instance for {db_file:?}"))
445                .unwrap();
446        if compress {
447            db_type.enable_compression().await.unwrap();
448        }
449        let keys = if encrypt {
450            let key = FileEncryption::from(EncryptionOption::Xor);
451            let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
452            let mut keys = HashMap::new();
453            keys.insert(key_id, key);
454            keys
455        } else {
456            HashMap::new()
457        };
458        let db_config = db_type.get_config().await.unwrap();
459
460        let state = State {
461            port: 8080,
462            directory: Some(
463                tempfile::TempDir::with_prefix("mdb-temp-samples")
464                    .unwrap()
465                    .path()
466                    .into(),
467            ),
468            max_upload: 10 * 1024 * 1024,
469            ip: "127.0.0.1".parse().unwrap(),
470            db_type,
471            db_config,
472            keys,
473            started: SystemTime::now(),
474            #[cfg(feature = "vt")]
475            vt_client: None,
476            cert: None,
477            key: None,
478        };
479
480        state
481            .db_type
482            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
483            .await
484            .context("Failed to set admin password")
485            .unwrap();
486
487        let source_id = state
488            .db_type
489            .create_source("temp-source", None, None, Local::now(), true, Some(false))
490            .await
491            .unwrap();
492
493        state
494            .db_type
495            .add_group_to_source(0, source_id)
496            .await
497            .unwrap();
498
499        (Arc::new(state), source_id)
500    }
501
502    /// Get a new state and generate an API token
503    async fn state_and_token(pem_file: bool) -> (State, u32, String) {
504        let mut db_file = env::temp_dir();
505        db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
506        if std::path::Path::new(&db_file).exists() {
507            fs::remove_file(&db_file)
508                .context(format!("failed to delete old SQLite file {db_file:?}"))
509                .unwrap();
510        }
511
512        let db_type =
513            DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
514                .await
515                .context(format!("failed to create SQLite instance for {db_file:?}"))
516                .unwrap();
517
518        let db_config = db_type.get_config().await.unwrap();
519
520        let state = if pem_file {
521            State {
522                port: 8443,
523                directory: Some(
524                    tempfile::TempDir::with_prefix("mdb-temp-samples")
525                        .unwrap()
526                        .path()
527                        .into(),
528                ),
529                max_upload: 10 * 1024 * 1024,
530                ip: "127.0.0.1".parse().unwrap(),
531                db_type,
532                db_config,
533                keys: HashMap::default(),
534                started: SystemTime::now(),
535                #[cfg(feature = "vt")]
536                vt_client: None,
537                cert: Some("../../testdata/server_ca_cert.pem".into()),
538                key: Some("../../testdata/server_key.pem".into()),
539            }
540        } else {
541            State {
542                port: 8444,
543                directory: Some(
544                    tempfile::TempDir::with_prefix("mdb-temp-samples")
545                        .unwrap()
546                        .path()
547                        .into(),
548                ),
549                max_upload: 10 * 1024 * 1024,
550                ip: "127.0.0.1".parse().unwrap(),
551                db_type,
552                db_config,
553                keys: HashMap::default(),
554                started: SystemTime::now(),
555                #[cfg(feature = "vt")]
556                vt_client: None,
557                cert: Some("../../testdata/server_cert.der".into()),
558                key: Some("../../testdata/server_key.der".into()),
559            }
560        };
561
562        state
563            .db_type
564            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
565            .await
566            .context("Failed to set admin password")
567            .unwrap();
568
569        let source_id = state
570            .db_type
571            .create_source("temp-source", None, None, Local::now(), true, Some(false))
572            .await
573            .unwrap();
574
575        state
576            .db_type
577            .add_group_to_source(0, source_id)
578            .await
579            .unwrap();
580
581        let token = state
582            .db_type
583            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
584            .await
585            .unwrap();
586
587        (state, source_id, token)
588    }
589
590    async fn get_key(state: Arc<State>) -> String {
591        let key_request = serde_json::to_string(&GetAPIKeyRequest {
592            user: ADMIN_UNAME.into(),
593            password: ADMIN_PASSWORD.into(),
594        })
595        .context("Failed to convert API key request to JSON")
596        .unwrap();
597
598        let request = Request::builder()
599            .method("POST")
600            .uri(malwaredb_api::USER_LOGIN_URL)
601            .header(CONTENT_TYPE, "application/json")
602            .body(Body::from(key_request))
603            .unwrap();
604
605        let response = app(state)
606            .oneshot(request)
607            .await
608            .context("failed to send/receive login request")
609            .unwrap();
610
611        assert_eq!(response.status(), StatusCode::OK);
612        let bytes = response
613            .into_body()
614            .collect()
615            .await
616            .expect("failed to collect response body to bytes")
617            .to_bytes();
618        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
619            .context("failed to convert response to string")
620            .unwrap();
621
622        let response: GetAPIKeyResponse = serde_json::from_str(&json_response)
623            .context("failed to convert json response to object")
624            .unwrap();
625
626        let key = response.key.clone().unwrap();
627        assert_eq!(key.len(), 64);
628
629        key
630    }
631
632    #[tokio::test]
633    async fn about_self() {
634        let (state, _) = state(false, false).await;
635        let api_key = get_key(state.clone()).await;
636
637        let request = Request::builder()
638            .method("GET")
639            .uri(malwaredb_api::USER_INFO_URL)
640            .header(MDB_API_HEADER, &api_key)
641            .body(Body::empty())
642            .unwrap();
643
644        let response = app(state.clone())
645            .oneshot(request)
646            .await
647            .context("failed to send/receive login request")
648            .unwrap();
649
650        assert_eq!(response.status(), StatusCode::OK);
651        let bytes = response
652            .into_body()
653            .collect()
654            .await
655            .expect("failed to collect response body to bytes")
656            .to_bytes();
657        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
658            .context("failed to convert response to string")
659            .unwrap();
660
661        let response: GetUserInfoResponse = serde_json::from_str(&json_response)
662            .context("failed to convert json response to object")
663            .unwrap();
664
665        assert_eq!(response.id, 0);
666        assert!(response.is_admin);
667        assert!(!response.is_readonly);
668        assert_eq!(response.username, "admin");
669
670        // Check labels, should be empty
671        let request = Request::builder()
672            .method("GET")
673            .uri(malwaredb_api::LIST_LABELS)
674            .header(MDB_API_HEADER, &api_key)
675            .body(Body::empty())
676            .unwrap();
677
678        let response = app(state)
679            .oneshot(request)
680            .await
681            .context("failed to send/receive login request")
682            .unwrap();
683
684        assert_eq!(response.status(), StatusCode::OK);
685        let bytes = response
686            .into_body()
687            .collect()
688            .await
689            .expect("failed to collect response body to bytes")
690            .to_bytes();
691        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
692            .context("failed to convert response to string")
693            .unwrap();
694
695        let response: Labels = serde_json::from_str(&json_response)
696            .context("failed to convert json response to object")
697            .unwrap();
698
699        assert!(response.is_empty());
700    }
701
702    #[rstest]
703    #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true)]
704    #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false)]
705    #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false)]
706    #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false)]
707    #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false)]
708    #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false)]
709    #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false)]
710    #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false)]
711    #[tokio::test]
712    async fn submit_sample(
713        #[case] contents: &[u8],
714        #[case] compress: bool,
715        #[case] encrypt: bool,
716        #[case] cart: bool,
717    ) {
718        let (state, source_id) = state(compress, encrypt).await;
719        let api_key = get_key(state.clone()).await;
720
721        let file_contents_b64 = general_purpose::STANDARD.encode(contents);
722        let mut hasher = Sha256::new();
723        hasher.update(contents);
724        let sha256 = hex::encode(hasher.finalize());
725
726        let upload = serde_json::to_string(&NewSample {
727            file_name: "some_sample".into(),
728            source_id,
729            file_contents_b64,
730            sha256: sha256.clone(),
731        })
732        .context("failed to create upload structure")
733        .unwrap();
734
735        let request = Request::builder()
736            .method("POST")
737            .uri(malwaredb_api::UPLOAD_SAMPLE)
738            .header(CONTENT_TYPE, "application/json")
739            .header(MDB_API_HEADER, &api_key)
740            .body(Body::from(upload))
741            .unwrap();
742
743        let response = app(state.clone())
744            .oneshot(request)
745            .await
746            .context("failed to send/receive upload request/response")
747            .unwrap();
748
749        assert_eq!(response.status(), StatusCode::OK);
750
751        let sha256 = if cart {
752            // We have to get the original SHA-256 hash, not the hash of the CaRT container.
753            // This way we can check that the file exists on disk, and request the file again later in this test.
754            let mut input_buffer = Cursor::new(contents);
755            let mut output_buffer = Cursor::new(vec![]);
756            let (_, footer) =
757                cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
758                    .expect("failed to decode CaRT file");
759            let footer = footer.expect("CaRT should have had a footer");
760            let sha256 = footer
761                .get("sha256")
762                .expect("CaRT footer should have had an entry for SHA-256")
763                .to_string();
764            sha256.replace('"', "") // Remove the quotes that get added by JSON
765        } else {
766            sha256
767        };
768
769        if let Some(dir) = &state.directory {
770            let mut sample_path = dir.clone();
771            sample_path.push(format!(
772                "{}/{}/{}/{}",
773                &sha256[0..2],
774                &sha256[2..4],
775                &sha256[4..6],
776                sha256
777            ));
778            eprintln!("Submitted sample should exist at {sample_path:?}.");
779            assert!(sample_path.exists());
780
781            if compress {
782                let sample_size_on_disk = sample_path.metadata().unwrap().len();
783                eprintln!(
784                    "Original size: {}, compressed: {}",
785                    contents.len(),
786                    sample_size_on_disk
787                );
788                assert!(sample_size_on_disk < contents.len() as u64);
789            }
790        } else {
791            panic!("Directory was set for the state, but is now `None`");
792        }
793
794        let request = Request::builder()
795            .method("GET")
796            .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT))
797            .header(MDB_API_HEADER, &api_key)
798            .body(Body::empty())
799            .unwrap();
800
801        let response = app(state.clone())
802            .oneshot(request)
803            .await
804            .context("failed to send/receive upload request/response")
805            .unwrap();
806
807        let bytes = response
808            .into_body()
809            .collect()
810            .await
811            .expect("failed to collect response body to bytes")
812            .to_bytes();
813        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
814            .context("failed to convert response to string")
815            .unwrap();
816
817        let report: Report = serde_json::from_str(&json_response)
818            .context("failed to convert json response to object")
819            .unwrap();
820
821        assert_eq!(report.sha256, sha256);
822        println!("Report: {report}");
823
824        let request = Request::builder()
825            .method("GET")
826            .uri(format!("{}/{sha256}", malwaredb_api::DOWNLOAD_SAMPLE_CART))
827            .header(MDB_API_HEADER, api_key)
828            .body(Body::empty())
829            .unwrap();
830
831        let response = app(state.clone())
832            .oneshot(request)
833            .await
834            .context("failed to send/receive upload request/response for CaRT")
835            .unwrap();
836
837        let bytes = response
838            .into_body()
839            .collect()
840            .await
841            .expect("failed to collect response body to bytes")
842            .to_bytes();
843
844        let bytes = bytes.to_vec();
845        let bytes_input = Cursor::new(bytes);
846        let output = Cursor::new(vec![]);
847        match cart_container::unpack_stream(bytes_input, output, None) {
848            Ok((header, _)) => {
849                let header = header.unwrap();
850                assert_eq!(
851                    header.get("sha384"),
852                    Some(&serde_json::to_value(report.sha384).unwrap())
853                );
854            }
855            Err(e) => panic!("{e}"),
856        }
857    }
858
859    /// Integration test between Malware DB Server & Client. This was under "/tests" in the project
860    /// but `cargo hack` would disable needed features, and there wasn't a way to use `cfg` to
861    /// put a condition on the test. Additionally, such an integration test required additional
862    /// fields of the Server's state to be public or to implement `Default`, which isn't ideal.
863    #[rstest]
864    #[case::ssl_pem(true)]
865    #[case::ssl_der(false)]
866    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
867    async fn client_integration(#[case] pem: bool) {
868        TRACING.call_once(init_tracing);
869
870        // This other get state function was needed, or else there were complaints of
871        // `state` being moved and unable to be used in `tokio::spawn()`
872        let (state, source_id, token) = state_and_token(pem).await;
873        let state_port = state.port; // copy before move below
874
875        let server = tokio::spawn(async move {
876            state
877                .serve()
878                .await
879                .expect("MalwareDB failed to .serve() in tokio::spawn()");
880        });
881        assert!(!server.is_finished());
882
883        // In case start-up time is needed
884        tokio::time::sleep(std::time::Duration::new(1, 0)).await;
885
886        println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
887        assert_eq!(
888            *malwaredb_client::MDB_VERSION_SEMVER,
889            *crate::MDB_VERSION_SEMVER,
890            "SemVer parsing of MDB version failed"
891        );
892
893        let mdb_client_async = MdbClient::new(
894            format!("https://127.0.0.1:{state_port}"),
895            token.clone(),
896            Some("../../testdata/ca_cert.pem".into()),
897        )
898        .unwrap();
899        let mdb_client_blocking = MdbClient::new(
900            format!("https://127.0.0.1:{state_port}"),
901            token,
902            Some("../../testdata/ca_cert.pem".into()),
903        )
904        .unwrap();
905
906        for (index, mdb_client) in [mdb_client_async, mdb_client_blocking]
907            .into_iter()
908            .enumerate()
909        {
910            let contents = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
911
912            assert!(mdb_client
913                .submit(contents, String::from("elf_haiku_x86"), source_id)
914                .await
915                .context("failed to upload test file")
916                .unwrap());
917
918            let report = mdb_client
919                .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
920                .await
921                .expect("failed to get report for file just submitted");
922
923            assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
924
925            let similar = mdb_client
926                .similar(contents)
927                .await
928                .expect("failed to query for files similar to what was just submitted");
929
930            assert_eq!(similar.results.len(), 1);
931
932            let search = mdb_client
933                .partial_search(
934                    Some((PartialHashSearchType::Any, "AAAA".into())),
935                    None,
936                    PartialHashSearchType::Any,
937                    10,
938                )
939                .await
940                .unwrap();
941            assert!(search.is_empty());
942
943            if index > 0 {
944                mdb_client.reset_key().await.expect("failed to reset key");
945            }
946        }
947        server.abort();
948    }
949}