1use crate::{
5 filters::ReplicationFilter,
6 migration, postgres,
7 table_rules::{QualifiedTable, TableRules},
8};
9use anyhow::{Context, Result};
10use inquire::{Confirm, MultiSelect, Select, Text};
11
12enum WizardStep {
14 SelectDatabases,
15 SelectTablesForDb(usize), SelectSchemaOnlyForDb(usize), ConfigureTimeFiltersForDb(usize), Review,
19}
20
21struct CachedDbTables {
23 all_tables: Vec<migration::TableInfo>,
24 table_display_names: Vec<String>,
25}
26
27pub async fn select_databases_and_tables(
65 source_url: &str,
66) -> Result<(ReplicationFilter, TableRules)> {
67 tracing::info!("Starting interactive database and table selection...");
68 println!();
69
70 tracing::info!("Connecting to source database...");
72 let source_client = postgres::connect_with_retry(source_url)
73 .await
74 .context("Failed to connect to source database")?;
75 tracing::info!("✓ Connected to source");
76 println!();
77
78 tracing::info!("Discovering databases on source...");
80 let all_databases = migration::list_databases(&source_client)
81 .await
82 .context("Failed to list databases on source")?;
83
84 if all_databases.is_empty() {
85 tracing::warn!("⚠ No user databases found on source");
86 tracing::warn!(" Source appears to contain only template databases");
87 return Ok((ReplicationFilter::empty(), TableRules::default()));
88 }
89
90 tracing::info!("✓ Found {} database(s)", all_databases.len());
91 println!();
92
93 let db_names: Vec<String> = all_databases.iter().map(|db| db.name.clone()).collect();
94
95 let mut selected_db_indices: Vec<usize> = Vec::new();
97 let mut current_step = WizardStep::SelectDatabases;
98
99 let mut included_tables_by_db: std::collections::HashMap<String, Vec<String>> =
101 std::collections::HashMap::new();
102 let mut schema_only_by_db: std::collections::HashMap<String, Vec<(String, String)>> =
103 std::collections::HashMap::new(); let mut time_filters_by_db: std::collections::HashMap<
105 String,
106 Vec<(String, String, String, String)>,
107 > = std::collections::HashMap::new(); let mut table_cache: std::collections::HashMap<String, CachedDbTables> =
111 std::collections::HashMap::new();
112
113 loop {
114 match current_step {
115 WizardStep::SelectDatabases => {
116 print_header("Step 1 of 5: Select Databases");
117 println!("Navigation: Space to toggle, Enter to confirm, Esc to cancel");
118 println!();
119
120 let defaults: Vec<usize> = selected_db_indices.clone();
121
122 let selections =
123 MultiSelect::new("Select databases to replicate:", db_names.clone())
124 .with_default(&defaults)
125 .with_help_message("↑↓ navigate, Space toggle, Enter confirm")
126 .prompt();
127
128 match selections {
129 Ok(selected) => {
130 selected_db_indices = selected
132 .iter()
133 .filter_map(|name| db_names.iter().position(|n| n == name))
134 .collect();
135
136 if selected_db_indices.is_empty() {
137 println!();
138 println!("⚠ Please select at least one database");
139 continue;
140 }
141
142 included_tables_by_db.clear();
144 schema_only_by_db.clear();
145 time_filters_by_db.clear();
146 table_cache.clear();
147
148 current_step = WizardStep::SelectTablesForDb(0);
149 }
150 Err(inquire::InquireError::OperationCanceled) => {
151 anyhow::bail!("Operation cancelled by user");
152 }
153 Err(inquire::InquireError::OperationInterrupted) => {
154 anyhow::bail!("Operation interrupted");
155 }
156 Err(e) => return Err(e.into()),
157 }
158 }
159
160 WizardStep::SelectTablesForDb(db_idx) => {
161 let db_name = &db_names[selected_db_indices[db_idx]].clone();
162 print_header(&format!(
163 "Step 2 of 5: Select Tables to Include ({}/{})",
164 db_idx + 1,
165 selected_db_indices.len()
166 ));
167 println!("Database: {}", db_name);
168 println!("Press Enter without selecting to include ALL tables.");
169 println!("Navigation: Space to toggle, Enter to continue, Esc to go back");
170 println!();
171
172 let cached = get_or_cache_tables(&mut table_cache, source_url, db_name).await?;
174
175 if cached.all_tables.is_empty() {
176 println!(" No tables found in database '{}'", db_name);
177 if db_idx + 1 < selected_db_indices.len() {
179 current_step = WizardStep::SelectTablesForDb(db_idx + 1);
180 } else {
181 current_step = WizardStep::SelectSchemaOnlyForDb(0);
182 }
183 continue;
184 }
185
186 let previous_inclusions: Vec<usize> = included_tables_by_db
188 .get(db_name)
189 .map(|included| {
190 included
191 .iter()
192 .filter_map(|t| {
193 let stripped =
195 t.strip_prefix(&format!("{}.", db_name)).unwrap_or(t);
196 cached
197 .table_display_names
198 .iter()
199 .position(|n| n == stripped)
200 })
201 .collect()
202 })
203 .unwrap_or_default();
204
205 let selections = MultiSelect::new(
206 "Select tables to INCLUDE (Enter = include all):",
207 cached.table_display_names.clone(),
208 )
209 .with_default(&previous_inclusions)
210 .with_help_message("Space toggle, Enter confirm, Esc go back")
211 .prompt();
212
213 match selections {
214 Ok(selected_inclusions) => {
215 let db_inclusions: Vec<String> = if selected_inclusions.is_empty() {
217 cached
218 .table_display_names
219 .iter()
220 .map(|table_name| format!("{}.{}", db_name, table_name))
221 .collect()
222 } else {
223 selected_inclusions
224 .iter()
225 .map(|table_name| format!("{}.{}", db_name, table_name))
226 .collect()
227 };
228
229 included_tables_by_db.insert(db_name.clone(), db_inclusions);
231
232 if db_idx + 1 < selected_db_indices.len() {
234 current_step = WizardStep::SelectTablesForDb(db_idx + 1);
235 } else {
236 current_step = WizardStep::SelectSchemaOnlyForDb(0);
237 }
238 }
239 Err(inquire::InquireError::OperationCanceled) => {
240 if db_idx > 0 {
242 current_step = WizardStep::SelectTablesForDb(db_idx - 1);
243 } else {
244 current_step = WizardStep::SelectDatabases;
245 }
246 }
247 Err(inquire::InquireError::OperationInterrupted) => {
248 anyhow::bail!("Operation interrupted");
249 }
250 Err(e) => return Err(e.into()),
251 }
252 }
253
254 WizardStep::SelectSchemaOnlyForDb(db_idx) => {
255 let db_name = &db_names[selected_db_indices[db_idx]].clone();
256 print_header(&format!(
257 "Step 3 of 5: Schema-Only Tables ({}/{})",
258 db_idx + 1,
259 selected_db_indices.len()
260 ));
261 println!("Database: {}", db_name);
262 println!("Schema-only tables replicate structure but NO data.");
263 println!("Navigation: Space to toggle, Enter to continue, Esc to go back");
264 println!();
265
266 let cached = get_or_cache_tables(&mut table_cache, source_url, db_name).await?;
267
268 if cached.all_tables.is_empty() {
269 if db_idx + 1 < selected_db_indices.len() {
271 current_step = WizardStep::SelectSchemaOnlyForDb(db_idx + 1);
272 } else {
273 current_step = WizardStep::ConfigureTimeFiltersForDb(0);
274 }
275 continue;
276 }
277
278 let included = included_tables_by_db.get(db_name);
280 let available_tables: Vec<(usize, String)> = cached
281 .table_display_names
282 .iter()
283 .enumerate()
284 .filter(|(_, name)| {
285 let full_name = format!("{}.{}", db_name, name);
286 included.is_some_and(|inc| inc.contains(&full_name))
287 })
288 .map(|(idx, name)| (idx, name.clone()))
289 .collect();
290
291 if available_tables.is_empty() {
292 println!(" No tables included from '{}'", db_name);
293 if db_idx + 1 < selected_db_indices.len() {
294 current_step = WizardStep::SelectSchemaOnlyForDb(db_idx + 1);
295 } else {
296 current_step = WizardStep::ConfigureTimeFiltersForDb(0);
297 }
298 continue;
299 }
300
301 let available_names: Vec<String> =
302 available_tables.iter().map(|(_, n)| n.clone()).collect();
303
304 let previous_schema_only: Vec<usize> = schema_only_by_db
306 .get(db_name)
307 .map(|selected| {
308 selected
309 .iter()
310 .filter_map(|(schema, table)| {
311 let display = if schema == "public" {
312 table.clone()
313 } else {
314 format!("{}.{}", schema, table)
315 };
316 available_names.iter().position(|n| n == &display)
317 })
318 .collect()
319 })
320 .unwrap_or_default();
321
322 let selections = MultiSelect::new(
323 "Select tables to replicate SCHEMA-ONLY (no data):",
324 available_names.clone(),
325 )
326 .with_default(&previous_schema_only)
327 .with_help_message("Space toggle, Enter confirm, Esc go back")
328 .prompt();
329
330 match selections {
331 Ok(selected_schema_only) => {
332 let schema_only_tables: Vec<(String, String)> = selected_schema_only
334 .iter()
335 .filter_map(|display_name| {
336 available_tables
337 .iter()
338 .find(|(_, n)| n == display_name)
339 .map(|(idx, _)| {
340 let t = &cached.all_tables[*idx];
341 (t.schema.clone(), t.name.clone())
342 })
343 })
344 .collect();
345
346 schema_only_by_db.insert(db_name.clone(), schema_only_tables);
347
348 if db_idx + 1 < selected_db_indices.len() {
349 current_step = WizardStep::SelectSchemaOnlyForDb(db_idx + 1);
350 } else {
351 current_step = WizardStep::ConfigureTimeFiltersForDb(0);
352 }
353 }
354 Err(inquire::InquireError::OperationCanceled) => {
355 if db_idx > 0 {
357 current_step = WizardStep::SelectSchemaOnlyForDb(db_idx - 1);
358 } else {
359 let last_db = selected_db_indices.len().saturating_sub(1);
360 current_step = WizardStep::SelectTablesForDb(last_db);
361 }
362 }
363 Err(inquire::InquireError::OperationInterrupted) => {
364 anyhow::bail!("Operation interrupted");
365 }
366 Err(e) => return Err(e.into()),
367 }
368 }
369
370 WizardStep::ConfigureTimeFiltersForDb(db_idx) => {
371 let db_name = &db_names[selected_db_indices[db_idx]].clone();
372 print_header(&format!(
373 "Step 4 of 5: Time Filters ({}/{})",
374 db_idx + 1,
375 selected_db_indices.len()
376 ));
377 println!("Database: {}", db_name);
378 println!("Time filters limit data to recent records (e.g., last 90 days).");
379 println!();
380
381 let cached = get_or_cache_tables(&mut table_cache, source_url, db_name).await?;
382
383 if cached.all_tables.is_empty() {
384 if db_idx + 1 < selected_db_indices.len() {
385 current_step = WizardStep::ConfigureTimeFiltersForDb(db_idx + 1);
386 } else {
387 current_step = WizardStep::Review;
388 }
389 continue;
390 }
391
392 let included = included_tables_by_db.get(db_name);
394 let schema_only = schema_only_by_db.get(db_name);
395 let available_tables: Vec<(usize, String)> = cached
396 .table_display_names
397 .iter()
398 .enumerate()
399 .filter(|(idx, name)| {
400 let full_name = format!("{}.{}", db_name, name);
401 let is_included = included.is_some_and(|inc| inc.contains(&full_name));
402 let t = &cached.all_tables[*idx];
403 let is_schema_only = schema_only.is_some_and(|so| {
404 so.iter().any(|(s, n)| s == &t.schema && n == &t.name)
405 });
406 is_included && !is_schema_only
407 })
408 .map(|(idx, name)| (idx, name.clone()))
409 .collect();
410
411 if available_tables.is_empty() {
412 println!(" No tables available for time filtering in '{}'", db_name);
413 if db_idx + 1 < selected_db_indices.len() {
414 current_step = WizardStep::ConfigureTimeFiltersForDb(db_idx + 1);
415 } else {
416 current_step = WizardStep::Review;
417 }
418 continue;
419 }
420
421 let configure = Confirm::new("Configure time-based filters for this database?")
423 .with_default(false)
424 .with_help_message("Enter to confirm, Esc to go back")
425 .prompt();
426
427 match configure {
428 Ok(true) => {
429 let available_names: Vec<String> =
431 available_tables.iter().map(|(_, n)| n.clone()).collect();
432
433 let table_selections = MultiSelect::new(
434 "Select tables to apply time filter:",
435 available_names.clone(),
436 )
437 .with_help_message("Space toggle, Enter confirm")
438 .prompt();
439
440 match table_selections {
441 Ok(selected_tables) => {
442 let mut time_filters: Vec<(String, String, String, String)> =
443 Vec::new();
444
445 for display_name in &selected_tables {
446 if let Some((idx, _)) =
447 available_tables.iter().find(|(_, n)| n == display_name)
448 {
449 let t = &cached.all_tables[*idx];
450 let db_url = replace_database_in_url(source_url, db_name)?;
451 let db_client = postgres::connect_with_retry(&db_url)
452 .await
453 .context("Failed to connect for column query")?;
454
455 let columns = migration::get_table_columns(
457 &db_client, &t.schema, &t.name,
458 )
459 .await?;
460
461 let timestamp_columns: Vec<String> = columns
462 .iter()
463 .filter(|c| c.is_timestamp)
464 .map(|c| format!("{} ({})", c.name, c.data_type))
465 .collect();
466
467 println!();
468 println!("Configure time filter for '{}':", display_name);
469
470 let column = if timestamp_columns.is_empty() {
471 println!(
472 " ⚠ No timestamp columns found. Enter column name manually."
473 );
474 Text::new(" Column name:")
475 .with_default("created_at")
476 .prompt()
477 .context("Failed to get column name")?
478 } else {
479 let mut options = timestamp_columns.clone();
480 options.push("[Enter custom column name]".to_string());
481
482 let selection =
483 Select::new(" Select timestamp column:", options)
484 .prompt()
485 .context("Failed to select column")?;
486
487 if selection == "[Enter custom column name]" {
488 Text::new(" Column name:")
489 .prompt()
490 .context("Failed to get column name")?
491 } else {
492 selection
494 .split(" (")
495 .next()
496 .unwrap_or(&selection)
497 .to_string()
498 }
499 };
500
501 let window = Text::new(
502 " Time window (e.g., '90 days', '6 months', '1 year'):",
503 )
504 .with_default("90 days")
505 .prompt()
506 .context("Failed to get time window")?;
507
508 time_filters.push((
509 t.schema.clone(),
510 t.name.clone(),
511 column,
512 window,
513 ));
514 }
515 }
516
517 time_filters_by_db.insert(db_name.clone(), time_filters);
518 }
519 Err(inquire::InquireError::OperationCanceled) => {
520 continue;
522 }
523 Err(inquire::InquireError::OperationInterrupted) => {
524 anyhow::bail!("Operation interrupted");
525 }
526 Err(e) => return Err(e.into()),
527 }
528
529 if db_idx + 1 < selected_db_indices.len() {
530 current_step = WizardStep::ConfigureTimeFiltersForDb(db_idx + 1);
531 } else {
532 current_step = WizardStep::Review;
533 }
534 }
535 Ok(false) => {
536 if db_idx + 1 < selected_db_indices.len() {
538 current_step = WizardStep::ConfigureTimeFiltersForDb(db_idx + 1);
539 } else {
540 current_step = WizardStep::Review;
541 }
542 }
543 Err(inquire::InquireError::OperationCanceled) => {
544 if db_idx > 0 {
546 current_step = WizardStep::ConfigureTimeFiltersForDb(db_idx - 1);
547 } else {
548 let last_db = selected_db_indices.len().saturating_sub(1);
549 current_step = WizardStep::SelectSchemaOnlyForDb(last_db);
550 }
551 }
552 Err(inquire::InquireError::OperationInterrupted) => {
553 anyhow::bail!("Operation interrupted");
554 }
555 Err(e) => return Err(e.into()),
556 }
557 }
558
559 WizardStep::Review => {
560 print_header("Step 5 of 5: Review Configuration");
561
562 let included_tables: Vec<String> =
564 included_tables_by_db.values().flatten().cloned().collect();
565
566 let selected_databases: Vec<String> = selected_db_indices
567 .iter()
568 .map(|&i| db_names[i].clone())
569 .collect();
570
571 println!();
572 println!("Databases to replicate: {}", selected_databases.len());
573 for db in &selected_databases {
574 println!(" ✓ {}", db);
575 }
576 println!();
577
578 println!("Tables to replicate: {}", included_tables.len());
579 if included_tables.len() <= 20 {
580 for table in &included_tables {
581 println!(" ✓ {}", table);
582 }
583 } else {
584 for table in included_tables.iter().take(10) {
586 println!(" ✓ {}", table);
587 }
588 println!(" ... ({} more tables)", included_tables.len() - 15);
589 for table in included_tables.iter().skip(included_tables.len() - 5) {
590 println!(" ✓ {}", table);
591 }
592 }
593 println!();
594
595 let schema_only_count: usize = schema_only_by_db.values().map(|v| v.len()).sum();
597 if schema_only_count > 0 {
598 println!("Schema-only tables (no data): {}", schema_only_count);
599 for (db, tables) in &schema_only_by_db {
600 for (schema, table) in tables {
601 let display = if schema == "public" {
602 format!("{}.{}", db, table)
603 } else {
604 format!("{}.{}.{}", db, schema, table)
605 };
606 println!(" ◇ {}", display);
607 }
608 }
609 println!();
610 } else {
611 println!("Schema-only tables: none");
612 println!();
613 }
614
615 let time_filter_count: usize = time_filters_by_db.values().map(|v| v.len()).sum();
617 if time_filter_count > 0 {
618 println!("Time-filtered tables: {}", time_filter_count);
619 for (db, filters) in &time_filters_by_db {
620 for (schema, table, column, window) in filters {
621 let display = if schema == "public" {
622 format!("{}.{}", db, table)
623 } else {
624 format!("{}.{}.{}", db, schema, table)
625 };
626 println!(" ⏱ {} ({} >= last {})", display, column, window);
627 }
628 }
629 println!();
630 } else {
631 println!("Time filters: none");
632 println!();
633 }
634
635 println!("───────────────────────────────────────────────────────────────");
636 println!();
637
638 let confirmed = Confirm::new("Proceed with this configuration?")
639 .with_default(true)
640 .with_help_message("Enter confirm, Esc go back")
641 .prompt();
642
643 match confirmed {
644 Ok(true) => break, Ok(false) | Err(inquire::InquireError::OperationCanceled) => {
646 let last_db = selected_db_indices.len().saturating_sub(1);
648 current_step = WizardStep::ConfigureTimeFiltersForDb(last_db);
649 }
650 Err(inquire::InquireError::OperationInterrupted) => {
651 anyhow::bail!("Operation interrupted");
652 }
653 Err(e) => return Err(e.into()),
654 }
655 }
656 }
657 }
658
659 let selected_databases: Vec<String> = selected_db_indices
661 .iter()
662 .map(|&i| db_names[i].clone())
663 .collect();
664
665 let included_tables: Vec<String> = included_tables_by_db.values().flatten().cloned().collect();
666
667 tracing::info!("");
668 tracing::info!("✓ Configuration confirmed");
669 tracing::info!("");
670
671 let filter = if included_tables.is_empty() {
673 ReplicationFilter::new(Some(selected_databases), None, None, None)?
674 } else {
675 ReplicationFilter::new(Some(selected_databases), None, Some(included_tables), None)?
676 };
677
678 let mut table_rules = TableRules::default();
680
681 for (db, tables) in &schema_only_by_db {
683 for (schema, table) in tables {
684 let qualified = QualifiedTable::new(Some(db.clone()), schema.clone(), table.clone());
685 table_rules.add_schema_only_table(qualified)?;
686 }
687 }
688
689 for (db, filters) in &time_filters_by_db {
691 for (schema, table, column, window) in filters {
692 let qualified = QualifiedTable::new(Some(db.clone()), schema.clone(), table.clone());
693 table_rules.add_time_filter(qualified, column.clone(), window.clone())?;
694 }
695 }
696
697 Ok((filter, table_rules))
698}
699
700async fn get_or_cache_tables<'a>(
702 cache: &'a mut std::collections::HashMap<String, CachedDbTables>,
703 source_url: &str,
704 db_name: &str,
705) -> Result<&'a CachedDbTables> {
706 if !cache.contains_key(db_name) {
707 let db_url = replace_database_in_url(source_url, db_name)?;
708 let db_client = postgres::connect_with_retry(&db_url)
709 .await
710 .context(format!("Failed to connect to database '{}'", db_name))?;
711
712 let all_tables = migration::list_tables(&db_client)
713 .await
714 .context(format!("Failed to list tables from database '{}'", db_name))?;
715
716 let table_display_names: Vec<String> = all_tables
717 .iter()
718 .map(|t| {
719 if t.schema == "public" {
720 t.name.clone()
721 } else {
722 format!("{}.{}", t.schema, t.name)
723 }
724 })
725 .collect();
726
727 cache.insert(
728 db_name.to_string(),
729 CachedDbTables {
730 all_tables,
731 table_display_names,
732 },
733 );
734 }
735
736 Ok(cache.get(db_name).unwrap())
737}
738
739fn print_header(title: &str) {
741 println!();
742 println!("╔{}╗", "═".repeat(62));
743 println!("║ {:<60}║", title);
744 println!("╚{}╝", "═".repeat(62));
745 println!();
746}
747
748fn replace_database_in_url(url: &str, new_db_name: &str) -> Result<String> {
759 let parts: Vec<&str> = url.splitn(2, '?').collect();
761 let base_url = parts[0];
762 let query_params = parts.get(1);
763
764 let url_parts: Vec<&str> = base_url.rsplitn(2, '/').collect();
766
767 if url_parts.len() != 2 {
768 anyhow::bail!("Invalid connection URL format: cannot replace database name");
769 }
770
771 let new_url = if let Some(params) = query_params {
773 format!("{}/{}?{}", url_parts[1], new_db_name, params)
774 } else {
775 format!("{}/{}", url_parts[1], new_db_name)
776 };
777
778 Ok(new_url)
779}
780
781#[cfg(test)]
782mod tests {
783 use super::*;
784
785 #[test]
786 fn test_replace_database_in_url() {
787 let url = "postgresql://user:pass@localhost:5432/olddb";
789 let new_url = replace_database_in_url(url, "newdb").unwrap();
790 assert_eq!(new_url, "postgresql://user:pass@localhost:5432/newdb");
791
792 let url = "postgresql://user:pass@localhost:5432/olddb?sslmode=require";
794 let new_url = replace_database_in_url(url, "newdb").unwrap();
795 assert_eq!(
796 new_url,
797 "postgresql://user:pass@localhost:5432/newdb?sslmode=require"
798 );
799
800 let url = "postgresql://user:pass@localhost/olddb";
802 let new_url = replace_database_in_url(url, "newdb").unwrap();
803 assert_eq!(new_url, "postgresql://user:pass@localhost/newdb");
804 }
805
806 #[tokio::test]
807 #[ignore]
808 async fn test_interactive_selection() {
809 let source_url = std::env::var("TEST_SOURCE_URL").unwrap();
811
812 let result = select_databases_and_tables(&source_url).await;
813
814 match &result {
816 Ok((filter, rules)) => {
817 println!("✓ Interactive selection completed");
818 println!("Filter: {:?}", filter);
819 println!("Rules: {:?}", rules);
820 }
821 Err(e) => {
822 println!("Interactive selection error: {:?}", e);
823 }
824 }
825 }
826}