1use crate::MAX_RECURSION_DEPTH;
4use crate::ftype::FileType;
5
6use std::error::Error;
7use std::fmt::{Display, Formatter};
8use std::path::{Path, PathBuf};
9
10use anyhow::Result;
11use clap::ValueEnum;
12use serde::de::IntoDeserializer;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15use walkdir::WalkDir;
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
20pub struct FileTypeUnion {
21 ftype: FileType,
23
24 non_model_type: NonModelTypes,
26}
27
28impl FileTypeUnion {
29 #[must_use]
31 pub fn from_bytes(bytes: &[u8]) -> Self {
32 if let Some(ftype) = FileType::from_bytes(bytes) {
33 Self {
34 ftype,
35 non_model_type: NonModelTypes::Unknown,
36 }
37 } else {
38 Self {
39 ftype: FileType::NotSet,
40 non_model_type: NonModelTypes::from_bytes(bytes),
41 }
42 }
43 }
44
45 #[must_use]
47 pub fn matches(&self, bytes: &[u8]) -> bool {
48 if self.ftype == FileType::NotSet {
49 NonModelTypes::from_bytes(bytes) == self.non_model_type
50 } else {
51 self.ftype.matches(bytes)
52 }
53 }
54
55 #[must_use]
57 pub fn is_unknown(&self) -> bool {
58 self.ftype == FileType::DsStore
59 || (self.ftype == FileType::NotSet && self.non_model_type == NonModelTypes::Unknown)
60 }
61}
62
63pub fn parse_file_type_union(
70 s: &str,
71) -> Result<FileTypeUnion, Box<dyn Error + Send + Sync + 'static>> {
72 if s.is_empty() {
73 return Err("File type cannot be empty.".into());
74 }
75
76 if let Ok(model_type) = crate::dataset::Dataset::file_type_from_line(s) {
77 return Ok(FileTypeUnion {
78 ftype: model_type,
79 non_model_type: NonModelTypes::Unknown,
80 });
81 }
82
83 let Ok(non_model_type) = NonModelTypes::name_from_line(s.to_lowercase().as_str()) else {
84 let mut allowed_variants = Vec::with_capacity(30);
85 for variant in FileType::value_variants() {
86 allowed_variants.push(variant.to_string().to_lowercase());
87 }
88
89 for variant in NonModelTypes::value_variants() {
90 allowed_variants.push(variant.to_string().to_lowercase());
91 }
92
93 return Err(format!(
94 "{s} is not a valid file type.\nAllowed types: {}.",
95 allowed_variants.join(", ")
96 )
97 .into());
98 };
99
100 Ok(FileTypeUnion {
101 ftype: FileType::NotSet,
102 non_model_type,
103 })
104}
105
106impl Display for FileTypeUnion {
107 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108 if self.ftype == FileType::NotSet {
109 write!(f, "{}", self.non_model_type)
110 } else {
111 write!(f, "{}", self.ftype)
112 }
113 }
114}
115
116#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
118pub enum NonModelTypes {
119 BATCH,
121
122 CAB,
124
125 ChromeExt,
127
128 Class,
130
131 COFF,
133
134 COM,
136
137 DEX,
139
140 FLAC,
142
143 FLV,
145
146 GIF,
148
149 GZip,
151
152 HTML,
154
155 JPEG,
157
158 JPEG2K,
160
161 PEF,
163
164 PemCrt,
166
167 PemCsr,
169
170 PemKey,
172
173 Perl,
175
176 PHP,
178
179 PNG,
181
182 PS,
184
185 Python,
187
188 RAR,
190
191 Shell,
193
194 TIFF,
196
197 SWF,
199
200 Wasm,
202
203 WindowsRegistry,
205
206 WindowsShortcut,
208
209 XML,
211
212 Zip,
214
215 #[serde(skip)]
217 #[clap(skip)]
218 UnknownText,
219
220 #[doc(hidden)]
222 #[serde(skip)]
223 #[clap(skip)]
224 Unknown,
225}
226
227impl NonModelTypes {
228 #[must_use]
230 #[allow(clippy::too_many_lines)]
231 pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
232 if bytes.starts_with(b"PK") {
233 return Self::Zip;
234 }
235
236 if bytes.starts_with(b"\x4D\x53\x43\x46") || bytes.starts_with(b"\x4D\x53\x63\x28") {
237 return Self::CAB;
238 }
239
240 if bytes.starts_with(&[0x4C, 0x01])
241 || bytes.starts_with(&[0x64, 0x86])
242 || bytes.starts_with(&[0x00, 0x02])
243 {
244 return Self::COFF;
245 }
246
247 if bytes.starts_with(&[0xC9])
248 || bytes.starts_with(&[0xE9])
249 || bytes.starts_with(&[0xE8])
250 || bytes.starts_with(&[0xEB])
251 {
252 return Self::COM;
253 }
254
255 if bytes.starts_with(&[0x43, 0x72, 0x32, 0x34]) {
256 return Self::ChromeExt;
257 }
258
259 if bytes.starts_with(b"\x64\x65\x78\x0A\x30\x33\x35\x00") {
260 return Self::DEX;
261 }
262
263 if bytes.starts_with(b"\x89PNG") {
264 return Self::PNG;
265 }
266
267 if bytes.starts_with(b"-----BEGIN CERTIFICATE-----") {
268 return Self::PemCrt;
269 }
270
271 if bytes.starts_with(b"-----BEGIN CERTIFICATE REQUEST-----") {
272 return Self::PemCsr;
273 }
274
275 if bytes.starts_with(b"-----BEGIN PRIVATE KEY-----")
276 || bytes.starts_with(b"-----BEGIN DSA PRIVATE KEY-----")
277 || bytes.starts_with(b"-----BEGIN RSA PRIVATE KEY-----")
278 {
279 return Self::PemKey;
280 }
281
282 if bytes.starts_with(b"\x46\x4C\x56") {
283 return Self::FLV;
284 }
285
286 if bytes.starts_with(b"GIF87") || bytes.starts_with(b"GIF89") {
287 return Self::GIF;
288 }
289
290 if bytes.starts_with(b"\xFF\xD8\xFF") {
291 return Self::JPEG;
292 }
293
294 if bytes.starts_with(&[0xFF, 0x4F, 0xFF, 0xF1])
295 || bytes.starts_with(&[
296 0x00, 0x00, 0x00, 0x0C, 0x6A, 0x50, 0x20, 0x20, 0x0D, 0x0A, 0x87, 0x0A,
297 ])
298 {
299 return Self::JPEG2K;
300 }
301
302 if bytes.starts_with(&[0x4A, 0x6F, 0x79, 0x21]) {
303 return Self::PEF;
304 }
305
306 if bytes.starts_with(b"%!") {
307 return Self::PS;
308 }
309
310 if bytes.starts_with(&[0x52, 0x61, 0x72, 0x21, 0x1A, 0x07]) {
311 return Self::RAR;
312 }
313
314 if bytes.starts_with(&[0x1F, 0x8B]) {
315 return Self::GZip;
316 }
317
318 if bytes.starts_with(b"CWS") || bytes.starts_with(b"FWS") || bytes.starts_with(b"ZWS") {
319 return Self::SWF;
320 }
321
322 if bytes.starts_with(&[0x49, 0x20, 0x49])
323 || bytes.starts_with(&[0x49, 0x49, 0x2A])
324 || bytes.starts_with(&[0x4D, 0x4D, 0x00])
325 {
326 return Self::TIFF;
327 }
328
329 if bytes.starts_with(b"\x66\x4C\x61\x43") {
330 return Self::FLAC;
331 }
332
333 if bytes.starts_with(b"\xCA\xFE\xBA\xBE") {
335 let version = u32::from_be_bytes([
336 bytes[0x04],
337 bytes[0x04 + 1],
338 bytes[0x04 + 2],
339 bytes[0x04 + 3],
340 ]);
341 if version >= 0x20 {
342 return Self::Class;
343 }
344 }
345
346 if bytes.starts_with(&[0x00, 0x61, 0x73, 0x6D]) {
347 return Self::Wasm;
348 }
349
350 if bytes.starts_with(&[0x72, 0x65, 0x67, 0x66]) || bytes.starts_with(b"REGEDIT") {
351 return Self::WindowsRegistry;
352 }
353
354 if bytes.starts_with(&[0x4C, 0x00, 0x00, 0x00, 0x01, 0x14, 0x02, 0x00]) {
355 return Self::WindowsShortcut;
356 }
357
358 if bytes.is_ascii() {
359 if bytes.starts_with(b"<?xml") {
360 return Self::XML;
361 }
362
363 if bytes.starts_with(b"#!") {
365 let string_size = bytes.len().min(25);
366 if let Ok(string) =
367 String::from_utf8(bytes[0..string_size].to_ascii_lowercase().clone())
368 {
369 if string.contains("python") {
370 return Self::Python;
371 }
372
373 if string.contains("perl") {
374 return Self::Perl;
375 }
376 }
377
378 return Self::Shell;
379 }
380
381 let string_size = bytes.len().min(1000);
382 if let Ok(string) =
383 String::from_utf8(bytes[0..string_size].to_ascii_lowercase().clone())
384 {
385 if string.contains("<?xml") {
386 return Self::XML;
387 }
388
389 if string.contains("<?php") {
390 return Self::PHP;
391 }
392
393 if string.contains("<html") || string.contains("<!doctype html>") {
394 return Self::HTML;
395 }
396
397 if string.contains("@echo off") {
398 return Self::BATCH;
399 }
400 }
401
402 return Self::UnknownText;
403 }
404
405 Self::Unknown
406 }
407
408 pub(crate) fn name_from_line(line: &str) -> Result<Self, serde::de::value::Error> {
410 let line = line.split(':').nth(1).unwrap_or(line).to_uppercase();
411 let ftype: Result<_, serde::de::value::Error> =
412 Self::deserialize(String::from(line.trim()).into_deserializer());
413 ftype
414 }
415}
416
417impl From<NonModelTypes> for &'static str {
418 fn from(ftype: NonModelTypes) -> Self {
419 match ftype {
420 NonModelTypes::BATCH => "BAT",
421 NonModelTypes::CAB => "CAB",
422 NonModelTypes::COFF => "COFF",
423 NonModelTypes::COM => "COM",
424 NonModelTypes::ChromeExt => "ChromeExt",
425 NonModelTypes::Class => "Class",
426 NonModelTypes::DEX => "DEX",
427 NonModelTypes::FLAC => "FLAC",
428 NonModelTypes::FLV => "FLV",
429 NonModelTypes::GIF => "GIF",
430 NonModelTypes::GZip => "GZip",
431 NonModelTypes::HTML => "HTML",
432 NonModelTypes::JPEG => "JPEG",
433 NonModelTypes::JPEG2K => "JPEG2K",
434 NonModelTypes::PEF => "PEF",
435 NonModelTypes::PemCrt => "PemCrt",
436 NonModelTypes::PemCsr => "PemCsr",
437 NonModelTypes::PemKey => "PemKey",
438 NonModelTypes::Perl => "Perl",
439 NonModelTypes::PHP => "PHP",
440 NonModelTypes::PNG => "PNG",
441 NonModelTypes::PS => "PS",
442 NonModelTypes::Python => "Python",
443 NonModelTypes::RAR => "RAR",
444 NonModelTypes::Shell => "Shell",
445 NonModelTypes::SWF => "SWF",
446 NonModelTypes::TIFF => "TIFF",
447 NonModelTypes::Wasm => "Wasm",
448 NonModelTypes::WindowsRegistry => "WindowsRegistry",
449 NonModelTypes::WindowsShortcut => "WindowsShortcut",
450 NonModelTypes::XML => "XML",
451 NonModelTypes::Zip => "Zip",
452 NonModelTypes::Unknown => "Unknown",
453 NonModelTypes::UnknownText => "UnknownText",
454 }
455 }
456}
457
458impl Display for NonModelTypes {
459 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
460 let s: &'static str = (*self).into();
461 write!(f, "{s}")
462 }
463}
464
465pub struct FileSortingResults {
467 pub total_files: usize,
469
470 pub files_removed: usize,
472
473 pub errors: usize,
475}
476
477pub fn file_sorting<P: AsRef<Path>>(
490 origin: P,
491 destination: P,
492 depth: u8,
493) -> Result<FileSortingResults> {
494 #[cfg(not(any(unix, windows)))]
495 anyhow::bail!(
496 "File Sorting is not supported on this platform {} due to missing symbolic link support.\nPlease file an issue at https://github.com/rjzak/malware-modeler-rs/issues/new/choose.",
497 std::env::consts::OS
498 );
499
500 let mut total_files = 0;
501 let mut duplicate_files = 0;
502 let mut errors = 0;
503
504 for entry in WalkDir::new(origin)
505 .max_depth(MAX_RECURSION_DEPTH)
506 .follow_links(true)
507 .into_iter()
508 .flatten()
509 {
510 if entry.path().is_file() {
511 total_files += 1;
512
513 let Ok(contents) = std::fs::read(entry.path()) else {
514 errors += 1;
515 continue;
516 };
517
518 let file_type = FileTypeUnion::from_bytes(&contents);
519 let hash = hex::encode(Sha256::digest(contents));
520
521 let mut destination_file = destination.as_ref().join(file_type.to_string());
522 destination_file.push(hash_depth(&hash, depth));
523
524 std::fs::create_dir_all(&destination_file)?;
525 destination_file.push(hash);
526 if destination_file.exists() {
527 duplicate_files += 1;
528 } else {
529 #[cfg(unix)]
530 std::os::unix::fs::symlink(entry.path(), destination_file)?;
531
532 #[cfg(windows)]
533 std::os::windows::fs::symlink_file(entry.path(), destination_file)?;
534 }
535 }
536 }
537
538 Ok(FileSortingResults {
539 total_files,
540 errors,
541 files_removed: duplicate_files,
542 })
543}
544
545#[inline]
548#[must_use]
549#[allow(clippy::cast_possible_truncation)]
550pub fn hash_depth(hash: &str, depth: u8) -> PathBuf {
551 let mut path = PathBuf::new();
552 for level in 0..depth.min((hash.len() / 2) as u8) {
553 path.push(&hash[(level as usize * 2)..=(level as usize * 2 + 1)]);
554 }
555
556 path
557}
558
559#[test]
560fn hash_depth_test() {
561 const HASH: &str = "9d6dc11990a109cd82d4dbafb6588b1b18e0e46b";
562 const HASH_512: &str = "fedc3e4d500fd9f3a52c05549a53f0f82ae684167033699e87ebe018517ceeb265136de09aa7e1fce5bbce0b8a4ead89170a99a5bdb2b5f7d1f02a81e3178af2";
563
564 let mut dummy = PathBuf::from("MyDir");
566 dummy.push(hash_depth(HASH, 0));
567 dummy.push("my_file.txt");
568 assert_eq!(dummy, PathBuf::from("MyDir/my_file.txt"));
569
570 assert_eq!(hash_depth(HASH, 0), PathBuf::from(""));
571 assert_eq!(hash_depth(HASH, 1), PathBuf::from("9d/"));
572 assert_eq!(hash_depth(HASH, 2), PathBuf::from("9d/6d/"));
573 assert_eq!(hash_depth(HASH, 3), PathBuf::from("9d/6d/c1/"));
574 assert_eq!(hash_depth(HASH, 4), PathBuf::from("9d/6d/c1/19/"));
575
576 assert_eq!(
578 hash_depth(HASH, 255),
579 PathBuf::from("9d/6d/c1/19/90/a1/09/cd/82/d4/db/af/b6/58/8b/1b/18/e0/e4/6b/")
580 );
581
582 assert_eq!(hash_depth(HASH_512, 0), PathBuf::from(""));
583 assert_eq!(hash_depth(HASH_512, 1), PathBuf::from("fe/"));
584 assert_eq!(hash_depth(HASH_512, 2), PathBuf::from("fe/dc/"));
585 assert_eq!(
586 hash_depth(HASH_512, 255),
587 PathBuf::from(
588 "fe/dc/3e/4d/50/0f/d9/f3/a5/2c/05/54/9a/53/f0/f8/2a/e6/84/16/70/33/69/9e/87/eb/e0/18/51/7c/ee/b2/65/13/6d/e0/9a/a7/e1/fc/e5/bb/ce/0b/8a/4e/ad/89/17/0a/99/a5/bd/b2/b5/f7/d1/f0/2a/81/e3/17/8a/f2"
589 )
590 );
591}