Skip to main content

malwaredb_server/http/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::State;
4use malwaredb_api::digest::HashType;
5use malwaredb_api::{
6    GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSampleB64, NewSampleBytes,
7    Report, SearchRequest, SearchResponse, ServerError, ServerInfo, ServerResponse,
8    SimilarSamplesRequest, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
9    YaraSearchRequestResponse, YaraSearchResponse, MDB_API_HEADER,
10};
11
12use std::fmt::{Display, Formatter};
13use std::io::Cursor;
14use std::iter::once;
15use std::sync::Arc;
16
17use axum::body::Bytes;
18use axum::extract::{DefaultBodyLimit, Path, Request};
19use axum::http::{header, StatusCode};
20use axum::middleware::Next;
21use axum::response::{IntoResponse, Response};
22use axum::routing::{get, post};
23use axum::{middleware, Extension, Json, Router};
24use axum_cbor::Cbor;
25use base64::{engine::general_purpose, Engine as _};
26use constcat::concat;
27use http::{HeaderMap, HeaderName, HeaderValue};
28use sha2::{Digest, Sha256};
29use tower_http::compression::CompressionLayer;
30use tower_http::decompression::DecompressionLayer;
31use tower_http::limit::RequestBodyLimitLayer;
32use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
33use uuid::Uuid;
34
35mod receive;
36
37const FAVICON_URL: &str = "/favicon.ico";
38
39/// Build the web service from the initial server state object
40pub fn app(state: Arc<State>) -> Router {
41    // This is an estimate, so err on the side of usability.
42    const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSampleB64>>() * 2;
43
44    let compression_layer = CompressionLayer::new()
45        .br(true)
46        .deflate(true)
47        .gzip(true)
48        .zstd(true);
49
50    let decompression_layer = DecompressionLayer::new()
51        .br(true)
52        .deflate(true)
53        .gzip(true)
54        .zstd(true);
55
56    let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
57
58    Router::new()
59        .route("/", get(health))
60        .route(FAVICON_URL, get(favicon))
61        .route(malwaredb_api::SERVER_INFO_URL, get(get_mdb_info))
62        .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
63        .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
64        .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
65        .route(
66            malwaredb_api::SUPPORTED_FILE_TYPES_URL,
67            get(get_supported_types),
68        )
69        .route(malwaredb_api::LIST_LABELS_URL, get(get_labels))
70        .route(malwaredb_api::LIST_SOURCES_URL, get(get_sources))
71        .route(
72            malwaredb_api::UPLOAD_SAMPLE_JSON_URL,
73            post(upload_new_sample_json),
74        )
75        .route(
76            malwaredb_api::UPLOAD_SAMPLE_CBOR_URL,
77            post(upload_new_sample_cbor),
78        )
79        .route(malwaredb_api::SEARCH_URL, post(sample_search))
80        .route(
81            concat!(malwaredb_api::DOWNLOAD_SAMPLE_URL, "/{hash}"),
82            get(download_sample),
83        )
84        .route(
85            concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART_URL, "/{hash}"),
86            get(download_sample_cart),
87        )
88        .route(
89            concat!(malwaredb_api::SAMPLE_REPORT_URL, "/{hash}"),
90            get(sample_report),
91        )
92        .route(malwaredb_api::SIMILAR_SAMPLES_URL, post(find_similar))
93        .route(malwaredb_api::YARA_SEARCH_URL, post(yara_search))
94        .route(
95            concat!(malwaredb_api::YARA_SEARCH_URL, "/{uuid}"),
96            get(yara_search_get),
97        )
98        .layer(DefaultBodyLimit::max(state.max_upload))
99        .layer(compression_layer)
100        .layer(decompression_layer)
101        .layer(size_limit_layer)
102        .layer(SetSensitiveHeadersLayer::new(once(
103            HeaderName::from_static(MDB_API_HEADER),
104        )))
105        .route_layer(middleware::from_fn(response_header_middleware))
106        .layer(Extension(state))
107}
108
109/// User ID from the API key, generated by the middleware function
110struct UserInfo {
111    pub id: u32,
112}
113
114/// * Ask that content from Malware DB not be cached
115/// * One central place to get the API key and get the user's ID
116async fn response_header_middleware(
117    Extension(state): Extension<Arc<State>>,
118    headers: HeaderMap,
119    mut req: Request,
120    next: Next,
121) -> Result<Response, HttpError> {
122    const ALWAYS_ALLOWED_ENDPOINTS: [&str; 4] = [
123        "/",
124        FAVICON_URL,
125        malwaredb_api::SERVER_INFO_URL,
126        malwaredb_api::USER_LOGIN_URL,
127    ];
128
129    if !ALWAYS_ALLOWED_ENDPOINTS.contains(&req.uri().path()) {
130        let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
131            ServerError::Unauthorized,
132            StatusCode::NOT_ACCEPTABLE,
133        ))?;
134        let key = key
135            .to_str()
136            .map_err(|_| HttpError(ServerError::Unauthorized, StatusCode::NOT_ACCEPTABLE))?;
137        let uid = state.db_type.get_uid(key).await.map_err(|e| {
138            tracing::warn!("Failed to get user ID from API key: {e}");
139            HttpError(ServerError::Unauthorized, StatusCode::UNAUTHORIZED)
140        })?;
141        req.extensions_mut().insert(Arc::new(UserInfo { id: uid }));
142    }
143
144    let mut response = next.run(req).await; // Run the next service in the chain
145
146    // `no-store` so an intermediate proxy doesn't save any data
147    // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Cache-Control
148    response
149        .headers_mut()
150        .insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store"));
151
152    Ok(response)
153}
154
155async fn health() -> StatusCode {
156    StatusCode::OK
157}
158
159async fn favicon() -> Response {
160    const ICON: Bytes = Bytes::from_static(include_bytes!("../../MDB_Logo.ico"));
161
162    let mut bytes = ICON.into_response();
163    bytes.headers_mut().insert(
164        header::CONTENT_TYPE,
165        HeaderValue::from_static("image/vnd.microsoft.icon"),
166    );
167
168    bytes
169}
170
171async fn get_mdb_info(
172    Extension(state): Extension<Arc<State>>,
173) -> Result<Json<ServerResponse<ServerInfo>>, HttpError> {
174    let server_info = ServerResponse::Success(state.get_info().await?);
175    Ok(Json(server_info))
176}
177
178async fn user_login(
179    Extension(state): Extension<Arc<State>>,
180    Json(payload): Json<GetAPIKeyRequest>,
181) -> Result<Json<ServerResponse<GetAPIKeyResponse>>, HttpError> {
182    let api_key = state
183        .db_type
184        .authenticate(&payload.user, &payload.password)
185        .await?;
186
187    Ok(Json(ServerResponse::Success(GetAPIKeyResponse {
188        key: api_key,
189        message: None,
190    })))
191}
192
193async fn user_logout(
194    Extension(state): Extension<Arc<State>>,
195    Extension(user): Extension<Arc<UserInfo>>,
196) -> Result<StatusCode, HttpError> {
197    state.db_type.reset_own_api_key(user.id).await?;
198    Ok(StatusCode::OK)
199}
200
201async fn get_user_groups_sources(
202    Extension(state): Extension<Arc<State>>,
203    Extension(user): Extension<Arc<UserInfo>>,
204) -> Result<Json<ServerResponse<GetUserInfoResponse>>, HttpError> {
205    let groups_sources = ServerResponse::Success(state.db_type.get_user_info(user.id).await?);
206    Ok(Json(groups_sources))
207}
208
209async fn get_supported_types(
210    Extension(state): Extension<Arc<State>>,
211) -> Result<Json<ServerResponse<SupportedFileTypes>>, HttpError> {
212    let data_types = state.db_type.get_known_data_types().await?;
213    let file_types = ServerResponse::Success(SupportedFileTypes {
214        types: data_types.into_iter().map(Into::into).collect(),
215        message: None,
216    });
217    Ok(Json(file_types))
218}
219
220async fn get_labels(
221    Extension(state): Extension<Arc<State>>,
222) -> Result<Json<ServerResponse<Labels>>, HttpError> {
223    let labels = ServerResponse::Success(state.db_type.get_labels().await?);
224    Ok(Json(labels))
225}
226
227async fn get_sources(
228    Extension(state): Extension<Arc<State>>,
229    Extension(user): Extension<Arc<UserInfo>>,
230) -> Result<Json<ServerResponse<Sources>>, HttpError> {
231    let sources = ServerResponse::Success(state.db_type.get_user_sources(user.id).await?);
232    Ok(Json(sources))
233}
234
235async fn upload_new_sample_json(
236    Extension(state): Extension<Arc<State>>,
237    Extension(user): Extension<Arc<UserInfo>>,
238    Json(payload): Json<NewSampleB64>,
239) -> Result<StatusCode, HttpError> {
240    let allowed = state
241        .db_type
242        .allowed_user_source(user.id, payload.source_id)
243        .await?;
244
245    if !allowed {
246        return Err(HttpError(
247            ServerError::Unauthorized,
248            StatusCode::UNAUTHORIZED,
249        ));
250    }
251
252    let received_hash = hex::decode(&payload.sha256)?;
253    let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
254
255    let mut hasher = Sha256::new();
256    hasher.update(&bytes);
257    let result = hasher.finalize();
258
259    if result[..] != received_hash[..] {
260        return Err(HttpError(
261            ServerError::Unauthorized,
262            StatusCode::NOT_ACCEPTABLE,
263        ));
264    }
265
266    receive::incoming_sample(
267        state.clone(),
268        bytes,
269        user.id,
270        payload.source_id,
271        payload.file_name,
272    )
273    .await?;
274
275    Ok(StatusCode::OK)
276}
277
278async fn upload_new_sample_cbor(
279    Extension(state): Extension<Arc<State>>,
280    Extension(user): Extension<Arc<UserInfo>>,
281    Cbor(payload): Cbor<NewSampleBytes>,
282) -> Result<StatusCode, HttpError> {
283    let allowed = state
284        .db_type
285        .allowed_user_source(user.id, payload.source_id)
286        .await?;
287
288    if !allowed {
289        return Err(HttpError(
290            ServerError::Unauthorized,
291            StatusCode::UNAUTHORIZED,
292        ));
293    }
294
295    let received_hash = hex::decode(&payload.sha256)?;
296    let bytes = payload.file_contents;
297    let mut hasher = Sha256::new();
298    hasher.update(&bytes);
299    let result = hasher.finalize();
300
301    if result[..] != received_hash[..] {
302        return Err(HttpError(
303            ServerError::Unauthorized,
304            StatusCode::NOT_ACCEPTABLE,
305        ));
306    }
307
308    receive::incoming_sample(
309        state.clone(),
310        bytes,
311        user.id,
312        payload.source_id,
313        payload.file_name,
314    )
315    .await?;
316
317    Ok(StatusCode::OK)
318}
319
320async fn sample_search(
321    Extension(state): Extension<Arc<State>>,
322    Extension(user): Extension<Arc<UserInfo>>,
323    Json(payload): Json<SearchRequest>,
324) -> Result<Json<ServerResponse<SearchResponse>>, HttpError> {
325    let hashes = ServerResponse::Success(state.db_type.partial_search(user.id, payload).await?);
326    Ok(Json(hashes))
327}
328
329async fn download_sample(
330    Path(hash): Path<String>,
331    Extension(user): Extension<Arc<UserInfo>>,
332    Extension(state): Extension<Arc<State>>,
333) -> Result<Response, HttpError> {
334    if state.directory.is_none() {
335        return Err(NO_SAMPLES_STORED_ERROR);
336    }
337
338    let hash = HashType::try_from(hash.as_str())?;
339    let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
340
341    // Ensure we're sending the content-digest as SHA-256, since we could have received MD5 or SHA1.
342    let hash = HashType::try_from(sha256.as_str())?;
343
344    let contents = state.retrieve_bytes(&sha256).await?;
345
346    let mut bytes = Bytes::from(contents).into_response();
347    let name_header_value = format!("attachment; filename=\"{sha256}\"");
348    bytes.headers_mut().insert(
349        header::CONTENT_DISPOSITION,
350        HeaderValue::from_str(&name_header_value)
351            .unwrap_or(HeaderValue::from_static("Unknown.bin")),
352    );
353
354    bytes.headers_mut().insert(
355        "content-digest",
356        HeaderValue::from_str(&hash.content_digest_header())?,
357    );
358
359    Ok(bytes)
360}
361
362async fn download_sample_cart(
363    Path(hash): Path<String>,
364    Extension(user): Extension<Arc<UserInfo>>,
365    Extension(state): Extension<Arc<State>>,
366) -> Result<Response, HttpError> {
367    if state.directory.is_none() {
368        return Err(NO_SAMPLES_STORED_ERROR);
369    }
370
371    let hash = HashType::try_from(hash.as_str())?;
372    let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
373    let report = state.db_type.get_sample_report(user.id, &hash).await?;
374
375    let contents = state.retrieve_bytes(&sha256).await?;
376    let contents_cursor = Cursor::new(contents);
377    let mut output_cursor = Cursor::new(vec![]);
378
379    let mut output_metadata = cart_container::JsonMap::new();
380    output_metadata.insert("sha384".into(), report.sha384.into());
381    output_metadata.insert("sha512".into(), report.sha512.into());
382    output_metadata.insert("entropy".into(), report.entropy.into());
383    if let Some(filecmd) = report.filecommand {
384        output_metadata.insert("file".into(), filecmd.into());
385    }
386
387    cart_container::pack_stream(
388        contents_cursor,
389        &mut output_cursor,
390        Some(output_metadata),
391        None,
392        cart_container::digesters::default_digesters(), // MD-5, SHA-1, SHA-256 applied here
393        None,
394    )?;
395
396    let mut hasher = Sha256::new();
397    hasher.update(output_cursor.get_ref());
398    let hash_b64 = general_purpose::STANDARD.encode(hasher.finalize());
399
400    let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
401    let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
402    bytes.headers_mut().insert(
403        header::CONTENT_DISPOSITION,
404        HeaderValue::from_str(&name_header_value)
405            .unwrap_or(HeaderValue::from_static("Unknown.cart")),
406    );
407    bytes.headers_mut().insert(
408        "content-digest",
409        HeaderValue::from_str(&format!("sha-256=:{hash_b64}:"))?,
410    );
411
412    Ok(bytes)
413}
414
415async fn sample_report(
416    Path(hash): Path<String>,
417    Extension(user): Extension<Arc<UserInfo>>,
418    Extension(state): Extension<Arc<State>>,
419) -> Result<Json<ServerResponse<Report>>, HttpError> {
420    let hash = HashType::try_from(hash.as_str())?;
421    let report =
422        ServerResponse::<Report>::Success(state.db_type.get_sample_report(user.id, &hash).await?);
423    Ok(Json(report))
424}
425
426async fn find_similar(
427    Extension(state): Extension<Arc<State>>,
428    Extension(user): Extension<Arc<UserInfo>>,
429    Json(payload): Json<SimilarSamplesRequest>,
430) -> Result<Json<ServerResponse<SimilarSamplesResponse>>, HttpError> {
431    let results = state
432        .db_type
433        .find_similar_samples(user.id, &payload.hashes)
434        .await?;
435
436    Ok(Json(ServerResponse::Success(SimilarSamplesResponse {
437        results,
438        message: None,
439    })))
440}
441
442#[allow(unused_variables)]
443async fn yara_search(
444    Extension(state): Extension<Arc<State>>,
445    Extension(user): Extension<Arc<UserInfo>>,
446    Json(payload): Json<YaraSearchRequest>,
447) -> Result<Json<ServerResponse<YaraSearchRequestResponse>>, HttpError> {
448    #[cfg(feature = "yara")]
449    {
450        if state.directory.is_some() {
451            let yara_string = payload.rules.join("\n");
452            let rules = crate::yara::compile_yara_rules(payload)?;
453            let yara_bytes = rules.serialize()?;
454
455            let search_uuid = state
456                .db_type
457                .add_yara_search(user.id, &yara_string, &yara_bytes)
458                .await?;
459
460            Ok(Json(ServerResponse::Success(YaraSearchRequestResponse {
461                uuid: search_uuid,
462            })))
463        } else {
464            Err(NO_SAMPLES_STORED_ERROR)
465        }
466    }
467
468    #[cfg(not(feature = "yara"))]
469    {
470        tracing::warn!("Received Yara search request, but Yara support is not enabled");
471        Err(HttpError(
472            ServerError::Unsupported,
473            StatusCode::INTERNAL_SERVER_ERROR,
474        ))
475    }
476}
477
478#[allow(unused_variables)]
479async fn yara_search_get(
480    Path(uuid): Path<Uuid>,
481    Extension(state): Extension<Arc<State>>,
482    Extension(user): Extension<Arc<UserInfo>>,
483) -> Result<Json<ServerResponse<YaraSearchResponse>>, HttpError> {
484    #[cfg(feature = "yara")]
485    {
486        if state.directory.is_some() {
487            Ok(Json(ServerResponse::Success(
488                state.db_type.get_yara_results(uuid, user.id).await?,
489            )))
490        } else {
491            Err(NO_SAMPLES_STORED_ERROR)
492        }
493    }
494    #[cfg(not(feature = "yara"))]
495    {
496        tracing::warn!("Received Yara search request, but Yara support is not enabled");
497        Err(HttpError(
498            ServerError::Unsupported,
499            StatusCode::INTERNAL_SERVER_ERROR,
500        ))
501    }
502}
503
504// How to use `anyhow::Error` with `axum`:
505// https://github.com/tokio-rs/axum/blob/c97967252de9741b602f400dc2b25c8a33216039/examples/anyhow-error-response/src/main.rs
506
507const NO_SAMPLES_STORED_ERROR: HttpError =
508    HttpError(ServerError::NoSamples, StatusCode::NOT_ACCEPTABLE);
509
510/// Custom error type to support Axum error handling
511pub struct HttpError(pub ServerError, pub StatusCode);
512
513/// Convert Anyhow error into an Axum response object
514impl IntoResponse for HttpError {
515    fn into_response(self) -> Response {
516        let response: ServerResponse<String> = ServerResponse::Error(self.0);
517        match serde_json::to_string(&response) {
518            Ok(json) => (self.1, json).into_response(),
519            Err(_) => self.1.into_response(),
520        }
521    }
522}
523
524impl Display for HttpError {
525    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
526        write!(f, "{}", self.0)
527    }
528}
529
530/// Enable the use of the ? operator in Axum handler functions
531impl<E> From<E> for HttpError
532where
533    E: Into<anyhow::Error>,
534{
535    fn from(_err: E) -> Self {
536        Self(ServerError::ServerError, StatusCode::INTERNAL_SERVER_ERROR)
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use crate::crypto::{EncryptionOption, FileEncryption};
544    use crate::db::DatabaseType;
545    use crate::StateBuilder;
546    use malwaredb_api::PartialHashSearchType;
547
548    use std::collections::HashMap;
549    use std::path::PathBuf;
550    use std::sync::{Once, RwLock};
551    use std::time::{Instant, SystemTime};
552    use std::{env, fs};
553
554    use anyhow::Context;
555    use axum::body::Body;
556    use axum::http::Request;
557    use chrono::Local;
558    use constcat::concat;
559    use deadpool_postgres::tokio_postgres::{Config, NoTls};
560    use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
561    use http::header::CONTENT_TYPE;
562    use http_body_util::BodyExt;
563    use rstest::rstest;
564    use tower::ServiceExt;
565    // for `app.oneshot()`
566    use uuid::Uuid;
567
568    const ADMIN_UNAME: &str = "admin";
569    const ADMIN_PASSWORD: &str = "password12345";
570    const SAMPLE_BYTES: &[u8] = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
571
572    static TRACING: Once = Once::new();
573
574    fn init_tracing() {
575        tracing_subscriber::fmt()
576            .with_max_level(tracing::Level::TRACE)
577            .init();
578    }
579
580    async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
581        TRACING.call_once(init_tracing);
582
583        // Each test needs a separate file, or else they'll clobber each other.
584        let mut db_file = env::temp_dir();
585        db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
586        if std::path::Path::new(&db_file).exists() {
587            fs::remove_file(&db_file)
588                .context(format!("failed to delete old SQLite file {db_file:?}"))
589                .unwrap();
590        }
591
592        let db_type =
593            DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
594                .await
595                .context(format!("failed to create SQLite instance for {db_file:?}"))
596                .unwrap();
597        if compress {
598            db_type.enable_compression().await.unwrap();
599        }
600        let keys = if encrypt {
601            let key = FileEncryption::from(EncryptionOption::Xor);
602            let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
603            let mut keys = HashMap::new();
604            keys.insert(key_id, key);
605            keys
606        } else {
607            HashMap::new()
608        };
609        let db_config = db_type.get_config().await.unwrap();
610
611        let state = State {
612            port: 8080,
613            directory: Some(
614                tempfile::TempDir::with_prefix("mdb-temp-samples")
615                    .unwrap()
616                    .path()
617                    .into(),
618            ),
619            max_upload: 10 * 1024 * 1024,
620            ip: "127.0.0.1".parse().unwrap(),
621            db_type: Arc::new(db_type),
622            db_config,
623            keys,
624            started: SystemTime::now(),
625            #[cfg(feature = "vt")]
626            vt_client: None,
627            tls_config: None,
628            mdns: None,
629        };
630
631        state
632            .db_type
633            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
634            .await
635            .context("Failed to set admin password")
636            .unwrap();
637
638        let source_id = state
639            .db_type
640            .create_source("temp-source", None, None, Local::now(), true, Some(false))
641            .await
642            .unwrap();
643
644        state
645            .db_type
646            .add_group_to_source(0, source_id)
647            .await
648            .unwrap();
649
650        (Arc::new(state), source_id)
651    }
652
653    /// Get a new state and generate an API token
654    async fn state_and_token(pem_file: bool, postgres: bool) -> (State, u32, String) {
655        TRACING.call_once(init_tracing);
656
657        let mut builder = if postgres {
658            const CONNECTION_STRING: &str =
659                "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
660
661            let pg_config = CONNECTION_STRING.parse::<Config>().unwrap();
662            let mgr_config = ManagerConfig {
663                recycling_method: RecyclingMethod::Fast,
664            };
665            let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
666            let pool = Pool::builder(mgr)
667                .max_size(num_cpus::get().min(16))
668                .build()
669                .unwrap();
670
671            let client = pool.get().await.unwrap();
672            client
673                .batch_execute("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
674                .await
675                .unwrap();
676
677            StateBuilder::new(concat!("postgres ", CONNECTION_STRING), None)
678                .await
679                .unwrap()
680        } else {
681            let mut db_file = env::temp_dir();
682            db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
683            if std::path::Path::new(&db_file).exists() {
684                fs::remove_file(&db_file)
685                    .context(format!("failed to delete old SQLite file {db_file:?}"))
686                    .unwrap();
687            }
688
689            StateBuilder::new(&format!("file:{}", db_file.display()), None)
690                .await
691                .unwrap()
692        };
693
694        builder = builder.directory(PathBuf::from(
695            tempfile::TempDir::with_prefix("mdb-temp-samples")
696                .unwrap()
697                .path(),
698        ));
699
700        if pem_file {
701            builder = builder.port(8443);
702            builder = builder
703                .tls(
704                    "../../testdata/server_ca_cert.pem".into(),
705                    "../../testdata/server_key.pem".into(),
706                )
707                .await
708                .unwrap();
709        } else {
710            builder = builder.port(8444);
711            builder = builder
712                .tls(
713                    "../../testdata/server_cert.der".into(),
714                    "../../testdata/server_key.der".into(),
715                )
716                .await
717                .unwrap();
718        }
719
720        let state = builder.into_state().await.unwrap();
721
722        state
723            .db_type
724            .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
725            .await
726            .context("Failed to set admin password")
727            .unwrap();
728
729        let source_id = state
730            .db_type
731            .create_source("temp-source", None, None, Local::now(), true, Some(false))
732            .await
733            .unwrap();
734
735        state
736            .db_type
737            .add_group_to_source(0, source_id)
738            .await
739            .unwrap();
740
741        let token = state
742            .db_type
743            .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
744            .await
745            .unwrap();
746
747        (state, source_id, token)
748    }
749
750    async fn get_key(state: Arc<State>) -> String {
751        TRACING.call_once(init_tracing);
752
753        let key_request = serde_json::to_string(&GetAPIKeyRequest {
754            user: ADMIN_UNAME.into(),
755            password: ADMIN_PASSWORD.into(),
756        })
757        .context("Failed to convert API key request to JSON")
758        .unwrap();
759
760        let request = Request::builder()
761            .method("POST")
762            .uri(malwaredb_api::USER_LOGIN_URL)
763            .header(CONTENT_TYPE, "application/json")
764            .body(Body::from(key_request))
765            .unwrap();
766
767        let response = app(state)
768            .oneshot(request)
769            .await
770            .context("failed to send/receive login request")
771            .unwrap();
772
773        assert_eq!(response.status(), StatusCode::OK);
774        let bytes = response
775            .into_body()
776            .collect()
777            .await
778            .expect("failed to collect response body to bytes")
779            .to_bytes();
780        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
781            .context("failed to convert response to string")
782            .unwrap();
783
784        let response: ServerResponse<GetAPIKeyResponse> = serde_json::from_str(&json_response)
785            .context("failed to convert json response to object")
786            .unwrap();
787
788        if let ServerResponse::Success(response) = response {
789            let key = response.key.clone();
790            assert_eq!(key.len(), 64);
791
792            key
793        } else {
794            panic!("failed to get API key response")
795        }
796    }
797
798    #[tokio::test]
799    async fn about_self() {
800        let (state, _) = state(false, false).await;
801        let api_key = get_key(state.clone()).await;
802
803        let request = Request::builder()
804            .method("GET")
805            .uri(malwaredb_api::USER_INFO_URL)
806            .header(MDB_API_HEADER, &api_key)
807            .body(Body::empty())
808            .unwrap();
809
810        let response = app(state.clone())
811            .oneshot(request)
812            .await
813            .context("failed to send/receive login request")
814            .unwrap();
815
816        assert_eq!(response.status(), StatusCode::OK);
817        let bytes = response
818            .into_body()
819            .collect()
820            .await
821            .expect("failed to collect response body to bytes")
822            .to_bytes();
823        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
824            .context("failed to convert response to string")
825            .unwrap();
826
827        let response: ServerResponse<GetUserInfoResponse> = serde_json::from_str(&json_response)
828            .context("failed to convert json response to object")
829            .unwrap();
830
831        let response = response.unwrap();
832        assert_eq!(response.id, 0);
833        assert!(response.is_admin);
834        assert!(!response.is_readonly);
835        assert_eq!(response.username, "admin");
836
837        // Check labels, should be empty
838        let request = Request::builder()
839            .method("GET")
840            .uri(malwaredb_api::LIST_LABELS_URL)
841            .header(MDB_API_HEADER, &api_key)
842            .body(Body::empty())
843            .unwrap();
844
845        let response = app(state)
846            .oneshot(request)
847            .await
848            .context("failed to send/receive login request")
849            .unwrap();
850
851        assert_eq!(response.status(), StatusCode::OK);
852        let bytes = response
853            .into_body()
854            .collect()
855            .await
856            .expect("failed to collect response body to bytes")
857            .to_bytes();
858        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
859            .context("failed to convert response to string")
860            .unwrap();
861
862        let response: ServerResponse<Labels> = serde_json::from_str(&json_response)
863            .context("failed to convert json response to object")
864            .unwrap();
865
866        let response = response.unwrap();
867        assert!(response.is_empty());
868    }
869
870    #[rstest]
871    #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true, false)]
872    #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false, false)]
873    #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false, false)]
874    #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false, false)]
875    #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false, false)]
876    #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false, false)]
877    #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false, false)]
878    #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false, false)]
879    #[case::icon_unknown_type_proxy(include_bytes!("../../../../MDB_Logo.ico"), false, false, false, true)]
880    #[tokio::test]
881    async fn submit_sample(
882        #[case] contents: &[u8],
883        #[case] compress: bool,
884        #[case] encrypt: bool,
885        #[case] cart: bool,
886        #[case] should_fail: bool,
887    ) {
888        let (state, source_id) = state(compress, encrypt).await;
889        let api_key = get_key(state.clone()).await;
890
891        let file_contents_b64 = general_purpose::STANDARD.encode(contents);
892        let mut hasher = Sha256::new();
893        hasher.update(contents);
894        let sha256 = hex::encode(hasher.finalize());
895
896        let upload = serde_json::to_string(&NewSampleB64 {
897            file_name: "some_sample".into(),
898            source_id,
899            file_contents_b64,
900            sha256: sha256.clone(),
901        })
902        .context("failed to create upload structure")
903        .unwrap();
904
905        let request = Request::builder()
906            .method("POST")
907            .uri(malwaredb_api::UPLOAD_SAMPLE_JSON_URL)
908            .header(CONTENT_TYPE, "application/json")
909            .header(MDB_API_HEADER, &api_key)
910            .body(Body::from(upload))
911            .unwrap();
912
913        let response = app(state.clone())
914            .oneshot(request)
915            .await
916            .context("failed to send/receive upload request/response")
917            .unwrap();
918
919        if should_fail {
920            assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
921            return;
922        }
923
924        assert_eq!(response.status(), StatusCode::OK);
925
926        let sha256 = if cart {
927            // We have to get the original SHA-256 hash, not the hash of the CaRT container.
928            // This way we can check that the file exists on disk, and request the file again later in this test.
929            let mut input_buffer = Cursor::new(contents);
930            let mut output_buffer = Cursor::new(vec![]);
931            let (_, footer) =
932                cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
933                    .expect("failed to decode CaRT file");
934            let footer = footer.expect("CaRT should have had a footer");
935            let sha256 = footer
936                .get("sha256")
937                .expect("CaRT footer should have had an entry for SHA-256")
938                .to_string();
939            sha256.replace('"', "") // Remove the quotes that get added by JSON
940        } else {
941            sha256
942        };
943
944        if let Some(dir) = &state.directory {
945            let mut sample_path = dir.clone();
946            sample_path.push(format!(
947                "{}/{}/{}/{}",
948                &sha256[0..2],
949                &sha256[2..4],
950                &sha256[4..6],
951                sha256
952            ));
953            eprintln!("Submitted sample should exist at {sample_path:?}.");
954            assert!(sample_path.exists());
955
956            if compress {
957                let sample_size_on_disk = sample_path.metadata().unwrap().len();
958                eprintln!(
959                    "Original size: {}, compressed: {}",
960                    contents.len(),
961                    sample_size_on_disk
962                );
963                assert!(sample_size_on_disk < contents.len() as u64);
964            }
965        } else {
966            panic!("Directory was set for the state, but is now `None`");
967        }
968
969        let request = Request::builder()
970            .method("GET")
971            .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT_URL))
972            .header(MDB_API_HEADER, &api_key)
973            .body(Body::empty())
974            .unwrap();
975
976        let response = app(state.clone())
977            .oneshot(request)
978            .await
979            .context("failed to send/receive upload request/response")
980            .unwrap();
981
982        println!("Response headers: {:?}", response.headers());
983
984        let bytes = response
985            .into_body()
986            .collect()
987            .await
988            .expect("failed to collect response body to bytes")
989            .to_bytes();
990        let json_response = String::from_utf8(bytes.to_ascii_lowercase())
991            .context("failed to convert response to string")
992            .unwrap();
993
994        let report: ServerResponse<Report> = serde_json::from_str(&json_response)
995            .context("failed to convert json response to object")
996            .unwrap();
997        let report = report.unwrap();
998
999        assert_eq!(report.sha256, sha256);
1000        println!("Report: {report}");
1001
1002        let request = Request::builder()
1003            .method("GET")
1004            .uri(format!(
1005                "{}/{sha256}",
1006                malwaredb_api::DOWNLOAD_SAMPLE_CART_URL
1007            ))
1008            .header(MDB_API_HEADER, api_key.clone())
1009            .body(Body::empty())
1010            .unwrap();
1011
1012        let response = app(state.clone())
1013            .oneshot(request)
1014            .await
1015            .context("failed to send/receive upload request/response for CaRT")
1016            .unwrap();
1017
1018        println!("CaRT Response headers: {:?}", response.headers());
1019
1020        let bytes = response
1021            .into_body()
1022            .collect()
1023            .await
1024            .expect("failed to collect response body to bytes")
1025            .to_bytes();
1026
1027        let bytes = bytes.to_vec();
1028        let bytes_input = Cursor::new(bytes);
1029        let output = Cursor::new(vec![]);
1030        match cart_container::unpack_stream(bytes_input, output, None) {
1031            Ok((header, _)) => {
1032                let header = header.unwrap();
1033                assert_eq!(
1034                    header.get("sha384"),
1035                    Some(&serde_json::to_value(report.sha384).unwrap())
1036                );
1037            }
1038            Err(e) => panic!("{e}"),
1039        }
1040    }
1041
1042    /// Integration test between Malware DB Server & Client. This was under "/tests" in the project
1043    /// but `cargo hack` would disable needed features, and there wasn't a way to use `cfg` to
1044    /// put a condition on the test. Additionally, such an integration test required additional
1045    /// fields of the Server's state to be public or to implement `Default`, which isn't ideal.
1046    #[rstest]
1047    #[case::ssl_pem_sqlite(true, false)]
1048    #[case::ssl_der_sqlite(false, false)]
1049    #[ignore = "don't run this in CI"]
1050    #[case::ssl_der_postgres(false, true)]
1051    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1052    async fn client_integration(#[case] pem: bool, #[case] postgres: bool) {
1053        TRACING.call_once(init_tracing);
1054
1055        // This other get state function was needed, or else there were complaints of
1056        // `state` being moved and unable to be used in `tokio::spawn()`
1057        let (state, source_id, token) = state_and_token(pem, postgres).await;
1058        let state_port = state.port; // copy before move below
1059
1060        let server = tokio::spawn(async move {
1061            state
1062                .serve()
1063                .await
1064                .expect("MalwareDB failed to .serve() in tokio::spawn()");
1065        });
1066        assert!(!server.is_finished());
1067
1068        // In case start-up time is needed
1069        tokio::time::sleep(std::time::Duration::new(1, 0)).await;
1070
1071        println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
1072        assert_eq!(
1073            *malwaredb_client::MDB_VERSION_SEMVER,
1074            *crate::MDB_VERSION_SEMVER,
1075            "SemVer parsing of MDB version failed"
1076        );
1077
1078        let mdb_client = malwaredb_client::MdbClient::new(
1079            format!("https://127.0.0.1:{state_port}"),
1080            token.clone(),
1081            Some("../../testdata/ca_cert.pem".into()),
1082        )
1083        .unwrap();
1084
1085        assert!(!mdb_client.supported_types().await.unwrap().types.is_empty());
1086        assert!(mdb_client.server_info().await.is_ok());
1087
1088        let start = Instant::now();
1089        assert!(mdb_client
1090            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1091            .await
1092            .context("failed to upload test file")
1093            .unwrap());
1094        let duration = start.elapsed();
1095        println!("Initial upload and database record creation via base64 took {duration:?}");
1096
1097        let start = Instant::now();
1098        mdb_client
1099            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1100            .await
1101            .context("failed to upload test file")
1102            .unwrap();
1103        let duration = start.elapsed();
1104        println!("Upload again via base64 took {duration:?}");
1105
1106        let start = Instant::now();
1107        mdb_client
1108            .submit_as_cbor(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1109            .await
1110            .context("failed to upload test file")
1111            .unwrap();
1112        let duration = start.elapsed();
1113        println!("Upload again via cbor took {duration:?}");
1114
1115        let report = mdb_client
1116            .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1117            .await
1118            .expect("failed to get report for file just submitted");
1119        assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1120
1121        let similar = mdb_client
1122            .similar(SAMPLE_BYTES)
1123            .await
1124            .expect("failed to query for files similar to what was just submitted");
1125        assert_eq!(similar.results.len(), 1);
1126
1127        let search = mdb_client
1128            .partial_search(
1129                Some((PartialHashSearchType::MD5, String::from(&report.md5[0..10]))),
1130                None,
1131                PartialHashSearchType::Any,
1132                10,
1133            )
1134            .await
1135            .unwrap();
1136        assert_eq!(search.hashes.len(), 1);
1137
1138        let search = mdb_client
1139            .partial_search(
1140                Some((PartialHashSearchType::Any, "AAAA".into())),
1141                None,
1142                PartialHashSearchType::Any,
1143                10,
1144            )
1145            .await
1146            .unwrap();
1147        assert!(search.hashes.is_empty());
1148
1149        #[cfg(feature = "yara")]
1150        {
1151            let elf_yara = "rule elf_file {
1152              condition:
1153                uint32(0) == 0x464c457f
1154            }";
1155
1156            let mut counter = 0;
1157            let mut successful = false;
1158
1159            let yara_result = mdb_client.yara_search(elf_yara).await.unwrap();
1160            while counter < 5 {
1161                if let Ok(yara_response) = mdb_client.yara_result(yara_result.uuid).await {
1162                    if !yara_response.results.is_empty() {
1163                        println!("Yara search upload response: {:?}", yara_response.results);
1164                        assert_eq!(yara_response.results.get("elf_file").unwrap().len(), 1);
1165                        assert_eq!(yara_response.results.len(), 1);
1166                        successful = true;
1167                        break;
1168                    }
1169                }
1170                tokio::time::sleep(std::time::Duration::new(2, 0)).await;
1171                counter += 1;
1172            }
1173
1174            assert!(successful);
1175        }
1176
1177        mdb_client.reset_key().await.expect("failed to reset key");
1178        server.abort();
1179    }
1180
1181    #[allow(clippy::too_many_lines)]
1182    #[test]
1183    #[ignore = "don't run this in CI"]
1184    fn client_integration_blocking() {
1185        TRACING.call_once(init_tracing);
1186        let token = Arc::new(RwLock::new(String::new()));
1187        let token_clone = token.clone();
1188
1189        // There has to be a better way to do this. We have to create a new thread to run Tokio,
1190        // otherwise the closing of the functions used by the blocking client will cause
1191        // Tokio to panic with this error:
1192        // Cannot drop a runtime in a context where blocking is not allowed. This happens when a runtime is dropped from within an asynchronous context.
1193        // So we wrap that all in another thread.
1194        let thread = std::thread::spawn(move || {
1195            let rt = tokio::runtime::Builder::new_multi_thread()
1196                .enable_all()
1197                .build()
1198                .unwrap();
1199
1200            let mut db_file = env::temp_dir();
1201            db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
1202            if std::path::Path::new(&db_file).exists() {
1203                fs::remove_file(&db_file)
1204                    .context(format!("failed to delete old SQLite file {db_file:?}"))
1205                    .unwrap();
1206            }
1207
1208            let (db_type, db_config) = rt.block_on(async {
1209                let db_type =
1210                    DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
1211                        .await
1212                        .context(format!("failed to create SQLite instance for {db_file:?}"))
1213                        .unwrap();
1214
1215                let db_config = db_type.get_config().await.unwrap();
1216                (db_type, db_config)
1217            });
1218
1219            let state = State {
1220                port: 9090,
1221                directory: Some(
1222                    tempfile::TempDir::with_prefix("mdb-temp-samples")
1223                        .unwrap()
1224                        .path()
1225                        .into(),
1226                ),
1227                max_upload: 10 * 1024 * 1024,
1228                ip: "127.0.0.1".parse().unwrap(),
1229                db_type: Arc::new(db_type),
1230                db_config,
1231                keys: HashMap::default(),
1232                started: SystemTime::now(),
1233                #[cfg(feature = "vt")]
1234                vt_client: None,
1235                tls_config: None,
1236                mdns: None,
1237            };
1238
1239            rt.block_on(async {
1240                state
1241                    .db_type
1242                    .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1243                    .await
1244                    .context("Failed to set admin password")
1245                    .unwrap();
1246
1247                let source_id = state
1248                    .db_type
1249                    .create_source("temp-source", None, None, Local::now(), true, Some(false))
1250                    .await
1251                    .unwrap();
1252
1253                state
1254                    .db_type
1255                    .add_group_to_source(0, source_id)
1256                    .await
1257                    .unwrap();
1258
1259                let token_string = state
1260                    .db_type
1261                    .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1262                    .await
1263                    .unwrap();
1264
1265                if let Ok(mut token_lock) = token_clone.write() {
1266                    *token_lock = token_string;
1267                }
1268
1269                state
1270                    .serve()
1271                    .await
1272                    .expect("MalwareDB failed to .serve() in tokio::spawn()");
1273            });
1274        });
1275        std::thread::sleep(std::time::Duration::from_secs(1));
1276
1277        let mdb_client = malwaredb_client::blocking::MdbClient::new(
1278            String::from("http://127.0.0.1:9090"),
1279            token.read().unwrap().clone(),
1280            None,
1281        )
1282        .unwrap();
1283
1284        let types = match mdb_client.supported_types() {
1285            Ok(types) => types,
1286            Err(e) => panic!("{e}"),
1287        };
1288        assert!(!types.types.is_empty());
1289
1290        assert!(mdb_client
1291            .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), 1)
1292            .context("failed to upload test file")
1293            .unwrap());
1294
1295        let report = mdb_client
1296            .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1297            .expect("failed to get report for file just submitted");
1298        assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1299
1300        let similar = mdb_client
1301            .similar(SAMPLE_BYTES)
1302            .expect("failed to query for files similar to what was just submitted");
1303        assert_eq!(similar.results.len(), 1);
1304
1305        let search = mdb_client
1306            .partial_search(
1307                Some((PartialHashSearchType::Any, "AAAA".into())),
1308                None,
1309                PartialHashSearchType::Any,
1310                10,
1311            )
1312            .unwrap();
1313        assert!(search.hashes.is_empty());
1314
1315        mdb_client.reset_key().expect("failed to reset key");
1316        drop(thread); // force shutdown of the thread, as `thread.join()` just waits forever.
1317    }
1318}