1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9 pool: &MySqlPool,
10 schemas: &[String],
11 include_views: bool,
12) -> Result<SchemaInfo> {
13 let tables = fetch_tables(pool, schemas).await?;
14 let mut views = if include_views {
15 fetch_views(pool, schemas).await?
16 } else {
17 Vec::new()
18 };
19
20 if !views.is_empty() {
21 let sources = fetch_view_column_sources(pool, schemas).await?;
22 resolve_view_nullability(&mut views, &sources, &tables);
23 }
24
25 let enums = extract_enums(&tables);
26
27 Ok(SchemaInfo {
28 tables,
29 views,
30 enums,
31 composite_types: Vec::new(),
32 domains: Vec::new(),
33 })
34}
35
36async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
37 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
39 let query = format!(
40 r#"
41 SELECT
42 c.TABLE_SCHEMA,
43 c.TABLE_NAME,
44 c.COLUMN_NAME,
45 c.DATA_TYPE,
46 c.COLUMN_TYPE,
47 c.IS_NULLABLE,
48 c.ORDINAL_POSITION,
49 c.COLUMN_KEY
50 FROM information_schema.COLUMNS c
51 JOIN information_schema.TABLES t
52 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
53 AND t.TABLE_NAME = c.TABLE_NAME
54 AND t.TABLE_TYPE = 'BASE TABLE'
55 WHERE c.TABLE_SCHEMA IN ({})
56 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
57 "#,
58 placeholders.join(",")
59 );
60
61 let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32, String)>(&query);
62 for schema in schemas {
63 q = q.bind(schema);
64 }
65 let rows = q.fetch_all(pool).await?;
66
67 let mut tables: Vec<TableInfo> = Vec::new();
68 let mut current_key: Option<(String, String)> = None;
69
70 for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
71 let key = (schema.clone(), table.clone());
72 if current_key.as_ref() != Some(&key) {
73 current_key = Some(key);
74 tables.push(TableInfo {
75 schema_name: schema.clone(),
76 name: table.clone(),
77 columns: Vec::new(),
78 });
79 }
80 tables.last_mut().unwrap().columns.push(ColumnInfo {
81 name: col_name,
82 data_type,
83 udt_name: column_type,
84 is_nullable: nullable == "YES",
85 is_primary_key: column_key == "PRI",
86 ordinal_position: ordinal as i32,
87 schema_name: schema,
88 column_default: None,
89 });
90 }
91
92 Ok(tables)
93}
94
95async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
96 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
97 let query = format!(
98 r#"
99 SELECT
100 c.TABLE_SCHEMA,
101 c.TABLE_NAME,
102 c.COLUMN_NAME,
103 c.DATA_TYPE,
104 c.COLUMN_TYPE,
105 c.IS_NULLABLE,
106 c.ORDINAL_POSITION
107 FROM information_schema.COLUMNS c
108 JOIN information_schema.TABLES t
109 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
110 AND t.TABLE_NAME = c.TABLE_NAME
111 AND t.TABLE_TYPE = 'VIEW'
112 WHERE c.TABLE_SCHEMA IN ({})
113 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
114 "#,
115 placeholders.join(",")
116 );
117
118 let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
119 for schema in schemas {
120 q = q.bind(schema);
121 }
122 let rows = q.fetch_all(pool).await?;
123
124 let mut views: Vec<TableInfo> = Vec::new();
125 let mut current_key: Option<(String, String)> = None;
126
127 for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
128 let key = (schema.clone(), table.clone());
129 if current_key.as_ref() != Some(&key) {
130 current_key = Some(key);
131 views.push(TableInfo {
132 schema_name: schema.clone(),
133 name: table.clone(),
134 columns: Vec::new(),
135 });
136 }
137 views.last_mut().unwrap().columns.push(ColumnInfo {
138 name: col_name,
139 data_type,
140 udt_name: column_type,
141 is_nullable: nullable == "YES",
142 is_primary_key: false,
143 ordinal_position: ordinal as i32,
144 schema_name: schema,
145 column_default: None,
146 });
147 }
148
149 Ok(views)
150}
151
152struct ViewColumnSource {
153 view_schema: String,
154 view_name: String,
155 table_schema: String,
156 table_name: String,
157 column_name: String,
158}
159
160async fn fetch_view_column_sources(
161 pool: &MySqlPool,
162 schemas: &[String],
163) -> Result<Vec<ViewColumnSource>> {
164 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
165 let query = format!(
166 r#"
167 SELECT
168 vcu.VIEW_SCHEMA,
169 vcu.VIEW_NAME,
170 vcu.TABLE_SCHEMA,
171 vcu.TABLE_NAME,
172 vcu.COLUMN_NAME
173 FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
174 WHERE vcu.VIEW_SCHEMA IN ({})
175 "#,
176 placeholders.join(",")
177 );
178
179 let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
180 for schema in schemas {
181 q = q.bind(schema);
182 }
183
184 match q.fetch_all(pool).await {
185 Ok(rows) => Ok(rows
186 .into_iter()
187 .map(
188 |(view_schema, view_name, table_schema, table_name, column_name)| {
189 ViewColumnSource {
190 view_schema,
191 view_name,
192 table_schema,
193 table_name,
194 column_name,
195 }
196 },
197 )
198 .collect()),
199 Err(_) => {
200 Ok(Vec::new())
202 }
203 }
204}
205
206fn resolve_view_nullability(
207 views: &mut [TableInfo],
208 sources: &[ViewColumnSource],
209 tables: &[TableInfo],
210) {
211 let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
213 for table in tables {
214 for col in &table.columns {
215 table_lookup.insert(
216 (&table.schema_name, &table.name, &col.name),
217 col.is_nullable,
218 );
219 }
220 }
221
222 let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
224 for src in sources {
225 if let Some(&is_nullable) =
226 table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
227 {
228 view_lookup
229 .entry((&src.view_schema, &src.view_name, &src.column_name))
230 .or_default()
231 .push(is_nullable);
232 }
233 }
234
235 for view in views.iter_mut() {
236 for col in view.columns.iter_mut() {
237 if let Some(nullable_flags) = view_lookup.get(&(
238 view.schema_name.as_str(),
239 view.name.as_str(),
240 col.name.as_str(),
241 )) {
242 if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
244 col.is_nullable = false;
245 }
246 }
247 }
248 }
249}
250
251fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
255 let mut enums = Vec::new();
256
257 for table in tables {
258 for col in &table.columns {
259 if col.udt_name.starts_with("enum(") {
260 let variants = parse_enum_variants(&col.udt_name);
261 if !variants.is_empty() {
262 let enum_name = format!("{}_{}", table.name, col.name);
263 enums.push(EnumInfo {
264 schema_name: table.schema_name.clone(),
265 name: enum_name,
266 variants,
267 default_variant: None,
268 });
269 }
270 }
271 }
272 }
273
274 enums
275}
276
277fn parse_enum_variants(column_type: &str) -> Vec<String> {
278 let inner = column_type
280 .strip_prefix("enum(")
281 .and_then(|s| s.strip_suffix(')'));
282 match inner {
283 Some(s) => s
284 .split(',')
285 .map(|v| v.trim().trim_matches('\'').to_string())
286 .filter(|v| !v.is_empty())
287 .collect(),
288 None => Vec::new(),
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
297 TableInfo {
298 schema_name: "test_db".to_string(),
299 name: name.to_string(),
300 columns,
301 }
302 }
303
304 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
305 ColumnInfo {
306 name: name.to_string(),
307 data_type: "varchar".to_string(),
308 udt_name: udt_name.to_string(),
309 is_nullable: false,
310 is_primary_key: false,
311 ordinal_position: 0,
312 schema_name: "test_db".to_string(),
313 column_default: None,
314 }
315 }
316
317 #[test]
320 fn test_parse_simple() {
321 assert_eq!(
322 parse_enum_variants("enum('a','b','c')"),
323 vec!["a", "b", "c"]
324 );
325 }
326
327 #[test]
328 fn test_parse_single_variant() {
329 assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
330 }
331
332 #[test]
333 fn test_parse_with_spaces() {
334 assert_eq!(
335 parse_enum_variants("enum( 'a' , 'b' )"),
336 vec!["a", "b"]
337 );
338 }
339
340 #[test]
341 fn test_parse_empty_parens() {
342 let result = parse_enum_variants("enum()");
343 assert!(result.is_empty());
344 }
345
346 #[test]
347 fn test_parse_varchar_not_enum() {
348 let result = parse_enum_variants("varchar(255)");
349 assert!(result.is_empty());
350 }
351
352 #[test]
353 fn test_parse_int_not_enum() {
354 let result = parse_enum_variants("int");
355 assert!(result.is_empty());
356 }
357
358 #[test]
359 fn test_parse_with_spaces_in_value() {
360 assert_eq!(
361 parse_enum_variants("enum('with space','no')"),
362 vec!["with space", "no"]
363 );
364 }
365
366 #[test]
367 fn test_parse_empty_variant_filtered() {
368 let result = parse_enum_variants("enum('a','','c')");
369 assert_eq!(result, vec!["a", "c"]);
370 }
371
372 #[test]
373 fn test_parse_uppercase_enum_not_matched() {
374 let result = parse_enum_variants("ENUM('a','b')");
376 assert!(result.is_empty());
377 }
378
379 #[test]
382 fn test_extract_from_enum_column() {
383 let tables = vec![make_table(
384 "users",
385 vec![make_col("status", "enum('active','inactive')")],
386 )];
387 let enums = extract_enums(&tables);
388 assert_eq!(enums.len(), 1);
389 assert_eq!(enums[0].variants, vec!["active", "inactive"]);
390 }
391
392 #[test]
393 fn test_extract_enum_name_format() {
394 let tables = vec![make_table(
395 "users",
396 vec![make_col("status", "enum('a')")],
397 )];
398 let enums = extract_enums(&tables);
399 assert_eq!(enums[0].name, "users_status");
400 }
401
402 #[test]
403 fn test_extract_no_enums() {
404 let tables = vec![make_table(
405 "users",
406 vec![make_col("id", "int"), make_col("name", "varchar(255)")],
407 )];
408 let enums = extract_enums(&tables);
409 assert!(enums.is_empty());
410 }
411
412 #[test]
413 fn test_extract_two_enum_columns_same_table() {
414 let tables = vec![make_table(
415 "users",
416 vec![
417 make_col("status", "enum('active','inactive')"),
418 make_col("role", "enum('admin','user')"),
419 ],
420 )];
421 let enums = extract_enums(&tables);
422 assert_eq!(enums.len(), 2);
423 assert_eq!(enums[0].name, "users_status");
424 assert_eq!(enums[1].name, "users_role");
425 }
426
427 #[test]
428 fn test_extract_enums_from_multiple_tables() {
429 let tables = vec![
430 make_table("users", vec![make_col("status", "enum('a')")]),
431 make_table("posts", vec![make_col("state", "enum('b')")]),
432 ];
433 let enums = extract_enums(&tables);
434 assert_eq!(enums.len(), 2);
435 }
436
437 #[test]
438 fn test_extract_non_enum_column_ignored() {
439 let tables = vec![make_table(
440 "users",
441 vec![
442 make_col("id", "int(11)"),
443 make_col("status", "enum('a')"),
444 ],
445 )];
446 let enums = extract_enums(&tables);
447 assert_eq!(enums.len(), 1);
448 }
449
450 fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
453 TableInfo {
454 schema_name: schema.to_string(),
455 name: name.to_string(),
456 columns: columns
457 .into_iter()
458 .enumerate()
459 .map(|(i, col)| ColumnInfo {
460 name: col.to_string(),
461 data_type: "varchar".to_string(),
462 udt_name: "varchar(255)".to_string(),
463 is_nullable: true,
464 is_primary_key: false,
465 ordinal_position: i as i32,
466 schema_name: schema.to_string(),
467 column_default: None,
468 })
469 .collect(),
470 }
471 }
472
473 fn make_table_with_nullability(
474 schema: &str,
475 name: &str,
476 columns: Vec<(&str, bool)>,
477 ) -> TableInfo {
478 TableInfo {
479 schema_name: schema.to_string(),
480 name: name.to_string(),
481 columns: columns
482 .into_iter()
483 .enumerate()
484 .map(|(i, (col, nullable))| ColumnInfo {
485 name: col.to_string(),
486 data_type: "varchar".to_string(),
487 udt_name: "varchar(255)".to_string(),
488 is_nullable: nullable,
489 is_primary_key: false,
490 ordinal_position: i as i32,
491 schema_name: schema.to_string(),
492 column_default: None,
493 })
494 .collect(),
495 }
496 }
497
498 fn make_source(
499 view_schema: &str,
500 view_name: &str,
501 table_schema: &str,
502 table_name: &str,
503 column_name: &str,
504 ) -> ViewColumnSource {
505 ViewColumnSource {
506 view_schema: view_schema.to_string(),
507 view_name: view_name.to_string(),
508 table_schema: table_schema.to_string(),
509 table_name: table_name.to_string(),
510 column_name: column_name.to_string(),
511 }
512 }
513
514 #[test]
515 fn test_resolve_not_null_column() {
516 let tables = vec![make_table_with_nullability(
517 "db",
518 "users",
519 vec![("id", false), ("name", false)],
520 )];
521 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
522 let sources = vec![
523 make_source("db", "my_view", "db", "users", "id"),
524 make_source("db", "my_view", "db", "users", "name"),
525 ];
526 resolve_view_nullability(&mut views, &sources, &tables);
527 assert!(!views[0].columns[0].is_nullable);
528 assert!(!views[0].columns[1].is_nullable);
529 }
530
531 #[test]
532 fn test_resolve_nullable_source() {
533 let tables = vec![make_table_with_nullability(
534 "db",
535 "users",
536 vec![("id", false), ("name", true)],
537 )];
538 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
539 let sources = vec![
540 make_source("db", "my_view", "db", "users", "id"),
541 make_source("db", "my_view", "db", "users", "name"),
542 ];
543 resolve_view_nullability(&mut views, &sources, &tables);
544 assert!(!views[0].columns[0].is_nullable);
545 assert!(views[0].columns[1].is_nullable);
546 }
547
548 #[test]
549 fn test_resolve_no_match_stays_nullable() {
550 let tables = vec![make_table_with_nullability(
551 "db",
552 "users",
553 vec![("id", false)],
554 )];
555 let mut views = vec![make_view("db", "my_view", vec!["computed"])];
556 let sources = vec![];
557 resolve_view_nullability(&mut views, &sources, &tables);
558 assert!(views[0].columns[0].is_nullable);
559 }
560
561 #[test]
562 fn test_resolve_empty_sources() {
563 let tables = vec![];
564 let mut views = vec![make_view("db", "my_view", vec!["id"])];
565 resolve_view_nullability(&mut views, &[], &tables);
566 assert!(views[0].columns[0].is_nullable);
567 }
568}