1use super::State;
4use malwaredb_api::digest::HashType;
5use malwaredb_api::{
6 GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSampleB64, NewSampleBytes,
7 Report, SearchRequest, SearchResponse, ServerError, ServerInfo, ServerResponse,
8 SimilarSamplesRequest, SimilarSamplesResponse, Sources, SupportedFileTypes, MDB_API_HEADER,
9};
10
11use std::fmt::{Display, Formatter};
12use std::io::Cursor;
13use std::iter::once;
14use std::sync::Arc;
15
16use axum::body::Bytes;
17use axum::extract::{DefaultBodyLimit, Path, Request};
18use axum::http::{header, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::routing::{get, post};
22use axum::{middleware, Extension, Json, Router};
23use axum_cbor::Cbor;
24use base64::{engine::general_purpose, Engine as _};
25use constcat::concat;
26use http::{HeaderMap, HeaderName, HeaderValue};
27use sha2::{Digest, Sha256};
28use tower_http::compression::CompressionLayer;
29use tower_http::decompression::DecompressionLayer;
30use tower_http::limit::RequestBodyLimitLayer;
31use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
32
33mod receive;
34
35const FAVICON_URL: &str = "/favicon.ico";
36
37pub fn app(state: Arc<State>) -> Router {
39 const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSampleB64>>() * 2;
41
42 let compression_layer = CompressionLayer::new()
43 .br(true)
44 .deflate(true)
45 .gzip(true)
46 .zstd(true);
47
48 let decompression_layer = DecompressionLayer::new()
49 .br(true)
50 .deflate(true)
51 .gzip(true)
52 .zstd(true);
53
54 let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
55
56 Router::new()
57 .route("/", get(health))
58 .route(FAVICON_URL, get(favicon))
59 .route(malwaredb_api::SERVER_INFO_URL, get(get_mdb_info))
60 .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
61 .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
62 .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
63 .route(
64 malwaredb_api::SUPPORTED_FILE_TYPES_URL,
65 get(get_supported_types),
66 )
67 .route(malwaredb_api::LIST_LABELS_URL, get(get_labels))
68 .route(malwaredb_api::LIST_SOURCES_URL, get(get_sources))
69 .route(
70 malwaredb_api::UPLOAD_SAMPLE_JSON_URL,
71 post(upload_new_sample_json),
72 )
73 .route(
74 malwaredb_api::UPLOAD_SAMPLE_CBOR_URL,
75 post(upload_new_sample_cbor),
76 )
77 .route(malwaredb_api::SEARCH_URL, post(sample_search))
78 .route(
79 concat!(malwaredb_api::DOWNLOAD_SAMPLE_URL, "/{hash}"),
80 get(download_sample),
81 )
82 .route(
83 concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART_URL, "/{hash}"),
84 get(download_sample_cart),
85 )
86 .route(
87 concat!(malwaredb_api::SAMPLE_REPORT_URL, "/{hash}"),
88 get(sample_report),
89 )
90 .route(malwaredb_api::SIMILAR_SAMPLES_URL, post(find_similar))
91 .layer(DefaultBodyLimit::max(state.max_upload))
92 .layer(compression_layer)
93 .layer(decompression_layer)
94 .layer(size_limit_layer)
95 .layer(SetSensitiveHeadersLayer::new(once(
96 HeaderName::from_static(MDB_API_HEADER),
97 )))
98 .route_layer(middleware::from_fn(response_header_middleware))
99 .layer(Extension(state))
100}
101
102struct UserInfo {
104 pub id: u32,
105}
106
107async fn response_header_middleware(
110 Extension(state): Extension<Arc<State>>,
111 headers: HeaderMap,
112 mut req: Request,
113 next: Next,
114) -> Result<Response, HttpError> {
115 const ALWAYS_ALLOWED_ENDPOINTS: [&str; 4] = [
116 "/",
117 FAVICON_URL,
118 malwaredb_api::SERVER_INFO_URL,
119 malwaredb_api::USER_LOGIN_URL,
120 ];
121
122 if !ALWAYS_ALLOWED_ENDPOINTS.contains(&req.uri().path()) {
123 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
124 ServerError::Unauthorized,
125 StatusCode::NOT_ACCEPTABLE,
126 ))?;
127 let key = key
128 .to_str()
129 .map_err(|_| HttpError(ServerError::Unauthorized, StatusCode::NOT_ACCEPTABLE))?;
130 let uid = state.db_type.get_uid(key).await.map_err(|e| {
131 tracing::warn!("Failed to get user ID from API key: {e}");
132 HttpError(ServerError::Unauthorized, StatusCode::UNAUTHORIZED)
133 })?;
134 req.extensions_mut().insert(Arc::new(UserInfo { id: uid }));
135 }
136
137 let mut response = next.run(req).await; response
142 .headers_mut()
143 .insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store"));
144
145 Ok(response)
146}
147
148async fn health() -> StatusCode {
149 StatusCode::OK
150}
151
152async fn favicon() -> Response {
153 const ICON: Bytes = Bytes::from_static(include_bytes!("../../MDB_Logo.ico"));
154
155 let mut bytes = ICON.into_response();
156 bytes.headers_mut().insert(
157 header::CONTENT_TYPE,
158 HeaderValue::from_static("image/vnd.microsoft.icon"),
159 );
160
161 bytes
162}
163
164async fn get_mdb_info(
165 Extension(state): Extension<Arc<State>>,
166) -> Result<Json<ServerResponse<ServerInfo>>, HttpError> {
167 let server_info = ServerResponse::Success(state.get_info().await?);
168 Ok(Json(server_info))
169}
170
171async fn user_login(
172 Extension(state): Extension<Arc<State>>,
173 Json(payload): Json<GetAPIKeyRequest>,
174) -> Result<Json<ServerResponse<GetAPIKeyResponse>>, HttpError> {
175 let api_key = state
176 .db_type
177 .authenticate(&payload.user, &payload.password)
178 .await?;
179
180 Ok(Json(ServerResponse::Success(GetAPIKeyResponse {
181 key: api_key,
182 message: None,
183 })))
184}
185
186async fn user_logout(
187 Extension(state): Extension<Arc<State>>,
188 Extension(user): Extension<Arc<UserInfo>>,
189) -> Result<StatusCode, HttpError> {
190 state.db_type.reset_own_api_key(user.id).await?;
191 Ok(StatusCode::OK)
192}
193
194async fn get_user_groups_sources(
195 Extension(state): Extension<Arc<State>>,
196 Extension(user): Extension<Arc<UserInfo>>,
197) -> Result<Json<ServerResponse<GetUserInfoResponse>>, HttpError> {
198 let groups_sources = ServerResponse::Success(state.db_type.get_user_info(user.id).await?);
199 Ok(Json(groups_sources))
200}
201
202async fn get_supported_types(
203 Extension(state): Extension<Arc<State>>,
204) -> Result<Json<ServerResponse<SupportedFileTypes>>, HttpError> {
205 let data_types = state.db_type.get_known_data_types().await?;
206 let file_types = ServerResponse::Success(SupportedFileTypes {
207 types: data_types.into_iter().map(Into::into).collect(),
208 message: None,
209 });
210 Ok(Json(file_types))
211}
212
213async fn get_labels(
214 Extension(state): Extension<Arc<State>>,
215) -> Result<Json<ServerResponse<Labels>>, HttpError> {
216 let labels = ServerResponse::Success(state.db_type.get_labels().await?);
217 Ok(Json(labels))
218}
219
220async fn get_sources(
221 Extension(state): Extension<Arc<State>>,
222 Extension(user): Extension<Arc<UserInfo>>,
223) -> Result<Json<ServerResponse<Sources>>, HttpError> {
224 let sources = ServerResponse::Success(state.db_type.get_user_sources(user.id).await?);
225 Ok(Json(sources))
226}
227
228async fn upload_new_sample_json(
229 Extension(state): Extension<Arc<State>>,
230 Extension(user): Extension<Arc<UserInfo>>,
231 Json(payload): Json<NewSampleB64>,
232) -> Result<StatusCode, HttpError> {
233 let allowed = state
234 .db_type
235 .allowed_user_source(user.id, payload.source_id)
236 .await?;
237
238 if !allowed {
239 return Err(HttpError(
240 ServerError::Unauthorized,
241 StatusCode::UNAUTHORIZED,
242 ));
243 }
244
245 let received_hash = hex::decode(&payload.sha256)?;
246 let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
247
248 let mut hasher = Sha256::new();
249 hasher.update(&bytes);
250 let result = hasher.finalize();
251
252 if result[..] != received_hash[..] {
253 return Err(HttpError(
254 ServerError::Unauthorized,
255 StatusCode::NOT_ACCEPTABLE,
256 ));
257 }
258
259 receive::incoming_sample(
260 state.clone(),
261 bytes,
262 user.id,
263 payload.source_id,
264 payload.file_name,
265 )
266 .await?;
267
268 Ok(StatusCode::OK)
269}
270
271async fn upload_new_sample_cbor(
272 Extension(state): Extension<Arc<State>>,
273 Extension(user): Extension<Arc<UserInfo>>,
274 Cbor(payload): Cbor<NewSampleBytes>,
275) -> Result<StatusCode, HttpError> {
276 let allowed = state
277 .db_type
278 .allowed_user_source(user.id, payload.source_id)
279 .await?;
280
281 if !allowed {
282 return Err(HttpError(
283 ServerError::Unauthorized,
284 StatusCode::UNAUTHORIZED,
285 ));
286 }
287
288 let received_hash = hex::decode(&payload.sha256)?;
289 let bytes = payload.file_contents;
290 let mut hasher = Sha256::new();
291 hasher.update(&bytes);
292 let result = hasher.finalize();
293
294 if result[..] != received_hash[..] {
295 return Err(HttpError(
296 ServerError::Unauthorized,
297 StatusCode::NOT_ACCEPTABLE,
298 ));
299 }
300
301 receive::incoming_sample(
302 state.clone(),
303 bytes,
304 user.id,
305 payload.source_id,
306 payload.file_name,
307 )
308 .await?;
309
310 Ok(StatusCode::OK)
311}
312
313async fn sample_search(
314 Extension(state): Extension<Arc<State>>,
315 Extension(user): Extension<Arc<UserInfo>>,
316 Json(payload): Json<SearchRequest>,
317) -> Result<Json<ServerResponse<SearchResponse>>, HttpError> {
318 let hashes = ServerResponse::Success(state.db_type.partial_search(user.id, payload).await?);
319 Ok(Json(hashes))
320}
321
322async fn download_sample(
323 Path(hash): Path<String>,
324 Extension(user): Extension<Arc<UserInfo>>,
325 Extension(state): Extension<Arc<State>>,
326) -> Result<Response, HttpError> {
327 if state.directory.is_none() {
328 return Err(HttpError(
329 ServerError::NoSamples,
330 StatusCode::NOT_ACCEPTABLE,
331 ));
332 }
333
334 let hash = HashType::try_from(hash)?;
335 let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
336
337 let contents = state.retrieve_bytes(&sha256).await?;
338
339 let mut bytes = Bytes::from(contents).into_response();
340 let name_header_value = format!("attachment; filename=\"{sha256}\"");
341 bytes.headers_mut().insert(
342 header::CONTENT_DISPOSITION,
343 HeaderValue::from_str(&name_header_value)
344 .unwrap_or(HeaderValue::from_static("Unknown.bin")),
345 );
346
347 Ok(bytes)
348}
349
350async fn download_sample_cart(
351 Path(hash): Path<String>,
352 Extension(user): Extension<Arc<UserInfo>>,
353 Extension(state): Extension<Arc<State>>,
354) -> Result<Response, HttpError> {
355 if state.directory.is_none() {
356 return Err(HttpError(
357 ServerError::NoSamples,
358 StatusCode::NOT_ACCEPTABLE,
359 ));
360 }
361
362 let hash = HashType::try_from(hash)?;
363 let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
364 let report = state.db_type.get_sample_report(user.id, &hash).await?;
365
366 let contents = state.retrieve_bytes(&sha256).await?;
367 let contents_cursor = Cursor::new(contents);
368 let mut output_cursor = Cursor::new(vec![]);
369
370 let mut output_metadata = cart_container::JsonMap::new();
371 output_metadata.insert("sha384".into(), report.sha384.into());
372 output_metadata.insert("sha512".into(), report.sha512.into());
373 output_metadata.insert("entropy".into(), report.entropy.into());
374 if let Some(filecmd) = report.filecommand {
375 output_metadata.insert("file".into(), filecmd.into());
376 }
377
378 cart_container::pack_stream(
379 contents_cursor,
380 &mut output_cursor,
381 Some(output_metadata),
382 None,
383 cart_container::digesters::default_digesters(), None,
385 )?;
386
387 let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
388 let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
389 bytes.headers_mut().insert(
390 header::CONTENT_DISPOSITION,
391 HeaderValue::from_str(&name_header_value)
392 .unwrap_or(HeaderValue::from_static("Unknown.cart")),
393 );
394
395 Ok(bytes)
396}
397
398async fn sample_report(
399 Path(hash): Path<String>,
400 Extension(user): Extension<Arc<UserInfo>>,
401 Extension(state): Extension<Arc<State>>,
402) -> Result<Json<ServerResponse<Report>>, HttpError> {
403 let hash = HashType::try_from(hash)?;
404 let report =
405 ServerResponse::<Report>::Success(state.db_type.get_sample_report(user.id, &hash).await?);
406 Ok(Json(report))
407}
408
409async fn find_similar(
410 Extension(state): Extension<Arc<State>>,
411 Extension(user): Extension<Arc<UserInfo>>,
412 Json(payload): Json<SimilarSamplesRequest>,
413) -> Result<Json<ServerResponse<SimilarSamplesResponse>>, HttpError> {
414 let results = state
415 .db_type
416 .find_similar_samples(user.id, &payload.hashes)
417 .await?;
418
419 Ok(Json(ServerResponse::Success(SimilarSamplesResponse {
420 results,
421 message: None,
422 })))
423}
424
425pub struct HttpError(pub ServerError, pub StatusCode);
430
431impl IntoResponse for HttpError {
433 fn into_response(self) -> Response {
434 let response: ServerResponse<String> = ServerResponse::Error(self.0);
435 match serde_json::to_string(&response) {
436 Ok(json) => (self.1, json).into_response(),
437 Err(_) => self.1.into_response(),
438 }
439 }
440}
441
442impl Display for HttpError {
443 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
444 write!(f, "{}", self.0)
445 }
446}
447
448impl<E> From<E> for HttpError
450where
451 E: Into<anyhow::Error>,
452{
453 fn from(_err: E) -> Self {
454 Self(ServerError::ServerError, StatusCode::INTERNAL_SERVER_ERROR)
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use crate::crypto::{EncryptionOption, FileEncryption};
462 use crate::db::DatabaseType;
463 use malwaredb_api::PartialHashSearchType;
464
465 use std::collections::HashMap;
466 use std::sync::{Once, RwLock};
467 use std::time::{Instant, SystemTime};
468 use std::{env, fs};
469
470 use anyhow::Context;
471 use axum::body::Body;
472 use axum::http::Request;
473 use chrono::Local;
474 use http::header::CONTENT_TYPE;
475 use http_body_util::BodyExt;
476 use rstest::rstest;
477 use tower::ServiceExt;
478 use uuid::Uuid;
480
481 const ADMIN_UNAME: &str = "admin";
482 const ADMIN_PASSWORD: &str = "password12345";
483 const SAMPLE_BYTES: &[u8] = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
484
485 static TRACING: Once = Once::new();
486
487 fn init_tracing() {
488 tracing_subscriber::fmt()
489 .with_max_level(tracing::Level::TRACE)
490 .init();
491 }
492
493 async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
494 let mut db_file = env::temp_dir();
496 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
497 if std::path::Path::new(&db_file).exists() {
498 fs::remove_file(&db_file)
499 .context(format!("failed to delete old SQLite file {db_file:?}"))
500 .unwrap();
501 }
502
503 let db_type =
504 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
505 .await
506 .context(format!("failed to create SQLite instance for {db_file:?}"))
507 .unwrap();
508 if compress {
509 db_type.enable_compression().await.unwrap();
510 }
511 let keys = if encrypt {
512 let key = FileEncryption::from(EncryptionOption::Xor);
513 let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
514 let mut keys = HashMap::new();
515 keys.insert(key_id, key);
516 keys
517 } else {
518 HashMap::new()
519 };
520 let db_config = db_type.get_config().await.unwrap();
521
522 let state = State {
523 port: 8080,
524 directory: Some(
525 tempfile::TempDir::with_prefix("mdb-temp-samples")
526 .unwrap()
527 .path()
528 .into(),
529 ),
530 max_upload: 10 * 1024 * 1024,
531 ip: "127.0.0.1".parse().unwrap(),
532 db_type: Arc::new(db_type),
533 db_config,
534 keys,
535 started: SystemTime::now(),
536 #[cfg(feature = "vt")]
537 vt_client: None,
538 cert: None,
539 key: None,
540 mdns: false,
541 };
542
543 state
544 .db_type
545 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
546 .await
547 .context("Failed to set admin password")
548 .unwrap();
549
550 let source_id = state
551 .db_type
552 .create_source("temp-source", None, None, Local::now(), true, Some(false))
553 .await
554 .unwrap();
555
556 state
557 .db_type
558 .add_group_to_source(0, source_id)
559 .await
560 .unwrap();
561
562 (Arc::new(state), source_id)
563 }
564
565 async fn state_and_token(pem_file: bool) -> (State, u32, String) {
567 let mut db_file = env::temp_dir();
568 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
569 if std::path::Path::new(&db_file).exists() {
570 fs::remove_file(&db_file)
571 .context(format!("failed to delete old SQLite file {db_file:?}"))
572 .unwrap();
573 }
574
575 let db_type =
576 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
577 .await
578 .context(format!("failed to create SQLite instance for {db_file:?}"))
579 .unwrap();
580
581 let db_config = db_type.get_config().await.unwrap();
582
583 let state = if pem_file {
584 State {
585 port: 8443,
586 directory: Some(
587 tempfile::TempDir::with_prefix("mdb-temp-samples")
588 .unwrap()
589 .path()
590 .into(),
591 ),
592 max_upload: 10 * 1024 * 1024,
593 ip: "127.0.0.1".parse().unwrap(),
594 db_type: Arc::new(db_type),
595 db_config,
596 keys: HashMap::default(),
597 started: SystemTime::now(),
598 #[cfg(feature = "vt")]
599 vt_client: None,
600 cert: Some("../../testdata/server_ca_cert.pem".into()),
601 key: Some("../../testdata/server_key.pem".into()),
602 mdns: false,
603 }
604 } else {
605 State {
606 port: 8444,
607 directory: Some(
608 tempfile::TempDir::with_prefix("mdb-temp-samples")
609 .unwrap()
610 .path()
611 .into(),
612 ),
613 max_upload: 10 * 1024 * 1024,
614 ip: "127.0.0.1".parse().unwrap(),
615 db_type: Arc::new(db_type),
616 db_config,
617 keys: HashMap::default(),
618 started: SystemTime::now(),
619 #[cfg(feature = "vt")]
620 vt_client: None,
621 cert: Some("../../testdata/server_cert.der".into()),
622 key: Some("../../testdata/server_key.der".into()),
623 mdns: false,
624 }
625 };
626
627 state
628 .db_type
629 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
630 .await
631 .context("Failed to set admin password")
632 .unwrap();
633
634 let source_id = state
635 .db_type
636 .create_source("temp-source", None, None, Local::now(), true, Some(false))
637 .await
638 .unwrap();
639
640 state
641 .db_type
642 .add_group_to_source(0, source_id)
643 .await
644 .unwrap();
645
646 let token = state
647 .db_type
648 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
649 .await
650 .unwrap();
651
652 (state, source_id, token)
653 }
654
655 async fn get_key(state: Arc<State>) -> String {
656 let key_request = serde_json::to_string(&GetAPIKeyRequest {
657 user: ADMIN_UNAME.into(),
658 password: ADMIN_PASSWORD.into(),
659 })
660 .context("Failed to convert API key request to JSON")
661 .unwrap();
662
663 let request = Request::builder()
664 .method("POST")
665 .uri(malwaredb_api::USER_LOGIN_URL)
666 .header(CONTENT_TYPE, "application/json")
667 .body(Body::from(key_request))
668 .unwrap();
669
670 let response = app(state)
671 .oneshot(request)
672 .await
673 .context("failed to send/receive login request")
674 .unwrap();
675
676 assert_eq!(response.status(), StatusCode::OK);
677 let bytes = response
678 .into_body()
679 .collect()
680 .await
681 .expect("failed to collect response body to bytes")
682 .to_bytes();
683 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
684 .context("failed to convert response to string")
685 .unwrap();
686
687 let response: ServerResponse<GetAPIKeyResponse> = serde_json::from_str(&json_response)
688 .context("failed to convert json response to object")
689 .unwrap();
690
691 if let ServerResponse::Success(response) = response {
692 let key = response.key.clone();
693 assert_eq!(key.len(), 64);
694
695 key
696 } else {
697 panic!("failed to get API key response")
698 }
699 }
700
701 #[tokio::test]
702 async fn about_self() {
703 let (state, _) = state(false, false).await;
704 let api_key = get_key(state.clone()).await;
705
706 let request = Request::builder()
707 .method("GET")
708 .uri(malwaredb_api::USER_INFO_URL)
709 .header(MDB_API_HEADER, &api_key)
710 .body(Body::empty())
711 .unwrap();
712
713 let response = app(state.clone())
714 .oneshot(request)
715 .await
716 .context("failed to send/receive login request")
717 .unwrap();
718
719 assert_eq!(response.status(), StatusCode::OK);
720 let bytes = response
721 .into_body()
722 .collect()
723 .await
724 .expect("failed to collect response body to bytes")
725 .to_bytes();
726 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
727 .context("failed to convert response to string")
728 .unwrap();
729
730 let response: ServerResponse<GetUserInfoResponse> = serde_json::from_str(&json_response)
731 .context("failed to convert json response to object")
732 .unwrap();
733
734 let response = response.unwrap();
735 assert_eq!(response.id, 0);
736 assert!(response.is_admin);
737 assert!(!response.is_readonly);
738 assert_eq!(response.username, "admin");
739
740 let request = Request::builder()
742 .method("GET")
743 .uri(malwaredb_api::LIST_LABELS_URL)
744 .header(MDB_API_HEADER, &api_key)
745 .body(Body::empty())
746 .unwrap();
747
748 let response = app(state)
749 .oneshot(request)
750 .await
751 .context("failed to send/receive login request")
752 .unwrap();
753
754 assert_eq!(response.status(), StatusCode::OK);
755 let bytes = response
756 .into_body()
757 .collect()
758 .await
759 .expect("failed to collect response body to bytes")
760 .to_bytes();
761 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
762 .context("failed to convert response to string")
763 .unwrap();
764
765 let response: ServerResponse<Labels> = serde_json::from_str(&json_response)
766 .context("failed to convert json response to object")
767 .unwrap();
768
769 let response = response.unwrap();
770 assert!(response.is_empty());
771 }
772
773 #[rstest]
774 #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true, false)]
775 #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false, false)]
776 #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false, false)]
777 #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false, false)]
778 #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false, false)]
779 #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false, false)]
780 #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false, false)]
781 #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false, false)]
782 #[case::icon_unknown_type_proxy(include_bytes!("../../../../MDB_Logo.ico"), false, false, false, true)]
783 #[tokio::test]
784 async fn submit_sample(
785 #[case] contents: &[u8],
786 #[case] compress: bool,
787 #[case] encrypt: bool,
788 #[case] cart: bool,
789 #[case] should_fail: bool,
790 ) {
791 let (state, source_id) = state(compress, encrypt).await;
792 let api_key = get_key(state.clone()).await;
793
794 let file_contents_b64 = general_purpose::STANDARD.encode(contents);
795 let mut hasher = Sha256::new();
796 hasher.update(contents);
797 let sha256 = hex::encode(hasher.finalize());
798
799 let upload = serde_json::to_string(&NewSampleB64 {
800 file_name: "some_sample".into(),
801 source_id,
802 file_contents_b64,
803 sha256: sha256.clone(),
804 })
805 .context("failed to create upload structure")
806 .unwrap();
807
808 let request = Request::builder()
809 .method("POST")
810 .uri(malwaredb_api::UPLOAD_SAMPLE_JSON_URL)
811 .header(CONTENT_TYPE, "application/json")
812 .header(MDB_API_HEADER, &api_key)
813 .body(Body::from(upload))
814 .unwrap();
815
816 let response = app(state.clone())
817 .oneshot(request)
818 .await
819 .context("failed to send/receive upload request/response")
820 .unwrap();
821
822 if should_fail {
823 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
824 return;
825 }
826
827 assert_eq!(response.status(), StatusCode::OK);
828
829 let sha256 = if cart {
830 let mut input_buffer = Cursor::new(contents);
833 let mut output_buffer = Cursor::new(vec![]);
834 let (_, footer) =
835 cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
836 .expect("failed to decode CaRT file");
837 let footer = footer.expect("CaRT should have had a footer");
838 let sha256 = footer
839 .get("sha256")
840 .expect("CaRT footer should have had an entry for SHA-256")
841 .to_string();
842 sha256.replace('"', "") } else {
844 sha256
845 };
846
847 if let Some(dir) = &state.directory {
848 let mut sample_path = dir.clone();
849 sample_path.push(format!(
850 "{}/{}/{}/{}",
851 &sha256[0..2],
852 &sha256[2..4],
853 &sha256[4..6],
854 sha256
855 ));
856 eprintln!("Submitted sample should exist at {sample_path:?}.");
857 assert!(sample_path.exists());
858
859 if compress {
860 let sample_size_on_disk = sample_path.metadata().unwrap().len();
861 eprintln!(
862 "Original size: {}, compressed: {}",
863 contents.len(),
864 sample_size_on_disk
865 );
866 assert!(sample_size_on_disk < contents.len() as u64);
867 }
868 } else {
869 panic!("Directory was set for the state, but is now `None`");
870 }
871
872 let request = Request::builder()
873 .method("GET")
874 .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT_URL))
875 .header(MDB_API_HEADER, &api_key)
876 .body(Body::empty())
877 .unwrap();
878
879 let response = app(state.clone())
880 .oneshot(request)
881 .await
882 .context("failed to send/receive upload request/response")
883 .unwrap();
884
885 let bytes = response
886 .into_body()
887 .collect()
888 .await
889 .expect("failed to collect response body to bytes")
890 .to_bytes();
891 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
892 .context("failed to convert response to string")
893 .unwrap();
894
895 let report: ServerResponse<Report> = serde_json::from_str(&json_response)
896 .context("failed to convert json response to object")
897 .unwrap();
898 let report = report.unwrap();
899
900 assert_eq!(report.sha256, sha256);
901 println!("Report: {report}");
902
903 let request = Request::builder()
904 .method("GET")
905 .uri(format!(
906 "{}/{sha256}",
907 malwaredb_api::DOWNLOAD_SAMPLE_CART_URL
908 ))
909 .header(MDB_API_HEADER, api_key)
910 .body(Body::empty())
911 .unwrap();
912
913 let response = app(state.clone())
914 .oneshot(request)
915 .await
916 .context("failed to send/receive upload request/response for CaRT")
917 .unwrap();
918
919 let bytes = response
920 .into_body()
921 .collect()
922 .await
923 .expect("failed to collect response body to bytes")
924 .to_bytes();
925
926 let bytes = bytes.to_vec();
927 let bytes_input = Cursor::new(bytes);
928 let output = Cursor::new(vec![]);
929 match cart_container::unpack_stream(bytes_input, output, None) {
930 Ok((header, _)) => {
931 let header = header.unwrap();
932 assert_eq!(
933 header.get("sha384"),
934 Some(&serde_json::to_value(report.sha384).unwrap())
935 );
936 }
937 Err(e) => panic!("{e}"),
938 }
939 }
940
941 #[rstest]
946 #[case::ssl_pem(true)]
947 #[case::ssl_der(false)]
948 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
949 async fn client_integration(#[case] pem: bool) {
950 TRACING.call_once(init_tracing);
951
952 let (state, source_id, token) = state_and_token(pem).await;
955 let state_port = state.port; let server = tokio::spawn(async move {
958 state
959 .serve()
960 .await
961 .expect("MalwareDB failed to .serve() in tokio::spawn()");
962 });
963 assert!(!server.is_finished());
964
965 tokio::time::sleep(std::time::Duration::new(1, 0)).await;
967
968 println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
969 assert_eq!(
970 *malwaredb_client::MDB_VERSION_SEMVER,
971 *crate::MDB_VERSION_SEMVER,
972 "SemVer parsing of MDB version failed"
973 );
974
975 let mdb_client = malwaredb_client::MdbClient::new(
976 format!("https://127.0.0.1:{state_port}"),
977 token.clone(),
978 Some("../../testdata/ca_cert.pem".into()),
979 )
980 .unwrap();
981
982 assert!(!mdb_client.supported_types().await.unwrap().types.is_empty());
983 assert!(mdb_client.server_info().await.is_ok());
984
985 let start = Instant::now();
986 assert!(mdb_client
987 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
988 .await
989 .context("failed to upload test file")
990 .unwrap());
991 let duration = start.elapsed();
992 println!("Initial upload and database record creation via base64 took {duration:?}");
993
994 let start = Instant::now();
995 mdb_client
996 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
997 .await
998 .context("failed to upload test file")
999 .unwrap();
1000 let duration = start.elapsed();
1001 println!("Upload again via base64 took {duration:?}");
1002
1003 let start = Instant::now();
1004 mdb_client
1005 .submit_as_cbor(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1006 .await
1007 .context("failed to upload test file")
1008 .unwrap();
1009 let duration = start.elapsed();
1010 println!("Upload again via cbor took {duration:?}");
1011
1012 let report = mdb_client
1013 .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1014 .await
1015 .expect("failed to get report for file just submitted");
1016 assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1017
1018 let similar = mdb_client
1019 .similar(SAMPLE_BYTES)
1020 .await
1021 .expect("failed to query for files similar to what was just submitted");
1022 assert_eq!(similar.results.len(), 1);
1023
1024 let search = mdb_client
1025 .partial_search(
1026 Some((PartialHashSearchType::Any, "AAAA".into())),
1027 None,
1028 PartialHashSearchType::Any,
1029 10,
1030 )
1031 .await
1032 .unwrap();
1033 assert!(search.hashes.is_empty());
1034
1035 mdb_client.reset_key().await.expect("failed to reset key");
1036 server.abort();
1037 }
1038
1039 #[allow(clippy::too_many_lines)]
1040 #[test]
1041 #[ignore = "don't run this in CI"]
1042 fn client_integration_blocking() {
1043 TRACING.call_once(init_tracing);
1044 let token = Arc::new(RwLock::new(String::new()));
1045 let token_clone = token.clone();
1046
1047 let thread = std::thread::spawn(move || {
1053 let rt = tokio::runtime::Builder::new_multi_thread()
1054 .enable_all()
1055 .build()
1056 .unwrap();
1057
1058 let mut db_file = env::temp_dir();
1059 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
1060 if std::path::Path::new(&db_file).exists() {
1061 fs::remove_file(&db_file)
1062 .context(format!("failed to delete old SQLite file {db_file:?}"))
1063 .unwrap();
1064 }
1065
1066 let (db_type, db_config) = rt.block_on(async {
1067 let db_type =
1068 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
1069 .await
1070 .context(format!("failed to create SQLite instance for {db_file:?}"))
1071 .unwrap();
1072
1073 let db_config = db_type.get_config().await.unwrap();
1074 (db_type, db_config)
1075 });
1076
1077 let state = State {
1078 port: 9090,
1079 directory: Some(
1080 tempfile::TempDir::with_prefix("mdb-temp-samples")
1081 .unwrap()
1082 .path()
1083 .into(),
1084 ),
1085 max_upload: 10 * 1024 * 1024,
1086 ip: "127.0.0.1".parse().unwrap(),
1087 db_type: Arc::new(db_type),
1088 db_config,
1089 keys: HashMap::default(),
1090 started: SystemTime::now(),
1091 #[cfg(feature = "vt")]
1092 vt_client: None,
1093 cert: None,
1094 key: None,
1095 mdns: false,
1096 };
1097
1098 rt.block_on(async {
1099 state
1100 .db_type
1101 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1102 .await
1103 .context("Failed to set admin password")
1104 .unwrap();
1105
1106 let source_id = state
1107 .db_type
1108 .create_source("temp-source", None, None, Local::now(), true, Some(false))
1109 .await
1110 .unwrap();
1111
1112 state
1113 .db_type
1114 .add_group_to_source(0, source_id)
1115 .await
1116 .unwrap();
1117
1118 let token_string = state
1119 .db_type
1120 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1121 .await
1122 .unwrap();
1123
1124 if let Ok(mut token_lock) = token_clone.write() {
1125 *token_lock = token_string;
1126 }
1127
1128 state
1129 .serve()
1130 .await
1131 .expect("MalwareDB failed to .serve() in tokio::spawn()");
1132 });
1133 });
1134 std::thread::sleep(std::time::Duration::from_secs(1));
1135
1136 let mdb_client = malwaredb_client::blocking::MdbClient::new(
1137 String::from("http://127.0.0.1:9090"),
1138 token.read().unwrap().clone(),
1139 None,
1140 )
1141 .unwrap();
1142
1143 let types = match mdb_client.supported_types() {
1144 Ok(types) => types,
1145 Err(e) => panic!("{e}"),
1146 };
1147 assert!(!types.types.is_empty());
1148
1149 assert!(mdb_client
1150 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), 1)
1151 .context("failed to upload test file")
1152 .unwrap());
1153
1154 let report = mdb_client
1155 .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1156 .expect("failed to get report for file just submitted");
1157 assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1158
1159 let similar = mdb_client
1160 .similar(SAMPLE_BYTES)
1161 .expect("failed to query for files similar to what was just submitted");
1162 assert_eq!(similar.results.len(), 1);
1163
1164 let search = mdb_client
1165 .partial_search(
1166 Some((PartialHashSearchType::Any, "AAAA".into())),
1167 None,
1168 PartialHashSearchType::Any,
1169 10,
1170 )
1171 .unwrap();
1172 assert!(search.hashes.is_empty());
1173
1174 mdb_client.reset_key().expect("failed to reset key");
1175 drop(thread); }
1177}