use std::collections::HashMap;
use std::fmt;
use std::fs;
use std::path::Path;
use serde::de::{self, Deserializer, MapAccess, Visitor};
use serde::{Deserialize, Serialize};
use serde_yaml::Value as Yaml;
use crate::document::Document;
use crate::optimiser::{self, Optimisations};
use crate::parser::{self, Expression};
use crate::solver;
use crate::tokeniser::{ModSym, Token, Tokeniser};
#[derive(Clone, Serialize)]
pub struct Detection {
#[serde(skip_serializing)]
pub expression: Expression,
#[serde(skip_serializing)]
pub identifiers: HashMap<String, Expression>,
#[serde(rename = "condition")]
expression_raw: String,
#[serde(flatten)]
identifiers_raw: HashMap<String, Yaml>,
}
impl fmt::Debug for Detection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Detection")
.field("expression", &self.expression_raw)
.field("identifiers", &self.identifiers_raw)
.finish()
}
}
impl<'de> Deserialize<'de> for Detection {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct DetectionVisitor;
impl<'de> Visitor<'de> for DetectionVisitor {
type Value = Detection;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("struct Detection")
}
fn visit_map<V>(self, mut map: V) -> Result<Detection, V::Error>
where
V: MapAccess<'de>,
{
let mut identifiers: HashMap<String, Expression> = HashMap::new();
let mut identifiers_raw: HashMap<String, Yaml> = HashMap::new();
let mut expression = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_ref() {
"condition" => {
if expression.is_some() {
return Err(de::Error::duplicate_field("condition"));
}
expression = Some(map.next_value::<String>()?);
}
_ => {
if identifiers.get(&key).is_some() {
return Err(de::Error::custom(format_args!(
"duplicate field `{}`",
key
)));
}
let v: Yaml = map.next_value()?;
identifiers.insert(
key.to_string(),
parser::parse_identifier(&v).map_err(|e| {
de::Error::custom(format!(
"failed to parse identifier - {:?}",
e
))
})?,
);
identifiers_raw.insert(key.to_string(), v.clone());
}
}
}
let expression_raw =
expression.ok_or_else(|| de::Error::missing_field("condition"))?;
let tokens = match expression_raw.tokenise() {
Ok(tokens) => tokens,
Err(err) => {
return Err(de::Error::custom(format_args!(
"invalid value: condition, failed to tokenise - {}",
err
)));
}
};
let mut i = 0;
for token in &tokens {
if i > 1 {
if let Token::Modifier(m) = &tokens[i - 2] {
match m {
ModSym::Int | ModSym::Not | ModSym::Str => {
i += 1;
continue;
}
}
}
}
if let Token::Identifier(id) = token {
if !identifiers.contains_key(id) {
return Err(de::Error::custom(format_args!(
"invalid condition: identifier not found - {}",
id
)));
}
}
i += 1;
}
let expression = match parser::parse(&tokens) {
Ok(expression) => expression,
Err(err) => {
return Err(de::Error::custom(format_args!(
"invalid value: condition, failed to parse - {}",
err
)));
}
};
if !expression.is_solvable() {
return Err(de::Error::custom(format_args!(
"invalid value: condition, not solveable - {}",
expression
)));
}
Ok(Detection {
expression,
identifiers,
expression_raw,
identifiers_raw,
})
}
}
const FIELDS: &[&str] = &["identifiers", "condition"];
deserializer.deserialize_struct("Detection", FIELDS, DetectionVisitor)
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Rule {
#[serde(default)]
optimised: bool,
pub detection: Detection,
pub true_positives: Vec<Yaml>,
pub true_negatives: Vec<Yaml>,
}
impl Rule {
pub fn load(path: &Path) -> crate::Result<Self> {
let contents = fs::read_to_string(path).map_err(crate::error::rule_invalid)?;
Self::from_str(&contents)
}
pub fn from_str(s: &str) -> crate::Result<Self> {
serde_yaml::from_str(s).map_err(crate::error::rule_invalid)
}
pub fn from_value(value: serde_yaml::Value) -> crate::Result<Self> {
serde_yaml::from_value(value).map_err(crate::error::rule_invalid)
}
pub fn optimise(mut self, options: Optimisations) -> Self {
if self.optimised {
return self;
}
if options.coalesce {
self.detection.expression =
optimiser::coalesce(self.detection.expression, &self.detection.identifiers);
self.detection.identifiers.clear();
}
if options.shake {
self.detection.expression = optimiser::shake(self.detection.expression);
self.detection.identifiers = self
.detection
.identifiers
.into_iter()
.map(|(k, v)| (k, optimiser::shake(v)))
.collect();
}
if options.rewrite {
self.detection.expression = optimiser::rewrite(self.detection.expression);
self.detection.identifiers = self
.detection
.identifiers
.into_iter()
.map(|(k, v)| (k, optimiser::rewrite(v)))
.collect();
}
if options.matrix {
self.detection.expression = optimiser::matrix(self.detection.expression);
self.detection.identifiers = self
.detection
.identifiers
.into_iter()
.map(|(k, v)| (k, optimiser::matrix(v)))
.collect();
}
self.optimised = true;
self
}
#[inline]
pub fn matches(&self, document: &dyn Document) -> bool {
solver::solve(&self.detection, document)
}
pub fn validate(&self) -> crate::Result<bool> {
let mut errors = vec![];
for test in &self.true_positives {
if !(solver::solve(&self.detection, test.as_mapping().unwrap())) {
errors.push(format!(
"failed to validate true positive check '{:?}'",
test
));
}
}
for test in &self.true_negatives {
if solver::solve(&self.detection, test.as_mapping().unwrap()) {
errors.push(format!(
"failed to validate true negative check '{:?}'",
test
));
}
}
if !errors.is_empty() {
return Err(crate::Error::new(crate::error::Kind::Validation).with(errors.join(";")));
}
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rule() {
let rule = r#"
detection:
A:
foo: 'foo*'
bar: '*bar'
B:
foobar:
- foobar
- foobaz
condition: A and B
true_positives:
- foo: foobar
bar: foobar
foobar: foobar
true_negatives:
- foo: bar
bar: foo
foobar: barfoo
"#;
let rule = Rule::from_str(rule).unwrap();
assert_eq!(rule.validate().unwrap(), true);
}
}