Skip to main content

malware_modeler/
sorting.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::MAX_RECURSION_DEPTH;
4use crate::ftype::FileType;
5
6use std::error::Error;
7use std::fmt::{Display, Formatter};
8use std::path::{Path, PathBuf};
9
10use anyhow::Result;
11use clap::ValueEnum;
12use serde::de::IntoDeserializer;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15use walkdir::WalkDir;
16
17/// Found file type, either a model type or a non-model type
18/// Think of this as a Union, but they aren't safe.
19#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
20pub struct FileTypeUnion {
21    /// Model type
22    ftype: FileType,
23
24    /// Other file type
25    non_model_type: NonModelTypes,
26}
27
28impl FileTypeUnion {
29    /// Get a file type from bytes
30    #[must_use]
31    pub fn from_bytes(bytes: &[u8]) -> Self {
32        if let Some(ftype) = FileType::from_bytes(bytes) {
33            Self {
34                ftype,
35                non_model_type: NonModelTypes::Unknown,
36            }
37        } else {
38            Self {
39                ftype: FileType::NotSet,
40                non_model_type: NonModelTypes::from_bytes(bytes),
41            }
42        }
43    }
44
45    /// Indicates if the specified file type matches some bytes
46    #[must_use]
47    pub fn matches(&self, bytes: &[u8]) -> bool {
48        if self.ftype == FileType::NotSet {
49            NonModelTypes::from_bytes(bytes) == self.non_model_type
50        } else {
51            self.ftype.matches(bytes)
52        }
53    }
54
55    /// Indicates if the file type isn't known
56    #[must_use]
57    pub fn is_unknown(&self) -> bool {
58        self.ftype == FileType::DsStore
59            || (self.ftype == FileType::NotSet && self.non_model_type == NonModelTypes::Unknown)
60    }
61}
62
63/// Clap parser function for [`FileTypeUnion`] types.
64/// Uses Serde to get the file type from a string, uses Clap to print allowed values.
65///
66/// # Errors
67///
68/// Returns an error if the string is empty
69pub fn parse_file_type_union(
70    s: &str,
71) -> Result<FileTypeUnion, Box<dyn Error + Send + Sync + 'static>> {
72    if s.is_empty() {
73        return Err("File type cannot be empty.".into());
74    }
75
76    if let Ok(model_type) = crate::dataset::Dataset::file_type_from_line(s) {
77        return Ok(FileTypeUnion {
78            ftype: model_type,
79            non_model_type: NonModelTypes::Unknown,
80        });
81    }
82
83    let Ok(non_model_type) = NonModelTypes::name_from_line(s.to_lowercase().as_str()) else {
84        let mut allowed_variants = Vec::with_capacity(30);
85        for variant in FileType::value_variants() {
86            allowed_variants.push(variant.to_string().to_lowercase());
87        }
88
89        for variant in NonModelTypes::value_variants() {
90            allowed_variants.push(variant.to_string().to_lowercase());
91        }
92
93        return Err(format!(
94            "{s} is not a valid file type.\nAllowed types: {}.",
95            allowed_variants.join(", ")
96        )
97        .into());
98    };
99
100    Ok(FileTypeUnion {
101        ftype: FileType::NotSet,
102        non_model_type,
103    })
104}
105
106impl Display for FileTypeUnion {
107    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108        if self.ftype == FileType::NotSet {
109            write!(f, "{}", self.non_model_type)
110        } else {
111            write!(f, "{}", self.ftype)
112        }
113    }
114}
115
116/// File types we're likely to encounter, but we don't think are good candidates for a model
117#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
118pub enum NonModelTypes {
119    /// MS-DOS Batch file
120    BATCH,
121
122    /// Microsoft Cabinet archive
123    CAB,
124
125    /// Google Chrome extension
126    ChromeExt,
127
128    /// Java class
129    Class,
130
131    /// Common Object File Format which could be an object file or static library on Windows
132    COFF,
133
134    /// MS-DOS COM executable
135    COM,
136
137    /// Android Dalvik Executable (DEX) file
138    DEX,
139
140    /// Free Lossless Audio Codec (FLAC) audio file
141    FLAC,
142
143    /// Flash Video
144    FLV,
145
146    /// Graphics Interchange Format (GIF) image
147    GIF,
148
149    /// Gzip-compressed file
150    GZip,
151
152    /// HyperText Markup Language (HTML) document
153    HTML,
154
155    /// JPEG image
156    JPEG,
157
158    /// JPEG 2000
159    JPEG2K,
160
161    /// Preferred Executable Format file
162    PEF,
163
164    /// PEM-encoded certificate
165    PemCrt,
166
167    /// PEM-encoded certificate signing request
168    PemCsr,
169
170    /// PEM-encoded private key
171    PemKey,
172
173    /// Perl script
174    Perl,
175
176    /// PHP script
177    PHP,
178
179    /// Portable Network Graphics (PNG) image
180    PNG,
181
182    /// Postscript document
183    PS,
184
185    /// Python script
186    Python,
187
188    /// RAR archive
189    RAR,
190
191    /// Shell script
192    Shell,
193
194    /// Tagged Image File Format (TIFF) image
195    TIFF,
196
197    /// Shockwave Flash animation
198    SWF,
199
200    /// WebAssembly binary
201    Wasm,
202
203    /// Windows Registry data (sometimes called a Hive)
204    WindowsRegistry,
205
206    /// Windows shortcut
207    WindowsShortcut,
208
209    /// Extensible Markup Language (XML) document
210    XML,
211
212    /// Zip archive
213    Zip,
214
215    /// Unidentified ASCII or Unicode text
216    #[serde(skip)]
217    #[clap(skip)]
218    UnknownText,
219
220    /// This is used as a convenience type for when a [`FileType`] is available in a [`FileTypeUnion`]
221    #[doc(hidden)]
222    #[serde(skip)]
223    #[clap(skip)]
224    Unknown,
225}
226
227impl NonModelTypes {
228    /// Get a file type from bytes
229    #[must_use]
230    #[allow(clippy::too_many_lines)]
231    pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
232        if bytes.starts_with(b"PK") {
233            return Self::Zip;
234        }
235
236        if bytes.starts_with(b"\x4D\x53\x43\x46") || bytes.starts_with(b"\x4D\x53\x63\x28") {
237            return Self::CAB;
238        }
239
240        if bytes.starts_with(&[0x4C, 0x01])
241            || bytes.starts_with(&[0x64, 0x86])
242            || bytes.starts_with(&[0x00, 0x02])
243        {
244            return Self::COFF;
245        }
246
247        if bytes.starts_with(&[0xC9])
248            || bytes.starts_with(&[0xE9])
249            || bytes.starts_with(&[0xE8])
250            || bytes.starts_with(&[0xEB])
251        {
252            return Self::COM;
253        }
254
255        if bytes.starts_with(&[0x43, 0x72, 0x32, 0x34]) {
256            return Self::ChromeExt;
257        }
258
259        if bytes.starts_with(b"\x64\x65\x78\x0A\x30\x33\x35\x00") {
260            return Self::DEX;
261        }
262
263        if bytes.starts_with(b"\x89PNG") {
264            return Self::PNG;
265        }
266
267        if bytes.starts_with(b"-----BEGIN CERTIFICATE-----") {
268            return Self::PemCrt;
269        }
270
271        if bytes.starts_with(b"-----BEGIN CERTIFICATE REQUEST-----") {
272            return Self::PemCsr;
273        }
274
275        if bytes.starts_with(b"-----BEGIN PRIVATE KEY-----")
276            || bytes.starts_with(b"-----BEGIN DSA PRIVATE KEY-----")
277            || bytes.starts_with(b"-----BEGIN RSA PRIVATE KEY-----")
278        {
279            return Self::PemKey;
280        }
281
282        if bytes.starts_with(b"\x46\x4C\x56") {
283            return Self::FLV;
284        }
285
286        if bytes.starts_with(b"GIF87") || bytes.starts_with(b"GIF89") {
287            return Self::GIF;
288        }
289
290        if bytes.starts_with(b"\xFF\xD8\xFF") {
291            return Self::JPEG;
292        }
293
294        if bytes.starts_with(&[0xFF, 0x4F, 0xFF, 0xF1])
295            || bytes.starts_with(&[
296                0x00, 0x00, 0x00, 0x0C, 0x6A, 0x50, 0x20, 0x20, 0x0D, 0x0A, 0x87, 0x0A,
297            ])
298        {
299            return Self::JPEG2K;
300        }
301
302        if bytes.starts_with(&[0x4A, 0x6F, 0x79, 0x21]) {
303            return Self::PEF;
304        }
305
306        if bytes.starts_with(b"%!") {
307            return Self::PS;
308        }
309
310        if bytes.starts_with(&[0x52, 0x61, 0x72, 0x21, 0x1A, 0x07]) {
311            return Self::RAR;
312        }
313
314        if bytes.starts_with(&[0x1F, 0x8B]) {
315            return Self::GZip;
316        }
317
318        if bytes.starts_with(b"CWS") || bytes.starts_with(b"FWS") || bytes.starts_with(b"ZWS") {
319            return Self::SWF;
320        }
321
322        if bytes.starts_with(&[0x49, 0x20, 0x49])
323            || bytes.starts_with(&[0x49, 0x49, 0x2A])
324            || bytes.starts_with(&[0x4D, 0x4D, 0x00])
325        {
326            return Self::TIFF;
327        }
328
329        if bytes.starts_with(b"\x66\x4C\x61\x43") {
330            return Self::FLAC;
331        }
332
333        // If CAFEBABE and we're here, it's not a Mach-O binary
334        if bytes.starts_with(b"\xCA\xFE\xBA\xBE") {
335            let version = u32::from_be_bytes([
336                bytes[0x04],
337                bytes[0x04 + 1],
338                bytes[0x04 + 2],
339                bytes[0x04 + 3],
340            ]);
341            if version >= 0x20 {
342                return Self::Class;
343            }
344        }
345
346        if bytes.starts_with(&[0x00, 0x61, 0x73, 0x6D]) {
347            return Self::Wasm;
348        }
349
350        if bytes.starts_with(&[0x72, 0x65, 0x67, 0x66]) || bytes.starts_with(b"REGEDIT") {
351            return Self::WindowsRegistry;
352        }
353
354        if bytes.starts_with(&[0x4C, 0x00, 0x00, 0x00, 0x01, 0x14, 0x02, 0x00]) {
355            return Self::WindowsShortcut;
356        }
357
358        if bytes.is_ascii() {
359            if bytes.starts_with(b"<?xml") {
360                return Self::XML;
361            }
362
363            // This will detect executable scripts, but not scripts which aren't executed directly.
364            if bytes.starts_with(b"#!") {
365                let string_size = bytes.len().min(25);
366                if let Ok(string) =
367                    String::from_utf8(bytes[0..string_size].to_ascii_lowercase().clone())
368                {
369                    if string.contains("python") {
370                        return Self::Python;
371                    }
372
373                    if string.contains("perl") {
374                        return Self::Perl;
375                    }
376                }
377
378                return Self::Shell;
379            }
380
381            let string_size = bytes.len().min(1000);
382            if let Ok(string) =
383                String::from_utf8(bytes[0..string_size].to_ascii_lowercase().clone())
384            {
385                if string.contains("<?xml") {
386                    return Self::XML;
387                }
388
389                if string.contains("<?php") {
390                    return Self::PHP;
391                }
392
393                if string.contains("<html") || string.contains("<!doctype html>") {
394                    return Self::HTML;
395                }
396
397                if string.contains("@echo off") {
398                    return Self::BATCH;
399                }
400            }
401
402            return Self::UnknownText;
403        }
404
405        Self::Unknown
406    }
407
408    /// Use Serde to get the enum from a plain string.
409    pub(crate) fn name_from_line(line: &str) -> Result<Self, serde::de::value::Error> {
410        let line = line.split(':').nth(1).unwrap_or(line).to_uppercase();
411        let ftype: Result<_, serde::de::value::Error> =
412            Self::deserialize(String::from(line.trim()).into_deserializer());
413        ftype
414    }
415}
416
417impl From<NonModelTypes> for &'static str {
418    fn from(ftype: NonModelTypes) -> Self {
419        match ftype {
420            NonModelTypes::BATCH => "BAT",
421            NonModelTypes::CAB => "CAB",
422            NonModelTypes::COFF => "COFF",
423            NonModelTypes::COM => "COM",
424            NonModelTypes::ChromeExt => "ChromeExt",
425            NonModelTypes::Class => "Class",
426            NonModelTypes::DEX => "DEX",
427            NonModelTypes::FLAC => "FLAC",
428            NonModelTypes::FLV => "FLV",
429            NonModelTypes::GIF => "GIF",
430            NonModelTypes::GZip => "GZip",
431            NonModelTypes::HTML => "HTML",
432            NonModelTypes::JPEG => "JPEG",
433            NonModelTypes::JPEG2K => "JPEG2K",
434            NonModelTypes::PEF => "PEF",
435            NonModelTypes::PemCrt => "PemCrt",
436            NonModelTypes::PemCsr => "PemCsr",
437            NonModelTypes::PemKey => "PemKey",
438            NonModelTypes::Perl => "Perl",
439            NonModelTypes::PHP => "PHP",
440            NonModelTypes::PNG => "PNG",
441            NonModelTypes::PS => "PS",
442            NonModelTypes::Python => "Python",
443            NonModelTypes::RAR => "RAR",
444            NonModelTypes::Shell => "Shell",
445            NonModelTypes::SWF => "SWF",
446            NonModelTypes::TIFF => "TIFF",
447            NonModelTypes::Wasm => "Wasm",
448            NonModelTypes::WindowsRegistry => "WindowsRegistry",
449            NonModelTypes::WindowsShortcut => "WindowsShortcut",
450            NonModelTypes::XML => "XML",
451            NonModelTypes::Zip => "Zip",
452            NonModelTypes::Unknown => "Unknown",
453            NonModelTypes::UnknownText => "UnknownText",
454        }
455    }
456}
457
458impl Display for NonModelTypes {
459    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
460        let s: &'static str = (*self).into();
461        write!(f, "{s}")
462    }
463}
464
465/// Results of sorting files by their file type
466pub struct FileSortingResults {
467    /// Total files encountered during sorting
468    pub total_files: usize,
469
470    /// Duplicates encountered
471    pub files_removed: usize,
472
473    /// If any errors occurred during sorting
474    pub errors: usize,
475}
476
477/// Sort files by file type by creating a directory based on the file type and making a symbolic link
478/// to the original file, so no files are actually deleted.
479///
480/// The `find` and `file` commands can be used to verify:
481/// `find . -type f -exec file {} \;`
482///
483/// The [`file`](https://github.com/file/file) command will not be 100% correct, and neither will we.
484///
485/// # Errors
486///
487/// An error occurs if files cannot be read or if symbolic links cannot be created in newly-created
488/// directories based on file type.
489pub fn file_sorting<P: AsRef<Path>>(
490    origin: P,
491    destination: P,
492    depth: u8,
493) -> Result<FileSortingResults> {
494    #[cfg(not(any(unix, windows)))]
495    anyhow::bail!(
496        "File Sorting is not supported on this platform {} due to missing symbolic link support.\nPlease file an issue at https://github.com/rjzak/malware-modeler-rs/issues/new/choose.",
497        std::env::consts::OS
498    );
499
500    let mut total_files = 0;
501    let mut duplicate_files = 0;
502    let mut errors = 0;
503
504    for entry in WalkDir::new(origin)
505        .max_depth(MAX_RECURSION_DEPTH)
506        .follow_links(true)
507        .into_iter()
508        .flatten()
509    {
510        if entry.path().is_file() {
511            total_files += 1;
512
513            let Ok(contents) = std::fs::read(entry.path()) else {
514                errors += 1;
515                continue;
516            };
517
518            let file_type = FileTypeUnion::from_bytes(&contents);
519            let hash = hex::encode(Sha256::digest(contents));
520
521            let mut destination_file = destination.as_ref().join(file_type.to_string());
522            destination_file.push(hash_depth(&hash, depth));
523
524            std::fs::create_dir_all(&destination_file)?;
525            destination_file.push(hash);
526            if destination_file.exists() {
527                duplicate_files += 1;
528            } else {
529                #[cfg(unix)]
530                std::os::unix::fs::symlink(entry.path(), destination_file)?;
531
532                #[cfg(windows)]
533                std::os::windows::fs::symlink_file(entry.path(), destination_file)?;
534            }
535        }
536    }
537
538    Ok(FileSortingResults {
539        total_files,
540        errors,
541        files_removed: duplicate_files,
542    })
543}
544
545/// Build up directories representing the hash but don't include the hash itself so the directory
546/// structure can be created before adding the file to the path.
547#[inline]
548#[must_use]
549#[allow(clippy::cast_possible_truncation)]
550pub fn hash_depth(hash: &str, depth: u8) -> PathBuf {
551    let mut path = PathBuf::new();
552    for level in 0..depth.min((hash.len() / 2) as u8) {
553        path.push(&hash[(level as usize * 2)..=(level as usize * 2 + 1)]);
554    }
555
556    path
557}
558
559#[test]
560fn hash_depth_test() {
561    const HASH: &str = "9d6dc11990a109cd82d4dbafb6588b1b18e0e46b";
562    const HASH_512: &str = "fedc3e4d500fd9f3a52c05549a53f0f82ae684167033699e87ebe018517ceeb265136de09aa7e1fce5bbce0b8a4ead89170a99a5bdb2b5f7d1f02a81e3178af2";
563
564    // Test for what I assume is expected behaviour of pushing an empty string into a `PathBuf` object
565    let mut dummy = PathBuf::from("MyDir");
566    dummy.push(hash_depth(HASH, 0));
567    dummy.push("my_file.txt");
568    assert_eq!(dummy, PathBuf::from("MyDir/my_file.txt"));
569
570    assert_eq!(hash_depth(HASH, 0), PathBuf::from(""));
571    assert_eq!(hash_depth(HASH, 1), PathBuf::from("9d/"));
572    assert_eq!(hash_depth(HASH, 2), PathBuf::from("9d/6d/"));
573    assert_eq!(hash_depth(HASH, 3), PathBuf::from("9d/6d/c1/"));
574    assert_eq!(hash_depth(HASH, 4), PathBuf::from("9d/6d/c1/19/"));
575
576    // Ensure an absurd depth value doesn't go beyond the length of the hash
577    assert_eq!(
578        hash_depth(HASH, 255),
579        PathBuf::from("9d/6d/c1/19/90/a1/09/cd/82/d4/db/af/b6/58/8b/1b/18/e0/e4/6b/")
580    );
581
582    assert_eq!(hash_depth(HASH_512, 0), PathBuf::from(""));
583    assert_eq!(hash_depth(HASH_512, 1), PathBuf::from("fe/"));
584    assert_eq!(hash_depth(HASH_512, 2), PathBuf::from("fe/dc/"));
585    assert_eq!(
586        hash_depth(HASH_512, 255),
587        PathBuf::from(
588            "fe/dc/3e/4d/50/0f/d9/f3/a5/2c/05/54/9a/53/f0/f8/2a/e6/84/16/70/33/69/9e/87/eb/e0/18/51/7c/ee/b2/65/13/6d/e0/9a/a7/e1/fc/e5/bb/ce/0b/8a/4e/ad/89/17/0a/99/a5/bd/b2/b5/f7/d1/f0/2a/81/e3/17/8a/f2"
589        )
590    );
591}