malwaredb_server/http/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::State;
4use malwaredb_api::digest::HashType;
5use malwaredb_api::{
6    GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSampleB64, NewSampleBytes,
7    Report, SearchRequest, SearchResponse, ServerError, ServerInfo, ServerResponse,
8    SimilarSamplesRequest, SimilarSamplesResponse, Sources, SupportedFileTypes, MDB_API_HEADER,
9};
10
11use std::fmt::{Display, Formatter};
12use std::io::Cursor;
13use std::iter::once;
14use std::sync::Arc;
15
16use axum::body::Bytes;
17use axum::extract::{DefaultBodyLimit, Path, Request};
18use axum::http::{header, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::routing::{get, post};
22use axum::{middleware, Extension, Json, Router};
23use axum_cbor::Cbor;
24use base64::{engine::general_purpose, Engine as _};
25use constcat::concat;
26use http::{HeaderMap, HeaderName, HeaderValue};
27use sha2::{Digest, Sha256};
28use tower_http::compression::CompressionLayer;
29use tower_http::decompression::DecompressionLayer;
30use tower_http::limit::RequestBodyLimitLayer;
31use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
32
33mod receive;
34
35const FAVICON_URL: &str = "/favicon.ico";
36
37/// Build the web service from the initial server state object
38pub fn app(state: Arc<State>) -> Router {
39    // This is an estimate, so err on the side of usability.
40    const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSampleB64>>() * 2;
41
42    let compression_layer = CompressionLayer::new()
43        .br(true)
44        .deflate(true)
45        .gzip(true)
46        .zstd(true);
47
48    let decompression_layer = DecompressionLayer::new()
49        .br(true)
50        .deflate(true)
51        .gzip(true)
52        .zstd(true);
53
54    let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
55
56    Router::new()
57        .route("/", get(health))
58        .route(FAVICON_URL, get(favicon))
59        .route(malwaredb_api::SERVER_INFO_URL, get(get_mdb_info))
60        .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
61        .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
62        .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
63        .route(
64            malwaredb_api::SUPPORTED_FILE_TYPES_URL,
65            get(get_supported_types),
66        )
67        .route(malwaredb_api::LIST_LABELS_URL, get(get_labels))
68        .route(malwaredb_api::LIST_SOURCES_URL, get(get_sources))
69        .route(
70            malwaredb_api::UPLOAD_SAMPLE_JSON_URL,
71            post(upload_new_sample_json),
72        )
73        .route(
74            malwaredb_api::UPLOAD_SAMPLE_CBOR_URL,
75            post(upload_new_sample_cbor),
76        )
77        .route(malwaredb_api::SEARCH_URL, post(sample_search))
78        .route(
79            concat!(malwaredb_api::DOWNLOAD_SAMPLE_URL, "/{hash}"),
80            get(download_sample),
81        )
82        .route(
83            concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART_URL, "/{hash}"),
84            get(download_sample_cart),
85        )
86        .route(
87            concat!(malwaredb_api::SAMPLE_REPORT_URL, "/{hash}"),
88            get(sample_report),
89        )
90        .route(malwaredb_api::SIMILAR_SAMPLES_URL, post(find_similar))
91        .layer(DefaultBodyLimit::max(state.max_upload))
92        .layer(compression_layer)
93        .layer(decompression_layer)
94        .layer(size_limit_layer)
95        .layer(SetSensitiveHeadersLayer::new(once(
96            HeaderName::from_static(MDB_API_HEADER),
97        )))
98        .route_layer(middleware::from_fn(response_header_middleware))
99        .layer(Extension(state))
100}
101
102/// User ID from the API key, generated by the middleware function
103struct UserInfo {
104    pub id: u32,
105}
106
107/// * Ask that content from Malware DB not be cached
108/// * One central place to get the API key and get the user's ID
109async fn response_header_middleware(
110    Extension(state): Extension<Arc<State>>,
111    headers: HeaderMap,
112    mut req: Request,
113    next: Next,
114) -> Result<Response, HttpError> {
115    const ALWAYS_ALLOWED_ENDPOINTS: [&str; 4] = [
116        "/",
117        FAVICON_URL,
118        malwaredb_api::SERVER_INFO_URL,
119        malwaredb_api::USER_LOGIN_URL,
120    ];
121
122    if !ALWAYS_ALLOWED_ENDPOINTS.contains(&req.uri().path()) {
123        let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
124            ServerError::Unauthorized,
125            StatusCode::NOT_ACCEPTABLE,
126        ))?;
127        let key = key
128            .to_str()
129            .map_err(|_| HttpError(ServerError::Unauthorized, StatusCode::NOT_ACCEPTABLE))?;
130        let uid = state.db_type.get_uid(key).await.map_err(|e| {
131            tracing::warn!("Failed to get user ID from API key: {e}");
132            HttpError(ServerError::Unauthorized, StatusCode::UNAUTHORIZED)
133        })?;
134        req.extensions_mut().insert(Arc::new(UserInfo { id: uid }));
135    }
136
137    let mut response = next.run(req).await; // Run the next service in the chain
138
139    // `no-store` so an intermediate proxy doesn't save any data
140    // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Cache-Control
141    response
142        .headers_mut()
143        .insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store"));
144
145    Ok(response)
146}
147
148async fn health() -> StatusCode {
149    StatusCode::OK
150}
151
152async fn favicon() -> Response {
153    const ICON: Bytes = Bytes::from_static(include_bytes!("../../MDB_Logo.ico"));
154
155    let mut bytes = ICON.into_response();
156    bytes.headers_mut().insert(
157        header::CONTENT_TYPE,
158        HeaderValue::from_static("image/vnd.microsoft.icon"),
159    );
160
161    bytes
162}
163
164async fn get_mdb_info(
165    Extension(state): Extension<Arc<State>>,
166) -> Result<Json<ServerResponse<ServerInfo>>, HttpError> {
167    let server_info = ServerResponse::Success(state.get_info().await?);
168    Ok(Json(server_info))
169}
170
171async fn user_login(
172    Extension(state): Extension<Arc<State>>,
173    Json(payload): Json<GetAPIKeyRequest>,
174) -> Result<Json<ServerResponse<GetAPIKeyResponse>>, HttpError> {
175    let api_key = state
176        .db_type
177        .authenticate(&payload.user, &payload.password)
178        .await?;
179
180    Ok(Json(ServerResponse::Success(GetAPIKeyResponse {
181        key: api_key,
182        message: None,
183    })))
184}
185
186async fn user_logout(
187    Extension(state): Extension<Arc<State>>,
188    Extension(user): Extension<Arc<UserInfo>>,
189) -> Result<StatusCode, HttpError> {
190    state.db_type.reset_own_api_key(user.id).await?;
191    Ok(StatusCode::OK)
192}
193
194async fn get_user_groups_sources(
195    Extension(state): Extension<Arc<State>>,
196    Extension(user): Extension<Arc<UserInfo>>,
197) -> Result<Json<ServerResponse<GetUserInfoResponse>>, HttpError> {
198    let groups_sources = ServerResponse::Success(state.db_type.get_user_info(user.id).await?);
199    Ok(Json(groups_sources))
200}
201
202async fn get_supported_types(
203    Extension(state): Extension<Arc<State>>,
204) -> Result<Json<ServerResponse<SupportedFileTypes>>, HttpError> {
205    let data_types = state.db_type.get_known_data_types().await?;
206    let file_types = ServerResponse::Success(SupportedFileTypes {
207        types: data_types.into_iter().map(Into::into).collect(),
208        message: None,
209    });
210    Ok(Json(file_types))
211}
212
213async fn get_labels(
214    Extension(state): Extension<Arc<State>>,
215) -> Result<Json<ServerResponse<Labels>>, HttpError> {
216    let labels = ServerResponse::Success(state.db_type.get_labels().await?);
217    Ok(Json(labels))
218}
219
220async fn get_sources(
221    Extension(state): Extension<Arc<State>>,
222    Extension(user): Extension<Arc<UserInfo>>,
223) -> Result<Json<ServerResponse<Sources>>, HttpError> {
224    let sources = ServerResponse::Success(state.db_type.get_user_sources(user.id).await?);
225    Ok(Json(sources))
226}
227
228async fn upload_new_sample_json(
229    Extension(state): Extension<Arc<State>>,
230    Extension(user): Extension<Arc<UserInfo>>,
231    Json(payload): Json<NewSampleB64>,
232) -> Result<StatusCode, HttpError> {
233    let allowed = state
234        .db_type
235        .allowed_user_source(user.id, payload.source_id)
236        .await?;
237
238    if !allowed {
239        return Err(HttpError(
240            ServerError::Unauthorized,
241            StatusCode::UNAUTHORIZED,
242        ));
243    }
244
245    let received_hash = hex::decode(&payload.sha256)?;
246    let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
247
248    let mut hasher = Sha256::new();
249    hasher.update(&bytes);
250    let result = hasher.finalize();
251
252    if result[..] != received_hash[..] {
253        return Err(HttpError(
254            ServerError::Unauthorized,
255            StatusCode::NOT_ACCEPTABLE,
256        ));
257    }
258
259    receive::incoming_sample(
260        state.clone(),
261        bytes,
262        user.id,
263        payload.source_id,
264        payload.file_name,
265    )
266    .await?;
267
268    Ok(StatusCode::OK)
269}
270
271async fn upload_new_sample_cbor(
272    Extension(state): Extension<Arc<State>>,
273    Extension(user): Extension<Arc<UserInfo>>,
274    Cbor(payload): Cbor<NewSampleBytes>,
275) -> Result<StatusCode, HttpError> {
276    let allowed = state
277        .db_type
278        .allowed_user_source(user.id, payload.source_id)
279        .await?;
280
281    if !allowed {
282        return Err(HttpError(
283            ServerError::Unauthorized,
284            StatusCode::UNAUTHORIZED,
285        ));
286    }
287
288    let received_hash = hex::decode(&payload.sha256)?;
289    let bytes = payload.file_contents;
290    let mut hasher = Sha256::new();
291    hasher.update(&bytes);
292    let result = hasher.finalize();
293
294    if result[..] != received_hash[..] {
295        return Err(HttpError(
296            ServerError::Unauthorized,
297            StatusCode::NOT_ACCEPTABLE,
298        ));
299    }
300
301    receive::incoming_sample(
302        state.clone(),
303        bytes,
304        user.id,
305        payload.source_id,
306        payload.file_name,
307    )
308    .await?;
309
310    Ok(StatusCode::OK)
311}
312
313async fn sample_search(
314    Extension(state): Extension<Arc<State>>,
315    Extension(user): Extension<Arc<UserInfo>>,
316    Json(payload): Json<SearchRequest>,
317) -> Result<Json<ServerResponse<SearchResponse>>, HttpError> {
318    let hashes = ServerResponse::Success(state.db_type.partial_search(user.id, payload).await?);
319    Ok(Json(hashes))
320}
321
322async fn download_sample(
323    Path(hash): Path<String>,
324    Extension(user): Extension<Arc<UserInfo>>,
325    Extension(state): Extension<Arc<State>>,
326) -> Result<Response, HttpError> {
327    if state.directory.is_none() {
328        return Err(HttpError(
329            ServerError::NoSamples,
330            StatusCode::NOT_ACCEPTABLE,
331        ));
332    }
333
334    let hash = HashType::try_from(hash)?;
335    let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
336
337    let contents = state.retrieve_bytes(&sha256).await?;
338
339    let mut bytes = Bytes::from(contents).into_response();
340    let name_header_value = format!("attachment; filename=\"{sha256}\"");
341    bytes.headers_mut().insert(
342        header::CONTENT_DISPOSITION,
343        HeaderValue::from_str(&name_header_value)
344            .unwrap_or(HeaderValue::from_static("Unknown.bin")),
345    );
346
347    Ok(bytes)
348}
349
350async fn download_sample_cart(
351    Path(hash): Path<String>,
352    Extension(user): Extension<Arc<UserInfo>>,
353    Extension(state): Extension<Arc<State>>,
354) -> Result<Response, HttpError> {
355    if state.directory.is_none() {
356        return Err(HttpError(
357            ServerError::NoSamples,
358            StatusCode::NOT_ACCEPTABLE,
359        ));
360    }
361
362    let hash = HashType::try_from(hash)?;
363    let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
364    let report = state.db_type.get_sample_report(user.id, &hash).await?;
365
366    let contents = state.retrieve_bytes(&sha256).await?;
367    let contents_cursor = Cursor::new(contents);
368    let mut output_cursor = Cursor::new(vec![]);
369
370    let mut output_metadata = cart_container::JsonMap::new();
371    output_metadata.insert("sha384".into(), report.sha384.into());
372    output_metadata.insert("sha512".into(), report.sha512.into());
373    output_metadata.insert("entropy".into(), report.entropy.into());
374    if let Some(filecmd) = report.filecommand {
375        output_metadata.insert("file".into(), filecmd.into());
376    }
377
378    cart_container::pack_stream(
379        contents_cursor,
380        &mut output_cursor,
381        Some(output_metadata),
382        None,
383        cart_container::digesters::default_digesters(), // MD-5, SHA-1, SHA-256 applied here
384        None,
385    )?;
386
387    let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
388    let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
389    bytes.headers_mut().insert(
390        header::CONTENT_DISPOSITION,
391        HeaderValue::from_str(&name_header_value)
392            .unwrap_or(HeaderValue::from_static("Unknown.cart")),
393    );
394
395    Ok(bytes)
396}
397
398async fn sample_report(
399    Path(hash): Path<String>,
400    Extension(user): Extension<Arc<UserInfo>>,
401    Extension(state): Extension<Arc<State>>,
402) -> Result<Json<ServerResponse<Report>>, HttpError> {
403    let hash = HashType::try_from(hash)?;
404    let report =
405        ServerResponse::<Report>::Success(state.db_type.get_sample_report(user.id, &hash).await?);
406    Ok(Json(report))
407}
408
409async fn find_similar(
410    Extension(state): Extension<Arc<State>>,
411    Extension(user): Extension<Arc<UserInfo>>,
412    Json(payload): Json<SimilarSamplesRequest>,
413) -> Result<Json<ServerResponse<SimilarSamplesResponse>>, HttpError> {
414    let results = state
415        .db_type
416        .find_similar_samples(user.id, &payload.hashes)
417        .await?;
418
419    Ok(Json(ServerResponse::Success(SimilarSamplesResponse {
420        results,
421        message: None,
422    })))
423}
424
425// How to use `anyhow::Error` with `axum`:
426// https://github.com/tokio-rs/axum/blob/c97967252de9741b602f400dc2b25c8a33216039/examples/anyhow-error-response/src/main.rs
427
428/// Custom error type to support Axum error handling
429pub struct HttpError(pub ServerError, pub StatusCode);
430
431/// Convert Anyhow error into an Axum response object
432impl IntoResponse for HttpError {
433    fn into_response(self) -> Response {
434        let response: ServerResponse<String> = ServerResponse::Error(self.0);
435        match serde_json::to_string(&response) {
436            Ok(json) => (self.1, json).into_response(),
437            Err(_) => self.1.into_response(),
438        }
439    }
440}
441
442impl Display for HttpError {
443    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
444        write!(f, "{}", self.0)
445    }
446}
447
448/// Enable the use of the ? operator in Axum handler functions
449impl<E> From<E> for HttpError
450where
451    E: Into<anyhow::Error>,
452{
453    fn from(_err: E) -> Self {
454        Self(ServerError::ServerError, StatusCode::INTERNAL_SERVER_ERROR)
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use crate::crypto::{EncryptionOption, FileEncryption};
462    use crate::db::DatabaseType;
463    use malwaredb_api::PartialHashSearchType;
464
465    use std::collections::HashMap;
466    use std::sync::{Once, RwLock};
467    use std::time::{Instant, SystemTime};
468    use std::{env, fs};
469
470    use anyhow::Context;
471    use axum::body::Body;
472    use axum::http::Request;
473    use chrono::Local;
474    use http::header::CONTENT_TYPE;
475    use http_body_util::BodyExt;
476    use rstest::rstest;
477    use tower::ServiceExt;
478    // for `app.oneshot()`
479    use uuid::Uuid;
480
481    const ADMIN_UNAME: &str = "admin";
482    const ADMIN_PASSWORD: &str = "password12345";
483    const SAMPLE_BYTES: &[u8] = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
484
485    static TRACING: Once = Once::new();
486
487    fn init_tracing() {
488        tracing_subscriber::fmt()
489            .with_max_level(tracing::Level::TRACE)
490            .init();
491    }
492
493    async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
494        // Each test needs a separate file, or else they'll clobber each other.
495        let mut db_file = env::temp_dir();
496        db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
497        if std::path::Path::new(&db_file).exists() {
498            fs::remove_file(&db_file)
499                .context(format!("failed to delete old SQLite file {db_file:?}"))
500                .unwrap();
501        }
502
503        let db_type =
504            DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
505                .await
506                .context(format!("failed to create SQLite instance for {db_file:?}"))
507                .unwrap();
508        if compress {
509            db_type.enable_compression().await.unwrap();
510        }
511        let keys = if encrypt {
512            let key = FileEncryption::from(EncryptionOption::Xor);
513            let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
514            let mut keys = HashMap::new();
515            keys.insert(key_id, key);
516            keys
517        } else {
518            HashMap::new()
519        };
520        let db_config = db_type.get_config().await.unwrap();
521
522        let state = State {
523            port: 8080,
524            directory: Some(
525                tempfile::TempDir::with_prefix("mdb-temp-samples")
526                    .unwrap()
527                    .path()
528                    .into(),
529            ),
530            max_upload: 10 * 1024 * 1024,
531            ip: "127.0.0.1".parse().unwrap(),
532            db_type: Arc::new(db_type),
533            db_config,
534            keys,
535            started: SystemTime::now(),
536            #[cfg(feature = "vt")]
537            vt_client: None,
538            cert: None,
539            key: None,
540            mdns: false,
541        };
542
543        state
544            .db_type
545            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
546            .await
547            .context("Failed to set admin password")
548            .unwrap();
549
550        let source_id = state
551            .db_type
552            .create_source("temp-source", None, None, Local::now(), true, Some(false))
553            .await
554            .unwrap();
555
556        state
557            .db_type
558            .add_group_to_source(0, source_id)
559            .await
560            .unwrap();
561
562        (Arc::new(state), source_id)
563    }
564
565    /// Get a new state and generate an API token
566    async fn state_and_token(pem_file: bool) -> (State, u32, String) {
567        let mut db_file = env::temp_dir();
568        db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
569        if std::path::Path::new(&db_file).exists() {
570            fs::remove_file(&db_file)
571                .context(format!("failed to delete old SQLite file {db_file:?}"))
572                .unwrap();
573        }
574
575        let db_type =
576            DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
577                .await
578                .context(format!("failed to create SQLite instance for {db_file:?}"))
579                .unwrap();
580
581        let db_config = db_type.get_config().await.unwrap();
582
583        let state = if pem_file {
584            State {
585                port: 8443,
586                directory: Some(
587                    tempfile::TempDir::with_prefix("mdb-temp-samples")
588                        .unwrap()
589                        .path()
590                        .into(),
591                ),
592                max_upload: 10 * 1024 * 1024,
593                ip: "127.0.0.1".parse().unwrap(),
594                db_type: Arc::new(db_type),
595                db_config,
596                keys: HashMap::default(),
597                started: SystemTime::now(),
598                #[cfg(feature = "vt")]
599                vt_client: None,
600                cert: Some("../../testdata/server_ca_cert.pem".into()),
601                key: Some("../../testdata/server_key.pem".into()),
602                mdns: false,
603            }
604        } else {
605            State {
606                port: 8444,
607                directory: Some(
608                    tempfile::TempDir::with_prefix("mdb-temp-samples")
609                        .unwrap()
610                        .path()
611                        .into(),
612                ),
613                max_upload: 10 * 1024 * 1024,
614                ip: "127.0.0.1".parse().unwrap(),
615                db_type: Arc::new(db_type),
616                db_config,
617                keys: HashMap::default(),
618                started: SystemTime::now(),
619                #[cfg(feature = "vt")]
620                vt_client: None,
621                cert: Some("../../testdata/server_cert.der".into()),
622                key: Some("../../testdata/server_key.der".into()),
623                mdns: false,
624            }
625        };
626
627        state
628            .db_type
629            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
630            .await
631            .context("Failed to set admin password")
632            .unwrap();
633
634        let source_id = state
635            .db_type
636            .create_source("temp-source", None, None, Local::now(), true, Some(false))
637            .await
638            .unwrap();
639
640        state
641            .db_type
642            .add_group_to_source(0, source_id)
643            .await
644            .unwrap();
645
646        let token = state
647            .db_type
648            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
649            .await
650            .unwrap();
651
652        (state, source_id, token)
653    }
654
655    async fn get_key(state: Arc<State>) -> String {
656        let key_request = serde_json::to_string(&GetAPIKeyRequest {
657            user: ADMIN_UNAME.into(),
658            password: ADMIN_PASSWORD.into(),
659        })
660        .context("Failed to convert API key request to JSON")
661        .unwrap();
662
663        let request = Request::builder()
664            .method("POST")
665            .uri(malwaredb_api::USER_LOGIN_URL)
666            .header(CONTENT_TYPE, "application/json")
667            .body(Body::from(key_request))
668            .unwrap();
669
670        let response = app(state)
671            .oneshot(request)
672            .await
673            .context("failed to send/receive login request")
674            .unwrap();
675
676        assert_eq!(response.status(), StatusCode::OK);
677        let bytes = response
678            .into_body()
679            .collect()
680            .await
681            .expect("failed to collect response body to bytes")
682            .to_bytes();
683        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
684            .context("failed to convert response to string")
685            .unwrap();
686
687        let response: ServerResponse<GetAPIKeyResponse> = serde_json::from_str(&json_response)
688            .context("failed to convert json response to object")
689            .unwrap();
690
691        if let ServerResponse::Success(response) = response {
692            let key = response.key.clone();
693            assert_eq!(key.len(), 64);
694
695            key
696        } else {
697            panic!("failed to get API key response")
698        }
699    }
700
701    #[tokio::test]
702    async fn about_self() {
703        let (state, _) = state(false, false).await;
704        let api_key = get_key(state.clone()).await;
705
706        let request = Request::builder()
707            .method("GET")
708            .uri(malwaredb_api::USER_INFO_URL)
709            .header(MDB_API_HEADER, &api_key)
710            .body(Body::empty())
711            .unwrap();
712
713        let response = app(state.clone())
714            .oneshot(request)
715            .await
716            .context("failed to send/receive login request")
717            .unwrap();
718
719        assert_eq!(response.status(), StatusCode::OK);
720        let bytes = response
721            .into_body()
722            .collect()
723            .await
724            .expect("failed to collect response body to bytes")
725            .to_bytes();
726        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
727            .context("failed to convert response to string")
728            .unwrap();
729
730        let response: ServerResponse<GetUserInfoResponse> = serde_json::from_str(&json_response)
731            .context("failed to convert json response to object")
732            .unwrap();
733
734        let response = response.unwrap();
735        assert_eq!(response.id, 0);
736        assert!(response.is_admin);
737        assert!(!response.is_readonly);
738        assert_eq!(response.username, "admin");
739
740        // Check labels, should be empty
741        let request = Request::builder()
742            .method("GET")
743            .uri(malwaredb_api::LIST_LABELS_URL)
744            .header(MDB_API_HEADER, &api_key)
745            .body(Body::empty())
746            .unwrap();
747
748        let response = app(state)
749            .oneshot(request)
750            .await
751            .context("failed to send/receive login request")
752            .unwrap();
753
754        assert_eq!(response.status(), StatusCode::OK);
755        let bytes = response
756            .into_body()
757            .collect()
758            .await
759            .expect("failed to collect response body to bytes")
760            .to_bytes();
761        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
762            .context("failed to convert response to string")
763            .unwrap();
764
765        let response: ServerResponse<Labels> = serde_json::from_str(&json_response)
766            .context("failed to convert json response to object")
767            .unwrap();
768
769        let response = response.unwrap();
770        assert!(response.is_empty());
771    }
772
773    #[rstest]
774    #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true, false)]
775    #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false, false)]
776    #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false, false)]
777    #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false, false)]
778    #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false, false)]
779    #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false, false)]
780    #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false, false)]
781    #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false, false)]
782    #[case::icon_unknown_type_proxy(include_bytes!("../../../../MDB_Logo.ico"), false, false, false, true)]
783    #[tokio::test]
784    async fn submit_sample(
785        #[case] contents: &[u8],
786        #[case] compress: bool,
787        #[case] encrypt: bool,
788        #[case] cart: bool,
789        #[case] should_fail: bool,
790    ) {
791        let (state, source_id) = state(compress, encrypt).await;
792        let api_key = get_key(state.clone()).await;
793
794        let file_contents_b64 = general_purpose::STANDARD.encode(contents);
795        let mut hasher = Sha256::new();
796        hasher.update(contents);
797        let sha256 = hex::encode(hasher.finalize());
798
799        let upload = serde_json::to_string(&NewSampleB64 {
800            file_name: "some_sample".into(),
801            source_id,
802            file_contents_b64,
803            sha256: sha256.clone(),
804        })
805        .context("failed to create upload structure")
806        .unwrap();
807
808        let request = Request::builder()
809            .method("POST")
810            .uri(malwaredb_api::UPLOAD_SAMPLE_JSON_URL)
811            .header(CONTENT_TYPE, "application/json")
812            .header(MDB_API_HEADER, &api_key)
813            .body(Body::from(upload))
814            .unwrap();
815
816        let response = app(state.clone())
817            .oneshot(request)
818            .await
819            .context("failed to send/receive upload request/response")
820            .unwrap();
821
822        if should_fail {
823            assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
824            return;
825        }
826
827        assert_eq!(response.status(), StatusCode::OK);
828
829        let sha256 = if cart {
830            // We have to get the original SHA-256 hash, not the hash of the CaRT container.
831            // This way we can check that the file exists on disk, and request the file again later in this test.
832            let mut input_buffer = Cursor::new(contents);
833            let mut output_buffer = Cursor::new(vec![]);
834            let (_, footer) =
835                cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
836                    .expect("failed to decode CaRT file");
837            let footer = footer.expect("CaRT should have had a footer");
838            let sha256 = footer
839                .get("sha256")
840                .expect("CaRT footer should have had an entry for SHA-256")
841                .to_string();
842            sha256.replace('"', "") // Remove the quotes that get added by JSON
843        } else {
844            sha256
845        };
846
847        if let Some(dir) = &state.directory {
848            let mut sample_path = dir.clone();
849            sample_path.push(format!(
850                "{}/{}/{}/{}",
851                &sha256[0..2],
852                &sha256[2..4],
853                &sha256[4..6],
854                sha256
855            ));
856            eprintln!("Submitted sample should exist at {sample_path:?}.");
857            assert!(sample_path.exists());
858
859            if compress {
860                let sample_size_on_disk = sample_path.metadata().unwrap().len();
861                eprintln!(
862                    "Original size: {}, compressed: {}",
863                    contents.len(),
864                    sample_size_on_disk
865                );
866                assert!(sample_size_on_disk < contents.len() as u64);
867            }
868        } else {
869            panic!("Directory was set for the state, but is now `None`");
870        }
871
872        let request = Request::builder()
873            .method("GET")
874            .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT_URL))
875            .header(MDB_API_HEADER, &api_key)
876            .body(Body::empty())
877            .unwrap();
878
879        let response = app(state.clone())
880            .oneshot(request)
881            .await
882            .context("failed to send/receive upload request/response")
883            .unwrap();
884
885        let bytes = response
886            .into_body()
887            .collect()
888            .await
889            .expect("failed to collect response body to bytes")
890            .to_bytes();
891        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
892            .context("failed to convert response to string")
893            .unwrap();
894
895        let report: ServerResponse<Report> = serde_json::from_str(&json_response)
896            .context("failed to convert json response to object")
897            .unwrap();
898        let report = report.unwrap();
899
900        assert_eq!(report.sha256, sha256);
901        println!("Report: {report}");
902
903        let request = Request::builder()
904            .method("GET")
905            .uri(format!(
906                "{}/{sha256}",
907                malwaredb_api::DOWNLOAD_SAMPLE_CART_URL
908            ))
909            .header(MDB_API_HEADER, api_key)
910            .body(Body::empty())
911            .unwrap();
912
913        let response = app(state.clone())
914            .oneshot(request)
915            .await
916            .context("failed to send/receive upload request/response for CaRT")
917            .unwrap();
918
919        let bytes = response
920            .into_body()
921            .collect()
922            .await
923            .expect("failed to collect response body to bytes")
924            .to_bytes();
925
926        let bytes = bytes.to_vec();
927        let bytes_input = Cursor::new(bytes);
928        let output = Cursor::new(vec![]);
929        match cart_container::unpack_stream(bytes_input, output, None) {
930            Ok((header, _)) => {
931                let header = header.unwrap();
932                assert_eq!(
933                    header.get("sha384"),
934                    Some(&serde_json::to_value(report.sha384).unwrap())
935                );
936            }
937            Err(e) => panic!("{e}"),
938        }
939    }
940
941    /// Integration test between Malware DB Server & Client. This was under "/tests" in the project
942    /// but `cargo hack` would disable needed features, and there wasn't a way to use `cfg` to
943    /// put a condition on the test. Additionally, such an integration test required additional
944    /// fields of the Server's state to be public or to implement `Default`, which isn't ideal.
945    #[rstest]
946    #[case::ssl_pem(true)]
947    #[case::ssl_der(false)]
948    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
949    async fn client_integration(#[case] pem: bool) {
950        TRACING.call_once(init_tracing);
951
952        // This other get state function was needed, or else there were complaints of
953        // `state` being moved and unable to be used in `tokio::spawn()`
954        let (state, source_id, token) = state_and_token(pem).await;
955        let state_port = state.port; // copy before move below
956
957        let server = tokio::spawn(async move {
958            state
959                .serve()
960                .await
961                .expect("MalwareDB failed to .serve() in tokio::spawn()");
962        });
963        assert!(!server.is_finished());
964
965        // In case start-up time is needed
966        tokio::time::sleep(std::time::Duration::new(1, 0)).await;
967
968        println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
969        assert_eq!(
970            *malwaredb_client::MDB_VERSION_SEMVER,
971            *crate::MDB_VERSION_SEMVER,
972            "SemVer parsing of MDB version failed"
973        );
974
975        let mdb_client = malwaredb_client::MdbClient::new(
976            format!("https://127.0.0.1:{state_port}"),
977            token.clone(),
978            Some("../../testdata/ca_cert.pem".into()),
979        )
980        .unwrap();
981
982        assert!(!mdb_client.supported_types().await.unwrap().types.is_empty());
983        assert!(mdb_client.server_info().await.is_ok());
984
985        let start = Instant::now();
986        assert!(mdb_client
987            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
988            .await
989            .context("failed to upload test file")
990            .unwrap());
991        let duration = start.elapsed();
992        println!("Initial upload and database record creation via base64 took {duration:?}");
993
994        let start = Instant::now();
995        mdb_client
996            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
997            .await
998            .context("failed to upload test file")
999            .unwrap();
1000        let duration = start.elapsed();
1001        println!("Upload again via base64 took {duration:?}");
1002
1003        let start = Instant::now();
1004        mdb_client
1005            .submit_as_cbor(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1006            .await
1007            .context("failed to upload test file")
1008            .unwrap();
1009        let duration = start.elapsed();
1010        println!("Upload again via cbor took {duration:?}");
1011
1012        let report = mdb_client
1013            .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1014            .await
1015            .expect("failed to get report for file just submitted");
1016        assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1017
1018        let similar = mdb_client
1019            .similar(SAMPLE_BYTES)
1020            .await
1021            .expect("failed to query for files similar to what was just submitted");
1022        assert_eq!(similar.results.len(), 1);
1023
1024        let search = mdb_client
1025            .partial_search(
1026                Some((PartialHashSearchType::Any, "AAAA".into())),
1027                None,
1028                PartialHashSearchType::Any,
1029                10,
1030            )
1031            .await
1032            .unwrap();
1033        assert!(search.hashes.is_empty());
1034
1035        mdb_client.reset_key().await.expect("failed to reset key");
1036        server.abort();
1037    }
1038
1039    #[allow(clippy::too_many_lines)]
1040    #[test]
1041    #[ignore = "don't run this in CI"]
1042    fn client_integration_blocking() {
1043        TRACING.call_once(init_tracing);
1044        let token = Arc::new(RwLock::new(String::new()));
1045        let token_clone = token.clone();
1046
1047        // There has to be a better way to do this. We have to create a new thread to run Tokio,
1048        // otherwise the closing of the functions used by the blocking client will cause
1049        // Tokio to panic with this error:
1050        // Cannot drop a runtime in a context where blocking is not allowed. This happens when a runtime is dropped from within an asynchronous context.
1051        // So we wrap that all in another thread.
1052        let thread = std::thread::spawn(move || {
1053            let rt = tokio::runtime::Builder::new_multi_thread()
1054                .enable_all()
1055                .build()
1056                .unwrap();
1057
1058            let mut db_file = env::temp_dir();
1059            db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
1060            if std::path::Path::new(&db_file).exists() {
1061                fs::remove_file(&db_file)
1062                    .context(format!("failed to delete old SQLite file {db_file:?}"))
1063                    .unwrap();
1064            }
1065
1066            let (db_type, db_config) = rt.block_on(async {
1067                let db_type =
1068                    DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
1069                        .await
1070                        .context(format!("failed to create SQLite instance for {db_file:?}"))
1071                        .unwrap();
1072
1073                let db_config = db_type.get_config().await.unwrap();
1074                (db_type, db_config)
1075            });
1076
1077            let state = State {
1078                port: 9090,
1079                directory: Some(
1080                    tempfile::TempDir::with_prefix("mdb-temp-samples")
1081                        .unwrap()
1082                        .path()
1083                        .into(),
1084                ),
1085                max_upload: 10 * 1024 * 1024,
1086                ip: "127.0.0.1".parse().unwrap(),
1087                db_type: Arc::new(db_type),
1088                db_config,
1089                keys: HashMap::default(),
1090                started: SystemTime::now(),
1091                #[cfg(feature = "vt")]
1092                vt_client: None,
1093                cert: None,
1094                key: None,
1095                mdns: false,
1096            };
1097
1098            rt.block_on(async {
1099                state
1100                    .db_type
1101                    .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1102                    .await
1103                    .context("Failed to set admin password")
1104                    .unwrap();
1105
1106                let source_id = state
1107                    .db_type
1108                    .create_source("temp-source", None, None, Local::now(), true, Some(false))
1109                    .await
1110                    .unwrap();
1111
1112                state
1113                    .db_type
1114                    .add_group_to_source(0, source_id)
1115                    .await
1116                    .unwrap();
1117
1118                let token_string = state
1119                    .db_type
1120                    .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1121                    .await
1122                    .unwrap();
1123
1124                if let Ok(mut token_lock) = token_clone.write() {
1125                    *token_lock = token_string;
1126                }
1127
1128                state
1129                    .serve()
1130                    .await
1131                    .expect("MalwareDB failed to .serve() in tokio::spawn()");
1132            });
1133        });
1134        std::thread::sleep(std::time::Duration::from_secs(1));
1135
1136        let mdb_client = malwaredb_client::blocking::MdbClient::new(
1137            String::from("http://127.0.0.1:9090"),
1138            token.read().unwrap().clone(),
1139            None,
1140        )
1141        .unwrap();
1142
1143        let types = match mdb_client.supported_types() {
1144            Ok(types) => types,
1145            Err(e) => panic!("{e}"),
1146        };
1147        assert!(!types.types.is_empty());
1148
1149        assert!(mdb_client
1150            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), 1)
1151            .context("failed to upload test file")
1152            .unwrap());
1153
1154        let report = mdb_client
1155            .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1156            .expect("failed to get report for file just submitted");
1157        assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1158
1159        let similar = mdb_client
1160            .similar(SAMPLE_BYTES)
1161            .expect("failed to query for files similar to what was just submitted");
1162        assert_eq!(similar.results.len(), 1);
1163
1164        let search = mdb_client
1165            .partial_search(
1166                Some((PartialHashSearchType::Any, "AAAA".into())),
1167                None,
1168                PartialHashSearchType::Any,
1169                10,
1170            )
1171            .unwrap();
1172        assert!(search.hashes.is_empty());
1173
1174        mdb_client.reset_key().expect("failed to reset key");
1175        drop(thread); // force shutdown of the thread, as `thread.join()` just waits forever.
1176    }
1177}