malware-modeler 0.0.2

Train logisitic regression models for benign vs. malicious files based on byte n-grams and publish research.
Documentation
// SPDX-License-Identifier: Apache-2.0

use crate::dataset::Dataset;
use crate::ftype::FileType;
use crate::{Bytes, MAX_RECURSION_DEPTH};

use std::collections::HashMap;
use std::fs::File;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::io::{self, BufRead, BufReader, Lines, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};

use anyhow::{bail, ensure, Result};
use dashmap::{DashMap, DashSet};
use rayon::prelude::*;
use walkdir::WalkDir;

pub(crate) const NGRAM_BUFFER_SIZE: usize = 4096;

#[inline]
fn calculate_hash<T: Hash>(t: &T) -> u64 {
    let mut s = DefaultHasher::new();
    t.hash(&mut s);
    s.finish()
}

// The output is wrapped in a Result to allow matching on errors.
// Returns an Iterator to the Reader of the lines of the file.
// https://doc.rust-lang.org/rust-by-example/std_misc/file/read_lines.html
fn read_lines<P: AsRef<Path>>(filename: P) -> io::Result<Lines<BufReader<File>>> {
    let file = File::open(filename)?;
    Ok(BufReader::new(file).lines())
}

/// N-grams as read from a file
#[derive(Clone)]
pub struct NgramsFile {
    /// N-grams and their index position
    pub ngrams: HashMap<Bytes, usize>,

    /// File type represented by the n-grams
    pub ftype: FileType,

    /// Convenience: keep track of the size of the n-grams (value of `n`)
    pub n: usize,
}

impl NgramsFile {
    /// Load the n-grams from a plain text file
    ///
    /// # Errors
    ///
    /// An error results if:
    /// * The file cannot be read
    /// * The file type is not specified
    /// * The n-grams aren't hexidecimal
    /// * The n-grams are different sizes
    pub fn load<P: AsRef<Path>>(ngrams_file: P) -> Result<Self> {
        let mut file_type = FileType::NotSet;
        let mut n = 0;
        let mut ngrams = HashMap::new();

        let lines = read_lines(&ngrams_file)?;
        let mut ngram_counter = 0;
        let mut line_counter = 0u32;
        for line in lines.map_while(Result::ok) {
            if line.starts_with("# File type:") {
                if let Ok(ftype) = Dataset::file_type_from_line(&line) {
                    file_type = ftype;
                }
                line_counter += 1;
                continue;
            }

            let line = if let Some(l) = line.split(',').next() {
                l
            } else {
                &line
            };

            if !line.len().is_multiple_of(2) {
                bail!("Line {line_counter} {line} has odd number of characters.");
            }

            if n == 0 {
                n = line.len() / 2;
            } else if line.len() / 2 != n {
                bail!(
                    "Line {line_counter} {line} has unexpected length of {} bytes, expected {n}",
                    line.len() / 2
                );
            }

            match hex::decode(line) {
                Ok(ngram) => {
                    ngrams.insert(ngram, ngram_counter);
                    ngram_counter += 1;
                }
                Err(e) => bail!("Line {line_counter} has non-hexidecimal ngram {line}: {e}"),
            }
            line_counter += 1;
        }

        ensure!(
            !ngrams.is_empty(),
            "No n-grams read from {}.",
            ngrams_file.as_ref().display()
        );
        ensure!(
            file_type != FileType::NotSet,
            "No file type specified in n-grams file."
        );

        Ok(Self {
            ngrams,
            ftype: file_type,
            n,
        })
    }

    /// Convert the n-gram map into a vector, preserving their order
    #[must_use]
    pub fn into_vec(self) -> Vec<Bytes> {
        let mut ngrams_vec = vec![Vec::new(); self.ngrams.len()];
        for (ngram, index) in self.ngrams {
            ngrams_vec[index] = ngram;
        }
        ngrams_vec
    }
}

/// N-gramming object
pub struct Ngrammer {
    /// Size of the byte sequence
    n: u16,

    /// Number of n-grams to keep
    k: usize,

    /// File paths to be viewed
    paths: Vec<PathBuf>,

    /// Expected file type
    ftype: FileType,

    /// N-grams discovered
    ngrams: DashMap<Bytes, usize>,
}

impl Ngrammer {
    #[cfg(not(feature = "low-memory"))]
    const COUNTS: usize = 2_000_000;
    #[cfg(feature = "low-memory")]
    const COUNTS: usize = 100_000;

    /// Return an n-gram calculation object
    ///
    /// # Errors
    ///
    /// An error will occur if the provided directories don't exist or can't be traversed.
    pub fn new(ftype: Option<FileType>, dir: &Path, n: u16, k: usize) -> Result<Self> {
        let mut paths = Vec::new();
        let mut ftype = ftype;

        for entry in WalkDir::new(dir)
            .max_depth(MAX_RECURSION_DEPTH)
            .follow_links(true)
            .into_iter()
            .flatten()
        {
            if entry.file_type().is_file() {
                match ftype {
                    Some(file_type) => {
                        if !file_type.matches_path(entry.path())? {
                            bail!(
                                "File {} does not match expected type {file_type:?}",
                                entry.path().display()
                            );
                        }
                    }
                    None => {
                        if let Some(detected_type) = FileType::from_path(entry.path())? {
                            ftype = Some(detected_type);
                        } else {
                            bail!("Unknown file type for {}", entry.path().display());
                        }
                    }
                }

                paths.push(entry.into_path());
            }
        }

        ensure!(!paths.is_empty(), "No files found!");
        if let Some(ftype) = ftype {
            Ok(Self {
                n,
                k,
                paths,
                ftype,
                ngrams: DashMap::new(),
            })
        } else {
            bail!("File type not provided and not detected!");
        }
    }

    /// Find the n-grams storing the results internally
    #[allow(clippy::cast_possible_truncation)]
    pub fn find(&mut self) {
        let data: Vec<AtomicUsize> = vec![0usize; Self::COUNTS]
            .into_iter()
            .map(AtomicUsize::new)
            .collect();

        self.paths.par_iter().for_each(|p| {
            for ngrams in self.find_ngram(p).unwrap_or_default() {
                let index = calculate_hash(&ngrams) as usize % Self::COUNTS;
                data[index].fetch_add(1, Ordering::Relaxed);
            }
        });

        let min_count = if self.k < Self::COUNTS {
            let mut sorted: Vec<usize> = data
                .iter()
                .map(|v| v.load(Ordering::Relaxed))
                .collect::<Vec<usize>>();
            sorted.par_sort();
            sorted[sorted.len() - self.k]
        } else {
            1
        };

        let kept_ngrams = DashMap::with_capacity(self.k);
        self.paths.par_iter().for_each(|p| {
            let file_size = match p.metadata() {
                Ok(metadata) => metadata.len(),
                Err(_) => return,
            };

            let Ok(mut file) = File::open(p) else { return };
            let mut buffer = vec![0; NGRAM_BUFFER_SIZE];
            let n = u64::from(self.n);

            loop {
                let read_count = file.read(&mut buffer).unwrap_or(0);
                let position = file.stream_position().unwrap_or_default();
                if position < n {
                    // Skip files which are too short or if we failed to read
                    break;
                }
                file.seek(SeekFrom::Start(position - n)).unwrap_or_default();

                for index in 0..read_count - self.n as usize {
                    let bytes = &buffer[index..index + self.n as usize];
                    let index = calculate_hash(&bytes) as usize % Self::COUNTS;
                    let count = data[index].load(Ordering::Relaxed);
                    if count >= min_count {
                        kept_ngrams.insert(bytes.to_vec(), count);
                        if kept_ngrams.len() >= self.k {
                            break;
                        }
                    }
                }

                if position >= file_size {
                    break;
                }
            }
        });

        self.ngrams = kept_ngrams;
    }

    fn find_ngram(&self, path: &Path) -> Result<DashSet<Bytes>> {
        let ngrams = DashSet::new();
        let file_size = path.metadata()?.len();

        let mut buffer = vec![0; NGRAM_BUFFER_SIZE];
        let mut file = File::open(path)?;
        let n = u64::from(self.n);

        loop {
            let read_count = file.read(&mut buffer)?;
            let position = file.stream_position()?;
            if position < n {
                // Skip files which are too short
                break;
            }
            file.seek(SeekFrom::Start(position - n))?;

            for index in 0..read_count - self.n as usize {
                let bytes = &buffer[index..index + self.n as usize];
                ngrams.insert(Vec::from(bytes));
            }

            if position >= file_size {
                break;
            }
        }

        Ok(ngrams)
    }

    /// Save the n-grams as a plain text file
    ///
    /// # Errors
    ///
    /// An error occurs if the file cannot be opened or cannot be written.
    pub fn save<P: AsRef<Path>>(&self, path: P, counts: bool) -> Result<()> {
        let mut output = File::create(path)?;
        writeln!(output, "# File type: {}", self.ftype)?;
        for entry in &self.ngrams {
            if counts {
                writeln!(output, "{},{}", hex::encode(entry.key()), entry.value())?;
            } else {
                writeln!(output, "{}", hex::encode(entry.key()))?;
            }
        }
        output.flush()?;
        Ok(())
    }

    /// Return the n-grams discovered
    #[inline]
    #[must_use]
    pub fn ngrams(&self) -> &DashMap<Bytes, usize> {
        &self.ngrams
    }

    /// Return the expected or discovered file type
    #[inline]
    #[must_use]
    pub fn ftype(&self) -> FileType {
        self.ftype
    }
}