use serde::{Deserialize, Serialize};
use crate::is_ident_byte;
use crate::stmt::{SqlVerb, Statement};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct TableAccess {
pub schema: Option<String>,
pub table: String,
pub access: AccessKind,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AccessKind {
Read,
Write,
}
#[must_use]
pub fn extract_table_accesses(stmts: &[Statement]) -> Vec<TableAccess> {
extract_table_accesses_bounded(stmts).0
}
#[must_use]
pub fn extract_table_accesses_bounded(
stmts: &[Statement],
) -> (Vec<TableAccess>, crate::RecursionOutcome) {
let mut out: Vec<TableAccess> = Vec::new();
let mut outcome = crate::RecursionOutcome::default();
walk_table_accesses(stmts, 0, &mut out, &mut outcome);
(dedup(out), outcome)
}
fn walk_table_accesses(
stmts: &[Statement],
depth: usize,
out: &mut Vec<TableAccess>,
outcome: &mut crate::RecursionOutcome,
) {
macro_rules! recurse_body {
($text:expr) => {{
if depth + 1 >= crate::MAX_RELOWER_DEPTH {
outcome.note_truncated();
} else {
let lowered = crate::lower_statement_body($text);
walk_table_accesses(&lowered, depth + 1, out, outcome);
}
}};
}
for stmt in stmts {
match stmt {
Statement::Sql { verb, raw_text } => {
accesses_from_sql(*verb, raw_text, out);
}
Statement::If {
arms,
else_body_text,
} => {
for arm in arms {
recurse_body!(&arm.body_text);
}
if let Some(eb) = else_body_text {
recurse_body!(eb);
}
}
Statement::ForLoop {
range_text,
body_text,
..
} => {
if let Some(inner) = parenthesised_query(range_text) {
recurse_body!(inner);
}
recurse_body!(body_text);
}
Statement::WhileLoop { body_text, .. } | Statement::BareLoop { body_text } => {
recurse_body!(body_text);
}
Statement::NestedBlock { body_text } => {
let inner = crate::calls::strip_block_wrapper(body_text);
if inner != body_text.as_str() {
recurse_body!(inner);
}
}
_ => {}
}
}
}
fn parenthesised_query(range_text: &str) -> Option<&str> {
let trimmed = range_text.trim();
let inner = trimmed.strip_prefix('(')?.strip_suffix(')')?;
Some(inner.trim())
}
fn accesses_from_sql(verb: SqlVerb, raw: &str, out: &mut Vec<TableAccess>) {
let upper = crate::fact_emit::mask_string_literals(&raw.to_ascii_uppercase());
match verb {
SqlVerb::Select => {
for t in tables_after(&upper, raw, "FROM") {
push(out, t, AccessKind::Read);
}
for t in tables_after(&upper, raw, "JOIN") {
push(out, t, AccessKind::Read);
}
}
SqlVerb::Insert => {
for t in tables_after(&upper, raw, "INTO") {
push(out, t, AccessKind::Write);
}
for t in tables_after(&upper, raw, "FROM") {
push(out, t, AccessKind::Read);
}
}
SqlVerb::Update => {
for t in tables_after(&upper, raw, "UPDATE") {
push(out, t, AccessKind::Write);
}
for t in tables_after(&upper, raw, "FROM") {
push(out, t, AccessKind::Read);
}
}
SqlVerb::Delete => {
let target = delete_target(raw);
let target_folded = target.as_deref().map(folded_name);
if let Some(t) = target {
push(out, t, AccessKind::Write);
}
let mut target_consumed = false;
for t in tables_after(&upper, raw, "FROM") {
if !target_consumed && Some(folded_name(&t)) == target_folded {
target_consumed = true;
continue;
}
push(out, t, AccessKind::Read);
}
for t in tables_after(&upper, raw, "JOIN") {
push(out, t, AccessKind::Read);
}
}
SqlVerb::Merge => {
for t in tables_after(&upper, raw, "INTO") {
push(out, t, AccessKind::Write);
}
for t in tables_after(&upper, raw, "USING") {
push(out, t, AccessKind::Read);
}
}
}
}
fn folded_name(raw_name: &str) -> String {
raw_name.to_ascii_uppercase()
}
fn delete_target(raw: &str) -> Option<String> {
let bytes = raw.as_bytes();
let mut i = 0;
while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
i += 1;
}
i = skip_ws(bytes, i);
if bytes[i..]
.get(..4)
.is_some_and(|w| w.eq_ignore_ascii_case(b"FROM"))
&& (i + 4 >= bytes.len() || !is_ident_byte(bytes[i + 4]))
{
i = skip_ws(bytes, i + 4);
}
let start = i;
while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
i += 1;
}
if i > start {
Some(raw[start..i].to_string())
} else {
None
}
}
fn skip_ws(bytes: &[u8], mut i: usize) -> usize {
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
i
}
fn tables_after(upper: &str, raw: &str, keyword: &str) -> Vec<String> {
let mut out = Vec::new();
let kw = keyword.to_ascii_uppercase();
let traverse_commas = kw == "FROM";
let bytes = upper.as_bytes();
let mut search = 0;
while let Some(rel) = upper[search..].find(&kw) {
let abs = search + rel;
search = abs + kw.len();
let prev_ok = abs == 0 || !is_ident_byte(bytes[abs - 1]);
let after = abs + kw.len();
let next_ok = after >= bytes.len() || !is_ident_byte(bytes[after]);
if !(prev_ok && next_ok) {
continue;
}
let mut i = after;
loop {
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
let start = i;
while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
i += 1;
}
if i == start {
break;
}
out.push(raw[start..i].to_string());
if !traverse_commas {
break;
}
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i < bytes.len() && is_ident_byte(bytes[i]) {
while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
i += 1;
}
}
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i >= bytes.len() || bytes[i] != b',' {
break;
}
i += 1; }
}
out
}
fn push(out: &mut Vec<TableAccess>, raw_name: String, access: AccessKind) {
let folded = raw_name.to_ascii_uppercase();
let (schema, table) = match folded.rsplit_once('.') {
Some((s, t)) if !t.is_empty() => (Some(s.to_string()), t.to_string()),
_ => (None, folded),
};
if table.is_empty() || table == "DUAL" {
return;
}
out.push(TableAccess {
schema,
table,
access,
});
}
fn dedup(mut v: Vec<TableAccess>) -> Vec<TableAccess> {
let mut seen: std::collections::BTreeSet<(Option<String>, String, AccessKind)> =
std::collections::BTreeSet::new();
v.retain(|a| seen.insert((a.schema.clone(), a.table.clone(), a.access)));
v
}
impl PartialOrd for AccessKind {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for AccessKind {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(*self as u8).cmp(&(*other as u8))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lower_statement_body;
#[test]
fn select_from_is_a_read() {
let s = lower_statement_body("SELECT id INTO v FROM employees;");
let a = extract_table_accesses(&s);
assert_eq!(a.len(), 1);
assert_eq!(a[0].table, "EMPLOYEES");
assert_eq!(a[0].access, AccessKind::Read);
}
#[test]
fn legacy_comma_join_reads_every_table() {
let s = lower_statement_body("SELECT a.x INTO v FROM emp a, dept b WHERE a.d = b.id;");
let acc = extract_table_accesses(&s);
for t in ["EMP", "DEPT"] {
assert!(
acc.iter()
.any(|x| x.table == t && x.access == AccessKind::Read),
"comma-join must read {t}: {acc:?}"
);
}
let s3 = lower_statement_body("SELECT x INTO v FROM a, b, c WHERE 1 = 1;");
let acc3 = extract_table_accesses(&s3);
for t in ["A", "B", "C"] {
assert!(
acc3.iter()
.any(|x| x.table == t && x.access == AccessKind::Read),
"comma-join must read {t}: {acc3:?}"
);
}
}
#[test]
fn clause_keyword_inside_string_literal_is_not_a_phantom_table() {
let s = lower_statement_body("UPDATE log SET msg = 'failed to INSERT INTO orders';");
let acc = extract_table_accesses(&s);
assert!(
!acc.iter().any(|x| x.table == "ORDERS"),
"INTO inside a literal must not mint a phantom ORDERS access: {acc:?}"
);
assert!(
acc.iter()
.any(|x| x.table == "LOG" && x.access == AccessKind::Write),
"the real UPDATE target LOG must still be a Write: {acc:?}"
);
}
#[test]
fn delete_with_multibyte_first_token_does_not_panic() {
let s = lower_statement_body("DELETE é★ WHERE x = 1;");
let _ = extract_table_accesses(&s);
let s2 = lower_statement_body("DELETE é★"); let _ = extract_table_accesses(&s2);
}
#[test]
fn insert_into_is_a_write() {
let s = lower_statement_body("INSERT INTO audit_log VALUES (1, 2);");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "AUDIT_LOG" && x.access == AccessKind::Write)
);
}
#[test]
fn insert_select_records_write_and_read() {
let s =
lower_statement_body("INSERT INTO summary SELECT dept_id, COUNT(*) FROM employees;");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "SUMMARY" && x.access == AccessKind::Write)
);
assert!(
a.iter()
.any(|x| x.table == "EMPLOYEES" && x.access == AccessKind::Read)
);
}
#[test]
fn update_is_a_write() {
let s = lower_statement_body("UPDATE employees SET salary = salary * 1.1;");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "EMPLOYEES" && x.access == AccessKind::Write)
);
}
#[test]
fn delete_from_is_a_write() {
let s = lower_statement_body("DELETE FROM stale_rows WHERE id < 100;");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "STALE_ROWS" && x.access == AccessKind::Write)
);
}
#[test]
fn delete_with_where_subquery_target_is_write_subquery_is_read() {
let s = lower_statement_body("DELETE FROM t WHERE id IN (SELECT id FROM staging);");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "T" && x.access == AccessKind::Write),
"DELETE target T must be a Write: {a:?}"
);
assert!(
a.iter()
.any(|x| x.table == "STAGING" && x.access == AccessKind::Read),
"WHERE sub-SELECT table STAGING must be a Read: {a:?}"
);
assert!(
!a.iter()
.any(|x| x.table == "STAGING" && x.access == AccessKind::Write),
"STAGING must NEVER be classified as a Write: {a:?}"
);
}
#[test]
fn from_less_delete_is_a_write() {
let s = lower_statement_body("DELETE employees WHERE id = 5;");
let a = extract_table_accesses(&s);
assert_eq!(a.len(), 1, "exactly one access expected: {a:?}");
assert_eq!(a[0].table, "EMPLOYEES");
assert_eq!(a[0].schema, None);
assert_eq!(a[0].access, AccessKind::Write);
}
#[test]
fn from_less_qualified_delete_is_a_write() {
let s = lower_statement_body("DELETE hr.audit_log WHERE ts < SYSDATE - 30;");
let a = extract_table_accesses(&s);
assert_eq!(a.len(), 1, "exactly one access expected: {a:?}");
assert_eq!(a[0].schema.as_deref(), Some("HR"));
assert_eq!(a[0].table, "AUDIT_LOG");
assert_eq!(a[0].access, AccessKind::Write);
}
#[test]
fn from_less_delete_with_where_subquery_target_write_subquery_read() {
let s = lower_statement_body("DELETE t WHERE id IN (SELECT id FROM staging);");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "T" && x.access == AccessKind::Write),
"FROM-less DELETE target T must be a Write: {a:?}"
);
assert!(
a.iter()
.any(|x| x.table == "STAGING" && x.access == AccessKind::Read),
"WHERE sub-SELECT table STAGING must be a Read: {a:?}"
);
assert!(
!a.iter()
.any(|x| x.table == "STAGING" && x.access == AccessKind::Write),
"STAGING must NEVER be classified as a Write: {a:?}"
);
}
#[test]
fn merge_writes_target_reads_source() {
let s = lower_statement_body(
"MERGE INTO target t USING source s ON (t.id = s.id) WHEN MATCHED THEN UPDATE SET t.v = s.v;",
);
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "TARGET" && x.access == AccessKind::Write)
);
assert!(
a.iter()
.any(|x| x.table == "SOURCE" && x.access == AccessKind::Read)
);
}
#[test]
fn schema_qualified_table_split() {
let s = lower_statement_body("SELECT 1 INTO v FROM hr.employees;");
let a = extract_table_accesses(&s);
assert_eq!(a[0].schema.as_deref(), Some("HR"));
assert_eq!(a[0].table, "EMPLOYEES");
}
#[test]
fn dual_is_filtered_out() {
let s = lower_statement_body("SELECT SYSDATE INTO v FROM dual;");
let a = extract_table_accesses(&s);
assert!(a.is_empty());
}
#[test]
fn loop_body_dml_recursed() {
let s = lower_statement_body("FOR i IN 1..10 LOOP INSERT INTO log VALUES (i); END LOOP;");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "LOG" && x.access == AccessKind::Write)
);
}
#[test]
fn cursor_for_loop_range_select_table_is_read() {
let s = lower_statement_body(
"FOR r IN (SELECT id FROM src) LOOP \
INSERT INTO dst VALUES (r.id); \
END LOOP;",
);
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "SRC" && x.access == AccessKind::Read),
"cursor-FOR-loop range sub-SELECT read of SRC must be extracted: {a:?}"
);
assert!(
a.iter()
.any(|x| x.table == "DST" && x.access == AccessKind::Write),
"loop body write of DST must still be extracted: {a:?}"
);
}
#[test]
fn numeric_range_for_loop_yields_no_extra_tables() {
let s = lower_statement_body("FOR i IN 1..10 LOOP NULL; END LOOP;");
let a = extract_table_accesses(&s);
assert!(a.is_empty(), "numeric range must not invent tables: {a:?}");
}
#[test]
fn duplicate_access_triples_dedupe() {
let s = lower_statement_body("SELECT 1 INTO a FROM t; SELECT 2 INTO b FROM t;");
let acc = extract_table_accesses(&s);
assert_eq!(acc.iter().filter(|x| x.table == "T").count(), 1);
}
#[test]
fn join_tables_are_reads() {
let s = lower_statement_body(
"SELECT 1 INTO v FROM employees e JOIN departments d ON e.dept = d.id;",
);
let a = extract_table_accesses(&s);
assert!(a.iter().any(|x| x.table == "EMPLOYEES"));
assert!(a.iter().any(|x| x.table == "DEPARTMENTS"));
assert!(a.iter().all(|x| x.access == AccessKind::Read));
}
#[test]
fn serde_round_trip() {
let s = lower_statement_body("SELECT 1 INTO v FROM t;");
let a = extract_table_accesses(&s);
let json = serde_json::to_string(&a[0]).unwrap();
let back: TableAccess = serde_json::from_str(&json).unwrap();
assert_eq!(back, a[0]);
assert!(json.contains("\"access\":\"read\""));
}
#[test]
fn non_shrinking_for_update_terminates_and_reports_limit() {
let stmts = vec![Statement::BareLoop {
body_text: "FOR UPDATE".to_string(),
}];
let (accesses, outcome) = extract_table_accesses_bounded(&stmts);
assert!(
outcome.limit_hit,
"non-shrinking BareLoop must trip the depth cap, \
outcome={outcome:?}, accesses={accesses:?}"
);
assert!(outcome.truncated_bodies >= 1);
let _ = extract_table_accesses(&stmts);
}
#[test]
fn nested_block_update_yields_write_edge() {
let s = lower_statement_body("BEGIN UPDATE secret_table SET x = 1 WHERE id = 9; END;");
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "SECRET_TABLE" && x.access == AccessKind::Write),
"a nested-block UPDATE must surface a Write of SECRET_TABLE: {a:?}"
);
}
#[test]
fn nested_declare_block_dml_yields_edges() {
let s = lower_statement_body(
"DECLARE v NUMBER; BEGIN INSERT INTO audit_log SELECT id FROM staging; END;",
);
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "AUDIT_LOG" && x.access == AccessKind::Write),
"nested-block INSERT target must be a Write: {a:?}"
);
assert!(
a.iter()
.any(|x| x.table == "STAGING" && x.access == AccessKind::Read),
"nested-block sub-SELECT must be a Read: {a:?}"
);
}
#[test]
fn if_arm_nested_block_dml_yields_edges() {
let s = lower_statement_body(
"IF p_flag = 1 THEN BEGIN UPDATE accounts SET bal = 0 WHERE id = 1; END; END IF;",
);
let a = extract_table_accesses(&s);
assert!(
a.iter()
.any(|x| x.table == "ACCOUNTS" && x.access == AccessKind::Write),
"an IF-arm nested-block UPDATE must surface a Write of ACCOUNTS: {a:?}"
);
}
}