Skip to main content

dbrest_sqlite/
introspector.rs

1//! SQLite schema introspector — implements [`DbIntrospector`] for `sqlx::SqlitePool`.
2//!
3//! Uses `sqlite_master`, `PRAGMA table_info()`, and `PRAGMA foreign_key_list()`
4//! to discover tables, columns, and relationships.
5
6use async_trait::async_trait;
7use sqlx::Row;
8
9use dbrest_core::error::Error;
10use dbrest_core::schema_cache::db::{
11    ComputedFieldRow, DbIntrospector, RelationshipRow, RoutineRow, TableRow,
12};
13
14use crate::executor::map_sqlx_error;
15
16/// SQLite introspector backed by `sqlx::SqlitePool`.
17pub struct SqliteIntrospector<'a> {
18    pool: &'a sqlx::SqlitePool,
19}
20
21impl<'a> SqliteIntrospector<'a> {
22    pub fn new(pool: &'a sqlx::SqlitePool) -> Self {
23        Self { pool }
24    }
25}
26
27#[async_trait]
28impl DbIntrospector for SqliteIntrospector<'_> {
29    async fn query_tables(&self, _schemas: &[String]) -> Result<Vec<TableRow>, Error> {
30        // SQLite has no schemas (we treat everything as "main").
31        // Query sqlite_master for tables and views.
32        let rows = sqlx::query(
33            r#"
34            SELECT
35                type,
36                name
37            FROM sqlite_master
38            WHERE type IN ('table', 'view')
39              AND name NOT LIKE 'sqlite_%'
40              AND name NOT LIKE '_dbrest_%'
41            ORDER BY name
42            "#,
43        )
44        .fetch_all(self.pool)
45        .await
46        .map_err(map_sqlx_error)?;
47
48        let mut tables = Vec::with_capacity(rows.len());
49
50        for row in &rows {
51            let obj_type: String = row.try_get("type").unwrap_or_default();
52            let name: String = row.try_get("name").unwrap_or_default();
53            let is_view = obj_type == "view";
54
55            // Get column info via PRAGMA
56            let pragma_sql = format!("PRAGMA table_info(\"{}\")", name.replace('"', "\"\""));
57            let col_rows = sqlx::query(&pragma_sql)
58                .fetch_all(self.pool)
59                .await
60                .map_err(map_sqlx_error)?;
61
62            let mut pk_cols = Vec::new();
63            let mut columns_json_parts = Vec::new();
64
65            for col in &col_rows {
66                let col_name: String = col.try_get("name").unwrap_or_default();
67                let col_type: String = col.try_get("type").unwrap_or_default();
68                let not_null: bool = col.try_get::<i32, _>("notnull").unwrap_or(0) != 0;
69                let pk: i32 = col.try_get("pk").unwrap_or(0);
70                let dflt: Option<String> = col.try_get("dflt_value").ok();
71
72                if pk > 0 {
73                    pk_cols.push(col_name.clone());
74                }
75
76                // Build a JSON object for each column matching the expected format
77                let col_json = serde_json::json!({
78                    "name": col_name,
79                    "data_type": normalize_sqlite_type(&col_type),
80                    "nominal_type": col_type,
81                    "nullable": !not_null,
82                    "default": dflt,
83                    "max_length": null,
84                    "description": null,
85                    "enum_values": [],
86                    "is_composite": false,
87                });
88                columns_json_parts.push(col_json);
89            }
90
91            let columns_json =
92                serde_json::to_string(&columns_json_parts).unwrap_or_else(|_| "[]".to_string());
93
94            tables.push(TableRow {
95                table_schema: "main".to_string(),
96                table_name: name,
97                table_description: None,
98                is_view,
99                insertable: !is_view,
100                updatable: !is_view,
101                deletable: !is_view,
102                readable: true,
103                pk_cols,
104                columns_json,
105            });
106        }
107
108        Ok(tables)
109    }
110
111    async fn query_relationships(&self) -> Result<Vec<RelationshipRow>, Error> {
112        // Discover foreign keys from all tables using PRAGMA foreign_key_list().
113        let table_rows = sqlx::query(
114            r#"
115            SELECT name FROM sqlite_master
116            WHERE type = 'table'
117              AND name NOT LIKE 'sqlite_%'
118              AND name NOT LIKE '_dbrest_%'
119            "#,
120        )
121        .fetch_all(self.pool)
122        .await
123        .map_err(map_sqlx_error)?;
124
125        let mut relationships = Vec::new();
126
127        for table_row in &table_rows {
128            let table_name: String = table_row.try_get("name").unwrap_or_default();
129            let pragma_sql = format!(
130                "PRAGMA foreign_key_list(\"{}\")",
131                table_name.replace('"', "\"\"")
132            );
133            let fk_rows = sqlx::query(&pragma_sql)
134                .fetch_all(self.pool)
135                .await
136                .map_err(map_sqlx_error)?;
137
138            // Group by constraint id
139            let mut fk_groups: std::collections::HashMap<i32, Vec<(String, String, String)>> =
140                std::collections::HashMap::new();
141            for fk in &fk_rows {
142                let id: i32 = fk.try_get("id").unwrap_or(0);
143                let foreign_table: String = fk.try_get("table").unwrap_or_default();
144                let from_col: String = fk.try_get("from").unwrap_or_default();
145                let to_col: String = fk.try_get("to").unwrap_or_default();
146                fk_groups
147                    .entry(id)
148                    .or_default()
149                    .push((foreign_table, from_col, to_col));
150            }
151
152            for (id, cols) in &fk_groups {
153                if cols.is_empty() {
154                    continue;
155                }
156                let foreign_table_name = &cols[0].0;
157                let cols_and_fcols: Vec<(String, String)> = cols
158                    .iter()
159                    .map(|(_, f, t)| (f.clone(), t.clone()))
160                    .collect();
161
162                let is_self = table_name == *foreign_table_name;
163                let constraint_name = format!("fk_{}_{}", table_name, id);
164
165                relationships.push(RelationshipRow {
166                    table_schema: "main".to_string(),
167                    table_name: table_name.clone(),
168                    foreign_table_schema: "main".to_string(),
169                    foreign_table_name: foreign_table_name.clone(),
170                    is_self,
171                    constraint_name,
172                    cols_and_fcols,
173                    one_to_one: false, // Conservative default
174                });
175            }
176        }
177
178        Ok(relationships)
179    }
180
181    async fn query_routines(&self, _schemas: &[String]) -> Result<Vec<RoutineRow>, Error> {
182        // SQLite has no user-defined stored routines.
183        Ok(vec![])
184    }
185
186    async fn query_computed_fields(
187        &self,
188        _schemas: &[String],
189    ) -> Result<Vec<ComputedFieldRow>, Error> {
190        // SQLite has no computed fields via functions.
191        Ok(vec![])
192    }
193
194    async fn query_timezones(&self) -> Result<Vec<String>, Error> {
195        // SQLite has no timezone catalog. Return an empty list.
196        Ok(vec![])
197    }
198}
199
200/// Normalize SQLite type strings to standard affinity names.
201fn normalize_sqlite_type(raw: &str) -> String {
202    let upper = raw.to_uppercase();
203    // SQLite type affinity rules (https://www.sqlite.org/datatype3.html)
204    if upper.contains("INT") {
205        "integer".to_string()
206    } else if upper.contains("CHAR") || upper.contains("CLOB") || upper.contains("TEXT") {
207        "text".to_string()
208    } else if upper.contains("BLOB") || upper.is_empty() {
209        "blob".to_string()
210    } else if upper.contains("REAL") || upper.contains("FLOA") || upper.contains("DOUB") {
211        "real".to_string()
212    } else if upper.contains("BOOL") {
213        "boolean".to_string()
214    } else if upper.contains("DATE") || upper.contains("TIME") {
215        "text".to_string() // SQLite stores dates as text
216    } else if upper.contains("JSON") {
217        "json".to_string()
218    } else {
219        "text".to_string()
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_normalize_sqlite_type() {
229        assert_eq!(normalize_sqlite_type("INTEGER"), "integer");
230        assert_eq!(normalize_sqlite_type("INT"), "integer");
231        assert_eq!(normalize_sqlite_type("BIGINT"), "integer");
232        assert_eq!(normalize_sqlite_type("TEXT"), "text");
233        assert_eq!(normalize_sqlite_type("VARCHAR(255)"), "text");
234        assert_eq!(normalize_sqlite_type("REAL"), "real");
235        assert_eq!(normalize_sqlite_type("DOUBLE PRECISION"), "real");
236        assert_eq!(normalize_sqlite_type("BLOB"), "blob");
237        assert_eq!(normalize_sqlite_type("BOOLEAN"), "boolean");
238        assert_eq!(normalize_sqlite_type("DATETIME"), "text");
239        assert_eq!(normalize_sqlite_type("JSON"), "json");
240        assert_eq!(normalize_sqlite_type(""), "blob");
241    }
242}