1use std::collections::HashMap;
7
8use prax_query::introspection::{
9 ColumnInfo, DatabaseSchema, EnumInfo, ForeignKeyInfo, IndexColumn, IndexInfo,
10 ReferentialAction, SortOrder, TableInfo, ViewInfo,
11 generate_prax_schema, normalize_type, queries,
12};
13use prax_query::sql::DatabaseType;
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17
18#[derive(Debug, Clone)]
20pub struct IntrospectionOptions {
21 pub schema: Option<String>,
23 pub include_views: bool,
25 pub include_materialized_views: bool,
27 pub table_filter: Option<String>,
29 pub exclude_pattern: Option<String>,
31 pub include_comments: bool,
33 pub sample_size: usize,
35}
36
37impl Default for IntrospectionOptions {
38 fn default() -> Self {
39 Self {
40 schema: None,
41 include_views: false,
42 include_materialized_views: false,
43 table_filter: None,
44 exclude_pattern: None,
45 include_comments: true,
46 sample_size: 100,
47 }
48 }
49}
50
51#[allow(async_fn_in_trait)]
53pub trait Introspector {
54 async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema>;
56}
57
58pub fn get_database_type(provider: &str) -> CliResult<DatabaseType> {
60 match provider.to_lowercase().as_str() {
61 "postgresql" | "postgres" | "pg" => Ok(DatabaseType::PostgreSQL),
62 "mysql" | "mariadb" => Ok(DatabaseType::MySQL),
63 "sqlite" | "sqlite3" => Ok(DatabaseType::SQLite),
64 "mssql" | "sqlserver" | "sql_server" => Ok(DatabaseType::MSSQL),
65 _ => Err(CliError::Config(format!(
66 "Unsupported database provider: {}",
67 provider
68 ))),
69 }
70}
71
72pub fn default_schema(db_type: DatabaseType) -> &'static str {
74 match db_type {
75 DatabaseType::PostgreSQL => "public",
76 DatabaseType::MySQL => "",
77 DatabaseType::SQLite => "",
78 DatabaseType::MSSQL => "dbo",
79 }
80}
81
82#[cfg(feature = "postgres")]
87pub mod postgres {
88 use super::*;
89 use tokio_postgres::{Client, NoTls};
90
91 pub struct PostgresIntrospector {
93 connection_string: String,
94 }
95
96 impl PostgresIntrospector {
97 pub fn new(connection_string: String) -> Self {
99 Self { connection_string }
100 }
101
102 async fn connect(&self) -> CliResult<Client> {
104 let (client, connection) = tokio_postgres::connect(&self.connection_string, NoTls)
105 .await
106 .map_err(|e| CliError::Database(format!("Failed to connect: {}", e)))?;
107
108 tokio::spawn(async move {
110 if let Err(e) = connection.await {
111 eprintln!("Connection error: {}", e);
112 }
113 });
114
115 Ok(client)
116 }
117 }
118
119 impl Introspector for PostgresIntrospector {
120 async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema> {
121 let client = self.connect().await?;
122 let schema_name = options.schema.as_deref().unwrap_or("public");
123
124 let mut db_schema = DatabaseSchema {
125 name: "database".to_string(),
126 schema: Some(schema_name.to_string()),
127 ..Default::default()
128 };
129
130 let tables_sql = queries::tables_query(DatabaseType::PostgreSQL, Some(schema_name));
132 let table_rows = client
133 .query(&tables_sql, &[])
134 .await
135 .map_err(|e| CliError::Database(format!("Failed to query tables: {}", e)))?;
136
137 for row in table_rows {
138 let table_name: String = row.get(0);
139
140 if let Some(ref pattern) = options.table_filter {
142 if !matches_pattern(&table_name, pattern) {
143 continue;
144 }
145 }
146 if let Some(ref exclude) = options.exclude_pattern {
147 if matches_pattern(&table_name, exclude) {
148 continue;
149 }
150 }
151
152 let comment: Option<String> = row.try_get(1).ok();
153
154 let mut table = TableInfo {
155 name: table_name.clone(),
156 schema: Some(schema_name.to_string()),
157 comment: if options.include_comments { comment } else { None },
158 ..Default::default()
159 };
160
161 let cols_sql = queries::columns_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
163 let col_rows = client
164 .query(&cols_sql, &[])
165 .await
166 .map_err(|e| CliError::Database(format!("Failed to query columns: {}", e)))?;
167
168 for col_row in col_rows {
169 let col_name: String = col_row.get(0);
170 let data_type: String = col_row.get(1);
171 let udt_name: String = col_row.get(2);
172 let nullable: bool = col_row.get(3);
173 let default: Option<String> = col_row.try_get(4).ok();
174 let max_length: Option<i32> = col_row.try_get(5).ok();
175 let precision: Option<i32> = col_row.try_get(6).ok();
176 let scale: Option<i32> = col_row.try_get(7).ok();
177 let comment: Option<String> = col_row.try_get(8).ok();
178 let auto_increment: bool = col_row.try_get(9).unwrap_or(false);
179
180 let normalized = normalize_type(
181 DatabaseType::PostgreSQL,
182 &udt_name,
183 max_length,
184 precision,
185 scale,
186 );
187
188 table.columns.push(ColumnInfo {
189 name: col_name,
190 db_type: data_type,
191 normalized_type: normalized,
192 nullable,
193 default,
194 auto_increment,
195 max_length,
196 precision,
197 scale,
198 comment: if options.include_comments { comment } else { None },
199 ..Default::default()
200 });
201 }
202
203 let pk_sql = queries::primary_keys_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
205 let pk_rows = client
206 .query(&pk_sql, &[])
207 .await
208 .map_err(|e| CliError::Database(format!("Failed to query primary keys: {}", e)))?;
209
210 for pk_row in pk_rows {
211 let col_name: String = pk_row.get(0);
212 table.primary_key.push(col_name.clone());
213
214 if let Some(col) = table.columns.iter_mut().find(|c| c.name == col_name) {
216 col.is_primary_key = true;
217 }
218 }
219
220 let fk_sql = queries::foreign_keys_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
222 let fk_rows = client
223 .query(&fk_sql, &[])
224 .await
225 .map_err(|e| CliError::Database(format!("Failed to query foreign keys: {}", e)))?;
226
227 let mut fk_map: HashMap<String, ForeignKeyInfo> = HashMap::new();
228 for fk_row in fk_rows {
229 let constraint_name: String = fk_row.get(0);
230 let column_name: String = fk_row.get(1);
231 let ref_table: String = fk_row.get(2);
232 let ref_schema: Option<String> = fk_row.try_get(3).ok();
233 let ref_column: String = fk_row.get(4);
234 let delete_rule: String = fk_row.get(5);
235 let update_rule: String = fk_row.get(6);
236
237 let fk = fk_map.entry(constraint_name.clone()).or_insert_with(|| {
238 ForeignKeyInfo {
239 name: constraint_name,
240 columns: Vec::new(),
241 referenced_table: ref_table,
242 referenced_schema: ref_schema,
243 referenced_columns: Vec::new(),
244 on_delete: ReferentialAction::from_str(&delete_rule),
245 on_update: ReferentialAction::from_str(&update_rule),
246 }
247 });
248
249 fk.columns.push(column_name);
250 fk.referenced_columns.push(ref_column);
251 }
252
253 table.foreign_keys = fk_map.into_values().collect();
254
255 let idx_sql = queries::indexes_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
257 let idx_rows = client
258 .query(&idx_sql, &[])
259 .await
260 .map_err(|e| CliError::Database(format!("Failed to query indexes: {}", e)))?;
261
262 let mut idx_map: HashMap<String, IndexInfo> = HashMap::new();
263 for idx_row in idx_rows {
264 let idx_name: String = idx_row.get(0);
265 let col_name: String = idx_row.get(1);
266 let is_unique: bool = idx_row.get(2);
267 let is_primary: bool = idx_row.get(3);
268 let idx_type: Option<String> = idx_row.try_get(4).ok();
269 let filter: Option<String> = idx_row.try_get(5).ok();
270
271 let idx = idx_map.entry(idx_name.clone()).or_insert_with(|| {
272 IndexInfo {
273 name: idx_name,
274 columns: Vec::new(),
275 is_unique,
276 is_primary,
277 index_type: idx_type,
278 filter,
279 }
280 });
281
282 idx.columns.push(IndexColumn {
283 name: col_name,
284 order: SortOrder::Asc,
285 ..Default::default()
286 });
287 }
288
289 table.indexes = idx_map.into_values().collect();
290
291 db_schema.tables.push(table);
292 }
293
294 let enums_sql = queries::enums_query(Some(schema_name));
296 let enum_rows = client
297 .query(&enums_sql, &[])
298 .await
299 .map_err(|e| CliError::Database(format!("Failed to query enums: {}", e)))?;
300
301 let mut enum_map: HashMap<String, EnumInfo> = HashMap::new();
302 for enum_row in enum_rows {
303 let enum_name: String = enum_row.get(0);
304 let enum_value: String = enum_row.get(1);
305
306 let enum_info = enum_map.entry(enum_name.clone()).or_insert_with(|| {
307 EnumInfo {
308 name: enum_name,
309 schema: Some(schema_name.to_string()),
310 values: Vec::new(),
311 }
312 });
313
314 enum_info.values.push(enum_value);
315 }
316
317 db_schema.enums = enum_map.into_values().collect();
318
319 if options.include_views || options.include_materialized_views {
321 let views_sql = queries::views_query(DatabaseType::PostgreSQL, Some(schema_name));
322 let view_rows = client
323 .query(&views_sql, &[])
324 .await
325 .map_err(|e| CliError::Database(format!("Failed to query views: {}", e)))?;
326
327 for view_row in view_rows {
328 let view_name: String = view_row.get(0);
329 let definition: Option<String> = view_row.try_get(1).ok();
330 let is_materialized: bool = view_row.get(2);
331
332 if is_materialized && !options.include_materialized_views {
333 continue;
334 }
335 if !is_materialized && !options.include_views {
336 continue;
337 }
338
339 db_schema.views.push(ViewInfo {
340 name: view_name,
341 schema: Some(schema_name.to_string()),
342 definition,
343 is_materialized,
344 columns: Vec::new(),
345 });
346 }
347 }
348
349 Ok(db_schema)
350 }
351 }
352
353 fn matches_pattern(name: &str, pattern: &str) -> bool {
355 if pattern == "*" {
356 return true;
357 }
358
359 if pattern.starts_with('*') && pattern.ends_with('*') {
360 let middle = &pattern[1..pattern.len() - 1];
361 return name.contains(middle);
362 }
363
364 if pattern.starts_with('*') {
365 let suffix = &pattern[1..];
366 return name.ends_with(suffix);
367 }
368
369 if pattern.ends_with('*') {
370 let prefix = &pattern[..pattern.len() - 1];
371 return name.starts_with(prefix);
372 }
373
374 name == pattern
375 }
376}
377
378pub fn format_as_prax(schema: &DatabaseSchema, config: &Config) -> String {
384 let mut output = String::new();
385
386 output.push_str("// Generated by `prax db pull`\n");
387 output.push_str("// Edit this file to customize your schema\n\n");
388
389 output.push_str("datasource db {\n");
390 output.push_str(&format!(" provider = \"{}\"\n", config.database.provider));
391 output.push_str(" url = env(\"DATABASE_URL\")\n");
392 output.push_str("}\n\n");
393
394 output.push_str("generator client {\n");
395 output.push_str(" provider = \"prax-client-rust\"\n");
396 output.push_str(" output = \"./src/generated\"\n");
397 output.push_str("}\n\n");
398
399 output.push_str(&generate_prax_schema(schema));
401
402 output
403}
404
405pub fn format_as_json(schema: &DatabaseSchema) -> CliResult<String> {
407 serde_json::to_string_pretty(schema)
408 .map_err(|e| CliError::Config(format!("Failed to serialize schema: {}", e)))
409}
410
411pub fn format_as_sql(schema: &DatabaseSchema, db_type: DatabaseType) -> String {
413 let mut output = String::new();
414
415 output.push_str("-- Generated by `prax db pull`\n");
416 output.push_str(&format!("-- Database: {}\n\n", db_type_name(db_type)));
417
418 if db_type == DatabaseType::PostgreSQL {
420 for enum_info in &schema.enums {
421 output.push_str(&format!("CREATE TYPE {} AS ENUM (\n", enum_info.name));
422 let values: Vec<String> = enum_info.values.iter().map(|v| format!(" '{}'", v)).collect();
423 output.push_str(&values.join(",\n"));
424 output.push_str("\n);\n\n");
425 }
426 }
427
428 for table in &schema.tables {
430 output.push_str(&format!("CREATE TABLE {} (\n", quote_identifier(&table.name, db_type)));
431
432 let mut col_defs: Vec<String> = Vec::new();
433
434 for col in &table.columns {
435 let mut def = format!(
436 " {} {}",
437 quote_identifier(&col.name, db_type),
438 col.db_type
439 );
440
441 if !col.nullable {
442 def.push_str(" NOT NULL");
443 }
444
445 if let Some(ref default) = col.default {
446 def.push_str(&format!(" DEFAULT {}", default));
447 }
448
449 col_defs.push(def);
450 }
451
452 if !table.primary_key.is_empty() {
454 let pk_cols: Vec<String> = table
455 .primary_key
456 .iter()
457 .map(|c| quote_identifier(c, db_type))
458 .collect();
459 col_defs.push(format!(" PRIMARY KEY ({})", pk_cols.join(", ")));
460 }
461
462 output.push_str(&col_defs.join(",\n"));
463 output.push_str("\n);\n\n");
464
465 for idx in &table.indexes {
467 if idx.is_primary {
468 continue;
469 }
470
471 let unique = if idx.is_unique { "UNIQUE " } else { "" };
472 let cols: Vec<String> = idx.columns.iter().map(|c| quote_identifier(&c.name, db_type)).collect();
473
474 output.push_str(&format!(
475 "CREATE {}INDEX {} ON {} ({});\n",
476 unique,
477 quote_identifier(&idx.name, db_type),
478 quote_identifier(&table.name, db_type),
479 cols.join(", ")
480 ));
481 }
482
483 output.push('\n');
484 }
485
486 output
487}
488
489fn db_type_name(db_type: DatabaseType) -> &'static str {
490 match db_type {
491 DatabaseType::PostgreSQL => "PostgreSQL",
492 DatabaseType::MySQL => "MySQL",
493 DatabaseType::SQLite => "SQLite",
494 DatabaseType::MSSQL => "SQL Server",
495 }
496}
497
498fn quote_identifier(name: &str, db_type: DatabaseType) -> String {
499 match db_type {
500 DatabaseType::PostgreSQL => format!("\"{}\"", name),
501 DatabaseType::MySQL => format!("`{}`", name),
502 DatabaseType::SQLite => format!("\"{}\"", name),
503 DatabaseType::MSSQL => format!("[{}]", name),
504 }
505}
506