1use anyhow::{anyhow, Context, Result};
6use flowscope_core::{ColumnSchema, SchemaMetadata, SchemaTable};
7use sqlx::{any::AnyPoolOptions, AnyPool, Row};
8use std::sync::Once;
9use std::time::Duration;
10
11const MAX_CONNECTIONS: u32 = 2;
14
15const ACQUIRE_TIMEOUT_SECS: u64 = 10;
18
19const IDENTIFIER_SAFE_LENGTH: usize = 255;
29
30static INSTALL_DRIVERS: Once = Once::new();
32
33fn url_scheme(url: &str) -> &str {
35 url.split("://").next().unwrap_or("unknown")
36}
37
38fn redact_url(url: &str) -> String {
43 if let Some((scheme, rest)) = url.split_once("://") {
44 if let Some(at_pos) = rest.rfind('@') {
46 let host_and_path = &rest[at_pos + 1..];
47 return format!("{}://<redacted>@{}", scheme, host_and_path);
48 }
49 if scheme == "sqlite" {
51 return format!("{}://<path>", scheme);
52 }
53 return format!("{}://{}", scheme, rest);
54 }
55 if url.starts_with("sqlite:") {
57 return "sqlite:<path>".to_string();
58 }
59 url_scheme(url).to_string()
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum DatabaseType {
65 Postgres,
66 Mysql,
67 Sqlite,
68}
69
70impl DatabaseType {
71 pub fn from_url(url: &str) -> Option<Self> {
73 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
74 Some(Self::Postgres)
75 } else if url.starts_with("mysql://") || url.starts_with("mariadb://") {
76 Some(Self::Mysql)
77 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
78 Some(Self::Sqlite)
79 } else {
80 None
81 }
82 }
83}
84
85pub struct SqlxMetadataProvider {
88 pool: AnyPool,
89 db_type: DatabaseType,
90 schema_filter: Option<String>,
91}
92
93impl SqlxMetadataProvider {
94 pub async fn connect(url: &str, schema_filter: Option<String>) -> Result<Self> {
103 let db_type = DatabaseType::from_url(url)
104 .ok_or_else(|| anyhow!("Unsupported database URL scheme: {}", url_scheme(url)))?;
105
106 INSTALL_DRIVERS.call_once(sqlx::any::install_default_drivers);
108
109 let pool = AnyPoolOptions::new()
112 .max_connections(MAX_CONNECTIONS)
113 .acquire_timeout(Duration::from_secs(ACQUIRE_TIMEOUT_SECS))
114 .connect(url)
115 .await
116 .with_context(|| format!("Failed to connect to database: {}", redact_url(url)))?;
117
118 Ok(Self {
119 pool,
120 db_type,
121 schema_filter,
122 })
123 }
124
125 pub async fn fetch_schema_async(&self) -> Result<SchemaMetadata> {
127 let tables = match self.db_type {
128 DatabaseType::Postgres => self.fetch_postgres_schema().await?,
129 DatabaseType::Mysql => self.fetch_mysql_schema().await?,
130 DatabaseType::Sqlite => self.fetch_sqlite_schema().await?,
131 };
132
133 let default_schema = self.resolve_default_schema().await?;
134
135 Ok(SchemaMetadata {
136 default_catalog: None,
137 default_schema,
138 search_path: None,
139 case_sensitivity: None,
140 tables,
141 allow_implied: false,
142 })
143 }
144
145 async fn fetch_postgres_schema(&self) -> Result<Vec<SchemaTable>> {
147 let schema_filter = self.schema_filter.as_deref().unwrap_or("public");
148
149 let query = r#"
151 SELECT
152 c.table_schema::text AS table_schema,
153 c.table_name::text AS table_name,
154 c.column_name::text AS column_name,
155 c.data_type::text AS data_type,
156 CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END AS is_primary_key
157 FROM information_schema.columns c
158 LEFT JOIN (
159 SELECT kcu.table_schema, kcu.table_name, kcu.column_name
160 FROM information_schema.table_constraints tc
161 JOIN information_schema.key_column_usage kcu
162 ON tc.constraint_name = kcu.constraint_name
163 AND tc.table_schema = kcu.table_schema
164 WHERE tc.constraint_type = 'PRIMARY KEY'
165 ) pk ON c.table_schema = pk.table_schema
166 AND c.table_name = pk.table_name
167 AND c.column_name = pk.column_name
168 WHERE c.table_schema = $1
169 ORDER BY c.table_schema, c.table_name, c.ordinal_position
170 "#;
171
172 let rows = sqlx::query(query)
173 .bind(schema_filter)
174 .fetch_all(&self.pool)
175 .await?;
176
177 self.rows_to_tables(rows)
178 }
179
180 async fn fetch_mysql_schema(&self) -> Result<Vec<SchemaTable>> {
182 let limit = IDENTIFIER_SAFE_LENGTH;
187 let query = if self.schema_filter.is_some() {
188 format!(
189 r#"
190 SELECT
191 LEFT(TABLE_SCHEMA, {limit}) as table_schema,
192 LEFT(TABLE_NAME, {limit}) as table_name,
193 LEFT(COLUMN_NAME, {limit}) as column_name,
194 LEFT(DATA_TYPE, {limit}) as data_type,
195 CASE WHEN COLUMN_KEY = 'PRI' THEN 1 ELSE 0 END AS is_primary_key
196 FROM information_schema.COLUMNS
197 WHERE TABLE_SCHEMA = ?
198 ORDER BY TABLE_SCHEMA, TABLE_NAME, ORDINAL_POSITION
199 "#
200 )
201 } else {
202 format!(
203 r#"
204 SELECT
205 LEFT(TABLE_SCHEMA, {limit}) as table_schema,
206 LEFT(TABLE_NAME, {limit}) as table_name,
207 LEFT(COLUMN_NAME, {limit}) as column_name,
208 LEFT(DATA_TYPE, {limit}) as data_type,
209 CASE WHEN COLUMN_KEY = 'PRI' THEN 1 ELSE 0 END AS is_primary_key
210 FROM information_schema.COLUMNS
211 WHERE TABLE_SCHEMA = DATABASE()
212 ORDER BY TABLE_SCHEMA, TABLE_NAME, ORDINAL_POSITION
213 "#
214 )
215 };
216
217 let rows = if let Some(ref schema) = self.schema_filter {
218 sqlx::query(&query)
219 .bind(schema)
220 .fetch_all(&self.pool)
221 .await?
222 } else {
223 sqlx::query(&query).fetch_all(&self.pool).await?
224 };
225
226 self.rows_to_tables(rows)
227 }
228
229 fn validate_sqlite_table_name(name: &str) -> Result<()> {
242 if name.is_empty() || name.len() > IDENTIFIER_SAFE_LENGTH {
243 return Err(anyhow!("Invalid table name length: {}", name.len()));
244 }
245 if !name
247 .chars()
248 .all(|c| c.is_alphanumeric() || c == '_' || c == '.')
249 {
250 return Err(anyhow!("Table name contains invalid characters: {}", name));
251 }
252 Ok(())
253 }
254
255 async fn fetch_sqlite_schema(&self) -> Result<Vec<SchemaTable>> {
257 let tables_query = r#"
259 SELECT name FROM sqlite_master
260 WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
261 ORDER BY name
262 "#;
263
264 let table_rows = sqlx::query(tables_query).fetch_all(&self.pool).await?;
265
266 let mut tables = Vec::new();
267
268 for table_row in table_rows {
269 let table_name: String = table_row.get("name");
270
271 if let Err(err) = Self::validate_sqlite_table_name(&table_name) {
273 eprintln!(
274 "flowscope: warning: Skipping SQLite table '{table_name}' due to unsupported identifier characters: {err}"
275 );
276 continue;
277 }
278
279 let columns_query = format!("PRAGMA table_info('{}')", table_name.replace('\'', "''"));
287
288 let column_rows = sqlx::query(&columns_query).fetch_all(&self.pool).await?;
289
290 let columns: Vec<ColumnSchema> = column_rows
291 .iter()
292 .map(|row| {
293 let name: String = row.get("name");
294 let data_type: String = row.get("type");
295 let pk: i32 = row.get("pk");
296
297 ColumnSchema {
298 name,
299 data_type: if data_type.is_empty() {
300 None
301 } else {
302 Some(data_type)
303 },
304 is_primary_key: if pk > 0 { Some(true) } else { None },
305 foreign_key: None,
306 }
307 })
308 .collect();
309
310 tables.push(SchemaTable {
311 catalog: None,
312 schema: None, name: table_name,
314 columns,
315 });
316 }
317
318 Ok(tables)
319 }
320
321 async fn resolve_default_schema(&self) -> Result<Option<String>> {
323 if let Some(schema) = &self.schema_filter {
324 return Ok(Some(schema.clone()));
325 }
326
327 match self.db_type {
328 DatabaseType::Mysql => self.fetch_mysql_default_schema().await,
329 _ => Ok(None),
330 }
331 }
332
333 async fn fetch_mysql_default_schema(&self) -> Result<Option<String>> {
335 let schema: Option<String> = sqlx::query_scalar("SELECT DATABASE()")
336 .fetch_one(&self.pool)
337 .await?;
338
339 Ok(schema)
340 }
341
342 fn rows_to_tables(&self, rows: Vec<sqlx::any::AnyRow>) -> Result<Vec<SchemaTable>> {
345 use std::collections::HashMap;
346
347 let mut table_map: HashMap<(String, String), Vec<ColumnSchema>> = HashMap::new();
349
350 for row in rows {
351 let table_schema: String = row.get("table_schema");
352 let table_name: String = row.get("table_name");
353 let column_name: String = row.get("column_name");
354 let data_type: String = row.get("data_type");
355
356 let is_primary_key = self.get_primary_key_from_row(&row);
358
359 let column = ColumnSchema {
360 name: column_name,
361 data_type: Some(data_type),
362 is_primary_key: if is_primary_key { Some(true) } else { None },
363 foreign_key: None,
364 };
365
366 table_map
367 .entry((table_schema, table_name))
368 .or_default()
369 .push(column);
370 }
371
372 let mut tables: Vec<SchemaTable> = table_map
374 .into_iter()
375 .map(|((schema, name), columns)| SchemaTable {
376 catalog: None,
377 schema: Some(schema),
378 name,
379 columns,
380 })
381 .collect();
382
383 tables.sort_by(|a, b| {
385 let schema_cmp = a.schema.cmp(&b.schema);
386 if schema_cmp == std::cmp::Ordering::Equal {
387 a.name.cmp(&b.name)
388 } else {
389 schema_cmp
390 }
391 });
392
393 Ok(tables)
394 }
395
396 fn get_primary_key_from_row(&self, row: &sqlx::any::AnyRow) -> bool {
398 if let Ok(val) = row.try_get::<bool, _>("is_primary_key") {
400 return val;
401 }
402 if let Ok(val) = row.try_get::<i32, _>("is_primary_key") {
403 return val != 0;
404 }
405 if let Ok(val) = row.try_get::<i64, _>("is_primary_key") {
406 return val != 0;
407 }
408 false
409 }
410}
411
412pub fn fetch_metadata_from_database(
423 url: &str,
424 schema_filter: Option<String>,
425) -> Result<SchemaMetadata> {
426 let rt = tokio::runtime::Runtime::new().context("Failed to create async runtime")?;
427 rt.block_on(async {
428 let provider = SqlxMetadataProvider::connect(url, schema_filter).await?;
429 provider.fetch_schema_async().await
430 })
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_database_type_from_url() {
439 assert_eq!(
440 DatabaseType::from_url("postgres://localhost/db"),
441 Some(DatabaseType::Postgres)
442 );
443 assert_eq!(
444 DatabaseType::from_url("postgresql://localhost/db"),
445 Some(DatabaseType::Postgres)
446 );
447 assert_eq!(
448 DatabaseType::from_url("mysql://localhost/db"),
449 Some(DatabaseType::Mysql)
450 );
451 assert_eq!(
452 DatabaseType::from_url("mariadb://localhost/db"),
453 Some(DatabaseType::Mysql)
454 );
455 assert_eq!(
456 DatabaseType::from_url("sqlite://path/to/db"),
457 Some(DatabaseType::Sqlite)
458 );
459 assert_eq!(
460 DatabaseType::from_url("sqlite::memory:"),
461 Some(DatabaseType::Sqlite)
462 );
463 assert_eq!(DatabaseType::from_url("unknown://localhost/db"), None);
464 }
465
466 #[test]
467 fn test_redact_url_with_credentials() {
468 assert_eq!(
470 redact_url("postgres://user:password@localhost:5432/mydb"),
471 "postgres://<redacted>@localhost:5432/mydb"
472 );
473
474 assert_eq!(
476 redact_url("mysql://admin:s3cr3t!@#$@db.example.com/prod"),
477 "mysql://<redacted>@db.example.com/prod"
478 );
479 }
480
481 #[test]
482 fn test_redact_url_without_credentials() {
483 assert_eq!(
485 redact_url("postgres://localhost:5432/mydb"),
486 "postgres://localhost:5432/mydb"
487 );
488 }
489
490 #[test]
491 fn test_redact_url_sqlite() {
492 assert_eq!(
494 redact_url("sqlite:///path/to/secret/database.db"),
495 "sqlite://<path>"
496 );
497
498 assert_eq!(redact_url("sqlite::memory:"), "sqlite:<path>");
500 assert_eq!(redact_url("sqlite:path/to/db"), "sqlite:<path>");
501 }
502
503 #[test]
504 fn test_redact_url_invalid() {
505 assert_eq!(redact_url("not-a-url"), "not-a-url");
507 assert_eq!(redact_url("unknown"), "unknown");
508 }
509
510 #[test]
511 fn test_url_scheme() {
512 assert_eq!(url_scheme("postgres://localhost/db"), "postgres");
513 assert_eq!(url_scheme("mysql://localhost/db"), "mysql");
514 assert_eq!(url_scheme("sqlite://path"), "sqlite");
515 assert_eq!(url_scheme("not-a-url"), "not-a-url");
516 }
517
518 #[test]
523 fn test_validate_sqlite_table_name_valid_simple() {
524 assert!(SqlxMetadataProvider::validate_sqlite_table_name("users").is_ok());
526 assert!(SqlxMetadataProvider::validate_sqlite_table_name("Users").is_ok());
527 assert!(SqlxMetadataProvider::validate_sqlite_table_name("USERS").is_ok());
528 assert!(SqlxMetadataProvider::validate_sqlite_table_name("users123").is_ok());
529 assert!(SqlxMetadataProvider::validate_sqlite_table_name("123users").is_ok());
530 }
531
532 #[test]
533 fn test_validate_sqlite_table_name_valid_with_underscore() {
534 assert!(SqlxMetadataProvider::validate_sqlite_table_name("user_accounts").is_ok());
536 assert!(SqlxMetadataProvider::validate_sqlite_table_name("_private").is_ok());
537 assert!(SqlxMetadataProvider::validate_sqlite_table_name("table_").is_ok());
538 assert!(SqlxMetadataProvider::validate_sqlite_table_name("__double__").is_ok());
539 }
540
541 #[test]
542 fn test_validate_sqlite_table_name_valid_with_dot() {
543 assert!(SqlxMetadataProvider::validate_sqlite_table_name("main.users").is_ok());
545 assert!(SqlxMetadataProvider::validate_sqlite_table_name("schema.table").is_ok());
546 assert!(SqlxMetadataProvider::validate_sqlite_table_name("db.schema.table").is_ok());
547 }
548
549 #[test]
550 fn test_validate_sqlite_table_name_rejects_empty() {
551 let result = SqlxMetadataProvider::validate_sqlite_table_name("");
552 assert!(result.is_err());
553 assert!(result.unwrap_err().to_string().contains("length"));
554 }
555
556 #[test]
557 fn test_validate_sqlite_table_name_rejects_too_long() {
558 let long_name = "a".repeat(256);
560 let result = SqlxMetadataProvider::validate_sqlite_table_name(&long_name);
561 assert!(result.is_err());
562 assert!(result.unwrap_err().to_string().contains("length"));
563
564 let max_name = "a".repeat(255);
566 assert!(SqlxMetadataProvider::validate_sqlite_table_name(&max_name).is_ok());
567 }
568
569 #[test]
570 fn test_validate_sqlite_table_name_rejects_spaces() {
571 let result = SqlxMetadataProvider::validate_sqlite_table_name("user accounts");
572 assert!(result.is_err());
573 assert!(result
574 .unwrap_err()
575 .to_string()
576 .contains("invalid characters"));
577 }
578
579 #[test]
580 fn test_validate_sqlite_table_name_rejects_quotes() {
581 let result = SqlxMetadataProvider::validate_sqlite_table_name("users'--");
583 assert!(result.is_err());
584
585 let result = SqlxMetadataProvider::validate_sqlite_table_name("users\"table");
587 assert!(result.is_err());
588
589 let result = SqlxMetadataProvider::validate_sqlite_table_name("users`table");
591 assert!(result.is_err());
592 }
593
594 #[test]
595 fn test_validate_sqlite_table_name_rejects_semicolon() {
596 let result = SqlxMetadataProvider::validate_sqlite_table_name("users;DROP TABLE");
598 assert!(result.is_err());
599 }
600
601 #[test]
602 fn test_validate_sqlite_table_name_rejects_special_chars() {
603 let invalid_names = [
605 "users@domain",
606 "users#tag",
607 "users$var",
608 "users%percent",
609 "users&",
610 "users*star",
611 "users(paren",
612 "users)paren",
613 "users+plus",
614 "users=equals",
615 "users[bracket",
616 "users]bracket",
617 "users{brace",
618 "users}brace",
619 "users|pipe",
620 "users\\backslash",
621 "users/slash",
622 "users?question",
623 "users<less",
624 "users>greater",
625 "users,comma",
626 "users:colon",
627 "users!bang",
628 "users~tilde",
629 "users\ttab",
630 "users\nnewline",
631 ];
632
633 for name in invalid_names {
634 let result = SqlxMetadataProvider::validate_sqlite_table_name(name);
635 assert!(
636 result.is_err(),
637 "Expected '{}' to be rejected but it was accepted",
638 name
639 );
640 }
641 }
642
643 #[test]
648 fn test_identifier_safe_length_constant() {
649 assert_eq!(IDENTIFIER_SAFE_LENGTH, 255);
651
652 const _: () = {
655 assert!(IDENTIFIER_SAFE_LENGTH <= 256);
656 assert!(IDENTIFIER_SAFE_LENGTH >= 64);
657 };
658 }
659
660 #[test]
666 fn test_error_context_uses_redacted_url() {
667 let pg_url = "postgres://admin:super_secret_password@db.example.com:5432/production";
672 let redacted = redact_url(pg_url);
673 assert!(
674 redacted.contains("db.example.com"),
675 "Redacted URL should preserve host for debugging"
676 );
677 assert!(
678 !redacted.contains("super_secret_password"),
679 "Redacted URL must not contain password"
680 );
681 assert!(
682 !redacted.contains("admin"),
683 "Redacted URL should not contain username"
684 );
685
686 let mysql_url = "mysql://root:mysql_root_pw@mysql.internal:3306/app_db";
688 let redacted = redact_url(mysql_url);
689 assert!(redacted.contains("mysql.internal"));
690 assert!(!redacted.contains("mysql_root_pw"));
691 assert!(!redacted.contains("root"));
692
693 let sqlite_url = "sqlite:///home/user/secrets/private.db";
695 let redacted = redact_url(sqlite_url);
696 assert!(!redacted.contains("/home/user/secrets"));
697 assert!(redacted.contains("sqlite"));
698 }
699
700 #[test]
701 fn test_redact_url_with_at_sign_in_password() {
702 let url = "postgres://user:p@ss@word@localhost/db";
705 let redacted = redact_url(url);
706 assert_eq!(redacted, "postgres://<redacted>@localhost/db");
707 assert!(!redacted.contains("p@ss@word"));
708 }
709
710 #[test]
711 fn test_redact_url_preserves_port_and_database() {
712 let url = "postgres://user:pass@host:5433/mydb?sslmode=require";
714 let redacted = redact_url(url);
715 assert!(
716 redacted.contains("5433"),
717 "Port should be preserved for debugging"
718 );
719 assert!(
720 redacted.contains("mydb"),
721 "Database name should be preserved for debugging"
722 );
723 }
724
725 #[tokio::test]
726 async fn test_connection_error_includes_redacted_url() {
727 let url = "postgres://secret_user:secret_password@nonexistent.invalid:5432/testdb";
730
731 let result = SqlxMetadataProvider::connect(url, None).await;
732
733 let error_message = match result {
734 Ok(_) => panic!("Connection to nonexistent host should fail"),
735 Err(e) => e.to_string(),
736 };
737
738 assert!(
740 error_message.contains("nonexistent.invalid"),
741 "Error should include host for debugging: {}",
742 error_message
743 );
744
745 assert!(
747 !error_message.contains("secret_user"),
748 "Error must not expose username: {}",
749 error_message
750 );
751 assert!(
752 !error_message.contains("secret_password"),
753 "Error must not expose password: {}",
754 error_message
755 );
756
757 assert!(
759 error_message.contains("Failed to connect")
760 || error_message.contains("connect")
761 || error_message.contains("database"),
762 "Error should indicate connection failure: {}",
763 error_message
764 );
765 }
766}