malware_modeler/
sorting.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::ftype::FileType;
4use crate::MAX_RECURSION_DEPTH;
5
6use std::error::Error;
7use std::fmt::{Display, Formatter};
8use std::path::{Path, PathBuf};
9
10use anyhow::Result;
11use clap::ValueEnum;
12use serde::de::IntoDeserializer;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15use walkdir::WalkDir;
16
17/// Found file type, either a model type or a non-model type
18/// Think of this as a Union, but they aren't safe.
19#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
20pub struct FileTypeUnion {
21    /// Model type
22    ftype: FileType,
23
24    /// Other file type
25    non_model_type: NonModelTypes,
26}
27
28impl FileTypeUnion {
29    /// Get a file type from bytes
30    #[must_use]
31    pub fn from_bytes(bytes: &[u8]) -> Self {
32        if let Some(ftype) = FileType::from_bytes(bytes) {
33            Self {
34                ftype,
35                non_model_type: NonModelTypes::Unknown,
36            }
37        } else {
38            Self {
39                ftype: FileType::NotSet,
40                non_model_type: NonModelTypes::from_bytes(bytes),
41            }
42        }
43    }
44
45    /// Indicates if the specified file type matches some bytes
46    #[must_use]
47    pub fn matches(&self, bytes: &[u8]) -> bool {
48        if self.ftype == FileType::NotSet {
49            NonModelTypes::from_bytes(bytes) == self.non_model_type
50        } else {
51            self.ftype.matches(bytes)
52        }
53    }
54
55    /// Indicates if the file type isn't known
56    #[must_use]
57    pub fn is_unknown(&self) -> bool {
58        self.ftype == FileType::NotSet
59            || self.ftype == FileType::DsStore
60            || self.non_model_type != NonModelTypes::Unknown
61    }
62}
63
64/// Clap parser function for [`FileTypeUnion`] types.
65/// Uses Serde to get the file type from a string, uses Clap to print allowed values.
66///
67/// # Errors
68///
69/// Returns an error if the string is empty
70pub fn parse_file_type_union(
71    s: &str,
72) -> Result<FileTypeUnion, Box<dyn Error + Send + Sync + 'static>> {
73    if s.is_empty() {
74        return Err("File type cannot be empty.".into());
75    }
76
77    if let Ok(model_type) = crate::dataset::Dataset::file_type_from_line(s) {
78        return Ok(FileTypeUnion {
79            ftype: model_type,
80            non_model_type: NonModelTypes::Unknown,
81        });
82    }
83
84    let Ok(non_model_type) = NonModelTypes::name_from_line(s.to_lowercase().as_str()) else {
85        let mut allowed_variants = Vec::with_capacity(30);
86        for variant in FileType::value_variants() {
87            allowed_variants.push(variant.to_string().to_lowercase());
88        }
89
90        for variant in NonModelTypes::value_variants() {
91            allowed_variants.push(variant.to_string().to_lowercase());
92        }
93
94        return Err(format!(
95            "{s} is not a valid file type.\nAllowed types: {}.",
96            allowed_variants.join(", ")
97        )
98        .into());
99    };
100
101    Ok(FileTypeUnion {
102        ftype: FileType::NotSet,
103        non_model_type,
104    })
105}
106
107impl Display for FileTypeUnion {
108    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109        if self.ftype == FileType::NotSet {
110            write!(f, "{}", self.non_model_type)
111        } else {
112            write!(f, "{}", self.ftype)
113        }
114    }
115}
116
117/// File types we're likely to encounter, but we don't think are good candidates for a model
118#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
119pub enum NonModelTypes {
120    /// MS-DOS Batch file
121    BATCH,
122
123    /// Microsoft Cabinet archive
124    CAB,
125
126    /// Google Chrome extension
127    ChromeExt,
128
129    /// Java class
130    Class,
131
132    /// MS-DOS COM executable
133    COM,
134
135    /// Android Dalvik Executable (DEX) file
136    DEX,
137
138    /// Free Lossless Audio Codec (FLAC) audio file
139    FLAC,
140
141    /// Flash Video
142    FLV,
143
144    /// Graphics Interchange Format (GIF) image
145    GIF,
146
147    /// Gzip-compressed file
148    GZip,
149
150    /// HyperText Markup Language (HTML) document
151    HTML,
152
153    /// JPEG image
154    JPEG,
155
156    /// JPEG 2000
157    JPEG2K,
158
159    /// Preferred Executable Format file
160    PEF,
161
162    /// PEM-encoded certificate
163    PemCrt,
164
165    /// PEM-encoded certificate signing request
166    PemCsr,
167
168    /// PEM-encoded private key
169    PemKey,
170
171    /// Portable Network Graphics (PNG) image
172    PNG,
173
174    /// Postscript document
175    PS,
176
177    /// Python script
178    Python,
179
180    /// RAR archive
181    RAR,
182
183    /// Shell script
184    Shell,
185
186    /// Tagged Image File Format (TIFF) image
187    TIFF,
188
189    /// Shockwave Flash animation
190    SWF,
191
192    /// WebAssembly binary
193    Wasm,
194
195    /// Windows shortcut
196    WindowsShortcut,
197
198    /// Extensible Markup Language (XML) document
199    XML,
200
201    /// Zip archive
202    Zip,
203
204    /// ASCII text
205    #[serde(skip)]
206    #[clap(skip)]
207    UnknownAscii,
208
209    /// This is used as a convenience type for when a [`FileType`] is available in a [`FileTypeUnion`]
210    #[doc(hidden)]
211    #[serde(skip)]
212    #[clap(skip)]
213    Unknown,
214}
215
216impl NonModelTypes {
217    /// Get a file type from bytes
218    #[must_use]
219    #[allow(clippy::too_many_lines)]
220    pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
221        if bytes.starts_with(b"PK") {
222            return Self::Zip;
223        }
224
225        if bytes.starts_with(b"\x4D\x53\x43\x46") || bytes.starts_with(b"\x4D\x53\x63\x28") {
226            return Self::CAB;
227        }
228
229        if bytes.starts_with(&[0xC9])
230            || bytes.starts_with(&[0xE9])
231            || bytes.starts_with(&[0xE8])
232            || bytes.starts_with(&[0xEB])
233        {
234            return Self::COM;
235        }
236
237        if bytes.starts_with(&[0x43, 0x72, 0x32, 0x34]) {
238            return Self::ChromeExt;
239        }
240
241        if bytes.starts_with(b"\x64\x65\x78\x0A\x30\x33\x35\x00") {
242            return Self::DEX;
243        }
244
245        if bytes.starts_with(b"\x89PNG") {
246            return Self::PNG;
247        }
248
249        if bytes.starts_with(b"-----BEGIN CERTIFICATE-----") {
250            return Self::PemCrt;
251        }
252
253        if bytes.starts_with(b"-----BEGIN CERTIFICATE REQUEST-----") {
254            return Self::PemCsr;
255        }
256
257        if bytes.starts_with(b"-----BEGIN PRIVATE KEY-----")
258            || bytes.starts_with(b"-----BEGIN DSA PRIVATE KEY-----")
259            || bytes.starts_with(b"-----BEGIN RSA PRIVATE KEY-----")
260        {
261            return Self::PemKey;
262        }
263
264        if bytes.starts_with(b"\x46\x4C\x56") {
265            return Self::FLV;
266        }
267
268        if bytes.starts_with(b"GIF87") || bytes.starts_with(b"GIF89") {
269            return Self::GIF;
270        }
271
272        if bytes.starts_with(b"\xFF\xD8\xFF") {
273            return Self::JPEG;
274        }
275
276        if bytes.starts_with(&[0xFF, 0x4F, 0xFF, 0xF1])
277            || bytes.starts_with(&[
278                0x00, 0x00, 0x00, 0x0C, 0x6A, 0x50, 0x20, 0x20, 0x0D, 0x0A, 0x87, 0x0A,
279            ])
280        {
281            return Self::JPEG2K;
282        }
283
284        if bytes.starts_with(&[0x4A, 0x6F, 0x79, 0x21]) {
285            return Self::PEF;
286        }
287
288        if bytes.starts_with(b"%!") {
289            return Self::PS;
290        }
291
292        if bytes.starts_with(&[0x52, 0x61, 0x72, 0x21, 0x1A, 0x07]) {
293            return Self::RAR;
294        }
295
296        if bytes.starts_with(&[0x1F, 0x8B]) {
297            return Self::GZip;
298        }
299
300        if bytes.starts_with(b"CWS") || bytes.starts_with(b"FWS") || bytes.starts_with(b"ZWS") {
301            return Self::SWF;
302        }
303
304        if bytes.starts_with(&[0x49, 0x20, 0x49])
305            || bytes.starts_with(&[0x49, 0x49, 0x2A])
306            || bytes.starts_with(&[0x4D, 0x4D, 0x00])
307        {
308            return Self::TIFF;
309        }
310
311        if bytes.starts_with(b"\x66\x4C\x61\x43") {
312            return Self::FLAC;
313        }
314
315        // If CAFEBABE and we're here, it's not a Mach-O binary
316        if bytes.starts_with(b"\xCA\xFE\xBA\xBE") {
317            let version = u32::from_be_bytes([
318                bytes[0x04],
319                bytes[0x04 + 1],
320                bytes[0x04 + 2],
321                bytes[0x04 + 3],
322            ]);
323            if version >= 0x20 {
324                return Self::Class;
325            }
326        }
327
328        if bytes.starts_with(&[0x00, 0x61, 0x73, 0x6D]) {
329            return Self::Wasm;
330        }
331
332        if bytes.starts_with(&[0x4C, 0x00, 0x00, 0x00, 0x01, 0x14, 0x02, 0x00]) {
333            return Self::WindowsShortcut;
334        }
335
336        if bytes.is_ascii() {
337            let ascii_size = bytes.len().min(50);
338            if let Ok(ascii) = String::from_utf8(bytes[0..ascii_size].to_ascii_lowercase().clone())
339            {
340                if ascii.contains("<?xml") {
341                    return Self::XML;
342                }
343
344                if ascii.contains("<html") || ascii.contains("<!doctype html>") {
345                    return Self::HTML;
346                }
347
348                if ascii.starts_with("#!") {
349                    if ascii.contains("python") {
350                        return Self::Python;
351                    }
352                    return Self::Shell;
353                }
354
355                if ascii.starts_with("@echo off") {
356                    return Self::BATCH;
357                }
358
359                return Self::UnknownAscii;
360            }
361        }
362
363        Self::Unknown
364    }
365
366    /// Use Serde to get the enum from a plain string.
367    pub(crate) fn name_from_line(line: &str) -> Result<Self, serde::de::value::Error> {
368        let line = line.split(':').nth(1).unwrap_or(line).to_uppercase();
369        let ftype: Result<_, serde::de::value::Error> =
370            Self::deserialize(String::from(line.trim()).into_deserializer());
371        ftype
372    }
373}
374
375impl From<NonModelTypes> for &'static str {
376    fn from(ftype: NonModelTypes) -> Self {
377        match ftype {
378            NonModelTypes::BATCH => "BAT",
379            NonModelTypes::CAB => "CAB",
380            NonModelTypes::COM => "COM",
381            NonModelTypes::ChromeExt => "ChromeExt",
382            NonModelTypes::Class => "Class",
383            NonModelTypes::DEX => "DEX",
384            NonModelTypes::FLAC => "FLAC",
385            NonModelTypes::FLV => "FLV",
386            NonModelTypes::GIF => "GIF",
387            NonModelTypes::GZip => "GZip",
388            NonModelTypes::HTML => "HTML",
389            NonModelTypes::JPEG => "JPEG",
390            NonModelTypes::JPEG2K => "JPEG2K",
391            NonModelTypes::PEF => "PEF",
392            NonModelTypes::PemCrt => "PemCrt",
393            NonModelTypes::PemCsr => "PemCsr",
394            NonModelTypes::PemKey => "PemKey",
395            NonModelTypes::PNG => "PNG",
396            NonModelTypes::PS => "PS",
397            NonModelTypes::Python => "Python",
398            NonModelTypes::RAR => "RAR",
399            NonModelTypes::Shell => "Shell",
400            NonModelTypes::SWF => "SWF",
401            NonModelTypes::TIFF => "TIFF",
402            NonModelTypes::Wasm => "Wasm",
403            NonModelTypes::WindowsShortcut => "WindowsShortcut",
404            NonModelTypes::XML => "XML",
405            NonModelTypes::Zip => "Zip",
406            NonModelTypes::Unknown => "Unknown",
407            NonModelTypes::UnknownAscii => "Unknown ASCII",
408        }
409    }
410}
411
412impl Display for NonModelTypes {
413    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
414        let s: &'static str = (*self).into();
415        write!(f, "{s}")
416    }
417}
418
419/// Results of sorting files by their file type
420pub struct FileSortingResults {
421    /// Total files encountered during sorting
422    pub total_files: usize,
423
424    /// Duplicates encountered
425    pub files_removed: usize,
426
427    /// If any errors occurred during sorting
428    pub errors: usize,
429}
430
431/// Sort files by file type by creating a directory based on the file type and making a symbolic link
432/// to the original file, so no files are actually deleted.
433///
434/// # Errors
435///
436/// An error occurs if files cannot be read or if symbolic links cannot be created in newly-created
437/// directories based on file type.
438pub fn file_sorting<P: AsRef<Path>>(
439    origin: P,
440    destination: P,
441    depth: u8,
442) -> Result<FileSortingResults> {
443    let mut total_files = 0;
444    let mut duplicate_files = 0;
445    let mut errors = 0;
446
447    for entry in WalkDir::new(origin)
448        .max_depth(MAX_RECURSION_DEPTH)
449        .follow_links(true)
450        .into_iter()
451        .flatten()
452    {
453        if entry.path().is_file() {
454            total_files += 1;
455
456            let Ok(contents) = std::fs::read(entry.path()) else {
457                errors += 1;
458                continue;
459            };
460
461            let file_type = FileTypeUnion::from_bytes(&contents);
462            let hash = hex::encode(Sha256::digest(contents));
463
464            let mut destination_file = destination.as_ref().join(file_type.to_string());
465            destination_file.push(hash_depth(&hash, depth));
466
467            std::fs::create_dir_all(&destination_file)?;
468            destination_file.push(hash);
469            if destination_file.exists() {
470                duplicate_files += 1;
471            } else {
472                #[cfg(unix)]
473                std::os::unix::fs::symlink(entry.path(), destination_file)?;
474
475                #[cfg(windows)]
476                std::os::windows::fs::symlink_file(entry.path(), destination_file)?;
477            }
478        }
479    }
480
481    Ok(FileSortingResults {
482        total_files,
483        errors,
484        files_removed: duplicate_files,
485    })
486}
487
488/// Build up directories representing the hash but don't include the hash itself so the directory
489/// structure can be created before adding the file to the path.
490#[inline]
491#[must_use]
492#[allow(clippy::cast_possible_truncation)]
493pub fn hash_depth(hash: &str, depth: u8) -> PathBuf {
494    let mut path = PathBuf::new();
495    for level in 0..depth.min((hash.len() / 2) as u8) {
496        path.push(&hash[(level as usize * 2)..=(level as usize * 2 + 1)]);
497    }
498
499    path
500}
501
502#[test]
503fn hash_depth_test() {
504    const HASH: &str = "9d6dc11990a109cd82d4dbafb6588b1b18e0e46b";
505    const HASH_512: &str = "fedc3e4d500fd9f3a52c05549a53f0f82ae684167033699e87ebe018517ceeb265136de09aa7e1fce5bbce0b8a4ead89170a99a5bdb2b5f7d1f02a81e3178af2";
506
507    // Test for what I assume is expected behaviour of pushing an empty string into a `PathBuf` object
508    let mut dummy = PathBuf::from("MyDir");
509    dummy.push(hash_depth(HASH, 0));
510    dummy.push("my_file.txt");
511    assert_eq!(dummy, PathBuf::from("MyDir/my_file.txt"));
512
513    assert_eq!(hash_depth(HASH, 0), PathBuf::from(""));
514    assert_eq!(hash_depth(HASH, 1), PathBuf::from("9d/"));
515    assert_eq!(hash_depth(HASH, 2), PathBuf::from("9d/6d/"));
516    assert_eq!(hash_depth(HASH, 3), PathBuf::from("9d/6d/c1/"));
517    assert_eq!(hash_depth(HASH, 4), PathBuf::from("9d/6d/c1/19/"));
518
519    // Ensure an absurd depth value doesn't go beyond the length of the hash
520    assert_eq!(
521        hash_depth(HASH, 255),
522        PathBuf::from("9d/6d/c1/19/90/a1/09/cd/82/d4/db/af/b6/58/8b/1b/18/e0/e4/6b/")
523    );
524
525    assert_eq!(hash_depth(HASH_512, 0), PathBuf::from(""));
526    assert_eq!(hash_depth(HASH_512, 1), PathBuf::from("fe/"));
527    assert_eq!(hash_depth(HASH_512, 2), PathBuf::from("fe/dc/"));
528    assert_eq!(hash_depth(HASH_512, 255), PathBuf::from("fe/dc/3e/4d/50/0f/d9/f3/a5/2c/05/54/9a/53/f0/f8/2a/e6/84/16/70/33/69/9e/87/eb/e0/18/51/7c/ee/b2/65/13/6d/e0/9a/a7/e1/fc/e5/bb/ce/0b/8a/4e/ad/89/17/0a/99/a5/bd/b2/b5/f7/d1/f0/2a/81/e3/17/8a/f2"));
529}