Skip to main content

prax_cli/commands/
introspect.rs

1//! Database introspection implementation.
2//!
3//! This module provides the actual database introspection functionality
4//! using the `prax-query` introspection types.
5
6use std::collections::HashMap;
7
8use prax_query::introspection::{
9    ColumnInfo, DatabaseSchema, EnumInfo, ForeignKeyInfo, IndexColumn, IndexInfo,
10    ReferentialAction, SortOrder, TableInfo, ViewInfo, generate_prax_schema, normalize_type,
11    queries,
12};
13use prax_query::sql::DatabaseType;
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17
18/// Introspection options.
19#[derive(Debug, Clone)]
20pub struct IntrospectionOptions {
21    /// Schema/namespace to introspect.
22    pub schema: Option<String>,
23    /// Include views.
24    pub include_views: bool,
25    /// Include materialized views.
26    pub include_materialized_views: bool,
27    /// Table filter pattern.
28    pub table_filter: Option<String>,
29    /// Tables to exclude.
30    pub exclude_pattern: Option<String>,
31    /// Include comments.
32    pub include_comments: bool,
33    /// Sample size for MongoDB.
34    pub sample_size: usize,
35}
36
37impl Default for IntrospectionOptions {
38    fn default() -> Self {
39        Self {
40            schema: None,
41            include_views: false,
42            include_materialized_views: false,
43            table_filter: None,
44            exclude_pattern: None,
45            include_comments: true,
46            sample_size: 100,
47        }
48    }
49}
50
51/// Database introspector trait.
52#[allow(async_fn_in_trait)]
53pub trait Introspector {
54    /// Introspect the database and return schema information.
55    async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema>;
56}
57
58/// Get the database type from provider string.
59pub fn get_database_type(provider: &str) -> CliResult<DatabaseType> {
60    match provider.to_lowercase().as_str() {
61        "postgresql" | "postgres" | "pg" => Ok(DatabaseType::PostgreSQL),
62        "mysql" | "mariadb" => Ok(DatabaseType::MySQL),
63        "sqlite" | "sqlite3" => Ok(DatabaseType::SQLite),
64        "mssql" | "sqlserver" | "sql_server" => Ok(DatabaseType::MSSQL),
65        _ => Err(CliError::Config(format!(
66            "Unsupported database provider: {}",
67            provider
68        ))),
69    }
70}
71
72/// Get default schema for database type.
73pub fn default_schema(db_type: DatabaseType) -> &'static str {
74    match db_type {
75        DatabaseType::PostgreSQL => "public",
76        DatabaseType::MySQL => "",
77        DatabaseType::SQLite => "",
78        DatabaseType::MSSQL => "dbo",
79    }
80}
81
82// ============================================================================
83// PostgreSQL Introspector
84// ============================================================================
85
86#[cfg(feature = "postgres")]
87pub mod postgres {
88    use super::*;
89    use tokio_postgres::{Client, NoTls};
90
91    /// PostgreSQL introspector.
92    pub struct PostgresIntrospector {
93        connection_string: String,
94    }
95
96    impl PostgresIntrospector {
97        /// Create a new PostgreSQL introspector.
98        pub fn new(connection_string: String) -> Self {
99            Self { connection_string }
100        }
101
102        /// Connect to the database.
103        async fn connect(&self) -> CliResult<Client> {
104            let (client, connection) = tokio_postgres::connect(&self.connection_string, NoTls)
105                .await
106                .map_err(|e| CliError::Database(format!("Failed to connect: {}", e)))?;
107
108            // Spawn the connection handler
109            tokio::spawn(async move {
110                if let Err(e) = connection.await {
111                    eprintln!("Connection error: {}", e);
112                }
113            });
114
115            Ok(client)
116        }
117    }
118
119    impl Introspector for PostgresIntrospector {
120        async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema> {
121            let client = self.connect().await?;
122            let schema_name = options.schema.as_deref().unwrap_or("public");
123
124            let mut db_schema = DatabaseSchema {
125                name: "database".to_string(),
126                schema: Some(schema_name.to_string()),
127                ..Default::default()
128            };
129
130            // Get tables
131            let tables_sql = queries::tables_query(DatabaseType::PostgreSQL, Some(schema_name));
132            let table_rows = client
133                .query(&tables_sql, &[])
134                .await
135                .map_err(|e| CliError::Database(format!("Failed to query tables: {}", e)))?;
136
137            for row in table_rows {
138                let table_name: String = row.get(0);
139
140                // Apply filters
141                if let Some(ref pattern) = options.table_filter {
142                    if !matches_pattern(&table_name, pattern) {
143                        continue;
144                    }
145                }
146                if let Some(ref exclude) = options.exclude_pattern {
147                    if matches_pattern(&table_name, exclude) {
148                        continue;
149                    }
150                }
151
152                let comment: Option<String> = row.try_get(1).ok();
153
154                let mut table = TableInfo {
155                    name: table_name.clone(),
156                    schema: Some(schema_name.to_string()),
157                    comment: if options.include_comments {
158                        comment
159                    } else {
160                        None
161                    },
162                    ..Default::default()
163                };
164
165                // Get columns
166                let cols_sql = queries::columns_query(
167                    DatabaseType::PostgreSQL,
168                    &table_name,
169                    Some(schema_name),
170                );
171                let col_rows = client
172                    .query(&cols_sql, &[])
173                    .await
174                    .map_err(|e| CliError::Database(format!("Failed to query columns: {}", e)))?;
175
176                for col_row in col_rows {
177                    let col_name: String = col_row.get(0);
178                    let data_type: String = col_row.get(1);
179                    let udt_name: String = col_row.get(2);
180                    let nullable: bool = col_row.get(3);
181                    let default: Option<String> = col_row.try_get(4).ok();
182                    let max_length: Option<i32> = col_row.try_get(5).ok();
183                    let precision: Option<i32> = col_row.try_get(6).ok();
184                    let scale: Option<i32> = col_row.try_get(7).ok();
185                    let comment: Option<String> = col_row.try_get(8).ok();
186                    let auto_increment: bool = col_row.try_get(9).unwrap_or(false);
187
188                    let normalized = normalize_type(
189                        DatabaseType::PostgreSQL,
190                        &udt_name,
191                        max_length,
192                        precision,
193                        scale,
194                    );
195
196                    table.columns.push(ColumnInfo {
197                        name: col_name,
198                        db_type: data_type,
199                        normalized_type: normalized,
200                        nullable,
201                        default,
202                        auto_increment,
203                        max_length,
204                        precision,
205                        scale,
206                        comment: if options.include_comments {
207                            comment
208                        } else {
209                            None
210                        },
211                        ..Default::default()
212                    });
213                }
214
215                // Get primary keys
216                let pk_sql = queries::primary_keys_query(
217                    DatabaseType::PostgreSQL,
218                    &table_name,
219                    Some(schema_name),
220                );
221                let pk_rows = client.query(&pk_sql, &[]).await.map_err(|e| {
222                    CliError::Database(format!("Failed to query primary keys: {}", e))
223                })?;
224
225                for pk_row in pk_rows {
226                    let col_name: String = pk_row.get(0);
227                    table.primary_key.push(col_name.clone());
228
229                    // Mark column as primary key
230                    if let Some(col) = table.columns.iter_mut().find(|c| c.name == col_name) {
231                        col.is_primary_key = true;
232                    }
233                }
234
235                // Get foreign keys
236                let fk_sql = queries::foreign_keys_query(
237                    DatabaseType::PostgreSQL,
238                    &table_name,
239                    Some(schema_name),
240                );
241                let fk_rows = client.query(&fk_sql, &[]).await.map_err(|e| {
242                    CliError::Database(format!("Failed to query foreign keys: {}", e))
243                })?;
244
245                let mut fk_map: HashMap<String, ForeignKeyInfo> = HashMap::new();
246                for fk_row in fk_rows {
247                    let constraint_name: String = fk_row.get(0);
248                    let column_name: String = fk_row.get(1);
249                    let ref_table: String = fk_row.get(2);
250                    let ref_schema: Option<String> = fk_row.try_get(3).ok();
251                    let ref_column: String = fk_row.get(4);
252                    let delete_rule: String = fk_row.get(5);
253                    let update_rule: String = fk_row.get(6);
254
255                    let fk =
256                        fk_map
257                            .entry(constraint_name.clone())
258                            .or_insert_with(|| ForeignKeyInfo {
259                                name: constraint_name,
260                                columns: Vec::new(),
261                                referenced_table: ref_table,
262                                referenced_schema: ref_schema,
263                                referenced_columns: Vec::new(),
264                                on_delete: ReferentialAction::from_str(&delete_rule),
265                                on_update: ReferentialAction::from_str(&update_rule),
266                            });
267
268                    fk.columns.push(column_name);
269                    fk.referenced_columns.push(ref_column);
270                }
271
272                table.foreign_keys = fk_map.into_values().collect();
273
274                // Get indexes
275                let idx_sql = queries::indexes_query(
276                    DatabaseType::PostgreSQL,
277                    &table_name,
278                    Some(schema_name),
279                );
280                let idx_rows = client
281                    .query(&idx_sql, &[])
282                    .await
283                    .map_err(|e| CliError::Database(format!("Failed to query indexes: {}", e)))?;
284
285                let mut idx_map: HashMap<String, IndexInfo> = HashMap::new();
286                for idx_row in idx_rows {
287                    let idx_name: String = idx_row.get(0);
288                    let col_name: String = idx_row.get(1);
289                    let is_unique: bool = idx_row.get(2);
290                    let is_primary: bool = idx_row.get(3);
291                    let idx_type: Option<String> = idx_row.try_get(4).ok();
292                    let filter: Option<String> = idx_row.try_get(5).ok();
293
294                    let idx = idx_map
295                        .entry(idx_name.clone())
296                        .or_insert_with(|| IndexInfo {
297                            name: idx_name,
298                            columns: Vec::new(),
299                            is_unique,
300                            is_primary,
301                            index_type: idx_type,
302                            filter,
303                        });
304
305                    idx.columns.push(IndexColumn {
306                        name: col_name,
307                        order: SortOrder::Asc,
308                        ..Default::default()
309                    });
310                }
311
312                table.indexes = idx_map.into_values().collect();
313
314                db_schema.tables.push(table);
315            }
316
317            // Get enums
318            let enums_sql = queries::enums_query(Some(schema_name));
319            let enum_rows = client
320                .query(&enums_sql, &[])
321                .await
322                .map_err(|e| CliError::Database(format!("Failed to query enums: {}", e)))?;
323
324            let mut enum_map: HashMap<String, EnumInfo> = HashMap::new();
325            for enum_row in enum_rows {
326                let enum_name: String = enum_row.get(0);
327                let enum_value: String = enum_row.get(1);
328
329                let enum_info = enum_map
330                    .entry(enum_name.clone())
331                    .or_insert_with(|| EnumInfo {
332                        name: enum_name,
333                        schema: Some(schema_name.to_string()),
334                        values: Vec::new(),
335                    });
336
337                enum_info.values.push(enum_value);
338            }
339
340            db_schema.enums = enum_map.into_values().collect();
341
342            // Get views
343            if options.include_views || options.include_materialized_views {
344                let views_sql = queries::views_query(DatabaseType::PostgreSQL, Some(schema_name));
345                let view_rows = client
346                    .query(&views_sql, &[])
347                    .await
348                    .map_err(|e| CliError::Database(format!("Failed to query views: {}", e)))?;
349
350                for view_row in view_rows {
351                    let view_name: String = view_row.get(0);
352                    let definition: Option<String> = view_row.try_get(1).ok();
353                    let is_materialized: bool = view_row.get(2);
354
355                    if is_materialized && !options.include_materialized_views {
356                        continue;
357                    }
358                    if !is_materialized && !options.include_views {
359                        continue;
360                    }
361
362                    db_schema.views.push(ViewInfo {
363                        name: view_name,
364                        schema: Some(schema_name.to_string()),
365                        definition,
366                        is_materialized,
367                        columns: Vec::new(),
368                    });
369                }
370            }
371
372            Ok(db_schema)
373        }
374    }
375
376    /// Simple glob-style pattern matching.
377    fn matches_pattern(name: &str, pattern: &str) -> bool {
378        if pattern == "*" {
379            return true;
380        }
381
382        if pattern.starts_with('*') && pattern.ends_with('*') {
383            let middle = &pattern[1..pattern.len() - 1];
384            return name.contains(middle);
385        }
386
387        if pattern.starts_with('*') {
388            let suffix = &pattern[1..];
389            return name.ends_with(suffix);
390        }
391
392        if pattern.ends_with('*') {
393            let prefix = &pattern[..pattern.len() - 1];
394            return name.starts_with(prefix);
395        }
396
397        name == pattern
398    }
399}
400
401// ============================================================================
402// Output Formatters
403// ============================================================================
404
405/// Generate Prax schema output.
406pub fn format_as_prax(schema: &DatabaseSchema, config: &Config) -> String {
407    let mut output = String::new();
408
409    output.push_str("// Generated by `prax db pull`\n");
410    output.push_str("// Edit this file to customize your schema\n\n");
411
412    output.push_str("datasource db {\n");
413    output.push_str(&format!(
414        "    provider = \"{}\"\n",
415        config.database.provider
416    ));
417    output.push_str("    url      = env(\"DATABASE_URL\")\n");
418    output.push_str("}\n\n");
419
420    output.push_str("generator client {\n");
421    output.push_str("    provider = \"prax-client-rust\"\n");
422    output.push_str("    output   = \"./src/generated\"\n");
423    output.push_str("}\n\n");
424
425    // Use the generate_prax_schema function
426    output.push_str(&generate_prax_schema(schema));
427
428    output
429}
430
431/// Generate JSON output.
432pub fn format_as_json(schema: &DatabaseSchema) -> CliResult<String> {
433    serde_json::to_string_pretty(schema)
434        .map_err(|e| CliError::Config(format!("Failed to serialize schema: {}", e)))
435}
436
437/// Generate SQL DDL output.
438pub fn format_as_sql(schema: &DatabaseSchema, db_type: DatabaseType) -> String {
439    let mut output = String::new();
440
441    output.push_str("-- Generated by `prax db pull`\n");
442    output.push_str(&format!("-- Database: {}\n\n", db_type_name(db_type)));
443
444    // Generate enums (PostgreSQL only)
445    if db_type == DatabaseType::PostgreSQL {
446        for enum_info in &schema.enums {
447            output.push_str(&format!("CREATE TYPE {} AS ENUM (\n", enum_info.name));
448            let values: Vec<String> = enum_info
449                .values
450                .iter()
451                .map(|v| format!("    '{}'", v))
452                .collect();
453            output.push_str(&values.join(",\n"));
454            output.push_str("\n);\n\n");
455        }
456    }
457
458    // Generate tables
459    for table in &schema.tables {
460        output.push_str(&format!(
461            "CREATE TABLE {} (\n",
462            quote_identifier(&table.name, db_type)
463        ));
464
465        let mut col_defs: Vec<String> = Vec::new();
466
467        for col in &table.columns {
468            let mut def = format!(
469                "    {} {}",
470                quote_identifier(&col.name, db_type),
471                col.db_type
472            );
473
474            if !col.nullable {
475                def.push_str(" NOT NULL");
476            }
477
478            if let Some(ref default) = col.default {
479                def.push_str(&format!(" DEFAULT {}", default));
480            }
481
482            col_defs.push(def);
483        }
484
485        // Primary key
486        if !table.primary_key.is_empty() {
487            let pk_cols: Vec<String> = table
488                .primary_key
489                .iter()
490                .map(|c| quote_identifier(c, db_type))
491                .collect();
492            col_defs.push(format!("    PRIMARY KEY ({})", pk_cols.join(", ")));
493        }
494
495        output.push_str(&col_defs.join(",\n"));
496        output.push_str("\n);\n\n");
497
498        // Indexes
499        for idx in &table.indexes {
500            if idx.is_primary {
501                continue;
502            }
503
504            let unique = if idx.is_unique { "UNIQUE " } else { "" };
505            let cols: Vec<String> = idx
506                .columns
507                .iter()
508                .map(|c| quote_identifier(&c.name, db_type))
509                .collect();
510
511            output.push_str(&format!(
512                "CREATE {}INDEX {} ON {} ({});\n",
513                unique,
514                quote_identifier(&idx.name, db_type),
515                quote_identifier(&table.name, db_type),
516                cols.join(", ")
517            ));
518        }
519
520        output.push('\n');
521    }
522
523    output
524}
525
526fn db_type_name(db_type: DatabaseType) -> &'static str {
527    match db_type {
528        DatabaseType::PostgreSQL => "PostgreSQL",
529        DatabaseType::MySQL => "MySQL",
530        DatabaseType::SQLite => "SQLite",
531        DatabaseType::MSSQL => "SQL Server",
532    }
533}
534
535fn quote_identifier(name: &str, db_type: DatabaseType) -> String {
536    match db_type {
537        DatabaseType::PostgreSQL => format!("\"{}\"", name),
538        DatabaseType::MySQL => format!("`{}`", name),
539        DatabaseType::SQLite => format!("\"{}\"", name),
540        DatabaseType::MSSQL => format!("[{}]", name),
541    }
542}