use std::cell::Cell;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicI32, Ordering};
use parking_lot::Mutex;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tracing::Subscriber;
use tracing::field::{Field, Visit};
use tracing_subscriber::Layer;
use tracing_subscriber::layer::{Context, SubscriberExt};
use tracing_subscriber::registry::LookupSpan;
use super::event::{CapturedEvent, TypedEntry};
use super::invariant::Invariant;
use super::query::TrailQuery;
#[allow(unused_imports)]
use super::query::TrailQueryExt;
static INSTALL_COUNT: AtomicI32 = AtomicI32::new(0);
#[inline]
#[must_use]
pub fn layer_installed() -> bool {
INSTALL_COUNT.load(Ordering::Relaxed) > 0
}
pub(crate) struct EventStore {
seq_counter: u64,
by_trail: HashMap<String, Vec<CapturedEvent>>,
last_sim_time_ms: u64,
}
impl EventStore {
fn new() -> Self {
Self {
seq_counter: 0,
by_trail: HashMap::new(),
last_sim_time_ms: 0,
}
}
}
pub struct SimulationLayer {
events: Arc<Mutex<EventStore>>,
invariants: Arc<Mutex<Vec<Box<dyn Invariant + Send>>>>,
}
impl SimulationLayer {
#[must_use]
pub fn new() -> Self {
Self {
events: Arc::new(Mutex::new(EventStore::new())),
invariants: Arc::new(Mutex::new(Vec::new())),
}
}
#[must_use]
pub fn handle(&self) -> SimulationLayerHandle {
SimulationLayerHandle {
events: self.events.clone(),
invariants: self.invariants.clone(),
}
}
#[must_use]
pub fn install(self) -> (SimulationLayerHandle, InstallGuard) {
let handle = self.handle();
INSTALL_COUNT.fetch_add(1, Ordering::Relaxed);
let interest_anchor = tracing::Dispatch::new(tracing_subscriber::registry());
let subscriber = tracing_subscriber::registry().with(self);
let guard = tracing::subscriber::set_default(subscriber);
tracing::callsite::rebuild_interest_cache();
(
handle,
InstallGuard {
_guard: guard,
_interest_anchor: interest_anchor,
},
)
}
}
impl Default for SimulationLayer {
fn default() -> Self {
Self::new()
}
}
pub struct InstallGuard {
_guard: tracing::subscriber::DefaultGuard,
_interest_anchor: tracing::Dispatch,
}
impl Drop for InstallGuard {
fn drop(&mut self) {
INSTALL_COUNT.fetch_sub(1, Ordering::Relaxed);
}
}
#[derive(Clone)]
pub struct SimulationLayerHandle {
events: Arc<Mutex<EventStore>>,
invariants: Arc<Mutex<Vec<Box<dyn Invariant + Send>>>>,
}
impl SimulationLayerHandle {
pub fn register(&self, inv: Box<dyn Invariant + Send>) {
self.invariants.lock().push(inv);
}
pub fn reset_for_seed(&self) {
{
let mut store = self.events.lock();
store.by_trail.clear();
store.seq_counter = 0;
store.last_sim_time_ms = 0;
}
let mut invs = self.invariants.lock();
for inv in invs.iter_mut() {
inv.reset();
}
}
pub fn trail<T: DeserializeOwned>(&self, trail: &str) -> Vec<TypedEntry<T>> {
let store = self.events.lock();
let Some(entries) = store.by_trail.get(trail) else {
return Vec::new();
};
entries
.iter()
.filter_map(TypedEntry::<T>::deserialize)
.collect()
}
#[must_use]
pub fn current_sim_time_ms(&self) -> u64 {
self.events.lock().last_sim_time_ms
}
pub fn set_sim_time_ms(&self, ms: u64) {
self.events.lock().last_sim_time_ms = ms;
}
}
struct LayerQuery {
events: Arc<Mutex<EventStore>>,
}
impl TrailQuery for LayerQuery {
fn len(&self, trail: &str) -> usize {
self.events
.lock()
.by_trail
.get(trail)
.map_or(0, std::vec::Vec::len)
}
fn last_seq(&self) -> u64 {
let store = self.events.lock();
store.seq_counter.saturating_sub(1)
}
fn drain_since(&self, trail: &str, cursor: &Cell<usize>) -> Vec<CapturedEvent> {
let store = self.events.lock();
let Some(entries) = store.by_trail.get(trail) else {
return Vec::new();
};
let len = entries.len();
let from = cursor.get();
if from >= len {
return Vec::new();
}
let result: Vec<CapturedEvent> = entries[from..].to_vec();
cursor.set(len);
result
}
}
struct CaptureVisitor {
capture: bool,
trail: Option<String>,
source: String,
payload: Option<Value>,
}
impl CaptureVisitor {
fn new() -> Self {
Self {
capture: false,
trail: None,
source: String::new(),
payload: None,
}
}
}
impl Visit for CaptureVisitor {
fn record_bool(&mut self, field: &Field, value: bool) {
if field.name() == "capture" {
self.capture = value;
}
}
fn record_str(&mut self, field: &Field, value: &str) {
match field.name() {
"trail" => self.trail = Some(value.to_owned()),
"source" => value.clone_into(&mut self.source),
_ => {}
}
}
fn record_value(&mut self, field: &Field, value: valuable::Value<'_>) {
if field.name() == "event" {
let serializable = valuable_serde::Serializable::new(value);
if let Ok(v) = serde_json::to_value(serializable) {
self.payload = Some(v);
}
}
}
fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {
}
}
impl<S> Layer<S> for SimulationLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn register_callsite(
&self,
_metadata: &'static tracing::Metadata<'static>,
) -> tracing::subscriber::Interest {
tracing::subscriber::Interest::sometimes()
}
fn on_event(&self, event: &tracing::Event<'_>, _ctx: Context<'_, S>) {
let mut visitor = CaptureVisitor::new();
event.record(&mut visitor);
if !visitor.capture {
return;
}
let Some(trail) = visitor.trail else {
return;
};
let Some(payload) = visitor.payload else {
return;
};
let sim_time_ms;
{
let mut store = self.events.lock();
let seq = store.seq_counter;
store.seq_counter += 1;
sim_time_ms = store.last_sim_time_ms;
store
.by_trail
.entry(trail.clone())
.or_default()
.push(CapturedEvent {
trail,
time_ms: sim_time_ms,
source: visitor.source,
seq,
payload,
});
}
let query = LayerQuery {
events: self.events.clone(),
};
let invariants = self.invariants.lock();
for inv in invariants.iter() {
inv.observe(&query, sim_time_ms);
}
}
}