use super::{
ExcludedItem, ExclusionReason, IncludedItem, InclusionReason, SelectionReport,
StageTraceSnapshot, TraceEvent,
};
use crate::model::{ContextBudget, ContextItem};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum TraceDetailLevel {
Stage,
Item,
}
pub trait TraceCollector {
fn is_enabled(&self) -> bool;
fn record_stage_event(&mut self, event: TraceEvent);
fn record_item_event(&mut self, event: TraceEvent);
fn record_included(&mut self, _item: ContextItem, _score: f64, _reason: InclusionReason) {}
fn record_excluded(&mut self, _item: ContextItem, _score: f64, _reason: ExclusionReason) {}
fn set_candidates(&mut self, _total: usize, _total_tokens: i64) {}
fn on_pipeline_completed(
&mut self,
_report: &SelectionReport,
_budget: &ContextBudget,
_stage_snapshots: &[StageTraceSnapshot],
) {
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NullTraceCollector;
impl TraceCollector for NullTraceCollector {
#[inline]
fn is_enabled(&self) -> bool {
false
}
#[inline]
fn record_stage_event(&mut self, _event: TraceEvent) {}
#[inline]
fn record_item_event(&mut self, _event: TraceEvent) {}
}
#[cfg(feature = "serde")]
fn ser_excluded_items<S: serde::Serializer>(
items: &Vec<(ExcludedItem, usize)>,
serializer: S,
) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(items.len()))?;
for (item, _) in items {
seq.serialize_element(item)?;
}
seq.end()
}
#[cfg(feature = "serde")]
fn de_excluded_items<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> Result<Vec<(ExcludedItem, usize)>, D::Error> {
use serde::Deserialize;
let items = Vec::<ExcludedItem>::deserialize(deserializer)?;
Ok(items
.into_iter()
.enumerate()
.map(|(i, item)| (item, i))
.collect())
}
type TraceEventCallback = Box<dyn Fn(&TraceEvent)>;
#[non_exhaustive]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DiagnosticTraceCollector {
events: Vec<TraceEvent>,
included: Vec<IncludedItem>,
#[cfg_attr(
feature = "serde",
serde(
serialize_with = "ser_excluded_items",
deserialize_with = "de_excluded_items"
)
)]
excluded: Vec<(ExcludedItem, usize)>,
total_candidates: usize,
total_tokens_considered: i64,
detail_level: TraceDetailLevel,
#[cfg_attr(feature = "serde", serde(skip))]
callback: Option<TraceEventCallback>,
}
impl std::fmt::Debug for DiagnosticTraceCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DiagnosticTraceCollector")
.field("events", &self.events)
.field("included", &self.included)
.field("excluded", &self.excluded)
.field("total_candidates", &self.total_candidates)
.field("total_tokens_considered", &self.total_tokens_considered)
.field("detail_level", &self.detail_level)
.field("callback", &self.callback.as_ref().map(|_| "<callback>"))
.finish()
}
}
impl DiagnosticTraceCollector {
pub fn new(detail_level: TraceDetailLevel) -> Self {
Self {
events: Vec::new(),
included: Vec::new(),
excluded: Vec::new(),
total_candidates: 0,
total_tokens_considered: 0,
detail_level,
callback: None,
}
}
pub fn with_callback(detail_level: TraceDetailLevel, callback: TraceEventCallback) -> Self {
Self {
events: Vec::new(),
included: Vec::new(),
excluded: Vec::new(),
total_candidates: 0,
total_tokens_considered: 0,
detail_level,
callback: Some(callback),
}
}
pub fn into_report(mut self) -> SelectionReport {
self.excluded
.sort_by(|(a, ai), (b, bi)| b.score.total_cmp(&a.score).then_with(|| ai.cmp(bi)));
let excluded: Vec<ExcludedItem> = self.excluded.into_iter().map(|(item, _)| item).collect();
SelectionReport {
events: self.events,
included: self.included,
excluded,
total_candidates: self.total_candidates,
total_tokens_considered: self.total_tokens_considered,
count_requirement_shortfalls: Vec::new(),
}
}
fn invoke_callback(&self, event: &TraceEvent) {
if let Some(cb) = &self.callback {
cb(event);
}
}
}
impl TraceCollector for DiagnosticTraceCollector {
#[inline]
fn is_enabled(&self) -> bool {
true
}
fn record_stage_event(&mut self, event: TraceEvent) {
self.invoke_callback(&event);
self.events.push(event);
}
fn record_item_event(&mut self, event: TraceEvent) {
if !matches!(self.detail_level, TraceDetailLevel::Item) {
return;
}
self.invoke_callback(&event);
self.events.push(event);
}
fn record_included(&mut self, item: ContextItem, score: f64, reason: InclusionReason) {
self.included.push(IncludedItem {
item,
score,
reason,
});
}
fn record_excluded(&mut self, item: ContextItem, score: f64, reason: ExclusionReason) {
let idx = self.excluded.len();
self.excluded.push((
ExcludedItem {
item,
score,
reason,
},
idx,
));
}
fn set_candidates(&mut self, total: usize, total_tokens: i64) {
self.total_candidates = total;
self.total_tokens_considered = total_tokens;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diagnostics::{ExclusionReason, InclusionReason, PipelineStage, TraceEvent};
use crate::model::ContextItemBuilder;
use std::sync::{
Arc,
atomic::{AtomicU32, Ordering},
};
fn make_item(content: &str, tokens: i64) -> ContextItem {
ContextItemBuilder::new(content, tokens).build().unwrap()
}
fn make_event(stage: PipelineStage) -> TraceEvent {
TraceEvent {
stage,
duration_ms: 1.0,
item_count: 0,
message: None,
}
}
#[test]
fn null_is_zst() {
assert_eq!(std::mem::size_of::<NullTraceCollector>(), 0);
}
#[test]
fn null_is_not_enabled() {
assert!(!NullTraceCollector.is_enabled());
}
#[test]
fn null_record_methods_are_noop() {
let mut c = NullTraceCollector;
let _ = c.is_enabled();
c.record_stage_event(make_event(PipelineStage::Classify));
c.record_item_event(make_event(PipelineStage::Score));
c.record_included(make_item("inc", 5), 1.0, InclusionReason::Scored);
c.record_excluded(
make_item("exc", 5),
0.5,
ExclusionReason::BudgetExceeded {
item_tokens: 5,
available_tokens: 0,
},
);
c.set_candidates(2, 10);
}
#[test]
fn diagnostic_is_enabled() {
assert!(DiagnosticTraceCollector::new(TraceDetailLevel::Stage).is_enabled());
assert!(DiagnosticTraceCollector::new(TraceDetailLevel::Item).is_enabled());
}
#[test]
fn stage_level_only_records_stage_events() {
let mut c = DiagnosticTraceCollector::new(TraceDetailLevel::Stage);
c.record_stage_event(make_event(PipelineStage::Classify));
c.record_item_event(make_event(PipelineStage::Score));
let report = c.into_report();
assert_eq!(report.events.len(), 1);
}
#[test]
fn item_level_records_both() {
let mut c = DiagnosticTraceCollector::new(TraceDetailLevel::Item);
c.record_stage_event(make_event(PipelineStage::Classify));
c.record_item_event(make_event(PipelineStage::Score));
let report = c.into_report();
assert_eq!(report.events.len(), 2);
}
#[test]
fn callback_invoked_on_stage_event() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let mut c = DiagnosticTraceCollector::with_callback(
TraceDetailLevel::Stage,
Box::new(move |_event| {
counter_clone.fetch_add(1, Ordering::SeqCst);
}),
);
c.record_stage_event(make_event(PipelineStage::Classify));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn callback_not_invoked_when_item_event_filtered() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let mut c = DiagnosticTraceCollector::with_callback(
TraceDetailLevel::Stage,
Box::new(move |_event| {
counter_clone.fetch_add(1, Ordering::SeqCst);
}),
);
c.record_item_event(make_event(PipelineStage::Score));
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[test]
fn into_report_sort_score_desc() {
let mut c = DiagnosticTraceCollector::new(TraceDetailLevel::Stage);
c.record_excluded(
make_item("low", 5),
2.0,
ExclusionReason::BudgetExceeded {
item_tokens: 5,
available_tokens: 0,
},
);
c.record_excluded(
make_item("high", 5),
5.0,
ExclusionReason::BudgetExceeded {
item_tokens: 5,
available_tokens: 0,
},
);
let report = c.into_report();
assert_eq!(report.excluded.len(), 2);
assert_eq!(report.excluded[0].score, 5.0);
}
#[test]
fn into_report_sort_stable_on_tie() {
let mut c = DiagnosticTraceCollector::new(TraceDetailLevel::Stage);
c.record_excluded(
make_item("first", 5),
3.0,
ExclusionReason::BudgetExceeded {
item_tokens: 5,
available_tokens: 0,
},
);
c.record_excluded(
make_item("second", 5),
3.0,
ExclusionReason::BudgetExceeded {
item_tokens: 5,
available_tokens: 0,
},
);
let report = c.into_report();
assert_eq!(report.excluded[0].item.content(), "first");
assert_eq!(report.excluded[1].item.content(), "second");
}
#[test]
fn item_recording_populates_report_fields() {
let mut c = DiagnosticTraceCollector::new(TraceDetailLevel::Stage);
c.record_included(make_item("item_a", 30), 4.0, InclusionReason::Scored);
c.record_excluded(
make_item("item_b", 100),
1.0,
ExclusionReason::BudgetExceeded {
item_tokens: 100,
available_tokens: 50,
},
);
c.set_candidates(2, 60);
let report = c.into_report();
assert_eq!(report.included.len(), 1);
assert_eq!(report.excluded.len(), 1);
assert_eq!(report.total_candidates, 2);
assert_eq!(report.total_tokens_considered, 60);
}
#[test]
fn into_report_events_in_insertion_order() {
let mut c = DiagnosticTraceCollector::new(TraceDetailLevel::Stage);
c.record_stage_event(make_event(PipelineStage::Classify));
c.record_stage_event(make_event(PipelineStage::Score));
let report = c.into_report();
assert_eq!(report.events[0].stage, PipelineStage::Classify);
assert_eq!(report.events[1].stage, PipelineStage::Score);
}
#[test]
fn on_pipeline_completed_default_is_noop() {
use crate::diagnostics::StageTraceSnapshot;
use crate::model::ContextBudget;
use std::collections::HashMap;
let budget = ContextBudget::new(1000, 800, 0, HashMap::new(), 0.0).unwrap();
let report = DiagnosticTraceCollector::new(TraceDetailLevel::Stage).into_report();
NullTraceCollector.on_pipeline_completed(&report, &budget, &[]);
let mut diag = DiagnosticTraceCollector::new(TraceDetailLevel::Stage);
diag.on_pipeline_completed(&report, &budget, &[]);
let snap = StageTraceSnapshot {
stage: PipelineStage::Classify,
item_count_in: 3,
item_count_out: 3,
duration_ms: 0.1,
excluded: vec![],
};
NullTraceCollector.on_pipeline_completed(&report, &budget, &[snap]);
}
}