use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct N1QueryTracker {
counts: HashMap<(&'static str, &'static str), AtomicUsize>,
threshold: usize,
enabled: bool,
call_sites: Vec<CallSite>,
}
impl Default for N1QueryTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CallSite {
pub parent_type: &'static str,
pub relationship: &'static str,
pub file: &'static str,
pub line: u32,
pub timestamp: std::time::Instant,
}
#[derive(Debug, Clone, Default)]
pub struct N1Stats {
pub total_loads: usize,
pub relationships_loaded: usize,
pub potential_n1: usize,
}
impl N1QueryTracker {
#[must_use]
pub fn new() -> Self {
Self {
counts: HashMap::new(),
threshold: 3,
enabled: true,
call_sites: Vec::new(),
}
}
#[must_use]
pub fn with_threshold(mut self, threshold: usize) -> Self {
self.threshold = threshold;
self
}
#[must_use]
pub fn threshold(&self) -> usize {
self.threshold
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn disable(&mut self) {
self.enabled = false;
}
pub fn enable(&mut self) {
self.enabled = true;
}
#[track_caller]
pub fn record_load(&mut self, parent_type: &'static str, relationship: &'static str) {
if !self.enabled {
return;
}
let key = (parent_type, relationship);
let count = self
.counts
.entry(key)
.or_insert_with(|| AtomicUsize::new(0))
.fetch_add(1, Ordering::Relaxed)
+ 1;
let caller = std::panic::Location::caller();
self.call_sites.push(CallSite {
parent_type,
relationship,
file: caller.file(),
line: caller.line(),
timestamp: std::time::Instant::now(),
});
if count == self.threshold {
self.emit_warning(parent_type, relationship, count);
}
}
fn emit_warning(&self, parent_type: &'static str, relationship: &'static str, count: usize) {
tracing::warn!(
target: "sqlmodel::n1",
parent = parent_type,
relationship = relationship,
queries = count,
threshold = self.threshold,
"N+1 QUERY PATTERN DETECTED! Consider using Session::load_many() for batch loading."
);
let sites: Vec<_> = self
.call_sites
.iter()
.filter(|s| s.parent_type == parent_type && s.relationship == relationship)
.take(5)
.collect();
for (i, site) in sites.iter().enumerate() {
tracing::debug!(
target: "sqlmodel::n1",
index = i,
file = site.file,
line = site.line,
" [{}] {}:{}",
i,
site.file,
site.line
);
}
}
pub fn reset(&mut self) {
self.counts.clear();
self.call_sites.clear();
}
#[must_use]
pub fn count_for(&self, parent_type: &'static str, relationship: &'static str) -> usize {
self.counts
.get(&(parent_type, relationship))
.map_or(0, |c| c.load(Ordering::Relaxed))
}
#[must_use]
pub fn stats(&self) -> N1Stats {
N1Stats {
total_loads: self
.counts
.values()
.map(|c| c.load(Ordering::Relaxed))
.sum(),
relationships_loaded: self.counts.len(),
potential_n1: self
.counts
.iter()
.filter(|(_, c)| c.load(Ordering::Relaxed) >= self.threshold)
.count(),
}
}
#[must_use]
pub fn call_sites(&self) -> &[CallSite] {
&self.call_sites
}
}
pub struct N1DetectionScope {
initial_stats: N1Stats,
threshold: usize,
verbose: bool,
}
impl N1DetectionScope {
#[must_use]
pub fn new(initial_stats: N1Stats, threshold: usize) -> Self {
tracing::debug!(
target: "sqlmodel::n1",
threshold = threshold,
"N+1 detection scope started"
);
Self {
initial_stats,
threshold,
verbose: false,
}
}
#[must_use]
pub fn from_tracker(tracker: &N1QueryTracker) -> Self {
Self::new(tracker.stats(), tracker.threshold())
}
#[must_use]
pub fn verbose(mut self) -> Self {
self.verbose = true;
self
}
pub fn log_summary(&self, final_stats: &N1Stats) {
let new_loads = final_stats
.total_loads
.saturating_sub(self.initial_stats.total_loads);
let new_relationships = final_stats
.relationships_loaded
.saturating_sub(self.initial_stats.relationships_loaded);
let new_n1 = final_stats
.potential_n1
.saturating_sub(self.initial_stats.potential_n1);
if new_n1 > 0 {
tracing::warn!(
target: "sqlmodel::n1",
potential_n1 = new_n1,
total_loads = new_loads,
relationships = new_relationships,
threshold = self.threshold,
"N+1 ISSUES DETECTED in this scope! Consider using Session::load_many() for batch loading."
);
} else if self.verbose {
tracing::info!(
target: "sqlmodel::n1",
total_loads = new_loads,
relationships = new_relationships,
"N+1 detection scope completed (no issues)"
);
} else {
tracing::debug!(
target: "sqlmodel::n1",
total_loads = new_loads,
relationships = new_relationships,
"N+1 detection scope completed (no issues)"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tracker_new_defaults() {
let tracker = N1QueryTracker::new();
assert_eq!(tracker.threshold(), 3);
assert!(tracker.is_enabled());
}
#[test]
fn test_tracker_with_threshold() {
let tracker = N1QueryTracker::new().with_threshold(5);
assert_eq!(tracker.threshold(), 5);
}
#[test]
fn test_tracker_enable_disable() {
let mut tracker = N1QueryTracker::new();
assert!(tracker.is_enabled());
tracker.disable();
assert!(!tracker.is_enabled());
tracker.enable();
assert!(tracker.is_enabled());
}
#[test]
fn test_tracker_records_single_load() {
let mut tracker = N1QueryTracker::new();
tracker.record_load("Hero", "team");
assert_eq!(tracker.count_for("Hero", "team"), 1);
}
#[test]
fn test_tracker_records_multiple_loads() {
let mut tracker = N1QueryTracker::new().with_threshold(10);
for _ in 0..5 {
tracker.record_load("Hero", "team");
}
assert_eq!(tracker.count_for("Hero", "team"), 5);
}
#[test]
fn test_tracker_records_multiple_relationships() {
let mut tracker = N1QueryTracker::new();
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "powers");
tracker.record_load("Team", "heroes");
assert_eq!(tracker.count_for("Hero", "team"), 2);
assert_eq!(tracker.count_for("Hero", "powers"), 1);
assert_eq!(tracker.count_for("Team", "heroes"), 1);
}
#[test]
fn test_tracker_disabled_no_recording() {
let mut tracker = N1QueryTracker::new();
tracker.disable();
tracker.record_load("Hero", "team");
assert_eq!(tracker.count_for("Hero", "team"), 0);
}
#[test]
fn test_tracker_reset_clears_counts() {
let mut tracker = N1QueryTracker::new();
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "team");
assert_eq!(tracker.count_for("Hero", "team"), 2);
tracker.reset();
assert_eq!(tracker.count_for("Hero", "team"), 0);
assert!(tracker.call_sites().is_empty());
}
#[test]
fn test_callsite_captures_location() {
let mut tracker = N1QueryTracker::new();
tracker.record_load("Hero", "team");
assert_eq!(tracker.call_sites().len(), 1);
let site = &tracker.call_sites()[0];
assert_eq!(site.parent_type, "Hero");
assert_eq!(site.relationship, "team");
assert!(site.file.contains("n1_detection.rs"));
assert!(site.line > 0);
}
#[test]
fn test_callsite_timestamp_monotonic() {
let mut tracker = N1QueryTracker::new();
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "team");
let sites = tracker.call_sites();
assert!(sites[1].timestamp >= sites[0].timestamp);
}
#[test]
fn test_stats_total_loads_accurate() {
let mut tracker = N1QueryTracker::new().with_threshold(10);
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "powers");
let stats = tracker.stats();
assert_eq!(stats.total_loads, 3);
}
#[test]
fn test_stats_relationships_count() {
let mut tracker = N1QueryTracker::new();
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "powers");
tracker.record_load("Team", "heroes");
let stats = tracker.stats();
assert_eq!(stats.relationships_loaded, 3);
}
#[test]
fn test_stats_potential_n1_count() {
let mut tracker = N1QueryTracker::new().with_threshold(2);
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "team"); tracker.record_load("Hero", "powers");
let stats = tracker.stats();
assert_eq!(stats.potential_n1, 1);
}
#[test]
fn test_stats_default() {
let stats = N1Stats::default();
assert_eq!(stats.total_loads, 0);
assert_eq!(stats.relationships_loaded, 0);
assert_eq!(stats.potential_n1, 0);
}
#[test]
fn test_scope_new_captures_initial_state() {
let initial = N1Stats {
total_loads: 5,
relationships_loaded: 2,
potential_n1: 1,
};
let scope = N1DetectionScope::new(initial.clone(), 3);
assert_eq!(scope.initial_stats.total_loads, 5);
assert_eq!(scope.threshold, 3);
}
#[test]
fn test_scope_from_tracker() {
let mut tracker = N1QueryTracker::new().with_threshold(5);
tracker.record_load("Hero", "team");
tracker.record_load("Hero", "team");
let scope = N1DetectionScope::from_tracker(&tracker);
assert_eq!(scope.threshold, 5);
assert_eq!(scope.initial_stats.total_loads, 2);
}
#[test]
fn test_scope_verbose_flag() {
let initial = N1Stats::default();
let scope = N1DetectionScope::new(initial, 3);
assert!(!scope.verbose);
let verbose_scope = scope.verbose();
assert!(verbose_scope.verbose);
}
#[test]
fn test_scope_log_summary_no_issues() {
let initial = N1Stats::default();
let scope = N1DetectionScope::new(initial, 3);
let final_stats = N1Stats {
total_loads: 2,
relationships_loaded: 1,
potential_n1: 0,
};
scope.log_summary(&final_stats);
}
#[test]
fn test_scope_log_summary_with_issues() {
let initial = N1Stats::default();
let scope = N1DetectionScope::new(initial, 3);
let final_stats = N1Stats {
total_loads: 10,
relationships_loaded: 2,
potential_n1: 1,
};
scope.log_summary(&final_stats);
}
#[test]
fn test_scope_calculates_delta() {
let initial = N1Stats {
total_loads: 5,
relationships_loaded: 2,
potential_n1: 0,
};
let scope = N1DetectionScope::new(initial, 3);
let final_stats = N1Stats {
total_loads: 15,
relationships_loaded: 4,
potential_n1: 2,
};
scope.log_summary(&final_stats);
}
}