1mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{hash_pk_tuple, parse_mysql_insert_rows, ParsedRow, PkHashSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31 Percent(u32),
33 Rows(usize),
35}
36
37#[derive(Debug)]
39pub struct SampleConfig {
40 pub input: PathBuf,
42 pub output: Option<PathBuf>,
44 pub dialect: SqlDialect,
46 pub mode: SampleMode,
48 pub preserve_relations: bool,
50 pub tables_filter: Option<Vec<String>>,
52 pub exclude: Vec<String>,
54 pub root_tables: Vec<String>,
56 pub include_global: GlobalTableMode,
58 pub seed: u64,
60 pub dry_run: bool,
62 pub progress: bool,
64 pub config_file: Option<PathBuf>,
66 pub max_total_rows: Option<usize>,
68 pub strict_fk: bool,
70 pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75 fn default() -> Self {
76 Self {
77 input: PathBuf::new(),
78 output: None,
79 dialect: SqlDialect::MySql,
80 mode: SampleMode::Percent(10),
81 preserve_relations: false,
82 tables_filter: None,
83 exclude: Vec::new(),
84 root_tables: Vec::new(),
85 include_global: GlobalTableMode::Lookups,
86 seed: rand::random(),
87 dry_run: false,
88 progress: false,
89 config_file: None,
90 max_total_rows: None,
91 strict_fk: false,
92 include_schema: true,
93 }
94 }
95}
96
97#[derive(Debug, Default, serde::Serialize)]
99pub struct SampleStats {
100 pub tables_sampled: usize,
102 pub tables_skipped: usize,
104 pub total_rows_selected: u64,
106 pub total_rows_seen: u64,
108 pub table_stats: Vec<TableSampleStats>,
110 pub warnings: Vec<String>,
112 pub fk_orphans_rejected: u64,
114}
115
116#[derive(Debug, Clone, serde::Serialize)]
118pub struct TableSampleStats {
119 pub name: String,
120 pub rows_seen: u64,
121 pub rows_selected: u64,
122 pub classification: TableClassification,
123}
124
125struct TableRuntime {
127 name: String,
129 pk_set: PkHashSet,
131 rows_seen: u64,
133 rows_selected: u64,
135 skip: bool,
137 classification: TableClassification,
139 fk_orphans: u64,
141 selected_temp_path: Option<PathBuf>,
143}
144
145enum UnifiedRow {
147 Insert(ParsedRow),
148 Copy(ParsedCopyRow),
149}
150
151#[derive(Debug, Clone, Copy, PartialEq)]
153enum RowFormat {
154 Insert,
155 Copy,
156}
157
158impl UnifiedRow {
159 fn pk(&self) -> Option<&smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>> {
160 match self {
161 UnifiedRow::Insert(r) => r.pk.as_ref(),
162 UnifiedRow::Copy(r) => r.pk.as_ref(),
163 }
164 }
165
166 fn fk_values(
167 &self,
168 ) -> &[(
169 crate::parser::mysql_insert::FkRef,
170 smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
171 )] {
172 match self {
173 UnifiedRow::Insert(r) => &r.fk_values,
174 UnifiedRow::Copy(r) => &r.fk_values,
175 }
176 }
177}
178
179pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
181 let yaml_config = if let Some(ref path) = config.config_file {
183 Some(SampleYamlConfig::load(path)?)
184 } else {
185 None
186 };
187
188 let mut rng = StdRng::seed_from_u64(config.seed);
189 let mut stats = SampleStats::default();
190
191 let file_size = std::fs::metadata(&config.input)?.len();
193
194 let progress_bar = if config.progress {
196 let pb = ProgressBar::new(file_size);
197 pb.set_style(
198 ProgressStyle::with_template(
199 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
200 )
201 .unwrap()
202 .progress_chars("█▓▒░ ")
203 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
204 );
205 pb.enable_steady_tick(std::time::Duration::from_millis(100));
206 pb.set_message("Splitting dump...");
207 Some(pb)
208 } else {
209 None
210 };
211
212 let temp_dir = TempDir::new()?;
214 let tables_dir = temp_dir.path().join("tables");
215
216 let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
217 .with_dialect(config.dialect)
218 .with_content_filter(ContentFilter::All);
219
220 if let Some(ref pb) = progress_bar {
221 let pb_clone = pb.clone();
222 splitter = splitter.with_progress(move |bytes| {
223 pb_clone.set_position(bytes);
224 });
225 }
226
227 let split_stats = splitter.split()?;
228
229 if let Some(ref pb) = progress_bar {
231 pb.finish_and_clear();
232 }
233
234 if config.progress {
235 eprintln!(
236 "Split complete: {} tables, {} statements",
237 split_stats.tables_found, split_stats.statements_processed
238 );
239 }
240
241 if config.progress {
243 eprintln!("Building schema graph...");
244 }
245
246 let graph = build_schema_graph(&tables_dir, &config)?;
247
248 let (topo_order, cyclic_tables) = graph.processing_order();
249
250 if !cyclic_tables.is_empty() {
251 let names: Vec<_> = cyclic_tables
252 .iter()
253 .filter_map(|&id| graph.table_name(id))
254 .collect();
255 let msg = format!(
256 "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
257 cyclic_tables.len(),
258 names
259 );
260 if config.progress {
261 eprintln!("{}", msg);
262 }
263 stats.warnings.push(msg);
264 }
265
266 let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
268
269 let explicit_roots: ahash::AHashSet<String> = config
271 .root_tables
272 .iter()
273 .map(|s| s.to_lowercase())
274 .collect();
275
276 let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
278 for table in graph.schema.iter() {
279 let classification =
280 determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
281 let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
282
283 runtimes.insert(
284 table.id,
285 TableRuntime {
286 name: table.name.clone(),
287 pk_set: PkHashSet::default(),
288 rows_seen: 0,
289 rows_selected: 0,
290 skip,
291 classification,
292 fk_orphans: 0,
293 selected_temp_path: None,
294 },
295 );
296 }
297
298 let selected_dir = temp_dir.path().join("selected");
300 fs::create_dir_all(&selected_dir)?;
301
302 if config.progress {
304 eprintln!(
305 "Sampling {} tables in dependency order...",
306 topo_order.len()
307 );
308 }
309
310 let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
312
313 let mut total_selected: u64 = 0;
314
315 for table_id in &all_tables {
316 let table_schema = match graph.schema.table(*table_id) {
317 Some(s) => s,
318 None => continue,
319 };
320
321 let (should_skip, table_name, classification) = {
323 let runtime = match runtimes.get(table_id) {
324 Some(r) => r,
325 None => continue,
326 };
327 (runtime.skip, runtime.name.clone(), runtime.classification)
328 };
329
330 if should_skip {
331 stats.tables_skipped += 1;
332 continue;
333 }
334
335 let sample_mode = match classification {
337 TableClassification::Lookup => {
338 match config.include_global {
339 GlobalTableMode::None => {
340 stats.tables_skipped += 1;
341 continue;
342 }
343 GlobalTableMode::Lookups | GlobalTableMode::All => {
344 SampleMode::Percent(100)
346 }
347 }
348 }
349 TableClassification::System => {
350 stats.tables_skipped += 1;
351 continue;
352 }
353 _ => get_table_sample_mode(&table_name, &config, &yaml_config),
354 };
355
356 let table_file = tables_dir.join(format!("{}.sql", table_name));
357 if !table_file.exists() {
358 continue;
359 }
360
361 let result = sample_table_streaming(
363 &table_file,
364 table_schema,
365 *table_id,
366 &table_name,
367 sample_mode,
368 &config,
369 &runtimes,
370 &cyclic_set,
371 &selected_dir,
372 &mut rng,
373 )?;
374
375 if let Some(max) = config.max_total_rows {
377 if total_selected + result.rows_selected > max as u64 {
378 let msg = format!(
379 "Warning: Reached max_total_rows limit ({}) at table '{}'",
380 max, table_name
381 );
382 stats.warnings.push(msg);
383 break;
384 }
385 }
386
387 total_selected += result.rows_selected;
389
390 let runtime = runtimes
393 .get_mut(table_id)
394 .expect("runtime must exist - checked at loop start");
395 runtime.rows_seen = result.rows_seen;
396 runtime.rows_selected = result.rows_selected;
397 runtime.fk_orphans = result.fk_orphans;
398
399 for pk_hash in result.pk_hashes {
401 runtime.pk_set.insert(pk_hash);
402 }
403
404 if result.rows_selected > 0 {
406 let temp_path = selected_dir.join(format!("{}.rows", table_name));
407 if temp_path.exists() {
408 runtime.selected_temp_path = Some(temp_path);
409 }
410 }
411
412 stats.fk_orphans_rejected += result.fk_orphans;
413
414 stats.table_stats.push(TableSampleStats {
415 name: runtime.name.clone(),
416 rows_seen: result.rows_seen,
417 rows_selected: result.rows_selected,
418 classification: runtime.classification,
419 });
420 }
421
422 for table_stats in &stats.table_stats {
424 stats.total_rows_seen += table_stats.rows_seen;
425 stats.total_rows_selected += table_stats.rows_selected;
426 }
427 stats.tables_sampled = stats.table_stats.len();
428
429 if config.progress {
430 eprintln!("Sampling complete");
431 }
432
433 if config.dry_run {
435 return Ok(stats);
436 }
437
438 if config.progress {
439 eprintln!("Writing output...");
440 }
441
442 write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
443
444 Ok(stats)
445}
446
447fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
449 let mut builder = SchemaBuilder::new();
450
451 for entry in fs::read_dir(tables_dir)? {
452 let entry = entry?;
453 let path = entry.path();
454
455 if path.extension().map(|e| e == "sql").unwrap_or(false) {
456 let file = File::open(&path)?;
457 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
458
459 while let Some(stmt) = parser.read_statement()? {
460 let stmt_str = String::from_utf8_lossy(&stmt);
461 let (stmt_type, _) =
462 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
463
464 match stmt_type {
465 StatementType::CreateTable => {
466 builder.parse_create_table(&stmt_str);
467 }
468 StatementType::AlterTable => {
469 builder.parse_alter_table(&stmt_str);
470 }
471 _ => {}
472 }
473 }
474 }
475 }
476
477 Ok(SchemaGraph::from_schema(builder.build()))
478}
479
480fn determine_classification(
482 name: &str,
483 graph: &SchemaGraph,
484 table_id: TableId,
485 yaml_config: &Option<SampleYamlConfig>,
486 explicit_roots: &ahash::AHashSet<String>,
487) -> TableClassification {
488 if explicit_roots.contains(&name.to_lowercase()) {
490 return TableClassification::Root;
491 }
492
493 if let Some(ref config) = yaml_config {
495 let class = config.get_classification(name);
496 if class != TableClassification::Normal {
497 return class;
498 }
499 }
500
501 if graph.parents[table_id.0 as usize].is_empty() {
503 return TableClassification::Root;
504 }
505
506 DefaultClassifier::classify(name)
508}
509
510fn should_skip_table(
512 name: &str,
513 config: &SampleConfig,
514 yaml_config: &Option<SampleYamlConfig>,
515 classification: TableClassification,
516) -> bool {
517 let name_lower = name.to_lowercase();
518
519 if config
521 .exclude
522 .iter()
523 .any(|e| e.to_lowercase() == name_lower)
524 {
525 return true;
526 }
527
528 if let Some(ref yc) = yaml_config {
530 if yc.should_skip(name) {
531 return true;
532 }
533 }
534
535 if let Some(ref filter) = config.tables_filter {
537 if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
538 return true;
539 }
540 }
541
542 if classification == TableClassification::System {
544 return true;
545 }
546
547 false
548}
549
550fn get_table_sample_mode(
552 name: &str,
553 config: &SampleConfig,
554 yaml_config: &Option<SampleYamlConfig>,
555) -> SampleMode {
556 if let Some(ref yc) = yaml_config {
558 if let Some(rows) = yc.get_rows(name) {
559 return SampleMode::Rows(rows);
560 }
561 if let Some(percent) = yc.get_percent(name) {
562 return SampleMode::Percent(percent);
563 }
564 }
565
566 config.mode
568}
569
570struct StreamingSampleResult {
572 rows_seen: u64,
573 rows_selected: u64,
574 fk_orphans: u64,
575 pk_hashes: Vec<u64>,
577}
578
579#[allow(clippy::too_many_arguments)]
584fn sample_table_streaming(
585 table_file: &Path,
586 table_schema: &crate::schema::TableSchema,
587 table_id: TableId,
588 table_name: &str,
589 sample_mode: SampleMode,
590 config: &SampleConfig,
591 runtimes: &AHashMap<TableId, TableRuntime>,
592 cyclic_set: &ahash::AHashSet<TableId>,
593 selected_dir: &Path,
594 rng: &mut StdRng,
595) -> anyhow::Result<StreamingSampleResult> {
596 let mut rows_seen = 0u64;
597 let mut rows_selected = 0u64;
598 let mut fk_orphans = 0u64;
599
600 let temp_path = selected_dir.join(format!("{}.rows", table_name));
602 let mut temp_writer: Option<BufWriter<File>> = None;
603
604 let mut selected_pk_hashes: Vec<u64> = Vec::new();
606
607 let mut copy_columns: Vec<String> = Vec::new();
609
610 match sample_mode {
611 SampleMode::Percent(p) => {
612 let prob = p as f64 / 100.0;
614
615 let file = File::open(table_file)?;
616 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
617
618 while let Some(stmt) = parser.read_statement()? {
619 let (stmt_type, _) =
620 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
621
622 match stmt_type {
623 StatementType::Insert => {
624 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
625 for row in rows {
626 rows_seen += 1;
627
628 if config.preserve_relations {
630 let unified = UnifiedRow::Insert(row.clone());
631 let (passes, orphan) = check_unified_fk_membership(
632 &unified,
633 table_schema,
634 runtimes,
635 cyclic_set,
636 &table_id,
637 );
638 if !passes {
639 fk_orphans += 1;
640 if orphan && config.strict_fk {
641 anyhow::bail!(
642 "FK integrity violation in table '{}': row references missing parent",
643 table_name
644 );
645 }
646 continue;
647 }
648 }
649
650 if rng.random::<f64>() < prob {
652 if temp_writer.is_none() {
654 temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
655 }
656 let writer = temp_writer.as_mut().unwrap();
657 writer.write_all(&[0u8])?;
659 writer.write_all(&row.raw)?;
660 writer.write_all(b"\n")?;
661
662 if let Some(pk) = &row.pk {
664 selected_pk_hashes.push(hash_pk_tuple(pk));
665 }
666 rows_selected += 1;
667 }
668 }
669 }
670 StatementType::Copy => {
671 let header = String::from_utf8_lossy(&stmt);
672 copy_columns = parse_copy_columns(&header);
673 }
674 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
675 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
676 let rows = parse_postgres_copy_rows(
677 &stmt,
678 table_schema,
679 copy_columns.clone(),
680 )?;
681 for row in rows {
682 rows_seen += 1;
683
684 if config.preserve_relations {
685 let unified = UnifiedRow::Copy(row.clone());
686 let (passes, orphan) = check_unified_fk_membership(
687 &unified,
688 table_schema,
689 runtimes,
690 cyclic_set,
691 &table_id,
692 );
693 if !passes {
694 fk_orphans += 1;
695 if orphan && config.strict_fk {
696 anyhow::bail!(
697 "FK integrity violation in table '{}': row references missing parent",
698 table_name
699 );
700 }
701 continue;
702 }
703 }
704
705 if rng.random::<f64>() < prob {
706 if temp_writer.is_none() {
707 temp_writer =
708 Some(BufWriter::new(File::create(&temp_path)?));
709 }
710 let writer = temp_writer.as_mut().unwrap();
711 writer.write_all(&[1u8])?;
712 writer.write_all(&row.raw)?;
713 writer.write_all(b"\n")?;
714
715 if let Some(pk) = &row.pk {
716 selected_pk_hashes.push(hash_pk_tuple(pk));
717 }
718 rows_selected += 1;
719 }
720 }
721 }
722 }
723 _ => {}
724 }
725 }
726 }
727 SampleMode::Rows(n) => {
728 let mut reservoir: Reservoir<(u64, RowFormat, Option<u64>)> =
731 Reservoir::new(n, StdRng::from_rng(&mut *rng));
732
733 let file = File::open(table_file)?;
735 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
736
737 while let Some(stmt) = parser.read_statement()? {
738 let (stmt_type, _) =
739 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
740
741 match stmt_type {
742 StatementType::Insert => {
743 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
744 for row in rows {
745 let current_idx = rows_seen;
746 rows_seen += 1;
747
748 if config.preserve_relations {
749 let unified = UnifiedRow::Insert(row.clone());
750 let (passes, orphan) = check_unified_fk_membership(
751 &unified,
752 table_schema,
753 runtimes,
754 cyclic_set,
755 &table_id,
756 );
757 if !passes {
758 fk_orphans += 1;
759 if orphan && config.strict_fk {
760 anyhow::bail!(
761 "FK integrity violation in table '{}': row references missing parent",
762 table_name
763 );
764 }
765 continue;
766 }
767 }
768
769 let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
770 reservoir.consider((current_idx, RowFormat::Insert, pk_hash));
771 }
772 }
773 StatementType::Copy => {
774 let header = String::from_utf8_lossy(&stmt);
775 copy_columns = parse_copy_columns(&header);
776 }
777 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
778 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
779 let rows = parse_postgres_copy_rows(
780 &stmt,
781 table_schema,
782 copy_columns.clone(),
783 )?;
784 for row in rows {
785 let current_idx = rows_seen;
786 rows_seen += 1;
787
788 if config.preserve_relations {
789 let unified = UnifiedRow::Copy(row.clone());
790 let (passes, orphan) = check_unified_fk_membership(
791 &unified,
792 table_schema,
793 runtimes,
794 cyclic_set,
795 &table_id,
796 );
797 if !passes {
798 fk_orphans += 1;
799 if orphan && config.strict_fk {
800 anyhow::bail!(
801 "FK integrity violation in table '{}': row references missing parent",
802 table_name
803 );
804 }
805 continue;
806 }
807 }
808
809 let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
810 reservoir.consider((current_idx, RowFormat::Copy, pk_hash));
811 }
812 }
813 }
814 _ => {}
815 }
816 }
817
818 let selected_items = reservoir.into_items();
820 if selected_items.is_empty() {
821 return Ok(StreamingSampleResult {
822 rows_seen,
823 rows_selected: 0,
824 fk_orphans,
825 pk_hashes: Vec::new(),
826 });
827 }
828
829 let mut selected_indices: Vec<(u64, RowFormat)> =
831 Vec::with_capacity(selected_items.len());
832 for (idx, format, pk_hash) in selected_items {
833 if let Some(h) = pk_hash {
834 selected_pk_hashes.push(h);
835 }
836 selected_indices.push((idx, format));
837 }
838 selected_indices.sort_by_key(|(idx, _)| *idx);
839
840 let file = File::open(table_file)?;
842 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
843 let mut current_row_idx = 0u64;
844 let mut select_iter = selected_indices.iter().peekable();
845
846 temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
847 let writer = temp_writer.as_mut().unwrap();
848
849 while let Some(stmt) = parser.read_statement()? {
850 if select_iter.peek().is_none() {
851 break; }
853
854 let (stmt_type, _) =
855 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
856
857 match stmt_type {
858 StatementType::Insert => {
859 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
860 for row in rows {
861 if let Some((next_idx, _)) = select_iter.peek() {
862 if current_row_idx == *next_idx {
863 writer.write_all(&[0u8])?;
864 writer.write_all(&row.raw)?;
865 writer.write_all(b"\n")?;
866 rows_selected += 1;
867 select_iter.next();
868 }
869 }
870 current_row_idx += 1;
871 }
872 }
873 StatementType::Copy => {
874 let header = String::from_utf8_lossy(&stmt);
875 copy_columns = parse_copy_columns(&header);
876 }
877 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
878 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
879 let rows = parse_postgres_copy_rows(
880 &stmt,
881 table_schema,
882 copy_columns.clone(),
883 )?;
884 for row in rows {
885 if let Some((next_idx, _)) = select_iter.peek() {
886 if current_row_idx == *next_idx {
887 writer.write_all(&[1u8])?;
888 writer.write_all(&row.raw)?;
889 writer.write_all(b"\n")?;
890 rows_selected += 1;
891 select_iter.next();
892 }
893 }
894 current_row_idx += 1;
895 }
896 }
897 }
898 _ => {}
899 }
900 }
901 }
902 }
903
904 if let Some(mut writer) = temp_writer {
906 writer.flush()?;
907 }
908
909 Ok(StreamingSampleResult {
910 rows_seen,
911 rows_selected,
912 fk_orphans,
913 pk_hashes: selected_pk_hashes,
914 })
915}
916
917fn check_unified_fk_membership(
920 row: &UnifiedRow,
921 table_schema: &crate::schema::TableSchema,
922 runtimes: &AHashMap<TableId, TableRuntime>,
923 cyclic_set: &ahash::AHashSet<TableId>,
924 current_table_id: &TableId,
925) -> (bool, bool) {
926 let mut passes = true;
927 let mut is_orphan = false;
928
929 for (fk_ref, fk_tuple) in row.fk_values() {
930 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
931 if let Some(parent_id) = fk.referenced_table_id {
932 if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
934 continue;
935 }
936
937 if let Some(parent_runtime) = runtimes.get(&parent_id) {
939 let fk_hash = hash_pk_tuple(fk_tuple);
940 if !parent_runtime.pk_set.contains(&fk_hash) {
941 passes = false;
942 is_orphan = true;
943 break;
944 }
945 }
946 }
947 }
948 }
949
950 (passes, is_orphan)
951}
952
953fn write_output(
955 config: &SampleConfig,
956 _graph: &SchemaGraph,
957 table_order: &[TableId],
958 runtimes: &AHashMap<TableId, TableRuntime>,
959 tables_dir: &Path,
960 stats: &SampleStats,
961) -> anyhow::Result<()> {
962 let mut writer: Box<dyn Write> = match &config.output {
963 Some(path) => {
964 if let Some(parent) = path.parent() {
965 fs::create_dir_all(parent)?;
966 }
967 Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
968 }
969 None => Box::new(BufWriter::new(std::io::stdout())),
970 };
971
972 write_header(&mut writer, config, stats)?;
974
975 write_dialect_header(&mut writer, config.dialect)?;
977
978 if config.include_schema {
980 for &table_id in table_order {
981 let runtime = match runtimes.get(&table_id) {
982 Some(r) if !r.skip && r.rows_selected > 0 => r,
983 _ => continue,
984 };
985
986 let table_file = tables_dir.join(format!("{}.sql", runtime.name));
987 if !table_file.exists() {
988 continue;
989 }
990
991 let file = File::open(&table_file)?;
993 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
994
995 while let Some(stmt) = parser.read_statement()? {
996 let (stmt_type, _) =
997 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
998
999 if stmt_type.is_schema() {
1000 writer.write_all(&stmt)?;
1001 writer.write_all(b"\n")?;
1002 }
1003 }
1004 }
1005 }
1006
1007 for &table_id in table_order {
1009 let runtime = match runtimes.get(&table_id) {
1010 Some(r) if !r.skip && r.rows_selected > 0 && r.selected_temp_path.is_some() => r,
1011 _ => continue,
1012 };
1013
1014 let table_name = &runtime.name;
1015 let row_count = runtime.rows_selected;
1016
1017 writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
1018
1019 let quoted_name = match config.dialect {
1021 SqlDialect::MySql => format!("`{}`", table_name),
1022 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1023 SqlDialect::Mssql => format!("[{}]", table_name),
1024 };
1025
1026 let temp_path = runtime.selected_temp_path.as_ref().unwrap();
1028 let temp_file = File::open(temp_path)?;
1029 let reader = std::io::BufReader::new(temp_file);
1030 use std::io::BufRead;
1031
1032 const CHUNK_SIZE: usize = 1000;
1033 let mut chunk_buffer: Vec<(RowFormat, Vec<u8>)> = Vec::with_capacity(CHUNK_SIZE);
1034
1035 for line in reader.lines() {
1036 let line = line?;
1037 if line.is_empty() {
1038 continue;
1039 }
1040
1041 let bytes = line.as_bytes();
1042 if bytes.is_empty() {
1043 continue;
1044 }
1045
1046 let format = if bytes[0] == 0 {
1048 RowFormat::Insert
1049 } else {
1050 RowFormat::Copy
1051 };
1052 let row_bytes = bytes[1..].to_vec();
1053
1054 chunk_buffer.push((format, row_bytes));
1055
1056 if chunk_buffer.len() >= CHUNK_SIZE {
1057 write_insert_chunk(&mut writer, "ed_name, &chunk_buffer, config.dialect)?;
1058 chunk_buffer.clear();
1059 }
1060 }
1061
1062 if !chunk_buffer.is_empty() {
1064 write_insert_chunk(&mut writer, "ed_name, &chunk_buffer, config.dialect)?;
1065 }
1066 }
1067
1068 write_dialect_footer(&mut writer, config.dialect)?;
1070
1071 writer.flush()?;
1072
1073 Ok(())
1074}
1075
1076fn write_header<W: Write>(
1078 writer: &mut W,
1079 config: &SampleConfig,
1080 stats: &SampleStats,
1081) -> std::io::Result<()> {
1082 writeln!(writer, "-- Sampled from: {}", config.input.display())?;
1083 writeln!(
1084 writer,
1085 "-- Date: {}",
1086 chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1087 )?;
1088 writeln!(
1089 writer,
1090 "-- Mode: {:?}{}",
1091 config.mode,
1092 if config.preserve_relations {
1093 ", preserve-relations"
1094 } else {
1095 ""
1096 }
1097 )?;
1098 writeln!(writer, "-- Seed: {}", config.seed)?;
1099 writeln!(writer, "-- Dialect: {}", config.dialect)?;
1100 writeln!(writer, "--")?;
1101 writeln!(writer, "-- Statistics:")?;
1102 writeln!(writer, "-- Tables sampled: {}", stats.tables_sampled)?;
1103 writeln!(writer, "-- Tables skipped: {}", stats.tables_skipped)?;
1104
1105 let percent = if stats.total_rows_seen > 0 {
1106 (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1107 } else {
1108 0.0
1109 };
1110 writeln!(
1111 writer,
1112 "-- Total rows: {} (from {} original, {:.1}%)",
1113 stats.total_rows_selected, stats.total_rows_seen, percent
1114 )?;
1115
1116 if stats.fk_orphans_rejected > 0 {
1117 writeln!(
1118 writer,
1119 "-- FK orphans rejected: {}",
1120 stats.fk_orphans_rejected
1121 )?;
1122 }
1123
1124 if !stats.warnings.is_empty() {
1125 writeln!(writer, "-- Warnings: {}", stats.warnings.len())?;
1126 }
1127
1128 writeln!(writer)?;
1129
1130 Ok(())
1131}
1132
1133fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1135 match dialect {
1136 SqlDialect::MySql => {
1137 writeln!(writer, "SET NAMES utf8mb4;")?;
1138 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1139 }
1140 SqlDialect::Postgres => {
1141 writeln!(writer, "SET client_encoding = 'UTF8';")?;
1142 writeln!(writer, "SET session_replication_role = replica;")?;
1143 }
1144 SqlDialect::Sqlite => {
1145 writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1146 }
1147 SqlDialect::Mssql => {
1148 writeln!(writer, "SET ANSI_NULLS ON;")?;
1149 writeln!(writer, "SET QUOTED_IDENTIFIER ON;")?;
1150 writeln!(writer, "SET NOCOUNT ON;")?;
1151 }
1152 }
1153 writeln!(writer)?;
1154 Ok(())
1155}
1156
1157fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1159 writeln!(writer)?;
1160 match dialect {
1161 SqlDialect::MySql => {
1162 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1163 }
1164 SqlDialect::Postgres => {
1165 writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1166 }
1167 SqlDialect::Sqlite => {
1168 writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1169 }
1170 SqlDialect::Mssql => {
1171 }
1173 }
1174 Ok(())
1175}
1176
1177fn write_insert_chunk<W: Write>(
1179 writer: &mut W,
1180 quoted_name: &str,
1181 chunk: &[(RowFormat, Vec<u8>)],
1182 dialect: SqlDialect,
1183) -> std::io::Result<()> {
1184 writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1185
1186 for (i, (format, row_bytes)) in chunk.iter().enumerate() {
1187 if i > 0 {
1188 writer.write_all(b",\n")?;
1189 }
1190
1191 let values = match format {
1192 RowFormat::Insert => match dialect {
1193 SqlDialect::Postgres => convert_row_to_postgres(row_bytes),
1194 _ => row_bytes.clone(),
1195 },
1196 RowFormat::Copy => convert_copy_to_insert_values(row_bytes, dialect),
1197 };
1198 writer.write_all(&values)?;
1199 }
1200
1201 writer.write_all(b";\n")?;
1202 Ok(())
1203}
1204
1205fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1207 let mut result = Vec::with_capacity(row.len());
1210 let mut i = 0;
1211
1212 while i < row.len() {
1213 if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1214 result.push(b'\'');
1216 result.push(b'\'');
1217 i += 2;
1218 } else {
1219 result.push(row[i]);
1220 i += 1;
1221 }
1222 }
1223
1224 result
1225}
1226
1227fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1229 let mut result = Vec::with_capacity(row.len() + 20);
1230 result.push(b'(');
1231
1232 let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1233
1234 for (i, field) in fields.iter().enumerate() {
1235 if i > 0 {
1236 result.extend_from_slice(b", ");
1237 }
1238
1239 if *field == b"\\N" {
1241 result.extend_from_slice(b"NULL");
1242 } else if field.is_empty() {
1243 match dialect {
1245 SqlDialect::MySql => result.extend_from_slice(b"''"),
1246 SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1247 result.extend_from_slice(b"''")
1248 }
1249 }
1250 } else if is_numeric(field) {
1251 result.extend_from_slice(field);
1253 } else {
1254 result.push(b'\'');
1256 for &b in *field {
1257 match b {
1258 b'\'' => {
1259 match dialect {
1261 SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1262 SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1263 result.extend_from_slice(b"''")
1264 }
1265 }
1266 }
1267 b'\\' if dialect == SqlDialect::MySql => {
1268 result.extend_from_slice(b"\\\\");
1270 }
1271 _ => result.push(b),
1272 }
1273 }
1274 result.push(b'\'');
1275 }
1276 }
1277
1278 result.push(b')');
1279 result
1280}
1281
1282fn is_numeric(s: &[u8]) -> bool {
1284 if s.is_empty() {
1285 return false;
1286 }
1287
1288 let mut has_digit = false;
1289 let mut has_dot = false;
1290 let mut start = 0;
1291
1292 if s[0] == b'-' || s[0] == b'+' {
1294 start = 1;
1295 }
1296
1297 for &b in &s[start..] {
1298 match b {
1299 b'0'..=b'9' => has_digit = true,
1300 b'.' if !has_dot => has_dot = true,
1301 b'e' | b'E' => {
1302 continue;
1304 }
1305 _ => return false,
1306 }
1307 }
1308
1309 has_digit
1310}