1use std::collections::HashMap;
7
8use prax_query::introspection::{
9 ColumnInfo, DatabaseSchema, EnumInfo, ForeignKeyInfo, IndexColumn, IndexInfo,
10 ReferentialAction, SortOrder, TableInfo, ViewInfo, generate_prax_schema, normalize_type,
11 queries,
12};
13use prax_query::sql::DatabaseType;
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17
18#[derive(Debug, Clone)]
20pub struct IntrospectionOptions {
21 pub schema: Option<String>,
23 pub include_views: bool,
25 pub include_materialized_views: bool,
27 pub table_filter: Option<String>,
29 pub exclude_pattern: Option<String>,
31 pub include_comments: bool,
33 pub sample_size: usize,
35}
36
37impl Default for IntrospectionOptions {
38 fn default() -> Self {
39 Self {
40 schema: None,
41 include_views: false,
42 include_materialized_views: false,
43 table_filter: None,
44 exclude_pattern: None,
45 include_comments: true,
46 sample_size: 100,
47 }
48 }
49}
50
51#[allow(async_fn_in_trait)]
53pub trait Introspector {
54 async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema>;
56}
57
58pub fn get_database_type(provider: &str) -> CliResult<DatabaseType> {
60 match provider.to_lowercase().as_str() {
61 "postgresql" | "postgres" | "pg" => Ok(DatabaseType::PostgreSQL),
62 "mysql" | "mariadb" => Ok(DatabaseType::MySQL),
63 "sqlite" | "sqlite3" => Ok(DatabaseType::SQLite),
64 "mssql" | "sqlserver" | "sql_server" => Ok(DatabaseType::MSSQL),
65 _ => Err(CliError::Config(format!(
66 "Unsupported database provider: {}",
67 provider
68 ))),
69 }
70}
71
72pub fn default_schema(db_type: DatabaseType) -> &'static str {
74 match db_type {
75 DatabaseType::PostgreSQL => "public",
76 DatabaseType::MySQL => "",
77 DatabaseType::SQLite => "",
78 DatabaseType::MSSQL => "dbo",
79 }
80}
81
82#[cfg(feature = "postgres")]
87pub mod postgres {
88 use super::*;
89 use tokio_postgres::{Client, NoTls};
90
91 pub struct PostgresIntrospector {
93 connection_string: String,
94 }
95
96 impl PostgresIntrospector {
97 pub fn new(connection_string: String) -> Self {
99 Self { connection_string }
100 }
101
102 async fn connect(&self) -> CliResult<Client> {
104 let (client, connection) = tokio_postgres::connect(&self.connection_string, NoTls)
105 .await
106 .map_err(|e| CliError::Database(format!("Failed to connect: {}", e)))?;
107
108 tokio::spawn(async move {
110 if let Err(e) = connection.await {
111 eprintln!("Connection error: {}", e);
112 }
113 });
114
115 Ok(client)
116 }
117 }
118
119 impl Introspector for PostgresIntrospector {
120 async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema> {
121 let client = self.connect().await?;
122 let schema_name = options.schema.as_deref().unwrap_or("public");
123
124 let mut db_schema = DatabaseSchema {
125 name: "database".to_string(),
126 schema: Some(schema_name.to_string()),
127 ..Default::default()
128 };
129
130 let tables_sql = queries::tables_query(DatabaseType::PostgreSQL, Some(schema_name));
132 let table_rows = client
133 .query(&tables_sql, &[])
134 .await
135 .map_err(|e| CliError::Database(format!("Failed to query tables: {}", e)))?;
136
137 for row in table_rows {
138 let table_name: String = row.get(0);
139
140 if let Some(ref pattern) = options.table_filter {
142 if !matches_pattern(&table_name, pattern) {
143 continue;
144 }
145 }
146 if let Some(ref exclude) = options.exclude_pattern {
147 if matches_pattern(&table_name, exclude) {
148 continue;
149 }
150 }
151
152 let comment: Option<String> = row.try_get(1).ok();
153
154 let mut table = TableInfo {
155 name: table_name.clone(),
156 schema: Some(schema_name.to_string()),
157 comment: if options.include_comments {
158 comment
159 } else {
160 None
161 },
162 ..Default::default()
163 };
164
165 let cols_sql = queries::columns_query(
167 DatabaseType::PostgreSQL,
168 &table_name,
169 Some(schema_name),
170 );
171 let col_rows = client
172 .query(&cols_sql, &[])
173 .await
174 .map_err(|e| CliError::Database(format!("Failed to query columns: {}", e)))?;
175
176 for col_row in col_rows {
177 let col_name: String = col_row.get(0);
178 let data_type: String = col_row.get(1);
179 let udt_name: String = col_row.get(2);
180 let nullable: bool = col_row.get(3);
181 let default: Option<String> = col_row.try_get(4).ok();
182 let max_length: Option<i32> = col_row.try_get(5).ok();
183 let precision: Option<i32> = col_row.try_get(6).ok();
184 let scale: Option<i32> = col_row.try_get(7).ok();
185 let comment: Option<String> = col_row.try_get(8).ok();
186 let auto_increment: bool = col_row.try_get(9).unwrap_or(false);
187
188 let normalized = normalize_type(
189 DatabaseType::PostgreSQL,
190 &udt_name,
191 max_length,
192 precision,
193 scale,
194 );
195
196 table.columns.push(ColumnInfo {
197 name: col_name,
198 db_type: data_type,
199 normalized_type: normalized,
200 nullable,
201 default,
202 auto_increment,
203 max_length,
204 precision,
205 scale,
206 comment: if options.include_comments {
207 comment
208 } else {
209 None
210 },
211 ..Default::default()
212 });
213 }
214
215 let pk_sql = queries::primary_keys_query(
217 DatabaseType::PostgreSQL,
218 &table_name,
219 Some(schema_name),
220 );
221 let pk_rows = client.query(&pk_sql, &[]).await.map_err(|e| {
222 CliError::Database(format!("Failed to query primary keys: {}", e))
223 })?;
224
225 for pk_row in pk_rows {
226 let col_name: String = pk_row.get(0);
227 table.primary_key.push(col_name.clone());
228
229 if let Some(col) = table.columns.iter_mut().find(|c| c.name == col_name) {
231 col.is_primary_key = true;
232 }
233 }
234
235 let fk_sql = queries::foreign_keys_query(
237 DatabaseType::PostgreSQL,
238 &table_name,
239 Some(schema_name),
240 );
241 let fk_rows = client.query(&fk_sql, &[]).await.map_err(|e| {
242 CliError::Database(format!("Failed to query foreign keys: {}", e))
243 })?;
244
245 let mut fk_map: HashMap<String, ForeignKeyInfo> = HashMap::new();
246 for fk_row in fk_rows {
247 let constraint_name: String = fk_row.get(0);
248 let column_name: String = fk_row.get(1);
249 let ref_table: String = fk_row.get(2);
250 let ref_schema: Option<String> = fk_row.try_get(3).ok();
251 let ref_column: String = fk_row.get(4);
252 let delete_rule: String = fk_row.get(5);
253 let update_rule: String = fk_row.get(6);
254
255 let fk =
256 fk_map
257 .entry(constraint_name.clone())
258 .or_insert_with(|| ForeignKeyInfo {
259 name: constraint_name,
260 columns: Vec::new(),
261 referenced_table: ref_table,
262 referenced_schema: ref_schema,
263 referenced_columns: Vec::new(),
264 on_delete: ReferentialAction::from_str(&delete_rule),
265 on_update: ReferentialAction::from_str(&update_rule),
266 });
267
268 fk.columns.push(column_name);
269 fk.referenced_columns.push(ref_column);
270 }
271
272 table.foreign_keys = fk_map.into_values().collect();
273
274 let idx_sql = queries::indexes_query(
276 DatabaseType::PostgreSQL,
277 &table_name,
278 Some(schema_name),
279 );
280 let idx_rows = client
281 .query(&idx_sql, &[])
282 .await
283 .map_err(|e| CliError::Database(format!("Failed to query indexes: {}", e)))?;
284
285 let mut idx_map: HashMap<String, IndexInfo> = HashMap::new();
286 for idx_row in idx_rows {
287 let idx_name: String = idx_row.get(0);
288 let col_name: String = idx_row.get(1);
289 let is_unique: bool = idx_row.get(2);
290 let is_primary: bool = idx_row.get(3);
291 let idx_type: Option<String> = idx_row.try_get(4).ok();
292 let filter: Option<String> = idx_row.try_get(5).ok();
293
294 let idx = idx_map
295 .entry(idx_name.clone())
296 .or_insert_with(|| IndexInfo {
297 name: idx_name,
298 columns: Vec::new(),
299 is_unique,
300 is_primary,
301 index_type: idx_type,
302 filter,
303 });
304
305 idx.columns.push(IndexColumn {
306 name: col_name,
307 order: SortOrder::Asc,
308 ..Default::default()
309 });
310 }
311
312 table.indexes = idx_map.into_values().collect();
313
314 db_schema.tables.push(table);
315 }
316
317 let enums_sql = queries::enums_query(Some(schema_name));
319 let enum_rows = client
320 .query(&enums_sql, &[])
321 .await
322 .map_err(|e| CliError::Database(format!("Failed to query enums: {}", e)))?;
323
324 let mut enum_map: HashMap<String, EnumInfo> = HashMap::new();
325 for enum_row in enum_rows {
326 let enum_name: String = enum_row.get(0);
327 let enum_value: String = enum_row.get(1);
328
329 let enum_info = enum_map
330 .entry(enum_name.clone())
331 .or_insert_with(|| EnumInfo {
332 name: enum_name,
333 schema: Some(schema_name.to_string()),
334 values: Vec::new(),
335 });
336
337 enum_info.values.push(enum_value);
338 }
339
340 db_schema.enums = enum_map.into_values().collect();
341
342 if options.include_views || options.include_materialized_views {
344 let views_sql = queries::views_query(DatabaseType::PostgreSQL, Some(schema_name));
345 let view_rows = client
346 .query(&views_sql, &[])
347 .await
348 .map_err(|e| CliError::Database(format!("Failed to query views: {}", e)))?;
349
350 for view_row in view_rows {
351 let view_name: String = view_row.get(0);
352 let definition: Option<String> = view_row.try_get(1).ok();
353 let is_materialized: bool = view_row.get(2);
354
355 if is_materialized && !options.include_materialized_views {
356 continue;
357 }
358 if !is_materialized && !options.include_views {
359 continue;
360 }
361
362 db_schema.views.push(ViewInfo {
363 name: view_name,
364 schema: Some(schema_name.to_string()),
365 definition,
366 is_materialized,
367 columns: Vec::new(),
368 });
369 }
370 }
371
372 Ok(db_schema)
373 }
374 }
375
376 fn matches_pattern(name: &str, pattern: &str) -> bool {
378 if pattern == "*" {
379 return true;
380 }
381
382 if pattern.starts_with('*') && pattern.ends_with('*') {
383 let middle = &pattern[1..pattern.len() - 1];
384 return name.contains(middle);
385 }
386
387 if pattern.starts_with('*') {
388 let suffix = &pattern[1..];
389 return name.ends_with(suffix);
390 }
391
392 if pattern.ends_with('*') {
393 let prefix = &pattern[..pattern.len() - 1];
394 return name.starts_with(prefix);
395 }
396
397 name == pattern
398 }
399}
400
401pub fn format_as_prax(schema: &DatabaseSchema, config: &Config) -> String {
407 let mut output = String::new();
408
409 output.push_str("// Generated by `prax db pull`\n");
410 output.push_str("// Edit this file to customize your schema\n\n");
411
412 output.push_str("datasource db {\n");
413 output.push_str(&format!(
414 " provider = \"{}\"\n",
415 config.database.provider
416 ));
417 output.push_str(" url = env(\"DATABASE_URL\")\n");
418 output.push_str("}\n\n");
419
420 output.push_str("generator client {\n");
421 output.push_str(" provider = \"prax-client-rust\"\n");
422 output.push_str(" output = \"./src/generated\"\n");
423 output.push_str("}\n\n");
424
425 output.push_str(&generate_prax_schema(schema));
427
428 output
429}
430
431pub fn format_as_json(schema: &DatabaseSchema) -> CliResult<String> {
433 serde_json::to_string_pretty(schema)
434 .map_err(|e| CliError::Config(format!("Failed to serialize schema: {}", e)))
435}
436
437pub fn format_as_sql(schema: &DatabaseSchema, db_type: DatabaseType) -> String {
439 let mut output = String::new();
440
441 output.push_str("-- Generated by `prax db pull`\n");
442 output.push_str(&format!("-- Database: {}\n\n", db_type_name(db_type)));
443
444 if db_type == DatabaseType::PostgreSQL {
446 for enum_info in &schema.enums {
447 output.push_str(&format!("CREATE TYPE {} AS ENUM (\n", enum_info.name));
448 let values: Vec<String> = enum_info
449 .values
450 .iter()
451 .map(|v| format!(" '{}'", v))
452 .collect();
453 output.push_str(&values.join(",\n"));
454 output.push_str("\n);\n\n");
455 }
456 }
457
458 for table in &schema.tables {
460 output.push_str(&format!(
461 "CREATE TABLE {} (\n",
462 quote_identifier(&table.name, db_type)
463 ));
464
465 let mut col_defs: Vec<String> = Vec::new();
466
467 for col in &table.columns {
468 let mut def = format!(
469 " {} {}",
470 quote_identifier(&col.name, db_type),
471 col.db_type
472 );
473
474 if !col.nullable {
475 def.push_str(" NOT NULL");
476 }
477
478 if let Some(ref default) = col.default {
479 def.push_str(&format!(" DEFAULT {}", default));
480 }
481
482 col_defs.push(def);
483 }
484
485 if !table.primary_key.is_empty() {
487 let pk_cols: Vec<String> = table
488 .primary_key
489 .iter()
490 .map(|c| quote_identifier(c, db_type))
491 .collect();
492 col_defs.push(format!(" PRIMARY KEY ({})", pk_cols.join(", ")));
493 }
494
495 output.push_str(&col_defs.join(",\n"));
496 output.push_str("\n);\n\n");
497
498 for idx in &table.indexes {
500 if idx.is_primary {
501 continue;
502 }
503
504 let unique = if idx.is_unique { "UNIQUE " } else { "" };
505 let cols: Vec<String> = idx
506 .columns
507 .iter()
508 .map(|c| quote_identifier(&c.name, db_type))
509 .collect();
510
511 output.push_str(&format!(
512 "CREATE {}INDEX {} ON {} ({});\n",
513 unique,
514 quote_identifier(&idx.name, db_type),
515 quote_identifier(&table.name, db_type),
516 cols.join(", ")
517 ));
518 }
519
520 output.push('\n');
521 }
522
523 output
524}
525
526fn db_type_name(db_type: DatabaseType) -> &'static str {
527 match db_type {
528 DatabaseType::PostgreSQL => "PostgreSQL",
529 DatabaseType::MySQL => "MySQL",
530 DatabaseType::SQLite => "SQLite",
531 DatabaseType::MSSQL => "SQL Server",
532 }
533}
534
535fn quote_identifier(name: &str, db_type: DatabaseType) -> String {
536 match db_type {
537 DatabaseType::PostgreSQL => format!("\"{}\"", name),
538 DatabaseType::MySQL => format!("`{}`", name),
539 DatabaseType::SQLite => format!("\"{}\"", name),
540 DatabaseType::MSSQL => format!("[{}]", name),
541 }
542}