Skip to main content

sql_splitter/sample/
mod.rs

1//! Sample command for creating reduced datasets from SQL dumps.
2//!
3//! The sample command creates reduced datasets while optionally preserving
4//! foreign key integrity through dependency-aware FK chain resolution.
5//!
6//! Supports MySQL, PostgreSQL, and SQLite dialects.
7
8mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{hash_pk_tuple, parse_mysql_insert_rows, ParsedRow, PkHashSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28/// Sampling mode
29#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31    /// Sample N% of rows from each table
32    Percent(u32),
33    /// Sample up to N rows from each table
34    Rows(usize),
35}
36
37/// Configuration for the sample command
38#[derive(Debug)]
39pub struct SampleConfig {
40    /// Input SQL file
41    pub input: PathBuf,
42    /// Output SQL file (None for stdout)
43    pub output: Option<PathBuf>,
44    /// SQL dialect
45    pub dialect: SqlDialect,
46    /// Sampling mode
47    pub mode: SampleMode,
48    /// Preserve foreign key relationships
49    pub preserve_relations: bool,
50    /// Only sample these tables (None = all)
51    pub tables_filter: Option<Vec<String>>,
52    /// Exclude these tables
53    pub exclude: Vec<String>,
54    /// Root tables for sampling (start from these)
55    pub root_tables: Vec<String>,
56    /// How to handle global/lookup tables
57    pub include_global: GlobalTableMode,
58    /// Random seed for reproducibility
59    pub seed: u64,
60    /// Dry run mode (show stats only)
61    pub dry_run: bool,
62    /// Show progress
63    pub progress: bool,
64    /// YAML config file path
65    pub config_file: Option<PathBuf>,
66    /// Maximum total rows to sample (explosion guard)
67    pub max_total_rows: Option<usize>,
68    /// Fail if any FK integrity issues detected
69    pub strict_fk: bool,
70    /// Include schema statements in output
71    pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75    fn default() -> Self {
76        Self {
77            input: PathBuf::new(),
78            output: None,
79            dialect: SqlDialect::MySql,
80            mode: SampleMode::Percent(10),
81            preserve_relations: false,
82            tables_filter: None,
83            exclude: Vec::new(),
84            root_tables: Vec::new(),
85            include_global: GlobalTableMode::Lookups,
86            seed: rand::random(),
87            dry_run: false,
88            progress: false,
89            config_file: None,
90            max_total_rows: None,
91            strict_fk: false,
92            include_schema: true,
93        }
94    }
95}
96
97/// Statistics from sample operation
98#[derive(Debug, Default, serde::Serialize)]
99pub struct SampleStats {
100    /// Number of tables sampled
101    pub tables_sampled: usize,
102    /// Number of tables skipped
103    pub tables_skipped: usize,
104    /// Total rows selected
105    pub total_rows_selected: u64,
106    /// Total rows seen
107    pub total_rows_seen: u64,
108    /// Per-table statistics
109    pub table_stats: Vec<TableSampleStats>,
110    /// Warning messages
111    pub warnings: Vec<String>,
112    /// FK orphan count (rows rejected due to missing parents)
113    pub fk_orphans_rejected: u64,
114}
115
116/// Per-table sampling statistics
117#[derive(Debug, Clone, serde::Serialize)]
118pub struct TableSampleStats {
119    pub name: String,
120    pub rows_seen: u64,
121    pub rows_selected: u64,
122    pub classification: TableClassification,
123}
124
125/// Runtime state for a table during sampling
126struct TableRuntime {
127    /// Table name
128    name: String,
129    /// Primary key hashes for FK membership checks (compact: 8 bytes per key)
130    pk_set: PkHashSet,
131    /// Rows seen count
132    rows_seen: u64,
133    /// Rows selected count
134    rows_selected: u64,
135    /// Whether to skip this table
136    skip: bool,
137    /// Table classification
138    classification: TableClassification,
139    /// FK orphans rejected for this table
140    fk_orphans: u64,
141    /// Path to temp file containing selected row bytes (None if no rows selected yet)
142    selected_temp_path: Option<PathBuf>,
143}
144
145/// Combined row representation for both MySQL INSERT and PostgreSQL COPY
146enum UnifiedRow {
147    Insert(ParsedRow),
148    Copy(ParsedCopyRow),
149}
150
151/// Row format indicator for output
152#[derive(Debug, Clone, Copy, PartialEq)]
153enum RowFormat {
154    Insert,
155    Copy,
156}
157
158impl UnifiedRow {
159    fn pk(&self) -> Option<&smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>> {
160        match self {
161            UnifiedRow::Insert(r) => r.pk.as_ref(),
162            UnifiedRow::Copy(r) => r.pk.as_ref(),
163        }
164    }
165
166    fn fk_values(
167        &self,
168    ) -> &[(
169        crate::parser::mysql_insert::FkRef,
170        smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
171    )] {
172        match self {
173            UnifiedRow::Insert(r) => &r.fk_values,
174            UnifiedRow::Copy(r) => &r.fk_values,
175        }
176    }
177}
178
179/// Run the sample command
180pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
181    // Load YAML config if provided
182    let yaml_config = if let Some(ref path) = config.config_file {
183        Some(SampleYamlConfig::load(path)?)
184    } else {
185        None
186    };
187
188    let mut rng = StdRng::seed_from_u64(config.seed);
189    let mut stats = SampleStats::default();
190
191    // Get file size for progress tracking
192    let file_size = std::fs::metadata(&config.input)?.len();
193
194    // Progress bar setup - byte-based for the split phase
195    let progress_bar = if config.progress {
196        let pb = ProgressBar::new(file_size);
197        pb.set_style(
198            ProgressStyle::with_template(
199                "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
200            )
201            .unwrap()
202            .progress_chars("█▓▒░  ")
203            .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
204        );
205        pb.enable_steady_tick(std::time::Duration::from_millis(100));
206        pb.set_message("Splitting dump...");
207        Some(pb)
208    } else {
209        None
210    };
211
212    // Phase 0: Split into temp per-table files
213    let temp_dir = TempDir::new()?;
214    let tables_dir = temp_dir.path().join("tables");
215
216    let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
217        .with_dialect(config.dialect)
218        .with_content_filter(ContentFilter::All);
219
220    if let Some(ref pb) = progress_bar {
221        let pb_clone = pb.clone();
222        splitter = splitter.with_progress(move |bytes| {
223            pb_clone.set_position(bytes);
224        });
225    }
226
227    let split_stats = splitter.split()?;
228
229    // Finish byte-based progress, switch to milestone messages
230    if let Some(ref pb) = progress_bar {
231        pb.finish_and_clear();
232    }
233
234    if config.progress {
235        eprintln!(
236            "Split complete: {} tables, {} statements",
237            split_stats.tables_found, split_stats.statements_processed
238        );
239    }
240
241    // Phase 1: Build schema graph
242    if config.progress {
243        eprintln!("Building schema graph...");
244    }
245
246    let graph = build_schema_graph(&tables_dir, &config)?;
247
248    let (topo_order, cyclic_tables) = graph.processing_order();
249
250    if !cyclic_tables.is_empty() {
251        let names: Vec<_> = cyclic_tables
252            .iter()
253            .filter_map(|&id| graph.table_name(id))
254            .collect();
255        let msg = format!(
256            "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
257            cyclic_tables.len(),
258            names
259        );
260        if config.progress {
261            eprintln!("{}", msg);
262        }
263        stats.warnings.push(msg);
264    }
265
266    // Build set of cyclic table IDs for quick lookup
267    let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
268
269    // Determine root tables
270    let explicit_roots: ahash::AHashSet<String> = config
271        .root_tables
272        .iter()
273        .map(|s| s.to_lowercase())
274        .collect();
275
276    // Initialize table runtimes with classification
277    let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
278    for table in graph.schema.iter() {
279        let classification =
280            determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
281        let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
282
283        runtimes.insert(
284            table.id,
285            TableRuntime {
286                name: table.name.clone(),
287                pk_set: PkHashSet::default(),
288                rows_seen: 0,
289                rows_selected: 0,
290                skip,
291                classification,
292                fk_orphans: 0,
293                selected_temp_path: None,
294            },
295        );
296    }
297
298    // Create directory for selected row temp files
299    let selected_dir = temp_dir.path().join("selected");
300    fs::create_dir_all(&selected_dir)?;
301
302    // Phase 2: Process tables in dependency order
303    if config.progress {
304        eprintln!(
305            "Sampling {} tables in dependency order...",
306            topo_order.len()
307        );
308    }
309
310    // Process acyclic tables first, then cyclic tables
311    let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
312
313    let mut total_selected: u64 = 0;
314
315    for table_id in &all_tables {
316        let table_schema = match graph.schema.table(*table_id) {
317            Some(s) => s,
318            None => continue,
319        };
320
321        // Check if we should skip this table
322        let (should_skip, table_name, classification) = {
323            let runtime = match runtimes.get(table_id) {
324                Some(r) => r,
325                None => continue,
326            };
327            (runtime.skip, runtime.name.clone(), runtime.classification)
328        };
329
330        if should_skip {
331            stats.tables_skipped += 1;
332            continue;
333        }
334
335        // Handle lookup/global tables specially
336        let sample_mode = match classification {
337            TableClassification::Lookup => {
338                match config.include_global {
339                    GlobalTableMode::None => {
340                        stats.tables_skipped += 1;
341                        continue;
342                    }
343                    GlobalTableMode::Lookups | GlobalTableMode::All => {
344                        // Include all rows
345                        SampleMode::Percent(100)
346                    }
347                }
348            }
349            TableClassification::System => {
350                stats.tables_skipped += 1;
351                continue;
352            }
353            _ => get_table_sample_mode(&table_name, &config, &yaml_config),
354        };
355
356        let table_file = tables_dir.join(format!("{}.sql", table_name));
357        if !table_file.exists() {
358            continue;
359        }
360
361        // Process table with streaming sampling - rows go directly to temp file
362        let result = sample_table_streaming(
363            &table_file,
364            table_schema,
365            *table_id,
366            &table_name,
367            sample_mode,
368            &config,
369            &runtimes,
370            &cyclic_set,
371            &selected_dir,
372            &mut rng,
373        )?;
374
375        // Check max_total_rows guard
376        if let Some(max) = config.max_total_rows {
377            if total_selected + result.rows_selected > max as u64 {
378                let msg = format!(
379                    "Warning: Reached max_total_rows limit ({}) at table '{}'",
380                    max, table_name
381                );
382                stats.warnings.push(msg);
383                break;
384            }
385        }
386
387        // Update total count
388        total_selected += result.rows_selected;
389
390        // Update runtime state and add PK hashes for FK checks by children
391        // Safe: runtime existence was checked at loop start (line 323-326)
392        let runtime = runtimes
393            .get_mut(table_id)
394            .expect("runtime must exist - checked at loop start");
395        runtime.rows_seen = result.rows_seen;
396        runtime.rows_selected = result.rows_selected;
397        runtime.fk_orphans = result.fk_orphans;
398
399        // Add PK hashes for FK membership checks by child tables
400        for pk_hash in result.pk_hashes {
401            runtime.pk_set.insert(pk_hash);
402        }
403
404        // Set the temp file path if we selected any rows
405        if result.rows_selected > 0 {
406            let temp_path = selected_dir.join(format!("{}.rows", table_name));
407            if temp_path.exists() {
408                runtime.selected_temp_path = Some(temp_path);
409            }
410        }
411
412        stats.fk_orphans_rejected += result.fk_orphans;
413
414        stats.table_stats.push(TableSampleStats {
415            name: runtime.name.clone(),
416            rows_seen: result.rows_seen,
417            rows_selected: result.rows_selected,
418            classification: runtime.classification,
419        });
420    }
421
422    // Calculate totals
423    for table_stats in &stats.table_stats {
424        stats.total_rows_seen += table_stats.rows_seen;
425        stats.total_rows_selected += table_stats.rows_selected;
426    }
427    stats.tables_sampled = stats.table_stats.len();
428
429    if config.progress {
430        eprintln!("Sampling complete");
431    }
432
433    // Phase 3: Output synthesis
434    if config.dry_run {
435        return Ok(stats);
436    }
437
438    if config.progress {
439        eprintln!("Writing output...");
440    }
441
442    write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
443
444    Ok(stats)
445}
446
447/// Build schema graph from split table files
448fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
449    let mut builder = SchemaBuilder::new();
450
451    for entry in fs::read_dir(tables_dir)? {
452        let entry = entry?;
453        let path = entry.path();
454
455        if path.extension().map(|e| e == "sql").unwrap_or(false) {
456            let file = File::open(&path)?;
457            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
458
459            while let Some(stmt) = parser.read_statement()? {
460                let stmt_str = String::from_utf8_lossy(&stmt);
461                let (stmt_type, _) =
462                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
463
464                match stmt_type {
465                    StatementType::CreateTable => {
466                        builder.parse_create_table(&stmt_str);
467                    }
468                    StatementType::AlterTable => {
469                        builder.parse_alter_table(&stmt_str);
470                    }
471                    _ => {}
472                }
473            }
474        }
475    }
476
477    Ok(SchemaGraph::from_schema(builder.build()))
478}
479
480/// Determine table classification
481fn determine_classification(
482    name: &str,
483    graph: &SchemaGraph,
484    table_id: TableId,
485    yaml_config: &Option<SampleYamlConfig>,
486    explicit_roots: &ahash::AHashSet<String>,
487) -> TableClassification {
488    // Check explicit roots first
489    if explicit_roots.contains(&name.to_lowercase()) {
490        return TableClassification::Root;
491    }
492
493    // Check YAML config
494    if let Some(ref config) = yaml_config {
495        let class = config.get_classification(name);
496        if class != TableClassification::Normal {
497            return class;
498        }
499    }
500
501    // Check if it's a graph root (no parents)
502    if graph.parents[table_id.0 as usize].is_empty() {
503        return TableClassification::Root;
504    }
505
506    // Use default classifier
507    DefaultClassifier::classify(name)
508}
509
510/// Check if a table should be skipped
511fn should_skip_table(
512    name: &str,
513    config: &SampleConfig,
514    yaml_config: &Option<SampleYamlConfig>,
515    classification: TableClassification,
516) -> bool {
517    let name_lower = name.to_lowercase();
518
519    // Check exclude list
520    if config
521        .exclude
522        .iter()
523        .any(|e| e.to_lowercase() == name_lower)
524    {
525        return true;
526    }
527
528    // Check YAML skip
529    if let Some(ref yc) = yaml_config {
530        if yc.should_skip(name) {
531            return true;
532        }
533    }
534
535    // Check include filter
536    if let Some(ref filter) = config.tables_filter {
537        if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
538            return true;
539        }
540    }
541
542    // Skip system tables by default
543    if classification == TableClassification::System {
544        return true;
545    }
546
547    false
548}
549
550/// Get sample mode for a specific table
551fn get_table_sample_mode(
552    name: &str,
553    config: &SampleConfig,
554    yaml_config: &Option<SampleYamlConfig>,
555) -> SampleMode {
556    // Check YAML config first
557    if let Some(ref yc) = yaml_config {
558        if let Some(rows) = yc.get_rows(name) {
559            return SampleMode::Rows(rows);
560        }
561        if let Some(percent) = yc.get_percent(name) {
562            return SampleMode::Percent(percent);
563        }
564    }
565
566    // Fall back to global config
567    config.mode
568}
569
570/// Result from streaming sampling
571struct StreamingSampleResult {
572    rows_seen: u64,
573    rows_selected: u64,
574    fk_orphans: u64,
575    /// PK hashes of selected rows (for FK checks by children)
576    pk_hashes: Vec<u64>,
577}
578
579/// Stream-sample a table: parse rows, apply FK checks, sample inline, write to temp file.
580/// Returns StreamingSampleResult with stats and PK hashes.
581/// Uses Bernoulli sampling for --percent mode (single pass).
582/// For --rows mode, we use reservoir sampling on row indices with a second pass.
583#[allow(clippy::too_many_arguments)]
584fn sample_table_streaming(
585    table_file: &Path,
586    table_schema: &crate::schema::TableSchema,
587    table_id: TableId,
588    table_name: &str,
589    sample_mode: SampleMode,
590    config: &SampleConfig,
591    runtimes: &AHashMap<TableId, TableRuntime>,
592    cyclic_set: &ahash::AHashSet<TableId>,
593    selected_dir: &Path,
594    rng: &mut StdRng,
595) -> anyhow::Result<StreamingSampleResult> {
596    let mut rows_seen = 0u64;
597    let mut rows_selected = 0u64;
598    let mut fk_orphans = 0u64;
599
600    // Temp file for selected rows
601    let temp_path = selected_dir.join(format!("{}.rows", table_name));
602    let mut temp_writer: Option<BufWriter<File>> = None;
603
604    // Track PKs of selected rows (for children's FK checks)
605    let mut selected_pk_hashes: Vec<u64> = Vec::new();
606
607    // For PostgreSQL COPY, track the current column order
608    let mut copy_columns: Vec<String> = Vec::new();
609
610    match sample_mode {
611        SampleMode::Percent(p) => {
612            // Bernoulli sampling: decide immediately for each row
613            let prob = p as f64 / 100.0;
614
615            let file = File::open(table_file)?;
616            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
617
618            while let Some(stmt) = parser.read_statement()? {
619                let (stmt_type, _) =
620                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
621
622                match stmt_type {
623                    StatementType::Insert => {
624                        let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
625                        for row in rows {
626                            rows_seen += 1;
627
628                            // FK check
629                            if config.preserve_relations {
630                                let unified = UnifiedRow::Insert(row.clone());
631                                let (passes, orphan) = check_unified_fk_membership(
632                                    &unified,
633                                    table_schema,
634                                    runtimes,
635                                    cyclic_set,
636                                    &table_id,
637                                );
638                                if !passes {
639                                    fk_orphans += 1;
640                                    if orphan && config.strict_fk {
641                                        anyhow::bail!(
642                                            "FK integrity violation in table '{}': row references missing parent",
643                                            table_name
644                                        );
645                                    }
646                                    continue;
647                                }
648                            }
649
650                            // Bernoulli sample
651                            if rng.random::<f64>() < prob {
652                                // Write to temp file
653                                if temp_writer.is_none() {
654                                    temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
655                                }
656                                let writer = temp_writer.as_mut().unwrap();
657                                // Format: 1-byte type (0=insert, 1=copy), then row bytes, then newline
658                                writer.write_all(&[0u8])?;
659                                writer.write_all(&row.raw)?;
660                                writer.write_all(b"\n")?;
661
662                                // Track PK hash
663                                if let Some(pk) = &row.pk {
664                                    selected_pk_hashes.push(hash_pk_tuple(pk));
665                                }
666                                rows_selected += 1;
667                            }
668                        }
669                    }
670                    StatementType::Copy => {
671                        let header = String::from_utf8_lossy(&stmt);
672                        copy_columns = parse_copy_columns(&header);
673                    }
674                    StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
675                        if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
676                            let rows = parse_postgres_copy_rows(
677                                &stmt,
678                                table_schema,
679                                copy_columns.clone(),
680                            )?;
681                            for row in rows {
682                                rows_seen += 1;
683
684                                if config.preserve_relations {
685                                    let unified = UnifiedRow::Copy(row.clone());
686                                    let (passes, orphan) = check_unified_fk_membership(
687                                        &unified,
688                                        table_schema,
689                                        runtimes,
690                                        cyclic_set,
691                                        &table_id,
692                                    );
693                                    if !passes {
694                                        fk_orphans += 1;
695                                        if orphan && config.strict_fk {
696                                            anyhow::bail!(
697                                                "FK integrity violation in table '{}': row references missing parent",
698                                                table_name
699                                            );
700                                        }
701                                        continue;
702                                    }
703                                }
704
705                                if rng.random::<f64>() < prob {
706                                    if temp_writer.is_none() {
707                                        temp_writer =
708                                            Some(BufWriter::new(File::create(&temp_path)?));
709                                    }
710                                    let writer = temp_writer.as_mut().unwrap();
711                                    writer.write_all(&[1u8])?;
712                                    writer.write_all(&row.raw)?;
713                                    writer.write_all(b"\n")?;
714
715                                    if let Some(pk) = &row.pk {
716                                        selected_pk_hashes.push(hash_pk_tuple(pk));
717                                    }
718                                    rows_selected += 1;
719                                }
720                            }
721                        }
722                    }
723                    _ => {}
724                }
725            }
726        }
727        SampleMode::Rows(n) => {
728            // Reservoir sampling: collect eligible row indices in first pass,
729            // then write selected rows in second pass
730            let mut reservoir: Reservoir<(u64, RowFormat, Option<u64>)> =
731                Reservoir::new(n, StdRng::from_rng(&mut *rng));
732
733            // First pass: build reservoir of (row_index, format, pk_hash)
734            let file = File::open(table_file)?;
735            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
736
737            while let Some(stmt) = parser.read_statement()? {
738                let (stmt_type, _) =
739                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
740
741                match stmt_type {
742                    StatementType::Insert => {
743                        let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
744                        for row in rows {
745                            let current_idx = rows_seen;
746                            rows_seen += 1;
747
748                            if config.preserve_relations {
749                                let unified = UnifiedRow::Insert(row.clone());
750                                let (passes, orphan) = check_unified_fk_membership(
751                                    &unified,
752                                    table_schema,
753                                    runtimes,
754                                    cyclic_set,
755                                    &table_id,
756                                );
757                                if !passes {
758                                    fk_orphans += 1;
759                                    if orphan && config.strict_fk {
760                                        anyhow::bail!(
761                                            "FK integrity violation in table '{}': row references missing parent",
762                                            table_name
763                                        );
764                                    }
765                                    continue;
766                                }
767                            }
768
769                            let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
770                            reservoir.consider((current_idx, RowFormat::Insert, pk_hash));
771                        }
772                    }
773                    StatementType::Copy => {
774                        let header = String::from_utf8_lossy(&stmt);
775                        copy_columns = parse_copy_columns(&header);
776                    }
777                    StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
778                        if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
779                            let rows = parse_postgres_copy_rows(
780                                &stmt,
781                                table_schema,
782                                copy_columns.clone(),
783                            )?;
784                            for row in rows {
785                                let current_idx = rows_seen;
786                                rows_seen += 1;
787
788                                if config.preserve_relations {
789                                    let unified = UnifiedRow::Copy(row.clone());
790                                    let (passes, orphan) = check_unified_fk_membership(
791                                        &unified,
792                                        table_schema,
793                                        runtimes,
794                                        cyclic_set,
795                                        &table_id,
796                                    );
797                                    if !passes {
798                                        fk_orphans += 1;
799                                        if orphan && config.strict_fk {
800                                            anyhow::bail!(
801                                                "FK integrity violation in table '{}': row references missing parent",
802                                                table_name
803                                            );
804                                        }
805                                        continue;
806                                    }
807                                }
808
809                                let pk_hash = row.pk.as_ref().map(hash_pk_tuple);
810                                reservoir.consider((current_idx, RowFormat::Copy, pk_hash));
811                            }
812                        }
813                    }
814                    _ => {}
815                }
816            }
817
818            // Extract selected indices and PKs from reservoir
819            let selected_items = reservoir.into_items();
820            if selected_items.is_empty() {
821                return Ok(StreamingSampleResult {
822                    rows_seen,
823                    rows_selected: 0,
824                    fk_orphans,
825                    pk_hashes: Vec::new(),
826                });
827            }
828
829            // Collect PK hashes and sort indices for second pass
830            let mut selected_indices: Vec<(u64, RowFormat)> =
831                Vec::with_capacity(selected_items.len());
832            for (idx, format, pk_hash) in selected_items {
833                if let Some(h) = pk_hash {
834                    selected_pk_hashes.push(h);
835                }
836                selected_indices.push((idx, format));
837            }
838            selected_indices.sort_by_key(|(idx, _)| *idx);
839
840            // Second pass: write selected rows to temp file
841            let file = File::open(table_file)?;
842            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
843            let mut current_row_idx = 0u64;
844            let mut select_iter = selected_indices.iter().peekable();
845
846            temp_writer = Some(BufWriter::new(File::create(&temp_path)?));
847            let writer = temp_writer.as_mut().unwrap();
848
849            while let Some(stmt) = parser.read_statement()? {
850                if select_iter.peek().is_none() {
851                    break; // All selected rows written
852                }
853
854                let (stmt_type, _) =
855                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
856
857                match stmt_type {
858                    StatementType::Insert => {
859                        let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
860                        for row in rows {
861                            if let Some((next_idx, _)) = select_iter.peek() {
862                                if current_row_idx == *next_idx {
863                                    writer.write_all(&[0u8])?;
864                                    writer.write_all(&row.raw)?;
865                                    writer.write_all(b"\n")?;
866                                    rows_selected += 1;
867                                    select_iter.next();
868                                }
869                            }
870                            current_row_idx += 1;
871                        }
872                    }
873                    StatementType::Copy => {
874                        let header = String::from_utf8_lossy(&stmt);
875                        copy_columns = parse_copy_columns(&header);
876                    }
877                    StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
878                        if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
879                            let rows = parse_postgres_copy_rows(
880                                &stmt,
881                                table_schema,
882                                copy_columns.clone(),
883                            )?;
884                            for row in rows {
885                                if let Some((next_idx, _)) = select_iter.peek() {
886                                    if current_row_idx == *next_idx {
887                                        writer.write_all(&[1u8])?;
888                                        writer.write_all(&row.raw)?;
889                                        writer.write_all(b"\n")?;
890                                        rows_selected += 1;
891                                        select_iter.next();
892                                    }
893                                }
894                                current_row_idx += 1;
895                            }
896                        }
897                    }
898                    _ => {}
899                }
900            }
901        }
902    }
903
904    // Flush temp file
905    if let Some(mut writer) = temp_writer {
906        writer.flush()?;
907    }
908
909    Ok(StreamingSampleResult {
910        rows_seen,
911        rows_selected,
912        fk_orphans,
913        pk_hashes: selected_pk_hashes,
914    })
915}
916
917/// Check FK membership for a unified row (works with both INSERT and COPY rows)
918/// Uses hash-based lookup for memory efficiency.
919fn check_unified_fk_membership(
920    row: &UnifiedRow,
921    table_schema: &crate::schema::TableSchema,
922    runtimes: &AHashMap<TableId, TableRuntime>,
923    cyclic_set: &ahash::AHashSet<TableId>,
924    current_table_id: &TableId,
925) -> (bool, bool) {
926    let mut passes = true;
927    let mut is_orphan = false;
928
929    for (fk_ref, fk_tuple) in row.fk_values() {
930        if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
931            if let Some(parent_id) = fk.referenced_table_id {
932                // Skip FK check for cyclic tables
933                if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
934                    continue;
935                }
936
937                // Check if parent row exists in parent's pk_set using hash lookup
938                if let Some(parent_runtime) = runtimes.get(&parent_id) {
939                    let fk_hash = hash_pk_tuple(fk_tuple);
940                    if !parent_runtime.pk_set.contains(&fk_hash) {
941                        passes = false;
942                        is_orphan = true;
943                        break;
944                    }
945                }
946            }
947        }
948    }
949
950    (passes, is_orphan)
951}
952
953/// Write sampled output
954fn write_output(
955    config: &SampleConfig,
956    _graph: &SchemaGraph,
957    table_order: &[TableId],
958    runtimes: &AHashMap<TableId, TableRuntime>,
959    tables_dir: &Path,
960    stats: &SampleStats,
961) -> anyhow::Result<()> {
962    let mut writer: Box<dyn Write> = match &config.output {
963        Some(path) => {
964            if let Some(parent) = path.parent() {
965                fs::create_dir_all(parent)?;
966            }
967            Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
968        }
969        None => Box::new(BufWriter::new(std::io::stdout())),
970    };
971
972    // Write header comment
973    write_header(&mut writer, config, stats)?;
974
975    // Write dialect-specific header
976    write_dialect_header(&mut writer, config.dialect)?;
977
978    // Write schema for each table (if enabled)
979    if config.include_schema {
980        for &table_id in table_order {
981            let runtime = match runtimes.get(&table_id) {
982                Some(r) if !r.skip && r.rows_selected > 0 => r,
983                _ => continue,
984            };
985
986            let table_file = tables_dir.join(format!("{}.sql", runtime.name));
987            if !table_file.exists() {
988                continue;
989            }
990
991            // Copy schema statements from table file
992            let file = File::open(&table_file)?;
993            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
994
995            while let Some(stmt) = parser.read_statement()? {
996                let (stmt_type, _) =
997                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
998
999                if stmt_type.is_schema() {
1000                    writer.write_all(&stmt)?;
1001                    writer.write_all(b"\n")?;
1002                }
1003            }
1004        }
1005    }
1006
1007    // Write data for each table (reading from temp files instead of memory)
1008    for &table_id in table_order {
1009        let runtime = match runtimes.get(&table_id) {
1010            Some(r) if !r.skip && r.rows_selected > 0 && r.selected_temp_path.is_some() => r,
1011            _ => continue,
1012        };
1013
1014        let table_name = &runtime.name;
1015        let row_count = runtime.rows_selected;
1016
1017        writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
1018
1019        // Get the table name quoting based on dialect
1020        let quoted_name = match config.dialect {
1021            SqlDialect::MySql => format!("`{}`", table_name),
1022            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1023            SqlDialect::Mssql => format!("[{}]", table_name),
1024        };
1025
1026        // Read rows from temp file and write INSERTs in chunks
1027        let temp_path = runtime.selected_temp_path.as_ref().unwrap();
1028        let temp_file = File::open(temp_path)?;
1029        let reader = std::io::BufReader::new(temp_file);
1030        use std::io::BufRead;
1031
1032        const CHUNK_SIZE: usize = 1000;
1033        let mut chunk_buffer: Vec<(RowFormat, Vec<u8>)> = Vec::with_capacity(CHUNK_SIZE);
1034
1035        for line in reader.lines() {
1036            let line = line?;
1037            if line.is_empty() {
1038                continue;
1039            }
1040
1041            let bytes = line.as_bytes();
1042            if bytes.is_empty() {
1043                continue;
1044            }
1045
1046            // First byte is format indicator (0=insert, 1=copy)
1047            let format = if bytes[0] == 0 {
1048                RowFormat::Insert
1049            } else {
1050                RowFormat::Copy
1051            };
1052            let row_bytes = bytes[1..].to_vec();
1053
1054            chunk_buffer.push((format, row_bytes));
1055
1056            if chunk_buffer.len() >= CHUNK_SIZE {
1057                write_insert_chunk(&mut writer, &quoted_name, &chunk_buffer, config.dialect)?;
1058                chunk_buffer.clear();
1059            }
1060        }
1061
1062        // Write remaining rows
1063        if !chunk_buffer.is_empty() {
1064            write_insert_chunk(&mut writer, &quoted_name, &chunk_buffer, config.dialect)?;
1065        }
1066    }
1067
1068    // Write dialect-specific footer
1069    write_dialect_footer(&mut writer, config.dialect)?;
1070
1071    writer.flush()?;
1072
1073    Ok(())
1074}
1075
1076/// Write header comment
1077fn write_header<W: Write>(
1078    writer: &mut W,
1079    config: &SampleConfig,
1080    stats: &SampleStats,
1081) -> std::io::Result<()> {
1082    writeln!(writer, "-- Sampled from: {}", config.input.display())?;
1083    writeln!(
1084        writer,
1085        "-- Date: {}",
1086        chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1087    )?;
1088    writeln!(
1089        writer,
1090        "-- Mode: {:?}{}",
1091        config.mode,
1092        if config.preserve_relations {
1093            ", preserve-relations"
1094        } else {
1095            ""
1096        }
1097    )?;
1098    writeln!(writer, "-- Seed: {}", config.seed)?;
1099    writeln!(writer, "-- Dialect: {}", config.dialect)?;
1100    writeln!(writer, "--")?;
1101    writeln!(writer, "-- Statistics:")?;
1102    writeln!(writer, "--   Tables sampled: {}", stats.tables_sampled)?;
1103    writeln!(writer, "--   Tables skipped: {}", stats.tables_skipped)?;
1104
1105    let percent = if stats.total_rows_seen > 0 {
1106        (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1107    } else {
1108        0.0
1109    };
1110    writeln!(
1111        writer,
1112        "--   Total rows: {} (from {} original, {:.1}%)",
1113        stats.total_rows_selected, stats.total_rows_seen, percent
1114    )?;
1115
1116    if stats.fk_orphans_rejected > 0 {
1117        writeln!(
1118            writer,
1119            "--   FK orphans rejected: {}",
1120            stats.fk_orphans_rejected
1121        )?;
1122    }
1123
1124    if !stats.warnings.is_empty() {
1125        writeln!(writer, "--   Warnings: {}", stats.warnings.len())?;
1126    }
1127
1128    writeln!(writer)?;
1129
1130    Ok(())
1131}
1132
1133/// Write dialect-specific header
1134fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1135    match dialect {
1136        SqlDialect::MySql => {
1137            writeln!(writer, "SET NAMES utf8mb4;")?;
1138            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1139        }
1140        SqlDialect::Postgres => {
1141            writeln!(writer, "SET client_encoding = 'UTF8';")?;
1142            writeln!(writer, "SET session_replication_role = replica;")?;
1143        }
1144        SqlDialect::Sqlite => {
1145            writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1146        }
1147        SqlDialect::Mssql => {
1148            writeln!(writer, "SET ANSI_NULLS ON;")?;
1149            writeln!(writer, "SET QUOTED_IDENTIFIER ON;")?;
1150            writeln!(writer, "SET NOCOUNT ON;")?;
1151        }
1152    }
1153    writeln!(writer)?;
1154    Ok(())
1155}
1156
1157/// Write dialect-specific footer
1158fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1159    writeln!(writer)?;
1160    match dialect {
1161        SqlDialect::MySql => {
1162            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1163        }
1164        SqlDialect::Postgres => {
1165            writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1166        }
1167        SqlDialect::Sqlite => {
1168            writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1169        }
1170        SqlDialect::Mssql => {
1171            // No footer needed
1172        }
1173    }
1174    Ok(())
1175}
1176
1177/// Write a chunk of rows as an INSERT statement
1178fn write_insert_chunk<W: Write>(
1179    writer: &mut W,
1180    quoted_name: &str,
1181    chunk: &[(RowFormat, Vec<u8>)],
1182    dialect: SqlDialect,
1183) -> std::io::Result<()> {
1184    writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1185
1186    for (i, (format, row_bytes)) in chunk.iter().enumerate() {
1187        if i > 0 {
1188            writer.write_all(b",\n")?;
1189        }
1190
1191        let values = match format {
1192            RowFormat::Insert => match dialect {
1193                SqlDialect::Postgres => convert_row_to_postgres(row_bytes),
1194                _ => row_bytes.clone(),
1195            },
1196            RowFormat::Copy => convert_copy_to_insert_values(row_bytes, dialect),
1197        };
1198        writer.write_all(&values)?;
1199    }
1200
1201    writer.write_all(b";\n")?;
1202    Ok(())
1203}
1204
1205/// Convert a MySQL-style row to PostgreSQL syntax
1206fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1207    // Simple conversion: just replace escaped quotes
1208    // A full implementation would handle more edge cases
1209    let mut result = Vec::with_capacity(row.len());
1210    let mut i = 0;
1211
1212    while i < row.len() {
1213        if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1214            // MySQL: \' -> PostgreSQL: ''
1215            result.push(b'\'');
1216            result.push(b'\'');
1217            i += 2;
1218        } else {
1219            result.push(row[i]);
1220            i += 1;
1221        }
1222    }
1223
1224    result
1225}
1226
1227/// Convert PostgreSQL COPY format (tab-separated) to INSERT VALUES format
1228fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1229    let mut result = Vec::with_capacity(row.len() + 20);
1230    result.push(b'(');
1231
1232    let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1233
1234    for (i, field) in fields.iter().enumerate() {
1235        if i > 0 {
1236            result.extend_from_slice(b", ");
1237        }
1238
1239        // Check for NULL marker
1240        if *field == b"\\N" {
1241            result.extend_from_slice(b"NULL");
1242        } else if field.is_empty() {
1243            // Empty string
1244            match dialect {
1245                SqlDialect::MySql => result.extend_from_slice(b"''"),
1246                SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1247                    result.extend_from_slice(b"''")
1248                }
1249            }
1250        } else if is_numeric(field) {
1251            // Numeric value - no quotes needed
1252            result.extend_from_slice(field);
1253        } else {
1254            // String value - needs quoting
1255            result.push(b'\'');
1256            for &b in *field {
1257                match b {
1258                    b'\'' => {
1259                        // Escape single quote
1260                        match dialect {
1261                            SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1262                            SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1263                                result.extend_from_slice(b"''")
1264                            }
1265                        }
1266                    }
1267                    b'\\' if dialect == SqlDialect::MySql => {
1268                        // Escape backslash in MySQL
1269                        result.extend_from_slice(b"\\\\");
1270                    }
1271                    _ => result.push(b),
1272                }
1273            }
1274            result.push(b'\'');
1275        }
1276    }
1277
1278    result.push(b')');
1279    result
1280}
1281
1282/// Check if a byte slice represents a numeric value
1283fn is_numeric(s: &[u8]) -> bool {
1284    if s.is_empty() {
1285        return false;
1286    }
1287
1288    let mut has_digit = false;
1289    let mut has_dot = false;
1290    let mut start = 0;
1291
1292    // Handle leading sign
1293    if s[0] == b'-' || s[0] == b'+' {
1294        start = 1;
1295    }
1296
1297    for &b in &s[start..] {
1298        match b {
1299            b'0'..=b'9' => has_digit = true,
1300            b'.' if !has_dot => has_dot = true,
1301            b'e' | b'E' => {
1302                // Scientific notation - just check rest is digits
1303                continue;
1304            }
1305            _ => return false,
1306        }
1307    }
1308
1309    has_digit
1310}