malwaredb_server/http/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::State;
4use malwaredb_api::digest::HashType;
5use malwaredb_api::{
6    GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSampleB64, NewSampleBytes,
7    Report, SearchRequest, SearchResponse, ServerError, ServerInfo, ServerResponse,
8    SimilarSamplesRequest, SimilarSamplesResponse, Sources, SupportedFileTypes, MDB_API_HEADER,
9};
10
11use std::fmt::{Display, Formatter};
12use std::io::Cursor;
13use std::iter::once;
14use std::sync::Arc;
15
16use axum::body::Bytes;
17use axum::extract::{DefaultBodyLimit, Path, Request};
18use axum::http::{header, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::routing::{get, post};
22use axum::{middleware, Extension, Json, Router};
23use axum_cbor::Cbor;
24use base64::{engine::general_purpose, Engine as _};
25use constcat::concat;
26use http::{HeaderMap, HeaderName, HeaderValue};
27use sha2::{Digest, Sha256};
28use tower_http::compression::CompressionLayer;
29use tower_http::decompression::DecompressionLayer;
30use tower_http::limit::RequestBodyLimitLayer;
31use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
32
33mod receive;
34
35const FAVICON_URL: &str = "/favicon.ico";
36
37/// Build the web service from the initial server state object
38pub fn app(state: Arc<State>) -> Router {
39    // This is an estimate, so err on the side of usability.
40    const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSampleB64>>() * 2;
41
42    let compression_layer = CompressionLayer::new()
43        .br(true)
44        .deflate(true)
45        .gzip(true)
46        .zstd(true);
47
48    let decompression_layer = DecompressionLayer::new()
49        .br(true)
50        .deflate(true)
51        .gzip(true)
52        .zstd(true);
53
54    let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
55
56    Router::new()
57        .route("/", get(health))
58        .route(FAVICON_URL, get(favicon))
59        .route(malwaredb_api::SERVER_INFO_URL, get(get_mdb_info))
60        .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
61        .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
62        .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
63        .route(
64            malwaredb_api::SUPPORTED_FILE_TYPES_URL,
65            get(get_supported_types),
66        )
67        .route(malwaredb_api::LIST_LABELS_URL, get(get_labels))
68        .route(malwaredb_api::LIST_SOURCES_URL, get(get_sources))
69        .route(
70            malwaredb_api::UPLOAD_SAMPLE_JSON_URL,
71            post(upload_new_sample_json),
72        )
73        .route(
74            malwaredb_api::UPLOAD_SAMPLE_CBOR_URL,
75            post(upload_new_sample_cbor),
76        )
77        .route(malwaredb_api::SEARCH_URL, post(sample_search))
78        .route(
79            concat!(malwaredb_api::DOWNLOAD_SAMPLE_URL, "/{hash}"),
80            get(download_sample),
81        )
82        .route(
83            concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART_URL, "/{hash}"),
84            get(download_sample_cart),
85        )
86        .route(
87            concat!(malwaredb_api::SAMPLE_REPORT_URL, "/{hash}"),
88            get(sample_report),
89        )
90        .route(malwaredb_api::SIMILAR_SAMPLES_URL, post(find_similar))
91        .layer(DefaultBodyLimit::max(state.max_upload))
92        .layer(compression_layer)
93        .layer(decompression_layer)
94        .layer(size_limit_layer)
95        .layer(SetSensitiveHeadersLayer::new(once(
96            HeaderName::from_static(MDB_API_HEADER),
97        )))
98        .route_layer(middleware::from_fn(response_header_middleware))
99        .layer(Extension(state))
100}
101
102/// User ID from the API key, generated by the middleware function
103struct UserInfo {
104    pub id: u32,
105}
106
107/// * Ask that content from Malware DB not be cached
108/// * One central place to get the API key and get the user's ID
109async fn response_header_middleware(
110    Extension(state): Extension<Arc<State>>,
111    headers: HeaderMap,
112    mut req: Request,
113    next: Next,
114) -> Result<Response, HttpError> {
115    const ALWAYS_ALLOWED_ENDPOINTS: [&str; 4] = [
116        "/",
117        FAVICON_URL,
118        malwaredb_api::SERVER_INFO_URL,
119        malwaredb_api::USER_LOGIN_URL,
120    ];
121
122    if !ALWAYS_ALLOWED_ENDPOINTS.contains(&req.uri().path()) {
123        let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
124            ServerError::Unauthorized,
125            StatusCode::NOT_ACCEPTABLE,
126        ))?;
127        let key = key
128            .to_str()
129            .map_err(|_| HttpError(ServerError::Unauthorized, StatusCode::NOT_ACCEPTABLE))?;
130        let uid = state.db_type.get_uid(key).await.map_err(|e| {
131            tracing::warn!("Failed to get user ID from API key: {e}");
132            HttpError(ServerError::Unauthorized, StatusCode::UNAUTHORIZED)
133        })?;
134        req.extensions_mut().insert(Arc::new(UserInfo { id: uid }));
135    }
136
137    let mut response = next.run(req).await; // Run the next service in the chain
138
139    // `no-store` so an intermediate proxy doesn't save any data
140    // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Cache-Control
141    response
142        .headers_mut()
143        .insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store"));
144
145    Ok(response)
146}
147
148async fn health() -> StatusCode {
149    StatusCode::OK
150}
151
152async fn favicon() -> Response {
153    const ICON: Bytes = Bytes::from_static(include_bytes!("../../MDB_Logo.ico"));
154
155    let mut bytes = ICON.into_response();
156    bytes.headers_mut().insert(
157        header::CONTENT_TYPE,
158        HeaderValue::from_static("image/vnd.microsoft.icon"),
159    );
160
161    bytes
162}
163
164async fn get_mdb_info(
165    Extension(state): Extension<Arc<State>>,
166) -> Result<Json<ServerResponse<ServerInfo>>, HttpError> {
167    let server_info = ServerResponse::Success(state.get_info().await?);
168    Ok(Json(server_info))
169}
170
171async fn user_login(
172    Extension(state): Extension<Arc<State>>,
173    Json(payload): Json<GetAPIKeyRequest>,
174) -> Result<Json<ServerResponse<GetAPIKeyResponse>>, HttpError> {
175    let api_key = state
176        .db_type
177        .authenticate(&payload.user, &payload.password)
178        .await?;
179
180    Ok(Json(ServerResponse::Success(GetAPIKeyResponse {
181        key: api_key,
182        message: None,
183    })))
184}
185
186async fn user_logout(
187    Extension(state): Extension<Arc<State>>,
188    Extension(user): Extension<Arc<UserInfo>>,
189) -> Result<StatusCode, HttpError> {
190    state.db_type.reset_own_api_key(user.id).await?;
191    Ok(StatusCode::OK)
192}
193
194async fn get_user_groups_sources(
195    Extension(state): Extension<Arc<State>>,
196    Extension(user): Extension<Arc<UserInfo>>,
197) -> Result<Json<ServerResponse<GetUserInfoResponse>>, HttpError> {
198    let groups_sources = ServerResponse::Success(state.db_type.get_user_info(user.id).await?);
199    Ok(Json(groups_sources))
200}
201
202async fn get_supported_types(
203    Extension(state): Extension<Arc<State>>,
204) -> Result<Json<ServerResponse<SupportedFileTypes>>, HttpError> {
205    let data_types = state.db_type.get_known_data_types().await?;
206    let file_types = ServerResponse::Success(SupportedFileTypes {
207        types: data_types.into_iter().map(Into::into).collect(),
208        message: None,
209    });
210    Ok(Json(file_types))
211}
212
213async fn get_labels(
214    Extension(state): Extension<Arc<State>>,
215) -> Result<Json<ServerResponse<Labels>>, HttpError> {
216    let labels = ServerResponse::Success(state.db_type.get_labels().await?);
217    Ok(Json(labels))
218}
219
220async fn get_sources(
221    Extension(state): Extension<Arc<State>>,
222    Extension(user): Extension<Arc<UserInfo>>,
223) -> Result<Json<ServerResponse<Sources>>, HttpError> {
224    let sources = ServerResponse::Success(state.db_type.get_user_sources(user.id).await?);
225    Ok(Json(sources))
226}
227
228async fn upload_new_sample_json(
229    Extension(state): Extension<Arc<State>>,
230    Extension(user): Extension<Arc<UserInfo>>,
231    Json(payload): Json<NewSampleB64>,
232) -> Result<StatusCode, HttpError> {
233    let allowed = state
234        .db_type
235        .allowed_user_source(user.id, payload.source_id)
236        .await?;
237
238    if !allowed {
239        return Err(HttpError(
240            ServerError::Unauthorized,
241            StatusCode::UNAUTHORIZED,
242        ));
243    }
244
245    let received_hash = hex::decode(&payload.sha256)?;
246    let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
247
248    let mut hasher = Sha256::new();
249    hasher.update(&bytes);
250    let result = hasher.finalize();
251
252    if result[..] != received_hash[..] {
253        return Err(HttpError(
254            ServerError::Unauthorized,
255            StatusCode::NOT_ACCEPTABLE,
256        ));
257    }
258
259    receive::incoming_sample(
260        state.clone(),
261        bytes,
262        user.id,
263        payload.source_id,
264        payload.file_name,
265    )
266    .await?;
267
268    Ok(StatusCode::OK)
269}
270
271async fn upload_new_sample_cbor(
272    Extension(state): Extension<Arc<State>>,
273    Extension(user): Extension<Arc<UserInfo>>,
274    Cbor(payload): Cbor<NewSampleBytes>,
275) -> Result<StatusCode, HttpError> {
276    let allowed = state
277        .db_type
278        .allowed_user_source(user.id, payload.source_id)
279        .await?;
280
281    if !allowed {
282        return Err(HttpError(
283            ServerError::Unauthorized,
284            StatusCode::UNAUTHORIZED,
285        ));
286    }
287
288    let received_hash = hex::decode(&payload.sha256)?;
289    let bytes = payload.file_contents;
290    let mut hasher = Sha256::new();
291    hasher.update(&bytes);
292    let result = hasher.finalize();
293
294    if result[..] != received_hash[..] {
295        return Err(HttpError(
296            ServerError::Unauthorized,
297            StatusCode::NOT_ACCEPTABLE,
298        ));
299    }
300
301    receive::incoming_sample(
302        state.clone(),
303        bytes,
304        user.id,
305        payload.source_id,
306        payload.file_name,
307    )
308    .await?;
309
310    Ok(StatusCode::OK)
311}
312
313async fn sample_search(
314    Extension(state): Extension<Arc<State>>,
315    Extension(user): Extension<Arc<UserInfo>>,
316    Json(payload): Json<SearchRequest>,
317) -> Result<Json<ServerResponse<SearchResponse>>, HttpError> {
318    let hashes = ServerResponse::Success(state.db_type.partial_search(user.id, payload).await?);
319    Ok(Json(hashes))
320}
321
322async fn download_sample(
323    Path(hash): Path<String>,
324    Extension(user): Extension<Arc<UserInfo>>,
325    Extension(state): Extension<Arc<State>>,
326) -> Result<Response, HttpError> {
327    if state.directory.is_none() {
328        return Err(HttpError(
329            ServerError::NoSamples,
330            StatusCode::NOT_ACCEPTABLE,
331        ));
332    }
333
334    let hash = HashType::try_from(hash.as_str())?;
335    let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
336
337    // Ensure we're sending the content-digest as SHA-256, since we could have received MD5 or SHA1.
338    let hash = HashType::try_from(sha256.as_str())?;
339
340    let contents = state.retrieve_bytes(&sha256).await?;
341
342    let mut bytes = Bytes::from(contents).into_response();
343    let name_header_value = format!("attachment; filename=\"{sha256}\"");
344    bytes.headers_mut().insert(
345        header::CONTENT_DISPOSITION,
346        HeaderValue::from_str(&name_header_value)
347            .unwrap_or(HeaderValue::from_static("Unknown.bin")),
348    );
349
350    bytes.headers_mut().insert(
351        "content-digest",
352        HeaderValue::from_str(&hash.content_digest_header())?,
353    );
354
355    Ok(bytes)
356}
357
358async fn download_sample_cart(
359    Path(hash): Path<String>,
360    Extension(user): Extension<Arc<UserInfo>>,
361    Extension(state): Extension<Arc<State>>,
362) -> Result<Response, HttpError> {
363    if state.directory.is_none() {
364        return Err(HttpError(
365            ServerError::NoSamples,
366            StatusCode::NOT_ACCEPTABLE,
367        ));
368    }
369
370    let hash = HashType::try_from(hash.as_str())?;
371    let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
372    let report = state.db_type.get_sample_report(user.id, &hash).await?;
373
374    let contents = state.retrieve_bytes(&sha256).await?;
375    let contents_cursor = Cursor::new(contents);
376    let mut output_cursor = Cursor::new(vec![]);
377
378    let mut output_metadata = cart_container::JsonMap::new();
379    output_metadata.insert("sha384".into(), report.sha384.into());
380    output_metadata.insert("sha512".into(), report.sha512.into());
381    output_metadata.insert("entropy".into(), report.entropy.into());
382    if let Some(filecmd) = report.filecommand {
383        output_metadata.insert("file".into(), filecmd.into());
384    }
385
386    cart_container::pack_stream(
387        contents_cursor,
388        &mut output_cursor,
389        Some(output_metadata),
390        None,
391        cart_container::digesters::default_digesters(), // MD-5, SHA-1, SHA-256 applied here
392        None,
393    )?;
394
395    let mut hasher = Sha256::new();
396    hasher.update(output_cursor.get_ref());
397    let hash_b64 = general_purpose::STANDARD.encode(hasher.finalize());
398
399    let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
400    let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
401    bytes.headers_mut().insert(
402        header::CONTENT_DISPOSITION,
403        HeaderValue::from_str(&name_header_value)
404            .unwrap_or(HeaderValue::from_static("Unknown.cart")),
405    );
406    bytes.headers_mut().insert(
407        "content-digest",
408        HeaderValue::from_str(&format!("sha-256=:{hash_b64}:"))?,
409    );
410
411    Ok(bytes)
412}
413
414async fn sample_report(
415    Path(hash): Path<String>,
416    Extension(user): Extension<Arc<UserInfo>>,
417    Extension(state): Extension<Arc<State>>,
418) -> Result<Json<ServerResponse<Report>>, HttpError> {
419    let hash = HashType::try_from(hash.as_str())?;
420    let report =
421        ServerResponse::<Report>::Success(state.db_type.get_sample_report(user.id, &hash).await?);
422    Ok(Json(report))
423}
424
425async fn find_similar(
426    Extension(state): Extension<Arc<State>>,
427    Extension(user): Extension<Arc<UserInfo>>,
428    Json(payload): Json<SimilarSamplesRequest>,
429) -> Result<Json<ServerResponse<SimilarSamplesResponse>>, HttpError> {
430    let results = state
431        .db_type
432        .find_similar_samples(user.id, &payload.hashes)
433        .await?;
434
435    Ok(Json(ServerResponse::Success(SimilarSamplesResponse {
436        results,
437        message: None,
438    })))
439}
440
441// How to use `anyhow::Error` with `axum`:
442// https://github.com/tokio-rs/axum/blob/c97967252de9741b602f400dc2b25c8a33216039/examples/anyhow-error-response/src/main.rs
443
444/// Custom error type to support Axum error handling
445pub struct HttpError(pub ServerError, pub StatusCode);
446
447/// Convert Anyhow error into an Axum response object
448impl IntoResponse for HttpError {
449    fn into_response(self) -> Response {
450        let response: ServerResponse<String> = ServerResponse::Error(self.0);
451        match serde_json::to_string(&response) {
452            Ok(json) => (self.1, json).into_response(),
453            Err(_) => self.1.into_response(),
454        }
455    }
456}
457
458impl Display for HttpError {
459    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
460        write!(f, "{}", self.0)
461    }
462}
463
464/// Enable the use of the ? operator in Axum handler functions
465impl<E> From<E> for HttpError
466where
467    E: Into<anyhow::Error>,
468{
469    fn from(_err: E) -> Self {
470        Self(ServerError::ServerError, StatusCode::INTERNAL_SERVER_ERROR)
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::crypto::{EncryptionOption, FileEncryption};
478    use crate::db::DatabaseType;
479    use crate::StateBuilder;
480    use malwaredb_api::PartialHashSearchType;
481
482    use std::collections::HashMap;
483    use std::path::PathBuf;
484    use std::sync::{Once, RwLock};
485    use std::time::{Instant, SystemTime};
486    use std::{env, fs};
487
488    use anyhow::Context;
489    use axum::body::Body;
490    use axum::http::Request;
491    use chrono::Local;
492    use http::header::CONTENT_TYPE;
493    use http_body_util::BodyExt;
494    use rstest::rstest;
495    use tower::ServiceExt;
496    // for `app.oneshot()`
497    use uuid::Uuid;
498
499    const ADMIN_UNAME: &str = "admin";
500    const ADMIN_PASSWORD: &str = "password12345";
501    const SAMPLE_BYTES: &[u8] = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
502
503    static TRACING: Once = Once::new();
504
505    fn init_tracing() {
506        tracing_subscriber::fmt()
507            .with_max_level(tracing::Level::TRACE)
508            .init();
509    }
510
511    async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
512        TRACING.call_once(init_tracing);
513
514        // Each test needs a separate file, or else they'll clobber each other.
515        let mut db_file = env::temp_dir();
516        db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
517        if std::path::Path::new(&db_file).exists() {
518            fs::remove_file(&db_file)
519                .context(format!("failed to delete old SQLite file {db_file:?}"))
520                .unwrap();
521        }
522
523        let db_type =
524            DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
525                .await
526                .context(format!("failed to create SQLite instance for {db_file:?}"))
527                .unwrap();
528        if compress {
529            db_type.enable_compression().await.unwrap();
530        }
531        let keys = if encrypt {
532            let key = FileEncryption::from(EncryptionOption::Xor);
533            let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
534            let mut keys = HashMap::new();
535            keys.insert(key_id, key);
536            keys
537        } else {
538            HashMap::new()
539        };
540        let db_config = db_type.get_config().await.unwrap();
541
542        let state = State {
543            port: 8080,
544            directory: Some(
545                tempfile::TempDir::with_prefix("mdb-temp-samples")
546                    .unwrap()
547                    .path()
548                    .into(),
549            ),
550            max_upload: 10 * 1024 * 1024,
551            ip: "127.0.0.1".parse().unwrap(),
552            db_type: Arc::new(db_type),
553            db_config,
554            keys,
555            started: SystemTime::now(),
556            #[cfg(feature = "vt")]
557            vt_client: None,
558            tls_config: None,
559            mdns: None,
560        };
561
562        state
563            .db_type
564            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
565            .await
566            .context("Failed to set admin password")
567            .unwrap();
568
569        let source_id = state
570            .db_type
571            .create_source("temp-source", None, None, Local::now(), true, Some(false))
572            .await
573            .unwrap();
574
575        state
576            .db_type
577            .add_group_to_source(0, source_id)
578            .await
579            .unwrap();
580
581        (Arc::new(state), source_id)
582    }
583
584    /// Get a new state and generate an API token
585    async fn state_and_token(pem_file: bool) -> (State, u32, String) {
586        TRACING.call_once(init_tracing);
587
588        let mut db_file = env::temp_dir();
589        db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
590        if std::path::Path::new(&db_file).exists() {
591            fs::remove_file(&db_file)
592                .context(format!("failed to delete old SQLite file {db_file:?}"))
593                .unwrap();
594        }
595
596        let mut builder = StateBuilder::new(&format!("file:{}", db_file.display()), None)
597            .await
598            .unwrap();
599        builder = builder.directory(PathBuf::from(
600            tempfile::TempDir::with_prefix("mdb-temp-samples")
601                .unwrap()
602                .path(),
603        ));
604        if pem_file {
605            builder = builder.port(8443);
606            builder = builder
607                .tls(
608                    "../../testdata/server_ca_cert.pem".into(),
609                    "../../testdata/server_key.pem".into(),
610                )
611                .await
612                .unwrap();
613        } else {
614            builder = builder.port(8444);
615            builder = builder
616                .tls(
617                    "../../testdata/server_cert.der".into(),
618                    "../../testdata/server_key.der".into(),
619                )
620                .await
621                .unwrap();
622        }
623
624        let state = builder.into_state().await.unwrap();
625
626        state
627            .db_type
628            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
629            .await
630            .context("Failed to set admin password")
631            .unwrap();
632
633        let source_id = state
634            .db_type
635            .create_source("temp-source", None, None, Local::now(), true, Some(false))
636            .await
637            .unwrap();
638
639        state
640            .db_type
641            .add_group_to_source(0, source_id)
642            .await
643            .unwrap();
644
645        let token = state
646            .db_type
647            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
648            .await
649            .unwrap();
650
651        (state, source_id, token)
652    }
653
654    async fn get_key(state: Arc<State>) -> String {
655        TRACING.call_once(init_tracing);
656
657        let key_request = serde_json::to_string(&GetAPIKeyRequest {
658            user: ADMIN_UNAME.into(),
659            password: ADMIN_PASSWORD.into(),
660        })
661        .context("Failed to convert API key request to JSON")
662        .unwrap();
663
664        let request = Request::builder()
665            .method("POST")
666            .uri(malwaredb_api::USER_LOGIN_URL)
667            .header(CONTENT_TYPE, "application/json")
668            .body(Body::from(key_request))
669            .unwrap();
670
671        let response = app(state)
672            .oneshot(request)
673            .await
674            .context("failed to send/receive login request")
675            .unwrap();
676
677        assert_eq!(response.status(), StatusCode::OK);
678        let bytes = response
679            .into_body()
680            .collect()
681            .await
682            .expect("failed to collect response body to bytes")
683            .to_bytes();
684        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
685            .context("failed to convert response to string")
686            .unwrap();
687
688        let response: ServerResponse<GetAPIKeyResponse> = serde_json::from_str(&json_response)
689            .context("failed to convert json response to object")
690            .unwrap();
691
692        if let ServerResponse::Success(response) = response {
693            let key = response.key.clone();
694            assert_eq!(key.len(), 64);
695
696            key
697        } else {
698            panic!("failed to get API key response")
699        }
700    }
701
702    #[tokio::test]
703    async fn about_self() {
704        let (state, _) = state(false, false).await;
705        let api_key = get_key(state.clone()).await;
706
707        let request = Request::builder()
708            .method("GET")
709            .uri(malwaredb_api::USER_INFO_URL)
710            .header(MDB_API_HEADER, &api_key)
711            .body(Body::empty())
712            .unwrap();
713
714        let response = app(state.clone())
715            .oneshot(request)
716            .await
717            .context("failed to send/receive login request")
718            .unwrap();
719
720        assert_eq!(response.status(), StatusCode::OK);
721        let bytes = response
722            .into_body()
723            .collect()
724            .await
725            .expect("failed to collect response body to bytes")
726            .to_bytes();
727        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
728            .context("failed to convert response to string")
729            .unwrap();
730
731        let response: ServerResponse<GetUserInfoResponse> = serde_json::from_str(&json_response)
732            .context("failed to convert json response to object")
733            .unwrap();
734
735        let response = response.unwrap();
736        assert_eq!(response.id, 0);
737        assert!(response.is_admin);
738        assert!(!response.is_readonly);
739        assert_eq!(response.username, "admin");
740
741        // Check labels, should be empty
742        let request = Request::builder()
743            .method("GET")
744            .uri(malwaredb_api::LIST_LABELS_URL)
745            .header(MDB_API_HEADER, &api_key)
746            .body(Body::empty())
747            .unwrap();
748
749        let response = app(state)
750            .oneshot(request)
751            .await
752            .context("failed to send/receive login request")
753            .unwrap();
754
755        assert_eq!(response.status(), StatusCode::OK);
756        let bytes = response
757            .into_body()
758            .collect()
759            .await
760            .expect("failed to collect response body to bytes")
761            .to_bytes();
762        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
763            .context("failed to convert response to string")
764            .unwrap();
765
766        let response: ServerResponse<Labels> = serde_json::from_str(&json_response)
767            .context("failed to convert json response to object")
768            .unwrap();
769
770        let response = response.unwrap();
771        assert!(response.is_empty());
772    }
773
774    #[rstest]
775    #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true, false)]
776    #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false, false)]
777    #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false, false)]
778    #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false, false)]
779    #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false, false)]
780    #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false, false)]
781    #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false, false)]
782    #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false, false)]
783    #[case::icon_unknown_type_proxy(include_bytes!("../../../../MDB_Logo.ico"), false, false, false, true)]
784    #[tokio::test]
785    async fn submit_sample(
786        #[case] contents: &[u8],
787        #[case] compress: bool,
788        #[case] encrypt: bool,
789        #[case] cart: bool,
790        #[case] should_fail: bool,
791    ) {
792        let (state, source_id) = state(compress, encrypt).await;
793        let api_key = get_key(state.clone()).await;
794
795        let file_contents_b64 = general_purpose::STANDARD.encode(contents);
796        let mut hasher = Sha256::new();
797        hasher.update(contents);
798        let sha256 = hex::encode(hasher.finalize());
799
800        let upload = serde_json::to_string(&NewSampleB64 {
801            file_name: "some_sample".into(),
802            source_id,
803            file_contents_b64,
804            sha256: sha256.clone(),
805        })
806        .context("failed to create upload structure")
807        .unwrap();
808
809        let request = Request::builder()
810            .method("POST")
811            .uri(malwaredb_api::UPLOAD_SAMPLE_JSON_URL)
812            .header(CONTENT_TYPE, "application/json")
813            .header(MDB_API_HEADER, &api_key)
814            .body(Body::from(upload))
815            .unwrap();
816
817        let response = app(state.clone())
818            .oneshot(request)
819            .await
820            .context("failed to send/receive upload request/response")
821            .unwrap();
822
823        if should_fail {
824            assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
825            return;
826        }
827
828        assert_eq!(response.status(), StatusCode::OK);
829
830        let sha256 = if cart {
831            // We have to get the original SHA-256 hash, not the hash of the CaRT container.
832            // This way we can check that the file exists on disk, and request the file again later in this test.
833            let mut input_buffer = Cursor::new(contents);
834            let mut output_buffer = Cursor::new(vec![]);
835            let (_, footer) =
836                cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
837                    .expect("failed to decode CaRT file");
838            let footer = footer.expect("CaRT should have had a footer");
839            let sha256 = footer
840                .get("sha256")
841                .expect("CaRT footer should have had an entry for SHA-256")
842                .to_string();
843            sha256.replace('"', "") // Remove the quotes that get added by JSON
844        } else {
845            sha256
846        };
847
848        if let Some(dir) = &state.directory {
849            let mut sample_path = dir.clone();
850            sample_path.push(format!(
851                "{}/{}/{}/{}",
852                &sha256[0..2],
853                &sha256[2..4],
854                &sha256[4..6],
855                sha256
856            ));
857            eprintln!("Submitted sample should exist at {sample_path:?}.");
858            assert!(sample_path.exists());
859
860            if compress {
861                let sample_size_on_disk = sample_path.metadata().unwrap().len();
862                eprintln!(
863                    "Original size: {}, compressed: {}",
864                    contents.len(),
865                    sample_size_on_disk
866                );
867                assert!(sample_size_on_disk < contents.len() as u64);
868            }
869        } else {
870            panic!("Directory was set for the state, but is now `None`");
871        }
872
873        let request = Request::builder()
874            .method("GET")
875            .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT_URL))
876            .header(MDB_API_HEADER, &api_key)
877            .body(Body::empty())
878            .unwrap();
879
880        let response = app(state.clone())
881            .oneshot(request)
882            .await
883            .context("failed to send/receive upload request/response")
884            .unwrap();
885
886        println!("Response headers: {:?}", response.headers());
887
888        let bytes = response
889            .into_body()
890            .collect()
891            .await
892            .expect("failed to collect response body to bytes")
893            .to_bytes();
894        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
895            .context("failed to convert response to string")
896            .unwrap();
897
898        let report: ServerResponse<Report> = serde_json::from_str(&json_response)
899            .context("failed to convert json response to object")
900            .unwrap();
901        let report = report.unwrap();
902
903        assert_eq!(report.sha256, sha256);
904        println!("Report: {report}");
905
906        let request = Request::builder()
907            .method("GET")
908            .uri(format!(
909                "{}/{sha256}",
910                malwaredb_api::DOWNLOAD_SAMPLE_CART_URL
911            ))
912            .header(MDB_API_HEADER, api_key)
913            .body(Body::empty())
914            .unwrap();
915
916        let response = app(state.clone())
917            .oneshot(request)
918            .await
919            .context("failed to send/receive upload request/response for CaRT")
920            .unwrap();
921
922        println!("Response headers: {:?}", response.headers());
923
924        let bytes = response
925            .into_body()
926            .collect()
927            .await
928            .expect("failed to collect response body to bytes")
929            .to_bytes();
930
931        let bytes = bytes.to_vec();
932        let bytes_input = Cursor::new(bytes);
933        let output = Cursor::new(vec![]);
934        match cart_container::unpack_stream(bytes_input, output, None) {
935            Ok((header, _)) => {
936                let header = header.unwrap();
937                assert_eq!(
938                    header.get("sha384"),
939                    Some(&serde_json::to_value(report.sha384).unwrap())
940                );
941            }
942            Err(e) => panic!("{e}"),
943        }
944    }
945
946    /// Integration test between Malware DB Server & Client. This was under "/tests" in the project
947    /// but `cargo hack` would disable needed features, and there wasn't a way to use `cfg` to
948    /// put a condition on the test. Additionally, such an integration test required additional
949    /// fields of the Server's state to be public or to implement `Default`, which isn't ideal.
950    #[rstest]
951    #[case::ssl_pem(true)]
952    #[case::ssl_der(false)]
953    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
954    async fn client_integration(#[case] pem: bool) {
955        TRACING.call_once(init_tracing);
956
957        // This other get state function was needed, or else there were complaints of
958        // `state` being moved and unable to be used in `tokio::spawn()`
959        let (state, source_id, token) = state_and_token(pem).await;
960        let state_port = state.port; // copy before move below
961
962        let server = tokio::spawn(async move {
963            state
964                .serve()
965                .await
966                .expect("MalwareDB failed to .serve() in tokio::spawn()");
967        });
968        assert!(!server.is_finished());
969
970        // In case start-up time is needed
971        tokio::time::sleep(std::time::Duration::new(1, 0)).await;
972
973        println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
974        assert_eq!(
975            *malwaredb_client::MDB_VERSION_SEMVER,
976            *crate::MDB_VERSION_SEMVER,
977            "SemVer parsing of MDB version failed"
978        );
979
980        let mdb_client = malwaredb_client::MdbClient::new(
981            format!("https://127.0.0.1:{state_port}"),
982            token.clone(),
983            Some("../../testdata/ca_cert.pem".into()),
984        )
985        .unwrap();
986
987        assert!(!mdb_client.supported_types().await.unwrap().types.is_empty());
988        assert!(mdb_client.server_info().await.is_ok());
989
990        let start = Instant::now();
991        assert!(mdb_client
992            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
993            .await
994            .context("failed to upload test file")
995            .unwrap());
996        let duration = start.elapsed();
997        println!("Initial upload and database record creation via base64 took {duration:?}");
998
999        let start = Instant::now();
1000        mdb_client
1001            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1002            .await
1003            .context("failed to upload test file")
1004            .unwrap();
1005        let duration = start.elapsed();
1006        println!("Upload again via base64 took {duration:?}");
1007
1008        let start = Instant::now();
1009        mdb_client
1010            .submit_as_cbor(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1011            .await
1012            .context("failed to upload test file")
1013            .unwrap();
1014        let duration = start.elapsed();
1015        println!("Upload again via cbor took {duration:?}");
1016
1017        let report = mdb_client
1018            .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1019            .await
1020            .expect("failed to get report for file just submitted");
1021        assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1022
1023        let similar = mdb_client
1024            .similar(SAMPLE_BYTES)
1025            .await
1026            .expect("failed to query for files similar to what was just submitted");
1027        assert_eq!(similar.results.len(), 1);
1028
1029        let search = mdb_client
1030            .partial_search(
1031                Some((PartialHashSearchType::Any, "AAAA".into())),
1032                None,
1033                PartialHashSearchType::Any,
1034                10,
1035            )
1036            .await
1037            .unwrap();
1038        assert!(search.hashes.is_empty());
1039
1040        mdb_client.reset_key().await.expect("failed to reset key");
1041        server.abort();
1042    }
1043
1044    #[allow(clippy::too_many_lines)]
1045    #[test]
1046    #[ignore = "don't run this in CI"]
1047    fn client_integration_blocking() {
1048        TRACING.call_once(init_tracing);
1049        let token = Arc::new(RwLock::new(String::new()));
1050        let token_clone = token.clone();
1051
1052        // There has to be a better way to do this. We have to create a new thread to run Tokio,
1053        // otherwise the closing of the functions used by the blocking client will cause
1054        // Tokio to panic with this error:
1055        // Cannot drop a runtime in a context where blocking is not allowed. This happens when a runtime is dropped from within an asynchronous context.
1056        // So we wrap that all in another thread.
1057        let thread = std::thread::spawn(move || {
1058            let rt = tokio::runtime::Builder::new_multi_thread()
1059                .enable_all()
1060                .build()
1061                .unwrap();
1062
1063            let mut db_file = env::temp_dir();
1064            db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
1065            if std::path::Path::new(&db_file).exists() {
1066                fs::remove_file(&db_file)
1067                    .context(format!("failed to delete old SQLite file {db_file:?}"))
1068                    .unwrap();
1069            }
1070
1071            let (db_type, db_config) = rt.block_on(async {
1072                let db_type =
1073                    DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
1074                        .await
1075                        .context(format!("failed to create SQLite instance for {db_file:?}"))
1076                        .unwrap();
1077
1078                let db_config = db_type.get_config().await.unwrap();
1079                (db_type, db_config)
1080            });
1081
1082            let state = State {
1083                port: 9090,
1084                directory: Some(
1085                    tempfile::TempDir::with_prefix("mdb-temp-samples")
1086                        .unwrap()
1087                        .path()
1088                        .into(),
1089                ),
1090                max_upload: 10 * 1024 * 1024,
1091                ip: "127.0.0.1".parse().unwrap(),
1092                db_type: Arc::new(db_type),
1093                db_config,
1094                keys: HashMap::default(),
1095                started: SystemTime::now(),
1096                #[cfg(feature = "vt")]
1097                vt_client: None,
1098                tls_config: None,
1099                mdns: None,
1100            };
1101
1102            rt.block_on(async {
1103                state
1104                    .db_type
1105                    .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1106                    .await
1107                    .context("Failed to set admin password")
1108                    .unwrap();
1109
1110                let source_id = state
1111                    .db_type
1112                    .create_source("temp-source", None, None, Local::now(), true, Some(false))
1113                    .await
1114                    .unwrap();
1115
1116                state
1117                    .db_type
1118                    .add_group_to_source(0, source_id)
1119                    .await
1120                    .unwrap();
1121
1122                let token_string = state
1123                    .db_type
1124                    .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1125                    .await
1126                    .unwrap();
1127
1128                if let Ok(mut token_lock) = token_clone.write() {
1129                    *token_lock = token_string;
1130                }
1131
1132                state
1133                    .serve()
1134                    .await
1135                    .expect("MalwareDB failed to .serve() in tokio::spawn()");
1136            });
1137        });
1138        std::thread::sleep(std::time::Duration::from_secs(1));
1139
1140        let mdb_client = malwaredb_client::blocking::MdbClient::new(
1141            String::from("http://127.0.0.1:9090"),
1142            token.read().unwrap().clone(),
1143            None,
1144        )
1145        .unwrap();
1146
1147        let types = match mdb_client.supported_types() {
1148            Ok(types) => types,
1149            Err(e) => panic!("{e}"),
1150        };
1151        assert!(!types.types.is_empty());
1152
1153        assert!(mdb_client
1154            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), 1)
1155            .context("failed to upload test file")
1156            .unwrap());
1157
1158        let report = mdb_client
1159            .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1160            .expect("failed to get report for file just submitted");
1161        assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1162
1163        let similar = mdb_client
1164            .similar(SAMPLE_BYTES)
1165            .expect("failed to query for files similar to what was just submitted");
1166        assert_eq!(similar.results.len(), 1);
1167
1168        let search = mdb_client
1169            .partial_search(
1170                Some((PartialHashSearchType::Any, "AAAA".into())),
1171                None,
1172                PartialHashSearchType::Any,
1173                10,
1174            )
1175            .unwrap();
1176        assert!(search.hashes.is_empty());
1177
1178        mdb_client.reset_key().expect("failed to reset key");
1179        drop(thread); // force shutdown of the thread, as `thread.join()` just waits forever.
1180    }
1181}