use crate::types::{Edge, EdgeType, NodeType};
use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use super::helpers::generate_node_id;
pub(crate) struct CrossStatementTracker {
pub(crate) produced_tables: HashMap<String, usize>,
pub(crate) produced_views: HashSet<String>,
pub(crate) declared_views: HashSet<String>,
pub(crate) declared_tables: HashSet<String>,
pub(crate) declared_ephemerals: HashSet<String>,
pub(crate) consumed_tables: HashMap<String, Vec<usize>>,
pub(crate) all_relations: HashSet<String>,
pub(crate) all_ctes: HashSet<String>,
}
impl CrossStatementTracker {
pub(crate) fn new() -> Self {
Self {
produced_tables: HashMap::new(),
produced_views: HashSet::new(),
declared_views: HashSet::new(),
declared_tables: HashSet::new(),
declared_ephemerals: HashSet::new(),
consumed_tables: HashMap::new(),
all_relations: HashSet::new(),
all_ctes: HashSet::new(),
}
}
pub(crate) fn record_produced(&mut self, canonical: &str, statement_index: usize) {
self.produced_tables
.insert(canonical.to_string(), statement_index);
self.all_relations.insert(canonical.to_string());
}
pub(crate) fn record_view_produced(&mut self, canonical: &str, statement_index: usize) {
self.produced_views.insert(canonical.to_string());
self.declared_views.insert(canonical.to_string());
self.declared_tables.remove(canonical);
self.record_produced(canonical, statement_index);
}
pub(crate) fn declare_ephemeral(&mut self, canonical: &str) {
self.declared_ephemerals.insert(canonical.to_string());
self.all_ctes.insert(canonical.to_string());
}
pub(crate) fn declare_view(&mut self, canonical: &str) {
self.declared_views.insert(canonical.to_string());
self.declared_tables.remove(canonical);
self.all_relations.insert(canonical.to_string());
}
pub(crate) fn declare_table(&mut self, canonical: &str) {
if !self.declared_views.contains(canonical) {
self.declared_tables.insert(canonical.to_string());
}
self.all_relations.insert(canonical.to_string());
}
pub(crate) fn record_consumed(&mut self, canonical: &str, statement_index: usize) {
self.consumed_tables
.entry(canonical.to_string())
.or_default()
.push(statement_index);
self.all_relations.insert(canonical.to_string());
}
pub(crate) fn record_cte(&mut self, cte_name: &str) {
self.all_ctes.insert(cte_name.to_string());
}
#[cfg(test)]
pub(crate) fn is_view(&self, canonical: &str) -> bool {
self.is_view_relation(canonical)
}
fn is_view_relation(&self, canonical: &str) -> bool {
self.produced_views.contains(canonical) || self.declared_views.contains(canonical)
}
fn is_ephemeral_relation(&self, canonical: &str) -> bool {
self.declared_ephemerals.contains(canonical)
}
pub(crate) fn was_produced(&self, canonical: &str) -> bool {
self.produced_tables.contains_key(canonical)
}
pub(crate) fn is_declared(&self, canonical: &str) -> bool {
self.declared_views.contains(canonical)
|| self.declared_tables.contains(canonical)
|| self.declared_ephemerals.contains(canonical)
}
#[cfg(test)]
pub(crate) fn producer_index(&self, canonical: &str) -> Option<usize> {
self.produced_tables.get(canonical).copied()
}
pub(crate) fn remove(&mut self, canonical: &str) {
self.produced_tables.remove(canonical);
self.produced_views.remove(canonical);
self.declared_views.remove(canonical);
self.declared_tables.remove(canonical);
self.declared_ephemerals.remove(canonical);
}
pub(crate) fn relation_identity(&self, canonical: &str) -> (Arc<str>, NodeType) {
if self.is_ephemeral_relation(canonical) {
(generate_node_id("cte", canonical), NodeType::Cte)
} else if self.is_view_relation(canonical) {
(generate_node_id("view", canonical), NodeType::View)
} else {
(generate_node_id("table", canonical), NodeType::Table)
}
}
pub(crate) fn relation_instance_identity(
&self,
canonical: &str,
alias: &str,
scope_id: usize,
) -> (Arc<str>, NodeType) {
let simple_name = crate::analyzer::helpers::extract_simple_name(canonical);
if alias == canonical || alias == simple_name {
return self.relation_identity(canonical);
}
let instance_key = format!("{canonical}::{alias}::scope_{scope_id}");
if self.is_ephemeral_relation(canonical) {
(generate_node_id("cte", &instance_key), NodeType::Cte)
} else if self.is_view_relation(canonical) {
(generate_node_id("view", &instance_key), NodeType::View)
} else {
(generate_node_id("table", &instance_key), NodeType::Table)
}
}
pub(crate) fn relation_node_id(&self, canonical: &str) -> Arc<str> {
self.relation_identity(canonical).0
}
pub(crate) fn build_cross_statement_edges(&self) -> Vec<Edge> {
let mut edges = Vec::new();
for (table_name, consumers) in &self.consumed_tables {
if let Some(&producer_idx) = self.produced_tables.get(table_name) {
for &consumer_idx in consumers {
if consumer_idx > producer_idx {
let mut hasher = DefaultHasher::new();
table_name.hash(&mut hasher);
producer_idx.hash(&mut hasher);
consumer_idx.hash(&mut hasher);
let edge_id = format!("cross_{:016x}", hasher.finish());
let node_id = self.relation_node_id(table_name);
let mut edge =
Edge::new(edge_id, node_id.clone(), node_id, EdgeType::CrossStatement);
edge.statement_ids = vec![producer_idx, consumer_idx];
edges.push(edge);
}
}
}
}
edges
}
}
impl Default for CrossStatementTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_record_produced_consumed() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("public.users", 0);
tracker.record_consumed("public.users", 1);
tracker.record_consumed("public.users", 2);
assert!(tracker.was_produced("public.users"));
assert_eq!(tracker.producer_index("public.users"), Some(0));
assert_eq!(
tracker.consumed_tables.get("public.users"),
Some(&vec![1, 2])
);
}
#[test]
fn test_view_vs_table() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("public.my_table", 0);
tracker.record_view_produced("public.my_view", 1);
assert!(!tracker.is_view("public.my_table"));
assert!(tracker.is_view("public.my_view"));
let (table_id, table_type) = tracker.relation_identity("public.my_table");
assert!(table_id.starts_with("table_"));
assert_eq!(table_type, NodeType::Table);
let (view_id, view_type) = tracker.relation_identity("public.my_view");
assert!(view_id.starts_with("view_"));
assert_eq!(view_type, NodeType::View);
}
#[test]
fn test_declared_view_uses_view_identity_before_producer_runs() {
let mut tracker = CrossStatementTracker::new();
tracker.declare_view("models.future_view");
assert!(tracker.is_view("models.future_view"));
let (view_id, view_type) = tracker.relation_identity("models.future_view");
assert!(view_id.starts_with("view_"));
assert_eq!(view_type, NodeType::View);
}
#[test]
fn test_declared_ephemeral_uses_cte_identity_before_producer_runs() {
let mut tracker = CrossStatementTracker::new();
tracker.declare_ephemeral("models.future_ephemeral");
let (node_id, node_type) = tracker.relation_identity("models.future_ephemeral");
assert!(node_id.starts_with("cte_"));
assert_eq!(node_type, NodeType::Cte);
}
#[test]
fn test_cross_statement_edges() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.temp", 0);
tracker.record_consumed("staging.temp", 1);
tracker.record_consumed("staging.temp", 2);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 2);
assert!(edges
.iter()
.all(|e| e.edge_type == EdgeType::CrossStatement));
assert!(edges
.iter()
.any(|e| e.statement_ids == vec![0usize, 1usize]));
assert!(edges
.iter()
.any(|e| e.statement_ids == vec![0usize, 2usize]));
}
#[test]
fn test_remove() {
let mut tracker = CrossStatementTracker::new();
tracker.record_view_produced("public.temp_view", 0);
assert!(tracker.is_view("public.temp_view"));
tracker.remove("public.temp_view");
assert!(!tracker.is_view("public.temp_view"));
assert!(!tracker.was_produced("public.temp_view"));
}
#[test]
fn test_no_cross_statement_edges_for_unconsumed_table() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.temp", 0);
let edges = tracker.build_cross_statement_edges();
assert!(edges.is_empty());
}
#[test]
fn test_no_cross_statement_edges_for_external_table() {
let mut tracker = CrossStatementTracker::new();
tracker.record_consumed("external.source", 0);
tracker.record_consumed("external.source", 1);
let edges = tracker.build_cross_statement_edges();
assert!(edges.is_empty());
}
#[test]
fn test_no_edge_when_consumer_before_producer() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.temp", 1);
tracker.record_consumed("staging.temp", 0);
let edges = tracker.build_cross_statement_edges();
assert!(edges.is_empty());
}
#[test]
fn test_multiple_tables_cross_statement() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.a", 0);
tracker.record_produced("staging.b", 1);
tracker.record_consumed("staging.a", 2);
tracker.record_consumed("staging.b", 2);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 2);
}
#[test]
fn test_record_cte() {
let mut tracker = CrossStatementTracker::new();
tracker.record_cte("my_cte");
tracker.record_cte("another_cte");
assert!(tracker.all_ctes.contains("my_cte"));
assert!(tracker.all_ctes.contains("another_cte"));
assert_eq!(tracker.all_ctes.len(), 2);
}
#[test]
fn test_all_relations_tracking() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.a", 0);
tracker.record_consumed("external.b", 1);
tracker.record_view_produced("staging.v", 2);
assert!(tracker.all_relations.contains("staging.a"));
assert!(tracker.all_relations.contains("external.b"));
assert!(tracker.all_relations.contains("staging.v"));
assert_eq!(tracker.all_relations.len(), 3);
}
#[test]
fn test_default_trait() {
let tracker = CrossStatementTracker::default();
assert!(tracker.produced_tables.is_empty());
assert!(tracker.consumed_tables.is_empty());
assert!(tracker.produced_views.is_empty());
assert!(tracker.declared_views.is_empty());
assert!(tracker.declared_ephemerals.is_empty());
}
#[test]
fn test_relation_node_id() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("public.users", 0);
tracker.record_view_produced("public.user_view", 1);
let table_id = tracker.relation_node_id("public.users");
let view_id = tracker.relation_node_id("public.user_view");
assert!(table_id.starts_with("table_"));
assert!(view_id.starts_with("view_"));
assert_ne!(table_id, view_id);
}
#[test]
fn test_cross_statement_edge_attributes() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.temp", 0);
tracker.record_consumed("staging.temp", 1);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 1);
let edge = &edges[0];
assert!(edge.id.starts_with("cross_"));
assert_eq!(edge.from, edge.to); assert_eq!(edge.statement_ids, vec![0usize, 1usize]);
assert!(edge.metadata.is_none());
}
#[test]
fn test_producer_overwrite() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.data", 0);
assert_eq!(tracker.producer_index("staging.data"), Some(0));
tracker.record_produced("staging.data", 2);
assert_eq!(tracker.producer_index("staging.data"), Some(2));
}
#[test]
fn test_same_statement_producer_consumer() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.data", 0);
tracker.record_consumed("staging.data", 0);
let edges = tracker.build_cross_statement_edges();
assert!(edges.is_empty());
}
#[test]
fn test_remove_preserves_all_relations() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("staging.temp", 0);
assert!(tracker.all_relations.contains("staging.temp"));
tracker.remove("staging.temp");
assert!(tracker.all_relations.contains("staging.temp"));
}
#[test]
fn test_remove_nonexistent_table() {
let mut tracker = CrossStatementTracker::new();
tracker.remove("nonexistent.table");
assert!(!tracker.was_produced("nonexistent.table"));
}
#[test]
fn test_view_edge_type() {
let mut tracker = CrossStatementTracker::new();
tracker.record_view_produced("analytics.user_summary", 0);
tracker.record_consumed("analytics.user_summary", 1);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 1);
let edge = &edges[0];
assert!(edge.from.starts_with("view_"));
assert_eq!(edge.edge_type, EdgeType::CrossStatement);
}
#[test]
fn test_complex_etl_pattern() {
let mut tracker = CrossStatementTracker::new();
tracker.record_consumed("external.source", 0);
tracker.record_produced("staging.raw", 0);
tracker.record_consumed("staging.raw", 1);
tracker.record_produced("staging.cleaned", 1);
tracker.record_consumed("staging.cleaned", 2);
tracker.record_produced("mart.final", 2);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 2);
let raw_node_id = tracker.relation_node_id("staging.raw");
let cleaned_node_id = tracker.relation_node_id("staging.cleaned");
let raw_edge = edges.iter().find(|e| e.from == raw_node_id);
let cleaned_edge = edges.iter().find(|e| e.from == cleaned_node_id);
assert!(raw_edge.is_some());
assert!(cleaned_edge.is_some());
let raw_edge = raw_edge.unwrap();
assert_eq!(raw_edge.statement_ids, vec![0usize, 1usize]);
}
#[test]
fn test_multiple_consumers_same_table() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("shared.data", 0);
tracker.record_consumed("shared.data", 1);
tracker.record_consumed("shared.data", 2);
tracker.record_consumed("shared.data", 3);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 3);
for edge in &edges {
assert_eq!(edge.statement_ids.first().copied(), Some(0usize));
}
}
#[test]
fn test_unknown_relation_identity() {
let tracker = CrossStatementTracker::new();
let (id, node_type) = tracker.relation_identity("unknown.table");
assert!(id.starts_with("table_"));
assert_eq!(node_type, NodeType::Table);
}
#[test]
fn test_duplicate_cte_recording() {
let mut tracker = CrossStatementTracker::new();
tracker.record_cte("my_cte");
tracker.record_cte("my_cte");
assert_eq!(tracker.all_ctes.len(), 1);
}
#[test]
fn test_edge_id_uniqueness() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("table_a", 0);
tracker.record_produced("table_b", 1);
tracker.record_consumed("table_a", 2);
tracker.record_consumed("table_b", 2);
tracker.record_consumed("table_a", 3);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 3);
let ids: Vec<_> = edges.iter().map(|e| &e.id).collect();
let unique_ids: std::collections::HashSet<_> = ids.iter().collect();
assert_eq!(ids.len(), unique_ids.len());
}
#[test]
fn edge_ids_differ_for_same_statement_pairs() {
let mut tracker = CrossStatementTracker::new();
tracker.record_produced("table_a", 0);
tracker.record_consumed("table_a", 1);
tracker.record_produced("table_b", 0);
tracker.record_consumed("table_b", 1);
let edges = tracker.build_cross_statement_edges();
assert_eq!(edges.len(), 2);
let ids: std::collections::HashSet<_> = edges.iter().map(|edge| edge.id.clone()).collect();
assert_eq!(ids.len(), 2, "expected unique edge IDs for each table");
}
}