Skip to main content

sqlx_gen/introspect/
sqlite.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::SqlitePool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result<SchemaInfo> {
9    let mut tables = fetch_tables(pool).await?;
10    let mut views = if include_views {
11        fetch_views(pool).await?
12    } else {
13        Vec::new()
14    };
15
16    if !views.is_empty() {
17        resolve_view_nullability(&mut views, &tables);
18        resolve_view_primary_keys(&mut views, &tables);
19    }
20
21    let enums = extract_check_enums(pool, &mut tables).await?;
22
23    Ok(SchemaInfo {
24        tables,
25        views,
26        enums,
27        composite_types: Vec::new(),
28        domains: Vec::new(),
29    })
30}
31
32/// Detect SQLite "implicit enum" columns of the form
33/// `TEXT CHECK (col IN ('a', 'b', 'c'))` by parsing the DDL stored in
34/// `sqlite_master.sql`. Promotes the column's `udt_name` to the enum's
35/// synthesised name (`<table>_<col>_enum`) so the rest of the pipeline
36/// treats it like a real enum (with PgHasArrayType skipped for SQLite).
37async fn extract_check_enums(pool: &SqlitePool, tables: &mut [TableInfo]) -> Result<Vec<EnumInfo>> {
38    let mut enums = Vec::new();
39
40    for table in tables.iter_mut() {
41        let sql: Option<(Option<String>,)> =
42            sqlx::query_as("SELECT sql FROM sqlite_master WHERE type = 'table' AND name = ?")
43                .bind(&table.name)
44                .fetch_optional(pool)
45                .await?;
46        let Some((Some(ddl),)) = sql else { continue };
47
48        for col in table.columns.iter_mut() {
49            if let Some(variants) = parse_check_in_variants(&ddl, &col.name) {
50                if variants.is_empty() {
51                    continue;
52                }
53                let enum_name = format!("{}_{}_enum", table.name, col.name);
54                col.udt_name = enum_name.clone();
55                enums.push(EnumInfo {
56                    schema_name: "main".to_string(),
57                    name: enum_name,
58                    variants,
59                    default_variant: None,
60                });
61            }
62        }
63    }
64
65    Ok(enums)
66}
67
68/// Parse `CHECK (col IN ('a','b','c'))` for a given column from a SQLite
69/// CREATE TABLE statement. Returns the parsed variants in declaration order
70/// or `None` if the column has no IN-style CHECK constraint.
71fn parse_check_in_variants(ddl: &str, column: &str) -> Option<Vec<String>> {
72    let lower_ddl = ddl.to_ascii_lowercase();
73    let lower_col = column.to_ascii_lowercase();
74    let mut search_from = 0usize;
75
76    while let Some(rel_check) = lower_ddl[search_from..].find("check") {
77        let check_pos = search_from + rel_check;
78        let after_check = &ddl[check_pos + 5..];
79        let after_check_lower = &lower_ddl[check_pos + 5..];
80
81        let open_rel = after_check.find('(')?;
82        let mut depth = 1i32;
83        let mut idx = open_rel + 1;
84        let bytes = after_check.as_bytes();
85        while idx < bytes.len() && depth > 0 {
86            match bytes[idx] {
87                b'(' => depth += 1,
88                b')' => depth -= 1,
89                b'\'' => {
90                    idx += 1;
91                    while idx < bytes.len() && bytes[idx] != b'\'' {
92                        idx += 1;
93                    }
94                }
95                _ => {}
96            }
97            idx += 1;
98        }
99        if depth != 0 {
100            return None;
101        }
102        let body = &after_check[open_rel + 1..idx - 1];
103        let body_lower = &after_check_lower[open_rel + 1..idx - 1];
104
105        search_from = check_pos + 5 + idx;
106
107        if !body_lower.contains(&lower_col) || !body_lower.contains(" in ") {
108            continue;
109        }
110
111        if let Some(in_pos) = body_lower.find(" in ") {
112            let list_start = body[in_pos..].find('(')?;
113            let list_body = &body[in_pos + list_start + 1..];
114            let mut variants = Vec::new();
115            let bytes = list_body.as_bytes();
116            let mut i = 0;
117            while i < bytes.len() {
118                if bytes[i] == b'\'' {
119                    let start = i + 1;
120                    let mut j = start;
121                    while j < bytes.len() && bytes[j] != b'\'' {
122                        j += 1;
123                    }
124                    variants.push(list_body[start..j].to_string());
125                    i = j + 1;
126                } else if bytes[i] == b')' {
127                    break;
128                } else {
129                    i += 1;
130                }
131            }
132            return Some(variants);
133        }
134    }
135
136    None
137}
138
139async fn fetch_tables(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
140    let table_names: Vec<(String,)> = sqlx::query_as(
141        "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
142    )
143    .fetch_all(pool)
144    .await?;
145
146    let mut tables = Vec::new();
147
148    for (table_name,) in table_names {
149        let columns = fetch_columns(pool, &table_name).await?;
150        tables.push(TableInfo {
151            schema_name: "main".to_string(),
152            name: table_name,
153            columns,
154        });
155    }
156
157    Ok(tables)
158}
159
160async fn fetch_views(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
161    let view_names: Vec<(String,)> =
162        sqlx::query_as("SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name")
163            .fetch_all(pool)
164            .await?;
165
166    let mut views = Vec::new();
167
168    for (view_name,) in view_names {
169        let columns = fetch_columns(pool, &view_name).await?;
170        views.push(TableInfo {
171            schema_name: "main".to_string(),
172            name: view_name,
173            columns,
174        });
175    }
176
177    Ok(views)
178}
179
180async fn fetch_columns(pool: &SqlitePool, table_name: &str) -> Result<Vec<ColumnInfo>> {
181    // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
182    let pragma_query = format!("PRAGMA table_info(\"{}\")", table_name.replace('"', "\"\""));
183    let rows: Vec<(i32, String, String, bool, Option<String>, i32)> =
184        sqlx::query_as(&pragma_query).fetch_all(pool).await?;
185
186    Ok(rows
187        .into_iter()
188        .map(|(cid, name, declared_type, notnull, dflt_value, pk)| {
189            let upper = declared_type.to_uppercase();
190            ColumnInfo {
191                name,
192                data_type: upper.clone(),
193                udt_name: upper,
194                udt_schema: None,
195                is_nullable: !notnull,
196                is_primary_key: pk > 0,
197                ordinal_position: cid,
198                schema_name: "main".to_string(),
199                column_default: dflt_value,
200            }
201        })
202        .collect())
203}
204
205/// Resolve view column nullability by matching column names against introspected tables.
206/// If a column name is found in exactly one table and is NOT NULL, propagate that.
207fn resolve_view_nullability(views: &mut [TableInfo], tables: &[TableInfo]) {
208    // Build lookup: column_name -> Vec<is_nullable>
209    let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
210    for table in tables {
211        for col in &table.columns {
212            col_lookup
213                .entry(&col.name)
214                .or_default()
215                .push(col.is_nullable);
216        }
217    }
218
219    for view in views.iter_mut() {
220        for col in view.columns.iter_mut() {
221            if let Some(nullable_flags) = col_lookup.get(col.name.as_str()) {
222                // Only resolve if column name appears in exactly one table
223                // and that column is NOT nullable
224                if nullable_flags.len() == 1 && !nullable_flags[0] {
225                    col.is_nullable = false;
226                }
227            }
228        }
229    }
230}
231
232/// Resolve view column primary keys by matching column names against introspected tables.
233/// If a column name is found in exactly one table and is a PK, propagate that.
234fn resolve_view_primary_keys(views: &mut [TableInfo], tables: &[TableInfo]) {
235    // Build lookup: column_name -> Vec<is_primary_key>
236    let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
237    for table in tables {
238        for col in &table.columns {
239            col_lookup
240                .entry(&col.name)
241                .or_default()
242                .push(col.is_primary_key);
243        }
244    }
245
246    for view in views.iter_mut() {
247        for col in view.columns.iter_mut() {
248            if let Some(pk_flags) = col_lookup.get(col.name.as_str()) {
249                // Only resolve if column name appears in exactly one table
250                // and that column is a PK
251                if pk_flags.len() == 1 && pk_flags[0] {
252                    col.is_primary_key = true;
253                }
254            }
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    fn make_table(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
264        TableInfo {
265            schema_name: "main".to_string(),
266            name: name.to_string(),
267            columns: columns
268                .into_iter()
269                .enumerate()
270                .map(|(i, (col, nullable))| ColumnInfo {
271                    name: col.to_string(),
272                    data_type: "TEXT".to_string(),
273                    udt_name: "TEXT".to_string(),
274                    is_nullable: nullable,
275                    is_primary_key: false,
276                    ordinal_position: i as i32,
277                    schema_name: "main".to_string(),
278                    udt_schema: None,
279                    column_default: None,
280                })
281                .collect(),
282        }
283    }
284
285    fn make_view(name: &str, columns: Vec<&str>) -> TableInfo {
286        TableInfo {
287            schema_name: "main".to_string(),
288            name: name.to_string(),
289            columns: columns
290                .into_iter()
291                .enumerate()
292                .map(|(i, col)| ColumnInfo {
293                    name: col.to_string(),
294                    data_type: "TEXT".to_string(),
295                    udt_name: "TEXT".to_string(),
296                    is_nullable: true,
297                    is_primary_key: false,
298                    ordinal_position: i as i32,
299                    schema_name: "main".to_string(),
300                    udt_schema: None,
301                    column_default: None,
302                })
303                .collect(),
304        }
305    }
306
307    #[test]
308    fn test_resolve_unique_not_null() {
309        let tables = vec![make_table("users", vec![("id", false), ("name", false)])];
310        let mut views = vec![make_view("my_view", vec!["id", "name"])];
311        resolve_view_nullability(&mut views, &tables);
312        assert!(!views[0].columns[0].is_nullable);
313        assert!(!views[0].columns[1].is_nullable);
314    }
315
316    #[test]
317    fn test_resolve_nullable_source() {
318        let tables = vec![make_table("users", vec![("id", false), ("name", true)])];
319        let mut views = vec![make_view("my_view", vec!["id", "name"])];
320        resolve_view_nullability(&mut views, &tables);
321        assert!(!views[0].columns[0].is_nullable);
322        assert!(views[0].columns[1].is_nullable);
323    }
324
325    #[test]
326    fn test_resolve_ambiguous_stays_nullable() {
327        // "id" appears in two tables — ambiguous, stay nullable
328        let tables = vec![
329            make_table("users", vec![("id", false)]),
330            make_table("orders", vec![("id", false)]),
331        ];
332        let mut views = vec![make_view("my_view", vec!["id"])];
333        resolve_view_nullability(&mut views, &tables);
334        assert!(views[0].columns[0].is_nullable);
335    }
336
337    #[test]
338    fn test_resolve_no_match() {
339        let tables = vec![make_table("users", vec![("id", false)])];
340        let mut views = vec![make_view("my_view", vec!["computed"])];
341        resolve_view_nullability(&mut views, &tables);
342        assert!(views[0].columns[0].is_nullable);
343    }
344
345    #[test]
346    fn test_resolve_empty_tables() {
347        let mut views = vec![make_view("my_view", vec!["id"])];
348        resolve_view_nullability(&mut views, &[]);
349        assert!(views[0].columns[0].is_nullable);
350    }
351
352    // ========== resolve_view_primary_keys ==========
353
354    fn make_table_with_pk(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
355        TableInfo {
356            schema_name: "main".to_string(),
357            name: name.to_string(),
358            columns: columns
359                .into_iter()
360                .enumerate()
361                .map(|(i, (col, is_pk))| ColumnInfo {
362                    name: col.to_string(),
363                    data_type: "TEXT".to_string(),
364                    udt_name: "TEXT".to_string(),
365                    is_nullable: false,
366                    is_primary_key: is_pk,
367                    ordinal_position: i as i32,
368                    schema_name: "main".to_string(),
369                    udt_schema: None,
370                    column_default: None,
371                })
372                .collect(),
373        }
374    }
375
376    #[test]
377    fn test_resolve_pk_unique_match() {
378        let tables = vec![make_table_with_pk(
379            "users",
380            vec![("id", true), ("name", false)],
381        )];
382        let mut views = vec![make_view("my_view", vec!["id", "name"])];
383        resolve_view_primary_keys(&mut views, &tables);
384        assert!(views[0].columns[0].is_primary_key);
385        assert!(!views[0].columns[1].is_primary_key);
386    }
387
388    #[test]
389    fn test_resolve_pk_ambiguous() {
390        // "id" appears in two tables — ambiguous, don't mark as PK
391        let tables = vec![
392            make_table_with_pk("users", vec![("id", true)]),
393            make_table_with_pk("orders", vec![("id", true)]),
394        ];
395        let mut views = vec![make_view("my_view", vec!["id"])];
396        resolve_view_primary_keys(&mut views, &tables);
397        assert!(!views[0].columns[0].is_primary_key);
398    }
399
400    #[test]
401    fn test_resolve_pk_no_match() {
402        let tables = vec![make_table_with_pk("users", vec![("id", true)])];
403        let mut views = vec![make_view("my_view", vec!["computed"])];
404        resolve_view_primary_keys(&mut views, &tables);
405        assert!(!views[0].columns[0].is_primary_key);
406    }
407
408    #[test]
409    fn test_resolve_pk_empty_tables() {
410        let mut views = vec![make_view("my_view", vec!["id"])];
411        resolve_view_primary_keys(&mut views, &[]);
412        assert!(!views[0].columns[0].is_primary_key);
413    }
414
415    // ========== parse_check_in_variants ==========
416
417    #[test]
418    fn test_parse_check_in_simple() {
419        let ddl = "CREATE TABLE t (id INTEGER PRIMARY KEY, status TEXT CHECK (status IN ('active', 'inactive')) NOT NULL)";
420        assert_eq!(
421            parse_check_in_variants(ddl, "status"),
422            Some(vec!["active".to_string(), "inactive".to_string()])
423        );
424    }
425
426    #[test]
427    fn test_parse_check_in_three_variants() {
428        let ddl = "CREATE TABLE t (priority TEXT CHECK (priority IN ('low','medium','high')))";
429        assert_eq!(
430            parse_check_in_variants(ddl, "priority"),
431            Some(vec![
432                "low".to_string(),
433                "medium".to_string(),
434                "high".to_string()
435            ])
436        );
437    }
438
439    #[test]
440    fn test_parse_check_in_returns_none_for_other_column() {
441        let ddl = "CREATE TABLE t (status TEXT CHECK (status IN ('a','b')))";
442        assert_eq!(parse_check_in_variants(ddl, "other"), None);
443    }
444
445    #[test]
446    fn test_parse_check_in_returns_none_without_check() {
447        let ddl = "CREATE TABLE t (status TEXT)";
448        assert_eq!(parse_check_in_variants(ddl, "status"), None);
449    }
450
451    #[test]
452    fn test_parse_check_in_case_insensitive_keyword() {
453        let ddl = "CREATE TABLE t (status TEXT check (Status in ('a','b')))";
454        assert_eq!(
455            parse_check_in_variants(ddl, "status"),
456            Some(vec!["a".to_string(), "b".to_string()])
457        );
458    }
459}