use std::collections::{BTreeMap, HashMap, hash_map};
use std::fmt::{Debug, Formatter};
use std::fs;
use std::io::Read;
use std::mem::transmute;
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::slice::Iter;
use std::sync::Once;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use bitvec::prelude::*;
use memmap2::{Mmap, MmapOptions};
use protobuf::{CodedInputStream, MessageDyn};
use thiserror::Error;
use crate::compiler::{RuleId, Rules};
use crate::models::Rule;
use crate::modules::{BUILTIN_MODULES, Module, ModuleError};
use crate::scanner::context::create_wasm_store_and_ctx;
use crate::types::{Struct, TypeValue};
use crate::variables::VariableError;
use crate::wasm::MATCHING_RULES_BITMAP_BASE;
use crate::wasm::runtime::Store;
use crate::{Variable, modules};
pub(crate) use crate::scanner::context::RuntimeObject;
pub(crate) use crate::scanner::context::RuntimeObjectHandle;
pub(crate) use crate::scanner::context::ScanContext;
pub(crate) use crate::scanner::context::ScanState;
pub(crate) use crate::scanner::matches::Match;
mod context;
mod matches;
pub mod blocks;
#[cfg(test)]
mod tests;
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ScanError {
#[error("timeout")]
Timeout,
#[error("can not open `{path}`: {err}")]
OpenError {
path: PathBuf,
err: std::io::Error,
},
#[error("can not map `{path}`: {err}")]
MapError {
path: PathBuf,
err: std::io::Error,
},
#[error(
"can not deserialize protobuf message for YARA module `{module}`: {err}"
)]
ProtoError {
module: String,
err: protobuf::Error,
},
#[error("unknown module `{module}`")]
UnknownModule {
module: String,
},
#[error("error in module `{module}`: {err}")]
ModuleError {
module: String,
err: ModuleError,
},
}
static HEARTBEAT_COUNTER: AtomicU64 = AtomicU64::new(0);
static INIT_HEARTBEAT: Once = Once::new();
pub enum ScannedData<'d> {
Slice(&'d [u8]),
Vec(Vec<u8>),
Mmap(Mmap),
}
impl AsRef<[u8]> for ScannedData<'_> {
fn as_ref(&self) -> &[u8] {
match self {
ScannedData::Slice(s) => s,
ScannedData::Vec(v) => v.as_ref(),
ScannedData::Mmap(m) => m.as_ref(),
}
}
}
impl<'d> TryInto<ScannedData<'d>> for &'d [u8] {
type Error = ScanError;
fn try_into(self) -> Result<ScannedData<'d>, Self::Error> {
Ok(ScannedData::Slice(self))
}
}
impl<'d, const N: usize> TryInto<ScannedData<'d>> for &'d [u8; N] {
type Error = ScanError;
fn try_into(self) -> Result<ScannedData<'d>, Self::Error> {
Ok(ScannedData::Slice(self))
}
}
#[cfg(feature = "rules-profiling")]
pub struct ProfilingData<'r> {
pub namespace: &'r str,
pub rule: &'r str,
pub condition_exec_time: Duration,
pub pattern_matching_time: Duration,
}
#[derive(Debug, Default)]
pub struct ScanOptions<'a> {
module_metadata: HashMap<&'a str, &'a [u8]>,
}
impl<'a> ScanOptions<'a> {
pub fn new() -> Self {
Self { module_metadata: Default::default() }
}
pub fn set_module_metadata(
mut self,
module_name: &'a str,
metadata: &'a [u8],
) -> Self {
self.module_metadata.insert(module_name, metadata);
self
}
}
pub struct Scanner<'r> {
_rules: &'r Rules,
wasm_store: Pin<Box<Store<ScanContext<'static, 'static>>>>,
use_mmap: bool,
}
impl<'r> Scanner<'r> {
pub fn new(rules: &'r Rules) -> Self {
let wasm_store = create_wasm_store_and_ctx(rules);
Self { _rules: rules, wasm_store, use_mmap: true }
}
pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
self.scan_context_mut().set_timeout(timeout);
self
}
pub fn max_matches_per_pattern(&mut self, n: usize) -> &mut Self {
self.scan_context_mut().pattern_matches.max_matches_per_pattern(n);
self
}
pub fn use_mmap(&mut self, yes: bool) -> &mut Self {
self.use_mmap = yes;
self
}
pub fn console_log<F>(&mut self, callback: F) -> &mut Self
where
F: FnMut(String) + 'r,
{
self.scan_context_mut().console_log = Some(Box::new(callback));
self
}
pub fn scan<'a>(
&'a mut self,
data: &'a [u8],
) -> Result<ScanResults<'a, 'r>, ScanError> {
self.scan_impl(data.try_into()?, None)
}
pub fn scan_file<'a, P>(
&'a mut self,
target: P,
) -> Result<ScanResults<'a, 'r>, ScanError>
where
P: AsRef<Path>,
{
self.scan_impl(self.load_file(target.as_ref())?, None)
}
pub fn scan_with_options<'a, 'opts>(
&'a mut self,
data: &'a [u8],
options: ScanOptions<'opts>,
) -> Result<ScanResults<'a, 'r>, ScanError> {
self.scan_impl(ScannedData::Slice(data), Some(options))
}
pub fn scan_file_with_options<'opts, P>(
&mut self,
target: P,
options: ScanOptions<'opts>,
) -> Result<ScanResults<'_, 'r>, ScanError>
where
P: AsRef<Path>,
{
self.scan_impl(self.load_file(target.as_ref())?, Some(options))
}
pub fn set_global<T: TryInto<Variable>>(
&mut self,
ident: &str,
value: T,
) -> Result<&mut Self, VariableError>
where
VariableError: From<<T as TryInto<Variable>>::Error>,
{
self.scan_context_mut().set_global(ident, value)?;
Ok(self)
}
pub fn set_module_output(
&mut self,
data: Box<dyn MessageDyn>,
) -> Result<&mut Self, ScanError> {
let descriptor = data.descriptor_dyn();
let full_name = descriptor.full_name();
if !BUILTIN_MODULES
.iter()
.any(|m| m.1.root_struct_descriptor.full_name() == full_name)
{
return Err(ScanError::UnknownModule {
module: full_name.to_string(),
});
}
self.scan_context_mut()
.user_provided_module_outputs
.insert(full_name.to_string(), data);
Ok(self)
}
pub fn set_module_output_raw(
&mut self,
name: &str,
data: &[u8],
) -> Result<&mut Self, ScanError> {
let descriptor = if let Some(module) = BUILTIN_MODULES.get(name) {
Some(&module.root_struct_descriptor)
} else {
BUILTIN_MODULES.values().find_map(|module| {
if module.root_struct_descriptor.full_name() == name {
Some(&module.root_struct_descriptor)
} else {
None
}
})
};
if descriptor.is_none() {
return Err(ScanError::UnknownModule { module: name.to_string() });
}
let mut is = CodedInputStream::from_bytes(data);
is.set_recursion_limit(500);
self.set_module_output(
descriptor.unwrap().parse_from(&mut is).map_err(|err| {
ScanError::ProtoError { module: name.to_string(), err }
})?,
)
}
#[cfg(feature = "rules-profiling")]
pub fn slowest_rules(&self, n: usize) -> Vec<ProfilingData<'_>> {
self.scan_context().slowest_rules(n)
}
#[cfg(feature = "rules-profiling")]
pub fn clear_profiling_data(&mut self) {
self.scan_context_mut().clear_profiling_data()
}
}
impl<'r> Scanner<'r> {
#[cfg(feature = "rules-profiling")]
#[inline]
fn scan_context<'a>(&self) -> &ScanContext<'r, 'a> {
unsafe {
transmute::<&ScanContext<'static, 'static>, &ScanContext<'r, '_>>(
self.wasm_store.data(),
)
}
}
#[inline]
fn scan_context_mut<'a>(&mut self) -> &mut ScanContext<'r, 'a> {
unsafe {
transmute::<
&mut ScanContext<'static, 'static>,
&mut ScanContext<'r, '_>,
>(self.wasm_store.data_mut())
}
}
fn load_file<'a>(
&self,
path: &Path,
) -> Result<ScannedData<'a>, ScanError> {
let mut file = fs::File::open(path).map_err(|err| {
ScanError::OpenError { path: path.to_path_buf(), err }
})?;
let size = file.metadata().map(|m| m.len()).unwrap_or(0);
let mut buffered_file;
let mapped_file;
let data = if self.use_mmap && size > 500_000_000 {
mapped_file = unsafe {
MmapOptions::new().map_copy_read_only(&file).map_err(|err| {
ScanError::MapError { path: path.to_path_buf(), err }
})
}?;
ScannedData::Mmap(mapped_file)
} else {
buffered_file = Vec::with_capacity(size as usize);
file.read_to_end(&mut buffered_file).map_err(|err| {
ScanError::OpenError { path: path.to_path_buf(), err }
})?;
ScannedData::Vec(buffered_file)
};
Ok(data)
}
fn scan_impl<'a, 'opts>(
&'a mut self,
data: ScannedData<'a>,
options: Option<ScanOptions<'opts>>,
) -> Result<ScanResults<'a, 'r>, ScanError> {
let ctx = self.scan_context_mut();
ctx.reset();
ctx.set_filesize(data.as_ref().len() as i64);
ctx.scan_state = ScanState::ScanningData(data);
for module_name in ctx.compiled_rules.imports() {
let module = modules::BUILTIN_MODULES
.get(module_name)
.unwrap_or_else(|| panic!("module `{module_name}` not found"));
let root_struct_name = module.root_struct_descriptor.full_name();
let module_output;
if let Some(output) =
ctx.user_provided_module_outputs.remove(root_struct_name)
{
module_output = Some(output);
} else {
let meta: Option<&'opts [u8]> =
options.as_ref().and_then(|options| {
options.module_metadata.get(module_name).copied()
});
if let Some(main_fn) = module.main_fn {
module_output = Some(
main_fn(ctx.scanned_data().unwrap(), meta).map_err(
|err| ScanError::ModuleError {
module: module_name.to_string(),
err,
},
)?,
);
} else {
module_output = None;
}
}
if let Some(module_output) = &module_output {
debug_assert_eq!(
module_output.descriptor_dyn().full_name(),
module.root_struct_descriptor.full_name(),
"main function of module `{}` must return `{}`, but returned `{}`",
module_name,
module.root_struct_descriptor.full_name(),
module_output.descriptor_dyn().full_name(),
);
debug_assert!(
module_output.is_initialized_dyn(),
"module `{}` returned a protobuf `{}` where some required fields are not initialized ",
module_name,
module.root_struct_descriptor.full_name()
);
}
let generate_fields_for_enums =
!cfg!(feature = "constant-folding");
let module_struct = Struct::from_proto_descriptor_and_msg(
&module.root_struct_descriptor,
module_output.as_deref(),
generate_fields_for_enums,
);
if let Some(module_output) = module_output {
ctx.module_outputs
.insert(root_struct_name.to_string(), module_output);
}
ctx.root_struct
.add_field(module_name, TypeValue::Struct(module_struct));
}
ctx.user_provided_module_outputs.clear();
ctx.set_pattern_search_done(false);
ctx.eval_conditions()?;
let data = match ctx.scan_state.take() {
ScanState::ScanningData(data) => data,
_ => unreachable!(),
};
ctx.scan_state = ScanState::Finished(DataSnippets::SingleBlock(data));
Ok(ScanResults::new(ctx))
}
}
pub(crate) enum DataSnippets<'d> {
SingleBlock(ScannedData<'d>),
MultiBlock(BTreeMap<usize, Vec<u8>>),
}
impl DataSnippets<'_> {
pub(crate) fn get(&self, range: Range<usize>) -> Option<&[u8]> {
match self {
Self::SingleBlock(data) => data.as_ref().get(range),
Self::MultiBlock(btree) => {
let (snippet_offset, snippet_data) =
btree.range(..=range.start).next_back()?;
let start = range.start - snippet_offset;
let end = range.end - snippet_offset;
snippet_data.get(start..end)
}
}
}
}
pub struct ScanResults<'a, 'r> {
ctx: &'a ScanContext<'r, 'a>,
}
impl Debug for ScanResults<'_, '_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("ScanResults")
}
}
impl<'a, 'r> ScanResults<'a, 'r> {
fn new(ctx: &'a ScanContext<'r, 'a>) -> Self {
Self { ctx }
}
pub fn matching_rules(&self) -> MatchingRules<'_, 'r> {
MatchingRules::new(self.ctx)
}
pub fn non_matching_rules(&self) -> NonMatchingRules<'_, 'r> {
NonMatchingRules::new(self.ctx)
}
pub fn module_output(
&self,
module_name: &str,
) -> Option<&'a dyn MessageDyn> {
let module = BUILTIN_MODULES.get(module_name)?;
let module_output = self
.ctx
.module_outputs
.get(module.root_struct_descriptor.full_name())?
.as_ref();
Some(module_output)
}
pub fn module_outputs(&self) -> ModuleOutputs<'a, 'r> {
ModuleOutputs::new(self.ctx)
}
}
pub struct MatchingRules<'a, 'r> {
ctx: &'a ScanContext<'r, 'a>,
iterator: Iter<'a, RuleId>,
len_non_private: usize,
len_private: usize,
include_private: bool,
}
impl<'a, 'r> MatchingRules<'a, 'r> {
fn new(ctx: &'a ScanContext<'r, 'a>) -> Self {
Self {
ctx,
iterator: ctx.matching_rules.iter(),
include_private: false,
len_non_private: ctx.matching_rules.len()
- ctx.num_matching_private_rules,
len_private: ctx.num_matching_private_rules,
}
}
pub fn include_private(mut self, yes: bool) -> Self {
self.include_private = yes;
self
}
}
impl<'a, 'r> Iterator for MatchingRules<'a, 'r> {
type Item = Rule<'a, 'r>;
fn next(&mut self) -> Option<Self::Item> {
let rules = self.ctx.compiled_rules;
loop {
let rule_id = *self.iterator.next()?;
let rule_info = rules.get(rule_id);
if rule_info.is_private {
self.len_private -= 1;
} else {
self.len_non_private -= 1;
}
if self.include_private || !rule_info.is_private {
return Some(Rule { ctx: Some(self.ctx), rule_info, rules });
}
}
}
}
impl ExactSizeIterator for MatchingRules<'_, '_> {
#[inline]
fn len(&self) -> usize {
if self.include_private {
self.len_non_private + self.len_private
} else {
self.len_non_private
}
}
}
pub struct NonMatchingRules<'a, 'r> {
ctx: &'a ScanContext<'r, 'a>,
iterator: bitvec::slice::IterZeros<'a, u8, Lsb0>,
include_private: bool,
len_private: usize,
len_non_private: usize,
}
impl<'a, 'r> NonMatchingRules<'a, 'r> {
fn new(ctx: &'a ScanContext<'r, 'a>) -> Self {
let num_rules = ctx.compiled_rules.num_rules();
let main_memory = ctx
.wasm_main_memory
.unwrap()
.data(unsafe { ctx.wasm_store.as_ref() });
let base = MATCHING_RULES_BITMAP_BASE as usize;
let matching_rules_bitmap = BitSlice::<_, Lsb0>::from_slice(
&main_memory[base..base + num_rules / 8 + 1],
);
let matching_rules_bitmap = &matching_rules_bitmap[0..num_rules];
Self {
ctx,
iterator: matching_rules_bitmap.iter_zeros(),
include_private: false,
len_non_private: ctx.compiled_rules.num_rules()
- ctx.matching_rules.len()
- ctx.num_non_matching_private_rules,
len_private: ctx.num_non_matching_private_rules,
}
}
pub fn include_private(mut self, yes: bool) -> Self {
self.include_private = yes;
self
}
}
impl<'a, 'r> Iterator for NonMatchingRules<'a, 'r> {
type Item = Rule<'a, 'r>;
fn next(&mut self) -> Option<Self::Item> {
let rules = self.ctx.compiled_rules;
loop {
let rule_id = RuleId::from(self.iterator.next()?);
let rule_info = rules.get(rule_id);
if rule_info.is_private {
self.len_private -= 1;
} else {
self.len_non_private -= 1;
}
if self.include_private || !rule_info.is_private {
return Some(Rule { ctx: Some(self.ctx), rule_info, rules });
}
}
}
}
impl ExactSizeIterator for NonMatchingRules<'_, '_> {
#[inline]
fn len(&self) -> usize {
if self.include_private {
self.len_non_private + self.len_private
} else {
self.len_non_private
}
}
}
pub struct ModuleOutputs<'a, 'r> {
ctx: &'a ScanContext<'r, 'a>,
len: usize,
iterator: hash_map::Iter<'a, &'a str, Module>,
}
impl<'a, 'r> ModuleOutputs<'a, 'r> {
fn new(ctx: &'a ScanContext<'r, 'a>) -> Self {
Self {
ctx,
len: ctx.module_outputs.len(),
iterator: BUILTIN_MODULES.iter(),
}
}
}
impl ExactSizeIterator for ModuleOutputs<'_, '_> {
#[inline]
fn len(&self) -> usize {
self.len
}
}
impl<'a> Iterator for ModuleOutputs<'a, '_> {
type Item = (&'a str, &'a dyn MessageDyn);
fn next(&mut self) -> Option<Self::Item> {
loop {
let (name, module) = self.iterator.next()?;
if let Some(module_output) = self
.ctx
.module_outputs
.get(module.root_struct_descriptor.full_name())
{
return Some((*name, module_output.as_ref()));
}
}
}
}
#[cfg(test)]
mod snippet_tests {
use super::DataSnippets;
use std::collections::BTreeMap;
#[test]
fn snippets() {
let mut btree_map = BTreeMap::new();
btree_map.insert(0, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
btree_map.insert(50, vec![51, 52, 53, 54]);
let snippets = DataSnippets::MultiBlock(btree_map);
assert_eq!(snippets.get(0..2), Some([1, 2].as_slice()));
assert_eq!(snippets.get(1..3), Some([2, 3].as_slice()));
assert_eq!(snippets.get(8..9), Some([9].as_slice()));
assert_eq!(snippets.get(9..10), None);
assert_eq!(snippets.get(50..51), Some([51].as_slice()));
assert_eq!(snippets.get(50..54), Some([51, 52, 53, 54].as_slice()));
assert_eq!(snippets.get(52..54), Some([53, 54].as_slice()));
assert_eq!(snippets.get(50..56), None);
}
}