Skip to main content

fraiseql_core/db/postgres/
introspector.rs

1/// ! PostgreSQL database introspection for fact tables.
2use async_trait::async_trait;
3use deadpool_postgres::Pool;
4use tokio_postgres::Row;
5
6use crate::{
7    compiler::fact_table::{DatabaseIntrospector, DatabaseType},
8    error::{FraiseQLError, Result},
9};
10
11/// PostgreSQL introspector for fact table metadata.
12pub struct PostgresIntrospector {
13    pool: Pool,
14}
15
16impl PostgresIntrospector {
17    /// Create new PostgreSQL introspector from connection pool.
18    #[must_use]
19    pub const fn new(pool: Pool) -> Self {
20        Self { pool }
21    }
22}
23
24#[async_trait]
25impl DatabaseIntrospector for PostgresIntrospector {
26    async fn list_fact_tables(&self) -> Result<Vec<String>> {
27        let client = self.pool.get().await.map_err(|e| FraiseQLError::ConnectionPool {
28            message: format!("Failed to acquire connection: {e}"),
29        })?;
30
31        // Query information_schema for tables matching tf_* pattern
32        let query = r"
33            SELECT table_name
34            FROM information_schema.tables
35            WHERE table_schema = 'public'
36              AND table_type = 'BASE TABLE'
37              AND table_name LIKE 'tf_%'
38            ORDER BY table_name
39        ";
40
41        let rows: Vec<Row> =
42            client.query(query, &[]).await.map_err(|e| FraiseQLError::Database {
43                message:   format!("Failed to list fact tables: {e}"),
44                sql_state: e.code().map(|c| c.code().to_string()),
45            })?;
46
47        let tables = rows
48            .into_iter()
49            .map(|row| {
50                let name: String = row.get(0);
51                name
52            })
53            .collect();
54
55        Ok(tables)
56    }
57
58    async fn get_columns(&self, table_name: &str) -> Result<Vec<(String, String, bool)>> {
59        let client = self.pool.get().await.map_err(|e| FraiseQLError::ConnectionPool {
60            message: format!("Failed to acquire connection: {e}"),
61        })?;
62
63        // Query information_schema for column information
64        let query = r"
65            SELECT
66                column_name,
67                data_type,
68                is_nullable = 'YES' as is_nullable
69            FROM information_schema.columns
70            WHERE table_name = $1
71            AND table_schema = 'public'
72            ORDER BY ordinal_position
73        ";
74
75        let rows: Vec<Row> =
76            client.query(query, &[&table_name]).await.map_err(|e| FraiseQLError::Database {
77                message:   format!("Failed to query column information: {e}"),
78                sql_state: e.code().map(|c| c.code().to_string()),
79            })?;
80
81        let columns = rows
82            .into_iter()
83            .map(|row| {
84                let name: String = row.get(0);
85                let data_type: String = row.get(1);
86                let is_nullable: bool = row.get(2);
87                (name, data_type, is_nullable)
88            })
89            .collect();
90
91        Ok(columns)
92    }
93
94    async fn get_indexed_columns(&self, table_name: &str) -> Result<Vec<String>> {
95        let client = self.pool.get().await.map_err(|e| FraiseQLError::ConnectionPool {
96            message: format!("Failed to acquire connection: {e}"),
97        })?;
98
99        // Query pg_indexes for indexed columns
100        let query = r"
101            SELECT DISTINCT
102                a.attname as column_name
103            FROM
104                pg_index i
105                JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
106                JOIN pg_class t ON t.oid = i.indrelid
107                JOIN pg_namespace n ON n.oid = t.relnamespace
108            WHERE
109                t.relname = $1
110                AND n.nspname = 'public'
111                AND a.attnum > 0
112            ORDER BY column_name
113        ";
114
115        let rows: Vec<Row> =
116            client.query(query, &[&table_name]).await.map_err(|e| FraiseQLError::Database {
117                message:   format!("Failed to query index information: {e}"),
118                sql_state: e.code().map(|c| c.code().to_string()),
119            })?;
120
121        let indexed_columns = rows
122            .into_iter()
123            .map(|row| {
124                let name: String = row.get(0);
125                name
126            })
127            .collect();
128
129        Ok(indexed_columns)
130    }
131
132    fn database_type(&self) -> DatabaseType {
133        DatabaseType::PostgreSQL
134    }
135
136    async fn get_sample_jsonb(
137        &self,
138        table_name: &str,
139        column_name: &str,
140    ) -> Result<Option<serde_json::Value>> {
141        let client = self.pool.get().await.map_err(|e| FraiseQLError::ConnectionPool {
142            message: format!("Failed to acquire connection: {e}"),
143        })?;
144
145        // Query for a sample row with non-null JSON data
146        // Use format! for identifiers (safe because we validate table_name pattern)
147        let query = format!(
148            r#"
149            SELECT "{column}"::text
150            FROM "{table}"
151            WHERE "{column}" IS NOT NULL
152            LIMIT 1
153            "#,
154            table = table_name,
155            column = column_name
156        );
157
158        let rows: Vec<Row> =
159            client.query(&query, &[]).await.map_err(|e| FraiseQLError::Database {
160                message:   format!("Failed to query sample JSONB: {e}"),
161                sql_state: e.code().map(|c| c.code().to_string()),
162            })?;
163
164        if rows.is_empty() {
165            return Ok(None);
166        }
167
168        let json_text: Option<String> = rows[0].get(0);
169        if let Some(text) = json_text {
170            let value: serde_json::Value =
171                serde_json::from_str(&text).map_err(|e| FraiseQLError::Parse {
172                    message:  format!("Failed to parse JSONB sample: {e}"),
173                    location: format!("{table_name}.{column_name}"),
174                })?;
175            Ok(Some(value))
176        } else {
177            Ok(None)
178        }
179    }
180}
181
182impl PostgresIntrospector {
183    /// Get indexed columns for a view/table that match the nested path naming convention.
184    ///
185    /// This method introspects the database to find columns that follow the FraiseQL
186    /// indexed column naming conventions:
187    /// - Human-readable: `items__product__category__code` (double underscore separated)
188    /// - Entity ID format: `f{entity_id}__{field_name}` (e.g., `f200100__code`)
189    ///
190    /// These columns are created by DBAs to optimize filtering on nested GraphQL paths
191    /// by avoiding JSONB extraction at runtime.
192    ///
193    /// # Arguments
194    ///
195    /// * `view_name` - Name of the view or table to introspect
196    ///
197    /// # Returns
198    ///
199    /// Set of column names that match the indexed column naming conventions.
200    ///
201    /// # Example
202    ///
203    /// ```rust,ignore
204    /// let introspector = PostgresIntrospector::new(pool);
205    /// let indexed_cols = introspector.get_indexed_nested_columns("v_order_items").await?;
206    /// // Returns: {"items__product__category__code", "f200100__code", ...}
207    /// ```
208    pub async fn get_indexed_nested_columns(
209        &self,
210        view_name: &str,
211    ) -> Result<std::collections::HashSet<String>> {
212        let client = self.pool.get().await.map_err(|e| FraiseQLError::ConnectionPool {
213            message: format!("Failed to acquire connection: {e}"),
214        })?;
215
216        // Query information_schema for columns matching __ pattern
217        // This works for both views and tables
218        let query = r"
219            SELECT column_name
220            FROM information_schema.columns
221            WHERE table_name = $1
222              AND table_schema = 'public'
223              AND column_name LIKE '%__%'
224            ORDER BY column_name
225        ";
226
227        let rows: Vec<Row> =
228            client.query(query, &[&view_name]).await.map_err(|e| FraiseQLError::Database {
229                message:   format!("Failed to query view columns: {e}"),
230                sql_state: e.code().map(|c| c.code().to_string()),
231            })?;
232
233        let indexed_columns: std::collections::HashSet<String> = rows
234            .into_iter()
235            .map(|row| {
236                let name: String = row.get(0);
237                name
238            })
239            .filter(|name| {
240                // Filter to only columns that match our naming conventions:
241                // 1. Human-readable: path__to__field (at least one __ separator)
242                // 2. Entity ID: f{digits}__field_name
243                Self::is_indexed_column_name(name)
244            })
245            .collect();
246
247        Ok(indexed_columns)
248    }
249
250    /// Check if a column name matches the indexed column naming convention.
251    ///
252    /// Valid patterns:
253    /// - `items__product__category__code` (human-readable nested path)
254    /// - `f200100__code` (entity ID format)
255    fn is_indexed_column_name(name: &str) -> bool {
256        // Must contain at least one double underscore
257        if !name.contains("__") {
258            return false;
259        }
260
261        // Check for entity ID format: f{digits}__field
262        if let Some(rest) = name.strip_prefix('f') {
263            if let Some(underscore_pos) = rest.find("__") {
264                let digits = &rest[..underscore_pos];
265                if digits.chars().all(|c| c.is_ascii_digit()) && !digits.is_empty() {
266                    // Verify the field part is valid
267                    let field_part = &rest[underscore_pos + 2..];
268                    if !field_part.is_empty()
269                        && field_part.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
270                        && !field_part.starts_with(|c: char| c.is_ascii_digit())
271                    {
272                        return true;
273                    }
274                }
275            }
276        }
277
278        // Human-readable format: split by __ and check each segment is valid identifier
279        // Must have at least 2 segments, and first segment must NOT be just 'f'
280        let segments: Vec<&str> = name.split("__").collect();
281        if segments.len() < 2 {
282            return false;
283        }
284
285        // Reject if first segment is just 'f' (reserved for entity ID format)
286        if segments[0] == "f" {
287            return false;
288        }
289
290        // Each segment should be a valid identifier (alphanumeric + underscore, not starting with
291        // digit)
292        segments.iter().all(|s| {
293            !s.is_empty()
294                && s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
295                && !s.starts_with(|c: char| c.is_ascii_digit())
296        })
297    }
298
299    /// Get all column names for a view/table.
300    ///
301    /// # Arguments
302    ///
303    /// * `view_name` - Name of the view or table to introspect
304    ///
305    /// # Returns
306    ///
307    /// List of all column names in the view/table.
308    pub async fn get_view_columns(&self, view_name: &str) -> Result<Vec<String>> {
309        let client = self.pool.get().await.map_err(|e| FraiseQLError::ConnectionPool {
310            message: format!("Failed to acquire connection: {e}"),
311        })?;
312
313        let query = r"
314            SELECT column_name
315            FROM information_schema.columns
316            WHERE table_name = $1
317              AND table_schema = 'public'
318            ORDER BY ordinal_position
319        ";
320
321        let rows: Vec<Row> =
322            client.query(query, &[&view_name]).await.map_err(|e| FraiseQLError::Database {
323                message:   format!("Failed to query view columns: {e}"),
324                sql_state: e.code().map(|c| c.code().to_string()),
325            })?;
326
327        let columns = rows
328            .into_iter()
329            .map(|row| {
330                let name: String = row.get(0);
331                name
332            })
333            .collect();
334
335        Ok(columns)
336    }
337}
338
339/// Unit tests that don't require a PostgreSQL connection.
340#[cfg(test)]
341mod unit_tests {
342    use super::*;
343
344    #[test]
345    fn test_is_indexed_column_name_human_readable() {
346        // Valid human-readable patterns
347        assert!(PostgresIntrospector::is_indexed_column_name("items__product"));
348        assert!(PostgresIntrospector::is_indexed_column_name("items__product__category"));
349        assert!(PostgresIntrospector::is_indexed_column_name("items__product__category__code"));
350        assert!(PostgresIntrospector::is_indexed_column_name("order_items__product_name"));
351
352        // Invalid patterns
353        assert!(!PostgresIntrospector::is_indexed_column_name("items"));
354        assert!(!PostgresIntrospector::is_indexed_column_name("items_product")); // single underscore
355        assert!(!PostgresIntrospector::is_indexed_column_name("__items")); // empty first segment
356        assert!(!PostgresIntrospector::is_indexed_column_name("items__")); // empty last segment
357    }
358
359    #[test]
360    fn test_is_indexed_column_name_entity_id() {
361        // Valid entity ID patterns
362        assert!(PostgresIntrospector::is_indexed_column_name("f200100__code"));
363        assert!(PostgresIntrospector::is_indexed_column_name("f1__name"));
364        assert!(PostgresIntrospector::is_indexed_column_name("f123456789__field"));
365
366        // Invalid entity ID patterns (that also aren't valid human-readable)
367        assert!(!PostgresIntrospector::is_indexed_column_name("f__code")); // no digits after 'f', and 'f' alone is reserved
368
369        // Note: fx123__code IS valid as a human-readable pattern (fx123 is a valid identifier)
370        assert!(PostgresIntrospector::is_indexed_column_name("fx123__code")); // valid as human-readable
371    }
372}
373
374/// Integration tests that require a PostgreSQL connection.
375#[cfg(all(test, feature = "test-postgres"))]
376mod integration_tests {
377    use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
378    use tokio_postgres::NoTls;
379
380    use super::*;
381    use crate::db::postgres::PostgresAdapter;
382
383    const TEST_DB_URL: &str =
384        "postgresql://fraiseql_test:fraiseql_test_password@localhost:5433/test_fraiseql";
385
386    // Helper to create test introspector
387    async fn create_test_introspector() -> PostgresIntrospector {
388        let _adapter =
389            PostgresAdapter::new(TEST_DB_URL).await.expect("Failed to create test adapter");
390
391        // Extract pool from adapter (we need a way to get the pool)
392        // For now, create a new pool directly
393
394        let mut cfg = Config::new();
395        cfg.url = Some(TEST_DB_URL.to_string());
396        cfg.manager = Some(ManagerConfig {
397            recycling_method: RecyclingMethod::Fast,
398        });
399        cfg.pool = Some(deadpool_postgres::PoolConfig::new(10));
400
401        let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls).expect("Failed to create pool");
402
403        PostgresIntrospector::new(pool)
404    }
405
406    #[tokio::test]
407    async fn test_get_columns_tf_sales() {
408        let introspector = create_test_introspector().await;
409
410        let columns = introspector.get_columns("tf_sales").await.expect("Failed to get columns");
411
412        // Should have: id, revenue, quantity, cost, discount, data, customer_id, product_id,
413        // occurred_at, created_at
414        assert!(columns.len() >= 10, "Expected at least 10 columns, got {}", columns.len());
415
416        // Check for key columns
417        let column_names: Vec<String> = columns.iter().map(|(name, _, _)| name.clone()).collect();
418        assert!(column_names.contains(&"revenue".to_string()));
419        assert!(column_names.contains(&"quantity".to_string()));
420        assert!(column_names.contains(&"data".to_string()));
421        assert!(column_names.contains(&"customer_id".to_string()));
422    }
423
424    #[tokio::test]
425    async fn test_get_indexed_columns_tf_sales() {
426        let introspector = create_test_introspector().await;
427
428        let indexed = introspector
429            .get_indexed_columns("tf_sales")
430            .await
431            .expect("Failed to get indexed columns");
432
433        // Should have indexes on: id (PK), customer_id, product_id, occurred_at, data (GIN)
434        assert!(indexed.len() >= 4, "Expected at least 4 indexed columns, got {}", indexed.len());
435
436        assert!(indexed.contains(&"customer_id".to_string()));
437        assert!(indexed.contains(&"product_id".to_string()));
438        assert!(indexed.contains(&"occurred_at".to_string()));
439    }
440
441    #[tokio::test]
442    async fn test_database_type() {
443        let introspector = create_test_introspector().await;
444        assert_eq!(introspector.database_type(), DatabaseType::PostgreSQL);
445    }
446}