malware-modeler 0.0.1

Train logisitic regression models for benign vs. malicious files based on byte n-grams and publish research.
Documentation
// SPDX-License-Identifier: Apache-2.0

#![doc = include_str!("../readme.md")]
#![deny(clippy::all)]
//#![deny(clippy::cargo)]
#![deny(clippy::pedantic)]
#![allow(clippy::doc_markdown)] // Clippy has issues with some names in the research list
#![deny(missing_docs)]
#![forbid(unsafe_code)]

/// Data structures and logic for storing training/inference data
pub mod dataset;

/// Data structure and logic for training a model and calculating predictions
pub mod model;

use std::fs::File;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::io::{Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};

use anyhow::{ensure, Result};
use dashmap::{DashMap, DashSet};
use rayon::prelude::*;
use walkdir::WalkDir;

/// Malware Modeler version
pub const VERSION: &str = concat!(
    "v",
    env!("CARGO_PKG_VERSION"),
    "-",
    env!("VERGEN_GIT_DESCRIBE"),
    " ",
    env!("VERGEN_BUILD_DATE")
);

/// Convenience type for vector of bytes
pub type Bytes = Vec<u8>;

/// Maximum recursion depth when talking a directory structure
pub const MAX_RECURSION_DEPTH: usize = 10;
const NGRAM_BUFFER_SIZE: usize = 4096;

fn calculate_hash<T: Hash>(t: &T) -> u64 {
    let mut s = DefaultHasher::new();
    t.hash(&mut s);
    s.finish()
}

/// N-gramming object
pub struct Ngrammer {
    /// Size of the byte sequence
    n: u16,

    /// Number of n-grams to keep
    k: usize,

    /// File paths to be viewed
    paths: Vec<PathBuf>,
}

impl Ngrammer {
    const COUNTS: usize = 1_000_000;

    /// Return an n-gram calculation object
    ///
    /// # Errors
    ///
    /// An error will occur if the provided directories don't exist or can't be traversed.
    pub fn new(dir: &Path, n: u16, k: usize) -> Result<Self> {
        let mut paths = Vec::new();

        for entry in WalkDir::new(dir)
            .max_depth(MAX_RECURSION_DEPTH)
            .follow_links(true)
            .into_iter()
            .flatten()
        {
            if entry.file_type().is_file() {
                paths.push(entry.into_path());
            }
        }

        ensure!(!paths.is_empty(), "No files found!");
        Ok(Self { n, k, paths })
    }

    /// Return the bytes from the discovered paths along with their occurrence counts
    #[allow(clippy::cast_possible_truncation)]
    pub fn ngrams(&self) -> DashMap<Bytes, usize> {
        let data: Vec<AtomicUsize> = vec![0usize; Self::COUNTS]
            .into_iter()
            .map(AtomicUsize::new)
            .collect();

        self.paths.par_iter().for_each(|p| {
            for ngrams in self.find_ngram(p).unwrap_or_default() {
                let index = calculate_hash(&ngrams) as usize % Self::COUNTS;
                data[index].fetch_add(1, Ordering::Relaxed);
            }
        });

        let min_count = if self.k < Self::COUNTS {
            let mut sorted: Vec<usize> = data
                .iter()
                .map(|v| v.load(Ordering::Relaxed))
                .collect::<Vec<usize>>();
            sorted.par_sort();
            sorted[sorted.len() - self.k]
        } else {
            1
        };

        let kept_ngrams = DashMap::with_capacity(self.k);
        self.paths.par_iter().for_each(|p| {
            let file_size = match p.metadata() {
                Ok(metadata) => metadata.len(),
                Err(_) => return,
            };

            let Ok(mut file) = File::open(p) else { return };
            let mut buffer = vec![0; NGRAM_BUFFER_SIZE];
            let n = u64::from(self.n);

            loop {
                let read_count = file.read(&mut buffer).unwrap_or(0);
                let position = file.stream_position().unwrap_or_default();
                if position < n {
                    // Skip files which are too short or if we failed to read
                    break;
                }
                file.seek(SeekFrom::Start(position - n)).unwrap_or_default();

                for index in 0..read_count - self.n as usize {
                    let bytes = &buffer[index..index + self.n as usize];
                    let index = calculate_hash(&bytes) as usize % Self::COUNTS;
                    let count = data[index].load(Ordering::Relaxed);
                    if count >= min_count {
                        kept_ngrams.insert(bytes.to_vec(), count);
                        if kept_ngrams.len() >= self.k {
                            break;
                        }
                    }
                }

                if position >= file_size {
                    break;
                }
            }
        });

        kept_ngrams
    }

    fn find_ngram(&self, path: &Path) -> Result<DashSet<Bytes>> {
        let ngrams = DashSet::new();
        let file_size = path.metadata()?.len();

        let mut buffer = vec![0; NGRAM_BUFFER_SIZE];
        let mut file = File::open(path)?;
        let n = u64::from(self.n);

        loop {
            let read_count = file.read(&mut buffer)?;
            let position = file.stream_position()?;
            if position < n {
                // Skip files which are too short
                break;
            }
            file.seek(SeekFrom::Start(position - n))?;

            for index in 0..read_count - self.n as usize {
                let bytes = &buffer[index..index + self.n as usize];
                ngrams.insert(Vec::from(bytes));
            }

            if position >= file_size {
                break;
            }
        }

        Ok(ngrams)
    }
}