prax_cli/commands/
introspect.rs

1//! Database introspection implementation.
2//!
3//! This module provides the actual database introspection functionality
4//! using the `prax-query` introspection types.
5
6use std::collections::HashMap;
7
8use prax_query::introspection::{
9    ColumnInfo, DatabaseSchema, EnumInfo, ForeignKeyInfo, IndexColumn, IndexInfo,
10    ReferentialAction, SortOrder, TableInfo, ViewInfo,
11    generate_prax_schema, normalize_type, queries,
12};
13use prax_query::sql::DatabaseType;
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17
18/// Introspection options.
19#[derive(Debug, Clone)]
20pub struct IntrospectionOptions {
21    /// Schema/namespace to introspect.
22    pub schema: Option<String>,
23    /// Include views.
24    pub include_views: bool,
25    /// Include materialized views.
26    pub include_materialized_views: bool,
27    /// Table filter pattern.
28    pub table_filter: Option<String>,
29    /// Tables to exclude.
30    pub exclude_pattern: Option<String>,
31    /// Include comments.
32    pub include_comments: bool,
33    /// Sample size for MongoDB.
34    pub sample_size: usize,
35}
36
37impl Default for IntrospectionOptions {
38    fn default() -> Self {
39        Self {
40            schema: None,
41            include_views: false,
42            include_materialized_views: false,
43            table_filter: None,
44            exclude_pattern: None,
45            include_comments: true,
46            sample_size: 100,
47        }
48    }
49}
50
51/// Database introspector trait.
52#[allow(async_fn_in_trait)]
53pub trait Introspector {
54    /// Introspect the database and return schema information.
55    async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema>;
56}
57
58/// Get the database type from provider string.
59pub fn get_database_type(provider: &str) -> CliResult<DatabaseType> {
60    match provider.to_lowercase().as_str() {
61        "postgresql" | "postgres" | "pg" => Ok(DatabaseType::PostgreSQL),
62        "mysql" | "mariadb" => Ok(DatabaseType::MySQL),
63        "sqlite" | "sqlite3" => Ok(DatabaseType::SQLite),
64        "mssql" | "sqlserver" | "sql_server" => Ok(DatabaseType::MSSQL),
65        _ => Err(CliError::Config(format!(
66            "Unsupported database provider: {}",
67            provider
68        ))),
69    }
70}
71
72/// Get default schema for database type.
73pub fn default_schema(db_type: DatabaseType) -> &'static str {
74    match db_type {
75        DatabaseType::PostgreSQL => "public",
76        DatabaseType::MySQL => "",
77        DatabaseType::SQLite => "",
78        DatabaseType::MSSQL => "dbo",
79    }
80}
81
82// ============================================================================
83// PostgreSQL Introspector
84// ============================================================================
85
86#[cfg(feature = "postgres")]
87pub mod postgres {
88    use super::*;
89    use tokio_postgres::{Client, NoTls};
90
91    /// PostgreSQL introspector.
92    pub struct PostgresIntrospector {
93        connection_string: String,
94    }
95
96    impl PostgresIntrospector {
97        /// Create a new PostgreSQL introspector.
98        pub fn new(connection_string: String) -> Self {
99            Self { connection_string }
100        }
101
102        /// Connect to the database.
103        async fn connect(&self) -> CliResult<Client> {
104            let (client, connection) = tokio_postgres::connect(&self.connection_string, NoTls)
105                .await
106                .map_err(|e| CliError::Database(format!("Failed to connect: {}", e)))?;
107
108            // Spawn the connection handler
109            tokio::spawn(async move {
110                if let Err(e) = connection.await {
111                    eprintln!("Connection error: {}", e);
112                }
113            });
114
115            Ok(client)
116        }
117    }
118
119    impl Introspector for PostgresIntrospector {
120        async fn introspect(&self, options: &IntrospectionOptions) -> CliResult<DatabaseSchema> {
121            let client = self.connect().await?;
122            let schema_name = options.schema.as_deref().unwrap_or("public");
123
124            let mut db_schema = DatabaseSchema {
125                name: "database".to_string(),
126                schema: Some(schema_name.to_string()),
127                ..Default::default()
128            };
129
130            // Get tables
131            let tables_sql = queries::tables_query(DatabaseType::PostgreSQL, Some(schema_name));
132            let table_rows = client
133                .query(&tables_sql, &[])
134                .await
135                .map_err(|e| CliError::Database(format!("Failed to query tables: {}", e)))?;
136
137            for row in table_rows {
138                let table_name: String = row.get(0);
139
140                // Apply filters
141                if let Some(ref pattern) = options.table_filter {
142                    if !matches_pattern(&table_name, pattern) {
143                        continue;
144                    }
145                }
146                if let Some(ref exclude) = options.exclude_pattern {
147                    if matches_pattern(&table_name, exclude) {
148                        continue;
149                    }
150                }
151
152                let comment: Option<String> = row.try_get(1).ok();
153
154                let mut table = TableInfo {
155                    name: table_name.clone(),
156                    schema: Some(schema_name.to_string()),
157                    comment: if options.include_comments { comment } else { None },
158                    ..Default::default()
159                };
160
161                // Get columns
162                let cols_sql = queries::columns_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
163                let col_rows = client
164                    .query(&cols_sql, &[])
165                    .await
166                    .map_err(|e| CliError::Database(format!("Failed to query columns: {}", e)))?;
167
168                for col_row in col_rows {
169                    let col_name: String = col_row.get(0);
170                    let data_type: String = col_row.get(1);
171                    let udt_name: String = col_row.get(2);
172                    let nullable: bool = col_row.get(3);
173                    let default: Option<String> = col_row.try_get(4).ok();
174                    let max_length: Option<i32> = col_row.try_get(5).ok();
175                    let precision: Option<i32> = col_row.try_get(6).ok();
176                    let scale: Option<i32> = col_row.try_get(7).ok();
177                    let comment: Option<String> = col_row.try_get(8).ok();
178                    let auto_increment: bool = col_row.try_get(9).unwrap_or(false);
179
180                    let normalized = normalize_type(
181                        DatabaseType::PostgreSQL,
182                        &udt_name,
183                        max_length,
184                        precision,
185                        scale,
186                    );
187
188                    table.columns.push(ColumnInfo {
189                        name: col_name,
190                        db_type: data_type,
191                        normalized_type: normalized,
192                        nullable,
193                        default,
194                        auto_increment,
195                        max_length,
196                        precision,
197                        scale,
198                        comment: if options.include_comments { comment } else { None },
199                        ..Default::default()
200                    });
201                }
202
203                // Get primary keys
204                let pk_sql = queries::primary_keys_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
205                let pk_rows = client
206                    .query(&pk_sql, &[])
207                    .await
208                    .map_err(|e| CliError::Database(format!("Failed to query primary keys: {}", e)))?;
209
210                for pk_row in pk_rows {
211                    let col_name: String = pk_row.get(0);
212                    table.primary_key.push(col_name.clone());
213
214                    // Mark column as primary key
215                    if let Some(col) = table.columns.iter_mut().find(|c| c.name == col_name) {
216                        col.is_primary_key = true;
217                    }
218                }
219
220                // Get foreign keys
221                let fk_sql = queries::foreign_keys_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
222                let fk_rows = client
223                    .query(&fk_sql, &[])
224                    .await
225                    .map_err(|e| CliError::Database(format!("Failed to query foreign keys: {}", e)))?;
226
227                let mut fk_map: HashMap<String, ForeignKeyInfo> = HashMap::new();
228                for fk_row in fk_rows {
229                    let constraint_name: String = fk_row.get(0);
230                    let column_name: String = fk_row.get(1);
231                    let ref_table: String = fk_row.get(2);
232                    let ref_schema: Option<String> = fk_row.try_get(3).ok();
233                    let ref_column: String = fk_row.get(4);
234                    let delete_rule: String = fk_row.get(5);
235                    let update_rule: String = fk_row.get(6);
236
237                    let fk = fk_map.entry(constraint_name.clone()).or_insert_with(|| {
238                        ForeignKeyInfo {
239                            name: constraint_name,
240                            columns: Vec::new(),
241                            referenced_table: ref_table,
242                            referenced_schema: ref_schema,
243                            referenced_columns: Vec::new(),
244                            on_delete: ReferentialAction::from_str(&delete_rule),
245                            on_update: ReferentialAction::from_str(&update_rule),
246                        }
247                    });
248
249                    fk.columns.push(column_name);
250                    fk.referenced_columns.push(ref_column);
251                }
252
253                table.foreign_keys = fk_map.into_values().collect();
254
255                // Get indexes
256                let idx_sql = queries::indexes_query(DatabaseType::PostgreSQL, &table_name, Some(schema_name));
257                let idx_rows = client
258                    .query(&idx_sql, &[])
259                    .await
260                    .map_err(|e| CliError::Database(format!("Failed to query indexes: {}", e)))?;
261
262                let mut idx_map: HashMap<String, IndexInfo> = HashMap::new();
263                for idx_row in idx_rows {
264                    let idx_name: String = idx_row.get(0);
265                    let col_name: String = idx_row.get(1);
266                    let is_unique: bool = idx_row.get(2);
267                    let is_primary: bool = idx_row.get(3);
268                    let idx_type: Option<String> = idx_row.try_get(4).ok();
269                    let filter: Option<String> = idx_row.try_get(5).ok();
270
271                    let idx = idx_map.entry(idx_name.clone()).or_insert_with(|| {
272                        IndexInfo {
273                            name: idx_name,
274                            columns: Vec::new(),
275                            is_unique,
276                            is_primary,
277                            index_type: idx_type,
278                            filter,
279                        }
280                    });
281
282                    idx.columns.push(IndexColumn {
283                        name: col_name,
284                        order: SortOrder::Asc,
285                        ..Default::default()
286                    });
287                }
288
289                table.indexes = idx_map.into_values().collect();
290
291                db_schema.tables.push(table);
292            }
293
294            // Get enums
295            let enums_sql = queries::enums_query(Some(schema_name));
296            let enum_rows = client
297                .query(&enums_sql, &[])
298                .await
299                .map_err(|e| CliError::Database(format!("Failed to query enums: {}", e)))?;
300
301            let mut enum_map: HashMap<String, EnumInfo> = HashMap::new();
302            for enum_row in enum_rows {
303                let enum_name: String = enum_row.get(0);
304                let enum_value: String = enum_row.get(1);
305
306                let enum_info = enum_map.entry(enum_name.clone()).or_insert_with(|| {
307                    EnumInfo {
308                        name: enum_name,
309                        schema: Some(schema_name.to_string()),
310                        values: Vec::new(),
311                    }
312                });
313
314                enum_info.values.push(enum_value);
315            }
316
317            db_schema.enums = enum_map.into_values().collect();
318
319            // Get views
320            if options.include_views || options.include_materialized_views {
321                let views_sql = queries::views_query(DatabaseType::PostgreSQL, Some(schema_name));
322                let view_rows = client
323                    .query(&views_sql, &[])
324                    .await
325                    .map_err(|e| CliError::Database(format!("Failed to query views: {}", e)))?;
326
327                for view_row in view_rows {
328                    let view_name: String = view_row.get(0);
329                    let definition: Option<String> = view_row.try_get(1).ok();
330                    let is_materialized: bool = view_row.get(2);
331
332                    if is_materialized && !options.include_materialized_views {
333                        continue;
334                    }
335                    if !is_materialized && !options.include_views {
336                        continue;
337                    }
338
339                    db_schema.views.push(ViewInfo {
340                        name: view_name,
341                        schema: Some(schema_name.to_string()),
342                        definition,
343                        is_materialized,
344                        columns: Vec::new(),
345                    });
346                }
347            }
348
349            Ok(db_schema)
350        }
351    }
352
353    /// Simple glob-style pattern matching.
354    fn matches_pattern(name: &str, pattern: &str) -> bool {
355        if pattern == "*" {
356            return true;
357        }
358
359        if pattern.starts_with('*') && pattern.ends_with('*') {
360            let middle = &pattern[1..pattern.len() - 1];
361            return name.contains(middle);
362        }
363
364        if pattern.starts_with('*') {
365            let suffix = &pattern[1..];
366            return name.ends_with(suffix);
367        }
368
369        if pattern.ends_with('*') {
370            let prefix = &pattern[..pattern.len() - 1];
371            return name.starts_with(prefix);
372        }
373
374        name == pattern
375    }
376}
377
378// ============================================================================
379// Output Formatters
380// ============================================================================
381
382/// Generate Prax schema output.
383pub fn format_as_prax(schema: &DatabaseSchema, config: &Config) -> String {
384    let mut output = String::new();
385
386    output.push_str("// Generated by `prax db pull`\n");
387    output.push_str("// Edit this file to customize your schema\n\n");
388
389    output.push_str("datasource db {\n");
390    output.push_str(&format!("    provider = \"{}\"\n", config.database.provider));
391    output.push_str("    url      = env(\"DATABASE_URL\")\n");
392    output.push_str("}\n\n");
393
394    output.push_str("generator client {\n");
395    output.push_str("    provider = \"prax-client-rust\"\n");
396    output.push_str("    output   = \"./src/generated\"\n");
397    output.push_str("}\n\n");
398
399    // Use the generate_prax_schema function
400    output.push_str(&generate_prax_schema(schema));
401
402    output
403}
404
405/// Generate JSON output.
406pub fn format_as_json(schema: &DatabaseSchema) -> CliResult<String> {
407    serde_json::to_string_pretty(schema)
408        .map_err(|e| CliError::Config(format!("Failed to serialize schema: {}", e)))
409}
410
411/// Generate SQL DDL output.
412pub fn format_as_sql(schema: &DatabaseSchema, db_type: DatabaseType) -> String {
413    let mut output = String::new();
414
415    output.push_str("-- Generated by `prax db pull`\n");
416    output.push_str(&format!("-- Database: {}\n\n", db_type_name(db_type)));
417
418    // Generate enums (PostgreSQL only)
419    if db_type == DatabaseType::PostgreSQL {
420        for enum_info in &schema.enums {
421            output.push_str(&format!("CREATE TYPE {} AS ENUM (\n", enum_info.name));
422            let values: Vec<String> = enum_info.values.iter().map(|v| format!("    '{}'", v)).collect();
423            output.push_str(&values.join(",\n"));
424            output.push_str("\n);\n\n");
425        }
426    }
427
428    // Generate tables
429    for table in &schema.tables {
430        output.push_str(&format!("CREATE TABLE {} (\n", quote_identifier(&table.name, db_type)));
431
432        let mut col_defs: Vec<String> = Vec::new();
433
434        for col in &table.columns {
435            let mut def = format!(
436                "    {} {}",
437                quote_identifier(&col.name, db_type),
438                col.db_type
439            );
440
441            if !col.nullable {
442                def.push_str(" NOT NULL");
443            }
444
445            if let Some(ref default) = col.default {
446                def.push_str(&format!(" DEFAULT {}", default));
447            }
448
449            col_defs.push(def);
450        }
451
452        // Primary key
453        if !table.primary_key.is_empty() {
454            let pk_cols: Vec<String> = table
455                .primary_key
456                .iter()
457                .map(|c| quote_identifier(c, db_type))
458                .collect();
459            col_defs.push(format!("    PRIMARY KEY ({})", pk_cols.join(", ")));
460        }
461
462        output.push_str(&col_defs.join(",\n"));
463        output.push_str("\n);\n\n");
464
465        // Indexes
466        for idx in &table.indexes {
467            if idx.is_primary {
468                continue;
469            }
470
471            let unique = if idx.is_unique { "UNIQUE " } else { "" };
472            let cols: Vec<String> = idx.columns.iter().map(|c| quote_identifier(&c.name, db_type)).collect();
473
474            output.push_str(&format!(
475                "CREATE {}INDEX {} ON {} ({});\n",
476                unique,
477                quote_identifier(&idx.name, db_type),
478                quote_identifier(&table.name, db_type),
479                cols.join(", ")
480            ));
481        }
482
483        output.push('\n');
484    }
485
486    output
487}
488
489fn db_type_name(db_type: DatabaseType) -> &'static str {
490    match db_type {
491        DatabaseType::PostgreSQL => "PostgreSQL",
492        DatabaseType::MySQL => "MySQL",
493        DatabaseType::SQLite => "SQLite",
494        DatabaseType::MSSQL => "SQL Server",
495    }
496}
497
498fn quote_identifier(name: &str, db_type: DatabaseType) -> String {
499    match db_type {
500        DatabaseType::PostgreSQL => format!("\"{}\"", name),
501        DatabaseType::MySQL => format!("`{}`", name),
502        DatabaseType::SQLite => format!("\"{}\"", name),
503        DatabaseType::MSSQL => format!("[{}]", name),
504    }
505}
506