use crate::observation::Observation;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ControlFlow {
Continue,
Halt,
}
pub trait ObservationSink: Send {
fn on_observation(&mut self, observation: &Observation) -> ControlFlow;
fn on_complete(&mut self) {}
}
pub struct BatchCollector {
observations: Vec<Observation>,
max_observations: usize,
}
impl BatchCollector {
#[must_use]
pub fn new(max_observations: usize) -> Self {
Self {
observations: Vec::new(),
max_observations,
}
}
#[must_use]
pub fn into_observations(self) -> Vec<Observation> {
self.observations
}
}
impl ObservationSink for BatchCollector {
fn on_observation(&mut self, observation: &Observation) -> ControlFlow {
if self.observations.len() < self.max_observations {
self.observations.push(observation.clone());
}
ControlFlow::Continue
}
}
pub struct EarlyStopSink<F: FnMut(&Observation) -> bool> {
predicate: F,
observations: Vec<Observation>,
}
impl<F: FnMut(&Observation) -> bool> EarlyStopSink<F> {
pub fn new(predicate: F) -> Self {
Self {
predicate,
observations: Vec::new(),
}
}
pub fn into_observations(self) -> Vec<Observation> {
self.observations
}
}
impl<F: FnMut(&Observation) -> bool + Send> ObservationSink for EarlyStopSink<F> {
fn on_observation(&mut self, observation: &Observation) -> ControlFlow {
self.observations.push(observation.clone());
if (self.predicate)(observation) {
ControlFlow::Halt
} else {
ControlFlow::Continue
}
}
}
#[derive(Debug, Default)]
pub struct CountingSink {
pub api_calls: u64,
pub dom_mutations: u64,
pub network_requests: u64,
pub dynamic_code: u64,
pub cookie_access: u64,
pub errors: u64,
pub wasm_instantiations: u64,
pub fingerprint_access: u64,
pub context_messages: u64,
pub resource_limits: u64,
pub total: u64,
}
impl ObservationSink for CountingSink {
fn on_observation(&mut self, observation: &Observation) -> ControlFlow {
self.total = self.total.saturating_add(1);
match observation {
Observation::ApiCall { .. } => self.api_calls = self.api_calls.saturating_add(1),
Observation::DomMutation { .. } => {
self.dom_mutations = self.dom_mutations.saturating_add(1)
}
Observation::NetworkRequest { .. } => {
self.network_requests = self.network_requests.saturating_add(1)
}
Observation::DynamicCodeExec { .. } => {
self.dynamic_code = self.dynamic_code.saturating_add(1)
}
Observation::CookieAccess { .. } => {
self.cookie_access = self.cookie_access.saturating_add(1)
}
Observation::Error { .. } => self.errors = self.errors.saturating_add(1),
Observation::WasmInstantiation { .. } => {
self.wasm_instantiations = self.wasm_instantiations.saturating_add(1)
}
Observation::FingerprintAccess { .. } => {
self.fingerprint_access = self.fingerprint_access.saturating_add(1)
}
Observation::ContextMessage { .. } => {
self.context_messages = self.context_messages.saturating_add(1)
}
Observation::ResourceLimit { .. } => {
self.resource_limits = self.resource_limits.saturating_add(1)
}
_ => {}
}
ControlFlow::Continue
}
}