use crate::types::{RegionId, TaskId};
use std::collections::BTreeSet;
use std::sync::Arc;
type FilterPredicate = dyn Fn(&dyn FilterableEvent) -> bool + Send + Sync;
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub enum EventCategory {
Scheduling,
Time,
Io,
Rng,
Region,
Waker,
Chaos,
Checkpoint,
}
impl EventCategory {
#[must_use]
pub const fn all() -> &'static [Self] {
&[
Self::Scheduling,
Self::Time,
Self::Io,
Self::Rng,
Self::Region,
Self::Waker,
Self::Chaos,
Self::Checkpoint,
]
}
#[must_use]
pub const fn high_frequency() -> &'static [Self] {
&[Self::Rng, Self::Waker]
}
#[must_use]
pub fn is_sampled(&self) -> bool {
matches!(self, Self::Rng | Self::Waker)
}
}
pub trait FilterableEvent {
fn event_kind(&self) -> EventCategory;
fn task_id(&self) -> Option<TaskId>;
fn region_id(&self) -> Option<RegionId>;
}
#[derive(Clone)]
pub struct TraceFilter {
include_kinds: BTreeSet<EventCategory>,
exclude_kinds: BTreeSet<EventCategory>,
region_filter: Option<BTreeSet<RegionId>>,
exclude_regions: BTreeSet<RegionId>,
task_filter: Option<BTreeSet<TaskId>>,
sample_rate: f64,
custom: Option<Arc<FilterPredicate>>,
sample_state: u64,
}
impl Default for TraceFilter {
fn default() -> Self {
Self {
include_kinds: BTreeSet::new(),
exclude_kinds: BTreeSet::new(),
region_filter: None,
exclude_regions: BTreeSet::new(),
task_filter: None,
sample_rate: 1.0,
custom: None,
sample_state: 0,
}
}
}
impl std::fmt::Debug for TraceFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TraceFilter")
.field("include_kinds", &self.include_kinds)
.field("exclude_kinds", &self.exclude_kinds)
.field("region_filter", &self.region_filter)
.field("exclude_regions", &self.exclude_regions)
.field("task_filter", &self.task_filter)
.field("sample_rate", &self.sample_rate)
.field("custom", &self.custom.as_ref().map(|_| "<predicate>"))
.finish_non_exhaustive()
}
}
impl TraceFilter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn scheduling_only() -> Self {
let mut filter = Self::new();
filter.include_kinds = [EventCategory::Scheduling, EventCategory::Chaos]
.into_iter()
.collect();
filter
}
#[must_use]
pub fn no_rng() -> Self {
let mut filter = Self::new();
filter.exclude_kinds.insert(EventCategory::Rng);
filter
}
#[must_use]
pub fn region_subtree(root: RegionId) -> Self {
let mut filter = Self::new();
let mut regions = BTreeSet::new();
regions.insert(root);
filter.region_filter = Some(regions);
filter
}
#[must_use]
pub fn io_focused() -> Self {
let mut filter = Self::new();
filter.include_kinds = [
EventCategory::Io,
EventCategory::Scheduling,
EventCategory::Time,
]
.into_iter()
.collect();
filter
}
#[must_use]
pub fn with_sampling(rate: f64) -> Self {
let mut filter = Self::new();
filter.sample_rate = rate.clamp(0.0, 1.0);
filter
}
#[must_use]
pub fn include_kinds<I>(mut self, kinds: I) -> Self
where
I: IntoIterator<Item = EventCategory>,
{
self.include_kinds = kinds.into_iter().collect();
self
}
#[must_use]
pub fn include_kind(mut self, kind: EventCategory) -> Self {
self.include_kinds.insert(kind);
self
}
#[must_use]
pub fn exclude_kinds<I>(mut self, kinds: I) -> Self
where
I: IntoIterator<Item = EventCategory>,
{
self.exclude_kinds = kinds.into_iter().collect();
self
}
#[must_use]
pub fn exclude_kind(mut self, kind: EventCategory) -> Self {
self.exclude_kinds.insert(kind);
self
}
#[must_use]
pub fn filter_regions<I>(mut self, regions: I) -> Self
where
I: IntoIterator<Item = RegionId>,
{
self.region_filter = Some(regions.into_iter().collect());
self
}
#[must_use]
pub fn include_region(mut self, region: RegionId) -> Self {
self.region_filter
.get_or_insert_with(BTreeSet::new)
.insert(region);
self
}
#[must_use]
pub fn exclude_region(mut self, region: RegionId) -> Self {
self.exclude_regions.insert(region);
if let Some(ref mut regions) = self.region_filter {
regions.remove(®ion);
}
self
}
#[must_use]
pub fn exclude_region_explicit(self, region: RegionId) -> Self {
self.exclude_region(region)
}
#[must_use]
pub fn filter_tasks<I>(mut self, tasks: I) -> Self
where
I: IntoIterator<Item = TaskId>,
{
self.task_filter = Some(tasks.into_iter().collect());
self
}
#[must_use]
pub fn include_task(mut self, task: TaskId) -> Self {
self.task_filter
.get_or_insert_with(BTreeSet::new)
.insert(task);
self
}
#[must_use]
pub fn with_sample_rate(mut self, rate: f64) -> Self {
self.sample_rate = rate.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_custom<F>(mut self, predicate: F) -> Self
where
F: Fn(&dyn FilterableEvent) -> bool + Send + Sync + 'static,
{
self.custom = Some(Arc::new(predicate));
self
}
#[must_use]
pub fn with_sample_seed(mut self, seed: u64) -> Self {
self.sample_state = seed;
self
}
pub fn should_record(&mut self, event: &dyn FilterableEvent) -> bool {
let kind = event.event_kind();
if self.exclude_kinds.contains(&kind) {
return false;
}
if !self.include_kinds.is_empty() && !self.include_kinds.contains(&kind) {
return false;
}
if let Some(region) = event.region_id() {
if self.exclude_regions.contains(®ion) {
return false;
}
}
if let Some(ref regions) = self.region_filter {
if let Some(region) = event.region_id() {
if !regions.contains(®ion) {
return false;
}
}
}
if let Some(ref tasks) = self.task_filter {
if let Some(task) = event.task_id() {
if !tasks.contains(&task) {
return false;
}
}
}
if kind.is_sampled() && self.sample_rate < 1.0 && !self.sample() {
return false;
}
if let Some(ref predicate) = self.custom {
if !predicate(event) {
return false;
}
}
true
}
#[must_use]
pub fn is_pass_through(&self) -> bool {
self.include_kinds.is_empty()
&& self.exclude_kinds.is_empty()
&& self.region_filter.is_none()
&& self.exclude_regions.is_empty()
&& self.task_filter.is_none()
&& (self.sample_rate - 1.0).abs() < f64::EPSILON
&& self.custom.is_none()
}
#[allow(clippy::cast_precision_loss)]
fn sample(&mut self) -> bool {
let mut x = self.sample_state;
if x == 0 {
x = 1;
}
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.sample_state = x;
let normalized = (x as f64) / (u64::MAX as f64);
normalized < self.sample_rate
}
#[must_use]
pub fn includes_kind(&self, kind: EventCategory) -> bool {
!self.exclude_kinds.contains(&kind)
&& (self.include_kinds.is_empty() || self.include_kinds.contains(&kind))
}
#[must_use]
pub fn sample_rate(&self) -> f64 {
self.sample_rate
}
#[must_use]
pub fn included_kinds(&self) -> &BTreeSet<EventCategory> {
&self.include_kinds
}
#[must_use]
pub fn excluded_kinds(&self) -> &BTreeSet<EventCategory> {
&self.exclude_kinds
}
#[must_use]
pub fn excluded_regions(&self) -> &BTreeSet<RegionId> {
&self.exclude_regions
}
}
pub struct FilterBuilder {
filter: TraceFilter,
}
impl FilterBuilder {
#[must_use]
pub fn new() -> Self {
Self {
filter: TraceFilter::new(),
}
}
#[must_use]
pub fn include_kinds<I>(mut self, kinds: I) -> Self
where
I: IntoIterator<Item = EventCategory>,
{
self.filter = self.filter.include_kinds(kinds);
self
}
#[must_use]
pub fn exclude_kinds<I>(mut self, kinds: I) -> Self
where
I: IntoIterator<Item = EventCategory>,
{
self.filter = self.filter.exclude_kinds(kinds);
self
}
#[must_use]
pub fn exclude_root_region(mut self) -> Self {
self.filter = self.filter.exclude_region(RegionId::testing_default());
self
}
#[must_use]
pub fn sample_rate(mut self, rate: f64) -> Self {
self.filter = self.filter.with_sample_rate(rate);
self
}
#[must_use]
pub fn build(self) -> TraceFilter {
self.filter
}
}
impl Default for FilterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestEvent {
kind: EventCategory,
task: Option<TaskId>,
region: Option<RegionId>,
}
impl FilterableEvent for TestEvent {
fn event_kind(&self) -> EventCategory {
self.kind
}
fn task_id(&self) -> Option<TaskId> {
self.task
}
fn region_id(&self) -> Option<RegionId> {
self.region
}
}
fn make_task_id(index: u32) -> TaskId {
TaskId::new_for_test(index, 0)
}
fn make_region_id(index: u32) -> RegionId {
RegionId::new_for_test(index, 0)
}
#[test]
fn default_filter_passes_all() {
let mut filter = TraceFilter::default();
assert!(filter.is_pass_through());
let event = TestEvent {
kind: EventCategory::Scheduling,
task: Some(make_task_id(1)),
region: Some(make_region_id(0)),
};
assert!(filter.should_record(&event));
}
#[test]
fn include_kinds_filter() {
let mut filter =
TraceFilter::new().include_kinds([EventCategory::Scheduling, EventCategory::Time]);
let scheduling = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: None,
};
let io = TestEvent {
kind: EventCategory::Io,
task: None,
region: None,
};
assert!(filter.should_record(&scheduling));
assert!(!filter.should_record(&io));
}
#[test]
fn exclude_kinds_filter() {
let mut filter = TraceFilter::new().exclude_kind(EventCategory::Rng);
let rng = TestEvent {
kind: EventCategory::Rng,
task: None,
region: None,
};
let scheduling = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: None,
};
assert!(!filter.should_record(&rng));
assert!(filter.should_record(&scheduling));
}
#[test]
fn exclude_takes_precedence_over_include() {
let mut filter = TraceFilter::new()
.include_kinds([EventCategory::Scheduling, EventCategory::Rng])
.exclude_kind(EventCategory::Rng);
let rng = TestEvent {
kind: EventCategory::Rng,
task: None,
region: None,
};
assert!(!filter.should_record(&rng));
}
#[test]
fn task_filter() {
let task1 = make_task_id(1);
let task2 = make_task_id(2);
let mut filter = TraceFilter::new().filter_tasks([task1]);
let event1 = TestEvent {
kind: EventCategory::Scheduling,
task: Some(task1),
region: None,
};
let event2 = TestEvent {
kind: EventCategory::Scheduling,
task: Some(task2),
region: None,
};
let no_task = TestEvent {
kind: EventCategory::Time,
task: None,
region: None,
};
assert!(filter.should_record(&event1));
assert!(!filter.should_record(&event2));
assert!(filter.should_record(&no_task)); }
#[test]
fn region_filter() {
let region1 = make_region_id(1);
let region2 = make_region_id(2);
let mut filter = TraceFilter::new().filter_regions([region1]);
let event1 = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(region1),
};
let event2 = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(region2),
};
let no_region = TestEvent {
kind: EventCategory::Time,
task: None,
region: None,
};
assert!(filter.should_record(&event1));
assert!(!filter.should_record(&event2));
assert!(filter.should_record(&no_region)); }
#[test]
fn exclude_region_blocks_events() {
let region1 = make_region_id(1);
let region2 = make_region_id(2);
let mut filter = TraceFilter::new().exclude_region(region1);
let excluded = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(region1),
};
let allowed = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(region2),
};
assert!(!filter.should_record(&excluded));
assert!(filter.should_record(&allowed));
assert!(filter.excluded_regions().contains(®ion1));
}
#[test]
fn exclude_region_overrides_region_filter() {
let region1 = make_region_id(1);
let region2 = make_region_id(2);
let mut filter = TraceFilter::new()
.filter_regions([region1, region2])
.exclude_region(region2);
let event1 = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(region1),
};
let event2 = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(region2),
};
assert!(filter.should_record(&event1));
assert!(!filter.should_record(&event2));
}
#[test]
fn sampling() {
let mut filter = TraceFilter::new()
.with_sample_rate(0.5)
.with_sample_seed(42);
let mut passed = 0;
let total = 1000;
for _ in 0..total {
let event = TestEvent {
kind: EventCategory::Rng, task: None,
region: None,
};
if filter.should_record(&event) {
passed += 1;
}
}
assert!(passed > 400 && passed < 600, "Passed: {passed}");
}
#[test]
fn no_sampling_for_non_high_frequency() {
let mut filter = TraceFilter::new()
.with_sample_rate(0.0) .with_sample_seed(42);
let event = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: None,
};
assert!(filter.should_record(&event));
}
#[test]
fn custom_predicate() {
let mut filter = TraceFilter::new().with_custom(|event| {
event.task_id().is_some()
});
let with_task = TestEvent {
kind: EventCategory::Scheduling,
task: Some(make_task_id(1)),
region: None,
};
let without_task = TestEvent {
kind: EventCategory::Time,
task: None,
region: None,
};
assert!(filter.should_record(&with_task));
assert!(!filter.should_record(&without_task));
}
#[test]
fn predefined_scheduling_only() {
let filter = TraceFilter::scheduling_only();
assert!(filter.includes_kind(EventCategory::Scheduling));
assert!(filter.includes_kind(EventCategory::Chaos));
assert!(!filter.includes_kind(EventCategory::Rng));
assert!(!filter.includes_kind(EventCategory::Io));
}
#[test]
fn predefined_no_rng() {
let filter = TraceFilter::no_rng();
assert!(!filter.includes_kind(EventCategory::Rng));
assert!(filter.includes_kind(EventCategory::Scheduling));
assert!(filter.includes_kind(EventCategory::Time));
}
#[test]
fn predefined_io_focused() {
let filter = TraceFilter::io_focused();
assert!(filter.includes_kind(EventCategory::Io));
assert!(filter.includes_kind(EventCategory::Scheduling));
assert!(filter.includes_kind(EventCategory::Time));
assert!(!filter.includes_kind(EventCategory::Rng));
assert!(!filter.includes_kind(EventCategory::Waker));
}
#[test]
fn filter_builder() {
let filter = FilterBuilder::new()
.include_kinds([EventCategory::Scheduling, EventCategory::Time])
.exclude_kinds([EventCategory::Waker])
.sample_rate(0.5)
.build();
assert!(filter.includes_kind(EventCategory::Scheduling));
assert!(!filter.includes_kind(EventCategory::Rng));
assert!(!filter.includes_kind(EventCategory::Waker));
assert!((filter.sample_rate() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn filter_builder_exclude_root_region() {
let mut filter = FilterBuilder::new().exclude_root_region().build();
let root = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(make_region_id(0)),
};
let non_root = TestEvent {
kind: EventCategory::Scheduling,
task: None,
region: Some(make_region_id(7)),
};
assert!(!filter.should_record(&root));
assert!(filter.should_record(&non_root));
}
#[test]
fn is_pass_through() {
assert!(TraceFilter::default().is_pass_through());
assert!(!TraceFilter::no_rng().is_pass_through());
assert!(!TraceFilter::scheduling_only().is_pass_through());
assert!(!TraceFilter::with_sampling(0.5).is_pass_through());
}
}