use crate::parser::ast::{MagicRule, OffsetSpec, Operator, StrengthModifier, TypeKind, Value};
pub const MAX_STRENGTH: i32 = 255;
pub const MIN_STRENGTH: i32 = 0;
#[must_use]
pub fn calculate_default_strength(rule: &MagicRule) -> i32 {
let mut strength: i32 = 0;
strength += match &rule.typ {
TypeKind::String { max_length } => {
let base = 20;
if max_length.is_some() { base + 5 } else { base }
}
TypeKind::Quad { .. } | TypeKind::Double { .. } => 16,
TypeKind::Long { .. } | TypeKind::Float { .. } => 15,
TypeKind::Short { .. } => 10,
TypeKind::Byte { .. } => 5,
};
strength += match &rule.op {
Operator::Equal => 10,
Operator::NotEqual => 5,
Operator::LessThan
| Operator::GreaterThan
| Operator::LessEqual
| Operator::GreaterEqual => 6,
Operator::BitwiseAndMask(_) => 7,
Operator::BitwiseAnd => 3,
Operator::BitwiseXor | Operator::BitwiseNot => 4,
Operator::AnyValue => 1,
};
strength += match &rule.offset {
OffsetSpec::Absolute(_) => 10,
OffsetSpec::FromEnd(_) => 8,
OffsetSpec::Indirect { .. } => 5,
OffsetSpec::Relative(_) => 3,
};
let value_length_bonus = match &rule.value {
Value::String(s) => {
i32::try_from(s.len()).unwrap_or(20).min(20)
}
Value::Bytes(b) => {
i32::try_from(b.len()).unwrap_or(20).min(20)
}
Value::Uint(_) | Value::Int(_) | Value::Float(_) => 0,
};
strength += value_length_bonus;
strength.clamp(MIN_STRENGTH, MAX_STRENGTH)
}
#[must_use]
pub fn apply_strength_modifier(base_strength: i32, modifier: &StrengthModifier) -> i32 {
let result = match modifier {
StrengthModifier::Add(n) => base_strength.saturating_add(*n),
StrengthModifier::Subtract(n) => base_strength.saturating_sub(*n),
StrengthModifier::Multiply(n) => base_strength.saturating_mul(*n),
StrengthModifier::Divide(n) => {
if *n == 0 {
base_strength
} else {
base_strength / n
}
}
StrengthModifier::Set(n) => *n,
};
result.clamp(MIN_STRENGTH, MAX_STRENGTH)
}
#[must_use]
pub fn calculate_rule_strength(rule: &MagicRule) -> i32 {
let base_strength = calculate_default_strength(rule);
if let Some(ref modifier) = rule.strength_modifier {
apply_strength_modifier(base_strength, modifier)
} else {
base_strength
}
}
pub fn sort_rules_by_strength(rules: &mut [MagicRule]) {
rules.sort_by(|a, b| {
let strength_a = calculate_rule_strength(a);
let strength_b = calculate_rule_strength(b);
strength_b.cmp(&strength_a)
});
}
#[must_use]
pub fn into_sorted_by_strength(mut rules: Vec<MagicRule>) -> Vec<MagicRule> {
sort_rules_by_strength(&mut rules);
rules
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::ast::Endianness;
fn make_rule(typ: TypeKind, op: Operator, offset: OffsetSpec, value: Value) -> MagicRule {
MagicRule {
offset,
typ,
op,
value,
message: "test".to_string(),
children: vec![],
level: 0,
strength_modifier: None,
}
}
#[test]
fn test_strength_type_byte() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 25);
}
#[test]
fn test_strength_type_short() {
let rule = make_rule(
TypeKind::Short {
endian: Endianness::Little,
signed: false,
},
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 30);
}
#[test]
fn test_strength_type_long() {
let rule = make_rule(
TypeKind::Long {
endian: Endianness::Big,
signed: false,
},
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 35);
}
#[test]
fn test_strength_type_quad() {
let rule = make_rule(
TypeKind::Quad {
endian: Endianness::Little,
signed: false,
},
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 36);
}
#[test]
fn test_strength_type_string() {
let rule = make_rule(
TypeKind::String { max_length: None },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::String("ELF".to_string()),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 43);
}
#[test]
fn test_strength_type_string_with_max_length() {
let rule = make_rule(
TypeKind::String {
max_length: Some(10),
},
Operator::Equal,
OffsetSpec::Absolute(0),
Value::String("TEST".to_string()),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 49);
}
#[test]
fn test_strength_operator_not_equal() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::NotEqual,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 20);
}
#[test]
fn test_strength_operator_bitwise_and() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::BitwiseAnd,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 18);
}
#[test]
fn test_strength_operator_bitwise_and_mask() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::BitwiseAndMask(0xFF),
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 22);
}
#[test]
fn test_strength_comparison_operators() {
let operators = [
Operator::LessThan,
Operator::GreaterThan,
Operator::LessEqual,
Operator::GreaterEqual,
];
for op in operators {
let rule = make_rule(
TypeKind::Byte { signed: true },
op.clone(),
OffsetSpec::Absolute(0),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 21, "Failed for operator: {op:?}");
}
}
#[test]
fn test_strength_offset_indirect() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Indirect {
base_offset: 0,
pointer_type: TypeKind::Long {
endian: Endianness::Little,
signed: false,
},
adjustment: 0,
endian: Endianness::Little,
},
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 20);
}
#[test]
fn test_strength_offset_relative() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Relative(4),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 18);
}
#[test]
fn test_strength_offset_from_end() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::FromEnd(-4),
Value::Uint(0),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 23);
}
#[test]
fn test_strength_value_bytes() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Bytes(vec![0x7f, 0x45, 0x4c, 0x46]),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 29);
}
#[test]
fn test_strength_value_long_string() {
let rule = make_rule(
TypeKind::String { max_length: None },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::String("This is a very long string that exceeds the cap".to_string()),
);
let strength = calculate_default_strength(&rule);
assert_eq!(strength, 60);
}
#[test]
fn test_apply_modifier_add() {
assert_eq!(apply_strength_modifier(50, &StrengthModifier::Add(10)), 60);
}
#[test]
fn test_apply_modifier_subtract() {
assert_eq!(
apply_strength_modifier(50, &StrengthModifier::Subtract(10)),
40
);
}
#[test]
fn test_apply_modifier_multiply() {
assert_eq!(
apply_strength_modifier(50, &StrengthModifier::Multiply(2)),
100
);
}
#[test]
fn test_apply_modifier_divide() {
assert_eq!(
apply_strength_modifier(50, &StrengthModifier::Divide(2)),
25
);
}
#[test]
fn test_apply_modifier_set() {
assert_eq!(apply_strength_modifier(50, &StrengthModifier::Set(75)), 75);
}
#[test]
fn test_apply_modifier_add_overflow() {
assert_eq!(
apply_strength_modifier(250, &StrengthModifier::Add(100)),
MAX_STRENGTH
);
}
#[test]
fn test_apply_modifier_subtract_underflow() {
assert_eq!(
apply_strength_modifier(10, &StrengthModifier::Subtract(100)),
MIN_STRENGTH
);
}
#[test]
fn test_apply_modifier_multiply_overflow() {
assert_eq!(
apply_strength_modifier(200, &StrengthModifier::Multiply(10)),
MAX_STRENGTH
);
}
#[test]
fn test_apply_modifier_divide_by_zero() {
assert_eq!(
apply_strength_modifier(50, &StrengthModifier::Divide(0)),
50
);
}
#[test]
fn test_apply_modifier_set_negative() {
assert_eq!(
apply_strength_modifier(50, &StrengthModifier::Set(-10)),
MIN_STRENGTH
);
}
#[test]
fn test_apply_modifier_set_over_max() {
assert_eq!(
apply_strength_modifier(50, &StrengthModifier::Set(1000)),
MAX_STRENGTH
);
}
#[test]
fn test_rule_strength_without_modifier() {
let rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
assert_eq!(calculate_rule_strength(&rule), 25);
}
#[test]
fn test_rule_strength_with_add_modifier() {
let mut rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
rule.strength_modifier = Some(StrengthModifier::Add(20));
assert_eq!(calculate_rule_strength(&rule), 45);
}
#[test]
fn test_rule_strength_with_multiply_modifier() {
let mut rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
rule.strength_modifier = Some(StrengthModifier::Multiply(2));
assert_eq!(calculate_rule_strength(&rule), 50);
}
#[test]
fn test_rule_strength_with_set_modifier() {
let mut rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
rule.strength_modifier = Some(StrengthModifier::Set(100));
assert_eq!(calculate_rule_strength(&rule), 100);
}
#[test]
fn test_sort_rules_by_strength_basic() {
let mut rules = vec![
{
let mut r = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
r.message = "byte rule".to_string();
r
},
{
let mut r = make_rule(
TypeKind::String { max_length: None },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::String("MAGIC".to_string()),
);
r.message = "string rule".to_string();
r
},
];
sort_rules_by_strength(&mut rules);
assert_eq!(rules[0].message, "string rule");
assert_eq!(rules[1].message, "byte rule");
}
#[test]
fn test_sort_rules_by_strength_with_modifier() {
let mut rules = vec![
{
let mut r = make_rule(
TypeKind::String { max_length: None },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::String("TEST".to_string()),
);
r.message = "string rule".to_string();
r.strength_modifier = Some(StrengthModifier::Set(10));
r
},
{
let mut r = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
r.message = "byte rule".to_string();
r.strength_modifier = Some(StrengthModifier::Set(100));
r
},
];
sort_rules_by_strength(&mut rules);
assert_eq!(rules[0].message, "byte rule");
assert_eq!(rules[1].message, "string rule");
}
#[test]
fn test_sort_rules_empty() {
let mut rules: Vec<MagicRule> = vec![];
sort_rules_by_strength(&mut rules);
assert!(rules.is_empty());
}
#[test]
fn test_sort_rules_single() {
let mut rules = vec![make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
)];
sort_rules_by_strength(&mut rules);
assert_eq!(rules.len(), 1);
}
#[test]
fn test_into_sorted_by_strength() {
let rules = vec![
{
let mut r = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
r.message = "byte rule".to_string();
r
},
{
let mut r = make_rule(
TypeKind::Long {
endian: Endianness::Big,
signed: false,
},
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0),
);
r.message = "long rule".to_string();
r
},
];
let sorted = into_sorted_by_strength(rules);
assert_eq!(sorted[0].message, "long rule");
assert_eq!(sorted[1].message, "byte rule");
}
#[test]
fn test_strength_comparison_string_vs_byte() {
let string_rule = make_rule(
TypeKind::String { max_length: None },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::String("AB".to_string()),
);
let byte_rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0x7f),
);
let string_strength = calculate_rule_strength(&string_rule);
let byte_strength = calculate_rule_strength(&byte_rule);
assert!(
string_strength > byte_strength,
"String strength {string_strength} should be > byte strength {byte_strength}"
);
}
#[test]
fn test_strength_comparison_absolute_vs_relative_offset() {
let absolute_rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Absolute(0),
Value::Uint(0x7f),
);
let relative_rule = make_rule(
TypeKind::Byte { signed: true },
Operator::Equal,
OffsetSpec::Relative(4),
Value::Uint(0x7f),
);
let absolute_strength = calculate_rule_strength(&absolute_rule);
let relative_strength = calculate_rule_strength(&relative_rule);
assert!(
absolute_strength > relative_strength,
"Absolute strength {absolute_strength} should be > relative strength {relative_strength}"
);
}
}