Skip to main content

sql_splitter/sample/
mod.rs

1//! Sample command for creating reduced datasets from SQL dumps.
2//!
3//! The sample command creates reduced datasets while optionally preserving
4//! foreign key integrity through dependency-aware FK chain resolution.
5//!
6//! Supports MySQL, PostgreSQL, and SQLite dialects.
7
8mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{hash_pk_tuple, parse_mysql_insert_rows, ParsedRow, PkHashSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28/// Sampling mode
29#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31    /// Sample N% of rows from each table
32    Percent(u32),
33    /// Sample up to N rows from each table
34    Rows(usize),
35}
36
37/// Configuration for the sample command
38#[derive(Debug)]
39pub struct SampleConfig {
40    /// Input SQL file
41    pub input: PathBuf,
42    /// Output SQL file (None for stdout)
43    pub output: Option<PathBuf>,
44    /// SQL dialect
45    pub dialect: SqlDialect,
46    /// Sampling mode
47    pub mode: SampleMode,
48    /// Preserve foreign key relationships
49    pub preserve_relations: bool,
50    /// Only sample these tables (None = all)
51    pub tables_filter: Option<Vec<String>>,
52    /// Exclude these tables
53    pub exclude: Vec<String>,
54    /// Root tables for sampling (start from these)
55    pub root_tables: Vec<String>,
56    /// How to handle global/lookup tables
57    pub include_global: GlobalTableMode,
58    /// Random seed for reproducibility
59    pub seed: u64,
60    /// Dry run mode (show stats only)
61    pub dry_run: bool,
62    /// Show progress
63    pub progress: bool,
64    /// YAML config file path
65    pub config_file: Option<PathBuf>,
66    /// Maximum total rows to sample (explosion guard)
67    pub max_total_rows: Option<usize>,
68    /// Fail if any FK integrity issues detected
69    pub strict_fk: bool,
70    /// Include schema statements in output
71    pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75    fn default() -> Self {
76        Self {
77            input: PathBuf::new(),
78            output: None,
79            dialect: SqlDialect::MySql,
80            mode: SampleMode::Percent(10),
81            preserve_relations: false,
82            tables_filter: None,
83            exclude: Vec::new(),
84            root_tables: Vec::new(),
85            include_global: GlobalTableMode::Lookups,
86            seed: rand::random(),
87            dry_run: false,
88            progress: false,
89            config_file: None,
90            max_total_rows: None,
91            strict_fk: false,
92            include_schema: true,
93        }
94    }
95}
96
97/// Statistics from sample operation
98#[derive(Debug, Default, serde::Serialize)]
99pub struct SampleStats {
100    /// Number of tables sampled
101    pub tables_sampled: usize,
102    /// Number of tables skipped
103    pub tables_skipped: usize,
104    /// Total rows selected
105    pub total_rows_selected: u64,
106    /// Total rows seen
107    pub total_rows_seen: u64,
108    /// Per-table statistics
109    pub table_stats: Vec<TableSampleStats>,
110    /// Warning messages
111    pub warnings: Vec<String>,
112    /// FK orphan count (rows rejected due to missing parents)
113    pub fk_orphans_rejected: u64,
114}
115
116/// Per-table sampling statistics
117#[derive(Debug, Clone, serde::Serialize)]
118pub struct TableSampleStats {
119    pub name: String,
120    pub rows_seen: u64,
121    pub rows_selected: u64,
122    pub classification: TableClassification,
123}
124
125/// Runtime state for a table during sampling
126struct TableRuntime {
127    /// Table name
128    name: String,
129    /// Primary key hashes for FK membership checks (compact: 8 bytes per key)
130    pk_set: PkHashSet,
131    /// Rows seen count
132    rows_seen: u64,
133    /// Rows selected count
134    rows_selected: u64,
135    /// Whether to skip this table
136    skip: bool,
137    /// Table classification
138    classification: TableClassification,
139    /// FK orphans rejected for this table
140    fk_orphans: u64,
141    /// Path to temp file containing selected row bytes (None if no rows selected yet)
142    selected_temp_path: Option<PathBuf>,
143}
144
145/// Combined row representation for both MySQL INSERT and PostgreSQL COPY
146enum UnifiedRow {
147    Insert(ParsedRow),
148    Copy(ParsedCopyRow),
149}
150
151/// Row format indicator for output
152#[derive(Debug, Clone, Copy, PartialEq)]
153enum RowFormat {
154    Insert,
155    Copy,
156}
157
158impl UnifiedRow {
159    fn fk_values(
160        &self,
161    ) -> &[(
162        crate::parser::mysql_insert::FkRef,
163        smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
164    )] {
165        match self {
166            UnifiedRow::Insert(r) => &r.fk_values,
167            UnifiedRow::Copy(r) => &r.fk_values,
168        }
169    }
170}
171
172/// Run the sample command
173pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
174    // Load YAML config if provided
175    let yaml_config = if let Some(ref path) = config.config_file {
176        Some(SampleYamlConfig::load(path)?)
177    } else {
178        None
179    };
180
181    let mut rng = StdRng::seed_from_u64(config.seed);
182    let mut stats = SampleStats::default();
183
184    // Get file size for progress tracking
185    let file_size = std::fs::metadata(&config.input)?.len();
186
187    // Progress bar setup - byte-based for the split phase
188    let progress_bar = if config.progress {
189        let pb = ProgressBar::new(file_size);
190        pb.set_style(
191            ProgressStyle::with_template(
192                "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
193            )
194            .unwrap()
195            .progress_chars("█▓▒░  ")
196            .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
197        );
198        pb.enable_steady_tick(std::time::Duration::from_millis(100));
199        pb.set_message("Splitting dump...");
200        Some(pb)
201    } else {
202        None
203    };
204
205    // Phase 0: Split into temp per-table files
206    let temp_dir = TempDir::new()?;
207    let tables_dir = temp_dir.path().join("tables");
208
209    let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
210        .with_dialect(config.dialect)
211        .with_content_filter(ContentFilter::All);
212
213    if let Some(ref pb) = progress_bar {
214        let pb_clone = pb.clone();
215        splitter = splitter.with_progress(move |bytes| {
216            pb_clone.set_position(bytes);
217        });
218    }
219
220    let split_stats = splitter.split()?;
221
222    // Finish byte-based progress, switch to milestone messages
223    if let Some(ref pb) = progress_bar {
224        pb.finish_and_clear();
225    }
226
227    if config.progress {
228        eprintln!(
229            "Split complete: {} tables, {} statements",
230            split_stats.tables_found, split_stats.statements_processed
231        );
232    }
233
234    // Phase 1: Build schema graph
235    if config.progress {
236        eprintln!("Building schema graph...");
237    }
238
239    let graph = build_schema_graph(&tables_dir, &config)?;
240
241    let (topo_order, cyclic_tables) = graph.processing_order();
242
243    if !cyclic_tables.is_empty() {
244        let names: Vec<_> = cyclic_tables
245            .iter()
246            .filter_map(|&id| graph.table_name(id))
247            .collect();
248        let msg = format!(
249            "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
250            cyclic_tables.len(),
251            names
252        );
253        if config.progress {
254            eprintln!("{}", msg);
255        }
256        stats.warnings.push(msg);
257    }
258
259    // Build set of cyclic table IDs for quick lookup
260    let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
261
262    // Determine root tables
263    let explicit_roots: ahash::AHashSet<String> = config
264        .root_tables
265        .iter()
266        .map(|s| s.to_lowercase())
267        .collect();
268
269    // Initialize table runtimes with classification
270    let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
271    for table in graph.schema.iter() {
272        let classification =
273            determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
274        let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
275
276        runtimes.insert(
277            table.id,
278            TableRuntime {
279                name: table.name.clone(),
280                pk_set: PkHashSet::default(),
281                rows_seen: 0,
282                rows_selected: 0,
283                skip,
284                classification,
285                fk_orphans: 0,
286                selected_temp_path: None,
287            },
288        );
289    }
290
291    // Create directory for selected row temp files
292    let selected_dir = temp_dir.path().join("selected");
293    fs::create_dir_all(&selected_dir)?;
294
295    // Phase 2: Process tables in dependency order
296    if config.progress {
297        eprintln!(
298            "Sampling {} tables in dependency order...",
299            topo_order.len()
300        );
301    }
302
303    // Process acyclic tables first, then cyclic tables
304    let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
305
306    let mut total_selected: u64 = 0;
307
308    for table_id in &all_tables {
309        let table_schema = match graph.schema.table(*table_id) {
310            Some(s) => s,
311            None => continue,
312        };
313
314        // Check if we should skip this table
315        let (should_skip, table_name, classification) = {
316            let runtime = match runtimes.get(table_id) {
317                Some(r) => r,
318                None => continue,
319            };
320            (runtime.skip, runtime.name.clone(), runtime.classification)
321        };
322
323        if should_skip {
324            stats.tables_skipped += 1;
325            continue;
326        }
327
328        // Handle lookup/global tables specially
329        let sample_mode = match classification {
330            TableClassification::Lookup => {
331                match config.include_global {
332                    GlobalTableMode::None => {
333                        stats.tables_skipped += 1;
334                        continue;
335                    }
336                    GlobalTableMode::Lookups | GlobalTableMode::All => {
337                        // Include all rows
338                        SampleMode::Percent(100)
339                    }
340                }
341            }
342            TableClassification::System => {
343                stats.tables_skipped += 1;
344                continue;
345            }
346            _ => get_table_sample_mode(&table_name, &config, &yaml_config),
347        };
348
349        let table_file = tables_dir.join(format!("{}.sql", table_name));
350        if !table_file.exists() {
351            continue;
352        }
353
354        // Process table with streaming sampling - rows go directly to temp file
355        let result = sample_table_streaming(
356            &table_file,
357            table_schema,
358            *table_id,
359            &table_name,
360            sample_mode,
361            &config,
362            &runtimes,
363            &cyclic_set,
364            &selected_dir,
365            &mut rng,
366        )?;
367
368        // Check max_total_rows guard
369        if let Some(max) = config.max_total_rows {
370            if total_selected + result.rows_selected > max as u64 {
371                let msg = format!(
372                    "Warning: Reached max_total_rows limit ({}) at table '{}'",
373                    max, table_name
374                );
375                stats.warnings.push(msg);
376                break;
377            }
378        }
379
380        // Update total count
381        total_selected += result.rows_selected;
382
383        // Update runtime state and add PK hashes for FK checks by children
384        // Safe: runtime existence was checked at loop start (line 323-326)
385        let runtime = runtimes
386            .get_mut(table_id)
387            .expect("runtime must exist - checked at loop start");
388        runtime.rows_seen = result.rows_seen;
389        runtime.rows_selected = result.rows_selected;
390        runtime.fk_orphans = result.fk_orphans;
391
392        // Add PK hashes for FK membership checks by child tables
393        for pk_hash in result.pk_hashes {
394            runtime.pk_set.insert(pk_hash);
395        }
396
397        // Set the temp file path if we selected any rows
398        if result.rows_selected > 0 {
399            let temp_path = selected_dir.join(format!("{}.rows", table_name));
400            if temp_path.exists() {
401                runtime.selected_temp_path = Some(temp_path);
402            }
403        }
404
405        stats.fk_orphans_rejected += result.fk_orphans;
406
407        stats.table_stats.push(TableSampleStats {
408            name: runtime.name.clone(),
409            rows_seen: result.rows_seen,
410            rows_selected: result.rows_selected,
411            classification: runtime.classification,
412        });
413    }
414
415    // Calculate totals
416    for table_stats in &stats.table_stats {
417        stats.total_rows_seen += table_stats.rows_seen;
418        stats.total_rows_selected += table_stats.rows_selected;
419    }
420    stats.tables_sampled = stats.table_stats.len();
421
422    if config.progress {
423        eprintln!("Sampling complete");
424    }
425
426    // Phase 3: Output synthesis
427    if config.dry_run {
428        return Ok(stats);
429    }
430
431    if config.progress {
432        eprintln!("Writing output...");
433    }
434
435    write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
436
437    Ok(stats)
438}
439
440/// Build schema graph from split table files
441fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
442    let mut builder = SchemaBuilder::new();
443
444    for entry in fs::read_dir(tables_dir)? {
445        let entry = entry?;
446        let path = entry.path();
447
448        if path.extension().map(|e| e == "sql").unwrap_or(false) {
449            let file = File::open(&path)?;
450            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
451
452            while let Some(stmt) = parser.read_statement()? {
453                let stmt_str = String::from_utf8_lossy(&stmt);
454                let (stmt_type, _) =
455                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
456
457                match stmt_type {
458                    StatementType::CreateTable => {
459                        builder.parse_create_table(&stmt_str);
460                    }
461                    StatementType::AlterTable => {
462                        builder.parse_alter_table(&stmt_str);
463                    }
464                    _ => {}
465                }
466            }
467        }
468    }
469
470    Ok(SchemaGraph::from_schema(builder.build()))
471}
472
473/// Determine table classification
474fn determine_classification(
475    name: &str,
476    graph: &SchemaGraph,
477    table_id: TableId,
478    yaml_config: &Option<SampleYamlConfig>,
479    explicit_roots: &ahash::AHashSet<String>,
480) -> TableClassification {
481    // Check explicit roots first
482    if explicit_roots.contains(&name.to_lowercase()) {
483        return TableClassification::Root;
484    }
485
486    // Check YAML config
487    if let Some(ref config) = yaml_config {
488        let class = config.get_classification(name);
489        if class != TableClassification::Normal {
490            return class;
491        }
492    }
493
494    // Check if it's a graph root (no parents)
495    if graph.parents[table_id.0 as usize].is_empty() {
496        return TableClassification::Root;
497    }
498
499    // Use default classifier
500    DefaultClassifier::classify(name)
501}
502
503/// Check if a table should be skipped
504fn should_skip_table(
505    name: &str,
506    config: &SampleConfig,
507    yaml_config: &Option<SampleYamlConfig>,
508    classification: TableClassification,
509) -> bool {
510    let name_lower = name.to_lowercase();
511
512    // Check exclude list
513    if config
514        .exclude
515        .iter()
516        .any(|e| e.to_lowercase() == name_lower)
517    {
518        return true;
519    }
520
521    // Check YAML skip
522    if let Some(ref yc) = yaml_config {
523        if yc.should_skip(name) {
524            return true;
525        }
526    }
527
528    // Check include filter
529    if let Some(ref filter) = config.tables_filter {
530        if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
531            return true;
532        }
533    }
534
535    // Skip system tables by default
536    if classification == TableClassification::System {
537        return true;
538    }
539
540    false
541}
542
543/// Get sample mode for a specific table
544fn get_table_sample_mode(
545    name: &str,
546    config: &SampleConfig,
547    yaml_config: &Option<SampleYamlConfig>,
548) -> SampleMode {
549    // Check YAML config first
550    if let Some(ref yc) = yaml_config {
551        if let Some(rows) = yc.get_rows(name) {
552            return SampleMode::Rows(rows);
553        }
554        if let Some(percent) = yc.get_percent(name) {
555            return SampleMode::Percent(percent);
556        }
557    }
558
559    // Fall back to global config
560    config.mode
561}
562
563/// Result from streaming sampling
564struct StreamingSampleResult {
565    rows_seen: u64,
566    rows_selected: u64,
567    fk_orphans: u64,
568    /// PK hashes of selected rows (for FK checks by children)
569    pk_hashes: Vec<u64>,
570}
571
572/// Stream-sample a table: parse rows, apply FK checks, sample inline, write to temp file.
573/// Returns StreamingSampleResult with stats and PK hashes.
574/// Uses Bernoulli sampling for --percent mode (single pass).
575/// For --rows mode, we use reservoir sampling on row indices with a second pass.
576#[allow(clippy::too_many_arguments)]
577fn sample_table_streaming(
578    table_file: &Path,
579    table_schema: &crate::schema::TableSchema,
580    table_id: TableId,
581    table_name: &str,
582    sample_mode: SampleMode,
583    config: &SampleConfig,
584    runtimes: &AHashMap<TableId, TableRuntime>,
585    cyclic_set: &ahash::AHashSet<TableId>,
586    selected_dir: &Path,
587    rng: &mut StdRng,
588) -> anyhow::Result<StreamingSampleResult> {
589    let mut rows_seen = 0u64;
590    let mut rows_selected = 0u64;
591    let mut fk_orphans = 0u64;
592
593    // Temp file for selected rows
594    let temp_path = selected_dir.join(format!("{}.rows", table_name));
595    let mut temp_writer: Option<BufWriter<File>> = None;
596
597    // Track PKs of selected rows (for children's FK checks)
598    let mut selected_pk_hashes: Vec<u64> = Vec::new();
599
600    // For PostgreSQL COPY, track the current column order
601    let mut copy_columns: Vec<String> = Vec::new();
602
603    match sample_mode {
604        SampleMode::Percent(p) => {
605            // Bernoulli sampling: decide immediately for each row
606            let prob = p as f64 / 100.0;
607
608            let file = File::open(table_file)?;
609            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
610
611            while let Some(stmt) = parser.read_statement()? {
612                let (stmt_type, _) =
613                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
614
615                match stmt_type {
616                    StatementType::Insert => {
617                        let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
618                        for row in rows {
619                            rows_seen += 1;
620
621                            // FK check
622                            if config.preserve_relations {
623                                let unified = UnifiedRow::Insert(row.clone());
624                                let (passes, orphan) = check_unified_fk_membership(
625                                    &unified,
626                                    table_schema,
627                                    runtimes,
628                                    cyclic_set,
629                                    &table_id,
630                                );
631                                if !passes {
632                                    fk_orphans += 1;
633                                    if orphan && config.strict_fk {
634                                        anyhow::bail!(
635                                            "FK integrity violation in table '{}': row references missing parent",
636                                            table_name
637                                        );
638                                    }
639                                    continue;
640                                }
641                            }
642
643                            // Bernoulli sample
644                            if rng.random::<f64>() < prob {
645                                // Write to temp file
646                                if temp_writer.is_none() {
647                                    temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
648                                }
649                                let writer = temp_writer.as_mut().unwrap();
650                                // Format: 1-byte type (0=insert, 1=copy), then row bytes, then newline
651                                writer.write_all(&[0u8])?;
652                                writer.write_all(&row.raw)?;
653                                writer.write_all(b"\n")?;
654
655                                // Track PK hash
656                                if let Some(pk) = &row.pk {
657                                    selected_pk_hashes.push(hash_pk_tuple(pk));
658                                }
659                                rows_selected += 1;
660                            }
661                        }
662                    }
663                    StatementType::Copy => {
664                        let header = String::from_utf8_lossy(&stmt);
665                        copy_columns = parse_copy_columns(&header);
666                    }
667                    StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
668                        if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
669                            let rows = parse_postgres_copy_rows(
670                                &stmt,
671                                table_schema,
672                                copy_columns.clone(),
673                            )?;
674                            for row in rows {
675                                rows_seen += 1;
676
677                                if config.preserve_relations {
678                                    let unified = UnifiedRow::Copy(row.clone());
679                                    let (passes, orphan) = check_unified_fk_membership(
680                                        &unified,
681                                        table_schema,
682                                        runtimes,
683                                        cyclic_set,
684                                        &table_id,
685                                    );
686                                    if !passes {
687                                        fk_orphans += 1;
688                                        if orphan && config.strict_fk {
689                                            anyhow::bail!(
690                                                "FK integrity violation in table '{}': row references missing parent",
691                                                table_name
692                                            );
693                                        }
694                                        continue;
695                                    }
696                                }
697
698                                if rng.random::<f64>() < prob {
699                                    if temp_writer.is_none() {
700                                        temp_writer =
701                                            Some(BufWriter::new(File::create(&temp_path)?));
702                                    }
703                                    let writer = temp_writer.as_mut().unwrap();
704                                    writer.write_all(&[1u8])?;
705                                    writer.write_all(&row.raw)?;
706                                    writer.write_all(b"\n")?;
707
708                                    if let Some(pk) = &row.pk {
709                                        selected_pk_hashes.push(hash_pk_tuple(pk));
710                                    }
711                                    rows_selected += 1;
712                                }
713                            }
714                        }
715                    }
716                    _ => {}
717                }
718            }
719        }
720        SampleMode::Rows(n) => {
721            // Reservoir sampling: collect eligible row indices in first pass,
722            // then write selected rows in second pass
723            let mut reservoir: Reservoir<(u64, RowFormat, Option<u64>)> =
724                Reservoir::new(n, StdRng::from_rng(&mut *rng));
725
726            // First pass: build reservoir of (row_index, format, pk_hash)
727            let file = File::open(table_file)?;
728            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
729
730            while let Some(stmt) = parser.read_statement()? {
731                let (stmt_type, _) =
732                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
733
734                match stmt_type {
735                    StatementType::Insert => {
736                        let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
737                        for row in rows {
738                            let current_idx = rows_seen;
739                            rows_seen += 1;
740
741                            if config.preserve_relations {
742                                let unified = UnifiedRow::Insert(row.clone());
743                                let (passes, orphan) = check_unified_fk_membership(
744                                    &unified,
745                                    table_schema,
746                                    runtimes,
747                                    cyclic_set,
748                                    &table_id,
749                                );
750                                if !passes {
751                                    fk_orphans += 1;
752                                    if orphan && config.strict_fk {
753                                        anyhow::bail!(
754                                            "FK integrity violation in table '{}': row references missing parent",
755                                            table_name
756                                        );
757                                    }
758                                    continue;
759                                }
760                            }
761
762                            let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
763                            reservoir.consider((current_idx, RowFormat::Insert, pk_hash));
764                        }
765                    }
766                    StatementType::Copy => {
767                        let header = String::from_utf8_lossy(&stmt);
768                        copy_columns = parse_copy_columns(&header);
769                    }
770                    StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
771                        if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
772                            let rows = parse_postgres_copy_rows(
773                                &stmt,
774                                table_schema,
775                                copy_columns.clone(),
776                            )?;
777                            for row in rows {
778                                let current_idx = rows_seen;
779                                rows_seen += 1;
780
781                                if config.preserve_relations {
782                                    let unified = UnifiedRow::Copy(row.clone());
783                                    let (passes, orphan) = check_unified_fk_membership(
784                                        &unified,
785                                        table_schema,
786                                        runtimes,
787                                        cyclic_set,
788                                        &table_id,
789                                    );
790                                    if !passes {
791                                        fk_orphans += 1;
792                                        if orphan && config.strict_fk {
793                                            anyhow::bail!(
794                                                "FK integrity violation in table '{}': row references missing parent",
795                                                table_name
796                                            );
797                                        }
798                                        continue;
799                                    }
800                                }
801
802                                let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
803                                reservoir.consider((current_idx, RowFormat::Copy, pk_hash));
804                            }
805                        }
806                    }
807                    _ => {}
808                }
809            }
810
811            // Extract selected indices and PKs from reservoir
812            let selected_items = reservoir.into_items();
813            if selected_items.is_empty() {
814                return Ok(StreamingSampleResult {
815                    rows_seen,
816                    rows_selected: 0,
817                    fk_orphans,
818                    pk_hashes: Vec::new(),
819                });
820            }
821
822            // Collect PK hashes and sort indices for second pass
823            let mut selected_indices: Vec<(u64, RowFormat)> =
824                Vec::with_capacity(selected_items.len());
825            for (idx, format, pk_hash) in selected_items {
826                if let Some(h) = pk_hash {
827                    selected_pk_hashes.push(h);
828                }
829                selected_indices.push((idx, format));
830            }
831            selected_indices.sort_by_key(|(idx, _)| *idx);
832
833            // Second pass: write selected rows to temp file
834            let file = File::open(table_file)?;
835            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
836            let mut current_row_idx = 0u64;
837            let mut select_iter = selected_indices.iter().peekable();
838
839            temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
840            let writer = temp_writer.as_mut().unwrap();
841
842            while let Some(stmt) = parser.read_statement()? {
843                if select_iter.peek().is_none() {
844                    break; // All selected rows written
845                }
846
847                let (stmt_type, _) =
848                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
849
850                match stmt_type {
851                    StatementType::Insert => {
852                        let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
853                        for row in rows {
854                            if let Some((next_idx, _)) = select_iter.peek() {
855                                if current_row_idx == *next_idx {
856                                    writer.write_all(&[0u8])?;
857                                    writer.write_all(&row.raw)?;
858                                    writer.write_all(b"\n")?;
859                                    rows_selected += 1;
860                                    select_iter.next();
861                                }
862                            }
863                            current_row_idx += 1;
864                        }
865                    }
866                    StatementType::Copy => {
867                        let header = String::from_utf8_lossy(&stmt);
868                        copy_columns = parse_copy_columns(&header);
869                    }
870                    StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
871                        if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
872                            let rows = parse_postgres_copy_rows(
873                                &stmt,
874                                table_schema,
875                                copy_columns.clone(),
876                            )?;
877                            for row in rows {
878                                if let Some((next_idx, _)) = select_iter.peek() {
879                                    if current_row_idx == *next_idx {
880                                        writer.write_all(&[1u8])?;
881                                        writer.write_all(&row.raw)?;
882                                        writer.write_all(b"\n")?;
883                                        rows_selected += 1;
884                                        select_iter.next();
885                                    }
886                                }
887                                current_row_idx += 1;
888                            }
889                        }
890                    }
891                    _ => {}
892                }
893            }
894        }
895    }
896
897    // Flush temp file
898    if let Some(mut writer) = temp_writer {
899        writer.flush()?;
900    }
901
902    Ok(StreamingSampleResult {
903        rows_seen,
904        rows_selected,
905        fk_orphans,
906        pk_hashes: selected_pk_hashes,
907    })
908}
909
910/// Check FK membership for a unified row (works with both INSERT and COPY rows)
911/// Uses hash-based lookup for memory efficiency.
912fn check_unified_fk_membership(
913    row: &UnifiedRow,
914    table_schema: &crate::schema::TableSchema,
915    runtimes: &AHashMap<TableId, TableRuntime>,
916    cyclic_set: &ahash::AHashSet<TableId>,
917    current_table_id: &TableId,
918) -> (bool, bool) {
919    let mut passes = true;
920    let mut is_orphan = false;
921
922    for (fk_ref, fk_tuple) in row.fk_values() {
923        if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
924            if let Some(parent_id) = fk.referenced_table_id {
925                // Skip FK check for cyclic tables
926                if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
927                    continue;
928                }
929
930                // Check if parent row exists in parent's pk_set using hash lookup
931                if let Some(parent_runtime) = runtimes.get(&parent_id) {
932                    let fk_hash = hash_pk_tuple(fk_tuple);
933                    if !parent_runtime.pk_set.contains(&fk_hash) {
934                        passes = false;
935                        is_orphan = true;
936                        break;
937                    }
938                }
939            }
940        }
941    }
942
943    (passes, is_orphan)
944}
945
946/// Write sampled output
947fn write_output(
948    config: &SampleConfig,
949    _graph: &SchemaGraph,
950    table_order: &[TableId],
951    runtimes: &AHashMap<TableId, TableRuntime>,
952    tables_dir: &Path,
953    stats: &SampleStats,
954) -> anyhow::Result<()> {
955    let mut writer: Box<dyn Write> = match &config.output {
956        Some(path) => {
957            if let Some(parent) = path.parent() {
958                fs::create_dir_all(parent)?;
959            }
960            Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
961        }
962        None => Box::new(BufWriter::new(std::io::stdout())),
963    };
964
965    // Write header comment
966    write_header(&mut writer, config, stats)?;
967
968    // Write dialect-specific header
969    write_dialect_header(&mut writer, config.dialect)?;
970
971    // Write schema for each table (if enabled)
972    if config.include_schema {
973        for &table_id in table_order {
974            let runtime = match runtimes.get(&table_id) {
975                Some(r) if !r.skip && r.rows_selected > 0 => r,
976                _ => continue,
977            };
978
979            let table_file = tables_dir.join(format!("{}.sql", runtime.name));
980            if !table_file.exists() {
981                continue;
982            }
983
984            // Copy schema statements from table file
985            let file = File::open(&table_file)?;
986            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
987
988            while let Some(stmt) = parser.read_statement()? {
989                let (stmt_type, _) =
990                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
991
992                if stmt_type.is_schema() {
993                    writer.write_all(&stmt)?;
994                    writer.write_all(b"\n")?;
995                }
996            }
997        }
998    }
999
1000    // Write data for each table (reading from temp files instead of memory)
1001    for &table_id in table_order {
1002        let runtime = match runtimes.get(&table_id) {
1003            Some(r) if !r.skip && r.rows_selected > 0 && r.selected_temp_path.is_some() => r,
1004            _ => continue,
1005        };
1006
1007        let table_name = &runtime.name;
1008        let row_count = runtime.rows_selected;
1009
1010        writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
1011
1012        // Get the table name quoting based on dialect
1013        let quoted_name = match config.dialect {
1014            SqlDialect::MySql => format!("`{}`", table_name),
1015            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1016            SqlDialect::Mssql => format!("[{}]", table_name),
1017        };
1018
1019        // Read rows from temp file and write INSERTs in chunks
1020        let temp_path = runtime.selected_temp_path.as_ref().unwrap();
1021        let temp_file = File::open(temp_path)?;
1022        let reader = std::io::BufReader::new(temp_file);
1023        use std::io::BufRead;
1024
1025        const CHUNK_SIZE: usize = 1000;
1026        let mut chunk_buffer: Vec<(RowFormat, Vec<u8>)> = Vec::with_capacity(CHUNK_SIZE);
1027
1028        for line in reader.lines() {
1029            let line = line?;
1030            if line.is_empty() {
1031                continue;
1032            }
1033
1034            let bytes = line.as_bytes();
1035            if bytes.is_empty() {
1036                continue;
1037            }
1038
1039            // First byte is format indicator (0=insert, 1=copy)
1040            let format = if bytes[0] == 0 {
1041                RowFormat::Insert
1042            } else {
1043                RowFormat::Copy
1044            };
1045            let row_bytes = bytes[1..].to_vec();
1046
1047            chunk_buffer.push((format, row_bytes));
1048
1049            if chunk_buffer.len() >= CHUNK_SIZE {
1050                write_insert_chunk(&mut writer, &quoted_name, &chunk_buffer, config.dialect)?;
1051                chunk_buffer.clear();
1052            }
1053        }
1054
1055        // Write remaining rows
1056        if !chunk_buffer.is_empty() {
1057            write_insert_chunk(&mut writer, &quoted_name, &chunk_buffer, config.dialect)?;
1058        }
1059    }
1060
1061    // Write dialect-specific footer
1062    write_dialect_footer(&mut writer, config.dialect)?;
1063
1064    writer.flush()?;
1065
1066    Ok(())
1067}
1068
1069/// Write header comment
1070fn write_header<W: Write>(
1071    writer: &mut W,
1072    config: &SampleConfig,
1073    stats: &SampleStats,
1074) -> std::io::Result<()> {
1075    writeln!(writer, "-- Sampled from: {}", config.input.display())?;
1076    writeln!(
1077        writer,
1078        "-- Date: {}",
1079        chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1080    )?;
1081    writeln!(
1082        writer,
1083        "-- Mode: {:?}{}",
1084        config.mode,
1085        if config.preserve_relations {
1086            ", preserve-relations"
1087        } else {
1088            ""
1089        }
1090    )?;
1091    writeln!(writer, "-- Seed: {}", config.seed)?;
1092    writeln!(writer, "-- Dialect: {}", config.dialect)?;
1093    writeln!(writer, "--")?;
1094    writeln!(writer, "-- Statistics:")?;
1095    writeln!(writer, "--   Tables sampled: {}", stats.tables_sampled)?;
1096    writeln!(writer, "--   Tables skipped: {}", stats.tables_skipped)?;
1097
1098    let percent = if stats.total_rows_seen > 0 {
1099        (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1100    } else {
1101        0.0
1102    };
1103    writeln!(
1104        writer,
1105        "--   Total rows: {} (from {} original, {:.1}%)",
1106        stats.total_rows_selected, stats.total_rows_seen, percent
1107    )?;
1108
1109    if stats.fk_orphans_rejected > 0 {
1110        writeln!(
1111            writer,
1112            "--   FK orphans rejected: {}",
1113            stats.fk_orphans_rejected
1114        )?;
1115    }
1116
1117    if !stats.warnings.is_empty() {
1118        writeln!(writer, "--   Warnings: {}", stats.warnings.len())?;
1119    }
1120
1121    writeln!(writer)?;
1122
1123    Ok(())
1124}
1125
1126/// Write dialect-specific header
1127fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1128    match dialect {
1129        SqlDialect::MySql => {
1130            writeln!(writer, "SET NAMES utf8mb4;")?;
1131            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1132        }
1133        SqlDialect::Postgres => {
1134            writeln!(writer, "SET client_encoding = 'UTF8';")?;
1135            writeln!(writer, "SET session_replication_role = replica;")?;
1136        }
1137        SqlDialect::Sqlite => {
1138            writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1139        }
1140        SqlDialect::Mssql => {
1141            writeln!(writer, "SET ANSI_NULLS ON;")?;
1142            writeln!(writer, "SET QUOTED_IDENTIFIER ON;")?;
1143            writeln!(writer, "SET NOCOUNT ON;")?;
1144        }
1145    }
1146    writeln!(writer)?;
1147    Ok(())
1148}
1149
1150/// Write dialect-specific footer
1151fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1152    writeln!(writer)?;
1153    match dialect {
1154        SqlDialect::MySql => {
1155            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1156        }
1157        SqlDialect::Postgres => {
1158            writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1159        }
1160        SqlDialect::Sqlite => {
1161            writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1162        }
1163        SqlDialect::Mssql => {
1164            // No footer needed
1165        }
1166    }
1167    Ok(())
1168}
1169
1170/// Write a chunk of rows as an INSERT statement
1171fn write_insert_chunk<W: Write>(
1172    writer: &mut W,
1173    quoted_name: &str,
1174    chunk: &[(RowFormat, Vec<u8>)],
1175    dialect: SqlDialect,
1176) -> std::io::Result<()> {
1177    writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1178
1179    for (i, (format, row_bytes)) in chunk.iter().enumerate() {
1180        if i > 0 {
1181            writer.write_all(b",\n")?;
1182        }
1183
1184        let values = match format {
1185            RowFormat::Insert => match dialect {
1186                SqlDialect::Postgres => convert_row_to_postgres(row_bytes),
1187                _ => row_bytes.clone(),
1188            },
1189            RowFormat::Copy => convert_copy_to_insert_values(row_bytes, dialect),
1190        };
1191        writer.write_all(&values)?;
1192    }
1193
1194    writer.write_all(b";\n")?;
1195    Ok(())
1196}
1197
1198/// Convert a MySQL-style row to PostgreSQL syntax
1199fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1200    // Simple conversion: just replace escaped quotes
1201    // A full implementation would handle more edge cases
1202    let mut result = Vec::with_capacity(row.len());
1203    let mut i = 0;
1204
1205    while i < row.len() {
1206        if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1207            // MySQL: \' -> PostgreSQL: ''
1208            result.push(b'\'');
1209            result.push(b'\'');
1210            i += 2;
1211        } else {
1212            result.push(row[i]);
1213            i += 1;
1214        }
1215    }
1216
1217    result
1218}
1219
1220/// Convert PostgreSQL COPY format (tab-separated) to INSERT VALUES format
1221fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1222    let mut result = Vec::with_capacity(row.len() + 20);
1223    result.push(b'(');
1224
1225    let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1226
1227    for (i, field) in fields.iter().enumerate() {
1228        if i > 0 {
1229            result.extend_from_slice(b", ");
1230        }
1231
1232        // Check for NULL marker
1233        if *field == b"\\N" {
1234            result.extend_from_slice(b"NULL");
1235        } else if field.is_empty() {
1236            // Empty string
1237            match dialect {
1238                SqlDialect::MySql => result.extend_from_slice(b"''"),
1239                SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1240                    result.extend_from_slice(b"''")
1241                }
1242            }
1243        } else if is_numeric(field) {
1244            // Numeric value - no quotes needed
1245            result.extend_from_slice(field);
1246        } else {
1247            // String value - needs quoting
1248            result.push(b'\'');
1249            for &b in *field {
1250                match b {
1251                    b'\'' => {
1252                        // Escape single quote
1253                        match dialect {
1254                            SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1255                            SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1256                                result.extend_from_slice(b"''")
1257                            }
1258                        }
1259                    }
1260                    b'\\' if dialect == SqlDialect::MySql => {
1261                        // Escape backslash in MySQL
1262                        result.extend_from_slice(b"\\\\");
1263                    }
1264                    _ => result.push(b),
1265                }
1266            }
1267            result.push(b'\'');
1268        }
1269    }
1270
1271    result.push(b')');
1272    result
1273}
1274
1275/// Check if a byte slice represents a numeric value
1276fn is_numeric(s: &[u8]) -> bool {
1277    if s.is_empty() {
1278        return false;
1279    }
1280
1281    let mut has_digit = false;
1282    let mut has_dot = false;
1283    let mut start = 0;
1284
1285    // Handle leading sign
1286    if s[0] == b'-' || s[0] == b'+' {
1287        start = 1;
1288    }
1289
1290    for &b in &s[start..] {
1291        match b {
1292            b'0'..=b'9' => has_digit = true,
1293            b'.' if !has_dot => has_dot = true,
1294            b'e' | b'E' => {
1295                // Scientific notation - just check rest is digits
1296                continue;
1297            }
1298            _ => return false,
1299        }
1300    }
1301
1302    has_digit
1303}