1use serde::{Deserialize, Serialize};
30
31use crate::is_ident_byte;
32use crate::stmt::{SqlVerb, Statement};
33
34#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36pub struct TableAccess {
37 pub schema: Option<String>,
39 pub table: String,
41 pub access: AccessKind,
42}
43
44#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum AccessKind {
47 Read,
48 Write,
49}
50
51#[must_use]
62pub fn extract_table_accesses(stmts: &[Statement]) -> Vec<TableAccess> {
63 extract_table_accesses_bounded(stmts).0
64}
65
66#[must_use]
72pub fn extract_table_accesses_bounded(
73 stmts: &[Statement],
74) -> (Vec<TableAccess>, crate::RecursionOutcome) {
75 let mut out: Vec<TableAccess> = Vec::new();
76 let mut outcome = crate::RecursionOutcome::default();
77 walk_table_accesses(stmts, 0, &mut out, &mut outcome);
78 (dedup(out), outcome)
79}
80
81fn walk_table_accesses(
82 stmts: &[Statement],
83 depth: usize,
84 out: &mut Vec<TableAccess>,
85 outcome: &mut crate::RecursionOutcome,
86) {
87 macro_rules! recurse_body {
92 ($text:expr) => {{
93 if depth + 1 >= crate::MAX_RELOWER_DEPTH {
94 outcome.note_truncated();
95 } else {
96 let lowered = crate::lower_statement_body($text);
97 walk_table_accesses(&lowered, depth + 1, out, outcome);
98 }
99 }};
100 }
101 for stmt in stmts {
102 match stmt {
103 Statement::Sql { verb, raw_text } => {
104 accesses_from_sql(*verb, raw_text, out);
105 }
106 Statement::If {
107 arms,
108 else_body_text,
109 } => {
110 for arm in arms {
111 recurse_body!(&arm.body_text);
112 }
113 if let Some(eb) = else_body_text {
114 recurse_body!(eb);
115 }
116 }
117 Statement::ForLoop {
118 range_text,
119 body_text,
120 ..
121 } => {
122 if let Some(inner) = parenthesised_query(range_text) {
129 recurse_body!(inner);
130 }
131 recurse_body!(body_text);
132 }
133 Statement::WhileLoop { body_text, .. } | Statement::BareLoop { body_text } => {
134 recurse_body!(body_text);
135 }
136 Statement::NestedBlock { body_text } => {
137 let inner = crate::calls::strip_block_wrapper(body_text);
148 if inner != body_text.as_str() {
149 recurse_body!(inner);
150 }
151 }
152 _ => {}
153 }
154 }
155}
156
157fn parenthesised_query(range_text: &str) -> Option<&str> {
162 let trimmed = range_text.trim();
163 let inner = trimmed.strip_prefix('(')?.strip_suffix(')')?;
164 Some(inner.trim())
165}
166
167fn accesses_from_sql(verb: SqlVerb, raw: &str, out: &mut Vec<TableAccess>) {
168 let upper = crate::fact_emit::mask_string_literals(&raw.to_ascii_uppercase());
175 match verb {
176 SqlVerb::Select => {
177 for t in tables_after(&upper, raw, "FROM") {
178 push(out, t, AccessKind::Read);
179 }
180 for t in tables_after(&upper, raw, "JOIN") {
181 push(out, t, AccessKind::Read);
182 }
183 }
184 SqlVerb::Insert => {
185 for t in tables_after(&upper, raw, "INTO") {
186 push(out, t, AccessKind::Write);
187 }
188 for t in tables_after(&upper, raw, "FROM") {
190 push(out, t, AccessKind::Read);
191 }
192 }
193 SqlVerb::Update => {
194 for t in tables_after(&upper, raw, "UPDATE") {
195 push(out, t, AccessKind::Write);
196 }
197 for t in tables_after(&upper, raw, "FROM") {
198 push(out, t, AccessKind::Read);
199 }
200 }
201 SqlVerb::Delete => {
202 let target = delete_target(raw);
217 let target_folded = target.as_deref().map(folded_name);
218 if let Some(t) = target {
219 push(out, t, AccessKind::Write);
220 }
221 let mut target_consumed = false;
226 for t in tables_after(&upper, raw, "FROM") {
227 if !target_consumed && Some(folded_name(&t)) == target_folded {
228 target_consumed = true;
229 continue;
230 }
231 push(out, t, AccessKind::Read);
232 }
233 for t in tables_after(&upper, raw, "JOIN") {
234 push(out, t, AccessKind::Read);
235 }
236 }
237 SqlVerb::Merge => {
238 for t in tables_after(&upper, raw, "INTO") {
239 push(out, t, AccessKind::Write);
240 }
241 for t in tables_after(&upper, raw, "USING") {
242 push(out, t, AccessKind::Read);
243 }
244 }
245 }
246}
247
248fn folded_name(raw_name: &str) -> String {
251 raw_name.to_ascii_uppercase()
252}
253
254fn delete_target(raw: &str) -> Option<String> {
261 let bytes = raw.as_bytes();
262 let mut i = 0;
264 while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
265 i += 1;
266 }
267 i = skip_ws(bytes, i);
268 if bytes[i..]
275 .get(..4)
276 .is_some_and(|w| w.eq_ignore_ascii_case(b"FROM"))
277 && (i + 4 >= bytes.len() || !is_ident_byte(bytes[i + 4]))
278 {
279 i = skip_ws(bytes, i + 4);
280 }
281 let start = i;
283 while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
284 i += 1;
285 }
286 if i > start {
287 Some(raw[start..i].to_string())
288 } else {
289 None
290 }
291}
292
293fn skip_ws(bytes: &[u8], mut i: usize) -> usize {
294 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
295 i += 1;
296 }
297 i
298}
299
300fn tables_after(upper: &str, raw: &str, keyword: &str) -> Vec<String> {
310 let mut out = Vec::new();
311 let kw = keyword.to_ascii_uppercase();
312 let traverse_commas = kw == "FROM";
313 let bytes = upper.as_bytes();
314 let mut search = 0;
315 while let Some(rel) = upper[search..].find(&kw) {
316 let abs = search + rel;
317 search = abs + kw.len();
318 let prev_ok = abs == 0 || !is_ident_byte(bytes[abs - 1]);
320 let after = abs + kw.len();
321 let next_ok = after >= bytes.len() || !is_ident_byte(bytes[after]);
322 if !(prev_ok && next_ok) {
323 continue;
324 }
325 let mut i = after;
326 loop {
327 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
329 i += 1;
330 }
331 let start = i;
332 while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
333 i += 1;
334 }
335 if i == start {
336 break;
337 }
338 out.push(raw[start..i].to_string());
339 if !traverse_commas {
340 break;
341 }
342 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
345 i += 1;
346 }
347 if i < bytes.len() && is_ident_byte(bytes[i]) {
348 while i < bytes.len() && (is_ident_byte(bytes[i]) || bytes[i] == b'.') {
349 i += 1;
350 }
351 }
352 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
355 i += 1;
356 }
357 if i >= bytes.len() || bytes[i] != b',' {
358 break;
359 }
360 i += 1; }
362 }
363 out
364}
365
366fn push(out: &mut Vec<TableAccess>, raw_name: String, access: AccessKind) {
367 let folded = raw_name.to_ascii_uppercase();
368 let (schema, table) = match folded.rsplit_once('.') {
369 Some((s, t)) if !t.is_empty() => (Some(s.to_string()), t.to_string()),
370 _ => (None, folded),
371 };
372 if table.is_empty() || table == "DUAL" {
373 return;
374 }
375 out.push(TableAccess {
376 schema,
377 table,
378 access,
379 });
380}
381
382fn dedup(mut v: Vec<TableAccess>) -> Vec<TableAccess> {
386 let mut seen: std::collections::BTreeSet<(Option<String>, String, AccessKind)> =
387 std::collections::BTreeSet::new();
388 v.retain(|a| seen.insert((a.schema.clone(), a.table.clone(), a.access)));
389 v
390}
391
392impl PartialOrd for AccessKind {
393 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
394 Some(self.cmp(other))
395 }
396}
397impl Ord for AccessKind {
398 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
399 (*self as u8).cmp(&(*other as u8))
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::lower_statement_body;
407
408 #[test]
409 fn select_from_is_a_read() {
410 let s = lower_statement_body("SELECT id INTO v FROM employees;");
411 let a = extract_table_accesses(&s);
412 assert_eq!(a.len(), 1);
413 assert_eq!(a[0].table, "EMPLOYEES");
414 assert_eq!(a[0].access, AccessKind::Read);
415 }
416
417 #[test]
418 fn legacy_comma_join_reads_every_table() {
419 let s = lower_statement_body("SELECT a.x INTO v FROM emp a, dept b WHERE a.d = b.id;");
422 let acc = extract_table_accesses(&s);
423 for t in ["EMP", "DEPT"] {
424 assert!(
425 acc.iter()
426 .any(|x| x.table == t && x.access == AccessKind::Read),
427 "comma-join must read {t}: {acc:?}"
428 );
429 }
430 let s3 = lower_statement_body("SELECT x INTO v FROM a, b, c WHERE 1 = 1;");
432 let acc3 = extract_table_accesses(&s3);
433 for t in ["A", "B", "C"] {
434 assert!(
435 acc3.iter()
436 .any(|x| x.table == t && x.access == AccessKind::Read),
437 "comma-join must read {t}: {acc3:?}"
438 );
439 }
440 }
441
442 #[test]
443 fn clause_keyword_inside_string_literal_is_not_a_phantom_table() {
444 let s = lower_statement_body("UPDATE log SET msg = 'failed to INSERT INTO orders';");
447 let acc = extract_table_accesses(&s);
448 assert!(
449 !acc.iter().any(|x| x.table == "ORDERS"),
450 "INTO inside a literal must not mint a phantom ORDERS access: {acc:?}"
451 );
452 assert!(
453 acc.iter()
454 .any(|x| x.table == "LOG" && x.access == AccessKind::Write),
455 "the real UPDATE target LOG must still be a Write: {acc:?}"
456 );
457 }
458
459 #[test]
460 fn delete_with_multibyte_first_token_does_not_panic() {
461 let s = lower_statement_body("DELETE é★ WHERE x = 1;");
468 let _ = extract_table_accesses(&s);
469 let s2 = lower_statement_body("DELETE é★"); let _ = extract_table_accesses(&s2);
471 }
472
473 #[test]
474 fn insert_into_is_a_write() {
475 let s = lower_statement_body("INSERT INTO audit_log VALUES (1, 2);");
476 let a = extract_table_accesses(&s);
477 assert!(
478 a.iter()
479 .any(|x| x.table == "AUDIT_LOG" && x.access == AccessKind::Write)
480 );
481 }
482
483 #[test]
484 fn insert_select_records_write_and_read() {
485 let s =
486 lower_statement_body("INSERT INTO summary SELECT dept_id, COUNT(*) FROM employees;");
487 let a = extract_table_accesses(&s);
488 assert!(
489 a.iter()
490 .any(|x| x.table == "SUMMARY" && x.access == AccessKind::Write)
491 );
492 assert!(
493 a.iter()
494 .any(|x| x.table == "EMPLOYEES" && x.access == AccessKind::Read)
495 );
496 }
497
498 #[test]
499 fn update_is_a_write() {
500 let s = lower_statement_body("UPDATE employees SET salary = salary * 1.1;");
501 let a = extract_table_accesses(&s);
502 assert!(
503 a.iter()
504 .any(|x| x.table == "EMPLOYEES" && x.access == AccessKind::Write)
505 );
506 }
507
508 #[test]
509 fn delete_from_is_a_write() {
510 let s = lower_statement_body("DELETE FROM stale_rows WHERE id < 100;");
511 let a = extract_table_accesses(&s);
512 assert!(
513 a.iter()
514 .any(|x| x.table == "STALE_ROWS" && x.access == AccessKind::Write)
515 );
516 }
517
518 #[test]
523 fn delete_with_where_subquery_target_is_write_subquery_is_read() {
524 let s = lower_statement_body("DELETE FROM t WHERE id IN (SELECT id FROM staging);");
525 let a = extract_table_accesses(&s);
526 assert!(
527 a.iter()
528 .any(|x| x.table == "T" && x.access == AccessKind::Write),
529 "DELETE target T must be a Write: {a:?}"
530 );
531 assert!(
532 a.iter()
533 .any(|x| x.table == "STAGING" && x.access == AccessKind::Read),
534 "WHERE sub-SELECT table STAGING must be a Read: {a:?}"
535 );
536 assert!(
537 !a.iter()
538 .any(|x| x.table == "STAGING" && x.access == AccessKind::Write),
539 "STAGING must NEVER be classified as a Write: {a:?}"
540 );
541 }
542
543 #[test]
550 fn from_less_delete_is_a_write() {
551 let s = lower_statement_body("DELETE employees WHERE id = 5;");
552 let a = extract_table_accesses(&s);
553 assert_eq!(a.len(), 1, "exactly one access expected: {a:?}");
554 assert_eq!(a[0].table, "EMPLOYEES");
555 assert_eq!(a[0].schema, None);
556 assert_eq!(a[0].access, AccessKind::Write);
557 }
558
559 #[test]
562 fn from_less_qualified_delete_is_a_write() {
563 let s = lower_statement_body("DELETE hr.audit_log WHERE ts < SYSDATE - 30;");
564 let a = extract_table_accesses(&s);
565 assert_eq!(a.len(), 1, "exactly one access expected: {a:?}");
566 assert_eq!(a[0].schema.as_deref(), Some("HR"));
567 assert_eq!(a[0].table, "AUDIT_LOG");
568 assert_eq!(a[0].access, AccessKind::Write);
569 }
570
571 #[test]
575 fn from_less_delete_with_where_subquery_target_write_subquery_read() {
576 let s = lower_statement_body("DELETE t WHERE id IN (SELECT id FROM staging);");
577 let a = extract_table_accesses(&s);
578 assert!(
579 a.iter()
580 .any(|x| x.table == "T" && x.access == AccessKind::Write),
581 "FROM-less DELETE target T must be a Write: {a:?}"
582 );
583 assert!(
584 a.iter()
585 .any(|x| x.table == "STAGING" && x.access == AccessKind::Read),
586 "WHERE sub-SELECT table STAGING must be a Read: {a:?}"
587 );
588 assert!(
589 !a.iter()
590 .any(|x| x.table == "STAGING" && x.access == AccessKind::Write),
591 "STAGING must NEVER be classified as a Write: {a:?}"
592 );
593 }
594
595 #[test]
596 fn merge_writes_target_reads_source() {
597 let s = lower_statement_body(
598 "MERGE INTO target t USING source s ON (t.id = s.id) WHEN MATCHED THEN UPDATE SET t.v = s.v;",
599 );
600 let a = extract_table_accesses(&s);
601 assert!(
602 a.iter()
603 .any(|x| x.table == "TARGET" && x.access == AccessKind::Write)
604 );
605 assert!(
606 a.iter()
607 .any(|x| x.table == "SOURCE" && x.access == AccessKind::Read)
608 );
609 }
610
611 #[test]
612 fn schema_qualified_table_split() {
613 let s = lower_statement_body("SELECT 1 INTO v FROM hr.employees;");
614 let a = extract_table_accesses(&s);
615 assert_eq!(a[0].schema.as_deref(), Some("HR"));
616 assert_eq!(a[0].table, "EMPLOYEES");
617 }
618
619 #[test]
620 fn dual_is_filtered_out() {
621 let s = lower_statement_body("SELECT SYSDATE INTO v FROM dual;");
622 let a = extract_table_accesses(&s);
623 assert!(a.is_empty());
624 }
625
626 #[test]
627 fn loop_body_dml_recursed() {
628 let s = lower_statement_body("FOR i IN 1..10 LOOP INSERT INTO log VALUES (i); END LOOP;");
629 let a = extract_table_accesses(&s);
630 assert!(
631 a.iter()
632 .any(|x| x.table == "LOG" && x.access == AccessKind::Write)
633 );
634 }
635
636 #[test]
640 fn cursor_for_loop_range_select_table_is_read() {
641 let s = lower_statement_body(
642 "FOR r IN (SELECT id FROM src) LOOP \
643 INSERT INTO dst VALUES (r.id); \
644 END LOOP;",
645 );
646 let a = extract_table_accesses(&s);
647 assert!(
648 a.iter()
649 .any(|x| x.table == "SRC" && x.access == AccessKind::Read),
650 "cursor-FOR-loop range sub-SELECT read of SRC must be extracted: {a:?}"
651 );
652 assert!(
654 a.iter()
655 .any(|x| x.table == "DST" && x.access == AccessKind::Write),
656 "loop body write of DST must still be extracted: {a:?}"
657 );
658 }
659
660 #[test]
663 fn numeric_range_for_loop_yields_no_extra_tables() {
664 let s = lower_statement_body("FOR i IN 1..10 LOOP NULL; END LOOP;");
665 let a = extract_table_accesses(&s);
666 assert!(a.is_empty(), "numeric range must not invent tables: {a:?}");
667 }
668
669 #[test]
670 fn duplicate_access_triples_dedupe() {
671 let s = lower_statement_body("SELECT 1 INTO a FROM t; SELECT 2 INTO b FROM t;");
672 let acc = extract_table_accesses(&s);
673 assert_eq!(acc.iter().filter(|x| x.table == "T").count(), 1);
675 }
676
677 #[test]
678 fn join_tables_are_reads() {
679 let s = lower_statement_body(
680 "SELECT 1 INTO v FROM employees e JOIN departments d ON e.dept = d.id;",
681 );
682 let a = extract_table_accesses(&s);
683 assert!(a.iter().any(|x| x.table == "EMPLOYEES"));
684 assert!(a.iter().any(|x| x.table == "DEPARTMENTS"));
685 assert!(a.iter().all(|x| x.access == AccessKind::Read));
686 }
687
688 #[test]
689 fn serde_round_trip() {
690 let s = lower_statement_body("SELECT 1 INTO v FROM t;");
691 let a = extract_table_accesses(&s);
692 let json = serde_json::to_string(&a[0]).unwrap();
693 let back: TableAccess = serde_json::from_str(&json).unwrap();
694 assert_eq!(back, a[0]);
695 assert!(json.contains("\"access\":\"read\""));
696 }
697
698 #[test]
703 fn non_shrinking_for_update_terminates_and_reports_limit() {
704 let stmts = vec![Statement::BareLoop {
705 body_text: "FOR UPDATE".to_string(),
706 }];
707 let (accesses, outcome) = extract_table_accesses_bounded(&stmts);
708 assert!(
709 outcome.limit_hit,
710 "non-shrinking BareLoop must trip the depth cap, \
711 outcome={outcome:?}, accesses={accesses:?}"
712 );
713 assert!(outcome.truncated_bodies >= 1);
714 let _ = extract_table_accesses(&stmts);
715 }
716
717 #[test]
723 fn nested_block_update_yields_write_edge() {
724 let s = lower_statement_body("BEGIN UPDATE secret_table SET x = 1 WHERE id = 9; END;");
725 let a = extract_table_accesses(&s);
726 assert!(
727 a.iter()
728 .any(|x| x.table == "SECRET_TABLE" && x.access == AccessKind::Write),
729 "a nested-block UPDATE must surface a Write of SECRET_TABLE: {a:?}"
730 );
731 }
732
733 #[test]
736 fn nested_declare_block_dml_yields_edges() {
737 let s = lower_statement_body(
738 "DECLARE v NUMBER; BEGIN INSERT INTO audit_log SELECT id FROM staging; END;",
739 );
740 let a = extract_table_accesses(&s);
741 assert!(
742 a.iter()
743 .any(|x| x.table == "AUDIT_LOG" && x.access == AccessKind::Write),
744 "nested-block INSERT target must be a Write: {a:?}"
745 );
746 assert!(
747 a.iter()
748 .any(|x| x.table == "STAGING" && x.access == AccessKind::Read),
749 "nested-block sub-SELECT must be a Read: {a:?}"
750 );
751 }
752
753 #[test]
758 fn if_arm_nested_block_dml_yields_edges() {
759 let s = lower_statement_body(
760 "IF p_flag = 1 THEN BEGIN UPDATE accounts SET bal = 0 WHERE id = 1; END; END IF;",
761 );
762 let a = extract_table_accesses(&s);
763 assert!(
764 a.iter()
765 .any(|x| x.table == "ACCOUNTS" && x.access == AccessKind::Write),
766 "an IF-arm nested-block UPDATE must surface a Write of ACCOUNTS: {a:?}"
767 );
768 }
769}