1use super::State;
4use crate::db::types::FileTypes;
5use malwaredb_api::digest::HashType;
6use malwaredb_api::{
7 GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSample, Report,
8 SearchRequest, ServerInfo, Sources, SupportedFileTypes, MDB_API_HEADER,
9};
10
11use std::fmt::{Display, Formatter};
12use std::io::Cursor;
13use std::iter::once;
14use std::sync::Arc;
15
16use axum::body::Bytes;
17use axum::extract::{DefaultBodyLimit, Path};
18use axum::http::{header, StatusCode};
19use axum::response::IntoResponse;
20use axum::routing::{get, post};
21use axum::{Extension, Json, Router};
22use base64::{engine::general_purpose, Engine as _};
23use constcat::concat;
24use http::{HeaderMap, HeaderName};
25use sha2::{Digest, Sha256};
26use tower_http::compression::CompressionLayer;
27use tower_http::decompression::DecompressionLayer;
28use tower_http::limit::RequestBodyLimitLayer;
29use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
30
31mod receive;
32
33pub fn app(state: Arc<State>) -> Router {
35 const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSample>>() * 2;
37
38 let compression_layer = CompressionLayer::new()
39 .br(true)
40 .deflate(true)
41 .gzip(true)
42 .zstd(true);
43
44 let decompression_layer = DecompressionLayer::new()
45 .br(true)
46 .deflate(true)
47 .gzip(true)
48 .zstd(true);
49
50 let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
51
52 Router::new()
53 .route("/", get(health))
54 .route(malwaredb_api::SERVER_INFO, get(get_mdb_info))
55 .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
56 .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
57 .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
58 .route(
59 malwaredb_api::SUPPORTED_FILE_TYPES,
60 get(get_supported_types),
61 )
62 .route(malwaredb_api::LIST_LABELS, get(get_labels))
63 .route(malwaredb_api::LIST_SOURCES, get(get_sources))
64 .route(malwaredb_api::UPLOAD_SAMPLE, post(get_new_sample))
65 .route(malwaredb_api::SEARCH, post(sample_search))
66 .route(
67 concat!(malwaredb_api::DOWNLOAD_SAMPLE, "/{hash}"),
68 get(download_sample),
69 )
70 .route(
71 concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART, "/{hash}"),
72 get(download_sample_cart),
73 )
74 .route(
75 concat!(malwaredb_api::SAMPLE_REPORT, "/{hash}"),
76 get(sample_report),
77 )
78 .route(malwaredb_api::SIMILAR_SAMPLES, post(find_similar))
79 .layer(DefaultBodyLimit::max(state.max_upload))
80 .layer(compression_layer)
81 .layer(decompression_layer)
82 .layer(size_limit_layer)
83 .layer(SetSensitiveHeadersLayer::new(once(
84 HeaderName::from_static(MDB_API_HEADER),
85 )))
86 .layer(Extension(state))
87}
88
89async fn health() -> StatusCode {
90 StatusCode::OK
91}
92
93async fn get_mdb_info(
94 Extension(state): Extension<Arc<State>>,
95) -> Result<Json<ServerInfo>, HttpError> {
96 let server_info = state.get_info().await?;
97 Ok(Json(server_info))
98}
99
100async fn user_login(
101 Extension(state): Extension<Arc<State>>,
102 Json(payload): Json<GetAPIKeyRequest>,
103) -> Result<Json<GetAPIKeyResponse>, HttpError> {
104 let api_key = state
105 .db_type
106 .authenticate(&payload.user, &payload.password)
107 .await?;
108
109 Ok(Json(GetAPIKeyResponse {
110 key: Some(api_key),
111 message: None,
112 }))
113}
114
115async fn user_logout(
116 Extension(state): Extension<Arc<State>>,
117 headers: HeaderMap,
118) -> Result<StatusCode, HttpError> {
119 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
120 anyhow::Error::msg("Missing API key"),
121 StatusCode::NOT_ACCEPTABLE,
122 ))?;
123 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
124 state.db_type.reset_own_api_key(uid).await?;
125 Ok(StatusCode::OK)
126}
127
128async fn get_user_groups_sources(
129 Extension(state): Extension<Arc<State>>,
130 headers: HeaderMap,
131) -> Result<Json<GetUserInfoResponse>, HttpError> {
132 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
133 anyhow::Error::msg("Missing API key"),
134 StatusCode::NOT_ACCEPTABLE,
135 ))?;
136 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
137 let groups_sources = state.db_type.get_user_info(uid).await?;
138 Ok(Json(groups_sources))
139}
140
141async fn get_supported_types(
142 Extension(state): Extension<Arc<State>>,
143) -> Result<Json<SupportedFileTypes>, HttpError> {
144 let data_types = state.db_type.get_known_data_types().await?;
145 let file_types = FileTypes(data_types);
146 Ok(Json(file_types.into()))
147}
148
149async fn get_labels(
150 Extension(state): Extension<Arc<State>>,
151 headers: HeaderMap,
152) -> Result<Json<Labels>, HttpError> {
153 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
154 anyhow::Error::msg("Missing API key"),
155 StatusCode::NOT_ACCEPTABLE,
156 ))?;
157 let _uid = state.db_type.get_uid(key.to_str().unwrap()).await?;
158 let labels = state.db_type.get_labels().await?;
159 Ok(Json(labels))
160}
161
162async fn get_sources(
163 Extension(state): Extension<Arc<State>>,
164 headers: HeaderMap,
165) -> Result<Json<Sources>, HttpError> {
166 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
167 anyhow::Error::msg("Missing API key"),
168 StatusCode::NOT_ACCEPTABLE,
169 ))?;
170 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
171 let sources = state.db_type.get_user_sources(uid).await?;
172 Ok(Json(sources))
173}
174
175async fn get_new_sample(
176 Extension(state): Extension<Arc<State>>,
177 headers: HeaderMap,
178 Json(payload): Json<NewSample>,
179) -> Result<StatusCode, HttpError> {
180 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
181 anyhow::Error::msg("Missing API key"),
182 StatusCode::NOT_ACCEPTABLE,
183 ))?;
184 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
185
186 let allowed = state
187 .db_type
188 .allowed_user_source(uid, payload.source_id)
189 .await?;
190
191 if !allowed {
192 return Err(HttpError(
193 anyhow::Error::msg("Unauthorized"),
194 StatusCode::UNAUTHORIZED,
195 ));
196 }
197
198 let received_hash = hex::decode(&payload.sha256)?;
199 let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
200
201 let mut hasher = Sha256::new();
202 hasher.update(&bytes);
203 let result = hasher.finalize();
204
205 if result[..] != received_hash[..] {
206 return Err(HttpError(
207 anyhow::Error::msg("Hash mismatch"),
208 StatusCode::NOT_ACCEPTABLE,
209 ));
210 }
211
212 receive::incoming_sample(
213 state.clone(),
214 bytes,
215 uid,
216 payload.source_id,
217 payload.file_name,
218 )
219 .await?;
220
221 Ok(StatusCode::OK)
222}
223
224async fn sample_search(
225 Extension(state): Extension<Arc<State>>,
226 headers: HeaderMap,
227 Json(payload): Json<SearchRequest>,
228) -> Result<Json<Vec<String>>, HttpError> {
229 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
230 anyhow::Error::msg("Missing API key"),
231 StatusCode::NOT_ACCEPTABLE,
232 ))?;
233 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
234
235 let hashes = state.db_type.partial_search(uid, &payload).await?;
236 Ok(Json(hashes))
237}
238
239async fn download_sample(
240 Path(hash): Path<String>,
241 headers: HeaderMap,
242 Extension(state): Extension<Arc<State>>,
243) -> Result<impl IntoResponse, HttpError> {
244 if state.directory.is_none() {
245 return Err(HttpError(
246 anyhow::Error::msg("Server does not store samples"),
247 StatusCode::NOT_ACCEPTABLE,
248 ));
249 }
250
251 let hash = HashType::try_from(hash)?;
252 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
253 anyhow::Error::msg("Missing API key"),
254 StatusCode::NOT_ACCEPTABLE,
255 ))?;
256
257 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
258 let sha256 = state.db_type.retrieve_sample(uid, &hash).await?;
259
260 let contents = state.retrieve_bytes(&sha256).await?;
261
262 let mut bytes = Bytes::from(contents).into_response();
263 let name_header_value = format!("attachment; filename=\"{sha256}\"");
264 bytes.headers_mut().insert(
265 header::CONTENT_DISPOSITION,
266 http::HeaderValue::from_str(&name_header_value)
267 .unwrap_or(http::HeaderValue::from_static("Unknown.bin")),
268 );
269
270 Ok(bytes)
271}
272
273async fn download_sample_cart(
274 Path(hash): Path<String>,
275 headers: HeaderMap,
276 Extension(state): Extension<Arc<State>>,
277) -> Result<impl IntoResponse, HttpError> {
278 if state.directory.is_none() {
279 return Err(HttpError(
280 anyhow::Error::msg("Server does not store samples"),
281 StatusCode::NOT_ACCEPTABLE,
282 ));
283 }
284
285 let hash = HashType::try_from(hash)?;
286 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
287 anyhow::Error::msg("Missing API key"),
288 StatusCode::NOT_ACCEPTABLE,
289 ))?;
290
291 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
292 let sha256 = state.db_type.retrieve_sample(uid, &hash).await?;
293 let report = state.db_type.get_sample_report(uid, &hash).await?;
294
295 let contents = state.retrieve_bytes(&sha256).await?;
296 let contents_cursor = Cursor::new(contents);
297 let mut output_cursor = Cursor::new(vec![]);
298
299 let mut output_metadata = cart_container::JsonMap::new();
300 output_metadata.insert("sha384".into(), report.sha384.into());
301 output_metadata.insert("sha512".into(), report.sha512.into());
302 output_metadata.insert("entropy".into(), report.entropy.into());
303 if let Some(filecmd) = report.filecommand {
304 output_metadata.insert("file".into(), filecmd.into());
305 }
306
307 cart_container::pack_stream(
308 contents_cursor,
309 &mut output_cursor,
310 Some(output_metadata),
311 None,
312 cart_container::digesters::default_digesters(), None,
314 )?;
315
316 let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
317 let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
318 bytes.headers_mut().insert(
319 header::CONTENT_DISPOSITION,
320 http::HeaderValue::from_str(&name_header_value)
321 .unwrap_or(http::HeaderValue::from_static("Unknown.cart")),
322 );
323
324 Ok(bytes)
325}
326
327async fn sample_report(
328 Path(hash): Path<String>,
329 headers: HeaderMap,
330 Extension(state): Extension<Arc<State>>,
331) -> Result<Json<Report>, HttpError> {
332 let hash = HashType::try_from(hash)?;
333 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
334 anyhow::Error::msg("Missing API key"),
335 StatusCode::NOT_ACCEPTABLE,
336 ))?;
337
338 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
339 let report = state.db_type.get_sample_report(uid, &hash).await?;
340 Ok(Json(report))
341}
342
343async fn find_similar(
344 Extension(state): Extension<Arc<State>>,
345 headers: HeaderMap,
346 Json(payload): Json<malwaredb_api::SimilarSamplesRequest>,
347) -> Result<Json<malwaredb_api::SimilarSamplesResponse>, HttpError> {
348 let key = headers.get(malwaredb_api::MDB_API_HEADER).ok_or(HttpError(
349 anyhow::Error::msg("Missing API key"),
350 StatusCode::NOT_ACCEPTABLE,
351 ))?;
352
353 let uid = state.db_type.get_uid(key.to_str().unwrap_or("")).await?;
354
355 let results = state
356 .db_type
357 .find_similar_samples(uid, &payload.hashes)
358 .await?;
359
360 Ok(Json(malwaredb_api::SimilarSamplesResponse {
361 results,
362 message: None,
363 }))
364}
365
366pub struct HttpError(pub anyhow::Error, pub StatusCode);
371
372impl IntoResponse for HttpError {
374 fn into_response(self) -> axum::response::Response {
375 (self.1, format!("MDB error: {}", self.0)).into_response()
376 }
377}
378
379impl Display for HttpError {
380 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
381 write!(f, "{}", self.0)
382 }
383}
384
385impl<E> From<E> for HttpError
387where
388 E: Into<anyhow::Error>,
389{
390 fn from(err: E) -> Self {
391 Self(err.into(), StatusCode::INTERNAL_SERVER_ERROR)
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use malwaredb_client::MdbClient;
398
399 use super::*;
400 use crate::crypto::{EncryptionOption, FileEncryption};
401 use crate::db::DatabaseType;
402 use malwaredb_api::PartialHashSearchType;
403
404 use std::collections::HashMap;
405 use std::sync::Once;
406 use std::time::SystemTime;
407 use std::{env, fs};
408
409 use anyhow::Context;
410 use axum::body::Body;
411 use axum::http::Request;
412 use chrono::Local;
413 use http::header::CONTENT_TYPE;
414 use http_body_util::BodyExt;
415 use rstest::rstest;
416 use tower::ServiceExt;
417 use uuid::Uuid;
419
420 const ADMIN_UNAME: &str = "admin";
421 const ADMIN_PASSWORD: &str = "password12345";
422
423 static TRACING: Once = Once::new();
424
425 fn init_tracing() {
426 tracing_subscriber::fmt()
427 .with_max_level(tracing::Level::TRACE)
428 .init();
429 }
430
431 async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
432 let mut db_file = env::temp_dir();
434 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
435 if std::path::Path::new(&db_file).exists() {
436 fs::remove_file(&db_file)
437 .context(format!("failed to delete old SQLite file {db_file:?}"))
438 .unwrap();
439 }
440
441 let db_type =
442 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
443 .await
444 .context(format!("failed to create SQLite instance for {db_file:?}"))
445 .unwrap();
446 if compress {
447 db_type.enable_compression().await.unwrap();
448 }
449 let keys = if encrypt {
450 let key = FileEncryption::from(EncryptionOption::Xor);
451 let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
452 let mut keys = HashMap::new();
453 keys.insert(key_id, key);
454 keys
455 } else {
456 HashMap::new()
457 };
458 let db_config = db_type.get_config().await.unwrap();
459
460 let state = State {
461 port: 8080,
462 directory: Some(
463 tempfile::TempDir::with_prefix("mdb-temp-samples")
464 .unwrap()
465 .path()
466 .into(),
467 ),
468 max_upload: 10 * 1024 * 1024,
469 ip: "127.0.0.1".parse().unwrap(),
470 db_type,
471 db_config,
472 keys,
473 started: SystemTime::now(),
474 #[cfg(feature = "vt")]
475 vt_client: None,
476 cert: None,
477 key: None,
478 };
479
480 state
481 .db_type
482 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
483 .await
484 .context("Failed to set admin password")
485 .unwrap();
486
487 let source_id = state
488 .db_type
489 .create_source("temp-source", None, None, Local::now(), true, Some(false))
490 .await
491 .unwrap();
492
493 state
494 .db_type
495 .add_group_to_source(0, source_id)
496 .await
497 .unwrap();
498
499 (Arc::new(state), source_id)
500 }
501
502 async fn state_and_token(pem_file: bool) -> (State, u32, String) {
504 let mut db_file = env::temp_dir();
505 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
506 if std::path::Path::new(&db_file).exists() {
507 fs::remove_file(&db_file)
508 .context(format!("failed to delete old SQLite file {db_file:?}"))
509 .unwrap();
510 }
511
512 let db_type =
513 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
514 .await
515 .context(format!("failed to create SQLite instance for {db_file:?}"))
516 .unwrap();
517
518 let db_config = db_type.get_config().await.unwrap();
519
520 let state = if pem_file {
521 State {
522 port: 8443,
523 directory: Some(
524 tempfile::TempDir::with_prefix("mdb-temp-samples")
525 .unwrap()
526 .path()
527 .into(),
528 ),
529 max_upload: 10 * 1024 * 1024,
530 ip: "127.0.0.1".parse().unwrap(),
531 db_type,
532 db_config,
533 keys: HashMap::default(),
534 started: SystemTime::now(),
535 #[cfg(feature = "vt")]
536 vt_client: None,
537 cert: Some("../../testdata/server_ca_cert.pem".into()),
538 key: Some("../../testdata/server_key.pem".into()),
539 }
540 } else {
541 State {
542 port: 8444,
543 directory: Some(
544 tempfile::TempDir::with_prefix("mdb-temp-samples")
545 .unwrap()
546 .path()
547 .into(),
548 ),
549 max_upload: 10 * 1024 * 1024,
550 ip: "127.0.0.1".parse().unwrap(),
551 db_type,
552 db_config,
553 keys: HashMap::default(),
554 started: SystemTime::now(),
555 #[cfg(feature = "vt")]
556 vt_client: None,
557 cert: Some("../../testdata/server_cert.der".into()),
558 key: Some("../../testdata/server_key.der".into()),
559 }
560 };
561
562 state
563 .db_type
564 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
565 .await
566 .context("Failed to set admin password")
567 .unwrap();
568
569 let source_id = state
570 .db_type
571 .create_source("temp-source", None, None, Local::now(), true, Some(false))
572 .await
573 .unwrap();
574
575 state
576 .db_type
577 .add_group_to_source(0, source_id)
578 .await
579 .unwrap();
580
581 let token = state
582 .db_type
583 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
584 .await
585 .unwrap();
586
587 (state, source_id, token)
588 }
589
590 async fn get_key(state: Arc<State>) -> String {
591 let key_request = serde_json::to_string(&GetAPIKeyRequest {
592 user: ADMIN_UNAME.into(),
593 password: ADMIN_PASSWORD.into(),
594 })
595 .context("Failed to convert API key request to JSON")
596 .unwrap();
597
598 let request = Request::builder()
599 .method("POST")
600 .uri(malwaredb_api::USER_LOGIN_URL)
601 .header(CONTENT_TYPE, "application/json")
602 .body(Body::from(key_request))
603 .unwrap();
604
605 let response = app(state)
606 .oneshot(request)
607 .await
608 .context("failed to send/receive login request")
609 .unwrap();
610
611 assert_eq!(response.status(), StatusCode::OK);
612 let bytes = response
613 .into_body()
614 .collect()
615 .await
616 .expect("failed to collect response body to bytes")
617 .to_bytes();
618 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
619 .context("failed to convert response to string")
620 .unwrap();
621
622 let response: GetAPIKeyResponse = serde_json::from_str(&json_response)
623 .context("failed to convert json response to object")
624 .unwrap();
625
626 let key = response.key.clone().unwrap();
627 assert_eq!(key.len(), 64);
628
629 key
630 }
631
632 #[tokio::test]
633 async fn about_self() {
634 let (state, _) = state(false, false).await;
635 let api_key = get_key(state.clone()).await;
636
637 let request = Request::builder()
638 .method("GET")
639 .uri(malwaredb_api::USER_INFO_URL)
640 .header(MDB_API_HEADER, &api_key)
641 .body(Body::empty())
642 .unwrap();
643
644 let response = app(state.clone())
645 .oneshot(request)
646 .await
647 .context("failed to send/receive login request")
648 .unwrap();
649
650 assert_eq!(response.status(), StatusCode::OK);
651 let bytes = response
652 .into_body()
653 .collect()
654 .await
655 .expect("failed to collect response body to bytes")
656 .to_bytes();
657 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
658 .context("failed to convert response to string")
659 .unwrap();
660
661 let response: GetUserInfoResponse = serde_json::from_str(&json_response)
662 .context("failed to convert json response to object")
663 .unwrap();
664
665 assert_eq!(response.id, 0);
666 assert!(response.is_admin);
667 assert!(!response.is_readonly);
668 assert_eq!(response.username, "admin");
669
670 let request = Request::builder()
672 .method("GET")
673 .uri(malwaredb_api::LIST_LABELS)
674 .header(MDB_API_HEADER, &api_key)
675 .body(Body::empty())
676 .unwrap();
677
678 let response = app(state)
679 .oneshot(request)
680 .await
681 .context("failed to send/receive login request")
682 .unwrap();
683
684 assert_eq!(response.status(), StatusCode::OK);
685 let bytes = response
686 .into_body()
687 .collect()
688 .await
689 .expect("failed to collect response body to bytes")
690 .to_bytes();
691 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
692 .context("failed to convert response to string")
693 .unwrap();
694
695 let response: Labels = serde_json::from_str(&json_response)
696 .context("failed to convert json response to object")
697 .unwrap();
698
699 assert!(response.is_empty());
700 }
701
702 #[rstest]
703 #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true)]
704 #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false)]
705 #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false)]
706 #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false)]
707 #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false)]
708 #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false)]
709 #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false)]
710 #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false)]
711 #[tokio::test]
712 async fn submit_sample(
713 #[case] contents: &[u8],
714 #[case] compress: bool,
715 #[case] encrypt: bool,
716 #[case] cart: bool,
717 ) {
718 let (state, source_id) = state(compress, encrypt).await;
719 let api_key = get_key(state.clone()).await;
720
721 let file_contents_b64 = general_purpose::STANDARD.encode(contents);
722 let mut hasher = Sha256::new();
723 hasher.update(contents);
724 let sha256 = hex::encode(hasher.finalize());
725
726 let upload = serde_json::to_string(&NewSample {
727 file_name: "some_sample".into(),
728 source_id,
729 file_contents_b64,
730 sha256: sha256.clone(),
731 })
732 .context("failed to create upload structure")
733 .unwrap();
734
735 let request = Request::builder()
736 .method("POST")
737 .uri(malwaredb_api::UPLOAD_SAMPLE)
738 .header(CONTENT_TYPE, "application/json")
739 .header(MDB_API_HEADER, &api_key)
740 .body(Body::from(upload))
741 .unwrap();
742
743 let response = app(state.clone())
744 .oneshot(request)
745 .await
746 .context("failed to send/receive upload request/response")
747 .unwrap();
748
749 assert_eq!(response.status(), StatusCode::OK);
750
751 let sha256 = if cart {
752 let mut input_buffer = Cursor::new(contents);
755 let mut output_buffer = Cursor::new(vec![]);
756 let (_, footer) =
757 cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
758 .expect("failed to decode CaRT file");
759 let footer = footer.expect("CaRT should have had a footer");
760 let sha256 = footer
761 .get("sha256")
762 .expect("CaRT footer should have had an entry for SHA-256")
763 .to_string();
764 sha256.replace('"', "") } else {
766 sha256
767 };
768
769 if let Some(dir) = &state.directory {
770 let mut sample_path = dir.clone();
771 sample_path.push(format!(
772 "{}/{}/{}/{}",
773 &sha256[0..2],
774 &sha256[2..4],
775 &sha256[4..6],
776 sha256
777 ));
778 eprintln!("Submitted sample should exist at {sample_path:?}.");
779 assert!(sample_path.exists());
780
781 if compress {
782 let sample_size_on_disk = sample_path.metadata().unwrap().len();
783 eprintln!(
784 "Original size: {}, compressed: {}",
785 contents.len(),
786 sample_size_on_disk
787 );
788 assert!(sample_size_on_disk < contents.len() as u64);
789 }
790 } else {
791 panic!("Directory was set for the state, but is now `None`");
792 }
793
794 let request = Request::builder()
795 .method("GET")
796 .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT))
797 .header(MDB_API_HEADER, &api_key)
798 .body(Body::empty())
799 .unwrap();
800
801 let response = app(state.clone())
802 .oneshot(request)
803 .await
804 .context("failed to send/receive upload request/response")
805 .unwrap();
806
807 let bytes = response
808 .into_body()
809 .collect()
810 .await
811 .expect("failed to collect response body to bytes")
812 .to_bytes();
813 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
814 .context("failed to convert response to string")
815 .unwrap();
816
817 let report: Report = serde_json::from_str(&json_response)
818 .context("failed to convert json response to object")
819 .unwrap();
820
821 assert_eq!(report.sha256, sha256);
822 println!("Report: {report}");
823
824 let request = Request::builder()
825 .method("GET")
826 .uri(format!("{}/{sha256}", malwaredb_api::DOWNLOAD_SAMPLE_CART))
827 .header(MDB_API_HEADER, api_key)
828 .body(Body::empty())
829 .unwrap();
830
831 let response = app(state.clone())
832 .oneshot(request)
833 .await
834 .context("failed to send/receive upload request/response for CaRT")
835 .unwrap();
836
837 let bytes = response
838 .into_body()
839 .collect()
840 .await
841 .expect("failed to collect response body to bytes")
842 .to_bytes();
843
844 let bytes = bytes.to_vec();
845 let bytes_input = Cursor::new(bytes);
846 let output = Cursor::new(vec![]);
847 match cart_container::unpack_stream(bytes_input, output, None) {
848 Ok((header, _)) => {
849 let header = header.unwrap();
850 assert_eq!(
851 header.get("sha384"),
852 Some(&serde_json::to_value(report.sha384).unwrap())
853 );
854 }
855 Err(e) => panic!("{e}"),
856 }
857 }
858
859 #[rstest]
864 #[case::ssl_pem(true)]
865 #[case::ssl_der(false)]
866 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
867 async fn client_integration(#[case] pem: bool) {
868 TRACING.call_once(init_tracing);
869
870 let (state, source_id, token) = state_and_token(pem).await;
873 let state_port = state.port; let server = tokio::spawn(async move {
876 state
877 .serve()
878 .await
879 .expect("MalwareDB failed to .serve() in tokio::spawn()");
880 });
881 assert!(!server.is_finished());
882
883 tokio::time::sleep(std::time::Duration::new(1, 0)).await;
885
886 println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
887 assert_eq!(
888 *malwaredb_client::MDB_VERSION_SEMVER,
889 *crate::MDB_VERSION_SEMVER,
890 "SemVer parsing of MDB version failed"
891 );
892
893 let mdb_client_async = MdbClient::new(
894 format!("https://127.0.0.1:{state_port}"),
895 token.clone(),
896 Some("../../testdata/ca_cert.pem".into()),
897 )
898 .unwrap();
899 let mdb_client_blocking = MdbClient::new(
900 format!("https://127.0.0.1:{state_port}"),
901 token,
902 Some("../../testdata/ca_cert.pem".into()),
903 )
904 .unwrap();
905
906 for (index, mdb_client) in [mdb_client_async, mdb_client_blocking]
907 .into_iter()
908 .enumerate()
909 {
910 let contents = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
911
912 assert!(mdb_client
913 .submit(contents, String::from("elf_haiku_x86"), source_id)
914 .await
915 .context("failed to upload test file")
916 .unwrap());
917
918 let report = mdb_client
919 .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
920 .await
921 .expect("failed to get report for file just submitted");
922
923 assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
924
925 let similar = mdb_client
926 .similar(contents)
927 .await
928 .expect("failed to query for files similar to what was just submitted");
929
930 assert_eq!(similar.results.len(), 1);
931
932 let search = mdb_client
933 .partial_search(
934 Some((PartialHashSearchType::Any, "AAAA".into())),
935 None,
936 PartialHashSearchType::Any,
937 10,
938 )
939 .await
940 .unwrap();
941 assert!(search.is_empty());
942
943 if index > 0 {
944 mdb_client.reset_key().await.expect("failed to reset key");
945 }
946 }
947 server.abort();
948 }
949}