use crate::metadata::token::Token;
#[derive(Clone, Debug, Default)]
pub struct TraceFilter {
pub method_tokens: Vec<Token>,
pub method_patterns: Vec<MethodPattern>,
pub min_depth: Option<u32>,
pub max_depth: Option<u32>,
pub capture_arguments: bool,
pub capture_returns: bool,
pub instructions: InstructionTraceLevel,
}
#[derive(Clone, Debug)]
pub struct MethodPattern {
pub namespace: Option<String>,
pub type_name: Option<String>,
pub method_name: Option<String>,
pub include: bool,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum InstructionTraceLevel {
Off,
CallsOnly,
BranchesAndCalls,
#[default]
All,
}
impl TraceFilter {
#[must_use]
pub fn allow_all() -> Self {
Self::default()
}
#[must_use]
pub fn should_trace(&self, method_token: Token, call_depth: u32) -> bool {
if let Some(min) = self.min_depth {
if call_depth < min {
return false;
}
}
if let Some(max) = self.max_depth {
if call_depth > max {
return false;
}
}
if !self.method_tokens.is_empty() && !self.method_tokens.contains(&method_token) {
return false;
}
true
}
#[must_use]
pub fn should_trace_by_name(
&self,
namespace: &str,
type_name: &str,
method_name: &str,
call_depth: u32,
) -> bool {
if let Some(min) = self.min_depth {
if call_depth < min {
return false;
}
}
if let Some(max) = self.max_depth {
if call_depth > max {
return false;
}
}
if self.method_patterns.is_empty() {
return true;
}
let mut included = false;
for pattern in &self.method_patterns {
if pattern_matches(pattern, namespace, type_name, method_name) {
included = pattern.include;
}
}
included
}
}
fn glob_matches(pattern: &str, value: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
value.starts_with(prefix)
} else {
pattern == value
}
}
fn pattern_matches(
pattern: &MethodPattern,
namespace: &str,
type_name: &str,
method_name: &str,
) -> bool {
if let Some(ref ns) = pattern.namespace {
if !glob_matches(ns, namespace) {
return false;
}
}
if let Some(ref tn) = pattern.type_name {
if !glob_matches(tn, type_name) {
return false;
}
}
if let Some(ref mn) = pattern.method_name {
if !glob_matches(mn, method_name) {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use crate::{
emulation::tracer::filter::{InstructionTraceLevel, MethodPattern, TraceFilter},
metadata::token::Token,
};
#[test]
fn test_allow_all_traces_everything() {
let filter = TraceFilter::allow_all();
assert!(filter.should_trace(Token::new(0x0600_0001), 0));
assert!(filter.should_trace(Token::new(0x0600_0001), 100));
}
#[test]
fn test_depth_limits() {
let filter = TraceFilter {
min_depth: Some(2),
max_depth: Some(5),
..TraceFilter::default()
};
assert!(!filter.should_trace(Token::new(0x0600_0001), 1));
assert!(filter.should_trace(Token::new(0x0600_0001), 3));
assert!(!filter.should_trace(Token::new(0x0600_0001), 6));
}
#[test]
fn test_token_whitelist() {
let target = Token::new(0x0600_0001);
let other = Token::new(0x0600_0002);
let filter = TraceFilter {
method_tokens: vec![target],
..TraceFilter::default()
};
assert!(filter.should_trace(target, 0));
assert!(!filter.should_trace(other, 0));
}
#[test]
fn test_name_pattern_matching() {
let filter = TraceFilter {
method_patterns: vec![MethodPattern {
namespace: Some("System.IO*".to_string()),
type_name: None,
method_name: None,
include: true,
}],
..TraceFilter::default()
};
assert!(filter.should_trace_by_name("System.IO", "Stream", "Read", 0));
assert!(filter.should_trace_by_name("System.IO.Compression", "GZipStream", "Read", 0));
assert!(!filter.should_trace_by_name("System.Text", "Encoding", "GetBytes", 0));
}
#[test]
fn test_default_instruction_level() {
let filter = TraceFilter::default();
assert_eq!(filter.instructions, InstructionTraceLevel::All);
}
}