1use std::num::NonZeroUsize;
91use std::sync::{Arc, Mutex as StdMutex};
92
93use async_trait::async_trait;
94use limbo::params::Params as LimboParams;
95use limbo::Builder;
96use tokio::sync::Mutex as TokioMutex;
97
98const STMT_CACHE_CAPACITY: usize = 128;
103
104use oxisql_core::{
105 ColumnInfo, Connection, ForeignKeyInfo, IndexInfo, OxiSqlError, PreparedStatement, Row,
106 TableInfo, TableType, ToSqlValue, Transaction, Value,
107};
108
109use crate::error::SqliteCompatError;
110use crate::types::{limbo_to_core, rewrite_params, split_statements};
111
112type StmtCache = Arc<StdMutex<lru::LruCache<String, limbo::Statement>>>;
121
122fn new_stmt_cache() -> StmtCache {
124 let cap = NonZeroUsize::new(STMT_CACHE_CAPACITY).unwrap_or(NonZeroUsize::MIN);
127 Arc::new(StdMutex::new(lru::LruCache::new(cap)))
128}
129
130async fn exec_rewritten(
139 conn: &limbo::Connection,
140 sql: &str,
141 limbo_params: Vec<limbo::Value>,
142 cache: Option<&StmtCache>,
143) -> Result<u64, SqliteCompatError> {
144 let lp = if limbo_params.is_empty() {
145 LimboParams::None
146 } else {
147 LimboParams::Positional(limbo_params)
148 };
149
150 if let Some(c) = cache {
164 let is_cached = c
166 .lock()
167 .map_err(|e| SqliteCompatError::Other(format!("stmt_cache lock poisoned: {e}")))?
168 .contains(sql);
169
170 if !is_cached {
171 let fresh = conn.prepare(sql).await.map_err(SqliteCompatError::from)?;
173 c.lock()
174 .map_err(|e| SqliteCompatError::Other(format!("stmt_cache lock poisoned: {e}")))?
175 .put(sql.to_owned(), fresh);
176 }
177 }
178
179 conn.execute(sql, lp)
181 .await
182 .map_err(SqliteCompatError::from)?;
183
184 let changes = fetch_scalar_i64(conn, "SELECT changes()").await?;
188 Ok(changes.max(0) as u64)
189}
190
191async fn query_rewritten(
194 conn: &limbo::Connection,
195 sql: &str,
196 limbo_params: Vec<limbo::Value>,
197) -> Result<Vec<Row>, SqliteCompatError> {
198 let lp = if limbo_params.is_empty() {
199 LimboParams::None
200 } else {
201 LimboParams::Positional(limbo_params)
202 };
203
204 let mut stmt = conn.prepare(sql).await.map_err(SqliteCompatError::from)?;
205 let cols: Vec<String> = stmt.columns().iter().map(|c| c.name().to_owned()).collect();
206 let mut rows_iter = stmt.query(lp).await.map_err(SqliteCompatError::from)?;
207
208 let mut rows: Vec<Row> = Vec::new();
209 while let Some(limbo_row) = rows_iter.next().await.map_err(SqliteCompatError::from)? {
210 let mut values: Vec<Value> = Vec::with_capacity(cols.len());
211 for idx in 0..limbo_row.column_count() {
212 let raw = limbo_row.get_value(idx).map_err(SqliteCompatError::from)?;
213 values.push(limbo_to_core(raw)?);
214 }
215 rows.push(Row::new(cols.clone(), values));
216 }
217 Ok(rows)
218}
219
220async fn fetch_scalar_i64(conn: &limbo::Connection, sql: &str) -> Result<i64, SqliteCompatError> {
222 let rows = query_rewritten(conn, sql, vec![]).await?;
223 if let Some(row) = rows.first() {
224 match row.get_by_index(0) {
225 Some(Value::I64(n)) => return Ok(*n),
226 Some(Value::Null) => return Ok(0),
227 Some(other) => {
228 return Err(SqliteCompatError::TypeMap(format!(
229 "expected i64 from scalar query, got {other:?}"
230 )))
231 }
232 None => {}
233 }
234 }
235 Ok(0)
236}
237
238#[derive(Clone)]
253pub struct SqliteConnection {
254 conn: limbo::Connection,
255 txn_lock: Arc<TokioMutex<()>>,
256 stmt_cache: StmtCache,
257 path: String,
258}
259
260impl std::fmt::Debug for SqliteConnection {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 let cache_len = self.stmt_cache.lock().map(|g| g.len()).unwrap_or(0);
263 f.debug_struct("SqliteConnection")
264 .field("path", &self.path)
265 .field("stmt_cache_len", &cache_len)
266 .finish_non_exhaustive()
267 }
268}
269
270impl SqliteConnection {
271 pub async fn open(path: &str) -> Result<Self, OxiSqlError> {
280 let db = Builder::new_local(path)
281 .build()
282 .await
283 .map_err(|e| OxiSqlError::Other(format!("limbo open error: {e}")))?;
284 let conn = db
285 .connect()
286 .map_err(|e| OxiSqlError::Other(format!("limbo connect error: {e}")))?;
287 Ok(Self {
288 conn,
289 txn_lock: Arc::new(TokioMutex::new(())),
290 stmt_cache: new_stmt_cache(),
291 path: path.to_owned(),
292 })
293 }
294
295 pub async fn open_memory() -> Result<Self, OxiSqlError> {
301 Self::open(":memory:").await
302 }
303
304 pub fn path(&self) -> &str {
306 &self.path
307 }
308}
309
310#[async_trait]
313impl Connection for SqliteConnection {
314 async fn execute(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
315 let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
316 exec_rewritten(&self.conn, &rewritten, limbo_params, Some(&self.stmt_cache))
317 .await
318 .map_err(OxiSqlError::from)
319 }
320
321 async fn query(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
322 let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
323 query_rewritten(&self.conn, &rewritten, limbo_params)
324 .await
325 .map_err(OxiSqlError::from)
326 }
327
328 async fn transaction(&self) -> Result<Box<dyn Transaction + '_>, OxiSqlError> {
329 let guard = self.txn_lock.lock().await;
333 self.conn
334 .execute("BEGIN", LimboParams::None)
335 .await
336 .map_err(|e| OxiSqlError::Other(format!("BEGIN failed: {e}")))?;
337 Ok(Box::new(SqliteTransaction {
338 conn: self.conn.clone(),
339 stmt_cache: Arc::clone(&self.stmt_cache),
342 _guard: guard,
345 done: false,
346 }))
347 }
348
349 async fn execute_batch(&self, sql: &str) -> Result<u64, OxiSqlError> {
350 let stmts = split_statements(sql);
353 let mut total = 0u64;
354 for stmt in stmts {
355 total += self.execute(stmt, &[]).await?;
356 }
357 Ok(total)
358 }
359
360 async fn ping(&self) -> Result<(), OxiSqlError> {
361 self.query("SELECT 1", &[]).await?;
362 Ok(())
363 }
364
365 async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement + '_>, OxiSqlError> {
366 Ok(Box::new(SqlitePrepared {
367 conn: &self.conn,
368 stmt_cache: Arc::clone(&self.stmt_cache),
369 sql: sql.to_owned(),
370 }))
371 }
372
373 async fn tables(&self) -> Result<Vec<TableInfo>, OxiSqlError> {
376 let rows = self
377 .query(
378 "SELECT name, type FROM sqlite_master \
379 WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' \
380 ORDER BY name",
381 &[],
382 )
383 .await?;
384
385 let infos = rows
386 .into_iter()
387 .map(|row| {
388 let name = row
389 .get_by_index(0)
390 .and_then(|v| {
391 if let Value::Text(s) = v {
392 Some(s.clone())
393 } else {
394 None
395 }
396 })
397 .unwrap_or_default();
398 let ttype_str = row
399 .get_by_index(1)
400 .and_then(|v| {
401 if let Value::Text(s) = v {
402 Some(s.as_str())
403 } else {
404 None
405 }
406 })
407 .unwrap_or("table");
408 let table_type = match ttype_str {
409 "view" => TableType::View,
410 _ => TableType::Base,
411 };
412 TableInfo {
413 name,
414 schema: None,
415 table_type,
416 }
417 })
418 .collect();
419 Ok(infos)
420 }
421
422 async fn columns(&self, table: &str) -> Result<Vec<ColumnInfo>, OxiSqlError> {
423 let sql = format!("PRAGMA table_info(\"{table}\")");
425 let rows = self.query(&sql, &[]).await?;
426
427 let infos = rows
428 .into_iter()
429 .map(|row| {
430 let text_at = |r: &Row, idx: usize| -> String {
432 r.get_by_index(idx)
433 .and_then(|v| match v {
434 Value::Text(s) => Some(s.clone()),
435 Value::I64(n) => Some(n.to_string()),
436 Value::Null => Some(String::new()),
437 _ => None,
438 })
439 .unwrap_or_default()
440 };
441 let i64_at = |r: &Row, idx: usize| -> i64 {
442 r.get_by_index(idx)
443 .and_then(|v| {
444 if let Value::I64(n) = v {
445 Some(*n)
446 } else {
447 None
448 }
449 })
450 .unwrap_or(0)
451 };
452
453 let ordinal = i64_at(&row, 0) as u32 + 1; let name = text_at(&row, 1);
455 let data_type = text_at(&row, 2);
456 let notnull = i64_at(&row, 3) != 0;
457 let default_val = row.get_by_index(4).and_then(|v| match v {
458 Value::Text(s) => Some(s.clone()),
459 Value::Null => None,
460 other => Some(format!("{other:?}")),
461 });
462
463 ColumnInfo {
464 name,
465 ordinal_position: ordinal,
466 data_type,
467 nullable: !notnull,
468 default: default_val,
469 max_length: None,
470 numeric_precision: None,
471 numeric_scale: None,
472 }
473 })
474 .collect();
475 Ok(infos)
476 }
477
478 async fn indexes(&self, table: &str) -> Result<Vec<IndexInfo>, OxiSqlError> {
479 let sql = "SELECT name, sql FROM sqlite_master \
484 WHERE type='index' AND tbl_name=$1 AND name NOT LIKE 'sqlite_%'";
485 let rows = self.query(sql, &[&table]).await?;
486
487 let mut infos: Vec<IndexInfo> = Vec::new();
488 for row in rows {
489 let name = row
490 .get_by_index(0)
491 .and_then(|v| {
492 if let Value::Text(s) = v {
493 Some(s.clone())
494 } else {
495 None
496 }
497 })
498 .unwrap_or_default();
499 let idx_sql = row
500 .get_by_index(1)
501 .and_then(|v| {
502 if let Value::Text(s) = v {
503 Some(s.clone())
504 } else {
505 None
506 }
507 })
508 .unwrap_or_default();
509
510 let upper = idx_sql.to_ascii_uppercase();
512 let unique = upper.contains("UNIQUE");
513
514 let columns: Vec<String> =
516 if let (Some(open), Some(close)) = (idx_sql.rfind('('), idx_sql.rfind(')')) {
517 idx_sql[open + 1..close]
518 .split(',')
519 .map(|c| c.trim().to_string())
520 .filter(|c| !c.is_empty())
521 .collect()
522 } else {
523 vec![]
524 };
525
526 infos.push(IndexInfo {
527 name,
528 columns,
529 unique,
530 primary: false,
531 });
532 }
533 Ok(infos)
534 }
535
536 async fn foreign_keys(&self, table: &str) -> Result<Vec<ForeignKeyInfo>, OxiSqlError> {
537 let sql = "SELECT sql FROM sqlite_master WHERE type = 'table' AND name = ?";
541 let rows = query_rewritten(&self.conn, sql, vec![limbo::Value::Text(table.into())])
542 .await
543 .map_err(OxiSqlError::from)?;
544
545 let ddl = match rows.first() {
546 Some(row) => match row.get_by_index(0) {
547 Some(Value::Text(s)) if !s.is_empty() => s.clone(),
548 _ => return Ok(vec![]),
549 },
550 None => return Ok(vec![]),
551 };
552
553 Ok(parse_foreign_keys(&ddl, table))
554 }
555}
556
557fn strip_sql_quotes(s: &str) -> &str {
564 let s = s.trim();
565 let bytes = s.as_bytes();
566 let len = bytes.len();
567 if len >= 2 {
568 let (open, close): (u8, u8) = match bytes[0] {
569 b'"' => (b'"', b'"'),
570 b'`' => (b'`', b'`'),
571 b'[' => (b'[', b']'),
572 _ => return s,
573 };
574 if bytes[0] == open && bytes[len - 1] == close {
575 return &s[1..len - 1];
576 }
577 }
578 s
579}
580
581fn split_top_level_commas(text: &str) -> Vec<&str> {
585 let mut depth: usize = 0;
586 let mut parts: Vec<&str> = Vec::new();
587 let mut start = 0usize;
588 for (i, ch) in text.char_indices() {
589 match ch {
590 '(' => depth += 1,
591 ')' => depth = depth.saturating_sub(1),
592 ',' if depth == 0 => {
593 parts.push(&text[start..i]);
594 start = i + 1;
595 }
596 _ => {}
597 }
598 }
599 parts.push(&text[start..]);
600 parts
601}
602
603fn find_paren_content(s: &str, offset: usize) -> Option<(usize, usize)> {
608 let slice = &s[offset..];
609 let rel_open = slice.find('(')?;
610 let abs_open = offset + rel_open;
611 let mut depth: usize = 0;
613 for (i, ch) in s[abs_open..].char_indices() {
614 match ch {
615 '(' => depth += 1,
616 ')' => {
617 depth -= 1;
618 if depth == 0 {
619 return Some((abs_open + 1, abs_open + i));
620 }
621 }
622 _ => {}
623 }
624 }
625 None
626}
627
628fn parse_references_clause(
633 upper: &str,
634 original: &str,
635 pos: usize,
636) -> Option<(String, Vec<String>)> {
637 let rest_upper = upper[pos..].trim_start();
638 if !rest_upper.starts_with("REFERENCES") {
639 return None;
640 }
641 let consumed_ws = upper[pos..].len() - upper[pos..].trim_start().len();
642 let after_ref = pos + consumed_ws + "REFERENCES".len();
643
644 let rest_orig = original[after_ref..].trim_start();
646 let ws_skip = original[after_ref..].len() - original[after_ref..].trim_start().len();
647 let table_start = after_ref + ws_skip;
648
649 let table_end = rest_orig
651 .find(|c: char| c.is_whitespace() || c == '(' || c == ',' || c == ')')
652 .map(|p| table_start + p)
653 .unwrap_or(original.len());
654
655 let raw_table = strip_sql_quotes(&original[table_start..table_end]).to_owned();
656 if raw_table.is_empty() {
657 return None;
658 }
659
660 let mut cols: Vec<String> = Vec::new();
662 let paren_search_start = table_end;
663 let rest_after_table = upper[paren_search_start..].trim_start();
664 if rest_after_table.starts_with('(') {
665 let ws2 =
666 upper[paren_search_start..].len() - upper[paren_search_start..].trim_start().len();
667 let abs_open_search = paren_search_start + ws2;
668 if let Some((inner_start, inner_end)) = find_paren_content(original, abs_open_search) {
669 let inner = &original[inner_start..inner_end];
670 for part in split_top_level_commas(inner) {
671 let col = strip_sql_quotes(part.trim()).to_owned();
672 if !col.is_empty() {
673 cols.push(col);
674 }
675 }
676 }
677 }
678
679 Some((raw_table, cols))
680}
681
682fn parse_foreign_keys(ddl: &str, table: &str) -> Vec<ForeignKeyInfo> {
691 let ddl = ddl.replace('\r', " ");
694
695 let body_range = match find_paren_content(&ddl, 0) {
697 Some(r) => r,
698 None => return vec![],
699 };
700 let body = &ddl[body_range.0..body_range.1];
701 let body_upper = body.to_ascii_uppercase();
702
703 let mut results: Vec<ForeignKeyInfo> = Vec::new();
704
705 let mut search_pos = 0usize;
708 while let Some(rel) = body_upper[search_pos..].find("FOREIGN KEY") {
709 let fk_pos = search_pos + rel;
710
711 let constraint_name: Option<String> = {
713 let before = body[..fk_pos].trim_end();
714 let before_upper = before.to_ascii_uppercase();
715 if let Some(c_rel) = before_upper.rfind("CONSTRAINT") {
716 let after_constraint = before[c_rel + "CONSTRAINT".len()..].trim_start();
717 let name_end = after_constraint
718 .find(|c: char| c.is_whitespace() || c == '(' || c == ',')
719 .unwrap_or(after_constraint.len());
720 let raw = strip_sql_quotes(&after_constraint[..name_end]);
721 if !raw.is_empty() {
722 Some(raw.to_owned())
723 } else {
724 None
725 }
726 } else {
727 None
728 }
729 };
730
731 let after_fk = fk_pos + "FOREIGN KEY".len();
733 let paren_start_search = {
734 let ws = body[after_fk..].len() - body[after_fk..].trim_start().len();
735 after_fk + ws
736 };
737
738 let (local_cols, refs_search_start) =
739 if let Some((inner_s, inner_e)) = find_paren_content(body, paren_start_search) {
740 let cols: Vec<String> = split_top_level_commas(&body[inner_s..inner_e])
741 .into_iter()
742 .map(|c| strip_sql_quotes(c.trim()).to_owned())
743 .filter(|c| !c.is_empty())
744 .collect();
745 (cols, inner_e + 1)
746 } else {
747 search_pos = fk_pos + "FOREIGN KEY".len();
748 continue;
749 };
750
751 let refs_pos = {
753 let ws = body_upper[refs_search_start..].len()
754 - body_upper[refs_search_start..].trim_start().len();
755 refs_search_start + ws
756 };
757
758 let (foreign_table, foreign_cols) =
760 match parse_references_clause(&body_upper, body, refs_pos) {
761 Some(v) => v,
762 None => {
763 search_pos = fk_pos + "FOREIGN KEY".len();
764 continue;
765 }
766 };
767
768 let first_col = local_cols.first().map(String::as_str).unwrap_or("col");
772 let shared_cname = constraint_name
773 .clone()
774 .unwrap_or_else(|| format!("fk_{table}_{first_col}"));
775
776 for (idx, local_col) in local_cols.iter().enumerate() {
778 let foreign_col = foreign_cols.get(idx).cloned().unwrap_or_default();
779 results.push(ForeignKeyInfo {
780 constraint_name: shared_cname.clone(),
781 column: local_col.clone(),
782 foreign_table: foreign_table.clone(),
783 foreign_column: foreign_col,
784 });
785 }
786
787 search_pos = fk_pos + "FOREIGN KEY".len();
788 }
789
790 for segment in split_top_level_commas(body) {
795 let seg_trimmed = segment.trim();
796 let seg_upper = seg_trimmed.to_ascii_uppercase();
797
798 if seg_upper.trim_start().starts_with("FOREIGN KEY")
800 || seg_upper.trim_start().starts_with("CONSTRAINT")
801 || seg_upper.trim_start().starts_with("PRIMARY KEY")
802 || seg_upper.trim_start().starts_with("UNIQUE")
803 || seg_upper.trim_start().starts_with("CHECK")
804 {
805 continue;
806 }
807
808 let ref_rel = match seg_upper.find("REFERENCES") {
810 Some(p) => p,
811 None => continue,
812 };
813
814 let col_name = {
816 let first_token_end = seg_trimmed
817 .find(|c: char| c.is_whitespace())
818 .unwrap_or(seg_trimmed.len());
819 strip_sql_quotes(&seg_trimmed[..first_token_end]).to_owned()
820 };
821 if col_name.is_empty() {
822 continue;
823 }
824
825 let (foreign_table, foreign_cols) =
827 match parse_references_clause(&seg_upper, seg_trimmed, ref_rel) {
828 Some(v) => v,
829 None => continue,
830 };
831
832 let foreign_col = foreign_cols.into_iter().next().unwrap_or_default();
833 let cname = format!("fk_{table}_{col_name}");
834
835 let already = results.iter().any(|r| r.column == col_name);
837 if !already {
838 results.push(ForeignKeyInfo {
839 constraint_name: cname,
840 column: col_name,
841 foreign_table,
842 foreign_column: foreign_col,
843 });
844 }
845 }
846
847 results
848}
849
850pub struct SqliteTransaction<'a> {
869 conn: limbo::Connection,
870 stmt_cache: StmtCache,
871 _guard: tokio::sync::MutexGuard<'a, ()>,
872 done: bool,
873}
874
875impl<'a> Drop for SqliteTransaction<'a> {
876 fn drop(&mut self) {
877 if !self.done {
878 let conn = self.conn.clone();
888 tokio::spawn(async move {
889 if let Err(e) = conn.execute("ROLLBACK", LimboParams::None).await {
890 log::warn!(
891 "SqliteTransaction drop: ROLLBACK failed (expected with limbo \
892 0.0.22 which does not implement ROLLBACK): {e}"
893 );
894 }
895 });
896 }
897 }
898}
899
900#[async_trait]
901impl<'a> Transaction for SqliteTransaction<'a> {
902 async fn execute(&mut self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
903 let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
904 exec_rewritten(&self.conn, &rewritten, limbo_params, Some(&self.stmt_cache))
905 .await
906 .map_err(OxiSqlError::from)
907 }
908
909 async fn query(
910 &mut self,
911 sql: &str,
912 params: &[&dyn ToSqlValue],
913 ) -> Result<Vec<Row>, OxiSqlError> {
914 let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
915 query_rewritten(&self.conn, &rewritten, limbo_params)
916 .await
917 .map_err(OxiSqlError::from)
918 }
919
920 async fn commit(mut self: Box<Self>) -> Result<(), OxiSqlError> {
921 self.done = true;
922 self.conn
923 .execute("COMMIT", LimboParams::None)
924 .await
925 .map_err(|e| OxiSqlError::Other(format!("COMMIT failed: {e}")))?;
926 Ok(())
927 }
928
929 async fn rollback(mut self: Box<Self>) -> Result<(), OxiSqlError> {
930 self.done = true;
932 Err(OxiSqlError::Other(
938 "ROLLBACK is not supported by the limbo 0.0.22 engine; \
939 this transaction cannot be rolled back — upgrade to limbo 0.1+ \
940 when available"
941 .to_owned(),
942 ))
943 }
944}
945
946pub struct SqlitePrepared<'a> {
955 conn: &'a limbo::Connection,
956 stmt_cache: StmtCache,
957 sql: String,
958}
959
960#[async_trait]
961impl<'a> PreparedStatement for SqlitePrepared<'a> {
962 async fn execute(&mut self, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
963 let (rewritten, limbo_params) =
964 rewrite_params(&self.sql, params).map_err(OxiSqlError::from)?;
965 exec_rewritten(self.conn, &rewritten, limbo_params, Some(&self.stmt_cache))
966 .await
967 .map_err(OxiSqlError::from)
968 }
969
970 async fn query(&mut self, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
971 let (rewritten, limbo_params) =
972 rewrite_params(&self.sql, params).map_err(OxiSqlError::from)?;
973 query_rewritten(self.conn, &rewritten, limbo_params)
974 .await
975 .map_err(OxiSqlError::from)
976 }
977
978 fn sql(&self) -> &str {
979 &self.sql
980 }
981}
982
983#[cfg(test)]
986mod fk_tests {
987 use super::parse_foreign_keys;
988
989 #[test]
990 fn test_single_column_level_fk() {
991 let ddl = "CREATE TABLE orders (\
992 id INTEGER PRIMARY KEY,\
993 customer_id INTEGER REFERENCES customers(id)\
994 )";
995 let fks = parse_foreign_keys(ddl, "orders");
996 assert_eq!(fks.len(), 1, "expected 1 FK, got {fks:?}");
997 assert_eq!(fks[0].column, "customer_id");
998 assert_eq!(fks[0].foreign_table, "customers");
999 assert_eq!(fks[0].foreign_column, "id");
1000 }
1001
1002 #[test]
1003 fn test_table_level_fk() {
1004 let ddl = "CREATE TABLE orders (\
1005 id INTEGER PRIMARY KEY,\
1006 cust_id INTEGER,\
1007 FOREIGN KEY (cust_id) REFERENCES customers(id)\
1008 )";
1009 let fks = parse_foreign_keys(ddl, "orders");
1010 assert_eq!(fks.len(), 1, "expected 1 FK, got {fks:?}");
1011 assert_eq!(fks[0].column, "cust_id");
1012 assert_eq!(fks[0].foreign_table, "customers");
1013 assert_eq!(fks[0].foreign_column, "id");
1014 }
1015
1016 #[test]
1017 fn test_composite_fk() {
1018 let ddl = "CREATE TABLE orders (\
1019 a INTEGER,\
1020 b INTEGER,\
1021 FOREIGN KEY (a, b) REFERENCES parent(x, y)\
1022 )";
1023 let fks = parse_foreign_keys(ddl, "orders");
1024 assert_eq!(
1025 fks.len(),
1026 2,
1027 "expected 2 entries for composite FK, got {fks:?}"
1028 );
1029 assert_eq!(fks[0].column, "a");
1030 assert_eq!(fks[0].foreign_column, "x");
1031 assert_eq!(fks[1].column, "b");
1032 assert_eq!(fks[1].foreign_column, "y");
1033 assert_eq!(fks[0].constraint_name, fks[1].constraint_name);
1035 }
1036
1037 #[test]
1038 fn test_multiple_fks() {
1039 let ddl = "CREATE TABLE items (\
1040 id INTEGER PRIMARY KEY,\
1041 category_id INTEGER REFERENCES categories(id),\
1042 supplier_id INTEGER REFERENCES suppliers(sid)\
1043 )";
1044 let fks = parse_foreign_keys(ddl, "items");
1045 assert_eq!(fks.len(), 2, "expected 2 FKs, got {fks:?}");
1046 let col_names: Vec<&str> = fks.iter().map(|f| f.column.as_str()).collect();
1047 assert!(col_names.contains(&"category_id"), "missing category_id FK");
1048 assert!(col_names.contains(&"supplier_id"), "missing supplier_id FK");
1049 }
1050
1051 #[test]
1052 fn test_quoted_identifiers() {
1053 let ddl = r#"CREATE TABLE "orders" (
1054 "cust_id" INTEGER REFERENCES `customers`("id")
1055 )"#;
1056 let fks = parse_foreign_keys(ddl, "orders");
1057 assert_eq!(
1058 fks.len(),
1059 1,
1060 "expected 1 FK from quoted identifiers, got {fks:?}"
1061 );
1062 assert_eq!(fks[0].column, "cust_id");
1063 assert_eq!(fks[0].foreign_table, "customers");
1064 assert_eq!(fks[0].foreign_column, "id");
1065 }
1066
1067 #[test]
1068 fn test_on_delete_cascade() {
1069 let ddl = "CREATE TABLE orders (\
1070 id INTEGER PRIMARY KEY,\
1071 customer_id INTEGER NOT NULL REFERENCES customers(id) ON DELETE CASCADE\
1072 )";
1073 let fks = parse_foreign_keys(ddl, "orders");
1074 assert_eq!(
1075 fks.len(),
1076 1,
1077 "ON DELETE CASCADE must not corrupt output; got {fks:?}"
1078 );
1079 assert_eq!(fks[0].column, "customer_id");
1080 assert_eq!(fks[0].foreign_table, "customers");
1081 assert_eq!(fks[0].foreign_column, "id");
1082 }
1083
1084 #[test]
1085 fn test_constraint_name() {
1086 let ddl = "CREATE TABLE orders (\
1087 id INTEGER PRIMARY KEY,\
1088 cust_id INTEGER,\
1089 CONSTRAINT fk_orders_cust FOREIGN KEY (cust_id) REFERENCES customers(id)\
1090 )";
1091 let fks = parse_foreign_keys(ddl, "orders");
1092 assert_eq!(fks.len(), 1, "expected 1 FK, got {fks:?}");
1093 assert_eq!(
1094 fks[0].constraint_name, "fk_orders_cust",
1095 "constraint name should be preserved"
1096 );
1097 }
1098
1099 #[test]
1100 fn test_implicit_pk_ref() {
1101 let ddl = "CREATE TABLE orders (\
1103 id INTEGER PRIMARY KEY,\
1104 customer_id INTEGER REFERENCES customers\
1105 )";
1106 let fks = parse_foreign_keys(ddl, "orders");
1107 assert_eq!(
1108 fks.len(),
1109 1,
1110 "expected 1 FK for implicit PK ref, got {fks:?}"
1111 );
1112 assert_eq!(fks[0].foreign_table, "customers");
1113 assert_eq!(
1114 fks[0].foreign_column, "",
1115 "implicit PK ref should have empty foreign_column"
1116 );
1117 }
1118
1119 #[test]
1120 fn test_decimal_type_no_false_fk() {
1121 let ddl = "CREATE TABLE products (\
1123 id INTEGER PRIMARY KEY,\
1124 price DECIMAL(10,2) NOT NULL\
1125 )";
1126 let fks = parse_foreign_keys(ddl, "products");
1127 assert!(
1128 fks.is_empty(),
1129 "DECIMAL(10,2) must not be mistaken for a FK, got {fks:?}"
1130 );
1131 }
1132}