1mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{hash_pk_tuple, parse_mysql_insert_rows, ParsedRow, PkHashSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31 Percent(u32),
33 Rows(usize),
35}
36
37#[derive(Debug)]
39pub struct SampleConfig {
40 pub input: PathBuf,
42 pub output: Option<PathBuf>,
44 pub dialect: SqlDialect,
46 pub mode: SampleMode,
48 pub preserve_relations: bool,
50 pub tables_filter: Option<Vec<String>>,
52 pub exclude: Vec<String>,
54 pub root_tables: Vec<String>,
56 pub include_global: GlobalTableMode,
58 pub seed: u64,
60 pub dry_run: bool,
62 pub progress: bool,
64 pub config_file: Option<PathBuf>,
66 pub max_total_rows: Option<usize>,
68 pub strict_fk: bool,
70 pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75 fn default() -> Self {
76 Self {
77 input: PathBuf::new(),
78 output: None,
79 dialect: SqlDialect::MySql,
80 mode: SampleMode::Percent(10),
81 preserve_relations: false,
82 tables_filter: None,
83 exclude: Vec::new(),
84 root_tables: Vec::new(),
85 include_global: GlobalTableMode::Lookups,
86 seed: rand::random(),
87 dry_run: false,
88 progress: false,
89 config_file: None,
90 max_total_rows: None,
91 strict_fk: false,
92 include_schema: true,
93 }
94 }
95}
96
97#[derive(Debug, Default, serde::Serialize)]
99pub struct SampleStats {
100 pub tables_sampled: usize,
102 pub tables_skipped: usize,
104 pub total_rows_selected: u64,
106 pub total_rows_seen: u64,
108 pub table_stats: Vec<TableSampleStats>,
110 pub warnings: Vec<String>,
112 pub fk_orphans_rejected: u64,
114}
115
116#[derive(Debug, Clone, serde::Serialize)]
118pub struct TableSampleStats {
119 pub name: String,
120 pub rows_seen: u64,
121 pub rows_selected: u64,
122 pub classification: TableClassification,
123}
124
125struct TableRuntime {
127 name: String,
129 pk_set: PkHashSet,
131 rows_seen: u64,
133 rows_selected: u64,
135 skip: bool,
137 classification: TableClassification,
139 fk_orphans: u64,
141 selected_temp_path: Option<PathBuf>,
143}
144
145enum UnifiedRow {
147 Insert(ParsedRow),
148 Copy(ParsedCopyRow),
149}
150
151#[derive(Debug, Clone, Copy, PartialEq)]
153enum RowFormat {
154 Insert,
155 Copy,
156}
157
158impl UnifiedRow {
159 fn fk_values(
160 &self,
161 ) -> &[(
162 crate::parser::mysql_insert::FkRef,
163 smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
164 )] {
165 match self {
166 UnifiedRow::Insert(r) => &r.fk_values,
167 UnifiedRow::Copy(r) => &r.fk_values,
168 }
169 }
170}
171
172pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
174 let yaml_config = if let Some(ref path) = config.config_file {
176 Some(SampleYamlConfig::load(path)?)
177 } else {
178 None
179 };
180
181 let mut rng = StdRng::seed_from_u64(config.seed);
182 let mut stats = SampleStats::default();
183
184 let file_size = std::fs::metadata(&config.input)?.len();
186
187 let progress_bar = if config.progress {
189 let pb = ProgressBar::new(file_size);
190 pb.set_style(
191 ProgressStyle::with_template(
192 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
193 )
194 .unwrap()
195 .progress_chars("█▓▒░ ")
196 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
197 );
198 pb.enable_steady_tick(std::time::Duration::from_millis(100));
199 pb.set_message("Splitting dump...");
200 Some(pb)
201 } else {
202 None
203 };
204
205 let temp_dir = TempDir::new()?;
207 let tables_dir = temp_dir.path().join("tables");
208
209 let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
210 .with_dialect(config.dialect)
211 .with_content_filter(ContentFilter::All);
212
213 if let Some(ref pb) = progress_bar {
214 let pb_clone = pb.clone();
215 splitter = splitter.with_progress(move |bytes| {
216 pb_clone.set_position(bytes);
217 });
218 }
219
220 let split_stats = splitter.split()?;
221
222 if let Some(ref pb) = progress_bar {
224 pb.finish_and_clear();
225 }
226
227 if config.progress {
228 eprintln!(
229 "Split complete: {} tables, {} statements",
230 split_stats.tables_found, split_stats.statements_processed
231 );
232 }
233
234 if config.progress {
236 eprintln!("Building schema graph...");
237 }
238
239 let graph = build_schema_graph(&tables_dir, &config)?;
240
241 let (topo_order, cyclic_tables) = graph.processing_order();
242
243 if !cyclic_tables.is_empty() {
244 let names: Vec<_> = cyclic_tables
245 .iter()
246 .filter_map(|&id| graph.table_name(id))
247 .collect();
248 let msg = format!(
249 "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
250 cyclic_tables.len(),
251 names
252 );
253 if config.progress {
254 eprintln!("{}", msg);
255 }
256 stats.warnings.push(msg);
257 }
258
259 let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
261
262 let explicit_roots: ahash::AHashSet<String> = config
264 .root_tables
265 .iter()
266 .map(|s| s.to_lowercase())
267 .collect();
268
269 let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
271 for table in graph.schema.iter() {
272 let classification =
273 determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
274 let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
275
276 runtimes.insert(
277 table.id,
278 TableRuntime {
279 name: table.name.clone(),
280 pk_set: PkHashSet::default(),
281 rows_seen: 0,
282 rows_selected: 0,
283 skip,
284 classification,
285 fk_orphans: 0,
286 selected_temp_path: None,
287 },
288 );
289 }
290
291 let selected_dir = temp_dir.path().join("selected");
293 fs::create_dir_all(&selected_dir)?;
294
295 if config.progress {
297 eprintln!(
298 "Sampling {} tables in dependency order...",
299 topo_order.len()
300 );
301 }
302
303 let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
305
306 let mut total_selected: u64 = 0;
307
308 for table_id in &all_tables {
309 let table_schema = match graph.schema.table(*table_id) {
310 Some(s) => s,
311 None => continue,
312 };
313
314 let (should_skip, table_name, classification) = {
316 let runtime = match runtimes.get(table_id) {
317 Some(r) => r,
318 None => continue,
319 };
320 (runtime.skip, runtime.name.clone(), runtime.classification)
321 };
322
323 if should_skip {
324 stats.tables_skipped += 1;
325 continue;
326 }
327
328 let sample_mode = match classification {
330 TableClassification::Lookup => {
331 match config.include_global {
332 GlobalTableMode::None => {
333 stats.tables_skipped += 1;
334 continue;
335 }
336 GlobalTableMode::Lookups | GlobalTableMode::All => {
337 SampleMode::Percent(100)
339 }
340 }
341 }
342 TableClassification::System => {
343 stats.tables_skipped += 1;
344 continue;
345 }
346 _ => get_table_sample_mode(&table_name, &config, &yaml_config),
347 };
348
349 let table_file = tables_dir.join(format!("{}.sql", table_name));
350 if !table_file.exists() {
351 continue;
352 }
353
354 let result = sample_table_streaming(
356 &table_file,
357 table_schema,
358 *table_id,
359 &table_name,
360 sample_mode,
361 &config,
362 &runtimes,
363 &cyclic_set,
364 &selected_dir,
365 &mut rng,
366 )?;
367
368 if let Some(max) = config.max_total_rows {
370 if total_selected + result.rows_selected > max as u64 {
371 let msg = format!(
372 "Warning: Reached max_total_rows limit ({}) at table '{}'",
373 max, table_name
374 );
375 stats.warnings.push(msg);
376 break;
377 }
378 }
379
380 total_selected += result.rows_selected;
382
383 let runtime = runtimes
386 .get_mut(table_id)
387 .expect("runtime must exist - checked at loop start");
388 runtime.rows_seen = result.rows_seen;
389 runtime.rows_selected = result.rows_selected;
390 runtime.fk_orphans = result.fk_orphans;
391
392 for pk_hash in result.pk_hashes {
394 runtime.pk_set.insert(pk_hash);
395 }
396
397 if result.rows_selected > 0 {
399 let temp_path = selected_dir.join(format!("{}.rows", table_name));
400 if temp_path.exists() {
401 runtime.selected_temp_path = Some(temp_path);
402 }
403 }
404
405 stats.fk_orphans_rejected += result.fk_orphans;
406
407 stats.table_stats.push(TableSampleStats {
408 name: runtime.name.clone(),
409 rows_seen: result.rows_seen,
410 rows_selected: result.rows_selected,
411 classification: runtime.classification,
412 });
413 }
414
415 for table_stats in &stats.table_stats {
417 stats.total_rows_seen += table_stats.rows_seen;
418 stats.total_rows_selected += table_stats.rows_selected;
419 }
420 stats.tables_sampled = stats.table_stats.len();
421
422 if config.progress {
423 eprintln!("Sampling complete");
424 }
425
426 if config.dry_run {
428 return Ok(stats);
429 }
430
431 if config.progress {
432 eprintln!("Writing output...");
433 }
434
435 write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
436
437 Ok(stats)
438}
439
440fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
442 let mut builder = SchemaBuilder::new();
443
444 for entry in fs::read_dir(tables_dir)? {
445 let entry = entry?;
446 let path = entry.path();
447
448 if path.extension().map(|e| e == "sql").unwrap_or(false) {
449 let file = File::open(&path)?;
450 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
451
452 while let Some(stmt) = parser.read_statement()? {
453 let stmt_str = String::from_utf8_lossy(&stmt);
454 let (stmt_type, _) =
455 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
456
457 match stmt_type {
458 StatementType::CreateTable => {
459 builder.parse_create_table(&stmt_str);
460 }
461 StatementType::AlterTable => {
462 builder.parse_alter_table(&stmt_str);
463 }
464 _ => {}
465 }
466 }
467 }
468 }
469
470 Ok(SchemaGraph::from_schema(builder.build()))
471}
472
473fn determine_classification(
475 name: &str,
476 graph: &SchemaGraph,
477 table_id: TableId,
478 yaml_config: &Option<SampleYamlConfig>,
479 explicit_roots: &ahash::AHashSet<String>,
480) -> TableClassification {
481 if explicit_roots.contains(&name.to_lowercase()) {
483 return TableClassification::Root;
484 }
485
486 if let Some(ref config) = yaml_config {
488 let class = config.get_classification(name);
489 if class != TableClassification::Normal {
490 return class;
491 }
492 }
493
494 if graph.parents[table_id.0 as usize].is_empty() {
496 return TableClassification::Root;
497 }
498
499 DefaultClassifier::classify(name)
501}
502
503fn should_skip_table(
505 name: &str,
506 config: &SampleConfig,
507 yaml_config: &Option<SampleYamlConfig>,
508 classification: TableClassification,
509) -> bool {
510 let name_lower = name.to_lowercase();
511
512 if config
514 .exclude
515 .iter()
516 .any(|e| e.to_lowercase() == name_lower)
517 {
518 return true;
519 }
520
521 if let Some(ref yc) = yaml_config {
523 if yc.should_skip(name) {
524 return true;
525 }
526 }
527
528 if let Some(ref filter) = config.tables_filter {
530 if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
531 return true;
532 }
533 }
534
535 if classification == TableClassification::System {
537 return true;
538 }
539
540 false
541}
542
543fn get_table_sample_mode(
545 name: &str,
546 config: &SampleConfig,
547 yaml_config: &Option<SampleYamlConfig>,
548) -> SampleMode {
549 if let Some(ref yc) = yaml_config {
551 if let Some(rows) = yc.get_rows(name) {
552 return SampleMode::Rows(rows);
553 }
554 if let Some(percent) = yc.get_percent(name) {
555 return SampleMode::Percent(percent);
556 }
557 }
558
559 config.mode
561}
562
563struct StreamingSampleResult {
565 rows_seen: u64,
566 rows_selected: u64,
567 fk_orphans: u64,
568 pk_hashes: Vec<u64>,
570}
571
572#[allow(clippy::too_many_arguments)]
577fn sample_table_streaming(
578 table_file: &Path,
579 table_schema: &crate::schema::TableSchema,
580 table_id: TableId,
581 table_name: &str,
582 sample_mode: SampleMode,
583 config: &SampleConfig,
584 runtimes: &AHashMap<TableId, TableRuntime>,
585 cyclic_set: &ahash::AHashSet<TableId>,
586 selected_dir: &Path,
587 rng: &mut StdRng,
588) -> anyhow::Result<StreamingSampleResult> {
589 let mut rows_seen = 0u64;
590 let mut rows_selected = 0u64;
591 let mut fk_orphans = 0u64;
592
593 let temp_path = selected_dir.join(format!("{}.rows", table_name));
595 let mut temp_writer: Option<BufWriter<File>> = None;
596
597 let mut selected_pk_hashes: Vec<u64> = Vec::new();
599
600 let mut copy_columns: Vec<String> = Vec::new();
602
603 match sample_mode {
604 SampleMode::Percent(p) => {
605 let prob = p as f64 / 100.0;
607
608 let file = File::open(table_file)?;
609 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
610
611 while let Some(stmt) = parser.read_statement()? {
612 let (stmt_type, _) =
613 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
614
615 match stmt_type {
616 StatementType::Insert => {
617 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
618 for row in rows {
619 rows_seen += 1;
620
621 if config.preserve_relations {
623 let unified = UnifiedRow::Insert(row.clone());
624 let (passes, orphan) = check_unified_fk_membership(
625 &unified,
626 table_schema,
627 runtimes,
628 cyclic_set,
629 &table_id,
630 );
631 if !passes {
632 fk_orphans += 1;
633 if orphan && config.strict_fk {
634 anyhow::bail!(
635 "FK integrity violation in table '{}': row references missing parent",
636 table_name
637 );
638 }
639 continue;
640 }
641 }
642
643 if rng.random::<f64>() < prob {
645 if temp_writer.is_none() {
647 temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
648 }
649 let writer = temp_writer.as_mut().unwrap();
650 writer.write_all(&[0u8])?;
652 writer.write_all(&row.raw)?;
653 writer.write_all(b"\n")?;
654
655 if let Some(pk) = &row.pk {
657 selected_pk_hashes.push(hash_pk_tuple(pk));
658 }
659 rows_selected += 1;
660 }
661 }
662 }
663 StatementType::Copy => {
664 let header = String::from_utf8_lossy(&stmt);
665 copy_columns = parse_copy_columns(&header);
666 }
667 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
668 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
669 let rows = parse_postgres_copy_rows(
670 &stmt,
671 table_schema,
672 copy_columns.clone(),
673 )?;
674 for row in rows {
675 rows_seen += 1;
676
677 if config.preserve_relations {
678 let unified = UnifiedRow::Copy(row.clone());
679 let (passes, orphan) = check_unified_fk_membership(
680 &unified,
681 table_schema,
682 runtimes,
683 cyclic_set,
684 &table_id,
685 );
686 if !passes {
687 fk_orphans += 1;
688 if orphan && config.strict_fk {
689 anyhow::bail!(
690 "FK integrity violation in table '{}': row references missing parent",
691 table_name
692 );
693 }
694 continue;
695 }
696 }
697
698 if rng.random::<f64>() < prob {
699 if temp_writer.is_none() {
700 temp_writer =
701 Some(BufWriter::new(File::create(&temp_path)?));
702 }
703 let writer = temp_writer.as_mut().unwrap();
704 writer.write_all(&[1u8])?;
705 writer.write_all(&row.raw)?;
706 writer.write_all(b"\n")?;
707
708 if let Some(pk) = &row.pk {
709 selected_pk_hashes.push(hash_pk_tuple(pk));
710 }
711 rows_selected += 1;
712 }
713 }
714 }
715 }
716 _ => {}
717 }
718 }
719 }
720 SampleMode::Rows(n) => {
721 let mut reservoir: Reservoir<(u64, RowFormat, Option<u64>)> =
724 Reservoir::new(n, StdRng::from_rng(&mut *rng));
725
726 let file = File::open(table_file)?;
728 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
729
730 while let Some(stmt) = parser.read_statement()? {
731 let (stmt_type, _) =
732 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
733
734 match stmt_type {
735 StatementType::Insert => {
736 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
737 for row in rows {
738 let current_idx = rows_seen;
739 rows_seen += 1;
740
741 if config.preserve_relations {
742 let unified = UnifiedRow::Insert(row.clone());
743 let (passes, orphan) = check_unified_fk_membership(
744 &unified,
745 table_schema,
746 runtimes,
747 cyclic_set,
748 &table_id,
749 );
750 if !passes {
751 fk_orphans += 1;
752 if orphan && config.strict_fk {
753 anyhow::bail!(
754 "FK integrity violation in table '{}': row references missing parent",
755 table_name
756 );
757 }
758 continue;
759 }
760 }
761
762 let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
763 reservoir.consider((current_idx, RowFormat::Insert, pk_hash));
764 }
765 }
766 StatementType::Copy => {
767 let header = String::from_utf8_lossy(&stmt);
768 copy_columns = parse_copy_columns(&header);
769 }
770 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
771 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
772 let rows = parse_postgres_copy_rows(
773 &stmt,
774 table_schema,
775 copy_columns.clone(),
776 )?;
777 for row in rows {
778 let current_idx = rows_seen;
779 rows_seen += 1;
780
781 if config.preserve_relations {
782 let unified = UnifiedRow::Copy(row.clone());
783 let (passes, orphan) = check_unified_fk_membership(
784 &unified,
785 table_schema,
786 runtimes,
787 cyclic_set,
788 &table_id,
789 );
790 if !passes {
791 fk_orphans += 1;
792 if orphan && config.strict_fk {
793 anyhow::bail!(
794 "FK integrity violation in table '{}': row references missing parent",
795 table_name
796 );
797 }
798 continue;
799 }
800 }
801
802 let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
803 reservoir.consider((current_idx, RowFormat::Copy, pk_hash));
804 }
805 }
806 }
807 _ => {}
808 }
809 }
810
811 let selected_items = reservoir.into_items();
813 if selected_items.is_empty() {
814 return Ok(StreamingSampleResult {
815 rows_seen,
816 rows_selected: 0,
817 fk_orphans,
818 pk_hashes: Vec::new(),
819 });
820 }
821
822 let mut selected_indices: Vec<(u64, RowFormat)> =
824 Vec::with_capacity(selected_items.len());
825 for (idx, format, pk_hash) in selected_items {
826 if let Some(h) = pk_hash {
827 selected_pk_hashes.push(h);
828 }
829 selected_indices.push((idx, format));
830 }
831 selected_indices.sort_by_key(|(idx, _)| *idx);
832
833 let file = File::open(table_file)?;
835 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
836 let mut current_row_idx = 0u64;
837 let mut select_iter = selected_indices.iter().peekable();
838
839 temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
840 let writer = temp_writer.as_mut().unwrap();
841
842 while let Some(stmt) = parser.read_statement()? {
843 if select_iter.peek().is_none() {
844 break; }
846
847 let (stmt_type, _) =
848 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
849
850 match stmt_type {
851 StatementType::Insert => {
852 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
853 for row in rows {
854 if let Some((next_idx, _)) = select_iter.peek() {
855 if current_row_idx == *next_idx {
856 writer.write_all(&[0u8])?;
857 writer.write_all(&row.raw)?;
858 writer.write_all(b"\n")?;
859 rows_selected += 1;
860 select_iter.next();
861 }
862 }
863 current_row_idx += 1;
864 }
865 }
866 StatementType::Copy => {
867 let header = String::from_utf8_lossy(&stmt);
868 copy_columns = parse_copy_columns(&header);
869 }
870 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
871 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
872 let rows = parse_postgres_copy_rows(
873 &stmt,
874 table_schema,
875 copy_columns.clone(),
876 )?;
877 for row in rows {
878 if let Some((next_idx, _)) = select_iter.peek() {
879 if current_row_idx == *next_idx {
880 writer.write_all(&[1u8])?;
881 writer.write_all(&row.raw)?;
882 writer.write_all(b"\n")?;
883 rows_selected += 1;
884 select_iter.next();
885 }
886 }
887 current_row_idx += 1;
888 }
889 }
890 }
891 _ => {}
892 }
893 }
894 }
895 }
896
897 if let Some(mut writer) = temp_writer {
899 writer.flush()?;
900 }
901
902 Ok(StreamingSampleResult {
903 rows_seen,
904 rows_selected,
905 fk_orphans,
906 pk_hashes: selected_pk_hashes,
907 })
908}
909
910fn check_unified_fk_membership(
913 row: &UnifiedRow,
914 table_schema: &crate::schema::TableSchema,
915 runtimes: &AHashMap<TableId, TableRuntime>,
916 cyclic_set: &ahash::AHashSet<TableId>,
917 current_table_id: &TableId,
918) -> (bool, bool) {
919 let mut passes = true;
920 let mut is_orphan = false;
921
922 for (fk_ref, fk_tuple) in row.fk_values() {
923 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
924 if let Some(parent_id) = fk.referenced_table_id {
925 if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
927 continue;
928 }
929
930 if let Some(parent_runtime) = runtimes.get(&parent_id) {
932 let fk_hash = hash_pk_tuple(fk_tuple);
933 if !parent_runtime.pk_set.contains(&fk_hash) {
934 passes = false;
935 is_orphan = true;
936 break;
937 }
938 }
939 }
940 }
941 }
942
943 (passes, is_orphan)
944}
945
946fn write_output(
948 config: &SampleConfig,
949 _graph: &SchemaGraph,
950 table_order: &[TableId],
951 runtimes: &AHashMap<TableId, TableRuntime>,
952 tables_dir: &Path,
953 stats: &SampleStats,
954) -> anyhow::Result<()> {
955 let mut writer: Box<dyn Write> = match &config.output {
956 Some(path) => {
957 if let Some(parent) = path.parent() {
958 fs::create_dir_all(parent)?;
959 }
960 Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
961 }
962 None => Box::new(BufWriter::new(std::io::stdout())),
963 };
964
965 write_header(&mut writer, config, stats)?;
967
968 write_dialect_header(&mut writer, config.dialect)?;
970
971 if config.include_schema {
973 for &table_id in table_order {
974 let runtime = match runtimes.get(&table_id) {
975 Some(r) if !r.skip && r.rows_selected > 0 => r,
976 _ => continue,
977 };
978
979 let table_file = tables_dir.join(format!("{}.sql", runtime.name));
980 if !table_file.exists() {
981 continue;
982 }
983
984 let file = File::open(&table_file)?;
986 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
987
988 while let Some(stmt) = parser.read_statement()? {
989 let (stmt_type, _) =
990 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
991
992 if stmt_type.is_schema() {
993 writer.write_all(&stmt)?;
994 writer.write_all(b"\n")?;
995 }
996 }
997 }
998 }
999
1000 for &table_id in table_order {
1002 let runtime = match runtimes.get(&table_id) {
1003 Some(r) if !r.skip && r.rows_selected > 0 && r.selected_temp_path.is_some() => r,
1004 _ => continue,
1005 };
1006
1007 let table_name = &runtime.name;
1008 let row_count = runtime.rows_selected;
1009
1010 writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
1011
1012 let quoted_name = match config.dialect {
1014 SqlDialect::MySql => format!("`{}`", table_name),
1015 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1016 SqlDialect::Mssql => format!("[{}]", table_name),
1017 };
1018
1019 let temp_path = runtime.selected_temp_path.as_ref().unwrap();
1021 let temp_file = File::open(temp_path)?;
1022 let reader = std::io::BufReader::new(temp_file);
1023 use std::io::BufRead;
1024
1025 const CHUNK_SIZE: usize = 1000;
1026 let mut chunk_buffer: Vec<(RowFormat, Vec<u8>)> = Vec::with_capacity(CHUNK_SIZE);
1027
1028 for line in reader.lines() {
1029 let line = line?;
1030 if line.is_empty() {
1031 continue;
1032 }
1033
1034 let bytes = line.as_bytes();
1035 if bytes.is_empty() {
1036 continue;
1037 }
1038
1039 let format = if bytes[0] == 0 {
1041 RowFormat::Insert
1042 } else {
1043 RowFormat::Copy
1044 };
1045 let row_bytes = bytes[1..].to_vec();
1046
1047 chunk_buffer.push((format, row_bytes));
1048
1049 if chunk_buffer.len() >= CHUNK_SIZE {
1050 write_insert_chunk(&mut writer, "ed_name, &chunk_buffer, config.dialect)?;
1051 chunk_buffer.clear();
1052 }
1053 }
1054
1055 if !chunk_buffer.is_empty() {
1057 write_insert_chunk(&mut writer, "ed_name, &chunk_buffer, config.dialect)?;
1058 }
1059 }
1060
1061 write_dialect_footer(&mut writer, config.dialect)?;
1063
1064 writer.flush()?;
1065
1066 Ok(())
1067}
1068
1069fn write_header<W: Write>(
1071 writer: &mut W,
1072 config: &SampleConfig,
1073 stats: &SampleStats,
1074) -> std::io::Result<()> {
1075 writeln!(writer, "-- Sampled from: {}", config.input.display())?;
1076 writeln!(
1077 writer,
1078 "-- Date: {}",
1079 chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1080 )?;
1081 writeln!(
1082 writer,
1083 "-- Mode: {:?}{}",
1084 config.mode,
1085 if config.preserve_relations {
1086 ", preserve-relations"
1087 } else {
1088 ""
1089 }
1090 )?;
1091 writeln!(writer, "-- Seed: {}", config.seed)?;
1092 writeln!(writer, "-- Dialect: {}", config.dialect)?;
1093 writeln!(writer, "--")?;
1094 writeln!(writer, "-- Statistics:")?;
1095 writeln!(writer, "-- Tables sampled: {}", stats.tables_sampled)?;
1096 writeln!(writer, "-- Tables skipped: {}", stats.tables_skipped)?;
1097
1098 let percent = if stats.total_rows_seen > 0 {
1099 (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1100 } else {
1101 0.0
1102 };
1103 writeln!(
1104 writer,
1105 "-- Total rows: {} (from {} original, {:.1}%)",
1106 stats.total_rows_selected, stats.total_rows_seen, percent
1107 )?;
1108
1109 if stats.fk_orphans_rejected > 0 {
1110 writeln!(
1111 writer,
1112 "-- FK orphans rejected: {}",
1113 stats.fk_orphans_rejected
1114 )?;
1115 }
1116
1117 if !stats.warnings.is_empty() {
1118 writeln!(writer, "-- Warnings: {}", stats.warnings.len())?;
1119 }
1120
1121 writeln!(writer)?;
1122
1123 Ok(())
1124}
1125
1126fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1128 match dialect {
1129 SqlDialect::MySql => {
1130 writeln!(writer, "SET NAMES utf8mb4;")?;
1131 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1132 }
1133 SqlDialect::Postgres => {
1134 writeln!(writer, "SET client_encoding = 'UTF8';")?;
1135 writeln!(writer, "SET session_replication_role = replica;")?;
1136 }
1137 SqlDialect::Sqlite => {
1138 writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1139 }
1140 SqlDialect::Mssql => {
1141 writeln!(writer, "SET ANSI_NULLS ON;")?;
1142 writeln!(writer, "SET QUOTED_IDENTIFIER ON;")?;
1143 writeln!(writer, "SET NOCOUNT ON;")?;
1144 }
1145 }
1146 writeln!(writer)?;
1147 Ok(())
1148}
1149
1150fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1152 writeln!(writer)?;
1153 match dialect {
1154 SqlDialect::MySql => {
1155 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1156 }
1157 SqlDialect::Postgres => {
1158 writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1159 }
1160 SqlDialect::Sqlite => {
1161 writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1162 }
1163 SqlDialect::Mssql => {
1164 }
1166 }
1167 Ok(())
1168}
1169
1170fn write_insert_chunk<W: Write>(
1172 writer: &mut W,
1173 quoted_name: &str,
1174 chunk: &[(RowFormat, Vec<u8>)],
1175 dialect: SqlDialect,
1176) -> std::io::Result<()> {
1177 writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1178
1179 for (i, (format, row_bytes)) in chunk.iter().enumerate() {
1180 if i > 0 {
1181 writer.write_all(b",\n")?;
1182 }
1183
1184 let values = match format {
1185 RowFormat::Insert => match dialect {
1186 SqlDialect::Postgres => convert_row_to_postgres(row_bytes),
1187 _ => row_bytes.clone(),
1188 },
1189 RowFormat::Copy => convert_copy_to_insert_values(row_bytes, dialect),
1190 };
1191 writer.write_all(&values)?;
1192 }
1193
1194 writer.write_all(b";\n")?;
1195 Ok(())
1196}
1197
1198fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1200 let mut result = Vec::with_capacity(row.len());
1203 let mut i = 0;
1204
1205 while i < row.len() {
1206 if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1207 result.push(b'\'');
1209 result.push(b'\'');
1210 i += 2;
1211 } else {
1212 result.push(row[i]);
1213 i += 1;
1214 }
1215 }
1216
1217 result
1218}
1219
1220fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1222 let mut result = Vec::with_capacity(row.len() + 20);
1223 result.push(b'(');
1224
1225 let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1226
1227 for (i, field) in fields.iter().enumerate() {
1228 if i > 0 {
1229 result.extend_from_slice(b", ");
1230 }
1231
1232 if *field == b"\\N" {
1234 result.extend_from_slice(b"NULL");
1235 } else if field.is_empty() {
1236 match dialect {
1238 SqlDialect::MySql => result.extend_from_slice(b"''"),
1239 SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1240 result.extend_from_slice(b"''")
1241 }
1242 }
1243 } else if is_numeric(field) {
1244 result.extend_from_slice(field);
1246 } else {
1247 result.push(b'\'');
1249 for &b in *field {
1250 match b {
1251 b'\'' => {
1252 match dialect {
1254 SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1255 SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1256 result.extend_from_slice(b"''")
1257 }
1258 }
1259 }
1260 b'\\' if dialect == SqlDialect::MySql => {
1261 result.extend_from_slice(b"\\\\");
1263 }
1264 _ => result.push(b),
1265 }
1266 }
1267 result.push(b'\'');
1268 }
1269 }
1270
1271 result.push(b')');
1272 result
1273}
1274
1275fn is_numeric(s: &[u8]) -> bool {
1277 if s.is_empty() {
1278 return false;
1279 }
1280
1281 let mut has_digit = false;
1282 let mut has_dot = false;
1283 let mut start = 0;
1284
1285 if s[0] == b'-' || s[0] == b'+' {
1287 start = 1;
1288 }
1289
1290 for &b in &s[start..] {
1291 match b {
1292 b'0'..=b'9' => has_digit = true,
1293 b'.' if !has_dot => has_dot = true,
1294 b'e' | b'E' => {
1295 continue;
1297 }
1298 _ => return false,
1299 }
1300 }
1301
1302 has_digit
1303}