use crate::MAX_RECURSION_DEPTH;
use crate::ftype::FileType;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::path::{Path, PathBuf};
use anyhow::Result;
use clap::ValueEnum;
use serde::de::IntoDeserializer;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use walkdir::WalkDir;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct FileTypeUnion {
ftype: FileType,
non_model_type: NonModelTypes,
}
impl FileTypeUnion {
#[must_use]
pub fn from_bytes(bytes: &[u8]) -> Self {
if let Some(ftype) = FileType::from_bytes(bytes) {
Self {
ftype,
non_model_type: NonModelTypes::Unknown,
}
} else {
Self {
ftype: FileType::NotSet,
non_model_type: NonModelTypes::from_bytes(bytes),
}
}
}
#[must_use]
pub fn matches(&self, bytes: &[u8]) -> bool {
if self.ftype == FileType::NotSet {
NonModelTypes::from_bytes(bytes) == self.non_model_type
} else {
self.ftype.matches(bytes)
}
}
#[must_use]
pub fn is_unknown(&self) -> bool {
self.ftype == FileType::DsStore
|| (self.ftype == FileType::NotSet && self.non_model_type == NonModelTypes::Unknown)
}
}
pub fn parse_file_type_union(
s: &str,
) -> Result<FileTypeUnion, Box<dyn Error + Send + Sync + 'static>> {
if s.is_empty() {
return Err("File type cannot be empty.".into());
}
if let Ok(model_type) = crate::dataset::Dataset::file_type_from_line(s) {
return Ok(FileTypeUnion {
ftype: model_type,
non_model_type: NonModelTypes::Unknown,
});
}
let Ok(non_model_type) = NonModelTypes::name_from_line(s.to_lowercase().as_str()) else {
let mut allowed_variants = Vec::with_capacity(30);
for variant in FileType::value_variants() {
allowed_variants.push(variant.to_string().to_lowercase());
}
for variant in NonModelTypes::value_variants() {
allowed_variants.push(variant.to_string().to_lowercase());
}
return Err(format!(
"{s} is not a valid file type.\nAllowed types: {}.",
allowed_variants.join(", ")
)
.into());
};
Ok(FileTypeUnion {
ftype: FileType::NotSet,
non_model_type,
})
}
impl Display for FileTypeUnion {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.ftype == FileType::NotSet {
write!(f, "{}", self.non_model_type)
} else {
write!(f, "{}", self.ftype)
}
}
}
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum NonModelTypes {
BATCH,
CAB,
ChromeExt,
Class,
COFF,
COM,
DEX,
FLAC,
FLV,
GIF,
GZip,
HTML,
JPEG,
JPEG2K,
PEF,
PemCrt,
PemCsr,
PemKey,
Perl,
PHP,
PNG,
PS,
Python,
RAR,
Shell,
TIFF,
SWF,
Wasm,
WindowsRegistry,
WindowsShortcut,
XML,
Zip,
#[serde(skip)]
#[clap(skip)]
UnknownText,
#[doc(hidden)]
#[serde(skip)]
#[clap(skip)]
Unknown,
}
impl NonModelTypes {
#[must_use]
#[allow(clippy::too_many_lines)]
pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
if bytes.starts_with(b"PK") {
return Self::Zip;
}
if bytes.starts_with(b"\x4D\x53\x43\x46") || bytes.starts_with(b"\x4D\x53\x63\x28") {
return Self::CAB;
}
if bytes.starts_with(&[0x4C, 0x01])
|| bytes.starts_with(&[0x64, 0x86])
|| bytes.starts_with(&[0x00, 0x02])
{
return Self::COFF;
}
if bytes.starts_with(&[0xC9])
|| bytes.starts_with(&[0xE9])
|| bytes.starts_with(&[0xE8])
|| bytes.starts_with(&[0xEB])
{
return Self::COM;
}
if bytes.starts_with(&[0x43, 0x72, 0x32, 0x34]) {
return Self::ChromeExt;
}
if bytes.starts_with(b"\x64\x65\x78\x0A\x30\x33\x35\x00") {
return Self::DEX;
}
if bytes.starts_with(b"\x89PNG") {
return Self::PNG;
}
if bytes.starts_with(b"-----BEGIN CERTIFICATE-----") {
return Self::PemCrt;
}
if bytes.starts_with(b"-----BEGIN CERTIFICATE REQUEST-----") {
return Self::PemCsr;
}
if bytes.starts_with(b"-----BEGIN PRIVATE KEY-----")
|| bytes.starts_with(b"-----BEGIN DSA PRIVATE KEY-----")
|| bytes.starts_with(b"-----BEGIN RSA PRIVATE KEY-----")
{
return Self::PemKey;
}
if bytes.starts_with(b"\x46\x4C\x56") {
return Self::FLV;
}
if bytes.starts_with(b"GIF87") || bytes.starts_with(b"GIF89") {
return Self::GIF;
}
if bytes.starts_with(b"\xFF\xD8\xFF") {
return Self::JPEG;
}
if bytes.starts_with(&[0xFF, 0x4F, 0xFF, 0xF1])
|| bytes.starts_with(&[
0x00, 0x00, 0x00, 0x0C, 0x6A, 0x50, 0x20, 0x20, 0x0D, 0x0A, 0x87, 0x0A,
])
{
return Self::JPEG2K;
}
if bytes.starts_with(&[0x4A, 0x6F, 0x79, 0x21]) {
return Self::PEF;
}
if bytes.starts_with(b"%!") {
return Self::PS;
}
if bytes.starts_with(&[0x52, 0x61, 0x72, 0x21, 0x1A, 0x07]) {
return Self::RAR;
}
if bytes.starts_with(&[0x1F, 0x8B]) {
return Self::GZip;
}
if bytes.starts_with(b"CWS") || bytes.starts_with(b"FWS") || bytes.starts_with(b"ZWS") {
return Self::SWF;
}
if bytes.starts_with(&[0x49, 0x20, 0x49])
|| bytes.starts_with(&[0x49, 0x49, 0x2A])
|| bytes.starts_with(&[0x4D, 0x4D, 0x00])
{
return Self::TIFF;
}
if bytes.starts_with(b"\x66\x4C\x61\x43") {
return Self::FLAC;
}
if bytes.starts_with(b"\xCA\xFE\xBA\xBE") {
let version = u32::from_be_bytes([
bytes[0x04],
bytes[0x04 + 1],
bytes[0x04 + 2],
bytes[0x04 + 3],
]);
if version >= 0x20 {
return Self::Class;
}
}
if bytes.starts_with(&[0x00, 0x61, 0x73, 0x6D]) {
return Self::Wasm;
}
if bytes.starts_with(&[0x72, 0x65, 0x67, 0x66]) || bytes.starts_with(b"REGEDIT") {
return Self::WindowsRegistry;
}
if bytes.starts_with(&[0x4C, 0x00, 0x00, 0x00, 0x01, 0x14, 0x02, 0x00]) {
return Self::WindowsShortcut;
}
if bytes.is_ascii() {
if bytes.starts_with(b"<?xml") {
return Self::XML;
}
if bytes.starts_with(b"#!") {
let string_size = bytes.len().min(25);
if let Ok(string) =
String::from_utf8(bytes[0..string_size].to_ascii_lowercase().clone())
{
if string.contains("python") {
return Self::Python;
}
if string.contains("perl") {
return Self::Perl;
}
}
return Self::Shell;
}
let string_size = bytes.len().min(1000);
if let Ok(string) =
String::from_utf8(bytes[0..string_size].to_ascii_lowercase().clone())
{
if string.contains("<?xml") {
return Self::XML;
}
if string.contains("<?php") {
return Self::PHP;
}
if string.contains("<html") || string.contains("<!doctype html>") {
return Self::HTML;
}
if string.contains("@echo off") {
return Self::BATCH;
}
}
return Self::UnknownText;
}
Self::Unknown
}
pub(crate) fn name_from_line(line: &str) -> Result<Self, serde::de::value::Error> {
let line = line.split(':').nth(1).unwrap_or(line).to_uppercase();
let ftype: Result<_, serde::de::value::Error> =
Self::deserialize(String::from(line.trim()).into_deserializer());
ftype
}
}
impl From<NonModelTypes> for &'static str {
fn from(ftype: NonModelTypes) -> Self {
match ftype {
NonModelTypes::BATCH => "BAT",
NonModelTypes::CAB => "CAB",
NonModelTypes::COFF => "COFF",
NonModelTypes::COM => "COM",
NonModelTypes::ChromeExt => "ChromeExt",
NonModelTypes::Class => "Class",
NonModelTypes::DEX => "DEX",
NonModelTypes::FLAC => "FLAC",
NonModelTypes::FLV => "FLV",
NonModelTypes::GIF => "GIF",
NonModelTypes::GZip => "GZip",
NonModelTypes::HTML => "HTML",
NonModelTypes::JPEG => "JPEG",
NonModelTypes::JPEG2K => "JPEG2K",
NonModelTypes::PEF => "PEF",
NonModelTypes::PemCrt => "PemCrt",
NonModelTypes::PemCsr => "PemCsr",
NonModelTypes::PemKey => "PemKey",
NonModelTypes::Perl => "Perl",
NonModelTypes::PHP => "PHP",
NonModelTypes::PNG => "PNG",
NonModelTypes::PS => "PS",
NonModelTypes::Python => "Python",
NonModelTypes::RAR => "RAR",
NonModelTypes::Shell => "Shell",
NonModelTypes::SWF => "SWF",
NonModelTypes::TIFF => "TIFF",
NonModelTypes::Wasm => "Wasm",
NonModelTypes::WindowsRegistry => "WindowsRegistry",
NonModelTypes::WindowsShortcut => "WindowsShortcut",
NonModelTypes::XML => "XML",
NonModelTypes::Zip => "Zip",
NonModelTypes::Unknown => "Unknown",
NonModelTypes::UnknownText => "UnknownText",
}
}
}
impl Display for NonModelTypes {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let s: &'static str = (*self).into();
write!(f, "{s}")
}
}
pub struct FileSortingResults {
pub total_files: usize,
pub files_removed: usize,
pub errors: usize,
}
pub fn file_sorting<P: AsRef<Path>>(
origin: P,
destination: P,
depth: u8,
) -> Result<FileSortingResults> {
#[cfg(not(any(unix, windows)))]
anyhow::bail!(
"File Sorting is not supported on this platform {} due to missing symbolic link support.\nPlease file an issue at https://github.com/rjzak/malware-modeler-rs/issues/new/choose.",
std::env::consts::OS
);
let mut total_files = 0;
let mut duplicate_files = 0;
let mut errors = 0;
for entry in WalkDir::new(origin)
.max_depth(MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.path().is_file() {
total_files += 1;
let Ok(contents) = std::fs::read(entry.path()) else {
errors += 1;
continue;
};
let file_type = FileTypeUnion::from_bytes(&contents);
let hash = hex::encode(Sha256::digest(contents));
let mut destination_file = destination.as_ref().join(file_type.to_string());
destination_file.push(hash_depth(&hash, depth));
std::fs::create_dir_all(&destination_file)?;
destination_file.push(hash);
if destination_file.exists() {
duplicate_files += 1;
} else {
#[cfg(unix)]
std::os::unix::fs::symlink(entry.path(), destination_file)?;
#[cfg(windows)]
std::os::windows::fs::symlink_file(entry.path(), destination_file)?;
}
}
}
Ok(FileSortingResults {
total_files,
errors,
files_removed: duplicate_files,
})
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn hash_depth(hash: &str, depth: u8) -> PathBuf {
let mut path = PathBuf::new();
for level in 0..depth.min((hash.len() / 2) as u8) {
path.push(&hash[(level as usize * 2)..=(level as usize * 2 + 1)]);
}
path
}
#[test]
fn hash_depth_test() {
const HASH: &str = "9d6dc11990a109cd82d4dbafb6588b1b18e0e46b";
const HASH_512: &str = "fedc3e4d500fd9f3a52c05549a53f0f82ae684167033699e87ebe018517ceeb265136de09aa7e1fce5bbce0b8a4ead89170a99a5bdb2b5f7d1f02a81e3178af2";
let mut dummy = PathBuf::from("MyDir");
dummy.push(hash_depth(HASH, 0));
dummy.push("my_file.txt");
assert_eq!(dummy, PathBuf::from("MyDir/my_file.txt"));
assert_eq!(hash_depth(HASH, 0), PathBuf::from(""));
assert_eq!(hash_depth(HASH, 1), PathBuf::from("9d/"));
assert_eq!(hash_depth(HASH, 2), PathBuf::from("9d/6d/"));
assert_eq!(hash_depth(HASH, 3), PathBuf::from("9d/6d/c1/"));
assert_eq!(hash_depth(HASH, 4), PathBuf::from("9d/6d/c1/19/"));
assert_eq!(
hash_depth(HASH, 255),
PathBuf::from("9d/6d/c1/19/90/a1/09/cd/82/d4/db/af/b6/58/8b/1b/18/e0/e4/6b/")
);
assert_eq!(hash_depth(HASH_512, 0), PathBuf::from(""));
assert_eq!(hash_depth(HASH_512, 1), PathBuf::from("fe/"));
assert_eq!(hash_depth(HASH_512, 2), PathBuf::from("fe/dc/"));
assert_eq!(
hash_depth(HASH_512, 255),
PathBuf::from(
"fe/dc/3e/4d/50/0f/d9/f3/a5/2c/05/54/9a/53/f0/f8/2a/e6/84/16/70/33/69/9e/87/eb/e0/18/51/7c/ee/b2/65/13/6d/e0/9a/a7/e1/fc/e5/bb/ce/0b/8a/4e/ad/89/17/0a/99/a5/bd/b2/b5/f7/d1/f0/2a/81/e3/17/8a/f2"
)
);
}