1mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{hash_pk_tuple, parse_mysql_insert_rows, ParsedRow, PkHashSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31 Percent(u32),
33 Rows(usize),
35}
36
37#[derive(Debug)]
39pub struct SampleConfig {
40 pub input: PathBuf,
42 pub output: Option<PathBuf>,
44 pub dialect: SqlDialect,
46 pub mode: SampleMode,
48 pub preserve_relations: bool,
50 pub tables_filter: Option<Vec<String>>,
52 pub exclude: Vec<String>,
54 pub root_tables: Vec<String>,
56 pub include_global: GlobalTableMode,
58 pub seed: u64,
60 pub dry_run: bool,
62 pub progress: bool,
64 pub config_file: Option<PathBuf>,
66 pub max_total_rows: Option<usize>,
68 pub strict_fk: bool,
70 pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75 fn default() -> Self {
76 Self {
77 input: PathBuf::new(),
78 output: None,
79 dialect: SqlDialect::MySql,
80 mode: SampleMode::Percent(10),
81 preserve_relations: false,
82 tables_filter: None,
83 exclude: Vec::new(),
84 root_tables: Vec::new(),
85 include_global: GlobalTableMode::Lookups,
86 seed: rand::random(),
87 dry_run: false,
88 progress: false,
89 config_file: None,
90 max_total_rows: None,
91 strict_fk: false,
92 include_schema: true,
93 }
94 }
95}
96
97#[derive(Debug, Default)]
99pub struct SampleStats {
100 pub tables_sampled: usize,
102 pub tables_skipped: usize,
104 pub total_rows_selected: u64,
106 pub total_rows_seen: u64,
108 pub table_stats: Vec<TableSampleStats>,
110 pub warnings: Vec<String>,
112 pub fk_orphans_rejected: u64,
114}
115
116#[derive(Debug, Clone)]
118pub struct TableSampleStats {
119 pub name: String,
120 pub rows_seen: u64,
121 pub rows_selected: u64,
122 pub classification: TableClassification,
123}
124
125struct TableRuntime {
127 name: String,
129 pk_set: PkHashSet,
131 rows_seen: u64,
133 rows_selected: u64,
135 skip: bool,
137 classification: TableClassification,
139 fk_orphans: u64,
141 selected_temp_path: Option<PathBuf>,
143}
144
145enum UnifiedRow {
147 Insert(ParsedRow),
148 Copy(ParsedCopyRow),
149}
150
151#[derive(Debug, Clone, Copy, PartialEq)]
153enum RowFormat {
154 Insert,
155 Copy,
156}
157
158impl UnifiedRow {
159 fn pk(&self) -> Option<&smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>> {
160 match self {
161 UnifiedRow::Insert(r) => r.pk.as_ref(),
162 UnifiedRow::Copy(r) => r.pk.as_ref(),
163 }
164 }
165
166 fn fk_values(
167 &self,
168 ) -> &[(
169 crate::parser::mysql_insert::FkRef,
170 smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
171 )] {
172 match self {
173 UnifiedRow::Insert(r) => &r.fk_values,
174 UnifiedRow::Copy(r) => &r.fk_values,
175 }
176 }
177}
178
179pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
181 let yaml_config = if let Some(ref path) = config.config_file {
183 Some(SampleYamlConfig::load(path)?)
184 } else {
185 None
186 };
187
188 let mut rng = StdRng::seed_from_u64(config.seed);
189 let mut stats = SampleStats::default();
190
191 let file_size = std::fs::metadata(&config.input)?.len();
193
194 let progress_bar = if config.progress {
196 let pb = ProgressBar::new(file_size);
197 pb.set_style(
198 ProgressStyle::with_template(
199 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
200 )
201 .unwrap()
202 .progress_chars("█▓▒░ ")
203 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
204 );
205 pb.enable_steady_tick(std::time::Duration::from_millis(100));
206 pb.set_message("Splitting dump...");
207 Some(pb)
208 } else {
209 None
210 };
211
212 let temp_dir = TempDir::new()?;
214 let tables_dir = temp_dir.path().join("tables");
215
216 let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
217 .with_dialect(config.dialect)
218 .with_content_filter(ContentFilter::All);
219
220 if let Some(ref pb) = progress_bar {
221 let pb_clone = pb.clone();
222 splitter = splitter.with_progress(move |bytes| {
223 pb_clone.set_position(bytes);
224 });
225 }
226
227 let split_stats = splitter.split()?;
228
229 if let Some(ref pb) = progress_bar {
231 pb.finish_and_clear();
232 }
233
234 if config.progress {
235 eprintln!(
236 "Split complete: {} tables, {} statements",
237 split_stats.tables_found, split_stats.statements_processed
238 );
239 }
240
241 if config.progress {
243 eprintln!("Building schema graph...");
244 }
245
246 let graph = build_schema_graph(&tables_dir, &config)?;
247
248 let (topo_order, cyclic_tables) = graph.processing_order();
249
250 if !cyclic_tables.is_empty() {
251 let names: Vec<_> = cyclic_tables
252 .iter()
253 .filter_map(|&id| graph.table_name(id))
254 .collect();
255 let msg = format!(
256 "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
257 cyclic_tables.len(),
258 names
259 );
260 if config.progress {
261 eprintln!("{}", msg);
262 }
263 stats.warnings.push(msg);
264 }
265
266 let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
268
269 let explicit_roots: ahash::AHashSet<String> = config
271 .root_tables
272 .iter()
273 .map(|s| s.to_lowercase())
274 .collect();
275
276 let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
278 for table in graph.schema.iter() {
279 let classification =
280 determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
281 let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
282
283 runtimes.insert(
284 table.id,
285 TableRuntime {
286 name: table.name.clone(),
287 pk_set: PkHashSet::default(),
288 rows_seen: 0,
289 rows_selected: 0,
290 skip,
291 classification,
292 fk_orphans: 0,
293 selected_temp_path: None,
294 },
295 );
296 }
297
298 let selected_dir = temp_dir.path().join("selected");
300 fs::create_dir_all(&selected_dir)?;
301
302 if config.progress {
304 eprintln!("Sampling {} tables in dependency order...", topo_order.len());
305 }
306
307 let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
309
310 let mut total_selected: u64 = 0;
311
312 for table_id in &all_tables {
313 let table_schema = match graph.schema.table(*table_id) {
314 Some(s) => s,
315 None => continue,
316 };
317
318 let (should_skip, table_name, classification) = {
320 let runtime = match runtimes.get(table_id) {
321 Some(r) => r,
322 None => continue,
323 };
324 (runtime.skip, runtime.name.clone(), runtime.classification)
325 };
326
327 if should_skip {
328 stats.tables_skipped += 1;
329 continue;
330 }
331
332 let sample_mode = match classification {
334 TableClassification::Lookup => {
335 match config.include_global {
336 GlobalTableMode::None => {
337 stats.tables_skipped += 1;
338 continue;
339 }
340 GlobalTableMode::Lookups | GlobalTableMode::All => {
341 SampleMode::Percent(100)
343 }
344 }
345 }
346 TableClassification::System => {
347 stats.tables_skipped += 1;
348 continue;
349 }
350 _ => get_table_sample_mode(&table_name, &config, &yaml_config),
351 };
352
353 let table_file = tables_dir.join(format!("{}.sql", table_name));
354 if !table_file.exists() {
355 continue;
356 }
357
358 let result = sample_table_streaming(
360 &table_file,
361 table_schema,
362 *table_id,
363 &table_name,
364 sample_mode,
365 &config,
366 &runtimes,
367 &cyclic_set,
368 &selected_dir,
369 &mut rng,
370 )?;
371
372 if let Some(max) = config.max_total_rows {
374 if total_selected + result.rows_selected > max as u64 {
375 let msg = format!(
376 "Warning: Reached max_total_rows limit ({}) at table '{}'",
377 max, table_name
378 );
379 stats.warnings.push(msg);
380 break;
381 }
382 }
383
384 total_selected += result.rows_selected;
386
387 let runtime = runtimes.get_mut(table_id).unwrap();
389 runtime.rows_seen = result.rows_seen;
390 runtime.rows_selected = result.rows_selected;
391 runtime.fk_orphans = result.fk_orphans;
392
393 for pk_hash in result.pk_hashes {
395 runtime.pk_set.insert(pk_hash);
396 }
397
398 if result.rows_selected > 0 {
400 let temp_path = selected_dir.join(format!("{}.rows", table_name));
401 if temp_path.exists() {
402 runtime.selected_temp_path = Some(temp_path);
403 }
404 }
405
406 stats.fk_orphans_rejected += result.fk_orphans;
407
408 stats.table_stats.push(TableSampleStats {
409 name: runtime.name.clone(),
410 rows_seen: result.rows_seen,
411 rows_selected: result.rows_selected,
412 classification: runtime.classification,
413 });
414 }
415
416 for table_stats in &stats.table_stats {
418 stats.total_rows_seen += table_stats.rows_seen;
419 stats.total_rows_selected += table_stats.rows_selected;
420 }
421 stats.tables_sampled = stats.table_stats.len();
422
423 if config.progress {
424 eprintln!("Sampling complete");
425 }
426
427 if config.dry_run {
429 return Ok(stats);
430 }
431
432 if config.progress {
433 eprintln!("Writing output...");
434 }
435
436 write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
437
438 Ok(stats)
439}
440
441fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
443 let mut builder = SchemaBuilder::new();
444
445 for entry in fs::read_dir(tables_dir)? {
446 let entry = entry?;
447 let path = entry.path();
448
449 if path.extension().map(|e| e == "sql").unwrap_or(false) {
450 let file = File::open(&path)?;
451 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
452
453 while let Some(stmt) = parser.read_statement()? {
454 let stmt_str = String::from_utf8_lossy(&stmt);
455 let (stmt_type, _) =
456 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
457
458 match stmt_type {
459 StatementType::CreateTable => {
460 builder.parse_create_table(&stmt_str);
461 }
462 StatementType::AlterTable => {
463 builder.parse_alter_table(&stmt_str);
464 }
465 _ => {}
466 }
467 }
468 }
469 }
470
471 Ok(SchemaGraph::from_schema(builder.build()))
472}
473
474fn determine_classification(
476 name: &str,
477 graph: &SchemaGraph,
478 table_id: TableId,
479 yaml_config: &Option<SampleYamlConfig>,
480 explicit_roots: &ahash::AHashSet<String>,
481) -> TableClassification {
482 if explicit_roots.contains(&name.to_lowercase()) {
484 return TableClassification::Root;
485 }
486
487 if let Some(ref config) = yaml_config {
489 let class = config.get_classification(name);
490 if class != TableClassification::Normal {
491 return class;
492 }
493 }
494
495 if graph.parents[table_id.0 as usize].is_empty() {
497 return TableClassification::Root;
498 }
499
500 DefaultClassifier::classify(name)
502}
503
504fn should_skip_table(
506 name: &str,
507 config: &SampleConfig,
508 yaml_config: &Option<SampleYamlConfig>,
509 classification: TableClassification,
510) -> bool {
511 let name_lower = name.to_lowercase();
512
513 if config
515 .exclude
516 .iter()
517 .any(|e| e.to_lowercase() == name_lower)
518 {
519 return true;
520 }
521
522 if let Some(ref yc) = yaml_config {
524 if yc.should_skip(name) {
525 return true;
526 }
527 }
528
529 if let Some(ref filter) = config.tables_filter {
531 if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
532 return true;
533 }
534 }
535
536 if classification == TableClassification::System {
538 return true;
539 }
540
541 false
542}
543
544fn get_table_sample_mode(
546 name: &str,
547 config: &SampleConfig,
548 yaml_config: &Option<SampleYamlConfig>,
549) -> SampleMode {
550 if let Some(ref yc) = yaml_config {
552 if let Some(rows) = yc.get_rows(name) {
553 return SampleMode::Rows(rows);
554 }
555 if let Some(percent) = yc.get_percent(name) {
556 return SampleMode::Percent(percent);
557 }
558 }
559
560 config.mode
562}
563
564struct StreamingSampleResult {
566 rows_seen: u64,
567 rows_selected: u64,
568 fk_orphans: u64,
569 pk_hashes: Vec<u64>,
571}
572
573#[allow(clippy::too_many_arguments)]
578fn sample_table_streaming(
579 table_file: &Path,
580 table_schema: &crate::schema::TableSchema,
581 table_id: TableId,
582 table_name: &str,
583 sample_mode: SampleMode,
584 config: &SampleConfig,
585 runtimes: &AHashMap<TableId, TableRuntime>,
586 cyclic_set: &ahash::AHashSet<TableId>,
587 selected_dir: &Path,
588 rng: &mut StdRng,
589) -> anyhow::Result<StreamingSampleResult> {
590 let mut rows_seen = 0u64;
591 let mut rows_selected = 0u64;
592 let mut fk_orphans = 0u64;
593
594 let temp_path = selected_dir.join(format!("{}.rows", table_name));
596 let mut temp_writer: Option<BufWriter<File>> = None;
597
598 let mut selected_pk_hashes: Vec<u64> = Vec::new();
600
601 let mut copy_columns: Vec<String> = Vec::new();
603
604 match sample_mode {
605 SampleMode::Percent(p) => {
606 let prob = p as f64 / 100.0;
608
609 let file = File::open(table_file)?;
610 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
611
612 while let Some(stmt) = parser.read_statement()? {
613 let (stmt_type, _) =
614 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
615
616 match stmt_type {
617 StatementType::Insert => {
618 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
619 for row in rows {
620 rows_seen += 1;
621
622 if config.preserve_relations {
624 let unified = UnifiedRow::Insert(row.clone());
625 let (passes, orphan) = check_unified_fk_membership(
626 &unified,
627 table_schema,
628 runtimes,
629 cyclic_set,
630 &table_id,
631 );
632 if !passes {
633 fk_orphans += 1;
634 if orphan && config.strict_fk {
635 anyhow::bail!(
636 "FK integrity violation in table '{}': row references missing parent",
637 table_name
638 );
639 }
640 continue;
641 }
642 }
643
644 if rng.gen::<f64>() < prob {
646 if temp_writer.is_none() {
648 temp_writer =
649 Some(BufWriter::new(File::create(&temp_path)?));
650 }
651 let writer = temp_writer.as_mut().unwrap();
652 writer.write_all(&[0u8])?;
654 writer.write_all(&row.raw)?;
655 writer.write_all(b"\n")?;
656
657 if let Some(pk) = &row.pk {
659 selected_pk_hashes.push(hash_pk_tuple(pk));
660 }
661 rows_selected += 1;
662 }
663 }
664 }
665 StatementType::Copy => {
666 let header = String::from_utf8_lossy(&stmt);
667 copy_columns = parse_copy_columns(&header);
668 }
669 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
670 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
671 let rows = parse_postgres_copy_rows(
672 &stmt,
673 table_schema,
674 copy_columns.clone(),
675 )?;
676 for row in rows {
677 rows_seen += 1;
678
679 if config.preserve_relations {
680 let unified = UnifiedRow::Copy(row.clone());
681 let (passes, orphan) = check_unified_fk_membership(
682 &unified,
683 table_schema,
684 runtimes,
685 cyclic_set,
686 &table_id,
687 );
688 if !passes {
689 fk_orphans += 1;
690 if orphan && config.strict_fk {
691 anyhow::bail!(
692 "FK integrity violation in table '{}': row references missing parent",
693 table_name
694 );
695 }
696 continue;
697 }
698 }
699
700 if rng.gen::<f64>() < prob {
701 if temp_writer.is_none() {
702 temp_writer =
703 Some(BufWriter::new(File::create(&temp_path)?));
704 }
705 let writer = temp_writer.as_mut().unwrap();
706 writer.write_all(&[1u8])?;
707 writer.write_all(&row.raw)?;
708 writer.write_all(b"\n")?;
709
710 if let Some(pk) = &row.pk {
711 selected_pk_hashes.push(hash_pk_tuple(pk));
712 }
713 rows_selected += 1;
714 }
715 }
716 }
717 }
718 _ => {}
719 }
720 }
721 }
722 SampleMode::Rows(n) => {
723 let mut reservoir: Reservoir<(u64, RowFormat, Option<u64>)> =
726 Reservoir::new(n, StdRng::from_rng(&mut *rng)?);
727
728 let file = File::open(table_file)?;
730 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
731
732 while let Some(stmt) = parser.read_statement()? {
733 let (stmt_type, _) =
734 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
735
736 match stmt_type {
737 StatementType::Insert => {
738 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
739 for row in rows {
740 let current_idx = rows_seen;
741 rows_seen += 1;
742
743 if config.preserve_relations {
744 let unified = UnifiedRow::Insert(row.clone());
745 let (passes, orphan) = check_unified_fk_membership(
746 &unified,
747 table_schema,
748 runtimes,
749 cyclic_set,
750 &table_id,
751 );
752 if !passes {
753 fk_orphans += 1;
754 if orphan && config.strict_fk {
755 anyhow::bail!(
756 "FK integrity violation in table '{}': row references missing parent",
757 table_name
758 );
759 }
760 continue;
761 }
762 }
763
764 let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
765 reservoir.consider((current_idx, RowFormat::Insert, pk_hash));
766 }
767 }
768 StatementType::Copy => {
769 let header = String::from_utf8_lossy(&stmt);
770 copy_columns = parse_copy_columns(&header);
771 }
772 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
773 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
774 let rows = parse_postgres_copy_rows(
775 &stmt,
776 table_schema,
777 copy_columns.clone(),
778 )?;
779 for row in rows {
780 let current_idx = rows_seen;
781 rows_seen += 1;
782
783 if config.preserve_relations {
784 let unified = UnifiedRow::Copy(row.clone());
785 let (passes, orphan) = check_unified_fk_membership(
786 &unified,
787 table_schema,
788 runtimes,
789 cyclic_set,
790 &table_id,
791 );
792 if !passes {
793 fk_orphans += 1;
794 if orphan && config.strict_fk {
795 anyhow::bail!(
796 "FK integrity violation in table '{}': row references missing parent",
797 table_name
798 );
799 }
800 continue;
801 }
802 }
803
804 let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
805 reservoir.consider((current_idx, RowFormat::Copy, pk_hash));
806 }
807 }
808 }
809 _ => {}
810 }
811 }
812
813 let selected_items = reservoir.into_items();
815 if selected_items.is_empty() {
816 return Ok(StreamingSampleResult {
817 rows_seen,
818 rows_selected: 0,
819 fk_orphans,
820 pk_hashes: Vec::new(),
821 });
822 }
823
824 let mut selected_indices: Vec<(u64, RowFormat)> = Vec::with_capacity(selected_items.len());
826 for (idx, format, pk_hash) in selected_items {
827 if let Some(h) = pk_hash {
828 selected_pk_hashes.push(h);
829 }
830 selected_indices.push((idx, format));
831 }
832 selected_indices.sort_by_key(|(idx, _)| *idx);
833
834 let file = File::open(table_file)?;
836 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
837 let mut current_row_idx = 0u64;
838 let mut select_iter = selected_indices.iter().peekable();
839
840 temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
841 let writer = temp_writer.as_mut().unwrap();
842
843 while let Some(stmt) = parser.read_statement()? {
844 if select_iter.peek().is_none() {
845 break; }
847
848 let (stmt_type, _) =
849 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
850
851 match stmt_type {
852 StatementType::Insert => {
853 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
854 for row in rows {
855 if let Some((next_idx, _)) = select_iter.peek() {
856 if current_row_idx == *next_idx {
857 writer.write_all(&[0u8])?;
858 writer.write_all(&row.raw)?;
859 writer.write_all(b"\n")?;
860 rows_selected += 1;
861 select_iter.next();
862 }
863 }
864 current_row_idx += 1;
865 }
866 }
867 StatementType::Copy => {
868 let header = String::from_utf8_lossy(&stmt);
869 copy_columns = parse_copy_columns(&header);
870 }
871 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
872 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
873 let rows = parse_postgres_copy_rows(
874 &stmt,
875 table_schema,
876 copy_columns.clone(),
877 )?;
878 for row in rows {
879 if let Some((next_idx, _)) = select_iter.peek() {
880 if current_row_idx == *next_idx {
881 writer.write_all(&[1u8])?;
882 writer.write_all(&row.raw)?;
883 writer.write_all(b"\n")?;
884 rows_selected += 1;
885 select_iter.next();
886 }
887 }
888 current_row_idx += 1;
889 }
890 }
891 }
892 _ => {}
893 }
894 }
895 }
896 }
897
898 if let Some(mut writer) = temp_writer {
900 writer.flush()?;
901 }
902
903 Ok(StreamingSampleResult {
904 rows_seen,
905 rows_selected,
906 fk_orphans,
907 pk_hashes: selected_pk_hashes,
908 })
909}
910
911fn check_unified_fk_membership(
914 row: &UnifiedRow,
915 table_schema: &crate::schema::TableSchema,
916 runtimes: &AHashMap<TableId, TableRuntime>,
917 cyclic_set: &ahash::AHashSet<TableId>,
918 current_table_id: &TableId,
919) -> (bool, bool) {
920 let mut passes = true;
921 let mut is_orphan = false;
922
923 for (fk_ref, fk_tuple) in row.fk_values() {
924 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
925 if let Some(parent_id) = fk.referenced_table_id {
926 if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
928 continue;
929 }
930
931 if let Some(parent_runtime) = runtimes.get(&parent_id) {
933 let fk_hash = hash_pk_tuple(fk_tuple);
934 if !parent_runtime.pk_set.contains(&fk_hash) {
935 passes = false;
936 is_orphan = true;
937 break;
938 }
939 }
940 }
941 }
942 }
943
944 (passes, is_orphan)
945}
946
947fn write_output(
949 config: &SampleConfig,
950 _graph: &SchemaGraph,
951 table_order: &[TableId],
952 runtimes: &AHashMap<TableId, TableRuntime>,
953 tables_dir: &Path,
954 stats: &SampleStats,
955) -> anyhow::Result<()> {
956 let mut writer: Box<dyn Write> = match &config.output {
957 Some(path) => {
958 if let Some(parent) = path.parent() {
959 fs::create_dir_all(parent)?;
960 }
961 Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
962 }
963 None => Box::new(BufWriter::new(std::io::stdout())),
964 };
965
966 write_header(&mut writer, config, stats)?;
968
969 write_dialect_header(&mut writer, config.dialect)?;
971
972 if config.include_schema {
974 for &table_id in table_order {
975 let runtime = match runtimes.get(&table_id) {
976 Some(r) if !r.skip && r.rows_selected > 0 => r,
977 _ => continue,
978 };
979
980 let table_file = tables_dir.join(format!("{}.sql", runtime.name));
981 if !table_file.exists() {
982 continue;
983 }
984
985 let file = File::open(&table_file)?;
987 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
988
989 while let Some(stmt) = parser.read_statement()? {
990 let (stmt_type, _) =
991 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
992
993 if stmt_type.is_schema() {
994 writer.write_all(&stmt)?;
995 writer.write_all(b"\n")?;
996 }
997 }
998 }
999 }
1000
1001 for &table_id in table_order {
1003 let runtime = match runtimes.get(&table_id) {
1004 Some(r) if !r.skip && r.rows_selected > 0 && r.selected_temp_path.is_some() => r,
1005 _ => continue,
1006 };
1007
1008 let table_name = &runtime.name;
1009 let row_count = runtime.rows_selected;
1010
1011 writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
1012
1013 let quoted_name = match config.dialect {
1015 SqlDialect::MySql => format!("`{}`", table_name),
1016 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1017 };
1018
1019 let temp_path = runtime.selected_temp_path.as_ref().unwrap();
1021 let temp_file = File::open(temp_path)?;
1022 let reader = std::io::BufReader::new(temp_file);
1023 use std::io::BufRead;
1024
1025 const CHUNK_SIZE: usize = 1000;
1026 let mut chunk_buffer: Vec<(RowFormat, Vec<u8>)> = Vec::with_capacity(CHUNK_SIZE);
1027
1028 for line in reader.lines() {
1029 let line = line?;
1030 if line.is_empty() {
1031 continue;
1032 }
1033
1034 let bytes = line.as_bytes();
1035 if bytes.is_empty() {
1036 continue;
1037 }
1038
1039 let format = if bytes[0] == 0 {
1041 RowFormat::Insert
1042 } else {
1043 RowFormat::Copy
1044 };
1045 let row_bytes = bytes[1..].to_vec();
1046
1047 chunk_buffer.push((format, row_bytes));
1048
1049 if chunk_buffer.len() >= CHUNK_SIZE {
1050 write_insert_chunk(&mut writer, "ed_name, &chunk_buffer, config.dialect)?;
1051 chunk_buffer.clear();
1052 }
1053 }
1054
1055 if !chunk_buffer.is_empty() {
1057 write_insert_chunk(&mut writer, "ed_name, &chunk_buffer, config.dialect)?;
1058 }
1059 }
1060
1061 write_dialect_footer(&mut writer, config.dialect)?;
1063
1064 writer.flush()?;
1065
1066 Ok(())
1067}
1068
1069fn write_header<W: Write>(
1071 writer: &mut W,
1072 config: &SampleConfig,
1073 stats: &SampleStats,
1074) -> std::io::Result<()> {
1075 writeln!(writer, "-- Sampled from: {}", config.input.display())?;
1076 writeln!(
1077 writer,
1078 "-- Date: {}",
1079 chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1080 )?;
1081 writeln!(
1082 writer,
1083 "-- Mode: {:?}{}",
1084 config.mode,
1085 if config.preserve_relations {
1086 ", preserve-relations"
1087 } else {
1088 ""
1089 }
1090 )?;
1091 writeln!(writer, "-- Seed: {}", config.seed)?;
1092 writeln!(writer, "-- Dialect: {}", config.dialect)?;
1093 writeln!(writer, "--")?;
1094 writeln!(writer, "-- Statistics:")?;
1095 writeln!(writer, "-- Tables sampled: {}", stats.tables_sampled)?;
1096 writeln!(writer, "-- Tables skipped: {}", stats.tables_skipped)?;
1097
1098 let percent = if stats.total_rows_seen > 0 {
1099 (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1100 } else {
1101 0.0
1102 };
1103 writeln!(
1104 writer,
1105 "-- Total rows: {} (from {} original, {:.1}%)",
1106 stats.total_rows_selected, stats.total_rows_seen, percent
1107 )?;
1108
1109 if stats.fk_orphans_rejected > 0 {
1110 writeln!(
1111 writer,
1112 "-- FK orphans rejected: {}",
1113 stats.fk_orphans_rejected
1114 )?;
1115 }
1116
1117 if !stats.warnings.is_empty() {
1118 writeln!(writer, "-- Warnings: {}", stats.warnings.len())?;
1119 }
1120
1121 writeln!(writer)?;
1122
1123 Ok(())
1124}
1125
1126fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1128 match dialect {
1129 SqlDialect::MySql => {
1130 writeln!(writer, "SET NAMES utf8mb4;")?;
1131 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1132 }
1133 SqlDialect::Postgres => {
1134 writeln!(writer, "SET client_encoding = 'UTF8';")?;
1135 writeln!(writer, "SET session_replication_role = replica;")?;
1136 }
1137 SqlDialect::Sqlite => {
1138 writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1139 }
1140 }
1141 writeln!(writer)?;
1142 Ok(())
1143}
1144
1145fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1147 writeln!(writer)?;
1148 match dialect {
1149 SqlDialect::MySql => {
1150 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1151 }
1152 SqlDialect::Postgres => {
1153 writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1154 }
1155 SqlDialect::Sqlite => {
1156 writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1157 }
1158 }
1159 Ok(())
1160}
1161
1162fn write_insert_chunk<W: Write>(
1164 writer: &mut W,
1165 quoted_name: &str,
1166 chunk: &[(RowFormat, Vec<u8>)],
1167 dialect: SqlDialect,
1168) -> std::io::Result<()> {
1169 writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1170
1171 for (i, (format, row_bytes)) in chunk.iter().enumerate() {
1172 if i > 0 {
1173 writer.write_all(b",\n")?;
1174 }
1175
1176 let values = match format {
1177 RowFormat::Insert => {
1178 match dialect {
1179 SqlDialect::Postgres => convert_row_to_postgres(row_bytes),
1180 _ => row_bytes.clone(),
1181 }
1182 }
1183 RowFormat::Copy => convert_copy_to_insert_values(row_bytes, dialect),
1184 };
1185 writer.write_all(&values)?;
1186 }
1187
1188 writer.write_all(b";\n")?;
1189 Ok(())
1190}
1191
1192fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1194 let mut result = Vec::with_capacity(row.len());
1197 let mut i = 0;
1198
1199 while i < row.len() {
1200 if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1201 result.push(b'\'');
1203 result.push(b'\'');
1204 i += 2;
1205 } else {
1206 result.push(row[i]);
1207 i += 1;
1208 }
1209 }
1210
1211 result
1212}
1213
1214fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1216 let mut result = Vec::with_capacity(row.len() + 20);
1217 result.push(b'(');
1218
1219 let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1220
1221 for (i, field) in fields.iter().enumerate() {
1222 if i > 0 {
1223 result.extend_from_slice(b", ");
1224 }
1225
1226 if *field == b"\\N" {
1228 result.extend_from_slice(b"NULL");
1229 } else if field.is_empty() {
1230 match dialect {
1232 SqlDialect::MySql => result.extend_from_slice(b"''"),
1233 SqlDialect::Postgres | SqlDialect::Sqlite => result.extend_from_slice(b"''"),
1234 }
1235 } else if is_numeric(field) {
1236 result.extend_from_slice(field);
1238 } else {
1239 result.push(b'\'');
1241 for &b in *field {
1242 match b {
1243 b'\'' => {
1244 match dialect {
1246 SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1247 SqlDialect::Postgres | SqlDialect::Sqlite => {
1248 result.extend_from_slice(b"''")
1249 }
1250 }
1251 }
1252 b'\\' if dialect == SqlDialect::MySql => {
1253 result.extend_from_slice(b"\\\\");
1255 }
1256 _ => result.push(b),
1257 }
1258 }
1259 result.push(b'\'');
1260 }
1261 }
1262
1263 result.push(b')');
1264 result
1265}
1266
1267fn is_numeric(s: &[u8]) -> bool {
1269 if s.is_empty() {
1270 return false;
1271 }
1272
1273 let mut has_digit = false;
1274 let mut has_dot = false;
1275 let mut start = 0;
1276
1277 if s[0] == b'-' || s[0] == b'+' {
1279 start = 1;
1280 }
1281
1282 for &b in &s[start..] {
1283 match b {
1284 b'0'..=b'9' => has_digit = true,
1285 b'.' if !has_dot => has_dot = true,
1286 b'e' | b'E' => {
1287 continue;
1289 }
1290 _ => return false,
1291 }
1292 }
1293
1294 has_digit
1295}