1use super::State;
4use malwaredb_api::digest::HashType;
5use malwaredb_api::{
6 GetAPIKeyRequest, GetAPIKeyResponse, GetUserInfoResponse, Labels, NewSampleB64, NewSampleBytes,
7 Report, SearchRequest, SearchResponse, ServerError, ServerInfo, ServerResponse,
8 SimilarSamplesRequest, SimilarSamplesResponse, Sources, SupportedFileTypes, YaraSearchRequest,
9 YaraSearchRequestResponse, YaraSearchResponse, MDB_API_HEADER,
10};
11
12use std::fmt::{Display, Formatter};
13use std::io::Cursor;
14use std::iter::once;
15use std::sync::Arc;
16
17use axum::body::Bytes;
18use axum::extract::{DefaultBodyLimit, Path, Request};
19use axum::http::{header, StatusCode};
20use axum::middleware::Next;
21use axum::response::{IntoResponse, Response};
22use axum::routing::{get, post};
23use axum::{middleware, Extension, Json, Router};
24use axum_cbor::Cbor;
25use base64::{engine::general_purpose, Engine as _};
26use constcat::concat;
27use http::{HeaderMap, HeaderName, HeaderValue};
28use sha2::{Digest, Sha256};
29use tower_http::compression::CompressionLayer;
30use tower_http::decompression::DecompressionLayer;
31use tower_http::limit::RequestBodyLimitLayer;
32use tower_http::sensitive_headers::SetSensitiveHeadersLayer;
33use uuid::Uuid;
34
35mod receive;
36
37const FAVICON_URL: &str = "/favicon.ico";
38
39pub fn app(state: Arc<State>) -> Router {
41 const UPLOAD_OVERHEAD: usize = std::mem::size_of::<Json<NewSampleB64>>() * 2;
43
44 let compression_layer = CompressionLayer::new()
45 .br(true)
46 .deflate(true)
47 .gzip(true)
48 .zstd(true);
49
50 let decompression_layer = DecompressionLayer::new()
51 .br(true)
52 .deflate(true)
53 .gzip(true)
54 .zstd(true);
55
56 let size_limit_layer = RequestBodyLimitLayer::new(UPLOAD_OVERHEAD + state.max_upload);
57
58 Router::new()
59 .route("/", get(health))
60 .route(FAVICON_URL, get(favicon))
61 .route(malwaredb_api::SERVER_INFO_URL, get(get_mdb_info))
62 .route(malwaredb_api::USER_LOGIN_URL, post(user_login))
63 .route(malwaredb_api::USER_LOGOUT_URL, get(user_logout))
64 .route(malwaredb_api::USER_INFO_URL, get(get_user_groups_sources))
65 .route(
66 malwaredb_api::SUPPORTED_FILE_TYPES_URL,
67 get(get_supported_types),
68 )
69 .route(malwaredb_api::LIST_LABELS_URL, get(get_labels))
70 .route(malwaredb_api::LIST_SOURCES_URL, get(get_sources))
71 .route(
72 malwaredb_api::UPLOAD_SAMPLE_JSON_URL,
73 post(upload_new_sample_json),
74 )
75 .route(
76 malwaredb_api::UPLOAD_SAMPLE_CBOR_URL,
77 post(upload_new_sample_cbor),
78 )
79 .route(malwaredb_api::SEARCH_URL, post(sample_search))
80 .route(
81 concat!(malwaredb_api::DOWNLOAD_SAMPLE_URL, "/{hash}"),
82 get(download_sample),
83 )
84 .route(
85 concat!(malwaredb_api::DOWNLOAD_SAMPLE_CART_URL, "/{hash}"),
86 get(download_sample_cart),
87 )
88 .route(
89 concat!(malwaredb_api::SAMPLE_REPORT_URL, "/{hash}"),
90 get(sample_report),
91 )
92 .route(malwaredb_api::SIMILAR_SAMPLES_URL, post(find_similar))
93 .route(malwaredb_api::YARA_SEARCH_URL, post(yara_search))
94 .route(
95 concat!(malwaredb_api::YARA_SEARCH_URL, "/{uuid}"),
96 get(yara_search_get),
97 )
98 .layer(DefaultBodyLimit::max(state.max_upload))
99 .layer(compression_layer)
100 .layer(decompression_layer)
101 .layer(size_limit_layer)
102 .layer(SetSensitiveHeadersLayer::new(once(
103 HeaderName::from_static(MDB_API_HEADER),
104 )))
105 .route_layer(middleware::from_fn(response_header_middleware))
106 .layer(Extension(state))
107}
108
109struct UserInfo {
111 pub id: u32,
112}
113
114async fn response_header_middleware(
117 Extension(state): Extension<Arc<State>>,
118 headers: HeaderMap,
119 mut req: Request,
120 next: Next,
121) -> Result<Response, HttpError> {
122 const ALWAYS_ALLOWED_ENDPOINTS: [&str; 4] = [
123 "/",
124 FAVICON_URL,
125 malwaredb_api::SERVER_INFO_URL,
126 malwaredb_api::USER_LOGIN_URL,
127 ];
128
129 if !ALWAYS_ALLOWED_ENDPOINTS.contains(&req.uri().path()) {
130 let key = headers.get(MDB_API_HEADER).ok_or(HttpError(
131 ServerError::Unauthorized,
132 StatusCode::NOT_ACCEPTABLE,
133 ))?;
134 let key = key
135 .to_str()
136 .map_err(|_| HttpError(ServerError::Unauthorized, StatusCode::NOT_ACCEPTABLE))?;
137 let uid = state.db_type.get_uid(key).await.map_err(|e| {
138 tracing::warn!("Failed to get user ID from API key: {e}");
139 HttpError(ServerError::Unauthorized, StatusCode::UNAUTHORIZED)
140 })?;
141 req.extensions_mut().insert(Arc::new(UserInfo { id: uid }));
142 }
143
144 let mut response = next.run(req).await; response
149 .headers_mut()
150 .insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store"));
151
152 Ok(response)
153}
154
155async fn health() -> StatusCode {
156 StatusCode::OK
157}
158
159async fn favicon() -> Response {
160 const ICON: Bytes = Bytes::from_static(include_bytes!("../../MDB_Logo.ico"));
161
162 let mut bytes = ICON.into_response();
163 bytes.headers_mut().insert(
164 header::CONTENT_TYPE,
165 HeaderValue::from_static("image/vnd.microsoft.icon"),
166 );
167
168 bytes
169}
170
171async fn get_mdb_info(
172 Extension(state): Extension<Arc<State>>,
173) -> Result<Json<ServerResponse<ServerInfo>>, HttpError> {
174 let server_info = ServerResponse::Success(state.get_info().await?);
175 Ok(Json(server_info))
176}
177
178async fn user_login(
179 Extension(state): Extension<Arc<State>>,
180 Json(payload): Json<GetAPIKeyRequest>,
181) -> Result<Json<ServerResponse<GetAPIKeyResponse>>, HttpError> {
182 let api_key = state
183 .db_type
184 .authenticate(&payload.user, &payload.password)
185 .await?;
186
187 Ok(Json(ServerResponse::Success(GetAPIKeyResponse {
188 key: api_key,
189 message: None,
190 })))
191}
192
193async fn user_logout(
194 Extension(state): Extension<Arc<State>>,
195 Extension(user): Extension<Arc<UserInfo>>,
196) -> Result<StatusCode, HttpError> {
197 state.db_type.reset_own_api_key(user.id).await?;
198 Ok(StatusCode::OK)
199}
200
201async fn get_user_groups_sources(
202 Extension(state): Extension<Arc<State>>,
203 Extension(user): Extension<Arc<UserInfo>>,
204) -> Result<Json<ServerResponse<GetUserInfoResponse>>, HttpError> {
205 let groups_sources = ServerResponse::Success(state.db_type.get_user_info(user.id).await?);
206 Ok(Json(groups_sources))
207}
208
209async fn get_supported_types(
210 Extension(state): Extension<Arc<State>>,
211) -> Result<Json<ServerResponse<SupportedFileTypes>>, HttpError> {
212 let data_types = state.db_type.get_known_data_types().await?;
213 let file_types = ServerResponse::Success(SupportedFileTypes {
214 types: data_types.into_iter().map(Into::into).collect(),
215 message: None,
216 });
217 Ok(Json(file_types))
218}
219
220async fn get_labels(
221 Extension(state): Extension<Arc<State>>,
222) -> Result<Json<ServerResponse<Labels>>, HttpError> {
223 let labels = ServerResponse::Success(state.db_type.get_labels().await?);
224 Ok(Json(labels))
225}
226
227async fn get_sources(
228 Extension(state): Extension<Arc<State>>,
229 Extension(user): Extension<Arc<UserInfo>>,
230) -> Result<Json<ServerResponse<Sources>>, HttpError> {
231 let sources = ServerResponse::Success(state.db_type.get_user_sources(user.id).await?);
232 Ok(Json(sources))
233}
234
235async fn upload_new_sample_json(
236 Extension(state): Extension<Arc<State>>,
237 Extension(user): Extension<Arc<UserInfo>>,
238 Json(payload): Json<NewSampleB64>,
239) -> Result<StatusCode, HttpError> {
240 let allowed = state
241 .db_type
242 .allowed_user_source(user.id, payload.source_id)
243 .await?;
244
245 if !allowed {
246 return Err(HttpError(
247 ServerError::Unauthorized,
248 StatusCode::UNAUTHORIZED,
249 ));
250 }
251
252 let received_hash = hex::decode(&payload.sha256)?;
253 let bytes = general_purpose::STANDARD.decode(&payload.file_contents_b64)?;
254
255 let mut hasher = Sha256::new();
256 hasher.update(&bytes);
257 let result = hasher.finalize();
258
259 if result[..] != received_hash[..] {
260 return Err(HttpError(
261 ServerError::Unauthorized,
262 StatusCode::NOT_ACCEPTABLE,
263 ));
264 }
265
266 receive::incoming_sample(
267 state.clone(),
268 bytes,
269 user.id,
270 payload.source_id,
271 payload.file_name,
272 )
273 .await?;
274
275 Ok(StatusCode::OK)
276}
277
278async fn upload_new_sample_cbor(
279 Extension(state): Extension<Arc<State>>,
280 Extension(user): Extension<Arc<UserInfo>>,
281 Cbor(payload): Cbor<NewSampleBytes>,
282) -> Result<StatusCode, HttpError> {
283 let allowed = state
284 .db_type
285 .allowed_user_source(user.id, payload.source_id)
286 .await?;
287
288 if !allowed {
289 return Err(HttpError(
290 ServerError::Unauthorized,
291 StatusCode::UNAUTHORIZED,
292 ));
293 }
294
295 let received_hash = hex::decode(&payload.sha256)?;
296 let bytes = payload.file_contents;
297 let mut hasher = Sha256::new();
298 hasher.update(&bytes);
299 let result = hasher.finalize();
300
301 if result[..] != received_hash[..] {
302 return Err(HttpError(
303 ServerError::Unauthorized,
304 StatusCode::NOT_ACCEPTABLE,
305 ));
306 }
307
308 receive::incoming_sample(
309 state.clone(),
310 bytes,
311 user.id,
312 payload.source_id,
313 payload.file_name,
314 )
315 .await?;
316
317 Ok(StatusCode::OK)
318}
319
320async fn sample_search(
321 Extension(state): Extension<Arc<State>>,
322 Extension(user): Extension<Arc<UserInfo>>,
323 Json(payload): Json<SearchRequest>,
324) -> Result<Json<ServerResponse<SearchResponse>>, HttpError> {
325 let hashes = ServerResponse::Success(state.db_type.partial_search(user.id, payload).await?);
326 Ok(Json(hashes))
327}
328
329async fn download_sample(
330 Path(hash): Path<String>,
331 Extension(user): Extension<Arc<UserInfo>>,
332 Extension(state): Extension<Arc<State>>,
333) -> Result<Response, HttpError> {
334 if state.directory.is_none() {
335 return Err(NO_SAMPLES_STORED_ERROR);
336 }
337
338 let hash = HashType::try_from(hash.as_str())?;
339 let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
340
341 let hash = HashType::try_from(sha256.as_str())?;
343
344 let contents = state.retrieve_bytes(&sha256).await?;
345
346 let mut bytes = Bytes::from(contents).into_response();
347 let name_header_value = format!("attachment; filename=\"{sha256}\"");
348 bytes.headers_mut().insert(
349 header::CONTENT_DISPOSITION,
350 HeaderValue::from_str(&name_header_value)
351 .unwrap_or(HeaderValue::from_static("Unknown.bin")),
352 );
353
354 bytes.headers_mut().insert(
355 "content-digest",
356 HeaderValue::from_str(&hash.content_digest_header())?,
357 );
358
359 Ok(bytes)
360}
361
362async fn download_sample_cart(
363 Path(hash): Path<String>,
364 Extension(user): Extension<Arc<UserInfo>>,
365 Extension(state): Extension<Arc<State>>,
366) -> Result<Response, HttpError> {
367 if state.directory.is_none() {
368 return Err(NO_SAMPLES_STORED_ERROR);
369 }
370
371 let hash = HashType::try_from(hash.as_str())?;
372 let sha256 = state.db_type.retrieve_sample(user.id, &hash).await?;
373 let report = state.db_type.get_sample_report(user.id, &hash).await?;
374
375 let contents = state.retrieve_bytes(&sha256).await?;
376 let contents_cursor = Cursor::new(contents);
377 let mut output_cursor = Cursor::new(vec![]);
378
379 let mut output_metadata = cart_container::JsonMap::new();
380 output_metadata.insert("sha384".into(), report.sha384.into());
381 output_metadata.insert("sha512".into(), report.sha512.into());
382 output_metadata.insert("entropy".into(), report.entropy.into());
383 if let Some(filecmd) = report.filecommand {
384 output_metadata.insert("file".into(), filecmd.into());
385 }
386
387 cart_container::pack_stream(
388 contents_cursor,
389 &mut output_cursor,
390 Some(output_metadata),
391 None,
392 cart_container::digesters::default_digesters(), None,
394 )?;
395
396 let mut hasher = Sha256::new();
397 hasher.update(output_cursor.get_ref());
398 let hash_b64 = general_purpose::STANDARD.encode(hasher.finalize());
399
400 let mut bytes = Bytes::from(output_cursor.into_inner()).into_response();
401 let name_header_value = format!("attachment; filename=\"{sha256}.cart\"");
402 bytes.headers_mut().insert(
403 header::CONTENT_DISPOSITION,
404 HeaderValue::from_str(&name_header_value)
405 .unwrap_or(HeaderValue::from_static("Unknown.cart")),
406 );
407 bytes.headers_mut().insert(
408 "content-digest",
409 HeaderValue::from_str(&format!("sha-256=:{hash_b64}:"))?,
410 );
411
412 Ok(bytes)
413}
414
415async fn sample_report(
416 Path(hash): Path<String>,
417 Extension(user): Extension<Arc<UserInfo>>,
418 Extension(state): Extension<Arc<State>>,
419) -> Result<Json<ServerResponse<Report>>, HttpError> {
420 let hash = HashType::try_from(hash.as_str())?;
421 let report =
422 ServerResponse::<Report>::Success(state.db_type.get_sample_report(user.id, &hash).await?);
423 Ok(Json(report))
424}
425
426async fn find_similar(
427 Extension(state): Extension<Arc<State>>,
428 Extension(user): Extension<Arc<UserInfo>>,
429 Json(payload): Json<SimilarSamplesRequest>,
430) -> Result<Json<ServerResponse<SimilarSamplesResponse>>, HttpError> {
431 let results = state
432 .db_type
433 .find_similar_samples(user.id, &payload.hashes)
434 .await?;
435
436 Ok(Json(ServerResponse::Success(SimilarSamplesResponse {
437 results,
438 message: None,
439 })))
440}
441
442#[allow(unused_variables)]
443async fn yara_search(
444 Extension(state): Extension<Arc<State>>,
445 Extension(user): Extension<Arc<UserInfo>>,
446 Json(payload): Json<YaraSearchRequest>,
447) -> Result<Json<ServerResponse<YaraSearchRequestResponse>>, HttpError> {
448 #[cfg(feature = "yara")]
449 {
450 if state.directory.is_some() {
451 let yara_string = payload.rules.join("\n");
452 let rules = crate::yara::compile_yara_rules(payload)?;
453 let yara_bytes = rules.serialize()?;
454
455 let search_uuid = state
456 .db_type
457 .add_yara_search(user.id, &yara_string, &yara_bytes)
458 .await?;
459
460 Ok(Json(ServerResponse::Success(YaraSearchRequestResponse {
461 uuid: search_uuid,
462 })))
463 } else {
464 Err(NO_SAMPLES_STORED_ERROR)
465 }
466 }
467
468 #[cfg(not(feature = "yara"))]
469 {
470 tracing::warn!("Received Yara search request, but Yara support is not enabled");
471 Err(HttpError(
472 ServerError::Unsupported,
473 StatusCode::INTERNAL_SERVER_ERROR,
474 ))
475 }
476}
477
478#[allow(unused_variables)]
479async fn yara_search_get(
480 Path(uuid): Path<Uuid>,
481 Extension(state): Extension<Arc<State>>,
482 Extension(user): Extension<Arc<UserInfo>>,
483) -> Result<Json<ServerResponse<YaraSearchResponse>>, HttpError> {
484 #[cfg(feature = "yara")]
485 {
486 if state.directory.is_some() {
487 Ok(Json(ServerResponse::Success(
488 state.db_type.get_yara_results(uuid, user.id).await?,
489 )))
490 } else {
491 Err(NO_SAMPLES_STORED_ERROR)
492 }
493 }
494 #[cfg(not(feature = "yara"))]
495 {
496 tracing::warn!("Received Yara search request, but Yara support is not enabled");
497 Err(HttpError(
498 ServerError::Unsupported,
499 StatusCode::INTERNAL_SERVER_ERROR,
500 ))
501 }
502}
503
504const NO_SAMPLES_STORED_ERROR: HttpError =
508 HttpError(ServerError::NoSamples, StatusCode::NOT_ACCEPTABLE);
509
510pub struct HttpError(pub ServerError, pub StatusCode);
512
513impl IntoResponse for HttpError {
515 fn into_response(self) -> Response {
516 let response: ServerResponse<String> = ServerResponse::Error(self.0);
517 match serde_json::to_string(&response) {
518 Ok(json) => (self.1, json).into_response(),
519 Err(_) => self.1.into_response(),
520 }
521 }
522}
523
524impl Display for HttpError {
525 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
526 write!(f, "{}", self.0)
527 }
528}
529
530impl<E> From<E> for HttpError
532where
533 E: Into<anyhow::Error>,
534{
535 fn from(_err: E) -> Self {
536 Self(ServerError::ServerError, StatusCode::INTERNAL_SERVER_ERROR)
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use crate::crypto::{EncryptionOption, FileEncryption};
544 use crate::db::DatabaseType;
545 use crate::StateBuilder;
546 use malwaredb_api::PartialHashSearchType;
547
548 use std::collections::HashMap;
549 use std::path::PathBuf;
550 use std::sync::{Once, RwLock};
551 use std::time::{Instant, SystemTime};
552 use std::{env, fs};
553
554 use anyhow::Context;
555 use axum::body::Body;
556 use axum::http::Request;
557 use chrono::Local;
558 use constcat::concat;
559 use deadpool_postgres::tokio_postgres::{Config, NoTls};
560 use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
561 use http::header::CONTENT_TYPE;
562 use http_body_util::BodyExt;
563 use rstest::rstest;
564 use tower::ServiceExt;
565 use uuid::Uuid;
567
568 const ADMIN_UNAME: &str = "admin";
569 const ADMIN_PASSWORD: &str = "password12345";
570 const SAMPLE_BYTES: &[u8] = include_bytes!("../../../types/testdata/elf/elf_haiku_x86");
571
572 static TRACING: Once = Once::new();
573
574 fn init_tracing() {
575 tracing_subscriber::fmt()
576 .with_max_level(tracing::Level::TRACE)
577 .init();
578 }
579
580 async fn state(compress: bool, encrypt: bool) -> (Arc<State>, u32) {
581 TRACING.call_once(init_tracing);
582
583 let mut db_file = env::temp_dir();
585 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
586 if std::path::Path::new(&db_file).exists() {
587 fs::remove_file(&db_file)
588 .context(format!("failed to delete old SQLite file {db_file:?}"))
589 .unwrap();
590 }
591
592 let db_type =
593 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
594 .await
595 .context(format!("failed to create SQLite instance for {db_file:?}"))
596 .unwrap();
597 if compress {
598 db_type.enable_compression().await.unwrap();
599 }
600 let keys = if encrypt {
601 let key = FileEncryption::from(EncryptionOption::Xor);
602 let key_id = db_type.add_file_encryption_key(&key).await.unwrap();
603 let mut keys = HashMap::new();
604 keys.insert(key_id, key);
605 keys
606 } else {
607 HashMap::new()
608 };
609 let db_config = db_type.get_config().await.unwrap();
610
611 let state = State {
612 port: 8080,
613 directory: Some(
614 tempfile::TempDir::with_prefix("mdb-temp-samples")
615 .unwrap()
616 .path()
617 .into(),
618 ),
619 max_upload: 10 * 1024 * 1024,
620 ip: "127.0.0.1".parse().unwrap(),
621 db_type: Arc::new(db_type),
622 db_config,
623 keys,
624 started: SystemTime::now(),
625 #[cfg(feature = "vt")]
626 vt_client: None,
627 tls_config: None,
628 mdns: None,
629 };
630
631 state
632 .db_type
633 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
634 .await
635 .context("Failed to set admin password")
636 .unwrap();
637
638 let source_id = state
639 .db_type
640 .create_source("temp-source", None, None, Local::now(), true, Some(false))
641 .await
642 .unwrap();
643
644 state
645 .db_type
646 .add_group_to_source(0, source_id)
647 .await
648 .unwrap();
649
650 (Arc::new(state), source_id)
651 }
652
653 async fn state_and_token(pem_file: bool, postgres: bool) -> (State, u32, String) {
655 TRACING.call_once(init_tracing);
656
657 let mut builder = if postgres {
658 const CONNECTION_STRING: &str =
659 "user=malwaredbtesting password=malwaredbtesting dbname=malwaredbtesting host=localhost sslmode=disable";
660
661 let pg_config = CONNECTION_STRING.parse::<Config>().unwrap();
662 let mgr_config = ManagerConfig {
663 recycling_method: RecyclingMethod::Fast,
664 };
665 let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
666 let pool = Pool::builder(mgr)
667 .max_size(num_cpus::get().min(16))
668 .build()
669 .unwrap();
670
671 let client = pool.get().await.unwrap();
672 client
673 .batch_execute("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
674 .await
675 .unwrap();
676
677 StateBuilder::new(concat!("postgres ", CONNECTION_STRING), None)
678 .await
679 .unwrap()
680 } else {
681 let mut db_file = env::temp_dir();
682 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
683 if std::path::Path::new(&db_file).exists() {
684 fs::remove_file(&db_file)
685 .context(format!("failed to delete old SQLite file {db_file:?}"))
686 .unwrap();
687 }
688
689 StateBuilder::new(&format!("file:{}", db_file.display()), None)
690 .await
691 .unwrap()
692 };
693
694 builder = builder.directory(PathBuf::from(
695 tempfile::TempDir::with_prefix("mdb-temp-samples")
696 .unwrap()
697 .path(),
698 ));
699
700 if pem_file {
701 builder = builder.port(8443);
702 builder = builder
703 .tls(
704 "../../testdata/server_ca_cert.pem".into(),
705 "../../testdata/server_key.pem".into(),
706 )
707 .await
708 .unwrap();
709 } else {
710 builder = builder.port(8444);
711 builder = builder
712 .tls(
713 "../../testdata/server_cert.der".into(),
714 "../../testdata/server_key.der".into(),
715 )
716 .await
717 .unwrap();
718 }
719
720 let state = builder.into_state().await.unwrap();
721
722 state
723 .db_type
724 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
725 .await
726 .context("Failed to set admin password")
727 .unwrap();
728
729 let source_id = state
730 .db_type
731 .create_source("temp-source", None, None, Local::now(), true, Some(false))
732 .await
733 .unwrap();
734
735 state
736 .db_type
737 .add_group_to_source(0, source_id)
738 .await
739 .unwrap();
740
741 let token = state
742 .db_type
743 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
744 .await
745 .unwrap();
746
747 (state, source_id, token)
748 }
749
750 async fn get_key(state: Arc<State>) -> String {
751 TRACING.call_once(init_tracing);
752
753 let key_request = serde_json::to_string(&GetAPIKeyRequest {
754 user: ADMIN_UNAME.into(),
755 password: ADMIN_PASSWORD.into(),
756 })
757 .context("Failed to convert API key request to JSON")
758 .unwrap();
759
760 let request = Request::builder()
761 .method("POST")
762 .uri(malwaredb_api::USER_LOGIN_URL)
763 .header(CONTENT_TYPE, "application/json")
764 .body(Body::from(key_request))
765 .unwrap();
766
767 let response = app(state)
768 .oneshot(request)
769 .await
770 .context("failed to send/receive login request")
771 .unwrap();
772
773 assert_eq!(response.status(), StatusCode::OK);
774 let bytes = response
775 .into_body()
776 .collect()
777 .await
778 .expect("failed to collect response body to bytes")
779 .to_bytes();
780 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
781 .context("failed to convert response to string")
782 .unwrap();
783
784 let response: ServerResponse<GetAPIKeyResponse> = serde_json::from_str(&json_response)
785 .context("failed to convert json response to object")
786 .unwrap();
787
788 if let ServerResponse::Success(response) = response {
789 let key = response.key.clone();
790 assert_eq!(key.len(), 64);
791
792 key
793 } else {
794 panic!("failed to get API key response")
795 }
796 }
797
798 #[tokio::test]
799 async fn about_self() {
800 let (state, _) = state(false, false).await;
801 let api_key = get_key(state.clone()).await;
802
803 let request = Request::builder()
804 .method("GET")
805 .uri(malwaredb_api::USER_INFO_URL)
806 .header(MDB_API_HEADER, &api_key)
807 .body(Body::empty())
808 .unwrap();
809
810 let response = app(state.clone())
811 .oneshot(request)
812 .await
813 .context("failed to send/receive login request")
814 .unwrap();
815
816 assert_eq!(response.status(), StatusCode::OK);
817 let bytes = response
818 .into_body()
819 .collect()
820 .await
821 .expect("failed to collect response body to bytes")
822 .to_bytes();
823 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
824 .context("failed to convert response to string")
825 .unwrap();
826
827 let response: ServerResponse<GetUserInfoResponse> = serde_json::from_str(&json_response)
828 .context("failed to convert json response to object")
829 .unwrap();
830
831 let response = response.unwrap();
832 assert_eq!(response.id, 0);
833 assert!(response.is_admin);
834 assert!(!response.is_readonly);
835 assert_eq!(response.username, "admin");
836
837 let request = Request::builder()
839 .method("GET")
840 .uri(malwaredb_api::LIST_LABELS_URL)
841 .header(MDB_API_HEADER, &api_key)
842 .body(Body::empty())
843 .unwrap();
844
845 let response = app(state)
846 .oneshot(request)
847 .await
848 .context("failed to send/receive login request")
849 .unwrap();
850
851 assert_eq!(response.status(), StatusCode::OK);
852 let bytes = response
853 .into_body()
854 .collect()
855 .await
856 .expect("failed to collect response body to bytes")
857 .to_bytes();
858 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
859 .context("failed to convert response to string")
860 .unwrap();
861
862 let response: ServerResponse<Labels> = serde_json::from_str(&json_response)
863 .context("failed to convert json response to object")
864 .unwrap();
865
866 let response = response.unwrap();
867 assert!(response.is_empty());
868 }
869
870 #[rstest]
871 #[case::elf_encrypt_cart(include_bytes!("../../../types/testdata/elf/elf_haiku_x86.cart"), false, true, true, false)]
872 #[case::pe32(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), false, false, false, false)]
873 #[case::pdf_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), false, true, false, false)]
874 #[case::rtf(include_bytes!("../../../types/testdata/rtf/hello.rtf"), false, false, false, false)]
875 #[case::elf_compress_encrypt(include_bytes!("../../../types/testdata/elf/elf_haiku_x86"), true, true, false, false)]
876 #[case::pe32_compress(include_bytes!("../../../types/testdata/exe/pe64_win32_gui_x86_64_gnu.exe"), true, false, false, false)]
877 #[case::pdf_compress_encrypt(include_bytes!("../../../types/testdata/pdf/test.pdf"), true, true, false, false)]
878 #[case::rtf_compress(include_bytes!("../../../types/testdata/rtf/hello.rtf"), true, false, false, false)]
879 #[case::icon_unknown_type_proxy(include_bytes!("../../../../MDB_Logo.ico"), false, false, false, true)]
880 #[tokio::test]
881 async fn submit_sample(
882 #[case] contents: &[u8],
883 #[case] compress: bool,
884 #[case] encrypt: bool,
885 #[case] cart: bool,
886 #[case] should_fail: bool,
887 ) {
888 let (state, source_id) = state(compress, encrypt).await;
889 let api_key = get_key(state.clone()).await;
890
891 let file_contents_b64 = general_purpose::STANDARD.encode(contents);
892 let mut hasher = Sha256::new();
893 hasher.update(contents);
894 let sha256 = hex::encode(hasher.finalize());
895
896 let upload = serde_json::to_string(&NewSampleB64 {
897 file_name: "some_sample".into(),
898 source_id,
899 file_contents_b64,
900 sha256: sha256.clone(),
901 })
902 .context("failed to create upload structure")
903 .unwrap();
904
905 let request = Request::builder()
906 .method("POST")
907 .uri(malwaredb_api::UPLOAD_SAMPLE_JSON_URL)
908 .header(CONTENT_TYPE, "application/json")
909 .header(MDB_API_HEADER, &api_key)
910 .body(Body::from(upload))
911 .unwrap();
912
913 let response = app(state.clone())
914 .oneshot(request)
915 .await
916 .context("failed to send/receive upload request/response")
917 .unwrap();
918
919 if should_fail {
920 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
921 return;
922 }
923
924 assert_eq!(response.status(), StatusCode::OK);
925
926 let sha256 = if cart {
927 let mut input_buffer = Cursor::new(contents);
930 let mut output_buffer = Cursor::new(vec![]);
931 let (_, footer) =
932 cart_container::unpack_stream(&mut input_buffer, &mut output_buffer, None)
933 .expect("failed to decode CaRT file");
934 let footer = footer.expect("CaRT should have had a footer");
935 let sha256 = footer
936 .get("sha256")
937 .expect("CaRT footer should have had an entry for SHA-256")
938 .to_string();
939 sha256.replace('"', "") } else {
941 sha256
942 };
943
944 if let Some(dir) = &state.directory {
945 let mut sample_path = dir.clone();
946 sample_path.push(format!(
947 "{}/{}/{}/{}",
948 &sha256[0..2],
949 &sha256[2..4],
950 &sha256[4..6],
951 sha256
952 ));
953 eprintln!("Submitted sample should exist at {sample_path:?}.");
954 assert!(sample_path.exists());
955
956 if compress {
957 let sample_size_on_disk = sample_path.metadata().unwrap().len();
958 eprintln!(
959 "Original size: {}, compressed: {}",
960 contents.len(),
961 sample_size_on_disk
962 );
963 assert!(sample_size_on_disk < contents.len() as u64);
964 }
965 } else {
966 panic!("Directory was set for the state, but is now `None`");
967 }
968
969 let request = Request::builder()
970 .method("GET")
971 .uri(format!("{}/{sha256}", malwaredb_api::SAMPLE_REPORT_URL))
972 .header(MDB_API_HEADER, &api_key)
973 .body(Body::empty())
974 .unwrap();
975
976 let response = app(state.clone())
977 .oneshot(request)
978 .await
979 .context("failed to send/receive upload request/response")
980 .unwrap();
981
982 println!("Response headers: {:?}", response.headers());
983
984 let bytes = response
985 .into_body()
986 .collect()
987 .await
988 .expect("failed to collect response body to bytes")
989 .to_bytes();
990 let json_response = String::from_utf8(bytes.to_ascii_lowercase())
991 .context("failed to convert response to string")
992 .unwrap();
993
994 let report: ServerResponse<Report> = serde_json::from_str(&json_response)
995 .context("failed to convert json response to object")
996 .unwrap();
997 let report = report.unwrap();
998
999 assert_eq!(report.sha256, sha256);
1000 println!("Report: {report}");
1001
1002 let request = Request::builder()
1003 .method("GET")
1004 .uri(format!(
1005 "{}/{sha256}",
1006 malwaredb_api::DOWNLOAD_SAMPLE_CART_URL
1007 ))
1008 .header(MDB_API_HEADER, api_key.clone())
1009 .body(Body::empty())
1010 .unwrap();
1011
1012 let response = app(state.clone())
1013 .oneshot(request)
1014 .await
1015 .context("failed to send/receive upload request/response for CaRT")
1016 .unwrap();
1017
1018 println!("CaRT Response headers: {:?}", response.headers());
1019
1020 let bytes = response
1021 .into_body()
1022 .collect()
1023 .await
1024 .expect("failed to collect response body to bytes")
1025 .to_bytes();
1026
1027 let bytes = bytes.to_vec();
1028 let bytes_input = Cursor::new(bytes);
1029 let output = Cursor::new(vec![]);
1030 match cart_container::unpack_stream(bytes_input, output, None) {
1031 Ok((header, _)) => {
1032 let header = header.unwrap();
1033 assert_eq!(
1034 header.get("sha384"),
1035 Some(&serde_json::to_value(report.sha384).unwrap())
1036 );
1037 }
1038 Err(e) => panic!("{e}"),
1039 }
1040 }
1041
1042 #[rstest]
1047 #[case::ssl_pem_sqlite(true, false)]
1048 #[case::ssl_der_sqlite(false, false)]
1049 #[ignore = "don't run this in CI"]
1050 #[case::ssl_der_postgres(false, true)]
1051 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1052 async fn client_integration(#[case] pem: bool, #[case] postgres: bool) {
1053 TRACING.call_once(init_tracing);
1054
1055 let (state, source_id, token) = state_and_token(pem, postgres).await;
1058 let state_port = state.port; let server = tokio::spawn(async move {
1061 state
1062 .serve()
1063 .await
1064 .expect("MalwareDB failed to .serve() in tokio::spawn()");
1065 });
1066 assert!(!server.is_finished());
1067
1068 tokio::time::sleep(std::time::Duration::new(1, 0)).await;
1070
1071 println!("SemVer of MalwareDB: {:?}", *crate::MDB_VERSION_SEMVER);
1072 assert_eq!(
1073 *malwaredb_client::MDB_VERSION_SEMVER,
1074 *crate::MDB_VERSION_SEMVER,
1075 "SemVer parsing of MDB version failed"
1076 );
1077
1078 let mdb_client = malwaredb_client::MdbClient::new(
1079 format!("https://127.0.0.1:{state_port}"),
1080 token.clone(),
1081 Some("../../testdata/ca_cert.pem".into()),
1082 )
1083 .unwrap();
1084
1085 assert!(!mdb_client.supported_types().await.unwrap().types.is_empty());
1086 assert!(mdb_client.server_info().await.is_ok());
1087
1088 let start = Instant::now();
1089 assert!(mdb_client
1090 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1091 .await
1092 .context("failed to upload test file")
1093 .unwrap());
1094 let duration = start.elapsed();
1095 println!("Initial upload and database record creation via base64 took {duration:?}");
1096
1097 let start = Instant::now();
1098 mdb_client
1099 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1100 .await
1101 .context("failed to upload test file")
1102 .unwrap();
1103 let duration = start.elapsed();
1104 println!("Upload again via base64 took {duration:?}");
1105
1106 let start = Instant::now();
1107 mdb_client
1108 .submit_as_cbor(SAMPLE_BYTES, String::from("elf_haiku_x86"), source_id)
1109 .await
1110 .context("failed to upload test file")
1111 .unwrap();
1112 let duration = start.elapsed();
1113 println!("Upload again via cbor took {duration:?}");
1114
1115 let report = mdb_client
1116 .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1117 .await
1118 .expect("failed to get report for file just submitted");
1119 assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1120
1121 let similar = mdb_client
1122 .similar(SAMPLE_BYTES)
1123 .await
1124 .expect("failed to query for files similar to what was just submitted");
1125 assert_eq!(similar.results.len(), 1);
1126
1127 let search = mdb_client
1128 .partial_search(
1129 Some((PartialHashSearchType::MD5, String::from(&report.md5[0..10]))),
1130 None,
1131 PartialHashSearchType::Any,
1132 10,
1133 )
1134 .await
1135 .unwrap();
1136 assert_eq!(search.hashes.len(), 1);
1137
1138 let search = mdb_client
1139 .partial_search(
1140 Some((PartialHashSearchType::Any, "AAAA".into())),
1141 None,
1142 PartialHashSearchType::Any,
1143 10,
1144 )
1145 .await
1146 .unwrap();
1147 assert!(search.hashes.is_empty());
1148
1149 #[cfg(feature = "yara")]
1150 {
1151 let elf_yara = "rule elf_file {
1152 condition:
1153 uint32(0) == 0x464c457f
1154 }";
1155
1156 let mut counter = 0;
1157 let mut successful = false;
1158
1159 let yara_result = mdb_client.yara_search(elf_yara).await.unwrap();
1160 while counter < 5 {
1161 if let Ok(yara_response) = mdb_client.yara_result(yara_result.uuid).await {
1162 if !yara_response.results.is_empty() {
1163 println!("Yara search upload response: {:?}", yara_response.results);
1164 assert_eq!(yara_response.results.get("elf_file").unwrap().len(), 1);
1165 assert_eq!(yara_response.results.len(), 1);
1166 successful = true;
1167 break;
1168 }
1169 }
1170 tokio::time::sleep(std::time::Duration::new(2, 0)).await;
1171 counter += 1;
1172 }
1173
1174 assert!(successful);
1175 }
1176
1177 mdb_client.reset_key().await.expect("failed to reset key");
1178 server.abort();
1179 }
1180
1181 #[allow(clippy::too_many_lines)]
1182 #[test]
1183 #[ignore = "don't run this in CI"]
1184 fn client_integration_blocking() {
1185 TRACING.call_once(init_tracing);
1186 let token = Arc::new(RwLock::new(String::new()));
1187 let token_clone = token.clone();
1188
1189 let thread = std::thread::spawn(move || {
1195 let rt = tokio::runtime::Builder::new_multi_thread()
1196 .enable_all()
1197 .build()
1198 .unwrap();
1199
1200 let mut db_file = env::temp_dir();
1201 db_file.push(format!("testing_sqlite_{}.db", Uuid::new_v4()));
1202 if std::path::Path::new(&db_file).exists() {
1203 fs::remove_file(&db_file)
1204 .context(format!("failed to delete old SQLite file {db_file:?}"))
1205 .unwrap();
1206 }
1207
1208 let (db_type, db_config) = rt.block_on(async {
1209 let db_type =
1210 DatabaseType::from_string(&format!("file:{}", db_file.to_str().unwrap()), None)
1211 .await
1212 .context(format!("failed to create SQLite instance for {db_file:?}"))
1213 .unwrap();
1214
1215 let db_config = db_type.get_config().await.unwrap();
1216 (db_type, db_config)
1217 });
1218
1219 let state = State {
1220 port: 9090,
1221 directory: Some(
1222 tempfile::TempDir::with_prefix("mdb-temp-samples")
1223 .unwrap()
1224 .path()
1225 .into(),
1226 ),
1227 max_upload: 10 * 1024 * 1024,
1228 ip: "127.0.0.1".parse().unwrap(),
1229 db_type: Arc::new(db_type),
1230 db_config,
1231 keys: HashMap::default(),
1232 started: SystemTime::now(),
1233 #[cfg(feature = "vt")]
1234 vt_client: None,
1235 tls_config: None,
1236 mdns: None,
1237 };
1238
1239 rt.block_on(async {
1240 state
1241 .db_type
1242 .set_password(ADMIN_UNAME, ADMIN_PASSWORD)
1243 .await
1244 .context("Failed to set admin password")
1245 .unwrap();
1246
1247 let source_id = state
1248 .db_type
1249 .create_source("temp-source", None, None, Local::now(), true, Some(false))
1250 .await
1251 .unwrap();
1252
1253 state
1254 .db_type
1255 .add_group_to_source(0, source_id)
1256 .await
1257 .unwrap();
1258
1259 let token_string = state
1260 .db_type
1261 .authenticate(ADMIN_UNAME, ADMIN_PASSWORD)
1262 .await
1263 .unwrap();
1264
1265 if let Ok(mut token_lock) = token_clone.write() {
1266 *token_lock = token_string;
1267 }
1268
1269 state
1270 .serve()
1271 .await
1272 .expect("MalwareDB failed to .serve() in tokio::spawn()");
1273 });
1274 });
1275 std::thread::sleep(std::time::Duration::from_secs(1));
1276
1277 let mdb_client = malwaredb_client::blocking::MdbClient::new(
1278 String::from("http://127.0.0.1:9090"),
1279 token.read().unwrap().clone(),
1280 None,
1281 )
1282 .unwrap();
1283
1284 let types = match mdb_client.supported_types() {
1285 Ok(types) => types,
1286 Err(e) => panic!("{e}"),
1287 };
1288 assert!(!types.types.is_empty());
1289
1290 assert!(mdb_client
1291 .submit(SAMPLE_BYTES, String::from("elf_haiku_x86"), 1)
1292 .context("failed to upload test file")
1293 .unwrap());
1294
1295 let report = mdb_client
1296 .report("de10ba5e5402b46ea975b5cb8a45eb7df9e81dc81012fd4efd145ed2dce3a740")
1297 .expect("failed to get report for file just submitted");
1298 assert_eq!(report.md5, "82123011556b0e68801bee7bd71bb345");
1299
1300 let similar = mdb_client
1301 .similar(SAMPLE_BYTES)
1302 .expect("failed to query for files similar to what was just submitted");
1303 assert_eq!(similar.results.len(), 1);
1304
1305 let search = mdb_client
1306 .partial_search(
1307 Some((PartialHashSearchType::Any, "AAAA".into())),
1308 None,
1309 PartialHashSearchType::Any,
1310 10,
1311 )
1312 .unwrap();
1313 assert!(search.hashes.is_empty());
1314
1315 mdb_client.reset_key().expect("failed to reset key");
1316 drop(thread); }
1318}