1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::SqlitePool;
5
6use super::{ColumnInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result<SchemaInfo> {
9 let tables = fetch_tables(pool).await?;
10 let mut views = if include_views {
11 fetch_views(pool).await?
12 } else {
13 Vec::new()
14 };
15
16 if !views.is_empty() {
17 resolve_view_nullability(&mut views, &tables);
18 }
19
20 Ok(SchemaInfo {
21 tables,
22 views,
23 enums: Vec::new(),
24 composite_types: Vec::new(),
25 domains: Vec::new(),
26 })
27}
28
29async fn fetch_tables(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
30 let table_names: Vec<(String,)> = sqlx::query_as(
31 "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
32 )
33 .fetch_all(pool)
34 .await?;
35
36 let mut tables = Vec::new();
37
38 for (table_name,) in table_names {
39 let columns = fetch_columns(pool, &table_name).await?;
40 tables.push(TableInfo {
41 schema_name: "main".to_string(),
42 name: table_name,
43 columns,
44 });
45 }
46
47 Ok(tables)
48}
49
50async fn fetch_views(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
51 let view_names: Vec<(String,)> = sqlx::query_as(
52 "SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name",
53 )
54 .fetch_all(pool)
55 .await?;
56
57 let mut views = Vec::new();
58
59 for (view_name,) in view_names {
60 let columns = fetch_columns(pool, &view_name).await?;
61 views.push(TableInfo {
62 schema_name: "main".to_string(),
63 name: view_name,
64 columns,
65 });
66 }
67
68 Ok(views)
69}
70
71async fn fetch_columns(pool: &SqlitePool, table_name: &str) -> Result<Vec<ColumnInfo>> {
72 let pragma_query = format!("PRAGMA table_info(\"{}\")", table_name.replace('"', "\"\""));
74 let rows: Vec<(i32, String, String, bool, Option<String>, i32)> =
75 sqlx::query_as(&pragma_query).fetch_all(pool).await?;
76
77 Ok(rows
78 .into_iter()
79 .map(|(cid, name, declared_type, notnull, dflt_value, pk)| {
80 let upper = declared_type.to_uppercase();
81 ColumnInfo {
82 name,
83 data_type: upper.clone(),
84 udt_name: upper,
85 is_nullable: !notnull,
86 is_primary_key: pk > 0,
87 ordinal_position: cid,
88 schema_name: "main".to_string(),
89 column_default: dflt_value,
90 }
91 })
92 .collect())
93}
94
95fn resolve_view_nullability(views: &mut [TableInfo], tables: &[TableInfo]) {
98 let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
100 for table in tables {
101 for col in &table.columns {
102 col_lookup.entry(&col.name).or_default().push(col.is_nullable);
103 }
104 }
105
106 for view in views.iter_mut() {
107 for col in view.columns.iter_mut() {
108 if let Some(nullable_flags) = col_lookup.get(col.name.as_str()) {
109 if nullable_flags.len() == 1 && !nullable_flags[0] {
112 col.is_nullable = false;
113 }
114 }
115 }
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 fn make_table(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
124 TableInfo {
125 schema_name: "main".to_string(),
126 name: name.to_string(),
127 columns: columns
128 .into_iter()
129 .enumerate()
130 .map(|(i, (col, nullable))| ColumnInfo {
131 name: col.to_string(),
132 data_type: "TEXT".to_string(),
133 udt_name: "TEXT".to_string(),
134 is_nullable: nullable,
135 is_primary_key: false,
136 ordinal_position: i as i32,
137 schema_name: "main".to_string(),
138 column_default: None,
139 })
140 .collect(),
141 }
142 }
143
144 fn make_view(name: &str, columns: Vec<&str>) -> TableInfo {
145 TableInfo {
146 schema_name: "main".to_string(),
147 name: name.to_string(),
148 columns: columns
149 .into_iter()
150 .enumerate()
151 .map(|(i, col)| ColumnInfo {
152 name: col.to_string(),
153 data_type: "TEXT".to_string(),
154 udt_name: "TEXT".to_string(),
155 is_nullable: true,
156 is_primary_key: false,
157 ordinal_position: i as i32,
158 schema_name: "main".to_string(),
159 column_default: None,
160 })
161 .collect(),
162 }
163 }
164
165 #[test]
166 fn test_resolve_unique_not_null() {
167 let tables = vec![make_table("users", vec![("id", false), ("name", false)])];
168 let mut views = vec![make_view("my_view", vec!["id", "name"])];
169 resolve_view_nullability(&mut views, &tables);
170 assert!(!views[0].columns[0].is_nullable);
171 assert!(!views[0].columns[1].is_nullable);
172 }
173
174 #[test]
175 fn test_resolve_nullable_source() {
176 let tables = vec![make_table("users", vec![("id", false), ("name", true)])];
177 let mut views = vec![make_view("my_view", vec!["id", "name"])];
178 resolve_view_nullability(&mut views, &tables);
179 assert!(!views[0].columns[0].is_nullable);
180 assert!(views[0].columns[1].is_nullable);
181 }
182
183 #[test]
184 fn test_resolve_ambiguous_stays_nullable() {
185 let tables = vec![
187 make_table("users", vec![("id", false)]),
188 make_table("orders", vec![("id", false)]),
189 ];
190 let mut views = vec![make_view("my_view", vec!["id"])];
191 resolve_view_nullability(&mut views, &tables);
192 assert!(views[0].columns[0].is_nullable);
193 }
194
195 #[test]
196 fn test_resolve_no_match() {
197 let tables = vec![make_table("users", vec![("id", false)])];
198 let mut views = vec![make_view("my_view", vec!["computed"])];
199 resolve_view_nullability(&mut views, &tables);
200 assert!(views[0].columns[0].is_nullable);
201 }
202
203 #[test]
204 fn test_resolve_empty_tables() {
205 let mut views = vec![make_view("my_view", vec!["id"])];
206 resolve_view_nullability(&mut views, &[]);
207 assert!(views[0].columns[0].is_nullable);
208 }
209}