1use super::State;
4use malwaredb_api::digest::HashType;
5use malwaredb_api::{
6 GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSampleB64, NewSampleBytes,
7 Report, SearchRequest, SearchResponse, ServerError, ServerInfo, ServerResponse,
8 SimilarSamplesRequest, SimilarSamplesResponse, Sources, SupportedFileTypes, MDB_API_HEADER,
9};
10
11use std::fmt::{Display, Formatter};
12use std::io::Cursor;
13use std::iter::once;
14use std::sync::Arc;
15
16use axum::body::Bytes;
17use axum::extract::{DefaultBodyLimit, Path, Request};
18use axum::http::{header, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::routing::{get, post};
22use axum::{middleware, Extension, Json, Router};
23use axum_cbor::Cbor;
24use base64::{engine::general_purpose, Engine as _};
25use constcat::concat;
26use http::{HeaderMap, HeaderName, HeaderValue};
27use sha2::{Digest, Sha256};
28use tower_http::compression::CompressionLayer;
29use tower_http::decompression::DecompressionLayer;
30use tower_http::limit::RequestBodyLimitLayer;
31use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
32
33mod receive;
34
35const FAVICON_URL: &str = "/favicon.ico";
36
37pub fn app(state: Arc<State>) -> Router {
39 const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSampleB64>>() * 2;
41
42 let compression_layer = CompressionLayer::new()
43 .br(true)
44 .deflate(true)
45 .gzip(true)
46 .zstd(true);
47
48 let decompression_layer = DecompressionLayer::new()
49 .br(true)
50 .deflate(true)
51 .gzip(true)
52 .zstd(true);
53
54 let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
55
56 Router::new()
57 .route("/", get(health))
58 .route(FAVICON_URL, get(favicon))
59 .route(malwaredb_api::SERVER_INFO_URL, get(get_mdb_info))
60 .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
61 .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
62 .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
63 .route(
64 malwaredb_api::SUPPORTED_FILE_TYPES_URL,
65 get(get_supported_types),
66 )
67 .route(malwaredb_api::LIST_LABELS_URL, get(get_labels))
68 .route(malwaredb_api::LIST_SOURCES_URL, get(get_sources))
69 .route(
70 malwaredb_api::UPLOAD_SAMPLE_JSON_URL,
71 post(upload_new_sample_json),
72 )
73 .route(
74 malwaredb_api::UPLOAD_SAMPLE_CBOR_URL,
75 post(upload_new_sample_cbor),
76 )
77 .route(malwaredb_api::SEARCH_URL, post(sample_search))
78 .route(
79 concat!(malwaredb_api::DOWNLOAD_SAMPLE_URL, "/{hash}"),
80 get(download_sample),
81 )
82 .route(
83 concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART_URL, "/{hash}"),
84 get(download_sample_cart),
85 )
86 .route(
87 concat!(malwaredb_api::SAMPLE_REPORT_URL, "/{hash}"),
88 get(sample_report),
89 )
90 .route(malwaredb_api::SIMILAR_SAMPLES_URL, post(find_similar))
91 .layer(DefaultBodyLimit::max(state.max_upload))
92 .layer(compression_layer)
93 .layer(decompression_layer)
94 .layer(size_limit_layer)
95 .layer(SetSensitiveHeadersLayer::new(once(
96 HeaderName::from_static(MDB_API_HEADER),
97 )))
98 .route_layer(middleware::from_fn(response_header_middleware))
99 .layer(Extension(state))
100}
101
102struct UserInfo {
104 pub id: u32,
105}
106
107async fn response_header_middleware(
110 Extension(state): Extension<Arc<State>>,
111 headers: HeaderMap,
112 mut req: Request,
113 next: Next,
114) -> Result<Response, HttpError> {
115 const ALWAYS_ALLOWED_ENDPOINTS: [&str; 4] = [
116 "/",
117 FAVICON_URL,
118 malwaredb_api::SERVER_INFO_URL,
119 malwaredb_api::USER_LOGIN_URL,
120 ];
121
122 if !ALWAYS_ALLOWED_ENDPOINTS.contains(&req.uri().path()) {
123 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
124 ServerError::Unauthorized,
125 StatusCode::NOT_ACCEPTABLE,
126 ))?;
127 let key = key
128 .to_str()
129 .map_err(|_| HttpError(ServerError::Unauthorized, StatusCode::NOT_ACCEPTABLE))?;
130 let uid = state.db_type.get_uid(key).await.map_err(|e| {
131 tracing::warn!("Failed to get user ID from API key: {e}");
132 HttpError(ServerError::Unauthorized, StatusCode::UNAUTHORIZED)
133 })?;
134 req.extensions_mut().insert(Arc::new(UserInfo { id: uid }));
135 }
136
137 let mut response = next.run(req).await; response
142 .headers_mut()
143 .insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store"));
144
145 Ok(response)
146}
147
148async fn health() -> StatusCode {
149 StatusCode::OK
150}
151
152async fn favicon() -> Response {
153 const ICON: Bytes = Bytes::from_static(include_bytes!("../../MDB_Logo.ico"));
154
155 let mut bytes = ICON.into_response();
156 bytes.headers_mut().insert(
157 header::CONTENT_TYPE,
158 HeaderValue::from_static("image/vnd.microsoft.icon"),
159 );
160
161 bytes
162}
163
164async fn get_mdb_info(
165 Extension(state): Extension<Arc<State>>,
166) -> Result<Json<ServerResponse<ServerInfo>>, HttpError> {
167 let server_info = ServerResponse::Success(state.get_info().await?);
168 Ok(Json(server_info))
169}
170
171async fn user_login(
172 Extension(state): Extension<Arc<State>>,
173 Json(payload): Json<GetAPIKeyRequest>,
174) -> Result<Json<ServerResponse<GetAPIKeyResponse>>, HttpError> {
175 let api_key = state
176 .db_type
177 .authenticate(&payload.user, &payload.password)
178 .await?;
179
180 Ok(Json(ServerResponse::Success(GetAPIKeyResponse {
181 key: api_key,
182 message: None,
183 })))
184}
185
186async fn user_logout(
187 Extension(state): Extension<Arc<State>>,
188 Extension(user): Extension<Arc<UserInfo>>,
189) -> Result<StatusCode, HttpError> {
190 state.db_type.reset_own_api_key(user.id).await?;
191 Ok(StatusCode::OK)
192}
193
194async fn get_user_groups_sources(
195 Extension(state): Extension<Arc<State>>,
196 Extension(user): Extension<Arc<UserInfo>>,
197) -> Result<Json<ServerResponse<GetUserInfoResponse>>, HttpError> {
198 let groups_sources = ServerResponse::Success(state.db_type.get_user_info(user.id).await?);
199 Ok(Json(groups_sources))
200}
201
202async fn get_supported_types(
203 Extension(state): Extension<Arc<State>>,
204) -> Result<Json<ServerResponse<SupportedFileTypes>>, HttpError> {
205 let data_types = state.db_type.get_known_data_types().await?;
206 let file_types = ServerResponse::Success(SupportedFileTypes {
207 types: data_types.into_iter().map(Into::into).collect(),
208 message: None,
209 });
210 Ok(Json(file_types))
211}
212
213async fn get_labels(
214 Extension(state): Extension<Arc<State>>,
215) -> Result<Json<ServerResponse<Labels>>, HttpError> {
216 let labels = ServerResponse::Success(state.db_type.get_labels().await?);
217 Ok(Json(labels))
218}
219
220async fn get_sources(
221 Extension(state): Extension<Arc<State>>,
222 Extension(user): Extension<Arc<UserInfo>>,
223) -> Result<Json<ServerResponse<Sources>>, HttpError> {
224 let sources = ServerResponse::Success(state.db_type.get_user_sources(user.id).await?);
225 Ok(Json(sources))
226}
227
228async fn upload_new_sample_json(
229 Extension(state): Extension<Arc<State>>,
230 Extension(user): Extension<Arc<UserInfo>>,
231 Json(payload): Json<NewSampleB64>,
232) -> Result<StatusCode, HttpError> {
233 let allowed = state
234 .db_type
235 .allowed_user_source(user.id, payload.source_id)
236 .await?;
237
238 if !allowed {
239 return Err(HttpError(
240 ServerError::Unauthorized,
241 StatusCode::UNAUTHORIZED,
242 ));
243 }
244
245 let received_hash = hex::decode(&payload.sha256)?;
246 let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
247
248 let mut hasher = Sha256::new();
249 hasher.update(&bytes);
250 let result = hasher.finalize();
251
252 if result[..] != received_hash[..] {
253 return Err(HttpError(
254 ServerError::Unauthorized,
255 StatusCode::NOT_ACCEPTABLE,
256 ));
257 }
258
259 receive::incoming_sample(
260 state.clone(),
261 bytes,
262 user.id,
263 payload.source_id,
264 payload.file_name,
265 )
266 .await?;
267
268 Ok(StatusCode::OK)
269}
270
271async fn upload_new_sample_cbor(
272 Extension(state): Extension<Arc<State>>,
273 Extension(user): Extension<Arc<UserInfo>>,
274 Cbor(payload): Cbor<NewSampleBytes>,
275) -> Result<StatusCode, HttpError> {
276 let allowed = state
277 .db_type
278 .allowed_user_source(user.id, payload.source_id)
279 .await?;
280
281 if !allowed {
282 return Err(HttpError(
283 ServerError::Unauthorized,
284 StatusCode::UNAUTHORIZED,
285 ));
286 }
287
288 let received_hash = hex::decode(&payload.sha256)?;
289 let bytes = payload.file_contents;
290 let mut hasher = Sha256::new();
291 hasher.update(&bytes);
292 let result = hasher.finalize();
293
294 if result[..] != received_hash[..] {
295 return Err(HttpError(
296 ServerError::Unauthorized,
297 StatusCode::NOT_ACCEPTABLE,
298 ));
299 }
300
301 receive::incoming_sample(
302 state.clone(),
303 bytes,
304 user.id,
305 payload.source_id,
306 payload.file_name,
307 )
308 .await?;
309
310 Ok(StatusCode::OK)
311}
312
313async fn sample_search(
314 Extension(state): Extension<Arc<State>>,
315 Extension(user): Extension<Arc<UserInfo>>,
316 Json(payload): Json<SearchRequest>,
317) -> Result<Json<ServerResponse<SearchResponse>>, HttpError> {
318 let hashes = ServerResponse::Success(state.db_type.partial_search(user.id, payload).await?);
319 Ok(Json(hashes))
320}
321
322async fn download_sample(
323 Path(hash): Path<String>,
324 Extension(user): Extension<Arc<UserInfo>>,
325 Extension(state): Extension<Arc<State>>,
326) -> Result<Response, HttpError> {
327 if state.directory.is_none() {
328 return Err(HttpError(
329 ServerError::NoSamples,
330 StatusCode::NOT_ACCEPTABLE,
331 ));
332 }
333
334 let hash = HashType::try_from(hash.as_str())?;
335 let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
336
337 let hash = HashType::try_from(sha256.as_str())?;
339
340 let contents = state.retrieve_bytes(&sha256).await?;
341
342 let mut bytes = Bytes::from(contents).into_response();
343 let name_header_value = format!("attachment; filename=\"{sha256}\"");
344 bytes.headers_mut().insert(
345 header::CONTENT_DISPOSITION,
346 HeaderValue::from_str(&name_header_value)
347 .unwrap_or(HeaderValue::from_static("Unknown.bin")),
348 );
349
350 bytes.headers_mut().insert(
351 "content-digest",
352 HeaderValue::from_str(&hash.content_digest_header())?,
353 );
354
355 Ok(bytes)
356}
357
358async fn download_sample_cart(
359 Path(hash): Path<String>,
360 Extension(user): Extension<Arc<UserInfo>>,
361 Extension(state): Extension<Arc<State>>,
362) -> Result<Response, HttpError> {
363 if state.directory.is_none() {
364 return Err(HttpError(
365 ServerError::NoSamples,
366 StatusCode::NOT_ACCEPTABLE,
367 ));
368 }
369
370 let hash = HashType::try_from(hash.as_str())?;
371 let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
372 let report = state.db_type.get_sample_report(user.id, &hash).await?;
373
374 let contents = state.retrieve_bytes(&sha256).await?;
375 let contents_cursor = Cursor::new(contents);
376 let mut output_cursor = Cursor::new(vec![]);
377
378 let mut output_metadata = cart_container::JsonMap::new();
379 output_metadata.insert("sha384".into(), report.sha384.into());
380 output_metadata.insert("sha512".into(), report.sha512.into());
381 output_metadata.insert("entropy".into(), report.entropy.into());
382 if let Some(filecmd) = report.filecommand {
383 output_metadata.insert("file".into(), filecmd.into());
384 }
385
386 cart_container::pack_stream(
387 contents_cursor,
388 &mut output_cursor,
389 Some(output_metadata),
390 None,
391 cart_container::digesters::default_digesters(), None,
393 )?;
394
395 let mut hasher = Sha256::new();
396 hasher.update(output_cursor.get_ref());
397 let hash_b64 = general_purpose::STANDARD.encode(hasher.finalize());
398
399 let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
400 let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
401 bytes.headers_mut().insert(
402 header::CONTENT_DISPOSITION,
403 HeaderValue::from_str(&name_header_value)
404 .unwrap_or(HeaderValue::from_static("Unknown.cart")),
405 );
406 bytes.headers_mut().insert(
407 "content-digest",
408 HeaderValue::from_str(&format!("sha-256=:{hash_b64}:"))?,
409 );
410
411 Ok(bytes)
412}
413
414async fn sample_report(
415 Path(hash): Path<String>,
416 Extension(user): Extension<Arc<UserInfo>>,
417 Extension(state): Extension<Arc<State>>,
418) -> Result<Json<ServerResponse<Report>>, HttpError> {
419 let hash = HashType::try_from(hash.as_str())?;
420 let report =
421 ServerResponse::<Report>::Success(state.db_type.get_sample_report(user.id, &hash).await?);
422 Ok(Json(report))
423}
424
425async fn find_similar(
426 Extension(state): Extension<Arc<State>>,
427 Extension(user): Extension<Arc<UserInfo>>,
428 Json(payload): Json<SimilarSamplesRequest>,
429) -> Result<Json<ServerResponse<SimilarSamplesResponse>>, HttpError> {
430 let results = state
431 .db_type
432 .find_similar_samples(user.id, &payload.hashes)
433 .await?;
434
435 Ok(Json(ServerResponse::Success(SimilarSamplesResponse {
436 results,
437 message: None,
438 })))
439}
440
441pub struct HttpError(pub ServerError, pub StatusCode);
446
447impl IntoResponse for HttpError {
449 fn into_response(self) -> Response {
450 let response: ServerResponse<String> = ServerResponse::Error(self.0);
451 match serde_json::to_string(&response) {
452 Ok(json) => (self.1, json).into_response(),
453 Err(_) => self.1.into_response(),
454 }
455 }
456}
457
458impl Display for HttpError {
459 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
460 write!(f, "{}", self.0)
461 }
462}
463
464impl<E> From<E> for HttpError
466where
467 E: Into<anyhow::Error>,
468{
469 fn from(_err: E) -> Self {
470 Self(ServerError::ServerError, StatusCode::INTERNAL_SERVER_ERROR)
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::crypto::{EncryptionOption, FileEncryption};
478 use crate::db::DatabaseType;
479 use crate::StateBuilder;
480 use malwaredb_api::PartialHashSearchType;
481
482 use std::collections::HashMap;
483 use std::path::PathBuf;
484 use std::sync::{Once, RwLock};
485 use std::time::{Instant, SystemTime};
486 use std::{env, fs};
487
488 use anyhow::Context;
489 use axum::body::Body;
490 use axum::http::Request;
491 use chrono::Local;
492 use http::header::CONTENT_TYPE;
493 use http_body_util::BodyExt;
494 use rstest::rstest;
495 use tower::ServiceExt;
496 use uuid::Uuid;
498
499 const ADMIN_UNAME: &str = "admin";
500 const ADMIN_PASSWORD: &str = "password12345";
501 const SAMPLE_BYTES: &[u8] = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
502
503 static TRACING: Once = Once::new();
504
505 fn init_tracing() {
506 tracing_subscriber::fmt()
507 .with_max_level(tracing::Level::TRACE)
508 .init();
509 }
510
511 async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
512 TRACING.call_once(init_tracing);
513
514 let mut db_file = env::temp_dir();
516 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
517 if std::path::Path::new(&db_file).exists() {
518 fs::remove_file(&db_file)
519 .context(format!("failed to delete old SQLite file {db_file:?}"))
520 .unwrap();
521 }
522
523 let db_type =
524 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
525 .await
526 .context(format!("failed to create SQLite instance for {db_file:?}"))
527 .unwrap();
528 if compress {
529 db_type.enable_compression().await.unwrap();
530 }
531 let keys = if encrypt {
532 let key = FileEncryption::from(EncryptionOption::Xor);
533 let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
534 let mut keys = HashMap::new();
535 keys.insert(key_id, key);
536 keys
537 } else {
538 HashMap::new()
539 };
540 let db_config = db_type.get_config().await.unwrap();
541
542 let state = State {
543 port: 8080,
544 directory: Some(
545 tempfile::TempDir::with_prefix("mdb-temp-samples")
546 .unwrap()
547 .path()
548 .into(),
549 ),
550 max_upload: 10 * 1024 * 1024,
551 ip: "127.0.0.1".parse().unwrap(),
552 db_type: Arc::new(db_type),
553 db_config,
554 keys,
555 started: SystemTime::now(),
556 #[cfg(feature = "vt")]
557 vt_client: None,
558 tls_config: None,
559 mdns: None,
560 };
561
562 state
563 .db_type
564 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
565 .await
566 .context("Failed to set admin password")
567 .unwrap();
568
569 let source_id = state
570 .db_type
571 .create_source("temp-source", None, None, Local::now(), true, Some(false))
572 .await
573 .unwrap();
574
575 state
576 .db_type
577 .add_group_to_source(0, source_id)
578 .await
579 .unwrap();
580
581 (Arc::new(state), source_id)
582 }
583
584 async fn state_and_token(pem_file: bool) -> (State, u32, String) {
586 TRACING.call_once(init_tracing);
587
588 let mut db_file = env::temp_dir();
589 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
590 if std::path::Path::new(&db_file).exists() {
591 fs::remove_file(&db_file)
592 .context(format!("failed to delete old SQLite file {db_file:?}"))
593 .unwrap();
594 }
595
596 let mut builder = StateBuilder::new(&format!("file:{}", db_file.display()), None)
597 .await
598 .unwrap();
599 builder = builder.directory(PathBuf::from(
600 tempfile::TempDir::with_prefix("mdb-temp-samples")
601 .unwrap()
602 .path(),
603 ));
604 if pem_file {
605 builder = builder.port(8443);
606 builder = builder
607 .tls(
608 "../../testdata/server_ca_cert.pem".into(),
609 "../../testdata/server_key.pem".into(),
610 )
611 .await
612 .unwrap();
613 } else {
614 builder = builder.port(8444);
615 builder = builder
616 .tls(
617 "../../testdata/server_cert.der".into(),
618 "../../testdata/server_key.der".into(),
619 )
620 .await
621 .unwrap();
622 }
623
624 let state = builder.into_state().await.unwrap();
625
626 state
627 .db_type
628 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
629 .await
630 .context("Failed to set admin password")
631 .unwrap();
632
633 let source_id = state
634 .db_type
635 .create_source("temp-source", None, None, Local::now(), true, Some(false))
636 .await
637 .unwrap();
638
639 state
640 .db_type
641 .add_group_to_source(0, source_id)
642 .await
643 .unwrap();
644
645 let token = state
646 .db_type
647 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
648 .await
649 .unwrap();
650
651 (state, source_id, token)
652 }
653
654 async fn get_key(state: Arc<State>) -> String {
655 TRACING.call_once(init_tracing);
656
657 let key_request = serde_json::to_string(&GetAPIKeyRequest {
658 user: ADMIN_UNAME.into(),
659 password: ADMIN_PASSWORD.into(),
660 })
661 .context("Failed to convert API key request to JSON")
662 .unwrap();
663
664 let request = Request::builder()
665 .method("POST")
666 .uri(malwaredb_api::USER_LOGIN_URL)
667 .header(CONTENT_TYPE, "application/json")
668 .body(Body::from(key_request))
669 .unwrap();
670
671 let response = app(state)
672 .oneshot(request)
673 .await
674 .context("failed to send/receive login request")
675 .unwrap();
676
677 assert_eq!(response.status(), StatusCode::OK);
678 let bytes = response
679 .into_body()
680 .collect()
681 .await
682 .expect("failed to collect response body to bytes")
683 .to_bytes();
684 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
685 .context("failed to convert response to string")
686 .unwrap();
687
688 let response: ServerResponse<GetAPIKeyResponse> = serde_json::from_str(&json_response)
689 .context("failed to convert json response to object")
690 .unwrap();
691
692 if let ServerResponse::Success(response) = response {
693 let key = response.key.clone();
694 assert_eq!(key.len(), 64);
695
696 key
697 } else {
698 panic!("failed to get API key response")
699 }
700 }
701
702 #[tokio::test]
703 async fn about_self() {
704 let (state, _) = state(false, false).await;
705 let api_key = get_key(state.clone()).await;
706
707 let request = Request::builder()
708 .method("GET")
709 .uri(malwaredb_api::USER_INFO_URL)
710 .header(MDB_API_HEADER, &api_key)
711 .body(Body::empty())
712 .unwrap();
713
714 let response = app(state.clone())
715 .oneshot(request)
716 .await
717 .context("failed to send/receive login request")
718 .unwrap();
719
720 assert_eq!(response.status(), StatusCode::OK);
721 let bytes = response
722 .into_body()
723 .collect()
724 .await
725 .expect("failed to collect response body to bytes")
726 .to_bytes();
727 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
728 .context("failed to convert response to string")
729 .unwrap();
730
731 let response: ServerResponse<GetUserInfoResponse> = serde_json::from_str(&json_response)
732 .context("failed to convert json response to object")
733 .unwrap();
734
735 let response = response.unwrap();
736 assert_eq!(response.id, 0);
737 assert!(response.is_admin);
738 assert!(!response.is_readonly);
739 assert_eq!(response.username, "admin");
740
741 let request = Request::builder()
743 .method("GET")
744 .uri(malwaredb_api::LIST_LABELS_URL)
745 .header(MDB_API_HEADER, &api_key)
746 .body(Body::empty())
747 .unwrap();
748
749 let response = app(state)
750 .oneshot(request)
751 .await
752 .context("failed to send/receive login request")
753 .unwrap();
754
755 assert_eq!(response.status(), StatusCode::OK);
756 let bytes = response
757 .into_body()
758 .collect()
759 .await
760 .expect("failed to collect response body to bytes")
761 .to_bytes();
762 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
763 .context("failed to convert response to string")
764 .unwrap();
765
766 let response: ServerResponse<Labels> = serde_json::from_str(&json_response)
767 .context("failed to convert json response to object")
768 .unwrap();
769
770 let response = response.unwrap();
771 assert!(response.is_empty());
772 }
773
774 #[rstest]
775 #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true, false)]
776 #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false, false)]
777 #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false, false)]
778 #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false, false)]
779 #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false, false)]
780 #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false, false)]
781 #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false, false)]
782 #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false, false)]
783 #[case::icon_unknown_type_proxy(include_bytes!("../../../../MDB_Logo.ico"), false, false, false, true)]
784 #[tokio::test]
785 async fn submit_sample(
786 #[case] contents: &[u8],
787 #[case] compress: bool,
788 #[case] encrypt: bool,
789 #[case] cart: bool,
790 #[case] should_fail: bool,
791 ) {
792 let (state, source_id) = state(compress, encrypt).await;
793 let api_key = get_key(state.clone()).await;
794
795 let file_contents_b64 = general_purpose::STANDARD.encode(contents);
796 let mut hasher = Sha256::new();
797 hasher.update(contents);
798 let sha256 = hex::encode(hasher.finalize());
799
800 let upload = serde_json::to_string(&NewSampleB64 {
801 file_name: "some_sample".into(),
802 source_id,
803 file_contents_b64,
804 sha256: sha256.clone(),
805 })
806 .context("failed to create upload structure")
807 .unwrap();
808
809 let request = Request::builder()
810 .method("POST")
811 .uri(malwaredb_api::UPLOAD_SAMPLE_JSON_URL)
812 .header(CONTENT_TYPE, "application/json")
813 .header(MDB_API_HEADER, &api_key)
814 .body(Body::from(upload))
815 .unwrap();
816
817 let response = app(state.clone())
818 .oneshot(request)
819 .await
820 .context("failed to send/receive upload request/response")
821 .unwrap();
822
823 if should_fail {
824 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
825 return;
826 }
827
828 assert_eq!(response.status(), StatusCode::OK);
829
830 let sha256 = if cart {
831 let mut input_buffer = Cursor::new(contents);
834 let mut output_buffer = Cursor::new(vec![]);
835 let (_, footer) =
836 cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
837 .expect("failed to decode CaRT file");
838 let footer = footer.expect("CaRT should have had a footer");
839 let sha256 = footer
840 .get("sha256")
841 .expect("CaRT footer should have had an entry for SHA-256")
842 .to_string();
843 sha256.replace('"', "") } else {
845 sha256
846 };
847
848 if let Some(dir) = &state.directory {
849 let mut sample_path = dir.clone();
850 sample_path.push(format!(
851 "{}/{}/{}/{}",
852 &sha256[0..2],
853 &sha256[2..4],
854 &sha256[4..6],
855 sha256
856 ));
857 eprintln!("Submitted sample should exist at {sample_path:?}.");
858 assert!(sample_path.exists());
859
860 if compress {
861 let sample_size_on_disk = sample_path.metadata().unwrap().len();
862 eprintln!(
863 "Original size: {}, compressed: {}",
864 contents.len(),
865 sample_size_on_disk
866 );
867 assert!(sample_size_on_disk < contents.len() as u64);
868 }
869 } else {
870 panic!("Directory was set for the state, but is now `None`");
871 }
872
873 let request = Request::builder()
874 .method("GET")
875 .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT_URL))
876 .header(MDB_API_HEADER, &api_key)
877 .body(Body::empty())
878 .unwrap();
879
880 let response = app(state.clone())
881 .oneshot(request)
882 .await
883 .context("failed to send/receive upload request/response")
884 .unwrap();
885
886 println!("Response headers: {:?}", response.headers());
887
888 let bytes = response
889 .into_body()
890 .collect()
891 .await
892 .expect("failed to collect response body to bytes")
893 .to_bytes();
894 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
895 .context("failed to convert response to string")
896 .unwrap();
897
898 let report: ServerResponse<Report> = serde_json::from_str(&json_response)
899 .context("failed to convert json response to object")
900 .unwrap();
901 let report = report.unwrap();
902
903 assert_eq!(report.sha256, sha256);
904 println!("Report: {report}");
905
906 let request = Request::builder()
907 .method("GET")
908 .uri(format!(
909 "{}/{sha256}",
910 malwaredb_api::DOWNLOAD_SAMPLE_CART_URL
911 ))
912 .header(MDB_API_HEADER, api_key)
913 .body(Body::empty())
914 .unwrap();
915
916 let response = app(state.clone())
917 .oneshot(request)
918 .await
919 .context("failed to send/receive upload request/response for CaRT")
920 .unwrap();
921
922 println!("Response headers: {:?}", response.headers());
923
924 let bytes = response
925 .into_body()
926 .collect()
927 .await
928 .expect("failed to collect response body to bytes")
929 .to_bytes();
930
931 let bytes = bytes.to_vec();
932 let bytes_input = Cursor::new(bytes);
933 let output = Cursor::new(vec![]);
934 match cart_container::unpack_stream(bytes_input, output, None) {
935 Ok((header, _)) => {
936 let header = header.unwrap();
937 assert_eq!(
938 header.get("sha384"),
939 Some(&serde_json::to_value(report.sha384).unwrap())
940 );
941 }
942 Err(e) => panic!("{e}"),
943 }
944 }
945
946 #[rstest]
951 #[case::ssl_pem(true)]
952 #[case::ssl_der(false)]
953 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
954 async fn client_integration(#[case] pem: bool) {
955 TRACING.call_once(init_tracing);
956
957 let (state, source_id, token) = state_and_token(pem).await;
960 let state_port = state.port; let server = tokio::spawn(async move {
963 state
964 .serve()
965 .await
966 .expect("MalwareDB failed to .serve() in tokio::spawn()");
967 });
968 assert!(!server.is_finished());
969
970 tokio::time::sleep(std::time::Duration::new(1, 0)).await;
972
973 println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
974 assert_eq!(
975 *malwaredb_client::MDB_VERSION_SEMVER,
976 *crate::MDB_VERSION_SEMVER,
977 "SemVer parsing of MDB version failed"
978 );
979
980 let mdb_client = malwaredb_client::MdbClient::new(
981 format!("https://127.0.0.1:{state_port}"),
982 token.clone(),
983 Some("../../testdata/ca_cert.pem".into()),
984 )
985 .unwrap();
986
987 assert!(!mdb_client.supported_types().await.unwrap().types.is_empty());
988 assert!(mdb_client.server_info().await.is_ok());
989
990 let start = Instant::now();
991 assert!(mdb_client
992 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
993 .await
994 .context("failed to upload test file")
995 .unwrap());
996 let duration = start.elapsed();
997 println!("Initial upload and database record creation via base64 took {duration:?}");
998
999 let start = Instant::now();
1000 mdb_client
1001 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1002 .await
1003 .context("failed to upload test file")
1004 .unwrap();
1005 let duration = start.elapsed();
1006 println!("Upload again via base64 took {duration:?}");
1007
1008 let start = Instant::now();
1009 mdb_client
1010 .submit_as_cbor(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1011 .await
1012 .context("failed to upload test file")
1013 .unwrap();
1014 let duration = start.elapsed();
1015 println!("Upload again via cbor took {duration:?}");
1016
1017 let report = mdb_client
1018 .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1019 .await
1020 .expect("failed to get report for file just submitted");
1021 assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1022
1023 let similar = mdb_client
1024 .similar(SAMPLE_BYTES)
1025 .await
1026 .expect("failed to query for files similar to what was just submitted");
1027 assert_eq!(similar.results.len(), 1);
1028
1029 let search = mdb_client
1030 .partial_search(
1031 Some((PartialHashSearchType::Any, "AAAA".into())),
1032 None,
1033 PartialHashSearchType::Any,
1034 10,
1035 )
1036 .await
1037 .unwrap();
1038 assert!(search.hashes.is_empty());
1039
1040 mdb_client.reset_key().await.expect("failed to reset key");
1041 server.abort();
1042 }
1043
1044 #[allow(clippy::too_many_lines)]
1045 #[test]
1046 #[ignore = "don't run this in CI"]
1047 fn client_integration_blocking() {
1048 TRACING.call_once(init_tracing);
1049 let token = Arc::new(RwLock::new(String::new()));
1050 let token_clone = token.clone();
1051
1052 let thread = std::thread::spawn(move || {
1058 let rt = tokio::runtime::Builder::new_multi_thread()
1059 .enable_all()
1060 .build()
1061 .unwrap();
1062
1063 let mut db_file = env::temp_dir();
1064 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
1065 if std::path::Path::new(&db_file).exists() {
1066 fs::remove_file(&db_file)
1067 .context(format!("failed to delete old SQLite file {db_file:?}"))
1068 .unwrap();
1069 }
1070
1071 let (db_type, db_config) = rt.block_on(async {
1072 let db_type =
1073 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
1074 .await
1075 .context(format!("failed to create SQLite instance for {db_file:?}"))
1076 .unwrap();
1077
1078 let db_config = db_type.get_config().await.unwrap();
1079 (db_type, db_config)
1080 });
1081
1082 let state = State {
1083 port: 9090,
1084 directory: Some(
1085 tempfile::TempDir::with_prefix("mdb-temp-samples")
1086 .unwrap()
1087 .path()
1088 .into(),
1089 ),
1090 max_upload: 10 * 1024 * 1024,
1091 ip: "127.0.0.1".parse().unwrap(),
1092 db_type: Arc::new(db_type),
1093 db_config,
1094 keys: HashMap::default(),
1095 started: SystemTime::now(),
1096 #[cfg(feature = "vt")]
1097 vt_client: None,
1098 tls_config: None,
1099 mdns: None,
1100 };
1101
1102 rt.block_on(async {
1103 state
1104 .db_type
1105 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1106 .await
1107 .context("Failed to set admin password")
1108 .unwrap();
1109
1110 let source_id = state
1111 .db_type
1112 .create_source("temp-source", None, None, Local::now(), true, Some(false))
1113 .await
1114 .unwrap();
1115
1116 state
1117 .db_type
1118 .add_group_to_source(0, source_id)
1119 .await
1120 .unwrap();
1121
1122 let token_string = state
1123 .db_type
1124 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1125 .await
1126 .unwrap();
1127
1128 if let Ok(mut token_lock) = token_clone.write() {
1129 *token_lock = token_string;
1130 }
1131
1132 state
1133 .serve()
1134 .await
1135 .expect("MalwareDB failed to .serve() in tokio::spawn()");
1136 });
1137 });
1138 std::thread::sleep(std::time::Duration::from_secs(1));
1139
1140 let mdb_client = malwaredb_client::blocking::MdbClient::new(
1141 String::from("http://127.0.0.1:9090"),
1142 token.read().unwrap().clone(),
1143 None,
1144 )
1145 .unwrap();
1146
1147 let types = match mdb_client.supported_types() {
1148 Ok(types) => types,
1149 Err(e) => panic!("{e}"),
1150 };
1151 assert!(!types.types.is_empty());
1152
1153 assert!(mdb_client
1154 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), 1)
1155 .context("failed to upload test file")
1156 .unwrap());
1157
1158 let report = mdb_client
1159 .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1160 .expect("failed to get report for file just submitted");
1161 assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1162
1163 let similar = mdb_client
1164 .similar(SAMPLE_BYTES)
1165 .expect("failed to query for files similar to what was just submitted");
1166 assert_eq!(similar.results.len(), 1);
1167
1168 let search = mdb_client
1169 .partial_search(
1170 Some((PartialHashSearchType::Any, "AAAA".into())),
1171 None,
1172 PartialHashSearchType::Any,
1173 10,
1174 )
1175 .unwrap();
1176 assert!(search.hashes.is_empty());
1177
1178 mdb_client.reset_key().expect("failed to reset key");
1179 drop(thread); }
1181}